Fixed issue with Gemini API

This commit is contained in:
ajaysi
2024-04-22 10:09:07 +05:30
parent 180f28a493
commit 357cba36e4
15 changed files with 188 additions and 186 deletions

View File

@@ -1,79 +0,0 @@
"""
"""
import os
import logging
from pathlib import Path
import google.generativeai as genai
logging.basicConfig(level=logging.INFO, format='%(asctime)s-%(levelname)s-%(module)s-%(lineno)d-%(message)s')
from dotenv import load_dotenv
load_dotenv(Path('../../.env'))
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
) # for exponential backoff
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_get_img_info(prompt, img_path):
""" Get image details from arxiv papers. """
logging.info(f"Get image details from Gemini Pro.")
try:
genai.configure(api_key=os.getenv("GEMINI_API_KEY"))
except Exception as e:
logging.error(f"Could not load gemini API key: {e}")
raise e
# Set up the model
generation_config = {
"temperature": 0.9,
"top_p": 1,
"top_k": 1,
"max_output_tokens": 1096,
}
safety_settings = [{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "BLOCK_NONE"
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "BLOCK_NONE"
},]
try:
model = genai.GenerativeModel(model_name="gemini-pro-vision",
generation_config=generation_config,
safety_settings=safety_settings)
except Exception as e:
logging.error(f"Could not create GenerativeModel: {e}")
raise e
# Validate that an image is present
if not (img := Path(img_path)).exists():
raise FileNotFoundError(f"Could not find image: {img}")
image_parts = [{
"mime_type": "image/png",
"data": Path(img_path).read_bytes()
},]
prompt_parts = [f"{prompt}", image_parts[0],]
try:
response = model.generate_content(prompt_parts)
return response.text
except Exception as e:
logging.error(f"Gemini is blocking this request: {response.prompt_feedback.block_reason}")
logging.error(f"Gemini Vision, Failed to give image Details: {e}\n{response.prompt_feedback}")
raise e

View File

@@ -132,7 +132,7 @@ def ai_essay_generator(essay_title, selected_essay_type, selected_education_leve
load_dotenv(Path('../.env'))
genai.configure(api_key=os.getenv('GEMINI_API_KEY'))
# Initialize the generative model
model = genai.GenerativeModel('gemini-1.0-pro')
model = genai.GenerativeModel('gemini-pro')
# Generate prompts
try:

View File

@@ -19,7 +19,7 @@ from tenacity import (
)
#@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_text_response(prompt, temperature, top_p, n, max_tokens):
""" Common functiont to get response from gemini pro Text. """
#FIXME: Include : https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/System_instructions_REST.ipynb

View File

@@ -0,0 +1,41 @@
from PIL import Image
import requests
# Ensure you sign up for an account to obtain an API key:
# https://platform.stability.ai/
# Your API key can be found here after account creation:
# https://platform.stability.ai/account/keys
def generate_stable_diffusion_image(prompt):
"""
Generate images using Stable Diffusion API based on a given prompt.
Args:
prompt (str): The prompt to generate the image.
image_dir (str): The directory where the image will be saved.
Raises:
Warning: If the adult content classifier is triggered.
Exception: For any issues during image generation or saving.
"""
api_key = os.getenv('STABILITY_API_KEY')
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={
"authorization": f"Bearer {api_key}",
"accept": "image/*"
},
files={"none": ''},
data={
"prompt": prompt,
"output_format": "webp",
},
)
if response.status_code == 200:
with open("./dog-wearing-glasses.jpeg", 'wb') as file:
file.write(response.content)
else:
raise Exception(str(response.json()))

View File

