Gemini AI common code and utils
This commit is contained in:
@@ -6,27 +6,29 @@
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from google.api_core import retry
|
||||
import google.generativeai as genai
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from ...gpt_providers.text_generation.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
def generate_with_retry(model, prompt):
|
||||
def generate_with_retry(prompt, system_prompt=None):
|
||||
"""
|
||||
Generates content from the model with retry handling for errors.
|
||||
Generates content using the llm_text_gen function with retry handling for errors.
|
||||
|
||||
Parameters:
|
||||
model (GenerativeModel): The generative model to use for content generation.
|
||||
prompt (str): The prompt to generate content from.
|
||||
system_prompt (str, optional): Custom system prompt to use instead of the default one.
|
||||
|
||||
Returns:
|
||||
str: The generated content.
|
||||
"""
|
||||
try:
|
||||
# FIXME: Need a progress bar here.
|
||||
return model.generate_content(prompt, request_options={'retry':retry.Retry()})
|
||||
# Use llm_text_gen instead of directly calling the model
|
||||
return llm_text_gen(prompt, system_prompt)
|
||||
except Exception as e:
|
||||
print(f"Error generating content: {e}")
|
||||
logger.error(f"Error generating content: {e}")
|
||||
return ""
|
||||
|
||||
|
||||
@@ -38,8 +40,15 @@ def ai_story(persona, story_setting, character_input,
|
||||
|
||||
Parameters:
|
||||
persona (str): The persona statement for the author.
|
||||
story_genre (str): The genre of the story.
|
||||
characters (str): The characters in the story.
|
||||
story_setting (str): The setting of the story.
|
||||
character_input (str): The characters in the story.
|
||||
plot_elements (str): The plot elements of the story.
|
||||
writing_style (str): The writing style of the story.
|
||||
story_tone (str): The tone of the story.
|
||||
narrative_pov (str): The narrative point of view.
|
||||
audience_age_group (str): The target audience age group.
|
||||
content_rating (str): The content rating of the story.
|
||||
ending_preference (str): The preferred ending of the story.
|
||||
"""
|
||||
st.info(f"""
|
||||
You have chosen to create a story set in **{story_setting}**.
|
||||
@@ -170,20 +179,16 @@ def ai_story(persona, story_setting, character_input,
|
||||
|
||||
{guidelines}
|
||||
'''
|
||||
|
||||
genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
|
||||
# Initialize the generative model
|
||||
model = genai.GenerativeModel('gemini-1.5-flash')
|
||||
|
||||
# Generate prompts
|
||||
try:
|
||||
premise = generate_with_retry(model, premise_prompt).text
|
||||
premise = generate_with_retry(premise_prompt)
|
||||
st.info(f"The premise of the story is: {premise}")
|
||||
except Exception as err:
|
||||
st.error(f"Premise Generation Error: {err}")
|
||||
return
|
||||
|
||||
outline = generate_with_retry(model, outline_prompt.format(premise=premise)).text
|
||||
outline = generate_with_retry(outline_prompt.format(premise=premise))
|
||||
with st.expander("Click to Checkout the outline, writing still in progress.."):
|
||||
st.markdown(f"The Outline of the story is: {outline}\n\n")
|
||||
|
||||
@@ -193,16 +198,16 @@ def ai_story(persona, story_setting, character_input,
|
||||
|
||||
# Generate starting draft
|
||||
try:
|
||||
starting_draft = generate_with_retry(model,
|
||||
starting_prompt.format(premise=premise, outline=outline)).text
|
||||
starting_draft = generate_with_retry(
|
||||
starting_prompt.format(premise=premise, outline=outline))
|
||||
except Exception as err:
|
||||
st.error(f"Failed to Generate Story draft: {err}")
|
||||
return
|
||||
|
||||
try:
|
||||
draft = starting_draft
|
||||
continuation = generate_with_retry(model,
|
||||
continuation_prompt.format(premise=premise, outline=outline, story_text=draft)).text
|
||||
continuation = generate_with_retry(
|
||||
continuation_prompt.format(premise=premise, outline=outline, story_text=draft))
|
||||
except Exception as err:
|
||||
st.error(f"Failed to write the initial draft: {err}")
|
||||
|
||||
@@ -217,8 +222,8 @@ def ai_story(persona, story_setting, character_input,
|
||||
while 'IAMDONE' not in continuation:
|
||||
try:
|
||||
status.update(label=f"Writing in progress... Current draft length: {len(draft)} characters")
|
||||
continuation = generate_with_retry(model,
|
||||
continuation_prompt.format(premise=premise, outline=outline, story_text=draft)).text
|
||||
continuation = generate_with_retry(
|
||||
continuation_prompt.format(premise=premise, outline=outline, story_text=draft))
|
||||
draft += '\n\n' + continuation
|
||||
except Exception as err:
|
||||
st.error(f"Failed to continually write the story: {err}")
|
||||
@@ -230,3 +235,4 @@ def ai_story(persona, story_setting, character_input,
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"Main Story writing: An error occurred: {e}")
|
||||
return ""
|
||||
|
||||
Reference in New Issue
Block a user