Files
ALwrity/backend/api/story_writer/router.py
2025-11-17 17:38:23 +05:30

1327 lines
54 KiB
Python

"""
Story Writer API Router
Main router for story generation operations including premise, outline,
content generation, and full story creation.
"""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Any, Dict, Union, List, Optional
from loguru import logger
from middleware.auth_middleware import get_current_user
from models.story_models import (
StoryGenerationRequest,
StorySetupGenerationRequest,
StorySetupGenerationResponse,
StorySetupOption,
StoryStartRequest,
StoryPremiseResponse,
StoryOutlineResponse,
StoryScene,
StoryContentResponse,
StoryFullGenerationResponse,
StoryContinueRequest,
StoryContinueResponse,
StoryImageGenerationRequest,
StoryImageGenerationResponse,
StoryImageResult,
StoryAudioGenerationRequest,
StoryAudioGenerationResponse,
StoryAudioResult,
StoryVideoGenerationRequest,
StoryVideoGenerationResponse,
StoryVideoResult,
TaskStatus,
)
from services.story_writer.story_service import StoryWriterService
from .task_manager import task_manager
from .cache_manager import cache_manager
from uuid import uuid4
from pydantic import BaseModel
from pathlib import Path
from .utils.auth import require_authenticated_user
from .utils.media_utils import resolve_media_file
from .utils.hd_video import (
generate_hd_video_payload,
generate_hd_video_scene_payload,
)
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
service = StoryWriterService()
@router.get("/health")
async def health() -> Dict[str, Any]:
"""Health check endpoint."""
return {"status": "ok", "service": "story_writer"}
# ---------------------------
# Story Setup Generation Endpoints
# ---------------------------
@router.post("/generate-setup", response_model=StorySetupGenerationResponse)
async def generate_story_setup(
request: StorySetupGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StorySetupGenerationResponse:
"""Generate 3 story setup options from a user's story idea."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.story_idea or not request.story_idea.strip():
raise HTTPException(status_code=400, detail="Story idea is required")
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
options = service.generate_story_setup_options(
story_idea=request.story_idea,
user_id=user_id
)
# Convert dict options to StorySetupOption models
setup_options = [StorySetupOption(**option) for option in options]
return StorySetupGenerationResponse(options=setup_options, success=True)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate story setup options: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Premise Generation Endpoints
# ---------------------------
@router.post("/generate-premise", response_model=StoryPremiseResponse)
async def generate_premise(
request: StoryGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryPremiseResponse:
"""Generate a story premise."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
logger.info(f"[StoryWriter] Generating premise for user {user_id}")
premise = service.generate_premise(
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
user_id=user_id
)
return StoryPremiseResponse(premise=premise, success=True)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate premise: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Outline Generation Endpoints
# ---------------------------
@router.post("/generate-outline", response_model=StoryOutlineResponse)
async def generate_outline(
request: StoryStartRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
use_structured: bool = True
) -> StoryOutlineResponse:
"""Generate a story outline from a premise."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.premise or not request.premise.strip():
raise HTTPException(status_code=400, detail="Premise is required")
logger.info(f"[StoryWriter] Generating outline for user {user_id} (structured={use_structured})")
logger.info(f"[StoryWriter] Outline generation parameters: audience_age_group={request.audience_age_group}, writing_style={request.writing_style}, story_tone={request.story_tone}")
outline = service.generate_outline(
premise=request.premise,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
user_id=user_id,
use_structured_output=use_structured
)
# Check if outline is structured (list of scenes) or plain text
is_structured = isinstance(outline, list)
if is_structured:
# Convert dict scenes to StoryScene models
scenes = [StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline]
return StoryOutlineResponse(outline=scenes, success=True, is_structured=True)
else:
# Plain text outline
return StoryOutlineResponse(outline=str(outline), success=True, is_structured=False)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate outline: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Story Content Generation Endpoints
# ---------------------------
@router.post("/generate-start", response_model=StoryContentResponse)
async def generate_story_start(
request: StoryStartRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryContentResponse:
"""Generate the starting section of a story."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.premise or not request.premise.strip():
raise HTTPException(status_code=400, detail="Premise is required")
if not request.outline or (isinstance(request.outline, str) and not request.outline.strip()):
raise HTTPException(status_code=400, detail="Outline is required")
logger.info(f"[StoryWriter] Generating story start for user {user_id}")
# Handle outline - could be string or list (structured scenes)
outline_data = request.outline
# Convert StoryScene models to dicts if needed
if isinstance(outline_data, list) and len(outline_data) > 0:
if isinstance(outline_data[0], StoryScene):
outline_data = [scene.dict() for scene in outline_data]
story_length = getattr(request, 'story_length', 'Medium')
story_start = service.generate_story_start(
premise=request.premise,
outline=outline_data,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
story_length=story_length,
user_id=user_id
)
# Check if this is a short story - if so, mark as complete immediately
story_length_lower = story_length.lower()
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
# For short stories, check word count to verify completeness
is_complete = False
if is_short_story:
word_count = len(story_start.split()) if story_start else 0
# Short story should be ~1000 words (900-1100 acceptable range)
if word_count >= 900:
is_complete = True
logger.info(f"[StoryWriter] Short story generated with {word_count} words. Marking as complete.")
else:
logger.warning(f"[StoryWriter] Short story generated with only {word_count} words. May need continuation.")
# Format outline for response (convert list to string if needed)
outline_response = outline_data
if isinstance(outline_data, list):
# Format structured outline as readable text
outline_response = "\n".join([
f"Scene {scene.get('scene_number', i+1) if isinstance(scene, dict) else getattr(scene, 'scene_number', i+1)}: "
f"{scene.get('title', 'Untitled') if isinstance(scene, dict) else getattr(scene, 'title', 'Untitled')}\n"
f" {scene.get('description', '') if isinstance(scene, dict) else getattr(scene, 'description', '')}"
for i, scene in enumerate(outline_data)
])
return StoryContentResponse(
story=story_start,
premise=request.premise,
outline=str(outline_response),
is_complete=is_complete, # True for short stories that are complete, False for medium/long
success=True
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate story start: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/continue", response_model=StoryContinueResponse)
async def continue_story(
request: StoryContinueRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryContinueResponse:
"""Continue writing a story."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.story_text or not request.story_text.strip():
raise HTTPException(status_code=400, detail="Story text is required")
logger.info(f"[StoryWriter] Continuing story for user {user_id}")
# Handle outline - could be string or list (structured scenes)
outline_data = request.outline
# Convert StoryScene models to dicts if needed
if isinstance(outline_data, list) and len(outline_data) > 0:
if isinstance(outline_data[0], StoryScene):
outline_data = [scene.dict() for scene in outline_data]
# Check word count before continuing
story_length = getattr(request, 'story_length', 'Medium')
story_length_lower = story_length.lower()
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
# Block continuation for short stories - they should be complete in one call
if is_short_story:
logger.warning(f"[StoryWriter] Attempted to continue a short story. Short stories should be complete in one call.")
raise HTTPException(
status_code=400,
detail="Short stories are generated in a single call and should be complete. If the story is incomplete, please regenerate it from the beginning."
)
current_word_count = len(request.story_text.split()) if request.story_text else 0
# Determine target word count based on story length (with 5% buffer)
# Medium: <5000 words (target ~4500, buffer ~4725)
# Long: around 10000 words (target ~10000, buffer ~10500)
if "long" in story_length_lower or "10000" in story_length_lower:
target_total_words = 10000
buffer_target = int(10000 * 1.05) # 10500 words maximum
else:
# Medium story: <5000 words
target_total_words = 4500 # Target for medium stories
buffer_target = int(4500 * 1.05) # ~4725 words maximum
# If target is already reached or exceeded, return completion immediately
if current_word_count >= buffer_target:
logger.info(f"[StoryWriter] Word count ({current_word_count}) already at or past buffer target ({buffer_target}) for {story_length} story. Story is complete.")
return StoryContinueResponse(
continuation="IAMDONE",
is_complete=True,
success=True
)
# Also check if we're very close to target (within 50 words)
if current_word_count >= target_total_words and (current_word_count - target_total_words) < 50:
logger.info(f"[StoryWriter] Word count ({current_word_count}) is very close to target ({target_total_words}). Story is complete.")
return StoryContinueResponse(
continuation="IAMDONE",
is_complete=True,
success=True
)
continuation = service.continue_story(
premise=request.premise,
outline=outline_data,
story_text=request.story_text,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
story_length=story_length,
user_id=user_id
)
# Check if continuation is IAMDONE or if word count now exceeds target
is_complete = 'IAMDONE' in continuation.upper()
# Also check word count after continuation
if not is_complete and continuation:
# Estimate new word count
new_story_text = request.story_text + '\n\n' + continuation
new_word_count = len(new_story_text.split())
# Calculate buffer target
buffer_target = int(target_total_words * 1.05)
# If new word count exceeds buffer target, mark as complete
if new_word_count >= buffer_target:
logger.info(f"[StoryWriter] Word count ({new_word_count}) now exceeds buffer target ({buffer_target}). Story is complete.")
# Append IAMDONE if not already present
if 'IAMDONE' not in continuation.upper():
continuation = continuation.rstrip() + '\n\nIAMDONE'
is_complete = True
# Also check if we're at or very close to target
elif new_word_count >= target_total_words and (new_word_count - target_total_words) < 100:
logger.info(f"[StoryWriter] Word count ({new_word_count}) is at or very close to target ({target_total_words}). Story is complete.")
if 'IAMDONE' not in continuation.upper():
continuation = continuation.rstrip() + '\n\nIAMDONE'
is_complete = True
return StoryContinueResponse(
continuation=continuation,
is_complete=is_complete,
success=True
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to continue story: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Full Story Generation Endpoints (Async)
# ---------------------------
@router.post("/generate-full", response_model=Dict[str, Any])
async def generate_full_story(
request: StoryGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
max_iterations: int = 10
) -> Dict[str, Any]:
"""Generate a complete story asynchronously."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Check cache first
cache_key = cache_manager.get_cache_key(request.dict())
cached_result = cache_manager.get_cached_result(cache_key)
if cached_result:
logger.info(f"[StoryWriter] Returning cached result for user {user_id}")
task_id = task_manager.create_task("story_generation")
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
result=cached_result,
message="Returned cached result"
)
return {"task_id": task_id, "cached": True}
# Create task
task_id = task_manager.create_task("story_generation")
# Prepare request data
request_data = request.dict()
request_data["max_iterations"] = max_iterations
# Execute task in background
background_tasks.add_task(
task_manager.execute_story_generation_task,
task_id=task_id,
request_data=request_data,
user_id=user_id
)
logger.info(f"[StoryWriter] Created task {task_id} for full story generation (user {user_id})")
return {
"task_id": task_id,
"status": "pending",
"message": "Story generation started. Use /task/{task_id}/status to check progress."
}
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to start story generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Task Management Endpoints
# ---------------------------
@router.get("/task/{task_id}/status", response_model=TaskStatus)
async def get_task_status(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> TaskStatus:
"""Get the status of a story generation task."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
return TaskStatus(**task_status)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to get task status: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/task/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get the result of a completed story generation task."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
if task_status["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Task {task_id} is not completed. Status: {task_status['status']}"
)
result = task_status.get("result")
if not result:
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
# Some tasks return a full-story payload compatible with StoryFullGenerationResponse,
# others (e.g., video-only) return a dict like {"video": {...}, "success": True}.
# To avoid model conflicts, return a generic payload and include task_id.
# Frontend callers can branch on keys present (e.g., "video").
if isinstance(result, dict):
# Ensure success flag present without duplicating
payload = {**result}
payload.setdefault("success", True)
payload["task_id"] = task_id
return payload
# Fallback: wrap non-dict results
return {"result": result, "success": True, "task_id": task_id}
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to get task result: {e}")
raise HTTPException(status_code=500, detail=str(e))
class HDVideoRequest(BaseModel):
prompt: str
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video")
async def generate_hd_video(
request: HDVideoRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate an HD AI animation using provider text-to-video (Hugging Face for now).
Saves the returned bytes as a video file and returns the secured URL.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
class HDVideoSceneRequest(BaseModel):
scene_number: int
scene_data: Dict[str, Any]
story_context: Dict[str, Any]
all_scenes: List[Dict[str, Any]]
scene_image_url: Optional[str] = None
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/hd-video-scene")
async def generate_hd_video_scene(
request: HDVideoSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate HD AI video for a single scene with AI-enhanced prompt.
Uses prompt enhancer to create HunyuanVideo-optimized prompt from story context.
"""
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_scene_payload(request, user_id)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Image Generation Endpoints
# ---------------------------
@router.post("/generate-images", response_model=StoryImageGenerationResponse)
async def generate_scene_images(
request: StoryImageGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryImageGenerationResponse:
"""Generate images for story scenes."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating images for {len(request.scenes)} scenes for user {user_id}")
# Import image generation service
from services.story_writer.image_generation_service import StoryImageGenerationService
image_service = StoryImageGenerationService()
# Convert StoryScene models to dicts
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
# Generate images for all scenes
image_results = image_service.generate_scene_images(
scenes=scenes_data,
user_id=user_id,
provider=request.provider,
width=request.width or 1024,
height=request.height or 1024,
model=request.model
)
# Convert results to StoryImageResult models
image_models = [
StoryImageResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
width=result.get("width", 1024),
height=result.get("height", 1024),
provider=result.get("provider", "unknown"),
model=result.get("model"),
seed=result.get("seed"),
error=result.get("error")
)
for result in image_results
]
return StoryImageGenerationResponse(
images=image_models,
success=True
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate images: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/images/{image_filename}")
async def serve_scene_image(
image_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve a generated story scene image."""
try:
require_authenticated_user(current_user)
from services.story_writer.image_generation_service import StoryImageGenerationService
from fastapi.responses import FileResponse
image_service = StoryImageGenerationService()
image_path = resolve_media_file(image_service.output_dir, image_filename)
return FileResponse(
path=str(image_path),
media_type="image/png",
filename=image_filename
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to serve image: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Audio Generation Endpoints
# ---------------------------
@router.post("/generate-audio", response_model=StoryAudioGenerationResponse)
async def generate_scene_audio(
request: StoryAudioGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryAudioGenerationResponse:
"""Generate audio narration for story scenes."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating audio for {len(request.scenes)} scenes for user {user_id}")
# Import audio generation service
from services.story_writer.audio_generation_service import StoryAudioGenerationService
audio_service = StoryAudioGenerationService()
# Convert StoryScene models to dicts
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
# Generate audio for all scenes
audio_results = audio_service.generate_scene_audio_list(
scenes=scenes_data,
user_id=user_id,
provider=request.provider or "gtts",
lang=request.lang or "en",
slow=request.slow or False,
rate=request.rate or 150
)
# Convert results to StoryAudioResult models
# Ensure all required fields are strings, not None
audio_models = []
for result in audio_results:
# Handle None values by converting to empty strings for required fields
audio_url = result.get("audio_url") or ""
audio_filename = result.get("audio_filename") or ""
audio_models.append(
StoryAudioResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
audio_filename=audio_filename,
audio_url=audio_url,
provider=result.get("provider", "unknown"),
file_size=result.get("file_size", 0),
error=result.get("error")
)
)
return StoryAudioGenerationResponse(
audio_files=audio_models,
success=True
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate audio: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/audio/{audio_filename}")
async def serve_scene_audio(
audio_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve a generated story scene audio file."""
try:
require_authenticated_user(current_user)
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from fastapi.responses import FileResponse
audio_service = StoryAudioGenerationService()
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
return FileResponse(
path=str(audio_path),
media_type="audio/mpeg",
filename=audio_filename
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to serve audio: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Video Generation Endpoints
# ---------------------------
@router.post("/generate-video", response_model=StoryVideoGenerationResponse)
async def generate_story_video(
request: StoryVideoGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> StoryVideoGenerationResponse:
"""Generate a video from story scenes, images, and audio."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
logger.info(f"[StoryWriter] Generating video for {len(request.scenes)} scenes for user {user_id}")
# Import video generation service and image/audio services
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from pathlib import Path
video_service = StoryVideoGenerationService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
# Convert StoryScene models to dicts
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
# Extract image and audio filenames from URLs
image_paths = []
audio_paths = []
valid_scenes = []
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
# Extract filename from URL (e.g., "/api/story/images/scene_1_image.png" -> "scene_1_image.png")
# Handle both full URLs and relative paths
image_filename = image_url.split('/')[-1] if '/' in image_url else image_url
audio_filename = audio_url.split('/')[-1] if '/' in audio_url else audio_url
# Remove query parameters if present
image_filename = image_filename.split('?')[0]
audio_filename = audio_filename.split('?')[0]
# Construct full paths
image_path = image_service.output_dir / image_filename
audio_path = audio_service.output_dir / audio_filename
if not image_path.exists():
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
image_paths.append(str(image_path))
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if len(image_paths) == 0 or len(audio_paths) == 0:
raise HTTPException(status_code=400, detail="No valid image or audio files were found")
if len(image_paths) != len(audio_paths):
raise HTTPException(status_code=400, detail="Number of valid images and audio files must match")
# Use only valid scenes that have both image and audio
scenes_data = valid_scenes
# Generate video
video_result = video_service.generate_story_video(
scenes=scenes_data,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5
)
# Convert result to StoryVideoResult model
video_model = StoryVideoResult(
video_filename=video_result.get("video_filename", ""),
video_url=video_result.get("video_url", ""),
duration=video_result.get("duration", 0.0),
fps=video_result.get("fps", 24),
file_size=video_result.get("file_size", 0),
num_scenes=video_result.get("num_scenes", 0),
error=video_result.get("error")
)
return StoryVideoGenerationResponse(
video=video_model,
success=True
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to generate video: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/generate-video-async", response_model=Dict[str, Any])
async def generate_story_video_async(
request: StoryVideoGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Generate a video asynchronously with progress updates via task manager.
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
"""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(status_code=400, detail="Number of scenes, image URLs, and audio URLs must match")
task_id = task_manager.create_task("story_video_generation")
background_tasks.add_task(
_execute_video_generation_task,
task_id=task_id,
request=request,
user_id=user_id
)
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to start async video generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
"""Background task to generate story video with progress mapped to task manager."""
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
try:
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
video_service = StoryVideoGenerationService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
# Prepare assets
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_paths, audio_paths, valid_scenes = [], [], []
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
image_filename = (image_url.split('/')[-1] if '/' in image_url else image_url).split('?')[0]
audio_filename = (audio_url.split('/')[-1] if '/' in audio_url else audio_url).split('?')[0]
image_path = image_service.output_dir / image_filename
audio_path = audio_service.output_dir / audio_filename
if not image_path.exists():
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
image_paths.append(str(image_path))
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
# Map service progress (0-100) to task progress (5-95)
def progress_callback(sub_progress: float, msg: str):
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
progress_callback=progress_callback
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Video generation complete!",
result={"video": result, "success": True}
)
except Exception as e:
logger.error(f"[StoryWriter] Async video generation failed: {e}", exc_info=True)
task_manager.update_task_status(task_id, "failed", error=str(e), message=f"Video generation failed: {e}")
@router.post("/generate-complete-video", response_model=Dict[str, Any])
async def generate_complete_story_video(
request: StoryGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Generate a complete story video (outline → images → audio → video) asynchronously."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
logger.info(f"[StoryWriter] Starting complete video generation for user {user_id}")
# Create task
task_id = task_manager.create_task("complete_video_generation")
# Start background task
background_tasks.add_task(
execute_complete_video_generation,
task_id=task_id,
request_data=request.dict(),
user_id=user_id
)
return {
"task_id": task_id,
"status": "pending",
"message": "Complete video generation started"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to start complete video generation: {e}")
raise HTTPException(status_code=500, detail=str(e))
def execute_complete_video_generation(
task_id: str,
request_data: Dict[str, Any],
user_id: str
):
"""
Execute complete video generation workflow synchronously.
This function runs in a background task and performs blocking operations.
It's not async because it calls synchronous methods from the services.
"""
from services.story_writer.story_service import StoryWriterService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from services.story_writer.video_generation_service import StoryVideoGenerationService
service = StoryWriterService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
video_service = StoryVideoGenerationService()
try:
task_manager.update_task_status(task_id, "processing", progress=5.0, message="Starting complete video generation...")
# Step 1: Generate premise
task_manager.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
premise = service.generate_premise(
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
)
# Step 2: Generate structured outline
task_manager.update_task_status(task_id, "processing", progress=20.0, message="Generating structured outline with scenes...")
outline_scenes = service.generate_outline(
premise=premise,
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id,
use_structured_output=True
)
if not isinstance(outline_scenes, list):
raise RuntimeError("Failed to generate structured outline")
# Step 3: Generate images for all scenes
# Progress range: 30-50% (20% total for image generation)
task_manager.update_task_status(task_id, "processing", progress=30.0, message="Generating images for scenes...")
def image_progress_callback(sub_progress: float, message: str):
"""Map sub-progress (0-100) to overall progress (30-50%)."""
overall_progress = 30.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
# Get image generation settings from request (with defaults)
image_provider = request_data.get("image_provider")
image_width = request_data.get("image_width", 1024)
image_height = request_data.get("image_height", 1024)
image_model = request_data.get("image_model")
image_results = image_service.generate_scene_images(
scenes=outline_scenes,
user_id=user_id,
provider=image_provider,
width=image_width,
height=image_height,
model=image_model,
progress_callback=image_progress_callback
)
# Step 4: Generate audio for all scenes
# Progress range: 50-70% (20% total for audio generation)
task_manager.update_task_status(task_id, "processing", progress=50.0, message="Generating audio narration for scenes...")
def audio_progress_callback(sub_progress: float, message: str):
"""Map sub-progress (0-100) to overall progress (50-70%)."""
overall_progress = 50.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
# Get audio generation settings from request (with defaults)
audio_provider = request_data.get("audio_provider", "gtts")
audio_lang = request_data.get("audio_lang", "en")
audio_slow = request_data.get("audio_slow", False)
audio_rate = request_data.get("audio_rate", 150)
audio_results = audio_service.generate_scene_audio_list(
scenes=outline_scenes,
user_id=user_id,
provider=audio_provider,
lang=audio_lang,
slow=audio_slow,
rate=audio_rate,
progress_callback=audio_progress_callback
)
# Step 5: Prepare image and audio paths
task_manager.update_task_status(task_id, "processing", progress=70.0, message="Preparing video assets...")
image_paths = []
audio_paths = []
valid_scenes = []
for scene in outline_scenes:
scene_number = scene.get("scene_number", 0)
image_result = next((img for img in image_results if img.get("scene_number") == scene_number), None)
audio_result = next((aud for aud in audio_results if aud.get("scene_number") == scene_number), None)
if image_result and audio_result and not image_result.get("error") and not audio_result.get("error"):
image_path = image_result.get("image_path")
audio_path = audio_result.get("audio_path")
if image_path and audio_path:
image_paths.append(image_path)
audio_paths.append(audio_path)
valid_scenes.append(scene)
if len(image_paths) == 0 or len(audio_paths) == 0:
raise RuntimeError(f"No valid images or audio files were generated. Images: {len(image_paths)}, Audio: {len(audio_paths)}")
if len(image_paths) != len(audio_paths):
raise RuntimeError(f"Mismatch between image and audio counts. Images: {len(image_paths)}, Audio: {len(audio_paths)}")
# Step 6: Generate video
# Progress range: 75-95% (20% total for video generation)
task_manager.update_task_status(task_id, "processing", progress=75.0, message="Composing video from scenes...")
def video_progress_callback(sub_progress: float, message: str):
"""Map sub-progress (0-100) to overall progress (75-95%)."""
overall_progress = 75.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
# Get video generation settings from request (with defaults)
video_fps = request_data.get("video_fps", 24)
video_transition_duration = request_data.get("video_transition_duration", 0.5)
story_title = request_data.get("story_setting", "Story")[:50]
video_result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=story_title,
fps=video_fps,
transition_duration=video_transition_duration,
progress_callback=video_progress_callback
)
# Prepare result
result = {
"premise": premise,
"outline_scenes": outline_scenes,
"images": image_results,
"audio_files": audio_results,
"video": video_result,
"success": True
}
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Complete video generation finished!",
result=result
)
logger.info(f"[StoryWriter] Complete video generation task {task_id} completed successfully")
except Exception as e:
error_msg = str(e)
logger.error(f"[StoryWriter] Complete video generation task {task_id} failed: {error_msg}", exc_info=True)
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"Complete video generation failed: {error_msg}"
)
@router.get("/videos/{video_filename}")
async def serve_story_video(
video_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve a generated story video file."""
try:
require_authenticated_user(current_user)
from services.story_writer.video_generation_service import StoryVideoGenerationService
from fastapi.responses import FileResponse
video_service = StoryVideoGenerationService()
video_path = resolve_media_file(video_service.output_dir, video_filename)
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=video_filename
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to serve video: {e}")
raise HTTPException(status_code=500, detail=str(e))
# ---------------------------
# Cache Management Endpoints
# ---------------------------
@router.get("/cache/stats")
async def get_cache_stats(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Get cache statistics."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
stats = cache_manager.get_cache_stats()
return {"success": True, "stats": stats}
except Exception as e:
logger.error(f"[StoryWriter] Failed to get cache stats: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cache/clear")
async def clear_cache(
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""Clear the story generation cache."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
result = cache_manager.clear_cache()
return {"success": True, **result}
except Exception as e:
logger.error(f"[StoryWriter] Failed to clear cache: {e}")
raise HTTPException(status_code=500, detail=str(e))