@@ -20,11 +20,12 @@ logger.add(sys.stdout,
format="<level>{level}</level>|<green>{file}:{line}:{function}</green>| {message}"
)
from .gpt_providers.openai_gpt_provider import generate_dalle2_images, generate_dalle3_images, openai_chatgpt
from .stabl_diff_img2html import generate_stable_diffusion_image
#from .gen_dali2_images
from .gen_dali3_images import generate_dalle3_images
from .gen_stabl_diff_img import generate_stable_diffusion_image
def generate_image(user_prompt, image_dir, image_engine="dalle3"):
def generate_image(user_prompt, image_engine="dalle3"):
"""
The generation API endpoint creates an image based on a text prompt.
@@ -40,18 +41,14 @@ def generate_image(user_prompt, image_dir, image_engine="dalle3"):
Must be one of "url" or "b64_json". Defaults to "url".
--> user (str): A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
"""
logger.info(f"Generated blog images will be stored at: {image_dir=}")
img_prompt = generate_img_prompt(user_prompt)
# call the OpenAI API to generate image from prompt.
logger.info(f"Calling openai.image.generate with prompt: {img_prompt}")
logger.info(f"Calling image.generate with prompt: {img_prompt}")
if 'dalle2' in image_engine:
image_stored_at = generate_dalle2_images(img_prompt, image_dir)
elif 'dalle3' in image_engine:
image_stored_at = generate_dalle3_images(img_prompt, image_dir)
elif 'stable_diffusion' in image_engine:
image_stored_at = generate_stable_diffusion_image(img_prompt, image_dir)
if 'Dalle3' in image_engine:
image_stored_at = generate_dalle3_images(img_prompt)
elif 'Stable Diffusion' in image_engine:
image_stored_at = generate_stable_diffusion_image(img_prompt)
return image_stored_at
@@ -72,5 +69,5 @@ def generate_img_prompt(user_prompt):
Advice for creating prompt for image from the given text(no more than 150 words).
Reply with only one answer and no descrition. Generate image prompt for the below text.
Text: {user_prompt}"""
response = openai_chatgpt(prompt)
response = (prompt)
return response

View File

@@ -1,66 +0,0 @@
import os
import io
import warnings
from PIL import Image
from stability_sdk import client
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
# Set the host URL environment variable. Ensure it doesn't have 'https' or a trailing slash.
os.environ['STABILITY_HOST'] = 'grpc.stability.ai:443'
# Ensure you sign up for an account to obtain an API key:
# https://platform.stability.ai/
# Your API key can be found here after account creation:
# https://platform.stability.ai/account/keys
def generate_stable_diffusion_image(prompt, image_dir):
"""
Generate images using Stable Diffusion API based on a given prompt.
Args:
prompt (str): The prompt to generate the image.
image_dir (str): The directory where the image will be saved.
Raises:
Warning: If the adult content classifier is triggered.
Exception: For any issues during image generation or saving.
"""
try:
# Initialize the StabilityInference client with the API key and other settings.
stability_api = client.StabilityInference(
key=os.environ['STABILITY_KEY'], # Reference to the API key.
verbose=True, # Enable verbose mode for debug messages.
engine="stable-diffusion-xl-1024-v1-0", # Engine used for generation.
)
# Generating the image with specified parameters.
answers = stability_api.generate(
prompt=prompt,
seed=4253978046, # Deterministic seed for reproducible results.
steps=50, # Number of inference steps.
cfg_scale=7.0, # Strength of prompt matching.
width=1024, height=1024, # Image dimensions.
samples=1, # Number of images to generate.
sampler=generation.SAMPLER_K_DPMPP_2M # Denoising sampler selection.
)
# Process responses and save images.
for resp in answers:
for artifact in resp.artifacts:
if artifact.finish_reason == generation.FILTER:
warnings.warn(
"Request activated safety filters. Modify the prompt and retry."
)
if artifact.type == generation.ARTIFACT_IMAGE:
img = Image.open(io.BytesIO(artifact.binary))
img_name = os.path.join(image_dir, f"{artifact.seed}.png")
img.show()
img.save(img_name) # Save the image with the seed in the filename.
except Exception as e:
raise Exception(f"Error during image generation or saving: {e}")
# Example usage:
# generate_stable_diffusion_image("A futuristic cityscape", "/path/to/save/images/")

View File

@@ -5,6 +5,8 @@ from datetime import datetime
from prompt_toolkit.shortcuts import checkboxlist_dialog, message_dialog, input_dialog
from prompt_toolkit import prompt
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator, ValidationError
from prompt_toolkit.shortcuts import radiolist_dialog
from lib.ai_web_researcher.gpt_online_researcher import gpt_web_researcher
@@ -13,6 +15,7 @@ from lib.ai_writers.keywords_to_blog import write_blog_from_keywords
from lib.ai_writers.speech_to_blog.main_audio_to_blog import generate_audio_blog
from lib.gpt_providers.text_generation.ai_story_writer import ai_story_generator
from lib.gpt_providers.text_generation.ai_essay_writer import ai_essay_generator
from lib.gpt_providers.text_to_image_generation.generate_image_from_prompt import generate_image
def blog_from_audio():
@@ -258,6 +261,95 @@ def blog_tools():
).run()
def image_generator():
""" Generate image from given text """
print("Enter your long string below---")
img_prompt = prompt("Enter text to create image from:: ")
img_models = WordCompleter(['Stability-Stable-Diffusion', 'Dalle2', 'Dalle3'], ignore_case=True)
print("Choose between:: Stable-Diffusion, Dalle2, Dalle3")
img_model = prompt('Choose the image model to use for generation: ', completer=img_models, validator=ModelTypeValidator())
print(f"{img_prompt}----{img_model}")
try:
generate_image(img_prompt, img_model)
except Exception as err:
print(f"Failed to generate image: {err}")
class ModelTypeValidator(Validator):
def validate(self, document):
if document.text.lower() not in ['stability-stable-diffusion', 'dalle2', 'dalle3']:
raise ValidationError(message='Please choose a valid Text to image model.')
def image_to_text_writer():
""" IMage to Text Content Generation"""
os.system("clear" if os.name == "posix" else "cls")
text = "_______________________________________________________________________\n"
text += "\n⚠️ Alert! 💥❓💥\n"
text += "Provide Inputs Below to Continue..\n"
text += "_______________________________________________________________________\n\n"
print(text)
print("Make sure the file path is correct and the file is one of the following image types: PNG, JPEG, WEBP, HEIC, HEIF.\n")
file_location = prompt('⚠️ Enter the image file location: ', validator=FileTypeValidator())
if file_location:
writing_completer = WordCompleter(['Blog', 'Food Recipe', 'Alt Text', 'Marketing Copy'], ignore_case=True)
print("Choose between 'Blog', 'Food Recipe', 'Alt Text', 'Marketing Copy'")
writing_type = prompt('Select the type of writing: ', completer=writing_completer, validator=WritingTypeValidator())
prompt_gemini = None
if writing_type.lower() == 'blog':
prompt_gemini = "Given an image of a product and its target audience, write an engaging marketing description",
elif writing_type.lower() == 'food recipe':
prompt_gemini = """I have the ingredients above. Not sure what to cook for lunch.
Show me a list of foods with the recipes.
Accurately identify the baked good in the image and provide an appropriate and recipe consistent with your analysis.
Write a short, engaging blog post based on this picture.
It should include a description of the meal in the photo and talk about my journey meal prepping.
"""
elif writing_type.lower() == 'alt text':
prompt_gemini = """Given an image from my blog, generate 3 different ALT texts.
The image alt text should be of maximum 2 lines. It should be descriptive and SEO optimised."""
elif writing_type.lower() == 'marketing copy':
prompt_gemini = "Given an image of a product and its target audience, write an engaging marketing description"
print("TBD/FIXME: Will be taken up soon..")
class WritingTypeValidator(Validator):
def validate(self, document):
writing_type = document.text.strip().lower()
if writing_type not in ['blog', 'food recipe', 'alt text', 'marketing copy']:
raise ValidationError(message="Please select a valid writing type: Blog, Food Recipe, Alt Text, or Marketing Copy.")
class FileTypeValidator(Validator):
def validate(self, document):
file_path = document.text.strip()
if not os.path.exists(file_path):
raise ValidationError(message="File does not exist.")
elif not self.is_valid_file_type(file_path):
raise ValidationError(message="Unsupported file type or MIME type. Please select an image file.")
def is_valid_file_type(self, file_path):
# Define supported MIME types for image files
supported_types = ['image/png', 'image/jpeg', 'image/webp', 'image/heic', 'image/heif']
file_mime_type = self.get_file_mime_type(file_path)
return file_mime_type in supported_types
def get_file_mime_type(self, file_path):
# Placeholder function to get the MIME type of the file
# You can use libraries like magic or mimetypes for this purpose
# Example:
# import magic
# mime = magic.Magic(mime=True)
# return mime.from_file(file_path)
return 'image/png' # Placeholder value for demonstration
def competitor_analysis():
text = "_______________________________________________________________________\n"