Base code
This commit is contained in:
9
backend/api/story_writer/__init__.py
Normal file
9
backend/api/story_writer/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Story Writer API
|
||||
|
||||
API endpoints for story generation functionality.
|
||||
"""
|
||||
|
||||
from .router import router
|
||||
|
||||
__all__ = ['router']
|
||||
70
backend/api/story_writer/cache_manager.py
Normal file
70
backend/api/story_writer/cache_manager.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""
|
||||
Cache Management System for Story Writer API
|
||||
|
||||
Handles story generation cache operations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Manages cache operations for story generation data."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the cache manager."""
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
logger.info("[StoryWriter] CacheManager initialized")
|
||||
|
||||
def get_cache_key(self, request_data: Dict[str, Any]) -> str:
|
||||
"""Generate a cache key from request data."""
|
||||
import hashlib
|
||||
import json
|
||||
|
||||
# Create a normalized version of the request for caching
|
||||
cache_data = {
|
||||
"persona": request_data.get("persona", ""),
|
||||
"story_setting": request_data.get("story_setting", ""),
|
||||
"character_input": request_data.get("character_input", ""),
|
||||
"plot_elements": request_data.get("plot_elements", ""),
|
||||
"writing_style": request_data.get("writing_style", ""),
|
||||
"story_tone": request_data.get("story_tone", ""),
|
||||
"narrative_pov": request_data.get("narrative_pov", ""),
|
||||
"audience_age_group": request_data.get("audience_age_group", ""),
|
||||
"content_rating": request_data.get("content_rating", ""),
|
||||
"ending_preference": request_data.get("ending_preference", ""),
|
||||
}
|
||||
|
||||
cache_str = json.dumps(cache_data, sort_keys=True)
|
||||
return hashlib.md5(cache_str.encode()).hexdigest()
|
||||
|
||||
def get_cached_result(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get a cached result if available."""
|
||||
if cache_key in self.cache:
|
||||
logger.debug(f"[StoryWriter] Cache hit for key: {cache_key}")
|
||||
return self.cache[cache_key]
|
||||
logger.debug(f"[StoryWriter] Cache miss for key: {cache_key}")
|
||||
return None
|
||||
|
||||
def cache_result(self, cache_key: str, result: Dict[str, Any]):
|
||||
"""Cache a result."""
|
||||
self.cache[cache_key] = result
|
||||
logger.debug(f"[StoryWriter] Cached result for key: {cache_key}")
|
||||
|
||||
def clear_cache(self):
|
||||
"""Clear all cached results."""
|
||||
count = len(self.cache)
|
||||
self.cache.clear()
|
||||
logger.info(f"[StoryWriter] Cleared {count} cached entries")
|
||||
return {"status": "success", "message": f"Cleared {count} cached entries"}
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
return {
|
||||
"total_entries": len(self.cache),
|
||||
"cache_keys": list(self.cache.keys())
|
||||
}
|
||||
|
||||
|
||||
# Global cache manager instance
|
||||
cache_manager = CacheManager()
|
||||
37
backend/api/story_writer/router.py
Normal file
37
backend/api/story_writer/router.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Story Writer API Router
|
||||
|
||||
Main router for story generation operations. This file serves as the entry point
|
||||
and includes modular sub-routers for different functionality areas.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .routes import (
|
||||
cache_routes,
|
||||
media_generation,
|
||||
scene_animation,
|
||||
story_content,
|
||||
story_setup,
|
||||
story_tasks,
|
||||
video_generation,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
|
||||
|
||||
# Include modular routers (order preserved roughly by workflow)
|
||||
router.include_router(story_setup.router)
|
||||
router.include_router(story_content.router)
|
||||
router.include_router(story_tasks.router)
|
||||
router.include_router(media_generation.router)
|
||||
router.include_router(scene_animation.router)
|
||||
router.include_router(video_generation.router)
|
||||
router.include_router(cache_routes.router)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> Dict[str, Any]:
|
||||
"""Health check endpoint."""
|
||||
return {"status": "ok", "service": "story_writer"}
|
||||
23
backend/api/story_writer/routes/__init__.py
Normal file
23
backend/api/story_writer/routes/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""
|
||||
Collection of modular routers for Story Writer endpoints.
|
||||
Each module focuses on a related set of routes to keep the primary
|
||||
`router.py` concise and easier to maintain.
|
||||
"""
|
||||
|
||||
from . import cache_routes
|
||||
from . import media_generation
|
||||
from . import scene_animation
|
||||
from . import story_content
|
||||
from . import story_setup
|
||||
from . import story_tasks
|
||||
from . import video_generation
|
||||
|
||||
__all__ = [
|
||||
"cache_routes",
|
||||
"media_generation",
|
||||
"scene_animation",
|
||||
"story_content",
|
||||
"story_setup",
|
||||
"story_tasks",
|
||||
"video_generation",
|
||||
]
|
||||
42
backend/api/story_writer/routes/cache_routes.py
Normal file
42
backend/api/story_writer/routes/cache_routes.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
from ..cache_manager import cache_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/cache/stats")
|
||||
async def get_cache_stats(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
stats = cache_manager.get_cache_stats()
|
||||
return {"success": True, "stats": stats}
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to get cache stats: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Clear the story generation cache."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
result = cache_manager.clear_cache()
|
||||
return {"success": True, **result}
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to clear cache: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
416
backend/api/story_writer/routes/media_generation.py
Normal file
416
backend/api/story_writer/routes/media_generation.py
Normal file
@@ -0,0 +1,416 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from models.story_models import (
|
||||
StoryImageGenerationRequest,
|
||||
StoryImageGenerationResponse,
|
||||
StoryImageResult,
|
||||
RegenerateImageRequest,
|
||||
RegenerateImageResponse,
|
||||
StoryAudioGenerationRequest,
|
||||
StoryAudioGenerationResponse,
|
||||
StoryAudioResult,
|
||||
GenerateAIAudioRequest,
|
||||
GenerateAIAudioResponse,
|
||||
StoryScene,
|
||||
)
|
||||
from services.database import get_db
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.media_utils import resolve_media_file
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
|
||||
|
||||
@router.post("/generate-images", response_model=StoryImageGenerationResponse)
|
||||
async def generate_scene_images(
|
||||
request: StoryImageGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StoryImageGenerationResponse:
|
||||
"""Generate images for story scenes."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating images for {len(request.scenes)} scenes for user {user_id}")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_results = image_service.generate_scene_images(
|
||||
scenes=scenes_data,
|
||||
user_id=user_id,
|
||||
provider=request.provider,
|
||||
width=request.width or 1024,
|
||||
height=request.height or 1024,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
image_models: List[StoryImageResult] = [
|
||||
StoryImageResult(
|
||||
scene_number=result.get("scene_number", 0),
|
||||
scene_title=result.get("scene_title", "Untitled"),
|
||||
image_filename=result.get("image_filename", ""),
|
||||
image_url=result.get("image_url", ""),
|
||||
width=result.get("width", 1024),
|
||||
height=result.get("height", 1024),
|
||||
provider=result.get("provider", "unknown"),
|
||||
model=result.get("model"),
|
||||
seed=result.get("seed"),
|
||||
error=result.get("error"),
|
||||
)
|
||||
for result in image_results
|
||||
]
|
||||
|
||||
# Save assets to library
|
||||
for result in image_results:
|
||||
if not result.get("error") and result.get("image_url"):
|
||||
try:
|
||||
scene_number = result.get("scene_number", 0)
|
||||
# Safely get prompt from scenes_data with bounds checking
|
||||
prompt = None
|
||||
if scene_number > 0 and scene_number <= len(scenes_data):
|
||||
prompt = scenes_data[scene_number - 1].get("image_prompt")
|
||||
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="story_writer",
|
||||
filename=result.get("image_filename", ""),
|
||||
file_url=result.get("image_url", ""),
|
||||
file_path=result.get("image_path"),
|
||||
file_size=result.get("file_size"),
|
||||
mime_type="image/png",
|
||||
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
|
||||
description=f"Story scene image for scene {scene_number}",
|
||||
prompt=prompt,
|
||||
tags=["story_writer", "scene", f"scene_{scene_number}"],
|
||||
provider=result.get("provider"),
|
||||
model=result.get("model"),
|
||||
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save image asset to library: {e}")
|
||||
|
||||
return StoryImageGenerationResponse(images=image_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate images: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/regenerate-images", response_model=RegenerateImageResponse)
|
||||
async def regenerate_scene_image(
|
||||
request: RegenerateImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> RegenerateImageResponse:
|
||||
"""Regenerate a single scene image using a direct prompt (no AI prompt generation)."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.prompt or not request.prompt.strip():
|
||||
raise HTTPException(status_code=400, detail="Prompt is required")
|
||||
|
||||
logger.info(
|
||||
f"[StoryWriter] Regenerating image for scene {request.scene_number} "
|
||||
f"({request.scene_title}) for user {user_id}"
|
||||
)
|
||||
|
||||
result = image_service.regenerate_scene_image(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
prompt=request.prompt.strip(),
|
||||
user_id=user_id,
|
||||
provider=request.provider,
|
||||
width=request.width or 1024,
|
||||
height=request.height or 1024,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
return RegenerateImageResponse(
|
||||
scene_number=result.get("scene_number", request.scene_number),
|
||||
scene_title=result.get("scene_title", request.scene_title),
|
||||
image_filename=result.get("image_filename", ""),
|
||||
image_url=result.get("image_url", ""),
|
||||
width=result.get("width", request.width or 1024),
|
||||
height=result.get("height", request.height or 1024),
|
||||
provider=result.get("provider", "unknown"),
|
||||
model=result.get("model"),
|
||||
seed=result.get("seed"),
|
||||
success=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to regenerate image: {exc}")
|
||||
return RegenerateImageResponse(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
image_filename="",
|
||||
image_url="",
|
||||
width=request.width or 1024,
|
||||
height=request.height or 1024,
|
||||
provider=request.provider or "unknown",
|
||||
success=False,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/images/{image_filename}")
|
||||
async def serve_scene_image(
|
||||
image_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve a generated story scene image.
|
||||
|
||||
Supports authentication via Authorization header or token query parameter.
|
||||
Query parameter is useful for HTML elements like <img> that cannot send custom headers.
|
||||
"""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
image_path = resolve_media_file(image_service.output_dir, image_filename)
|
||||
return FileResponse(path=str(image_path), media_type="image/png", filename=image_filename)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve image: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-audio", response_model=StoryAudioGenerationResponse)
|
||||
async def generate_scene_audio(
|
||||
request: StoryAudioGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> StoryAudioGenerationResponse:
|
||||
"""Generate audio narration for story scenes."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating audio for {len(request.scenes)} scenes for user {user_id}")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
audio_results = audio_service.generate_scene_audio_list(
|
||||
scenes=scenes_data,
|
||||
user_id=user_id,
|
||||
provider=request.provider or "gtts",
|
||||
lang=request.lang or "en",
|
||||
slow=request.slow or False,
|
||||
rate=request.rate or 150,
|
||||
)
|
||||
|
||||
audio_models: List[StoryAudioResult] = []
|
||||
for result in audio_results:
|
||||
audio_url = result.get("audio_url") or ""
|
||||
audio_filename = result.get("audio_filename") or ""
|
||||
|
||||
audio_models.append(
|
||||
StoryAudioResult(
|
||||
scene_number=result.get("scene_number", 0),
|
||||
scene_title=result.get("scene_title", "Untitled"),
|
||||
audio_filename=audio_filename,
|
||||
audio_url=audio_url,
|
||||
provider=result.get("provider", "unknown"),
|
||||
file_size=result.get("file_size", 0),
|
||||
error=result.get("error"),
|
||||
)
|
||||
)
|
||||
|
||||
# Save assets to library
|
||||
if not result.get("error") and audio_url:
|
||||
try:
|
||||
scene_number = result.get("scene_number", 0)
|
||||
# Safely get prompt from scenes_data with bounds checking
|
||||
prompt = None
|
||||
if scene_number > 0 and scene_number <= len(scenes_data):
|
||||
prompt = scenes_data[scene_number - 1].get("text")
|
||||
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="audio",
|
||||
source_module="story_writer",
|
||||
filename=audio_filename,
|
||||
file_url=audio_url,
|
||||
file_path=result.get("audio_path"),
|
||||
file_size=result.get("file_size"),
|
||||
mime_type="audio/mpeg",
|
||||
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
|
||||
description=f"Story scene audio narration for scene {scene_number}",
|
||||
prompt=prompt,
|
||||
tags=["story_writer", "audio", "narration", f"scene_{scene_number}"],
|
||||
provider=result.get("provider"),
|
||||
model=result.get("model"),
|
||||
cost=result.get("cost"),
|
||||
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save audio asset to library: {e}")
|
||||
|
||||
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate audio: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-ai-audio", response_model=GenerateAIAudioResponse)
|
||||
async def generate_ai_audio(
|
||||
request: GenerateAIAudioRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> GenerateAIAudioResponse:
|
||||
"""Generate AI audio for a single scene using WaveSpeed Minimax Speech 02 HD."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
raise HTTPException(status_code=400, detail="Text is required")
|
||||
|
||||
logger.info(
|
||||
f"[StoryWriter] Generating AI audio for scene {request.scene_number} "
|
||||
f"({request.scene_title}) for user {user_id}"
|
||||
)
|
||||
|
||||
result = audio_service.generate_ai_audio(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
text=request.text.strip(),
|
||||
user_id=user_id,
|
||||
voice_id=request.voice_id or "Wise_Woman",
|
||||
speed=request.speed or 1.0,
|
||||
volume=request.volume or 1.0,
|
||||
pitch=request.pitch or 0.0,
|
||||
emotion=request.emotion or "happy",
|
||||
)
|
||||
|
||||
return GenerateAIAudioResponse(
|
||||
scene_number=result.get("scene_number", request.scene_number),
|
||||
scene_title=result.get("scene_title", request.scene_title),
|
||||
audio_filename=result.get("audio_filename", ""),
|
||||
audio_url=result.get("audio_url", ""),
|
||||
provider=result.get("provider", "wavespeed"),
|
||||
model=result.get("model", "minimax/speech-02-hd"),
|
||||
voice_id=result.get("voice_id", request.voice_id or "Wise_Woman"),
|
||||
text_length=result.get("text_length", len(request.text)),
|
||||
file_size=result.get("file_size", 0),
|
||||
cost=result.get("cost", 0.0),
|
||||
success=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate AI audio: {exc}")
|
||||
return GenerateAIAudioResponse(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
audio_filename="",
|
||||
audio_url="",
|
||||
provider="wavespeed",
|
||||
model="minimax/speech-02-hd",
|
||||
voice_id=request.voice_id or "Wise_Woman",
|
||||
text_length=len(request.text) if request.text else 0,
|
||||
file_size=0,
|
||||
cost=0.0,
|
||||
success=False,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/{audio_filename}")
|
||||
async def serve_scene_audio(
|
||||
audio_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve a generated story scene audio file."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
|
||||
return FileResponse(path=str(audio_path), media_type="audio/mpeg", filename=audio_filename)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve audio: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
class PromptOptimizeRequest(BaseModel):
|
||||
text: str = Field(..., description="The prompt text to optimize")
|
||||
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
|
||||
style: Optional[str] = Field(
|
||||
default="default",
|
||||
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
|
||||
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
|
||||
)
|
||||
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
|
||||
|
||||
|
||||
class PromptOptimizeResponse(BaseModel):
|
||||
optimized_prompt: str
|
||||
success: bool
|
||||
|
||||
|
||||
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
|
||||
async def optimize_prompt(
|
||||
request: PromptOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> PromptOptimizeResponse:
|
||||
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
raise HTTPException(status_code=400, detail="Prompt text is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
|
||||
client = WaveSpeedClient()
|
||||
optimized_prompt = client.optimize_prompt(
|
||||
text=request.text.strip(),
|
||||
mode=request.mode or "image",
|
||||
style=request.style or "default",
|
||||
image=request.image, # Optional base64 image
|
||||
enable_sync_mode=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
|
||||
|
||||
return PromptOptimizeResponse(
|
||||
optimized_prompt=optimized_prompt,
|
||||
success=True
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to optimize prompt: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
484
backend/api/story_writer/routes/scene_animation.py
Normal file
484
backend/api/story_writer/routes/scene_animation.py
Normal file
@@ -0,0 +1,484 @@
|
||||
"""
|
||||
Scene Animation Routes
|
||||
|
||||
Handles scene animation endpoints using WaveSpeed Kling and InfiniteTalk.
|
||||
"""
|
||||
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from urllib.parse import quote
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
AnimateSceneRequest,
|
||||
AnimateSceneResponse,
|
||||
AnimateSceneVoiceoverRequest,
|
||||
ResumeSceneAnimationRequest,
|
||||
)
|
||||
from services.database import get_db
|
||||
from services.llm_providers.main_video_generation import track_video_usage
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_scene_animation_operation
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
from services.wavespeed.kling_animation import animate_scene_image, resume_scene_animation
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from ..task_manager import task_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.media_utils import load_story_audio_bytes, load_story_image_bytes
|
||||
|
||||
router = APIRouter()
|
||||
scene_logger = get_service_logger("api.story_writer.scene_animation")
|
||||
AI_VIDEO_SUBDIR = Path("AI_Videos")
|
||||
|
||||
|
||||
def _build_authenticated_media_url(request: Request, path: str) -> str:
|
||||
"""Append the caller's auth token to a media URL so <video>/<img> tags can access it."""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
token: Optional[str] = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.replace("Bearer ", "").strip()
|
||||
elif "token" in request.query_params:
|
||||
token = request.query_params["token"]
|
||||
|
||||
if token:
|
||||
separator = "&" if "?" in path else "?"
|
||||
path = f"{path}{separator}token={quote(token)}"
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def _guess_mime_from_url(url: str, fallback: str) -> str:
|
||||
"""Guess MIME type from URL."""
|
||||
if not url:
|
||||
return fallback
|
||||
mime, _ = mimetypes.guess_type(url)
|
||||
return mime or fallback
|
||||
|
||||
|
||||
@router.post("/animate-scene-preview", response_model=AnimateSceneResponse)
|
||||
async def animate_scene_preview(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""
|
||||
Animate a single scene image using WaveSpeed Kling v2.5 Turbo Std.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
duration = request.duration or 5
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds.")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] User=%s scene=%s duration=%s image_url=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
duration,
|
||||
request.image_url,
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
scene_logger.warning("[AnimateScene] Missing image bytes for user=%s scene=%s", user_id, request.scene_number)
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
animation_result = animate_scene_image(
|
||||
image_bytes=image_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Completed user=%s scene=%s duration=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["duration"],
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
# Save video asset to library
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="story_writer",
|
||||
filename=video_filename,
|
||||
file_url=video_url,
|
||||
file_path=str(ai_video_dir / video_filename),
|
||||
file_size=len(animation_result["video_bytes"]),
|
||||
mime_type="video/mp4",
|
||||
title=f"Scene {request.scene_number} Animation",
|
||||
description=f"Animated scene {request.scene_number} from story",
|
||||
prompt=animation_result["prompt"],
|
||||
tags=["story_writer", "video", "animation", f"scene_{request.scene_number}"],
|
||||
provider=animation_result["provider"],
|
||||
model=animation_result.get("model_name"),
|
||||
cost=animation_result["cost"],
|
||||
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-resume", response_model=AnimateSceneResponse)
|
||||
async def resume_scene_animation_endpoint(
|
||||
request_obj: Request,
|
||||
request: ResumeSceneAnimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""Resume downloading a WaveSpeed animation when the initial call timed out."""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] Resume requested user=%s scene=%s prediction=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
)
|
||||
|
||||
animation_result = resume_scene_animation(
|
||||
prediction_id=request.prediction_id,
|
||||
duration=request.duration or 5,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] (Resume) Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Resume completed user=%s scene=%s prediction=%s video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
video_url,
|
||||
)
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-voiceover", response_model=Dict[str, Any])
|
||||
async def animate_scene_voiceover_endpoint(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene using WaveSpeed InfiniteTalk (image + audio) asynchronously.
|
||||
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] User=%s scene=%s resolution=%s (async)",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.resolution or "720p",
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
audio_bytes = load_story_audio_bytes(request.audio_url)
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene audio not found. Generate audio first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Extract token for authenticated URL building (if needed)
|
||||
auth_token = None
|
||||
auth_header = request_obj.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.replace("Bearer ", "").strip()
|
||||
|
||||
# Create async task
|
||||
task_id = task_manager.create_task("scene_voiceover_animation")
|
||||
background_tasks.add_task(
|
||||
_execute_voiceover_animation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "InfiniteTalk animation started. This may take up to 10 minutes.",
|
||||
}
|
||||
|
||||
|
||||
def _execute_voiceover_animation_task(
|
||||
task_id: str,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
user_id: str,
|
||||
image_bytes: bytes,
|
||||
audio_bytes: bytes,
|
||||
auth_token: Optional[str] = None,
|
||||
):
|
||||
"""Background task to generate InfiniteTalk video with progress updates."""
|
||||
try:
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=5.0, message="Submitting to WaveSpeed InfiniteTalk..."
|
||||
)
|
||||
|
||||
animation_result = animate_scene_with_voiceover(
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
resolution=request.resolution or "720p",
|
||||
prompt_override=request.prompt,
|
||||
image_mime=_guess_mime_from_url(request.image_url, "image/png"),
|
||||
audio_mime=_guess_mime_from_url(request.audio_url, "audio/mpeg"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=80.0, message="Saving video file..."
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
# Build authenticated URL if token provided, otherwise return plain URL
|
||||
video_url = f"/api/story/videos/ai/{video_filename}"
|
||||
if auth_token:
|
||||
video_url = f"{video_url}?token={quote(auth_token)}"
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateSceneVoiceover] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] ✅ Completed user=%s scene=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
# Save video asset to library
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="story_writer",
|
||||
filename=video_filename,
|
||||
file_url=video_url,
|
||||
file_path=str(ai_video_dir / video_filename),
|
||||
file_size=len(animation_result["video_bytes"]),
|
||||
mime_type="video/mp4",
|
||||
title=f"Scene {request.scene_number} Animation (Voiceover)",
|
||||
description=f"Animated scene {request.scene_number} with voiceover from story",
|
||||
prompt=animation_result["prompt"],
|
||||
tags=["story_writer", "video", "animation", "voiceover", f"scene_{request.scene_number}"],
|
||||
provider=animation_result["provider"],
|
||||
model=animation_result.get("model_name"),
|
||||
cost=animation_result["cost"],
|
||||
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
result = AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="InfiniteTalk animation complete!",
|
||||
result=result.dict(),
|
||||
)
|
||||
except HTTPException as exc:
|
||||
error_msg = str(exc.detail) if isinstance(exc.detail, str) else exc.detail.get("error", "Animation failed") if isinstance(exc.detail, dict) else "Animation failed"
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Failed: {error_msg}")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation failed: {error_msg}",
|
||||
)
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Error: {error_msg}", exc_info=True)
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation error: {error_msg}",
|
||||
)
|
||||
|
||||
297
backend/api/story_writer/routes/story_content.py
Normal file
297
backend/api/story_writer/routes/story_content.py
Normal file
@@ -0,0 +1,297 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
StoryStartRequest,
|
||||
StoryContentResponse,
|
||||
StoryScene,
|
||||
StoryContinueRequest,
|
||||
StoryContinueResponse,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
scene_approval_store: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {}
|
||||
APPROVAL_TTL_SECONDS = 60 * 60 * 24
|
||||
MAX_APPROVALS_PER_USER = 200
|
||||
|
||||
|
||||
def _cleanup_user_approvals(user_id: str) -> None:
|
||||
user_store = scene_approval_store.get(user_id)
|
||||
if not user_store:
|
||||
return
|
||||
now = datetime.utcnow()
|
||||
for project_id in list(user_store.keys()):
|
||||
scenes = user_store.get(project_id, {})
|
||||
for scene_id in list(scenes.keys()):
|
||||
timestamp = scenes[scene_id].get("timestamp")
|
||||
if isinstance(timestamp, datetime):
|
||||
if (now - timestamp).total_seconds() > APPROVAL_TTL_SECONDS:
|
||||
scenes.pop(scene_id, None)
|
||||
if not scenes:
|
||||
user_store.pop(project_id, None)
|
||||
if not user_store:
|
||||
scene_approval_store.pop(user_id, None)
|
||||
|
||||
|
||||
def _enforce_capacity(user_id: str) -> None:
|
||||
user_store = scene_approval_store.get(user_id)
|
||||
if not user_store:
|
||||
return
|
||||
entries: List[tuple[datetime, str, str]] = []
|
||||
for project_id, scenes in user_store.items():
|
||||
for scene_id, meta in scenes.items():
|
||||
timestamp = meta.get("timestamp")
|
||||
if isinstance(timestamp, datetime):
|
||||
entries.append((timestamp, project_id, scene_id))
|
||||
if len(entries) <= MAX_APPROVALS_PER_USER:
|
||||
return
|
||||
entries.sort(key=lambda item: item[0])
|
||||
to_remove = len(entries) - MAX_APPROVALS_PER_USER
|
||||
for i in range(to_remove):
|
||||
_, project_id, scene_id = entries[i]
|
||||
scenes = user_store.get(project_id)
|
||||
if not scenes:
|
||||
continue
|
||||
scenes.pop(scene_id, None)
|
||||
if not scenes:
|
||||
user_store.pop(project_id, None)
|
||||
|
||||
|
||||
def _get_user_store(user_id: str) -> Dict[str, Dict[str, Dict[str, Any]]]:
|
||||
_cleanup_user_approvals(user_id)
|
||||
return scene_approval_store.setdefault(user_id, {})
|
||||
|
||||
|
||||
@router.post("/generate-start", response_model=StoryContentResponse)
|
||||
async def generate_story_start(
|
||||
request: StoryStartRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryContentResponse:
|
||||
"""Generate the starting section of a story."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.premise or not request.premise.strip():
|
||||
raise HTTPException(status_code=400, detail="Premise is required")
|
||||
if not request.outline or (isinstance(request.outline, str) and not request.outline.strip()):
|
||||
raise HTTPException(status_code=400, detail="Outline is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating story start for user {user_id}")
|
||||
|
||||
outline_data: Any = request.outline
|
||||
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
|
||||
outline_data = [scene.dict() for scene in outline_data]
|
||||
|
||||
story_length = getattr(request, "story_length", "Medium")
|
||||
story_start = story_service.generate_story_start(
|
||||
premise=request.premise,
|
||||
outline=outline_data,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
story_length=story_length,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
story_length_lower = story_length.lower()
|
||||
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
||||
is_complete = False
|
||||
if is_short_story:
|
||||
word_count = len(story_start.split()) if story_start else 0
|
||||
if word_count >= 900:
|
||||
is_complete = True
|
||||
logger.info(
|
||||
f"[StoryWriter] Short story generated with {word_count} words. Marking as complete."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[StoryWriter] Short story generated with only {word_count} words. May need continuation."
|
||||
)
|
||||
|
||||
outline_response = outline_data
|
||||
if isinstance(outline_response, list):
|
||||
outline_response = "\n".join(
|
||||
[
|
||||
f"Scene {scene.get('scene_number', i + 1)}: "
|
||||
f"{scene.get('title', 'Untitled')}\n {scene.get('description', '')}"
|
||||
for i, scene in enumerate(outline_response)
|
||||
]
|
||||
)
|
||||
|
||||
return StoryContentResponse(
|
||||
story=story_start,
|
||||
premise=request.premise,
|
||||
outline=str(outline_response),
|
||||
is_complete=is_complete,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate story start: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/continue", response_model=StoryContinueResponse)
|
||||
async def continue_story(
|
||||
request: StoryContinueRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryContinueResponse:
|
||||
"""Continue writing a story."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.story_text or not request.story_text.strip():
|
||||
raise HTTPException(status_code=400, detail="Story text is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Continuing story for user {user_id}")
|
||||
|
||||
outline_data: Any = request.outline
|
||||
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
|
||||
outline_data = [scene.dict() for scene in outline_data]
|
||||
|
||||
story_length = getattr(request, "story_length", "Medium")
|
||||
story_length_lower = story_length.lower()
|
||||
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
||||
if is_short_story:
|
||||
logger.warning(
|
||||
"[StoryWriter] Attempted to continue a short story. Short stories should be complete in one call."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Short stories are generated in a single call and should be complete. "
|
||||
"If the story is incomplete, please regenerate it from the beginning.",
|
||||
)
|
||||
|
||||
current_word_count = len(request.story_text.split()) if request.story_text else 0
|
||||
if "long" in story_length_lower or "10000" in story_length_lower:
|
||||
target_total_words = 10000
|
||||
else:
|
||||
target_total_words = 4500
|
||||
buffer_target = int(target_total_words * 1.05)
|
||||
|
||||
if current_word_count >= buffer_target or (
|
||||
current_word_count >= target_total_words
|
||||
and (current_word_count - target_total_words) < 50
|
||||
):
|
||||
logger.info(
|
||||
f"[StoryWriter] Word count ({current_word_count}) already at or near target ({target_total_words})."
|
||||
)
|
||||
return StoryContinueResponse(continuation="IAMDONE", is_complete=True, success=True)
|
||||
|
||||
continuation = story_service.continue_story(
|
||||
premise=request.premise,
|
||||
outline=outline_data,
|
||||
story_text=request.story_text,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
story_length=story_length,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
is_complete = "IAMDONE" in continuation.upper()
|
||||
if not is_complete and continuation:
|
||||
new_story_text = request.story_text + "\n\n" + continuation
|
||||
new_word_count = len(new_story_text.split())
|
||||
if new_word_count >= buffer_target:
|
||||
logger.info(
|
||||
f"[StoryWriter] Word count ({new_word_count}) now exceeds buffer target ({buffer_target})."
|
||||
)
|
||||
if "IAMDONE" not in continuation.upper():
|
||||
continuation = continuation.rstrip() + "\n\nIAMDONE"
|
||||
is_complete = True
|
||||
elif new_word_count >= target_total_words and (
|
||||
new_word_count - target_total_words
|
||||
) < 100:
|
||||
logger.info(
|
||||
f"[StoryWriter] Word count ({new_word_count}) is at or very close to target ({target_total_words})."
|
||||
)
|
||||
if "IAMDONE" not in continuation.upper():
|
||||
continuation = continuation.rstrip() + "\n\nIAMDONE"
|
||||
is_complete = True
|
||||
|
||||
return StoryContinueResponse(continuation=continuation, is_complete=is_complete, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to continue story: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
class SceneApprovalRequest(BaseModel):
|
||||
project_id: str = Field(..., min_length=1)
|
||||
scene_id: str = Field(..., min_length=1)
|
||||
approved: bool = True
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("/script/approve")
|
||||
async def approve_script_scene(
|
||||
request: SceneApprovalRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Persist scene approval metadata for auditing."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
if not request.project_id.strip() or not request.scene_id.strip():
|
||||
raise HTTPException(status_code=400, detail="project_id and scene_id are required")
|
||||
|
||||
notes = request.notes.strip() if request.notes else None
|
||||
user_store = _get_user_store(user_id)
|
||||
project_store = user_store.setdefault(request.project_id, {})
|
||||
timestamp = datetime.utcnow()
|
||||
project_store[request.scene_id] = {
|
||||
"approved": request.approved,
|
||||
"notes": notes,
|
||||
"user_id": user_id,
|
||||
"timestamp": timestamp,
|
||||
}
|
||||
_enforce_capacity(user_id)
|
||||
logger.info(
|
||||
"[StoryWriter] Scene approval recorded user=%s project=%s scene=%s approved=%s",
|
||||
user_id,
|
||||
request.project_id,
|
||||
request.scene_id,
|
||||
request.approved,
|
||||
)
|
||||
return {
|
||||
"success": True,
|
||||
"project_id": request.project_id,
|
||||
"scene_id": request.scene_id,
|
||||
"approved": request.approved,
|
||||
"timestamp": timestamp.isoformat(),
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to approve scene: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
141
backend/api/story_writer/routes/story_setup.py
Normal file
141
backend/api/story_writer/routes/story_setup.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
StorySetupGenerationRequest,
|
||||
StorySetupGenerationResponse,
|
||||
StorySetupOption,
|
||||
StoryGenerationRequest,
|
||||
StoryOutlineResponse,
|
||||
StoryScene,
|
||||
StoryStartRequest,
|
||||
StoryPremiseResponse,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
|
||||
|
||||
@router.post("/generate-setup", response_model=StorySetupGenerationResponse)
|
||||
async def generate_story_setup(
|
||||
request: StorySetupGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StorySetupGenerationResponse:
|
||||
"""Generate 3 story setup options from a user's story idea."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.story_idea or not request.story_idea.strip():
|
||||
raise HTTPException(status_code=400, detail="Story idea is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
|
||||
|
||||
options = story_service.generate_story_setup_options(
|
||||
story_idea=request.story_idea,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
setup_options = [StorySetupOption(**option) for option in options]
|
||||
return StorySetupGenerationResponse(options=setup_options, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate story setup options: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-premise", response_model=StoryPremiseResponse)
|
||||
async def generate_premise(
|
||||
request: StoryGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryPremiseResponse:
|
||||
"""Generate a story premise."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.info(f"[StoryWriter] Generating premise for user {user_id}")
|
||||
|
||||
premise = story_service.generate_premise(
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return StoryPremiseResponse(premise=premise, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate premise: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-outline", response_model=StoryOutlineResponse)
|
||||
async def generate_outline(
|
||||
request: StoryStartRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
use_structured: bool = True,
|
||||
) -> StoryOutlineResponse:
|
||||
"""Generate a story outline from a premise."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.premise or not request.premise.strip():
|
||||
raise HTTPException(status_code=400, detail="Premise is required")
|
||||
|
||||
logger.info(
|
||||
f"[StoryWriter] Generating outline for user {user_id} (structured={use_structured})"
|
||||
)
|
||||
logger.info(
|
||||
"[StoryWriter] Outline params: audience_age_group=%s, writing_style=%s, story_tone=%s",
|
||||
request.audience_age_group,
|
||||
request.writing_style,
|
||||
request.story_tone,
|
||||
)
|
||||
|
||||
outline = story_service.generate_outline(
|
||||
premise=request.premise,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
user_id=user_id,
|
||||
use_structured_output=use_structured,
|
||||
)
|
||||
|
||||
if isinstance(outline, list):
|
||||
scenes: List[StoryScene] = [
|
||||
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline
|
||||
]
|
||||
return StoryOutlineResponse(outline=scenes, success=True, is_structured=True)
|
||||
|
||||
return StoryOutlineResponse(outline=str(outline), success=True, is_structured=False)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate outline: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
130
backend/api/story_writer/routes/story_tasks.py
Normal file
130
backend/api/story_writer/routes/story_tasks.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
StoryGenerationRequest,
|
||||
TaskStatus,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..cache_manager import cache_manager
|
||||
from ..task_manager import task_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
|
||||
|
||||
@router.post("/generate-full", response_model=Dict[str, Any])
|
||||
async def generate_full_story(
|
||||
request: StoryGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
max_iterations: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a complete story asynchronously."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
cache_key = cache_manager.get_cache_key(request.dict())
|
||||
cached_result = cache_manager.get_cached_result(cache_key)
|
||||
if cached_result:
|
||||
logger.info(f"[StoryWriter] Returning cached result for user {user_id}")
|
||||
task_id = task_manager.create_task("story_generation")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
result=cached_result,
|
||||
message="Returned cached result",
|
||||
)
|
||||
return {"task_id": task_id, "cached": True}
|
||||
|
||||
task_id = task_manager.create_task("story_generation")
|
||||
request_data = request.dict()
|
||||
request_data["max_iterations"] = max_iterations
|
||||
|
||||
background_tasks.add_task(
|
||||
task_manager.execute_story_generation_task,
|
||||
task_id=task_id,
|
||||
request_data=request_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Created task {task_id} for full story generation (user {user_id})")
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Story generation started. Use /task/{task_id}/status to check progress.",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to start story generation: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/status", response_model=TaskStatus)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> TaskStatus:
|
||||
"""Get the status of a story generation task."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
task_status = task_manager.get_task_status(task_id)
|
||||
if not task_status:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
|
||||
return TaskStatus(**task_status)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to get task status: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/result")
|
||||
async def get_task_result(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the result of a completed story generation task."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
task_status = task_manager.get_task_status(task_id)
|
||||
if not task_status:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
if task_status["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Task {task_id} is not completed. Status: {task_status['status']}",
|
||||
)
|
||||
|
||||
result = task_status.get("result")
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
|
||||
|
||||
if isinstance(result, dict):
|
||||
payload = {**result}
|
||||
payload.setdefault("success", True)
|
||||
payload["task_id"] = task_id
|
||||
return payload
|
||||
|
||||
return {"result": result, "success": True, "task_id": task_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to get task result: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
539
backend/api/story_writer/routes/video_generation.py
Normal file
539
backend/api/story_writer/routes/video_generation.py
Normal file
@@ -0,0 +1,539 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from models.story_models import (
|
||||
StoryVideoGenerationRequest,
|
||||
StoryVideoGenerationResponse,
|
||||
StoryVideoResult,
|
||||
StoryScene,
|
||||
StoryGenerationRequest,
|
||||
)
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..task_manager import task_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.hd_video import (
|
||||
generate_hd_video_payload,
|
||||
generate_hd_video_scene_payload,
|
||||
)
|
||||
from ..utils.media_utils import resolve_media_file
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
video_service = StoryVideoGenerationService()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
story_service = StoryWriterService()
|
||||
|
||||
|
||||
class HDVideoRequest(BaseModel):
|
||||
prompt: str
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
class HDVideoSceneRequest(BaseModel):
|
||||
scene_number: int
|
||||
scene_data: Dict[str, Any]
|
||||
story_context: Dict[str, Any]
|
||||
all_scenes: List[Dict[str, Any]]
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@router.post("/generate-video", response_model=StoryVideoGenerationResponse)
|
||||
async def generate_story_video(
|
||||
request: StoryVideoGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryVideoGenerationResponse:
|
||||
"""Generate a video from story scenes, images, and audio."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of scenes, image URLs, and audio URLs must match",
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Generating video for {len(request.scenes)} scenes for user {user_id}")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
video_paths: List[Optional[str]] = [] # Animated videos (preferred)
|
||||
image_paths: List[Optional[str]] = [] # Static images (fallback)
|
||||
audio_paths: List[str] = []
|
||||
valid_scenes: List[Dict[str, Any]] = []
|
||||
|
||||
# Resolve video/audio directories
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
|
||||
|
||||
video_urls = request.video_urls or [None] * len(request.scenes)
|
||||
ai_audio_urls = request.ai_audio_urls or [None] * len(request.scenes)
|
||||
|
||||
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
|
||||
# Prefer animated video if available
|
||||
video_url = video_urls[idx] if idx < len(video_urls) else None
|
||||
video_path = None
|
||||
image_path = None
|
||||
|
||||
if video_url:
|
||||
# Extract filename from animated video URL (e.g., /api/story/videos/ai/filename.mp4)
|
||||
video_filename = video_url.split("/")[-1].split("?")[0]
|
||||
video_path = ai_video_dir / video_filename
|
||||
if video_path.exists():
|
||||
logger.info(f"[StoryWriter] Using animated video for scene {scene.get('scene_number', idx+1)}: {video_filename}")
|
||||
video_paths.append(str(video_path))
|
||||
image_paths.append(None)
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Animated video not found: {video_path}, falling back to image")
|
||||
video_paths.append(None)
|
||||
video_path = None
|
||||
|
||||
# Fall back to image if no animated video
|
||||
if not video_path:
|
||||
image_filename = image_url.split("/")[-1].split("?")[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
if image_path.exists():
|
||||
video_paths.append(None)
|
||||
image_paths.append(str(image_path))
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
|
||||
# Prefer AI audio if available, otherwise use free audio
|
||||
ai_audio_url = ai_audio_urls[idx] if idx < len(ai_audio_urls) else None
|
||||
audio_filename = None
|
||||
audio_path = None
|
||||
|
||||
if ai_audio_url:
|
||||
audio_filename = ai_audio_url.split("/")[-1].split("?")[0]
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if audio_path.exists():
|
||||
logger.info(f"[StoryWriter] Using AI audio for scene {scene.get('scene_number', idx+1)}: {audio_filename}")
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] AI audio not found: {audio_path}, falling back to free audio")
|
||||
audio_path = None
|
||||
|
||||
# Fall back to free audio if no AI audio
|
||||
if not audio_path:
|
||||
audio_filename = audio_url.split("/")[-1].split("?")[0]
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if len(valid_scenes) == 0 or len(audio_paths) == 0:
|
||||
raise HTTPException(status_code=400, detail="No valid video/image or audio files were found")
|
||||
if len(valid_scenes) != len(audio_paths):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of valid scenes and audio files must match",
|
||||
)
|
||||
|
||||
video_result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths, # Can contain None for scenes with animated videos
|
||||
video_paths=video_paths, # Can contain None for scenes with static images
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
)
|
||||
|
||||
video_model = StoryVideoResult(
|
||||
video_filename=video_result.get("video_filename", ""),
|
||||
video_url=video_result.get("video_url", ""),
|
||||
duration=video_result.get("duration", 0.0),
|
||||
fps=video_result.get("fps", 24),
|
||||
file_size=video_result.get("file_size", 0),
|
||||
num_scenes=video_result.get("num_scenes", 0),
|
||||
error=video_result.get("error"),
|
||||
)
|
||||
|
||||
return StoryVideoGenerationResponse(video=video_model, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate video: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-video-async", response_model=Dict[str, Any])
|
||||
async def generate_story_video_async(
|
||||
request: StoryVideoGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a video asynchronously with progress updates via task manager.
|
||||
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of scenes, image URLs, and audio URLs must match",
|
||||
)
|
||||
|
||||
task_id = task_manager.create_task("story_video_generation")
|
||||
background_tasks.add_task(
|
||||
_execute_video_generation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
)
|
||||
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to start async video generation: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
|
||||
"""Background task to generate story video with progress mapped to task manager."""
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_paths: List[str] = []
|
||||
audio_paths: List[str] = []
|
||||
valid_scenes: List[Dict[str, Any]] = []
|
||||
|
||||
for scene, image_url, audio_url in zip(scenes_data, request.image_urls, request.audio_urls):
|
||||
image_filename = image_url.split("/")[-1].split("?")[0]
|
||||
audio_filename = audio_url.split("/")[-1].split("?")[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not image_path.exists():
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
image_paths.append(str(image_path))
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
|
||||
|
||||
def progress_callback(sub_progress: float, msg: str):
|
||||
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
|
||||
|
||||
result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Video generation complete!",
|
||||
result={"video": result, "success": True},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Async video generation failed: {exc}", exc_info=True)
|
||||
task_manager.update_task_status(task_id, "failed", error=str(exc), message=f"Video generation failed: {exc}")
|
||||
|
||||
|
||||
@router.post("/hd-video")
|
||||
async def generate_hd_video(
|
||||
request: HDVideoRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/hd-video-scene")
|
||||
async def generate_hd_video_scene(
|
||||
request: HDVideoSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_scene_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-complete-video", response_model=Dict[str, Any])
|
||||
async def generate_complete_story_video(
|
||||
request: StoryGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a complete story video workflow asynchronously."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.info(f"[StoryWriter] Starting complete video generation for user {user_id}")
|
||||
|
||||
task_id = task_manager.create_task("complete_video_generation")
|
||||
background_tasks.add_task(
|
||||
execute_complete_video_generation,
|
||||
task_id=task_id,
|
||||
request_data=request.dict(),
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Complete video generation started",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to start complete video generation: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
def execute_complete_video_generation(
|
||||
task_id: str,
|
||||
request_data: Dict[str, Any],
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Execute complete video generation workflow synchronously.
|
||||
Runs in a background task and performs blocking operations.
|
||||
"""
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=5.0, message="Starting complete video generation...")
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
|
||||
premise = story_service.generate_premise(
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=20.0, message="Generating structured outline with scenes...")
|
||||
outline_scenes = story_service.generate_outline(
|
||||
premise=premise,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id,
|
||||
use_structured_output=True,
|
||||
)
|
||||
|
||||
if not isinstance(outline_scenes, list):
|
||||
raise RuntimeError("Failed to generate structured outline")
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=30.0, message="Generating images for scenes...")
|
||||
|
||||
def image_progress_callback(sub_progress: float, message: str):
|
||||
overall_progress = 30.0 + (sub_progress * 0.2)
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
|
||||
|
||||
image_results = image_service.generate_scene_images(
|
||||
scenes=outline_scenes,
|
||||
user_id=user_id,
|
||||
provider=request_data.get("image_provider"),
|
||||
width=request_data.get("image_width", 1024),
|
||||
height=request_data.get("image_height", 1024),
|
||||
model=request_data.get("image_model"),
|
||||
progress_callback=image_progress_callback,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=50.0, message="Generating audio narration for scenes...")
|
||||
|
||||
def audio_progress_callback(sub_progress: float, message: str):
|
||||
overall_progress = 50.0 + (sub_progress * 0.2)
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
|
||||
|
||||
audio_results = audio_service.generate_scene_audio_list(
|
||||
scenes=outline_scenes,
|
||||
user_id=user_id,
|
||||
provider=request_data.get("audio_provider", "gtts"),
|
||||
lang=request_data.get("audio_lang", "en"),
|
||||
slow=request_data.get("audio_slow", False),
|
||||
rate=request_data.get("audio_rate", 150),
|
||||
progress_callback=audio_progress_callback,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=70.0, message="Preparing video assets...")
|
||||
image_paths: List[str] = []
|
||||
audio_paths: List[str] = []
|
||||
valid_scenes: List[Dict[str, Any]] = []
|
||||
|
||||
for scene in outline_scenes:
|
||||
scene_number = scene.get("scene_number", 0)
|
||||
image_result = next((img for img in image_results if img.get("scene_number") == scene_number), None)
|
||||
audio_result = next((aud for aud in audio_results if aud.get("scene_number") == scene_number), None)
|
||||
|
||||
if image_result and audio_result and not image_result.get("error") and not audio_result.get("error"):
|
||||
image_path = image_result.get("image_path")
|
||||
audio_path = audio_result.get("audio_path")
|
||||
if image_path and audio_path:
|
||||
image_paths.append(image_path)
|
||||
audio_paths.append(audio_path)
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if len(image_paths) == 0 or len(audio_paths) == 0:
|
||||
raise RuntimeError(
|
||||
f"No valid images or audio files were generated. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
|
||||
)
|
||||
if len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError(
|
||||
f"Mismatch between image and audio counts. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=75.0, message="Composing video from scenes...")
|
||||
|
||||
def video_progress_callback(sub_progress: float, message: str):
|
||||
overall_progress = 75.0 + (sub_progress * 0.2)
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
|
||||
|
||||
video_result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request_data.get("story_setting", "Story")[:50],
|
||||
fps=request_data.get("video_fps", 24),
|
||||
transition_duration=request_data.get("video_transition_duration", 0.5),
|
||||
progress_callback=video_progress_callback,
|
||||
)
|
||||
|
||||
result = {
|
||||
"premise": premise,
|
||||
"outline_scenes": outline_scenes,
|
||||
"images": image_results,
|
||||
"audio_files": audio_results,
|
||||
"video": video_result,
|
||||
"success": True,
|
||||
}
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Complete video generation finished!",
|
||||
result=result,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Complete video generation task {task_id} completed successfully")
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.error(f"[StoryWriter] Complete video generation task {task_id} failed: {error_msg}", exc_info=True)
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"Complete video generation failed: {error_msg}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/videos/{video_filename}")
|
||||
async def serve_story_video(
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve a generated story video file."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
video_path = resolve_media_file(video_service.output_dir, video_filename)
|
||||
return FileResponse(path=str(video_path), media_type="video/mp4", filename=video_filename)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve video: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/videos/ai/{video_filename}")
|
||||
async def serve_ai_story_video(
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve a generated AI scene animation video."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
|
||||
video_service_ai = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
video_path = resolve_media_file(video_service_ai.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve AI video: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
253
backend/api/story_writer/task_manager.py
Normal file
253
backend/api/story_writer/task_manager.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Task Management System for Story Writer API
|
||||
|
||||
Handles background task execution, status tracking, and progress updates
|
||||
for story generation operations.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""Manages background tasks for story generation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the task manager."""
|
||||
self.task_storage: Dict[str, Dict[str, Any]] = {}
|
||||
logger.info("[StoryWriter] TaskManager initialized")
|
||||
|
||||
def cleanup_old_tasks(self):
|
||||
"""Remove tasks older than 1 hour to prevent memory leaks."""
|
||||
current_time = datetime.now()
|
||||
tasks_to_remove = []
|
||||
|
||||
for task_id, task_data in self.task_storage.items():
|
||||
created_at = task_data.get("created_at")
|
||||
if created_at and (current_time - created_at).total_seconds() > 3600: # 1 hour
|
||||
tasks_to_remove.append(task_id)
|
||||
|
||||
for task_id in tasks_to_remove:
|
||||
del self.task_storage[task_id]
|
||||
logger.debug(f"[StoryWriter] Cleaned up old task: {task_id}")
|
||||
|
||||
def create_task(self, task_type: str = "story_generation") -> str:
|
||||
"""Create a new task and return its ID."""
|
||||
task_id = str(uuid.uuid4())
|
||||
|
||||
self.task_storage[task_id] = {
|
||||
"status": "pending",
|
||||
"created_at": datetime.now(),
|
||||
"result": None,
|
||||
"error": None,
|
||||
"progress_messages": [],
|
||||
"task_type": task_type,
|
||||
"progress": 0.0
|
||||
}
|
||||
|
||||
logger.info(f"[StoryWriter] Created task: {task_id} (type: {task_type})")
|
||||
return task_id
|
||||
|
||||
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get the status of a task."""
|
||||
self.cleanup_old_tasks()
|
||||
|
||||
if task_id not in self.task_storage:
|
||||
# Log at DEBUG level - task not found is expected when tasks expire or are cleaned up
|
||||
# This prevents log spam from frontend polling for expired/completed tasks
|
||||
logger.debug(f"[StoryWriter] Task not found: {task_id} (may have expired or been cleaned up)")
|
||||
return None
|
||||
|
||||
task = self.task_storage[task_id]
|
||||
response = {
|
||||
"task_id": task_id,
|
||||
"status": task["status"],
|
||||
"progress": task.get("progress", 0.0),
|
||||
"message": task.get("progress_messages", [])[-1] if task.get("progress_messages") else None,
|
||||
"created_at": task["created_at"].isoformat() if task.get("created_at") else None,
|
||||
"updated_at": task.get("updated_at", task.get("created_at")).isoformat() if task.get("updated_at") or task.get("created_at") else None,
|
||||
}
|
||||
|
||||
if task["status"] == "completed" and task.get("result"):
|
||||
response["result"] = task["result"]
|
||||
|
||||
if task["status"] == "failed" and task.get("error"):
|
||||
response["error"] = task["error"]
|
||||
|
||||
return response
|
||||
|
||||
def update_task_status(
|
||||
self,
|
||||
task_id: str,
|
||||
status: str,
|
||||
progress: Optional[float] = None,
|
||||
message: Optional[str] = None,
|
||||
result: Optional[Dict[str, Any]] = None,
|
||||
error: Optional[str] = None
|
||||
):
|
||||
"""Update the status of a task."""
|
||||
if task_id not in self.task_storage:
|
||||
logger.warning(f"[StoryWriter] Cannot update non-existent task: {task_id}")
|
||||
return
|
||||
|
||||
task = self.task_storage[task_id]
|
||||
task["status"] = status
|
||||
task["updated_at"] = datetime.now()
|
||||
|
||||
if progress is not None:
|
||||
task["progress"] = progress
|
||||
|
||||
if message:
|
||||
if "progress_messages" not in task:
|
||||
task["progress_messages"] = []
|
||||
task["progress_messages"].append(message)
|
||||
logger.info(f"[StoryWriter] Task {task_id}: {message} (progress: {progress}%)")
|
||||
|
||||
if result is not None:
|
||||
task["result"] = result
|
||||
|
||||
if error is not None:
|
||||
task["error"] = error
|
||||
logger.error(f"[StoryWriter] Task {task_id} error: {error}")
|
||||
|
||||
async def execute_story_generation_task(
|
||||
self,
|
||||
task_id: str,
|
||||
request_data: Dict[str, Any],
|
||||
user_id: str
|
||||
):
|
||||
"""Execute story generation task asynchronously."""
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
service = StoryWriterService()
|
||||
|
||||
try:
|
||||
self.update_task_status(task_id, "processing", progress=0.0, message="Starting story generation...")
|
||||
|
||||
# Step 1: Generate premise
|
||||
self.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
|
||||
premise = service.generate_premise(
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Step 2: Generate outline
|
||||
self.update_task_status(task_id, "processing", progress=30.0, message="Generating story outline...")
|
||||
outline = service.generate_outline(
|
||||
premise=premise,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Step 3: Generate story start
|
||||
self.update_task_status(task_id, "processing", progress=50.0, message="Writing story beginning...")
|
||||
story_start = service.generate_story_start(
|
||||
premise=premise,
|
||||
outline=outline,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Step 4: Continue story
|
||||
self.update_task_status(task_id, "processing", progress=70.0, message="Continuing story generation...")
|
||||
story_text = story_start
|
||||
max_iterations = request_data.get("max_iterations", 10)
|
||||
iteration = 0
|
||||
|
||||
while 'IAMDONE' not in story_text and iteration < max_iterations:
|
||||
iteration += 1
|
||||
progress = 70.0 + (iteration / max_iterations) * 25.0
|
||||
self.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=min(progress, 95.0),
|
||||
message=f"Writing continuation {iteration}/{max_iterations}..."
|
||||
)
|
||||
|
||||
continuation = service.continue_story(
|
||||
premise=premise,
|
||||
outline=outline,
|
||||
story_text=story_text,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if continuation:
|
||||
story_text += '\n\n' + continuation
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Empty continuation at iteration {iteration}")
|
||||
break
|
||||
|
||||
# Clean up and finalize
|
||||
final_story = story_text.replace('IAMDONE', '').strip()
|
||||
|
||||
result = {
|
||||
"premise": premise,
|
||||
"outline": outline,
|
||||
"story": final_story,
|
||||
"is_complete": 'IAMDONE' in story_text or iteration >= max_iterations,
|
||||
"iterations": iteration
|
||||
}
|
||||
|
||||
self.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Story generation completed!",
|
||||
result=result
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Task {task_id} completed successfully")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[StoryWriter] Task {task_id} failed: {error_msg}")
|
||||
self.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"Story generation failed: {error_msg}"
|
||||
)
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
task_manager = TaskManager()
|
||||
8
backend/api/story_writer/utils/__init__.py
Normal file
8
backend/api/story_writer/utils/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
"""
|
||||
Utility helpers for Story Writer API routes.
|
||||
|
||||
Grouped here to keep the main router lean while reusing common logic
|
||||
such as authentication guards, media resolution, and HD video helpers.
|
||||
"""
|
||||
|
||||
|
||||
23
backend/api/story_writer/utils/auth.py
Normal file
23
backend/api/story_writer/utils/auth.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
|
||||
"""
|
||||
Validates the current user dictionary provided by Clerk middleware and
|
||||
returns the normalized user_id. Raises HTTP 401 if authentication fails.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", "")).strip()
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid user ID in authentication token",
|
||||
)
|
||||
|
||||
return user_id
|
||||
|
||||
|
||||
146
backend/api/story_writer/utils/hd_video.py
Normal file
146
backend/api/story_writer/utils/hd_video.py
Normal file
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from uuid import uuid4
|
||||
|
||||
|
||||
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""Handles synchronous HD video generation."""
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
output_dir = video_service.output_dir
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
|
||||
result = ai_video_generate(
|
||||
prompt=request.prompt,
|
||||
operation_type="text-to-video",
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract video bytes from result dict
|
||||
video_bytes = result["video_bytes"]
|
||||
|
||||
filename = f"hd_{uuid4().hex}.mp4"
|
||||
file_path = output_dir / filename
|
||||
with open(file_path, "wb") as fh:
|
||||
fh.write(video_bytes)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved to {file_path}")
|
||||
return {
|
||||
"success": True,
|
||||
"video_filename": filename,
|
||||
"video_url": f"/api/story/videos/{filename}",
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Handles per-scene HD video generation including prompt enhancement
|
||||
and subscription validation.
|
||||
"""
|
||||
from services.database import get_db as get_db_validation
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
|
||||
from services.llm_providers.main_video_generation import ai_video_generate
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
|
||||
scene_number = request.scene_number
|
||||
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
|
||||
|
||||
hf_token = APIKeyManager().get_api_key("hf_token")
|
||||
if not hf_token:
|
||||
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
|
||||
},
|
||||
)
|
||||
|
||||
db_validation = next(get_db_validation())
|
||||
try:
|
||||
pricing_service = PricingService(db_validation)
|
||||
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
|
||||
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
|
||||
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
|
||||
finally:
|
||||
db_validation.close()
|
||||
|
||||
enhanced_prompt = enhance_scene_prompt_for_video(
|
||||
current_scene=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
all_scenes=request.all_scenes,
|
||||
user_id=user_id,
|
||||
)
|
||||
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
if getattr(request, "num_frames", None):
|
||||
kwargs["num_frames"] = request.num_frames
|
||||
if getattr(request, "guidance_scale", None) is not None:
|
||||
kwargs["guidance_scale"] = request.guidance_scale
|
||||
if getattr(request, "num_inference_steps", None):
|
||||
kwargs["num_inference_steps"] = request.num_inference_steps
|
||||
if getattr(request, "negative_prompt", None):
|
||||
kwargs["negative_prompt"] = request.negative_prompt
|
||||
if getattr(request, "seed", None) is not None:
|
||||
kwargs["seed"] = request.seed
|
||||
|
||||
result = ai_video_generate(
|
||||
prompt=enhanced_prompt,
|
||||
operation_type="text-to-video",
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Extract video bytes from result dict
|
||||
video_bytes = result["video_bytes"]
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=video_bytes,
|
||||
scene_number=scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
|
||||
return {
|
||||
"success": True,
|
||||
"scene_number": scene_number,
|
||||
"video_filename": save_result.get("video_filename"),
|
||||
"video_url": save_result.get("video_url"),
|
||||
"prompt_used": enhanced_prompt,
|
||||
"provider": getattr(request, "provider", None) or "huggingface",
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
148
backend/api/story_writer/utils/media_utils.py
Normal file
148
backend/api/story_writer/utils/media_utils.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from loguru import logger
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
|
||||
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
|
||||
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
STORY_AUDIO_DIR = (BASE_DIR / "story_audio").resolve()
|
||||
STORY_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
|
||||
"""
|
||||
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
|
||||
Returns None if the file cannot be located.
|
||||
"""
|
||||
if not image_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(image_url)
|
||||
path = parsed.path if parsed.scheme else image_url
|
||||
prefix = "/api/story/images/"
|
||||
if prefix not in path:
|
||||
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
|
||||
return None
|
||||
|
||||
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
file_path = (STORY_IMAGES_DIR / filename).resolve()
|
||||
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
|
||||
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
|
||||
return None
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
|
||||
return None
|
||||
|
||||
return file_path.read_bytes()
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def load_story_audio_bytes(audio_url: str) -> Optional[bytes]:
|
||||
"""
|
||||
Resolve an authenticated story audio URL (e.g., /api/story/audio/<file>) to raw bytes.
|
||||
Returns None if the file cannot be located.
|
||||
"""
|
||||
if not audio_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(audio_url)
|
||||
path = parsed.path if parsed.scheme else audio_url
|
||||
prefix = "/api/story/audio/"
|
||||
if prefix not in path:
|
||||
logger.warning(f"[StoryWriter] Unsupported audio URL for video reference: {audio_url}")
|
||||
return None
|
||||
|
||||
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
file_path = (STORY_AUDIO_DIR / filename).resolve()
|
||||
if not str(file_path).startswith(str(STORY_AUDIO_DIR)):
|
||||
logger.error(f"[StoryWriter] Attempted path traversal when resolving audio: {audio_url}")
|
||||
return None
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"[StoryWriter] Referenced scene audio not found on disk: {file_path}")
|
||||
return None
|
||||
|
||||
return file_path.read_bytes()
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to load reference audio for video gen: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_media_file(base_dir: Path, filename: str) -> Path:
|
||||
"""
|
||||
Returns a safe resolved path for a media file stored under base_dir.
|
||||
Guards against directory traversal and ensures the file exists.
|
||||
"""
|
||||
filename = filename.split("?")[0].strip()
|
||||
resolved = (base_dir / filename).resolve()
|
||||
|
||||
try:
|
||||
resolved.relative_to(base_dir.resolve())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
if not resolved.exists():
|
||||
alternate = _find_alternate_media_file(base_dir, filename)
|
||||
if alternate:
|
||||
logger.warning(
|
||||
"[StoryWriter] Requested media file '%s' missing; serving closest match '%s'",
|
||||
filename,
|
||||
alternate.name,
|
||||
)
|
||||
return alternate
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def _find_alternate_media_file(base_dir: Path, filename: str) -> Optional[Path]:
|
||||
"""
|
||||
Attempt to find the most recent media file that matches the original name prefix.
|
||||
|
||||
This helps when files are regenerated with new UUID/hash suffixes but the frontend still
|
||||
references an older filename.
|
||||
"""
|
||||
try:
|
||||
base_dir = base_dir.resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
stem = Path(filename).stem
|
||||
suffix = Path(filename).suffix
|
||||
|
||||
if not suffix or "_" not in stem:
|
||||
return None
|
||||
|
||||
prefix = stem.rsplit("_", 1)[0]
|
||||
pattern = f"{prefix}_*{suffix}"
|
||||
|
||||
try:
|
||||
candidates = sorted(
|
||||
(p for p in base_dir.glob(pattern) if p.is_file()),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[StoryWriter] Failed to search alternate media files for {filename}: {exc}")
|
||||
return None
|
||||
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user