Files
ALwrity/backend/api/story_writer/routes/media_generation.py

418 lines
17 KiB
Python

from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from models.story_models import (
StoryImageGenerationRequest,
StoryImageGenerationResponse,
StoryImageResult,
RegenerateImageRequest,
RegenerateImageResponse,
StoryAudioGenerationRequest,
StoryAudioGenerationResponse,
StoryAudioResult,
GenerateAIAudioRequest,
GenerateAIAudioResponse,
StoryScene,
)
from services.database import get_db
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from utils.asset_tracker import save_asset_to_library
from ..utils.auth import require_authenticated_user
from ..utils.media_utils import resolve_media_file, resolve_story_media_path
router = APIRouter()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
@router.post("/generate-images", response_model=StoryImageGenerationResponse)
async def generate_scene_images(
request: StoryImageGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StoryImageGenerationResponse:
"""Generate images for story scenes."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating images for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_results = image_service.generate_scene_images(
scenes=scenes_data,
user_id=user_id,
provider=request.provider,
width=request.width or 1024,
height=request.height or 1024,
model=request.model,
db=db,
)
image_models: List[StoryImageResult] = [
StoryImageResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
width=result.get("width", 1024),
height=result.get("height", 1024),
provider=result.get("provider", "unknown"),
model=result.get("model"),
seed=result.get("seed"),
error=result.get("error"),
)
for result in image_results
]
# Save assets to library
for result in image_results:
if not result.get("error") and result.get("image_url"):
try:
scene_number = result.get("scene_number", 0)
# Safely get prompt from scenes_data with bounds checking
prompt = None
if scene_number > 0 and scene_number <= len(scenes_data):
prompt = scenes_data[scene_number - 1].get("image_prompt")
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="story_writer",
filename=result.get("image_filename", ""),
file_url=result.get("image_url", ""),
file_path=result.get("image_path"),
file_size=result.get("file_size"),
mime_type="image/png",
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
description=f"Story scene image for scene {scene_number}",
prompt=prompt,
tags=["story_writer", "scene", f"scene_{scene_number}"],
provider=result.get("provider"),
model=result.get("model"),
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save image asset to library: {e}")
return StoryImageGenerationResponse(images=image_models, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate images: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/regenerate-images", response_model=RegenerateImageResponse)
async def regenerate_scene_image(
request: RegenerateImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> RegenerateImageResponse:
"""Regenerate a single scene image using a direct prompt (no AI prompt generation)."""
try:
user_id = require_authenticated_user(current_user)
if not request.prompt or not request.prompt.strip():
raise HTTPException(status_code=400, detail="Prompt is required")
logger.info(
f"[StoryWriter] Regenerating image for scene {request.scene_number} "
f"({request.scene_title}) for user {user_id}"
)
result = image_service.regenerate_scene_image(
scene_number=request.scene_number,
scene_title=request.scene_title,
prompt=request.prompt.strip(),
user_id=user_id,
provider=request.provider,
width=request.width or 1024,
height=request.height or 1024,
model=request.model,
)
return RegenerateImageResponse(
scene_number=result.get("scene_number", request.scene_number),
scene_title=result.get("scene_title", request.scene_title),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
width=result.get("width", request.width or 1024),
height=result.get("height", request.height or 1024),
provider=result.get("provider", "unknown"),
model=result.get("model"),
seed=result.get("seed"),
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to regenerate image: {exc}")
return RegenerateImageResponse(
scene_number=request.scene_number,
scene_title=request.scene_title,
image_filename="",
image_url="",
width=request.width or 1024,
height=request.height or 1024,
provider=request.provider or "unknown",
success=False,
error=str(exc),
)
@router.get("/images/{image_filename}")
async def serve_scene_image(
image_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
):
"""Serve a generated story scene image.
Supports authentication via Authorization header or token query parameter.
Query parameter is useful for HTML elements like <img> that cannot send custom headers.
"""
try:
require_authenticated_user(current_user)
image_path = resolve_media_file(image_service.output_dir, image_filename)
return FileResponse(path=str(image_path), media_type="image/png", filename=image_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve image: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-audio", response_model=StoryAudioGenerationResponse)
async def generate_scene_audio(
request: StoryAudioGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StoryAudioGenerationResponse:
"""Generate audio narration for story scenes."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating audio for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
audio_results = audio_service.generate_scene_audio_list(
scenes=scenes_data,
user_id=user_id,
provider=request.provider or "gtts",
lang=request.lang or "en",
slow=request.slow or False,
rate=request.rate or 150,
)
audio_models: List[StoryAudioResult] = []
for result in audio_results:
audio_url = result.get("audio_url") or ""
audio_filename = result.get("audio_filename") or ""
audio_models.append(
StoryAudioResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
audio_filename=audio_filename,
audio_url=audio_url,
provider=result.get("provider", "unknown"),
file_size=result.get("file_size", 0),
error=result.get("error"),
)
)
# Save assets to library
if not result.get("error") and audio_url:
try:
scene_number = result.get("scene_number", 0)
# Safely get prompt from scenes_data with bounds checking
prompt = None
if scene_number > 0 and scene_number <= len(scenes_data):
prompt = scenes_data[scene_number - 1].get("text")
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="audio",
source_module="story_writer",
filename=audio_filename,
file_url=audio_url,
file_path=result.get("audio_path"),
file_size=result.get("file_size"),
mime_type="audio/mpeg",
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
description=f"Story scene audio narration for scene {scene_number}",
prompt=prompt,
tags=["story_writer", "audio", "narration", f"scene_{scene_number}"],
provider=result.get("provider"),
model=result.get("model"),
cost=result.get("cost"),
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save audio asset to library: {e}")
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate audio: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-ai-audio", response_model=GenerateAIAudioResponse)
async def generate_ai_audio(
request: GenerateAIAudioRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> GenerateAIAudioResponse:
"""Generate AI audio for a single scene using WaveSpeed Minimax Speech 02 HD."""
try:
user_id = require_authenticated_user(current_user)
if not request.text or not request.text.strip():
raise HTTPException(status_code=400, detail="Text is required")
logger.info(
f"[StoryWriter] Generating AI audio for scene {request.scene_number} "
f"({request.scene_title}) for user {user_id}"
)
result = audio_service.generate_ai_audio(
scene_number=request.scene_number,
scene_title=request.scene_title,
text=request.text.strip(),
user_id=user_id,
voice_id=request.voice_id or "Wise_Woman",
speed=request.speed or 1.0,
volume=request.volume or 1.0,
pitch=request.pitch or 0.0,
emotion=request.emotion or "happy",
)
return GenerateAIAudioResponse(
scene_number=result.get("scene_number", request.scene_number),
scene_title=result.get("scene_title", request.scene_title),
audio_filename=result.get("audio_filename", ""),
audio_url=result.get("audio_url", ""),
provider=result.get("provider", "wavespeed"),
model=result.get("model", "minimax/speech-02-hd"),
voice_id=result.get("voice_id", request.voice_id or "Wise_Woman"),
text_length=result.get("text_length", len(request.text)),
file_size=result.get("file_size", 0),
cost=result.get("cost", 0.0),
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate AI audio: {exc}")
return GenerateAIAudioResponse(
scene_number=request.scene_number,
scene_title=request.scene_title,
audio_filename="",
audio_url="",
provider="wavespeed",
model="minimax/speech-02-hd",
voice_id=request.voice_id or "Wise_Woman",
text_length=len(request.text) if request.text else 0,
file_size=0,
cost=0.0,
success=False,
error=str(exc),
)
@router.get("/audio/{audio_filename}")
async def serve_scene_audio(
audio_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve a generated story scene audio file."""
try:
require_authenticated_user(current_user)
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
return FileResponse(path=str(audio_path), media_type="audio/mpeg", filename=audio_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve audio: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
class PromptOptimizeRequest(BaseModel):
text: str = Field(..., description="The prompt text to optimize")
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
style: Optional[str] = Field(
default="default",
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
)
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
class PromptOptimizeResponse(BaseModel):
optimized_prompt: str
success: bool
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
async def optimize_prompt(
request: PromptOptimizeRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> PromptOptimizeResponse:
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
try:
user_id = require_authenticated_user(current_user)
if not request.text or not request.text.strip():
raise HTTPException(status_code=400, detail="Prompt text is required")
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
from services.wavespeed.client import WaveSpeedClient
client = WaveSpeedClient()
optimized_prompt = client.optimize_prompt(
text=request.text.strip(),
mode=request.mode or "image",
style=request.style or "default",
image=request.image, # Optional base64 image
enable_sync_mode=True,
timeout=30
)
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
return PromptOptimizeResponse(
optimized_prompt=optimized_prompt,
success=True
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to optimize prompt: {exc}")
raise HTTPException(status_code=500, detail=str(exc))