ALwrity Facebook Writer CopilotKit Implementation Plan
This commit is contained in:
@@ -2,11 +2,11 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import streamlit as st
|
||||
import base64
|
||||
from typing import List, Optional, Tuple
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
from loguru import logger
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
import logging
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...api_key_manager import APIKeyManager
|
||||
@@ -16,7 +16,9 @@ try:
|
||||
from google.generativeai import types
|
||||
except ImportError:
|
||||
genai = None
|
||||
logger.warning("Google genai library not available. Install with: pip install google-generativeai")
|
||||
logging.getLogger('gemini_image_generator').warning(
|
||||
"Google genai library not available. Install with: pip install google-generativeai"
|
||||
)
|
||||
|
||||
|
||||
from .save_image import save_generated_image
|
||||
@@ -28,9 +30,8 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
|
||||
# With image generation in Gemini, your imagination is the limit.
|
||||
# If what you see doesn't quite match what you had in mind, try adding more details to the prompt.
|
||||
# The more specific you are, the better Gemini can create images that reflect your vision.
|
||||
# With image generation in Gemini, your imagination is the limit.
|
||||
# Follow Google AI best practices for detailed prompts and iterative refinement.
|
||||
|
||||
# Generate images using Gemini
|
||||
# Gemini 2.0 Flash Experimental supports the ability to output text and inline images.
|
||||
@@ -167,161 +168,131 @@ class AIPromptGenerator:
|
||||
|
||||
return ", ".join(prompt_parts)
|
||||
|
||||
|
||||
def generate_gemini_image(prompt, keywords=None, style=None, focus=None, enhance_prompt=True, max_retries=3, initial_retry_delay=2, aspect_ratio="16:9"):
|
||||
"""
|
||||
Generate an image using Gemini's image generation capabilities.
|
||||
|
||||
Args:
|
||||
prompt (str): The text prompt for image generation
|
||||
keywords (list, optional): Keywords to enhance the prompt
|
||||
style (str, optional): Style of the image (photorealistic, artistic, etc.)
|
||||
focus (str, optional): Focus area for photorealistic images
|
||||
enhance_prompt (bool, optional): Whether to enhance the prompt with AI
|
||||
max_retries (int, optional): Maximum number of retry attempts
|
||||
initial_retry_delay (int, optional): Initial delay between retries
|
||||
aspect_ratio (str, optional): Aspect ratio for the generated image
|
||||
|
||||
Returns:
|
||||
str: The path to the generated image.
|
||||
"""
|
||||
logger.info(f"Generating image with prompt: '{prompt[:100]}...'")
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
def _ensure_client() -> Optional[object]:
|
||||
"""Create a Gemini client if available and API key is configured."""
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
|
||||
if not api_key:
|
||||
error_msg = "Gemini API key not found. Please configure it in the onboarding process."
|
||||
logger.error(error_msg)
|
||||
st.error(f"🔑 {error_msg}")
|
||||
if not api_key or genai is None:
|
||||
return None
|
||||
|
||||
# Enhance the prompt if requested
|
||||
try:
|
||||
return genai.Client(api_key=api_key)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def generate_gemini_images_base64(
|
||||
prompt: str,
|
||||
*,
|
||||
keywords: Optional[list] = None,
|
||||
style: Optional[str] = None,
|
||||
focus: Optional[str] = None,
|
||||
enhance_prompt: bool = True,
|
||||
aspect_ratio: str = "9:16",
|
||||
max_retries: int = 2,
|
||||
initial_retry_delay: float = 1.0,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Return list of base64 PNG images generated from a prompt.
|
||||
|
||||
Implements best practices per Gemini docs: send text prompt, parse inline image parts,
|
||||
and return base64 data suitable for API responses. No Streamlit, no printing.
|
||||
|
||||
Docs: https://ai.google.dev/gemini-api/docs/image-generation
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
logger.info("Generating image (base64) with Gemini")
|
||||
|
||||
if enhance_prompt and keywords:
|
||||
prompt_generator = AIPromptGenerator()
|
||||
if style == "photorealistic" and focus:
|
||||
logger.info(f"Generating photorealistic prompt with focus: {focus}")
|
||||
enhanced_prompt = prompt_generator.generate_photorealistic_prompt(keywords, focus)
|
||||
else:
|
||||
logger.info("Generating enhanced prompt")
|
||||
enhanced_prompt = prompt_generator.generate_prompt(keywords)
|
||||
|
||||
# Combine the enhanced prompt with the original prompt
|
||||
prompt = f"{prompt}\n\nEnhanced prompt: {enhanced_prompt}"
|
||||
logger.info(f"Final prompt: '{prompt[:100]}...'")
|
||||
|
||||
# Add aspect ratio to the prompt
|
||||
pg = AIPromptGenerator()
|
||||
enhanced = (
|
||||
pg.generate_photorealistic_prompt(keywords, focus)
|
||||
if style == "photorealistic" and focus
|
||||
else pg.generate_prompt(keywords)
|
||||
)
|
||||
prompt = f"{prompt}\n\nEnhanced prompt: {enhanced}"
|
||||
|
||||
# Optional hint in-text for aspect ratio; API doesn't take ratio param directly
|
||||
if aspect_ratio:
|
||||
prompt += f"\n\nPlease generate the image with {aspect_ratio} aspect ratio."
|
||||
|
||||
retry_count = 0
|
||||
retry_delay = initial_retry_delay
|
||||
|
||||
while retry_count <= max_retries:
|
||||
prompt = f"{prompt}\n\nAspect ratio: {aspect_ratio}"
|
||||
|
||||
client = _ensure_client()
|
||||
if client is None:
|
||||
logger.warning("Gemini client not available or API key missing")
|
||||
return []
|
||||
|
||||
retry = 0
|
||||
delay = initial_retry_delay
|
||||
while retry <= max_retries:
|
||||
try:
|
||||
client = genai.Client(api_key=api_key)
|
||||
contents = (prompt)
|
||||
|
||||
logger.info("Sending request to Gemini API")
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=contents,
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['Text', 'Image']
|
||||
)
|
||||
model="gemini-2.5-flash-image-preview",
|
||||
contents=[prompt],
|
||||
)
|
||||
logger.info("Received response from Gemini API")
|
||||
|
||||
img_name = None
|
||||
images_b64: List[str] = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.text is not None:
|
||||
logger.info(f"Received text response: '{part.text[:100]}...'")
|
||||
print(part.text)
|
||||
elif part.inline_data is not None:
|
||||
logger.info("Received image data from Gemini")
|
||||
image = Image.open(BytesIO((part.inline_data.data)))
|
||||
|
||||
# Resize image to match aspect ratio if needed
|
||||
if aspect_ratio:
|
||||
current_width, current_height = image.size
|
||||
target_width = current_width
|
||||
target_height = current_height
|
||||
|
||||
# Calculate target dimensions based on aspect ratio
|
||||
if aspect_ratio == "16:9":
|
||||
target_height = int(current_width * 9/16)
|
||||
elif aspect_ratio == "9:16":
|
||||
target_width = int(current_height * 9/16)
|
||||
elif aspect_ratio == "4:3":
|
||||
target_height = int(current_width * 3/4)
|
||||
elif aspect_ratio == "3:4":
|
||||
target_width = int(current_height * 3/4)
|
||||
elif aspect_ratio == "1:1":
|
||||
target_size = min(current_width, current_height)
|
||||
target_width = target_size
|
||||
target_height = target_size
|
||||
|
||||
logger.info(f"Resizing image from {current_width}x{current_height} to {target_width}x{target_height}")
|
||||
|
||||
# Create a new image with the target dimensions
|
||||
resized_image = Image.new('RGB', (target_width, target_height), (255, 255, 255))
|
||||
|
||||
# Calculate position to paste the original image
|
||||
paste_x = (target_width - current_width) // 2
|
||||
paste_y = (target_height - current_height) // 2
|
||||
|
||||
# Paste the original image onto the new canvas
|
||||
resized_image.paste(image, (paste_x, paste_y))
|
||||
image = resized_image
|
||||
|
||||
if part.text is not None:
|
||||
img_name = f'{part.text}-gemini-native-image.png'
|
||||
if getattr(part, 'inline_data', None) is not None:
|
||||
# part.inline_data.data is bytes (base64 decoded by SDK?)
|
||||
# Standardize to base64 string for API consumers
|
||||
raw = part.inline_data.data
|
||||
if isinstance(raw, bytes):
|
||||
images_b64.append(base64.b64encode(raw).decode('utf-8'))
|
||||
else:
|
||||
img_name = f'gemini-native-image-{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}.png'
|
||||
try:
|
||||
logger.info(f"Saving image to: {img_name}")
|
||||
image.save(img_name)
|
||||
|
||||
# Create a dictionary with the expected format for save_generated_image
|
||||
img_response = {
|
||||
"artifacts": [
|
||||
{
|
||||
"base64": base64.b64encode(open(img_name, "rb").read()).decode('utf-8')
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Call save_generated_image with the correct format
|
||||
save_generated_image(img_response)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to save image: {err}")
|
||||
st.error(f"Failed to save image: {err}")
|
||||
|
||||
logger.info(f"Image generation completed. Image name: {img_name}")
|
||||
return img_name
|
||||
except Exception as err:
|
||||
error_message = str(err)
|
||||
logger.error(f"Error in generate_gemini_image: {err}")
|
||||
|
||||
# Check if this is a 503 UNAVAILABLE error
|
||||
if "503 UNAVAILABLE" in error_message and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
logger.info(f"Model is overloaded. Retrying in {retry_delay} seconds (attempt {retry_count}/{max_retries})")
|
||||
st.warning(f"The image generation service is currently busy. Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
# Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
st.error(f"Error generating image: {err}")
|
||||
return None
|
||||
|
||||
# If we've exhausted all retries
|
||||
st.error("The image generation service is currently unavailable. Please try again later.")
|
||||
return None
|
||||
# Some SDKs may already present base64 str
|
||||
images_b64.append(str(raw))
|
||||
return images_b64
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
logger.warning(f"Gemini image gen error: {msg}")
|
||||
if "503" in msg and retry < max_retries:
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
retry += 1
|
||||
continue
|
||||
return []
|
||||
|
||||
|
||||
def edit_image(image_path, prompt, max_retries=3, initial_retry_delay=2):
|
||||
def generate_gemini_image(
|
||||
prompt,
|
||||
keywords=None,
|
||||
style=None,
|
||||
focus=None,
|
||||
enhance_prompt=True,
|
||||
max_retries=2,
|
||||
initial_retry_delay=1.0,
|
||||
aspect_ratio="9:16",
|
||||
):
|
||||
"""
|
||||
Backward-compatible wrapper that generates a single image file on disk and returns path.
|
||||
Prefer generate_gemini_images_base64 in new code paths.
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
images = generate_gemini_images_base64(
|
||||
prompt,
|
||||
keywords=keywords,
|
||||
style=style,
|
||||
focus=focus,
|
||||
enhance_prompt=enhance_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
max_retries=max_retries,
|
||||
initial_retry_delay=initial_retry_delay,
|
||||
)
|
||||
if not images:
|
||||
return None
|
||||
# Persist first image to file for legacy callers
|
||||
img_b64 = images[0]
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
out_name = f'gemini-native-image-{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}.png'
|
||||
try:
|
||||
img.save(out_name)
|
||||
# Also call save_generated_image to reuse existing pipeline
|
||||
save_generated_image({"artifacts": [{"base64": img_b64}]})
|
||||
return out_name
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def edit_image(image_path, prompt, max_retries=2, initial_retry_delay=1.0):
|
||||
"""
|
||||
- Image editing (text and image to image)
|
||||
Example prompt: "Edit this image to make it look like a cartoon"
|
||||
@@ -352,7 +323,9 @@ def edit_image(image_path, prompt, max_retries=3, initial_retry_delay=2):
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
client = genai.Client()
|
||||
client = _ensure_client()
|
||||
if client is None:
|
||||
return None
|
||||
text_input = (prompt)
|
||||
|
||||
logger.info("Sending request to Gemini API for image editing")
|
||||
@@ -367,13 +340,9 @@ def edit_image(image_path, prompt, max_retries=3, initial_retry_delay=2):
|
||||
|
||||
edited_img_name = None
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.text is not None:
|
||||
logger.info(f"Received text response: '{part.text[:100]}...'")
|
||||
st.write(part.text)
|
||||
elif part.inline_data is not None:
|
||||
if getattr(part, 'inline_data', None) is not None:
|
||||
logger.info("Received edited image data from Gemini")
|
||||
edited_image = Image.open(BytesIO(part.inline_data.data))
|
||||
edited_image.show()
|
||||
|
||||
# Save the edited image
|
||||
edited_img_name = f'edited-{os.path.basename(image_path)}'
|
||||
@@ -394,28 +363,22 @@ def edit_image(image_path, prompt, max_retries=3, initial_retry_delay=2):
|
||||
save_generated_image(img_response)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to save edited image: {err}")
|
||||
st.error(f"Failed to save edited image: {err}")
|
||||
|
||||
logger.info(f"Image editing completed. Edited image name: {edited_img_name}")
|
||||
return edited_img_name
|
||||
except Exception as err:
|
||||
error_message = str(err)
|
||||
logger.error(f"Error in edit_image: {err}")
|
||||
|
||||
# Check if this is a 503 UNAVAILABLE error
|
||||
if "503 UNAVAILABLE" in error_message and retry_count < max_retries:
|
||||
# Retry on transient 503
|
||||
if "503" in error_message and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
logger.info(f"Model is overloaded. Retrying in {retry_delay} seconds (attempt {retry_count}/{max_retries})")
|
||||
st.warning(f"The image editing service is currently busy. Retrying in {retry_delay} seconds...")
|
||||
logger.info(f"Retrying in {retry_delay} seconds (attempt {retry_count}/{max_retries})")
|
||||
time.sleep(retry_delay)
|
||||
# Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
st.error(f"Error editing image: {err}")
|
||||
return None
|
||||
|
||||
# If we've exhausted all retries
|
||||
st.error("The image editing service is currently unavailable. Please try again later.")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user