ALwrity Prompts - AI Integration Plan
This commit is contained in:
@@ -3,6 +3,7 @@ import sys
|
||||
import time
|
||||
import datetime
|
||||
import base64
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
@@ -12,8 +13,8 @@ import logging
|
||||
from ...api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
import google.generativeai as genai
|
||||
from google.generativeai import types
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except ImportError:
|
||||
genai = None
|
||||
logging.getLogger('gemini_image_generator').warning(
|
||||
@@ -30,6 +31,24 @@ logging.basicConfig(
|
||||
)
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
|
||||
# Imagen fallback configuration
|
||||
IMAGEN_FALLBACK_CONFIG = {
|
||||
'enabled': os.getenv('IMAGEN_FALLBACK_ENABLED', 'true').lower() == 'true', # Master switch for Imagen fallback
|
||||
'auto_fallback': os.getenv('IMAGEN_AUTO_FALLBACK', 'true').lower() == 'true', # Automatically fall back on Gemini failures
|
||||
'preferred_model': os.getenv('IMAGEN_MODEL', 'imagen-4.0-generate-001'), # Fast model for quick generation
|
||||
'fallback_aspect_ratios': {
|
||||
'1:1': '1:1',
|
||||
'3:4': '3:4',
|
||||
'4:3': '4:3',
|
||||
'9:16': '9:16',
|
||||
'16:9': '16:9'
|
||||
},
|
||||
'max_images': int(os.getenv('IMAGEN_MAX_IMAGES', '1')), # Generate 1 image for LinkedIn posts
|
||||
}
|
||||
|
||||
# Log configuration on startup
|
||||
logger.info(f"🔄 Imagen fallback configuration: {IMAGEN_FALLBACK_CONFIG}")
|
||||
|
||||
# With image generation in Gemini, your imagination is the limit.
|
||||
# Follow Google AI best practices for detailed prompts and iterative refinement.
|
||||
|
||||
@@ -173,13 +192,137 @@ def _ensure_client() -> Optional[object]:
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
if not api_key or genai is None:
|
||||
if not api_key:
|
||||
logger.warning("No Gemini API key found")
|
||||
if genai is None:
|
||||
logger.warning("Google Generative AI library not available")
|
||||
return None
|
||||
try:
|
||||
return genai.Client(api_key=api_key)
|
||||
except Exception:
|
||||
logger.info("Creating Gemini client...")
|
||||
# Create a client using the correct API pattern
|
||||
# The API key is passed directly to the Client constructor
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("Gemini client created successfully")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Gemini client: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
|
||||
def _generate_imagen_images_base64(prompt: str, aspect_ratio: str = "1:1") -> List[str]:
|
||||
"""
|
||||
Generate images using Imagen API as a fallback method.
|
||||
|
||||
This function implements the Imagen API following the official documentation:
|
||||
https://ai.google.dev/gemini-api/docs/imagen
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for image generation
|
||||
aspect_ratio: Desired aspect ratio (1:1, 3:4, 4:3, 9:16, 16:9)
|
||||
|
||||
Returns:
|
||||
List of base64-encoded PNG images
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
logger.info("🔄 Falling back to Imagen API for image generation")
|
||||
|
||||
try:
|
||||
# Get API key for Imagen (can use same Gemini API key)
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini") # Imagen uses same API key
|
||||
|
||||
if not api_key:
|
||||
logger.error("No API key available for Imagen fallback")
|
||||
return []
|
||||
|
||||
# Create Imagen client
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Map aspect ratio to Imagen format using configuration
|
||||
imagen_aspect_ratio = IMAGEN_FALLBACK_CONFIG['fallback_aspect_ratios'].get(aspect_ratio, "1:1")
|
||||
|
||||
# Optimize prompt for Imagen (remove Gemini-specific formatting)
|
||||
imagen_prompt = _optimize_prompt_for_imagen(prompt)
|
||||
|
||||
logger.info(f"Generating Imagen images with prompt: {imagen_prompt[:100]}...")
|
||||
logger.info(f"Using aspect ratio: {imagen_aspect_ratio}")
|
||||
logger.info(f"Using model: {IMAGEN_FALLBACK_CONFIG['preferred_model']}")
|
||||
|
||||
# Generate images using configured Imagen model
|
||||
# Note: sample_image_size is not supported in current library version
|
||||
config_params = {
|
||||
'number_of_images': IMAGEN_FALLBACK_CONFIG['max_images'],
|
||||
'aspect_ratio': imagen_aspect_ratio,
|
||||
}
|
||||
|
||||
# Add additional configuration options if needed
|
||||
# config_params['guidance_scale'] = 7.5 # Optional: control image generation quality
|
||||
# config_params['person_generation'] = 'allow_adult' # Optional: control person generation
|
||||
|
||||
response = client.models.generate_images(
|
||||
model=IMAGEN_FALLBACK_CONFIG['preferred_model'],
|
||||
prompt=imagen_prompt,
|
||||
config=types.GenerateImagesConfig(**config_params)
|
||||
)
|
||||
|
||||
# Extract base64 images from response
|
||||
images_b64: List[str] = []
|
||||
for generated_image in response.generated_images:
|
||||
if hasattr(generated_image, 'image') and hasattr(generated_image.image, 'image_bytes'):
|
||||
# Convert image bytes to base64
|
||||
image_bytes = generated_image.image.image_bytes
|
||||
if isinstance(image_bytes, bytes):
|
||||
images_b64.append(base64.b64encode(image_bytes).decode('utf-8'))
|
||||
else:
|
||||
# If already base64 string
|
||||
images_b64.append(str(image_bytes))
|
||||
|
||||
if images_b64:
|
||||
logger.info(f"✅ Imagen fallback successful! Generated {len(images_b64)} images")
|
||||
return images_b64
|
||||
else:
|
||||
logger.warning("Imagen fallback returned no images")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Imagen fallback failed: {e}")
|
||||
import traceback
|
||||
logger.error(f"Imagen error traceback: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
|
||||
def _optimize_prompt_for_imagen(prompt: str) -> str:
|
||||
"""
|
||||
Optimize prompt for Imagen API by removing Gemini-specific formatting
|
||||
and enhancing it with Imagen best practices.
|
||||
|
||||
Based on Imagen prompt guide: https://ai.google.dev/gemini-api/docs/imagen
|
||||
"""
|
||||
# Remove Gemini-specific formatting
|
||||
prompt = prompt.replace('\n\nEnhanced prompt:', '')
|
||||
prompt = prompt.replace('\n\nAspect ratio:', '')
|
||||
|
||||
# Clean up extra whitespace
|
||||
prompt = ' '.join(prompt.split())
|
||||
|
||||
# Add Imagen-specific enhancements if not present
|
||||
if 'professional' in prompt.lower() and 'linkedin' in prompt.lower():
|
||||
# Enhance for LinkedIn professional content
|
||||
prompt += ", high quality, professional photography, business appropriate"
|
||||
|
||||
if 'digital transformation' in prompt.lower() or 'technology' in prompt.lower():
|
||||
# Enhance for tech content
|
||||
prompt += ", modern, innovative, clean design, corporate aesthetic"
|
||||
|
||||
# Ensure prompt doesn't exceed Imagen's 480 token limit
|
||||
if len(prompt) > 400: # Leave some buffer
|
||||
prompt = prompt[:400] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def generate_gemini_images_base64(
|
||||
prompt: str,
|
||||
*,
|
||||
@@ -190,17 +333,23 @@ def generate_gemini_images_base64(
|
||||
aspect_ratio: str = "9:16",
|
||||
max_retries: int = 2,
|
||||
initial_retry_delay: float = 1.0,
|
||||
enable_imagen_fallback: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Return list of base64 PNG images generated from a prompt.
|
||||
|
||||
Primary method: Gemini API for image generation
|
||||
Fallback method: Imagen API when Gemini fails (quota limits, API errors, etc.)
|
||||
|
||||
Implements best practices per Gemini docs: send text prompt, parse inline image parts,
|
||||
and return base64 data suitable for API responses. No Streamlit, no printing.
|
||||
|
||||
Docs: https://ai.google.dev/gemini-api/docs/image-generation
|
||||
Docs:
|
||||
- Gemini: https://ai.google.dev/gemini-api/docs/image-generation
|
||||
- Imagen: https://ai.google.dev/gemini-api/docs/imagen
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
logger.info("Generating image (base64) with Gemini")
|
||||
logger.info("Generating image (base64) with Gemini (with Imagen fallback)")
|
||||
|
||||
if enhance_prompt and keywords:
|
||||
pg = AIPromptGenerator()
|
||||
@@ -215,9 +364,13 @@ def generate_gemini_images_base64(
|
||||
if aspect_ratio:
|
||||
prompt = f"{prompt}\n\nAspect ratio: {aspect_ratio}"
|
||||
|
||||
# Try Gemini first
|
||||
client = _ensure_client()
|
||||
if client is None:
|
||||
logger.warning("Gemini client not available or API key missing")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("Falling back to Imagen API")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
retry = 0
|
||||
@@ -225,9 +378,10 @@ def generate_gemini_images_base64(
|
||||
while retry <= max_retries:
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.5-flash-image-preview",
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=[prompt],
|
||||
)
|
||||
|
||||
images_b64: List[str] = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if getattr(part, 'inline_data', None) is not None:
|
||||
@@ -239,16 +393,47 @@ def generate_gemini_images_base64(
|
||||
else:
|
||||
# Some SDKs may already present base64 str
|
||||
images_b64.append(str(raw))
|
||||
return images_b64
|
||||
|
||||
if images_b64:
|
||||
logger.info(f"✅ Gemini generated {len(images_b64)} images successfully")
|
||||
return images_b64
|
||||
else:
|
||||
logger.warning("Gemini returned no images, falling back to Imagen")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
logger.warning(f"Gemini image gen error: {msg}")
|
||||
|
||||
# Check if this is a quota/API error that warrants fallback
|
||||
if any(error_type in msg.lower() for error_type in [
|
||||
'quota', 'resource_exhausted', 'rate_limit', 'billing', 'api_key', '403', '429'
|
||||
]):
|
||||
logger.info("Gemini quota/API error detected, falling back to Imagen")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
# For other errors, retry if possible
|
||||
if "503" in msg and retry < max_retries:
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
# Final fallback for any other errors
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("Final fallback to Imagen due to Gemini error")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
# If all retries exhausted, fall back to Imagen
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("All Gemini retries exhausted, falling back to Imagen")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
|
||||
def generate_gemini_image(
|
||||
@@ -260,9 +445,12 @@ def generate_gemini_image(
|
||||
max_retries=2,
|
||||
initial_retry_delay=1.0,
|
||||
aspect_ratio="9:16",
|
||||
enable_imagen_fallback=True,
|
||||
):
|
||||
"""
|
||||
Backward-compatible wrapper that generates a single image file on disk and returns path.
|
||||
Now includes Imagen fallback for improved reliability.
|
||||
|
||||
Prefer generate_gemini_images_base64 in new code paths.
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
@@ -275,20 +463,31 @@ def generate_gemini_image(
|
||||
aspect_ratio=aspect_ratio,
|
||||
max_retries=max_retries,
|
||||
initial_retry_delay=initial_retry_delay,
|
||||
enable_imagen_fallback=enable_imagen_fallback,
|
||||
)
|
||||
if not images:
|
||||
return None
|
||||
|
||||
# Persist first image to file for legacy callers
|
||||
img_b64 = images[0]
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
out_name = f'gemini-native-image-{datetime.datetime.now().strftime("%Y%m%d-%H%M%S")}.png'
|
||||
|
||||
# Update filename to indicate which API was used
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
if 'imagen' in prompt.lower() or 'fallback' in prompt.lower():
|
||||
out_name = f'imagen-fallback-image-{timestamp}.png'
|
||||
else:
|
||||
out_name = f'gemini-native-image-{timestamp}.png'
|
||||
|
||||
try:
|
||||
img.save(out_name)
|
||||
# Also call save_generated_image to reuse existing pipeline
|
||||
save_generated_image({"artifacts": [{"base64": img_b64}]})
|
||||
logger.info(f"✅ Image saved successfully: {out_name}")
|
||||
return out_name
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to save image: {e}")
|
||||
return None
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user