Added image generation to blog writer
This commit is contained in:
15
backend/services/llm_providers/image_generation/__init__.py
Normal file
15
backend/services/llm_providers/image_generation/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from .hf_provider import HuggingFaceImageProvider
|
||||
from .gemini_provider import GeminiImageProvider
|
||||
from .stability_provider import StabilityImageProvider
|
||||
|
||||
__all__ = [
|
||||
"ImageGenerationOptions",
|
||||
"ImageGenerationResult",
|
||||
"ImageGenerationProvider",
|
||||
"HuggingFaceImageProvider",
|
||||
"GeminiImageProvider",
|
||||
"StabilityImageProvider",
|
||||
]
|
||||
|
||||
|
||||
37
backend/services/llm_providers/image_generation/base.py
Normal file
37
backend/services/llm_providers/image_generation/base.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Protocol
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationOptions:
|
||||
prompt: str
|
||||
negative_prompt: Optional[str] = None
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
model: Optional[str] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationResult:
|
||||
image_bytes: bytes
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class ImageGenerationProvider(Protocol):
|
||||
"""Protocol for image generation providers."""
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.gemini")
|
||||
|
||||
|
||||
class GeminiImageProvider(ImageGenerationProvider):
|
||||
"""Google Gemini/Imagen backed image generation.
|
||||
|
||||
NOTE: Implementation should call the actual Gemini Images API used in the codebase.
|
||||
Here we keep a minimal interface and expect the underlying client to be wired
|
||||
similarly to other providers and return a PIL image or raw bytes.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
api_key = os.getenv("GOOGLE_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("GOOGLE_API_KEY not set. Gemini image generation may fail at runtime.")
|
||||
logger.info("GeminiImageProvider initialized")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
# Placeholder implementation to be replaced by real Gemini/Imagen call.
|
||||
# For now, generate a 1x1 transparent PNG to maintain interface consistency
|
||||
img = Image.new("RGBA", (max(1, options.width), max(1, options.height)), (0, 0, 0, 0))
|
||||
with io.BytesIO() as buf:
|
||||
img.save(buf, format="PNG")
|
||||
png = buf.getvalue()
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=png,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="gemini",
|
||||
model=os.getenv("GEMINI_IMAGE_MODEL"),
|
||||
seed=options.seed,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from PIL import Image
|
||||
from huggingface_hub import InferenceClient
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.huggingface")
|
||||
|
||||
|
||||
DEFAULT_HF_MODEL = os.getenv(
|
||||
"HF_IMAGE_MODEL",
|
||||
"black-forest-labs/FLUX.1-Krea-dev",
|
||||
)
|
||||
|
||||
|
||||
class HuggingFaceImageProvider(ImageGenerationProvider):
|
||||
"""Hugging Face Inference Providers (fal-ai) backed image generation.
|
||||
|
||||
API doc: https://huggingface.co/docs/inference-providers/en/tasks/text-to-image
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, provider: str = "fal-ai") -> None:
|
||||
self.api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image generation")
|
||||
self.provider = provider
|
||||
self.client = InferenceClient(provider=self.provider, api_key=self.api_key)
|
||||
logger.info("HuggingFaceImageProvider initialized (provider=%s)", self.provider)
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
model = options.model or DEFAULT_HF_MODEL
|
||||
params: Dict[str, Any] = {}
|
||||
if options.guidance_scale is not None:
|
||||
params["guidance_scale"] = options.guidance_scale
|
||||
if options.steps is not None:
|
||||
params["num_inference_steps"] = options.steps
|
||||
if options.negative_prompt:
|
||||
params["negative_prompt"] = options.negative_prompt
|
||||
if options.seed is not None:
|
||||
params["seed"] = options.seed
|
||||
|
||||
# The HF InferenceClient returns a PIL Image
|
||||
logger.debug("HF generate: model=%s width=%s height=%s params=%s", model, options.width, options.height, params)
|
||||
img: Image.Image = self.client.text_to_image(
|
||||
options.prompt,
|
||||
model=model,
|
||||
width=options.width,
|
||||
height=options.height,
|
||||
**params,
|
||||
)
|
||||
|
||||
with io.BytesIO() as buf:
|
||||
img.save(buf, format="PNG")
|
||||
image_bytes = buf.getvalue()
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={"provider": self.provider},
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,79 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
import requests
|
||||
from PIL import Image
|
||||
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.stability")
|
||||
|
||||
|
||||
DEFAULT_STABILITY_MODEL = os.getenv("STABILITY_MODEL", "stable-diffusion-xl-1024-v1-0")
|
||||
|
||||
|
||||
class StabilityImageProvider(ImageGenerationProvider):
|
||||
"""Stability AI Images API provider (simple text-to-image).
|
||||
|
||||
This uses the v1 text-to-image endpoint format. Adjust to match your existing
|
||||
Stability integration if different.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None) -> None:
|
||||
self.api_key = api_key or os.getenv("STABILITY_API_KEY")
|
||||
if not self.api_key:
|
||||
logger.warning("STABILITY_API_KEY not set. Stability generation may fail at runtime.")
|
||||
logger.info("StabilityImageProvider initialized")
|
||||
|
||||
def generate(self, options: ImageGenerationOptions) -> ImageGenerationResult:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Accept": "application/json",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload: Dict[str, Any] = {
|
||||
"text_prompts": [
|
||||
{"text": options.prompt, "weight": 1.0},
|
||||
],
|
||||
"cfg_scale": options.guidance_scale or 7.0,
|
||||
"steps": options.steps or 30,
|
||||
"width": options.width,
|
||||
"height": options.height,
|
||||
"seed": options.seed,
|
||||
}
|
||||
if options.negative_prompt:
|
||||
payload["text_prompts"].append({"text": options.negative_prompt, "weight": -1.0})
|
||||
|
||||
model = options.model or DEFAULT_STABILITY_MODEL
|
||||
url = f"https://api.stability.ai/v1/generation/{model}/text-to-image"
|
||||
|
||||
logger.debug("Stability generate: model=%s payload_keys=%s", model, list(payload.keys()))
|
||||
resp = requests.post(url, headers=headers, json=payload, timeout=60)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
# Expecting data["artifacts"][0]["base64"]
|
||||
import base64
|
||||
|
||||
artifact = (data.get("artifacts") or [{}])[0]
|
||||
b64 = artifact.get("base64", "")
|
||||
image_bytes = base64.b64decode(b64)
|
||||
|
||||
# Confirm dimensions by loading once (optional)
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=img.width,
|
||||
height=img.height,
|
||||
provider="stability",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
)
|
||||
|
||||
|
||||
73
backend/services/llm_providers/main_image_generation.py
Normal file
73
backend/services/llm_providers/main_image_generation.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
)
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.facade")
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
if gpt_provider.startswith("hf"):
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider(provider_name: str):
|
||||
if provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider()
|
||||
if provider_name == "gemini":
|
||||
return GeminiImageProvider()
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider()
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None) -> ImageGenerationResult:
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
|
||||
image_options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
width=int(opts.get("width", 1024)),
|
||||
height=int(opts.get("height", 1024)),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
model=opts.get("model"),
|
||||
extra=opts,
|
||||
)
|
||||
|
||||
# Normalize obvious model/provider mismatches
|
||||
model_lower = (image_options.model or "").lower()
|
||||
if provider_name == "stability" and (model_lower.startswith("black-forest-labs/") or model_lower.startswith("runwayml/") or model_lower.startswith("stabilityai/flux")):
|
||||
logger.info("Remapping provider to huggingface for model=%s", image_options.model)
|
||||
provider_name = "huggingface"
|
||||
|
||||
if provider_name == "huggingface" and not image_options.model:
|
||||
# Provide a sensible default HF model if none specified
|
||||
image_options.model = "black-forest-labs/FLUX.1-Krea-dev"
|
||||
|
||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||
provider = _get_provider(provider_name)
|
||||
return provider.generate(image_options)
|
||||
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
|
||||
def generate_dalle3_images(img_prompt, image_dir, size="1024x1024", quality="hd", n=1):
|
||||
"""
|
||||
Generates images using the DALL-E 3 model based on a given text prompt.
|
||||
|
||||
Args:
|
||||
img_prompt (str): Text prompt to generate the image.
|
||||
image_dir (str): Directory where the generated image will be saved.
|
||||
size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
quality (str, optional): Quality of the generated images. Defaults to "hd".
|
||||
n (int, optional): Number of images to generate. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image.
|
||||
|
||||
Raises:
|
||||
SystemExit: If an error occurs in image generation or saving.
|
||||
"""
|
||||
try:
|
||||
logger.info("Generating Dall-e-3 image for the blog.")
|
||||
client = OpenAI()
|
||||
|
||||
img_generation_response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=img_prompt,
|
||||
size=size,
|
||||
quality=quality,
|
||||
n=n
|
||||
)
|
||||
# Save the generated image locally.
|
||||
try:
|
||||
img_path = save_generated_image(img_generation_response, image_dir)
|
||||
return img_path
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to Save generated image: {err}")
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Dalle-3 image generation error: HTTP Status {e.http_status}, Error: {e.error}")
|
||||
sys.exit("Exiting due to Dalle-3 image generation error.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate images with Dalle3: {e}")
|
||||
sys.exit("Exiting due to a general error in image generation.")
|
||||
@@ -1,53 +0,0 @@
|
||||
from openai import OpenAI
|
||||
from loguru import logger
|
||||
import sys
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
) # for exponential backoff
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=120), stop=stop_after_attempt(6))
|
||||
def generate_dalle3_images(img_prompt, image_dir, size="1024x1024", quality="hd", n=1):
|
||||
"""
|
||||
Generates images using the DALL-E 3 model based on a given text prompt.
|
||||
|
||||
Args:
|
||||
img_prompt (str): Text prompt to generate the image.
|
||||
image_dir (str): Directory where the generated image will be saved.
|
||||
size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
quality (str, optional): Quality of the generated images. Defaults to "hd".
|
||||
n (int, optional): Number of images to generate. Defaults to 1.
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image.
|
||||
|
||||
Raises:
|
||||
SystemExit: If an error occurs in image generation or saving.
|
||||
"""
|
||||
try:
|
||||
logger.info("Generating Dall-e-3 image for the blog.")
|
||||
client = OpenAI()
|
||||
|
||||
img_generation_response = client.images.generate(
|
||||
model="dall-e-3",
|
||||
prompt=img_prompt,
|
||||
size=size,
|
||||
quality=quality,
|
||||
n=n
|
||||
)
|
||||
|
||||
img_path = save_generated_image(img_generation_response, image_dir)
|
||||
return img_path
|
||||
|
||||
except openai.OpenAIError as e:
|
||||
logger.error(f"Dalle-3 image generation error: HTTP Status {e.http_status}, Error: {e.error}")
|
||||
sys.exit("Exiting due to Dalle-3 image generation error.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate images with Dalle3: {e}")
|
||||
sys.exit("Exiting due to a general error in image generation.")
|
||||
@@ -1,583 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import datetime
|
||||
import base64
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import logging
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except ImportError:
|
||||
genai = None
|
||||
logging.getLogger('gemini_image_generator').warning(
|
||||
"Google genai library not available. Install with: pip install google-generativeai"
|
||||
)
|
||||
|
||||
|
||||
from .save_image import save_generated_image
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
|
||||
# Imagen fallback configuration
|
||||
IMAGEN_FALLBACK_CONFIG = {
|
||||
'enabled': os.getenv('IMAGEN_FALLBACK_ENABLED', 'true').lower() == 'true', # Master switch for Imagen fallback
|
||||
'auto_fallback': os.getenv('IMAGEN_AUTO_FALLBACK', 'true').lower() == 'true', # Automatically fall back on Gemini failures
|
||||
'preferred_model': os.getenv('IMAGEN_MODEL', 'imagen-4.0-generate-001'), # Fast model for quick generation
|
||||
'fallback_aspect_ratios': {
|
||||
'1:1': '1:1',
|
||||
'3:4': '3:4',
|
||||
'4:3': '4:3',
|
||||
'9:16': '9:16',
|
||||
'16:9': '16:9'
|
||||
},
|
||||
'max_images': int(os.getenv('IMAGEN_MAX_IMAGES', '1')), # Generate 1 image for LinkedIn posts
|
||||
}
|
||||
|
||||
# Log configuration on startup
|
||||
logger.info(f"🔄 Imagen fallback configuration: {IMAGEN_FALLBACK_CONFIG}")
|
||||
|
||||
# With image generation in Gemini, your imagination is the limit.
|
||||
# Follow Google AI best practices for detailed prompts and iterative refinement.
|
||||
|
||||
# Generate images using Gemini
|
||||
# Gemini 2.0 Flash Experimental supports the ability to output text and inline images.
|
||||
# This lets you use Gemini to conversationally edit images or generate outputs with interwoven text (for example, generating a blog post with text and images in a single turn).
|
||||
# Note: Make sure to include responseModalities: ["Text", "Image"] in your generation configuration for text and image output with gemini-2.0-flash-exp-image-generation. Image only is not allowed.
|
||||
|
||||
|
||||
class AIPromptGenerator:
|
||||
"""
|
||||
Generates enhanced AI image prompts based on user keywords,
|
||||
following the guidelines of the Imagen documentation.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.photography_styles = ["photo", "photograph"]
|
||||
self.art_styles = ["painting", "sketch", "drawing", "illustration", "digital art", "render"]
|
||||
self.art_techniques = ["technical pencil drawing", "charcoal drawing", "color pencil drawing", "pastel painting", "digital art", "art deco (poster)", "impressionist painting", "renaissance painting", "pop art"]
|
||||
self.camera_proximity = ["close-up", "zoomed out", "taken from far away"]
|
||||
self.camera_position = ["aerial", "from below"]
|
||||
self.lighting = ["natural lighting", "dramatic lighting", "warm lighting", "cold lighting", "studio lighting", "golden hour lighting"]
|
||||
self.camera_settings = ["motion blur", "soft focus", "bokeh", "portrait"]
|
||||
self.lens_types = ["35mm lens", "50mm lens", "fisheye lens", "wide angle lens", "macro lens", "telephoto lens"]
|
||||
self.film_types = ["black and white film", "polaroid"]
|
||||
self.materials = ["made of cheese", "made of paper", "made of neon tubes", "metallic", "glass", "wooden", "stone"]
|
||||
self.shapes = ["in the shape of a bird", "angular", "curved", "geometric"]
|
||||
self.quality_modifiers_general = ["high-quality", "beautiful", "stylized", "detailed", "epic", "grand"]
|
||||
self.quality_modifiers_photo = ["4K", "HDR", "studio photo", "professional photo", "photorealistic"]
|
||||
self.quality_modifiers_art = ["by a professional artist", "intricate details", "masterpiece"]
|
||||
self.aspect_ratios = ["1:1 aspect ratio", "4:3 aspect ratio", "3:4 aspect ratio", "16:9 aspect ratio", "9:16 aspect ratio"]
|
||||
self.photorealistic_modifiers = {
|
||||
"portraits": ["prime lens", "zoom lens", "24-35mm", "black and white film", "film noir", "shallow depth of field", "duotone (mention two colors)"],
|
||||
"objects": ["macro lens", "60-105mm", "high detail", "precise focusing", "controlled lighting"],
|
||||
"motion": ["telephoto zoom lens", "100-400mm", "fast shutter speed", "action shot", "movement tracking"],
|
||||
"wide-angle": ["wide-angle lens", "10-24mm", "long exposure", "sharp focus", "smooth water or clouds", "astro photography"]
|
||||
}
|
||||
|
||||
def generate_prompt(self, keywords):
|
||||
"""
|
||||
Generates an enhanced AI image prompt based on user-provided keywords.
|
||||
|
||||
Args:
|
||||
keywords (list): A list of keywords describing the desired image.
|
||||
|
||||
Returns:
|
||||
str: An enhanced AI image prompt.
|
||||
"""
|
||||
if not keywords:
|
||||
return "A beautiful image."
|
||||
|
||||
prompt_parts = []
|
||||
subject = " ".join(keywords)
|
||||
prompt_parts.append(subject)
|
||||
|
||||
# Add context and background (optional)
|
||||
context_options = ["in a detailed background", "outdoors", "indoors", "in a studio", "with a blurred background"]
|
||||
if random.random() < 0.6: # Add context with a probability
|
||||
prompt_parts.append(random.choice(context_options))
|
||||
|
||||
# Add style (optional)
|
||||
style_options = self.photography_styles + [f"{art} of" for art in self.art_styles]
|
||||
if random.random() < 0.7:
|
||||
prompt_parts.insert(0, random.choice(style_options))
|
||||
if prompt_parts[0].startswith("painting of") or prompt_parts[0].startswith("sketch of") or prompt_parts[0].startswith("drawing of"):
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(f"in the style of {random.choice(self.art_techniques)}")
|
||||
|
||||
# Add photography modifiers (if photography style is chosen)
|
||||
if any(style in prompt_parts[0] for style in self.photography_styles):
|
||||
if random.random() < 0.4:
|
||||
prompt_parts.append(random.choice(self.camera_proximity))
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.camera_position))
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(random.choice(self.lighting))
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.camera_settings))
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.lens_types))
|
||||
if random.random() < 0.1:
|
||||
prompt_parts.append(random.choice(self.film_types))
|
||||
|
||||
# Add shapes and materials (optional)
|
||||
if random.random() < 0.3:
|
||||
prompt_parts.append(random.choice(self.materials))
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.shapes))
|
||||
|
||||
# Add quality modifiers (optional)
|
||||
if random.random() < 0.6:
|
||||
quality_options = self.quality_modifiers_general
|
||||
if any(style in prompt_parts[0] for style in self.photography_styles):
|
||||
quality_options += self.quality_modifiers_photo
|
||||
else:
|
||||
quality_options += self.quality_modifiers_art
|
||||
prompt_parts.append(random.choice(list(set(quality_options)))) # Avoid duplicates
|
||||
|
||||
# Add aspect ratio (optional)
|
||||
if random.random() < 0.2:
|
||||
prompt_parts.append(random.choice(self.aspect_ratios))
|
||||
|
||||
return ", ".join(prompt_parts)
|
||||
|
||||
def generate_photorealistic_prompt(self, keywords, focus=""):
|
||||
"""
|
||||
Generates an enhanced AI image prompt specifically for photorealistic images.
|
||||
|
||||
Args:
|
||||
keywords (list): A list of keywords describing the desired image.
|
||||
focus (str, optional): The focus of the photorealistic image (e.g., "portraits", "objects", "motion", "wide-angle"). Defaults to "".
|
||||
|
||||
Returns:
|
||||
str: An enhanced photorealistic AI image prompt.
|
||||
"""
|
||||
if not keywords:
|
||||
return "A photorealistic image."
|
||||
|
||||
prompt_parts = ["A photo of", "photorealistic"]
|
||||
prompt_parts.append(" ".join(keywords))
|
||||
|
||||
if focus and focus in self.photorealistic_modifiers:
|
||||
modifiers = self.photorealistic_modifiers[focus]
|
||||
if modifiers:
|
||||
num_modifiers = random.randint(1, min(3, len(modifiers)))
|
||||
selected_modifiers = random.sample(modifiers, num_modifiers)
|
||||
prompt_parts.extend(selected_modifiers)
|
||||
|
||||
# Add general quality modifiers
|
||||
if random.random() < 0.5:
|
||||
prompt_parts.append(random.choice(self.quality_modifiers_photo))
|
||||
|
||||
# Add lighting
|
||||
if random.random() < 0.4:
|
||||
prompt_parts.append(random.choice(self.lighting))
|
||||
|
||||
return ", ".join(prompt_parts)
|
||||
|
||||
def _ensure_client() -> Optional[object]:
|
||||
"""Create a Gemini client if available and API key is configured."""
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini")
|
||||
if not api_key or genai is None:
|
||||
if not api_key:
|
||||
logger.warning("No Gemini API key found")
|
||||
if genai is None:
|
||||
logger.warning("Google Generative AI library not available")
|
||||
return None
|
||||
try:
|
||||
logger.info("Creating Gemini client...")
|
||||
# Create a client using the correct API pattern
|
||||
# The API key is passed directly to the Client constructor
|
||||
client = genai.Client(api_key=api_key)
|
||||
logger.info("Gemini client created successfully")
|
||||
return client
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create Gemini client: {e}")
|
||||
import traceback
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
|
||||
def _generate_imagen_images_base64(prompt: str, aspect_ratio: str = "1:1") -> List[str]:
|
||||
"""
|
||||
Generate images using Imagen API as a fallback method.
|
||||
|
||||
This function implements the Imagen API following the official documentation:
|
||||
https://ai.google.dev/gemini-api/docs/imagen
|
||||
|
||||
Args:
|
||||
prompt: Text prompt for image generation
|
||||
aspect_ratio: Desired aspect ratio (1:1, 3:4, 4:3, 9:16, 16:9)
|
||||
|
||||
Returns:
|
||||
List of base64-encoded PNG images
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
logger.info("🔄 Falling back to Imagen API for image generation")
|
||||
|
||||
try:
|
||||
# Get API key for Imagen (can use same Gemini API key)
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("gemini") # Imagen uses same API key
|
||||
|
||||
if not api_key:
|
||||
logger.error("No API key available for Imagen fallback")
|
||||
return []
|
||||
|
||||
# Create Imagen client
|
||||
client = genai.Client(api_key=api_key)
|
||||
|
||||
# Map aspect ratio to Imagen format using configuration
|
||||
imagen_aspect_ratio = IMAGEN_FALLBACK_CONFIG['fallback_aspect_ratios'].get(aspect_ratio, "1:1")
|
||||
|
||||
# Optimize prompt for Imagen (remove Gemini-specific formatting)
|
||||
imagen_prompt = _optimize_prompt_for_imagen(prompt)
|
||||
|
||||
logger.info(f"Generating Imagen images with prompt: {imagen_prompt[:100]}...")
|
||||
logger.info(f"Using aspect ratio: {imagen_aspect_ratio}")
|
||||
logger.info(f"Using model: {IMAGEN_FALLBACK_CONFIG['preferred_model']}")
|
||||
|
||||
# Generate images using configured Imagen model
|
||||
# Note: sample_image_size is not supported in current library version
|
||||
config_params = {
|
||||
'number_of_images': IMAGEN_FALLBACK_CONFIG['max_images'],
|
||||
'aspect_ratio': imagen_aspect_ratio,
|
||||
}
|
||||
|
||||
# Add additional configuration options if needed
|
||||
# config_params['guidance_scale'] = 7.5 # Optional: control image generation quality
|
||||
# config_params['person_generation'] = 'allow_adult' # Optional: control person generation
|
||||
|
||||
response = client.models.generate_images(
|
||||
model=IMAGEN_FALLBACK_CONFIG['preferred_model'],
|
||||
prompt=imagen_prompt,
|
||||
config=types.GenerateImagesConfig(**config_params)
|
||||
)
|
||||
|
||||
# Extract base64 images from response
|
||||
images_b64: List[str] = []
|
||||
for generated_image in response.generated_images:
|
||||
if hasattr(generated_image, 'image') and hasattr(generated_image.image, 'image_bytes'):
|
||||
# Convert image bytes to base64
|
||||
image_bytes = generated_image.image.image_bytes
|
||||
if isinstance(image_bytes, bytes):
|
||||
images_b64.append(base64.b64encode(image_bytes).decode('utf-8'))
|
||||
else:
|
||||
# If already base64 string
|
||||
images_b64.append(str(image_bytes))
|
||||
|
||||
if images_b64:
|
||||
logger.info(f"✅ Imagen fallback successful! Generated {len(images_b64)} images")
|
||||
return images_b64
|
||||
else:
|
||||
logger.warning("Imagen fallback returned no images")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Imagen fallback failed: {e}")
|
||||
import traceback
|
||||
logger.error(f"Imagen error traceback: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
|
||||
def _optimize_prompt_for_imagen(prompt: str) -> str:
|
||||
"""
|
||||
Optimize prompt for Imagen API by removing Gemini-specific formatting
|
||||
and enhancing it with Imagen best practices.
|
||||
|
||||
Based on Imagen prompt guide: https://ai.google.dev/gemini-api/docs/imagen
|
||||
"""
|
||||
# Remove Gemini-specific formatting
|
||||
prompt = prompt.replace('\n\nEnhanced prompt:', '')
|
||||
prompt = prompt.replace('\n\nAspect ratio:', '')
|
||||
|
||||
# Clean up extra whitespace
|
||||
prompt = ' '.join(prompt.split())
|
||||
|
||||
# Add Imagen-specific enhancements if not present
|
||||
if 'professional' in prompt.lower() and 'linkedin' in prompt.lower():
|
||||
# Enhance for LinkedIn professional content
|
||||
prompt += ", high quality, professional photography, business appropriate"
|
||||
|
||||
if 'digital transformation' in prompt.lower() or 'technology' in prompt.lower():
|
||||
# Enhance for tech content
|
||||
prompt += ", modern, innovative, clean design, corporate aesthetic"
|
||||
|
||||
# Ensure prompt doesn't exceed Imagen's 480 token limit
|
||||
if len(prompt) > 400: # Leave some buffer
|
||||
prompt = prompt[:400] + "..."
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
def generate_gemini_images_base64(
|
||||
prompt: str,
|
||||
*,
|
||||
keywords: Optional[list] = None,
|
||||
style: Optional[str] = None,
|
||||
focus: Optional[str] = None,
|
||||
enhance_prompt: bool = True,
|
||||
aspect_ratio: str = "9:16",
|
||||
max_retries: int = 2,
|
||||
initial_retry_delay: float = 1.0,
|
||||
enable_imagen_fallback: bool = True,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Return list of base64 PNG images generated from a prompt.
|
||||
|
||||
Primary method: Gemini API for image generation
|
||||
Fallback method: Imagen API when Gemini fails (quota limits, API errors, etc.)
|
||||
|
||||
Implements best practices per Gemini docs: send text prompt, parse inline image parts,
|
||||
and return base64 data suitable for API responses. No Streamlit, no printing.
|
||||
|
||||
Docs:
|
||||
- Gemini: https://ai.google.dev/gemini-api/docs/image-generation
|
||||
- Imagen: https://ai.google.dev/gemini-api/docs/imagen
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
logger.info("Generating image (base64) with Gemini (with Imagen fallback)")
|
||||
|
||||
if enhance_prompt and keywords:
|
||||
pg = AIPromptGenerator()
|
||||
enhanced = (
|
||||
pg.generate_photorealistic_prompt(keywords, focus)
|
||||
if style == "photorealistic" and focus
|
||||
else pg.generate_prompt(keywords)
|
||||
)
|
||||
prompt = f"{prompt}\n\nEnhanced prompt: {enhanced}"
|
||||
|
||||
# Optional hint in-text for aspect ratio; API doesn't take ratio param directly
|
||||
if aspect_ratio:
|
||||
prompt = f"{prompt}\n\nAspect ratio: {aspect_ratio}"
|
||||
|
||||
# Try Gemini first
|
||||
client = _ensure_client()
|
||||
if client is None:
|
||||
logger.warning("Gemini client not available or API key missing")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("Falling back to Imagen API")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
retry = 0
|
||||
delay = initial_retry_delay
|
||||
while retry <= max_retries:
|
||||
try:
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=[prompt],
|
||||
)
|
||||
|
||||
images_b64: List[str] = []
|
||||
for part in response.candidates[0].content.parts:
|
||||
if getattr(part, 'inline_data', None) is not None:
|
||||
# part.inline_data.data is bytes (base64 decoded by SDK?)
|
||||
# Standardize to base64 string for API consumers
|
||||
raw = part.inline_data.data
|
||||
if isinstance(raw, bytes):
|
||||
images_b64.append(base64.b64encode(raw).decode('utf-8'))
|
||||
else:
|
||||
# Some SDKs may already present base64 str
|
||||
images_b64.append(str(raw))
|
||||
|
||||
if images_b64:
|
||||
logger.info(f"✅ Gemini generated {len(images_b64)} images successfully")
|
||||
return images_b64
|
||||
else:
|
||||
logger.warning("Gemini returned no images, falling back to Imagen")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
msg = str(e)
|
||||
logger.warning(f"Gemini image gen error: {msg}")
|
||||
|
||||
# Check if this is a quota/API error that warrants fallback
|
||||
if any(error_type in msg.lower() for error_type in [
|
||||
'quota', 'resource_exhausted', 'rate_limit', 'billing', 'api_key', '403', '429'
|
||||
]):
|
||||
logger.info("Gemini quota/API error detected, falling back to Imagen")
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
# For other errors, retry if possible
|
||||
if "503" in msg and retry < max_retries:
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
retry += 1
|
||||
continue
|
||||
|
||||
# Final fallback for any other errors
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("Final fallback to Imagen due to Gemini error")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
# If all retries exhausted, fall back to Imagen
|
||||
if enable_imagen_fallback and IMAGEN_FALLBACK_CONFIG['enabled']:
|
||||
logger.info("All Gemini retries exhausted, falling back to Imagen")
|
||||
return _generate_imagen_images_base64(prompt, aspect_ratio)
|
||||
return []
|
||||
|
||||
|
||||
def generate_gemini_image(
|
||||
prompt,
|
||||
keywords=None,
|
||||
style=None,
|
||||
focus=None,
|
||||
enhance_prompt=True,
|
||||
max_retries=2,
|
||||
initial_retry_delay=1.0,
|
||||
aspect_ratio="9:16",
|
||||
enable_imagen_fallback=True,
|
||||
):
|
||||
"""
|
||||
Backward-compatible wrapper that generates a single image file on disk and returns path.
|
||||
Now includes Imagen fallback for improved reliability.
|
||||
|
||||
Prefer generate_gemini_images_base64 in new code paths.
|
||||
"""
|
||||
logger = logging.getLogger('gemini_image_generator')
|
||||
images = generate_gemini_images_base64(
|
||||
prompt,
|
||||
keywords=keywords,
|
||||
style=style,
|
||||
focus=focus,
|
||||
enhance_prompt=enhance_prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
max_retries=max_retries,
|
||||
initial_retry_delay=initial_retry_delay,
|
||||
enable_imagen_fallback=enable_imagen_fallback,
|
||||
)
|
||||
if not images:
|
||||
return None
|
||||
|
||||
# Persist first image to file for legacy callers
|
||||
img_b64 = images[0]
|
||||
img_bytes = base64.b64decode(img_b64)
|
||||
img = Image.open(BytesIO(img_bytes))
|
||||
|
||||
# Update filename to indicate which API was used
|
||||
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
|
||||
if 'imagen' in prompt.lower() or 'fallback' in prompt.lower():
|
||||
out_name = f'imagen-fallback-image-{timestamp}.png'
|
||||
else:
|
||||
out_name = f'gemini-native-image-{timestamp}.png'
|
||||
|
||||
try:
|
||||
img.save(out_name)
|
||||
# Also call save_generated_image to reuse existing pipeline
|
||||
save_generated_image({"artifacts": [{"base64": img_b64}]})
|
||||
logger.info(f"✅ Image saved successfully: {out_name}")
|
||||
return out_name
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Failed to save image: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def edit_image(image_path, prompt, max_retries=2, initial_retry_delay=1.0):
|
||||
"""
|
||||
- Image editing (text and image to image)
|
||||
Example prompt: "Edit this image to make it look like a cartoon"
|
||||
Example prompt: [image of a cat] + [image of a pillow] + "Create a cross stitch of my cat on this pillow."
|
||||
|
||||
- Multi-turn image editing (chat)
|
||||
Example prompts: [upload an image of a blue car.] "Turn this car into a convertible." "Now change the color to yellow."
|
||||
|
||||
Image editing with Gemini
|
||||
To perform image editing, add an image as input.
|
||||
The following example demonstrats uploading base64 encoded images.
|
||||
For multiple images and larger payloads, check the image input section.
|
||||
|
||||
Args:
|
||||
image_path (str): The path to the image to edit.
|
||||
prompt (str): The prompt to edit the image with.
|
||||
max_retries (int, optional): Maximum number of retry attempts for handling 503 errors. Defaults to 3.
|
||||
initial_retry_delay (int, optional): Initial delay in seconds before retrying. Defaults to 2.
|
||||
|
||||
Returns:
|
||||
str: The path to the edited image.
|
||||
"""
|
||||
import PIL.Image
|
||||
image = PIL.Image.open(image_path)
|
||||
|
||||
retry_count = 0
|
||||
retry_delay = initial_retry_delay
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
client = _ensure_client()
|
||||
if client is None:
|
||||
return None
|
||||
text_input = (prompt)
|
||||
|
||||
logger.info("Sending request to Gemini API for image editing")
|
||||
response = client.models.generate_content(
|
||||
model="gemini-2.0-flash-exp-image-generation",
|
||||
contents=[text_input, image],
|
||||
config=types.GenerateContentConfig(
|
||||
response_modalities=['Text', 'Image']
|
||||
)
|
||||
)
|
||||
logger.info("Received response from Gemini API for image editing")
|
||||
|
||||
edited_img_name = None
|
||||
for part in response.candidates[0].content.parts:
|
||||
if getattr(part, 'inline_data', None) is not None:
|
||||
logger.info("Received edited image data from Gemini")
|
||||
edited_image = Image.open(BytesIO(part.inline_data.data))
|
||||
|
||||
# Save the edited image
|
||||
edited_img_name = f'edited-{os.path.basename(image_path)}'
|
||||
try:
|
||||
logger.info(f"Saving edited image to: {edited_img_name}")
|
||||
edited_image.save(edited_img_name)
|
||||
|
||||
# Create a dictionary with the expected format for save_generated_image
|
||||
img_response = {
|
||||
"artifacts": [
|
||||
{
|
||||
"base64": base64.b64encode(open(edited_img_name, "rb").read()).decode('utf-8')
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# Call save_generated_image with the correct format
|
||||
save_generated_image(img_response)
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to save edited image: {err}")
|
||||
|
||||
logger.info(f"Image editing completed. Edited image name: {edited_img_name}")
|
||||
return edited_img_name
|
||||
except Exception as err:
|
||||
error_message = str(err)
|
||||
logger.error(f"Error in edit_image: {err}")
|
||||
# Retry on transient 503
|
||||
if "503" in error_message and retry_count < max_retries:
|
||||
retry_count += 1
|
||||
logger.info(f"Retrying in {retry_delay} seconds (attempt {retry_count}/{max_retries})")
|
||||
time.sleep(retry_delay)
|
||||
# Exponential backoff
|
||||
retry_delay *= 2
|
||||
else:
|
||||
return None
|
||||
# If we've exhausted all retries
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
# Ensure you sign up for an account to obtain an API key:
|
||||
# https://platform.stability.ai/
|
||||
# Your API key can be found here after account creation:
|
||||
# https://platform.stability.ai/account/keys
|
||||
|
||||
import os
|
||||
import requests
|
||||
import base64
|
||||
from PIL import Image
|
||||
from io import BytesIO
|
||||
import streamlit as st
|
||||
from loguru import logger
|
||||
|
||||
# Import APIKeyManager
|
||||
from ...onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
def save_generated_image(data):
|
||||
"""Save the generated image to a file."""
|
||||
# Implementation for saving image
|
||||
pass
|
||||
|
||||
def generate_stable_diffusion_image(prompt):
|
||||
engine_id = "stable-diffusion-xl-1024-v1-0"
|
||||
api_host = os.getenv('API_HOST', 'https://api.stability.ai')
|
||||
|
||||
# Use APIKeyManager instead of direct environment variable access
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key = api_key_manager.get_api_key("stability")
|
||||
|
||||
if api_key is None:
|
||||
st.warning("Missing Stability API key. Please configure it in the onboarding process.")
|
||||
return None
|
||||
|
||||
response = requests.post(
|
||||
f"{api_host}/v1/generation/{engine_id}/text-to-image",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Accept": "application/json",
|
||||
"Authorization": f"Bearer {api_key}"
|
||||
},
|
||||
json={
|
||||
"text_prompts": [
|
||||
{
|
||||
"text": prompt
|
||||
}
|
||||
],
|
||||
"cfg_scale": 7,
|
||||
"height": 1024,
|
||||
"width": 1024,
|
||||
"samples": 1,
|
||||
"steps": 30,
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception("Non-200 response: " + str(response.text))
|
||||
|
||||
data = response.json()
|
||||
img_path = save_generated_image(data)
|
||||
|
||||
for i, image in enumerate(data["artifacts"]):
|
||||
# Decode base64 image data
|
||||
img_data = base64.b64decode(image["base64"])
|
||||
# Open image using PIL
|
||||
img = Image.open(BytesIO(img_data))
|
||||
# Display the image
|
||||
img.show()
|
||||
|
||||
return img_path
|
||||
@@ -1,51 +0,0 @@
|
||||
from loguru import logger
|
||||
import sys
|
||||
from PIL import Image
|
||||
from openai import OpenAI
|
||||
|
||||
def gen_new_from_given_img(img_path, image_dir, num_img=1, img_size="1024x1024", response_format="url"):
|
||||
"""
|
||||
Generates variations of a given image using OpenAI's image variation API.
|
||||
|
||||
This function takes an existing image, processes it, and generates a specified number of new images based on it.
|
||||
These generated images are variations of the original, providing creative flexibility.
|
||||
|
||||
Args:
|
||||
img_path (str): Path to the original image file.
|
||||
image_dir (str): Directory where the generated images will be saved.
|
||||
num_img (int, optional): Number of image variations to generate. Defaults to 1.
|
||||
img_size (str, optional): Size of the generated images. Defaults to "1024x1024".
|
||||
response_format (str, optional): Format in which the generated images are returned. Defaults to "url".
|
||||
|
||||
Returns:
|
||||
str: Path to the saved image variation.
|
||||
|
||||
Raises:
|
||||
SystemExit: If a critical error occurs that prevents successful execution.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting image variation generation for: {img_path}")
|
||||
|
||||
# Convert and prepare the image
|
||||
png = Image.open(img_path).convert('RGBA')
|
||||
background = Image.new('RGBA', png.size, (255, 255, 255))
|
||||
alpha_composite = Image.alpha_composite(background, png)
|
||||
alpha_composite.save(img_path, 'PNG', quality=80)
|
||||
logger.info("Image prepared for variation generation.")
|
||||
|
||||
client = OpenAI()
|
||||
variation_response = client.images.create_variation(
|
||||
image=open(img_path, "rb", encoding="utf-8"),
|
||||
n=num_img,
|
||||
size=img_size,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
# Saving the generated image
|
||||
generated_image_path = save_generated_image(variation_response, image_dir)
|
||||
logger.info(f"Image variation generated and saved to: {generated_image_path}")
|
||||
return generated_image_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error occurred during image variation generation: {e}")
|
||||
sys.exit(f"Exiting due to critical error: {e}")
|
||||
@@ -1,162 +0,0 @@
|
||||
#########################################################
|
||||
#
|
||||
# This module will generate images for the blogs using APIs
|
||||
# from Dall-E and other free resources. Given a prompt, the
|
||||
# images will be stored in local directory.
|
||||
# Required: openai API key.
|
||||
#
|
||||
#########################################################
|
||||
|
||||
# imports
|
||||
import os
|
||||
import sys
|
||||
import datetime
|
||||
import streamlit as st
|
||||
|
||||
import openai # OpenAI Python library to make API calls
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("text_to_image_generation")
|
||||
|
||||
#from .gen_dali2_images
|
||||
from .gen_dali3_images import generate_dalle3_images
|
||||
from .gen_stabl_diff_img import generate_stable_diffusion_image
|
||||
from ..text_generation.main_text_generation import llm_text_gen
|
||||
from .gen_gemini_images import generate_gemini_image
|
||||
|
||||
def generate_image(user_prompt, title=None, description=None, tags=None, content=None, aspect_ratio="16:9"):
|
||||
"""
|
||||
The generation API endpoint creates an image based on a text prompt.
|
||||
|
||||
Required inputs:
|
||||
prompt (str): A text description of the desired image(s). The maximum length is 1000 characters.
|
||||
|
||||
Optional inputs:
|
||||
--> image_engine: dalle2, dalle3, stable diffusion are supported.
|
||||
--> num_images (int): The number of images to generate. Must be between 1 and 10. Defaults to 1.
|
||||
--> size (str): The size of the generated images. Must be one of "256x256", "512x512", or "1024x1024".
|
||||
Smaller images are faster. Defaults to "1024x1024".
|
||||
-->response_format (str): The format in which the generated images are returned.
|
||||
Must be one of "url" or "b64_json". Defaults to "url".
|
||||
--> user (str): A unique identifier representing your end-user, which will help OpenAI to monitor and detect abuse.
|
||||
--> aspect_ratio (str): The aspect ratio for the generated image. Must be one of "16:9", "4:3", or "1:1". Defaults to "16:9".
|
||||
"""
|
||||
# FIXME: Need to remove default value to match sidebar input.
|
||||
image_engine = 'Gemini-AI'
|
||||
image_stored_at = None
|
||||
|
||||
if user_prompt:
|
||||
try:
|
||||
# Use enhanced prompt generator with all available parameters
|
||||
img_prompt = generate_enhanced_img_prompt(user_prompt, title, description, tags, content)
|
||||
|
||||
# Add aspect ratio to the prompt
|
||||
if aspect_ratio:
|
||||
img_prompt += f"\n\nAspect ratio: {aspect_ratio}"
|
||||
|
||||
if 'Dalle3' in image_engine:
|
||||
logger.info(f"Calling Dalle3 text-to-image with prompt: {img_prompt}")
|
||||
image_stored_at = generate_dalle3_images(img_prompt)
|
||||
elif 'Stability-AI' in image_engine:
|
||||
logger.info(f"Calling Stable diffusion text-to-image with prompt: \n{img_prompt}")
|
||||
image_stored_at = generate_stable_diffusion_image(img_prompt)
|
||||
elif 'Gemini-AI' in image_engine:
|
||||
logger.info(f"Calling Gemini text-to-image with prompt: \n{img_prompt}")
|
||||
image_stored_at = generate_gemini_image(img_prompt, aspect_ratio=aspect_ratio)
|
||||
return image_stored_at
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to generate Image: {err}")
|
||||
st.warning(f"Failed to generate Image: {err}")
|
||||
else:
|
||||
logger.error("Skipping Image creation, No prompt provided.")
|
||||
|
||||
|
||||
def generate_img_prompt(user_prompt):
|
||||
"""
|
||||
Given prompt, this functions generated a prompt for image generation.
|
||||
"""
|
||||
prompt = f"""
|
||||
As an expert prompt generator for AI text to image models and artist, I will provide you with 'user text' for creating images.
|
||||
Your task is to create a prompt for a highly relevant image from given 'user text'.
|
||||
\n
|
||||
Choose from various art styles, utilize light & shadow effects etc.
|
||||
Make sure to avoid common image generation mistakes.
|
||||
Reply with only one answer, no descrition and in plaintext.
|
||||
Make sure your prompt is detailed and creative descriptions that will inspire unique and interesting images from the AI.
|
||||
|
||||
\n\nuser text:
|
||||
'''{user_prompt}'''"""
|
||||
|
||||
response = llm_text_gen(prompt)
|
||||
return response
|
||||
|
||||
|
||||
def generate_enhanced_img_prompt(user_prompt, title=None, description=None, tags=None, content=None):
|
||||
"""
|
||||
Given user prompt and additional context (title, description, tags, content),
|
||||
this function generates an enhanced prompt for better image generation.
|
||||
|
||||
Args:
|
||||
user_prompt (str): Base prompt from the user
|
||||
title (str, optional): Blog title or content title
|
||||
description (str, optional): Blog or content description/summary
|
||||
tags (list, optional): List of tags related to the content
|
||||
content (str, optional): Actual content or excerpt
|
||||
|
||||
Returns:
|
||||
str: Enhanced prompt for image generation
|
||||
"""
|
||||
# Start with the base prompt
|
||||
context_parts = [user_prompt]
|
||||
|
||||
# Add relevant context if available
|
||||
if title:
|
||||
context_parts.append(f"Title: {title}")
|
||||
|
||||
if description:
|
||||
context_parts.append(f"Description: {description}")
|
||||
|
||||
if tags and len(tags) > 0:
|
||||
tag_text = ", ".join(tags[:5]) # Limit to 5 tags to avoid too much noise
|
||||
context_parts.append(f"Tags: {tag_text}")
|
||||
|
||||
# Create a combined context
|
||||
combined_context = "\n".join(context_parts)
|
||||
|
||||
# Add some content excerpt if available (limited to avoid token limits)
|
||||
content_excerpt = ""
|
||||
if content:
|
||||
# Just use the first few hundred characters as excerpt
|
||||
content_excerpt = content[:300] + "..." if len(content) > 300 else content
|
||||
|
||||
# Create the prompt for LLM
|
||||
prompt = f"""
|
||||
As an expert prompt engineer for AI image generation models, create a detailed, creative prompt
|
||||
for generating a high-quality, relevant image based on the following context:
|
||||
|
||||
{combined_context}
|
||||
|
||||
Additional content excerpt:
|
||||
{content_excerpt}
|
||||
|
||||
Your task is to:
|
||||
1. Analyze the context and content to understand the main theme and subject
|
||||
2. Create a rich, detailed prompt for image generation (50-75 words)
|
||||
3. Include specific visual details, art style, mood, lighting, composition
|
||||
4. Make sure the prompt is highly relevant to the original context
|
||||
5. Avoid prohibited content or anything that violates image generation guidelines
|
||||
|
||||
Reply with ONLY the final prompt. No explanations or other text.
|
||||
"""
|
||||
|
||||
# Generate the enhanced prompt
|
||||
try:
|
||||
enhanced_prompt = llm_text_gen(prompt)
|
||||
logger.info(f"Generated enhanced image prompt: {enhanced_prompt[:100]}...")
|
||||
return enhanced_prompt
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating enhanced prompt: {e}")
|
||||
# Fall back to the simple prompt generation if enhanced fails
|
||||
return generate_img_prompt(user_prompt)
|
||||
@@ -1,39 +0,0 @@
|
||||
import base64
|
||||
import datetime
|
||||
import os
|
||||
import requests
|
||||
from PIL import Image
|
||||
import logging
|
||||
|
||||
def save_generated_image(img_generation_response):
|
||||
"""
|
||||
Save generated images for blog, ensuring unique names for SEO.
|
||||
"""
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Get image save directory with fallback to a local directory
|
||||
image_save_dir = os.getenv('IMG_SAVE_DIR', 'generated_images')
|
||||
|
||||
# Create the directory if it doesn't exist
|
||||
if not os.path.exists(image_save_dir):
|
||||
logger.info(f"Creating image save directory: {image_save_dir}")
|
||||
os.makedirs(image_save_dir, exist_ok=True)
|
||||
|
||||
generated_image_name = f"generated_image_{datetime.datetime.now():%Y-%m-%d-%H-%M-%S}.webp"
|
||||
generated_image_filepath = os.path.join(image_save_dir, generated_image_name)
|
||||
|
||||
try:
|
||||
for i, image in enumerate(img_generation_response["artifacts"]):
|
||||
with open(generated_image_filepath, "wb") as f:
|
||||
f.write(base64.b64decode(image["base64"]))
|
||||
except requests.exceptions.RequestException as e:
|
||||
logger.error(f"Failed to get generated image content: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving image: {e}")
|
||||
return None
|
||||
|
||||
logger.info(f"Saved image at path: {generated_image_filepath}")
|
||||
|
||||
return generated_image_filepath
|
||||
Reference in New Issue
Block a user