381 lines
15 KiB
Python
381 lines
15 KiB
Python
"""
|
|
Image Generation Service for Story Writer
|
|
|
|
Generates images for story scenes using the existing image generation service.
|
|
"""
|
|
|
|
import os
|
|
import base64
|
|
import uuid
|
|
from typing import List, Dict, Any, Optional
|
|
from pathlib import Path
|
|
from fastapi import HTTPException
|
|
from sqlalchemy.orm import Session
|
|
|
|
from services.llm_providers.main_image_generation import generate_image
|
|
from services.llm_providers.image_generation import ImageGenerationResult
|
|
from utils.logger_utils import get_service_logger
|
|
from services.user_workspace_manager import UserWorkspaceManager
|
|
|
|
logger = get_service_logger("story_writer.image_generation")
|
|
|
|
|
|
class StoryImageGenerationService:
|
|
"""Service for generating images for story scenes."""
|
|
|
|
def __init__(self, output_dir: Optional[str] = None):
|
|
"""
|
|
Initialize the image generation service.
|
|
|
|
Parameters:
|
|
output_dir (str, optional): Directory to save generated images.
|
|
Defaults to 'backend/story_images' if not provided.
|
|
"""
|
|
if output_dir:
|
|
self.output_dir = Path(output_dir)
|
|
else:
|
|
# Default to root/data/media/story_images directory
|
|
base_dir = Path(__file__).resolve().parents[3]
|
|
self.output_dir = base_dir / "data" / "media" / "story_images"
|
|
|
|
# Create output directory if it doesn't exist
|
|
self.output_dir.mkdir(parents=True, exist_ok=True)
|
|
logger.info(f"[StoryImageGeneration] Initialized with output directory: {self.output_dir}")
|
|
|
|
def _get_user_image_dir(self, user_id: str, db: Optional[Session] = None) -> Path:
|
|
"""
|
|
Get the image directory for a specific user.
|
|
Falls back to default output_dir if workspace not found.
|
|
"""
|
|
if db and user_id:
|
|
try:
|
|
workspace_manager = UserWorkspaceManager(db)
|
|
workspace = workspace_manager.get_user_workspace(user_id)
|
|
if workspace:
|
|
# Use media/story_images inside user workspace
|
|
user_image_dir = Path(workspace['workspace_path']) / "media" / "story_images"
|
|
user_image_dir.mkdir(parents=True, exist_ok=True)
|
|
return user_image_dir
|
|
except Exception as e:
|
|
logger.warning(f"[StoryImageGeneration] Failed to resolve user workspace path for {user_id}: {e}")
|
|
|
|
return self.output_dir
|
|
|
|
def _generate_image_filename(self, scene_number: int, scene_title: str) -> str:
|
|
"""Generate a unique filename for a scene image."""
|
|
# Clean scene title for filename
|
|
clean_title = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in scene_title[:30])
|
|
unique_id = str(uuid.uuid4())[:8]
|
|
return f"scene_{scene_number}_{clean_title}_{unique_id}.png"
|
|
|
|
def _refine_image_prompt_with_bible(
|
|
self,
|
|
image_prompt: str,
|
|
scene: Dict[str, Any],
|
|
anime_bible: Optional[Dict[str, Any]] = None,
|
|
) -> str:
|
|
"""
|
|
Lightweight image prompt refinement using the anime story bible.
|
|
|
|
Takes the existing scene image_prompt and enriches it with visual_style,
|
|
world, and cast hints from the bible. This is deterministic and avoids
|
|
extra LLM calls.
|
|
"""
|
|
if not image_prompt or not isinstance(image_prompt, str):
|
|
return image_prompt
|
|
|
|
if not anime_bible or not isinstance(anime_bible, dict):
|
|
return image_prompt
|
|
|
|
visual_style = anime_bible.get("visual_style") or {}
|
|
world = anime_bible.get("world") or {}
|
|
main_cast = anime_bible.get("main_cast") or []
|
|
|
|
parts: List[str] = []
|
|
|
|
style_preset = visual_style.get("style_preset")
|
|
if style_preset:
|
|
parts.append(f"{style_preset} anime illustration style")
|
|
|
|
camera_style = visual_style.get("camera_style")
|
|
if camera_style:
|
|
parts.append(f"framing and camera style: {camera_style}")
|
|
|
|
color_mood = visual_style.get("color_mood")
|
|
if color_mood:
|
|
parts.append(f"color mood: {color_mood}")
|
|
|
|
lighting = visual_style.get("lighting")
|
|
if lighting:
|
|
parts.append(f"lighting: {lighting}")
|
|
|
|
line_style = visual_style.get("line_style")
|
|
if line_style:
|
|
parts.append(f"line style: {line_style}")
|
|
|
|
extra_tags = visual_style.get("extra_tags") or []
|
|
if isinstance(extra_tags, (list, tuple)):
|
|
extra_text = ", ".join(str(tag) for tag in extra_tags[:6] if tag)
|
|
if extra_text:
|
|
parts.append(extra_text)
|
|
|
|
setting = world.get("setting") if isinstance(world, dict) else None
|
|
if setting:
|
|
parts.append(f"world setting: {setting}")
|
|
|
|
if isinstance(main_cast, list):
|
|
names = [
|
|
c.get("name")
|
|
for c in main_cast
|
|
if isinstance(c, dict) and c.get("name")
|
|
]
|
|
if names:
|
|
joined = ", ".join(names[:4])
|
|
parts.append(f"keep character designs consistent for: {joined}")
|
|
|
|
if not parts:
|
|
return image_prompt
|
|
|
|
suffix = ", " + ", ".join(parts)
|
|
return image_prompt.strip() + suffix
|
|
|
|
def generate_scene_image(
|
|
self,
|
|
scene: Dict[str, Any],
|
|
user_id: str,
|
|
provider: Optional[str] = None,
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
model: Optional[str] = None,
|
|
anime_bible: Optional[Dict[str, Any]] = None,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Generate an image for a single story scene.
|
|
|
|
Parameters:
|
|
scene (Dict[str, Any]): Scene data with image_prompt.
|
|
user_id (str): Clerk user ID for subscription checking.
|
|
provider (str, optional): Image generation provider (gemini, huggingface, stability).
|
|
width (int): Image width (default: 1024).
|
|
height (int): Image height (default: 1024).
|
|
model (str, optional): Model to use for image generation.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Image metadata including file path, URL, and scene info.
|
|
"""
|
|
scene_number = scene.get("scene_number", 0)
|
|
scene_title = scene.get("title", "Untitled")
|
|
image_prompt = scene.get("image_prompt", "")
|
|
|
|
if anime_bible:
|
|
try:
|
|
image_prompt = self._refine_image_prompt_with_bible(
|
|
image_prompt=image_prompt,
|
|
scene=scene,
|
|
anime_bible=anime_bible,
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"[StoryImageGeneration] Failed to refine image prompt with bible: {e}")
|
|
|
|
if not image_prompt:
|
|
raise ValueError(f"Scene {scene_number} ({scene_title}) has no image_prompt")
|
|
|
|
try:
|
|
logger.info(f"[StoryImageGeneration] Generating image for scene {scene_number}: {scene_title}")
|
|
logger.debug(f"[StoryImageGeneration] Image prompt: {image_prompt[:100]}...")
|
|
|
|
# Generate image using main_image_generation service
|
|
image_options = {
|
|
"provider": provider,
|
|
"width": width,
|
|
"height": height,
|
|
"model": model,
|
|
}
|
|
|
|
result: ImageGenerationResult = generate_image(
|
|
prompt=image_prompt,
|
|
options=image_options,
|
|
user_id=user_id
|
|
)
|
|
|
|
# Save image to file
|
|
image_filename = self._generate_image_filename(scene_number, scene_title)
|
|
image_path = self.output_dir / image_filename
|
|
|
|
with open(image_path, "wb") as f:
|
|
f.write(result.image_bytes)
|
|
|
|
logger.info(f"[StoryImageGeneration] Saved image to: {image_path}")
|
|
|
|
# Return image metadata
|
|
# Use relative path for image_url (will be served via API endpoint)
|
|
return {
|
|
"scene_number": scene_number,
|
|
"scene_title": scene_title,
|
|
"image_path": str(image_path),
|
|
"image_filename": image_filename,
|
|
"image_url": f"/api/story/images/{image_filename}", # API endpoint to serve images
|
|
"width": result.width,
|
|
"height": result.height,
|
|
"provider": result.provider,
|
|
"model": result.model,
|
|
"seed": result.seed,
|
|
}
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTPExceptions (e.g., 429 subscription limit)
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[StoryImageGeneration] Error generating image for scene {scene_number}: {e}")
|
|
raise RuntimeError(f"Failed to generate image for scene {scene_number}: {str(e)}") from e
|
|
|
|
def generate_scene_images(
|
|
self,
|
|
scenes: List[Dict[str, Any]],
|
|
user_id: str,
|
|
provider: Optional[str] = None,
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
model: Optional[str] = None,
|
|
progress_callback: Optional[callable] = None,
|
|
db: Optional[Session] = None,
|
|
anime_bible: Optional[Dict[str, Any]] = None,
|
|
) -> List[Dict[str, Any]]:
|
|
"""
|
|
Generate images for multiple story scenes.
|
|
|
|
Parameters:
|
|
scenes (List[Dict[str, Any]]): List of scene data with image_prompts.
|
|
user_id (str): Clerk user ID for subscription checking.
|
|
provider (str, optional): Image generation provider (gemini, huggingface, stability).
|
|
width (int): Image width (default: 1024).
|
|
height (int): Image height (default: 1024).
|
|
model (str, optional): Model to use for image generation.
|
|
progress_callback (callable, optional): Callback function for progress updates.
|
|
db (Session, optional): Database session.
|
|
|
|
Returns:
|
|
List[Dict[str, Any]]: List of image metadata for each scene.
|
|
"""
|
|
if not scenes:
|
|
raise ValueError("No scenes provided for image generation")
|
|
|
|
logger.info(f"[StoryImageGeneration] Generating images for {len(scenes)} scenes")
|
|
|
|
image_results = []
|
|
total_scenes = len(scenes)
|
|
|
|
for idx, scene in enumerate(scenes):
|
|
try:
|
|
# Generate image for scene
|
|
image_result = self.generate_scene_image(
|
|
scene=scene,
|
|
user_id=user_id,
|
|
provider=provider,
|
|
width=width,
|
|
height=height,
|
|
model=model,
|
|
anime_bible=anime_bible,
|
|
)
|
|
|
|
image_results.append(image_result)
|
|
|
|
# Call progress callback if provided
|
|
if progress_callback:
|
|
progress = ((idx + 1) / total_scenes) * 100
|
|
progress_callback(progress, f"Generated image for scene {scene.get('scene_number', idx + 1)}")
|
|
|
|
logger.info(f"[StoryImageGeneration] Generated image {idx + 1}/{total_scenes}")
|
|
|
|
except Exception as e:
|
|
logger.error(f"[StoryImageGeneration] Failed to generate image for scene {idx + 1}: {e}")
|
|
# Continue with next scene instead of failing completely
|
|
image_results.append({
|
|
"scene_number": scene.get("scene_number", idx + 1),
|
|
"scene_title": scene.get("title", "Untitled"),
|
|
"error": str(e),
|
|
"image_path": None,
|
|
"image_url": None,
|
|
})
|
|
|
|
logger.info(f"[StoryImageGeneration] Generated {len(image_results)} images out of {total_scenes} scenes")
|
|
return image_results
|
|
|
|
def regenerate_scene_image(
|
|
self,
|
|
scene_number: int,
|
|
scene_title: str,
|
|
prompt: str,
|
|
user_id: str,
|
|
provider: Optional[str] = None,
|
|
width: int = 1024,
|
|
height: int = 1024,
|
|
model: Optional[str] = None
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Regenerate an image for a single scene using a direct prompt (no AI prompt generation).
|
|
|
|
Parameters:
|
|
scene_number (int): Scene number.
|
|
scene_title (str): Scene title.
|
|
prompt (str): Direct prompt to use for image generation.
|
|
user_id (str): Clerk user ID for subscription checking.
|
|
provider (str, optional): Image generation provider (gemini, huggingface, stability).
|
|
width (int): Image width (default: 1024).
|
|
height (int): Image height (default: 1024).
|
|
model (str, optional): Model to use for image generation.
|
|
|
|
Returns:
|
|
Dict[str, Any]: Image metadata including file path, URL, and scene info.
|
|
"""
|
|
if not prompt or not prompt.strip():
|
|
raise ValueError(f"Scene {scene_number} ({scene_title}) requires a non-empty prompt")
|
|
|
|
try:
|
|
logger.info(f"[StoryImageGeneration] Regenerating image for scene {scene_number}: {scene_title}")
|
|
logger.debug(f"[StoryImageGeneration] Using direct prompt: {prompt[:100]}...")
|
|
|
|
# Generate image using main_image_generation service with the direct prompt
|
|
image_options = {
|
|
"provider": provider,
|
|
"width": width,
|
|
"height": height,
|
|
"model": model,
|
|
}
|
|
|
|
result: ImageGenerationResult = generate_image(
|
|
prompt=prompt.strip(),
|
|
options=image_options,
|
|
user_id=user_id
|
|
)
|
|
|
|
# Save image to file
|
|
image_filename = self._generate_image_filename(scene_number, scene_title)
|
|
image_path = self.output_dir / image_filename
|
|
|
|
with open(image_path, "wb") as f:
|
|
f.write(result.image_bytes)
|
|
|
|
logger.info(f"[StoryImageGeneration] Saved regenerated image to: {image_path}")
|
|
|
|
# Return image metadata
|
|
return {
|
|
"scene_number": scene_number,
|
|
"scene_title": scene_title,
|
|
"image_path": str(image_path),
|
|
"image_filename": image_filename,
|
|
"image_url": f"/api/story/images/{image_filename}",
|
|
"width": result.width,
|
|
"height": result.height,
|
|
"provider": result.provider,
|
|
"model": result.model,
|
|
"seed": result.seed,
|
|
}
|
|
|
|
except HTTPException:
|
|
# Re-raise HTTPExceptions (e.g., 429 subscription limit)
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[StoryImageGeneration] Error regenerating image for scene {scene_number}: {e}")
|
|
raise RuntimeError(f"Failed to regenerate image for scene {scene_number}: {str(e)}") from e
|