Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

View File

@@ -0,0 +1,9 @@
"""
Story Writer API
API endpoints for story generation functionality.
"""
from .router import router
__all__ = ['router']

View File

@@ -0,0 +1,70 @@
"""
Cache Management System for Story Writer API
Handles story generation cache operations.
"""
from typing import Any, Dict, Optional
from loguru import logger
class CacheManager:
"""Manages cache operations for story generation data."""
def __init__(self):
"""Initialize the cache manager."""
self.cache: Dict[str, Dict[str, Any]] = {}
logger.info("[StoryWriter] CacheManager initialized")
def get_cache_key(self, request_data: Dict[str, Any]) -> str:
"""Generate a cache key from request data."""
import hashlib
import json
# Create a normalized version of the request for caching
cache_data = {
"persona": request_data.get("persona", ""),
"story_setting": request_data.get("story_setting", ""),
"character_input": request_data.get("character_input", ""),
"plot_elements": request_data.get("plot_elements", ""),
"writing_style": request_data.get("writing_style", ""),
"story_tone": request_data.get("story_tone", ""),
"narrative_pov": request_data.get("narrative_pov", ""),
"audience_age_group": request_data.get("audience_age_group", ""),
"content_rating": request_data.get("content_rating", ""),
"ending_preference": request_data.get("ending_preference", ""),
}
cache_str = json.dumps(cache_data, sort_keys=True)
return hashlib.md5(cache_str.encode()).hexdigest()
def get_cached_result(self, cache_key: str) -> Optional[Dict[str, Any]]:
"""Get a cached result if available."""
if cache_key in self.cache:
logger.debug(f"[StoryWriter] Cache hit for key: {cache_key}")
return self.cache[cache_key]
logger.debug(f"[StoryWriter] Cache miss for key: {cache_key}")
return None
def cache_result(self, cache_key: str, result: Dict[str, Any]):
"""Cache a result."""
self.cache[cache_key] = result
logger.debug(f"[StoryWriter] Cached result for key: {cache_key}")
def clear_cache(self):
"""Clear all cached results."""
count = len(self.cache)
self.cache.clear()
logger.info(f"[StoryWriter] Cleared {count} cached entries")
return {"status": "success", "message": f"Cleared {count} cached entries"}
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics."""
return {
"total_entries": len(self.cache),
"cache_keys": list(self.cache.keys())
}
# Global cache manager instance
cache_manager = CacheManager()

View File

@@ -0,0 +1,37 @@
"""
Story Writer API Router
Main router for story generation operations. This file serves as the entry point
and includes modular sub-routers for different functionality areas.
"""
from typing import Any, Dict
from fastapi import APIRouter
from .routes import (
cache_routes,
media_generation,
scene_animation,
story_content,
story_setup,
story_tasks,
video_generation,
)
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
# Include modular routers (order preserved roughly by workflow)
router.include_router(story_setup.router)
router.include_router(story_content.router)
router.include_router(story_tasks.router)
router.include_router(media_generation.router)
router.include_router(scene_animation.router)
router.include_router(video_generation.router)
router.include_router(cache_routes.router)
@router.get("/health")
async def health() -> Dict[str, Any]:
"""Health check endpoint."""
return {"status": "ok", "service": "story_writer"}

View File

@@ -0,0 +1,23 @@
"""
Collection of modular routers for Story Writer endpoints.
Each module focuses on a related set of routes to keep the primary
`router.py` concise and easier to maintain.
"""
from . import cache_routes
from . import media_generation
from . import scene_animation
from . import story_content
from . import story_setup
from . import story_tasks
from . import video_generation
__all__ = [
"cache_routes",
"media_generation",
"scene_animation",
"story_content",
"story_setup",
"story_tasks",
"video_generation",
]

View File

@@ -0,0 +1,42 @@
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from middleware.auth_middleware import get_current_user
from ..cache_manager import cache_manager
from ..utils.auth import require_authenticated_user
router = APIRouter()
@router.get("/cache/stats")
async def get_cache_stats(
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Get cache statistics."""
try:
require_authenticated_user(current_user)
stats = cache_manager.get_cache_stats()
return {"success": True, "stats": stats}
except Exception as exc:
logger.error(f"[StoryWriter] Failed to get cache stats: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/cache/clear")
async def clear_cache(
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Clear the story generation cache."""
try:
require_authenticated_user(current_user)
result = cache_manager.clear_cache()
return {"success": True, **result}
except Exception as exc:
logger.error(f"[StoryWriter] Failed to clear cache: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,416 @@
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from models.story_models import (
StoryImageGenerationRequest,
StoryImageGenerationResponse,
StoryImageResult,
RegenerateImageRequest,
RegenerateImageResponse,
StoryAudioGenerationRequest,
StoryAudioGenerationResponse,
StoryAudioResult,
GenerateAIAudioRequest,
GenerateAIAudioResponse,
StoryScene,
)
from services.database import get_db
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from utils.asset_tracker import save_asset_to_library
from ..utils.auth import require_authenticated_user
from ..utils.media_utils import resolve_media_file
router = APIRouter()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
@router.post("/generate-images", response_model=StoryImageGenerationResponse)
async def generate_scene_images(
request: StoryImageGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StoryImageGenerationResponse:
"""Generate images for story scenes."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating images for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_results = image_service.generate_scene_images(
scenes=scenes_data,
user_id=user_id,
provider=request.provider,
width=request.width or 1024,
height=request.height or 1024,
model=request.model,
)
image_models: List[StoryImageResult] = [
StoryImageResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
width=result.get("width", 1024),
height=result.get("height", 1024),
provider=result.get("provider", "unknown"),
model=result.get("model"),
seed=result.get("seed"),
error=result.get("error"),
)
for result in image_results
]
# Save assets to library
for result in image_results:
if not result.get("error") and result.get("image_url"):
try:
scene_number = result.get("scene_number", 0)
# Safely get prompt from scenes_data with bounds checking
prompt = None
if scene_number > 0 and scene_number <= len(scenes_data):
prompt = scenes_data[scene_number - 1].get("image_prompt")
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
source_module="story_writer",
filename=result.get("image_filename", ""),
file_url=result.get("image_url", ""),
file_path=result.get("image_path"),
file_size=result.get("file_size"),
mime_type="image/png",
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
description=f"Story scene image for scene {scene_number}",
prompt=prompt,
tags=["story_writer", "scene", f"scene_{scene_number}"],
provider=result.get("provider"),
model=result.get("model"),
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save image asset to library: {e}")
return StoryImageGenerationResponse(images=image_models, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate images: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/regenerate-images", response_model=RegenerateImageResponse)
async def regenerate_scene_image(
request: RegenerateImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> RegenerateImageResponse:
"""Regenerate a single scene image using a direct prompt (no AI prompt generation)."""
try:
user_id = require_authenticated_user(current_user)
if not request.prompt or not request.prompt.strip():
raise HTTPException(status_code=400, detail="Prompt is required")
logger.info(
f"[StoryWriter] Regenerating image for scene {request.scene_number} "
f"({request.scene_title}) for user {user_id}"
)
result = image_service.regenerate_scene_image(
scene_number=request.scene_number,
scene_title=request.scene_title,
prompt=request.prompt.strip(),
user_id=user_id,
provider=request.provider,
width=request.width or 1024,
height=request.height or 1024,
model=request.model,
)
return RegenerateImageResponse(
scene_number=result.get("scene_number", request.scene_number),
scene_title=result.get("scene_title", request.scene_title),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
width=result.get("width", request.width or 1024),
height=result.get("height", request.height or 1024),
provider=result.get("provider", "unknown"),
model=result.get("model"),
seed=result.get("seed"),
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to regenerate image: {exc}")
return RegenerateImageResponse(
scene_number=request.scene_number,
scene_title=request.scene_title,
image_filename="",
image_url="",
width=request.width or 1024,
height=request.height or 1024,
provider=request.provider or "unknown",
success=False,
error=str(exc),
)
@router.get("/images/{image_filename}")
async def serve_scene_image(
image_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
):
"""Serve a generated story scene image.
Supports authentication via Authorization header or token query parameter.
Query parameter is useful for HTML elements like <img> that cannot send custom headers.
"""
try:
require_authenticated_user(current_user)
image_path = resolve_media_file(image_service.output_dir, image_filename)
return FileResponse(path=str(image_path), media_type="image/png", filename=image_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve image: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-audio", response_model=StoryAudioGenerationResponse)
async def generate_scene_audio(
request: StoryAudioGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
) -> StoryAudioGenerationResponse:
"""Generate audio narration for story scenes."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating audio for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
audio_results = audio_service.generate_scene_audio_list(
scenes=scenes_data,
user_id=user_id,
provider=request.provider or "gtts",
lang=request.lang or "en",
slow=request.slow or False,
rate=request.rate or 150,
)
audio_models: List[StoryAudioResult] = []
for result in audio_results:
audio_url = result.get("audio_url") or ""
audio_filename = result.get("audio_filename") or ""
audio_models.append(
StoryAudioResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
audio_filename=audio_filename,
audio_url=audio_url,
provider=result.get("provider", "unknown"),
file_size=result.get("file_size", 0),
error=result.get("error"),
)
)
# Save assets to library
if not result.get("error") and audio_url:
try:
scene_number = result.get("scene_number", 0)
# Safely get prompt from scenes_data with bounds checking
prompt = None
if scene_number > 0 and scene_number <= len(scenes_data):
prompt = scenes_data[scene_number - 1].get("text")
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="audio",
source_module="story_writer",
filename=audio_filename,
file_url=audio_url,
file_path=result.get("audio_path"),
file_size=result.get("file_size"),
mime_type="audio/mpeg",
title=f"Scene {scene_number}: {result.get('scene_title', 'Untitled')}",
description=f"Story scene audio narration for scene {scene_number}",
prompt=prompt,
tags=["story_writer", "audio", "narration", f"scene_{scene_number}"],
provider=result.get("provider"),
model=result.get("model"),
cost=result.get("cost"),
asset_metadata={"scene_number": scene_number, "scene_title": result.get("scene_title"), "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save audio asset to library: {e}")
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate audio: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-ai-audio", response_model=GenerateAIAudioResponse)
async def generate_ai_audio(
request: GenerateAIAudioRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> GenerateAIAudioResponse:
"""Generate AI audio for a single scene using WaveSpeed Minimax Speech 02 HD."""
try:
user_id = require_authenticated_user(current_user)
if not request.text or not request.text.strip():
raise HTTPException(status_code=400, detail="Text is required")
logger.info(
f"[StoryWriter] Generating AI audio for scene {request.scene_number} "
f"({request.scene_title}) for user {user_id}"
)
result = audio_service.generate_ai_audio(
scene_number=request.scene_number,
scene_title=request.scene_title,
text=request.text.strip(),
user_id=user_id,
voice_id=request.voice_id or "Wise_Woman",
speed=request.speed or 1.0,
volume=request.volume or 1.0,
pitch=request.pitch or 0.0,
emotion=request.emotion or "happy",
)
return GenerateAIAudioResponse(
scene_number=result.get("scene_number", request.scene_number),
scene_title=result.get("scene_title", request.scene_title),
audio_filename=result.get("audio_filename", ""),
audio_url=result.get("audio_url", ""),
provider=result.get("provider", "wavespeed"),
model=result.get("model", "minimax/speech-02-hd"),
voice_id=result.get("voice_id", request.voice_id or "Wise_Woman"),
text_length=result.get("text_length", len(request.text)),
file_size=result.get("file_size", 0),
cost=result.get("cost", 0.0),
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate AI audio: {exc}")
return GenerateAIAudioResponse(
scene_number=request.scene_number,
scene_title=request.scene_title,
audio_filename="",
audio_url="",
provider="wavespeed",
model="minimax/speech-02-hd",
voice_id=request.voice_id or "Wise_Woman",
text_length=len(request.text) if request.text else 0,
file_size=0,
cost=0.0,
success=False,
error=str(exc),
)
@router.get("/audio/{audio_filename}")
async def serve_scene_audio(
audio_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve a generated story scene audio file."""
try:
require_authenticated_user(current_user)
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
return FileResponse(path=str(audio_path), media_type="audio/mpeg", filename=audio_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve audio: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
class PromptOptimizeRequest(BaseModel):
text: str = Field(..., description="The prompt text to optimize")
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
style: Optional[str] = Field(
default="default",
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
)
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
class PromptOptimizeResponse(BaseModel):
optimized_prompt: str
success: bool
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
async def optimize_prompt(
request: PromptOptimizeRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> PromptOptimizeResponse:
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
try:
user_id = require_authenticated_user(current_user)
if not request.text or not request.text.strip():
raise HTTPException(status_code=400, detail="Prompt text is required")
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
from services.wavespeed.client import WaveSpeedClient
client = WaveSpeedClient()
optimized_prompt = client.optimize_prompt(
text=request.text.strip(),
mode=request.mode or "image",
style=request.style or "default",
image=request.image, # Optional base64 image
enable_sync_mode=True,
timeout=30
)
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
return PromptOptimizeResponse(
optimized_prompt=optimized_prompt,
success=True
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to optimize prompt: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,484 @@
"""
Scene Animation Routes
Handles scene animation endpoints using WaveSpeed Kling and InfiniteTalk.
"""
import mimetypes
from pathlib import Path
from typing import Any, Dict, Optional
from urllib.parse import quote
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from loguru import logger
from sqlalchemy.orm import Session
from middleware.auth_middleware import get_current_user
from models.story_models import (
AnimateSceneRequest,
AnimateSceneResponse,
AnimateSceneVoiceoverRequest,
ResumeSceneAnimationRequest,
)
from services.database import get_db
from services.llm_providers.main_video_generation import track_video_usage
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_scene_animation_operation
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
from services.wavespeed.kling_animation import animate_scene_image, resume_scene_animation
from utils.asset_tracker import save_asset_to_library
from utils.logger_utils import get_service_logger
from ..task_manager import task_manager
from ..utils.auth import require_authenticated_user
from ..utils.media_utils import load_story_audio_bytes, load_story_image_bytes
router = APIRouter()
scene_logger = get_service_logger("api.story_writer.scene_animation")
AI_VIDEO_SUBDIR = Path("AI_Videos")
def _build_authenticated_media_url(request: Request, path: str) -> str:
"""Append the caller's auth token to a media URL so <video>/<img> tags can access it."""
if not path:
return path
token: Optional[str] = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.replace("Bearer ", "").strip()
elif "token" in request.query_params:
token = request.query_params["token"]
if token:
separator = "&" if "?" in path else "?"
path = f"{path}{separator}token={quote(token)}"
return path
def _guess_mime_from_url(url: str, fallback: str) -> str:
"""Guess MIME type from URL."""
if not url:
return fallback
mime, _ = mimetypes.guess_type(url)
return mime or fallback
@router.post("/animate-scene-preview", response_model=AnimateSceneResponse)
async def animate_scene_preview(
request_obj: Request,
request: AnimateSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimateSceneResponse:
"""
Animate a single scene image using WaveSpeed Kling v2.5 Turbo Std.
"""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
duration = request.duration or 5
if duration not in (5, 10):
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds.")
scene_logger.info(
"[AnimateScene] User=%s scene=%s duration=%s image_url=%s",
user_id,
request.scene_number,
duration,
request.image_url,
)
image_bytes = load_story_image_bytes(request.image_url)
if not image_bytes:
scene_logger.warning("[AnimateScene] Missing image bytes for user=%s scene=%s", user_id, request.scene_number)
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
db = next(get_db())
try:
pricing_service = PricingService(db)
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
animation_result = animate_scene_image(
image_bytes=image_bytes,
scene_data=request.scene_data,
story_context=request.story_context,
user_id=user_id,
duration=duration,
)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
video_url = _build_authenticated_media_url(
request_obj, f"/api/story/videos/ai/{video_filename}"
)
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateScene] Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateScene] ✅ Completed user=%s scene=%s duration=%s cost=$%.2f video=%s",
user_id,
request.scene_number,
animation_result["duration"],
animation_result["cost"],
video_url,
)
# Save video asset to library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="story_writer",
filename=video_filename,
file_url=video_url,
file_path=str(ai_video_dir / video_filename),
file_size=len(animation_result["video_bytes"]),
mime_type="video/mp4",
title=f"Scene {request.scene_number} Animation",
description=f"Animated scene {request.scene_number} from story",
prompt=animation_result["prompt"],
tags=["story_writer", "video", "animation", f"scene_{request.scene_number}"],
provider=animation_result["provider"],
model=animation_result.get("model_name"),
cost=animation_result["cost"],
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
finally:
db.close()
return AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
@router.post("/animate-scene-resume", response_model=AnimateSceneResponse)
async def resume_scene_animation_endpoint(
request_obj: Request,
request: ResumeSceneAnimationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimateSceneResponse:
"""Resume downloading a WaveSpeed animation when the initial call timed out."""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
scene_logger.info(
"[AnimateScene] Resume requested user=%s scene=%s prediction=%s",
user_id,
request.scene_number,
request.prediction_id,
)
animation_result = resume_scene_animation(
prediction_id=request.prediction_id,
duration=request.duration or 5,
user_id=user_id,
)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
video_url = _build_authenticated_media_url(
request_obj, f"/api/story/videos/ai/{video_filename}"
)
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateScene] (Resume) Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateScene] ✅ Resume completed user=%s scene=%s prediction=%s video=%s",
user_id,
request.scene_number,
request.prediction_id,
video_url,
)
return AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
@router.post("/animate-scene-voiceover", response_model=Dict[str, Any])
async def animate_scene_voiceover_endpoint(
request_obj: Request,
request: AnimateSceneVoiceoverRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Animate a scene using WaveSpeed InfiniteTalk (image + audio) asynchronously.
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
"""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
scene_logger.info(
"[AnimateSceneVoiceover] User=%s scene=%s resolution=%s (async)",
user_id,
request.scene_number,
request.resolution or "720p",
)
image_bytes = load_story_image_bytes(request.image_url)
if not image_bytes:
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
audio_bytes = load_story_audio_bytes(request.audio_url)
if not audio_bytes:
raise HTTPException(status_code=404, detail="Scene audio not found. Generate audio first.")
db = next(get_db())
try:
pricing_service = PricingService(db)
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
# Extract token for authenticated URL building (if needed)
auth_token = None
auth_header = request_obj.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.replace("Bearer ", "").strip()
# Create async task
task_id = task_manager.create_task("scene_voiceover_animation")
background_tasks.add_task(
_execute_voiceover_animation_task,
task_id=task_id,
request=request,
user_id=user_id,
image_bytes=image_bytes,
audio_bytes=audio_bytes,
auth_token=auth_token,
)
return {
"task_id": task_id,
"status": "pending",
"message": "InfiniteTalk animation started. This may take up to 10 minutes.",
}
def _execute_voiceover_animation_task(
task_id: str,
request: AnimateSceneVoiceoverRequest,
user_id: str,
image_bytes: bytes,
audio_bytes: bytes,
auth_token: Optional[str] = None,
):
"""Background task to generate InfiniteTalk video with progress updates."""
try:
task_manager.update_task_status(
task_id, "processing", progress=5.0, message="Submitting to WaveSpeed InfiniteTalk..."
)
animation_result = animate_scene_with_voiceover(
image_bytes=image_bytes,
audio_bytes=audio_bytes,
scene_data=request.scene_data,
story_context=request.story_context,
user_id=user_id,
resolution=request.resolution or "720p",
prompt_override=request.prompt,
image_mime=_guess_mime_from_url(request.image_url, "image/png"),
audio_mime=_guess_mime_from_url(request.audio_url, "audio/mpeg"),
)
task_manager.update_task_status(
task_id, "processing", progress=80.0, message="Saving video file..."
)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
# Build authenticated URL if token provided, otherwise return plain URL
video_url = f"/api/story/videos/ai/{video_filename}"
if auth_token:
video_url = f"{video_url}?token={quote(auth_token)}"
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateSceneVoiceover] Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateSceneVoiceover] ✅ Completed user=%s scene=%s cost=$%.2f video=%s",
user_id,
request.scene_number,
animation_result["cost"],
video_url,
)
# Save video asset to library
db = next(get_db())
try:
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="video",
source_module="story_writer",
filename=video_filename,
file_url=video_url,
file_path=str(ai_video_dir / video_filename),
file_size=len(animation_result["video_bytes"]),
mime_type="video/mp4",
title=f"Scene {request.scene_number} Animation (Voiceover)",
description=f"Animated scene {request.scene_number} with voiceover from story",
prompt=animation_result["prompt"],
tags=["story_writer", "video", "animation", "voiceover", f"scene_{request.scene_number}"],
provider=animation_result["provider"],
model=animation_result.get("model_name"),
cost=animation_result["cost"],
asset_metadata={"scene_number": request.scene_number, "duration": animation_result["duration"], "status": "completed"}
)
except Exception as e:
logger.warning(f"[StoryWriter] Failed to save video asset to library: {e}")
finally:
db.close()
result = AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="InfiniteTalk animation complete!",
result=result.dict(),
)
except HTTPException as exc:
error_msg = str(exc.detail) if isinstance(exc.detail, str) else exc.detail.get("error", "Animation failed") if isinstance(exc.detail, dict) else "Animation failed"
scene_logger.error(f"[AnimateSceneVoiceover] Failed: {error_msg}")
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"InfiniteTalk animation failed: {error_msg}",
)
except Exception as exc:
error_msg = str(exc)
scene_logger.error(f"[AnimateSceneVoiceover] Error: {error_msg}", exc_info=True)
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"InfiniteTalk animation error: {error_msg}",
)

View File

@@ -0,0 +1,297 @@
from datetime import datetime
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from pydantic import BaseModel, Field
from middleware.auth_middleware import get_current_user
from models.story_models import (
StoryStartRequest,
StoryContentResponse,
StoryScene,
StoryContinueRequest,
StoryContinueResponse,
)
from services.story_writer.story_service import StoryWriterService
from ..utils.auth import require_authenticated_user
router = APIRouter()
story_service = StoryWriterService()
scene_approval_store: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]] = {}
APPROVAL_TTL_SECONDS = 60 * 60 * 24
MAX_APPROVALS_PER_USER = 200
def _cleanup_user_approvals(user_id: str) -> None:
user_store = scene_approval_store.get(user_id)
if not user_store:
return
now = datetime.utcnow()
for project_id in list(user_store.keys()):
scenes = user_store.get(project_id, {})
for scene_id in list(scenes.keys()):
timestamp = scenes[scene_id].get("timestamp")
if isinstance(timestamp, datetime):
if (now - timestamp).total_seconds() > APPROVAL_TTL_SECONDS:
scenes.pop(scene_id, None)
if not scenes:
user_store.pop(project_id, None)
if not user_store:
scene_approval_store.pop(user_id, None)
def _enforce_capacity(user_id: str) -> None:
user_store = scene_approval_store.get(user_id)
if not user_store:
return
entries: List[tuple[datetime, str, str]] = []
for project_id, scenes in user_store.items():
for scene_id, meta in scenes.items():
timestamp = meta.get("timestamp")
if isinstance(timestamp, datetime):
entries.append((timestamp, project_id, scene_id))
if len(entries) <= MAX_APPROVALS_PER_USER:
return
entries.sort(key=lambda item: item[0])
to_remove = len(entries) - MAX_APPROVALS_PER_USER
for i in range(to_remove):
_, project_id, scene_id = entries[i]
scenes = user_store.get(project_id)
if not scenes:
continue
scenes.pop(scene_id, None)
if not scenes:
user_store.pop(project_id, None)
def _get_user_store(user_id: str) -> Dict[str, Dict[str, Dict[str, Any]]]:
_cleanup_user_approvals(user_id)
return scene_approval_store.setdefault(user_id, {})
@router.post("/generate-start", response_model=StoryContentResponse)
async def generate_story_start(
request: StoryStartRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryContentResponse:
"""Generate the starting section of a story."""
try:
user_id = require_authenticated_user(current_user)
if not request.premise or not request.premise.strip():
raise HTTPException(status_code=400, detail="Premise is required")
if not request.outline or (isinstance(request.outline, str) and not request.outline.strip()):
raise HTTPException(status_code=400, detail="Outline is required")
logger.info(f"[StoryWriter] Generating story start for user {user_id}")
outline_data: Any = request.outline
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
outline_data = [scene.dict() for scene in outline_data]
story_length = getattr(request, "story_length", "Medium")
story_start = story_service.generate_story_start(
premise=request.premise,
outline=outline_data,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
story_length=story_length,
user_id=user_id,
)
story_length_lower = story_length.lower()
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
is_complete = False
if is_short_story:
word_count = len(story_start.split()) if story_start else 0
if word_count >= 900:
is_complete = True
logger.info(
f"[StoryWriter] Short story generated with {word_count} words. Marking as complete."
)
else:
logger.warning(
f"[StoryWriter] Short story generated with only {word_count} words. May need continuation."
)
outline_response = outline_data
if isinstance(outline_response, list):
outline_response = "\n".join(
[
f"Scene {scene.get('scene_number', i + 1)}: "
f"{scene.get('title', 'Untitled')}\n {scene.get('description', '')}"
for i, scene in enumerate(outline_response)
]
)
return StoryContentResponse(
story=story_start,
premise=request.premise,
outline=str(outline_response),
is_complete=is_complete,
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate story start: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/continue", response_model=StoryContinueResponse)
async def continue_story(
request: StoryContinueRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryContinueResponse:
"""Continue writing a story."""
try:
user_id = require_authenticated_user(current_user)
if not request.story_text or not request.story_text.strip():
raise HTTPException(status_code=400, detail="Story text is required")
logger.info(f"[StoryWriter] Continuing story for user {user_id}")
outline_data: Any = request.outline
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
outline_data = [scene.dict() for scene in outline_data]
story_length = getattr(request, "story_length", "Medium")
story_length_lower = story_length.lower()
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
if is_short_story:
logger.warning(
"[StoryWriter] Attempted to continue a short story. Short stories should be complete in one call."
)
raise HTTPException(
status_code=400,
detail="Short stories are generated in a single call and should be complete. "
"If the story is incomplete, please regenerate it from the beginning.",
)
current_word_count = len(request.story_text.split()) if request.story_text else 0
if "long" in story_length_lower or "10000" in story_length_lower:
target_total_words = 10000
else:
target_total_words = 4500
buffer_target = int(target_total_words * 1.05)
if current_word_count >= buffer_target or (
current_word_count >= target_total_words
and (current_word_count - target_total_words) < 50
):
logger.info(
f"[StoryWriter] Word count ({current_word_count}) already at or near target ({target_total_words})."
)
return StoryContinueResponse(continuation="IAMDONE", is_complete=True, success=True)
continuation = story_service.continue_story(
premise=request.premise,
outline=outline_data,
story_text=request.story_text,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
story_length=story_length,
user_id=user_id,
)
is_complete = "IAMDONE" in continuation.upper()
if not is_complete and continuation:
new_story_text = request.story_text + "\n\n" + continuation
new_word_count = len(new_story_text.split())
if new_word_count >= buffer_target:
logger.info(
f"[StoryWriter] Word count ({new_word_count}) now exceeds buffer target ({buffer_target})."
)
if "IAMDONE" not in continuation.upper():
continuation = continuation.rstrip() + "\n\nIAMDONE"
is_complete = True
elif new_word_count >= target_total_words and (
new_word_count - target_total_words
) < 100:
logger.info(
f"[StoryWriter] Word count ({new_word_count}) is at or very close to target ({target_total_words})."
)
if "IAMDONE" not in continuation.upper():
continuation = continuation.rstrip() + "\n\nIAMDONE"
is_complete = True
return StoryContinueResponse(continuation=continuation, is_complete=is_complete, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to continue story: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
class SceneApprovalRequest(BaseModel):
project_id: str = Field(..., min_length=1)
scene_id: str = Field(..., min_length=1)
approved: bool = True
notes: Optional[str] = None
@router.post("/script/approve")
async def approve_script_scene(
request: SceneApprovalRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Persist scene approval metadata for auditing."""
try:
user_id = require_authenticated_user(current_user)
if not request.project_id.strip() or not request.scene_id.strip():
raise HTTPException(status_code=400, detail="project_id and scene_id are required")
notes = request.notes.strip() if request.notes else None
user_store = _get_user_store(user_id)
project_store = user_store.setdefault(request.project_id, {})
timestamp = datetime.utcnow()
project_store[request.scene_id] = {
"approved": request.approved,
"notes": notes,
"user_id": user_id,
"timestamp": timestamp,
}
_enforce_capacity(user_id)
logger.info(
"[StoryWriter] Scene approval recorded user=%s project=%s scene=%s approved=%s",
user_id,
request.project_id,
request.scene_id,
request.approved,
)
return {
"success": True,
"project_id": request.project_id,
"scene_id": request.scene_id,
"approved": request.approved,
"timestamp": timestamp.isoformat(),
}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to approve scene: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,141 @@
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from middleware.auth_middleware import get_current_user
from models.story_models import (
StorySetupGenerationRequest,
StorySetupGenerationResponse,
StorySetupOption,
StoryGenerationRequest,
StoryOutlineResponse,
StoryScene,
StoryStartRequest,
StoryPremiseResponse,
)
from services.story_writer.story_service import StoryWriterService
from ..utils.auth import require_authenticated_user
router = APIRouter()
story_service = StoryWriterService()
@router.post("/generate-setup", response_model=StorySetupGenerationResponse)
async def generate_story_setup(
request: StorySetupGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StorySetupGenerationResponse:
"""Generate 3 story setup options from a user's story idea."""
try:
user_id = require_authenticated_user(current_user)
if not request.story_idea or not request.story_idea.strip():
raise HTTPException(status_code=400, detail="Story idea is required")
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
options = story_service.generate_story_setup_options(
story_idea=request.story_idea,
user_id=user_id,
)
setup_options = [StorySetupOption(**option) for option in options]
return StorySetupGenerationResponse(options=setup_options, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate story setup options: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-premise", response_model=StoryPremiseResponse)
async def generate_premise(
request: StoryGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryPremiseResponse:
"""Generate a story premise."""
try:
user_id = require_authenticated_user(current_user)
logger.info(f"[StoryWriter] Generating premise for user {user_id}")
premise = story_service.generate_premise(
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
user_id=user_id,
)
return StoryPremiseResponse(premise=premise, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate premise: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-outline", response_model=StoryOutlineResponse)
async def generate_outline(
request: StoryStartRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
use_structured: bool = True,
) -> StoryOutlineResponse:
"""Generate a story outline from a premise."""
try:
user_id = require_authenticated_user(current_user)
if not request.premise or not request.premise.strip():
raise HTTPException(status_code=400, detail="Premise is required")
logger.info(
f"[StoryWriter] Generating outline for user {user_id} (structured={use_structured})"
)
logger.info(
"[StoryWriter] Outline params: audience_age_group=%s, writing_style=%s, story_tone=%s",
request.audience_age_group,
request.writing_style,
request.story_tone,
)
outline = story_service.generate_outline(
premise=request.premise,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
user_id=user_id,
use_structured_output=use_structured,
)
if isinstance(outline, list):
scenes: List[StoryScene] = [
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline
]
return StoryOutlineResponse(outline=scenes, success=True, is_structured=True)
return StoryOutlineResponse(outline=str(outline), success=True, is_structured=False)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate outline: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,130 @@
from typing import Any, Dict
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from loguru import logger
from middleware.auth_middleware import get_current_user
from models.story_models import (
StoryGenerationRequest,
TaskStatus,
)
from services.story_writer.story_service import StoryWriterService
from ..cache_manager import cache_manager
from ..task_manager import task_manager
from ..utils.auth import require_authenticated_user
router = APIRouter()
story_service = StoryWriterService()
@router.post("/generate-full", response_model=Dict[str, Any])
async def generate_full_story(
request: StoryGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
max_iterations: int = 10,
) -> Dict[str, Any]:
"""Generate a complete story asynchronously."""
try:
user_id = require_authenticated_user(current_user)
cache_key = cache_manager.get_cache_key(request.dict())
cached_result = cache_manager.get_cached_result(cache_key)
if cached_result:
logger.info(f"[StoryWriter] Returning cached result for user {user_id}")
task_id = task_manager.create_task("story_generation")
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
result=cached_result,
message="Returned cached result",
)
return {"task_id": task_id, "cached": True}
task_id = task_manager.create_task("story_generation")
request_data = request.dict()
request_data["max_iterations"] = max_iterations
background_tasks.add_task(
task_manager.execute_story_generation_task,
task_id=task_id,
request_data=request_data,
user_id=user_id,
)
logger.info(f"[StoryWriter] Created task {task_id} for full story generation (user {user_id})")
return {
"task_id": task_id,
"status": "pending",
"message": "Story generation started. Use /task/{task_id}/status to check progress.",
}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to start story generation: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/task/{task_id}/status", response_model=TaskStatus)
async def get_task_status(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> TaskStatus:
"""Get the status of a story generation task."""
try:
require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
return TaskStatus(**task_status)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to get task status: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/task/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Get the result of a completed story generation task."""
try:
require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
if task_status["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Task {task_id} is not completed. Status: {task_status['status']}",
)
result = task_status.get("result")
if not result:
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
if isinstance(result, dict):
payload = {**result}
payload.setdefault("success", True)
payload["task_id"] = task_id
return payload
return {"result": result, "success": True, "task_id": task_id}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to get task result: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,539 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger
from pydantic import BaseModel
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from models.story_models import (
StoryVideoGenerationRequest,
StoryVideoGenerationResponse,
StoryVideoResult,
StoryScene,
StoryGenerationRequest,
)
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from services.story_writer.story_service import StoryWriterService
from ..task_manager import task_manager
from ..utils.auth import require_authenticated_user
from ..utils.hd_video import (
generate_hd_video_payload,
generate_hd_video_scene_payload,
)
from ..utils.media_utils import resolve_media_file
router = APIRouter()
video_service = StoryVideoGenerationService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
story_service = StoryWriterService()
class HDVideoRequest(BaseModel):
prompt: str
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
class HDVideoSceneRequest(BaseModel):
scene_number: int
scene_data: Dict[str, Any]
story_context: Dict[str, Any]
all_scenes: List[Dict[str, Any]]
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/generate-video", response_model=StoryVideoGenerationResponse)
async def generate_story_video(
request: StoryVideoGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryVideoGenerationResponse:
"""Generate a video from story scenes, images, and audio."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(
status_code=400,
detail="Number of scenes, image URLs, and audio URLs must match",
)
logger.info(f"[StoryWriter] Generating video for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
video_paths: List[Optional[str]] = [] # Animated videos (preferred)
image_paths: List[Optional[str]] = [] # Static images (fallback)
audio_paths: List[str] = []
valid_scenes: List[Dict[str, Any]] = []
# Resolve video/audio directories
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
video_urls = request.video_urls or [None] * len(request.scenes)
ai_audio_urls = request.ai_audio_urls or [None] * len(request.scenes)
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
# Prefer animated video if available
video_url = video_urls[idx] if idx < len(video_urls) else None
video_path = None
image_path = None
if video_url:
# Extract filename from animated video URL (e.g., /api/story/videos/ai/filename.mp4)
video_filename = video_url.split("/")[-1].split("?")[0]
video_path = ai_video_dir / video_filename
if video_path.exists():
logger.info(f"[StoryWriter] Using animated video for scene {scene.get('scene_number', idx+1)}: {video_filename}")
video_paths.append(str(video_path))
image_paths.append(None)
else:
logger.warning(f"[StoryWriter] Animated video not found: {video_path}, falling back to image")
video_paths.append(None)
video_path = None
# Fall back to image if no animated video
if not video_path:
image_filename = image_url.split("/")[-1].split("?")[0]
image_path = image_service.output_dir / image_filename
if image_path.exists():
video_paths.append(None)
image_paths.append(str(image_path))
else:
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
# Prefer AI audio if available, otherwise use free audio
ai_audio_url = ai_audio_urls[idx] if idx < len(ai_audio_urls) else None
audio_filename = None
audio_path = None
if ai_audio_url:
audio_filename = ai_audio_url.split("/")[-1].split("?")[0]
audio_path = audio_service.output_dir / audio_filename
if audio_path.exists():
logger.info(f"[StoryWriter] Using AI audio for scene {scene.get('scene_number', idx+1)}: {audio_filename}")
else:
logger.warning(f"[StoryWriter] AI audio not found: {audio_path}, falling back to free audio")
audio_path = None
# Fall back to free audio if no AI audio
if not audio_path:
audio_filename = audio_url.split("/")[-1].split("?")[0]
audio_path = audio_service.output_dir / audio_filename
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if len(valid_scenes) == 0 or len(audio_paths) == 0:
raise HTTPException(status_code=400, detail="No valid video/image or audio files were found")
if len(valid_scenes) != len(audio_paths):
raise HTTPException(
status_code=400,
detail="Number of valid scenes and audio files must match",
)
video_result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths, # Can contain None for scenes with animated videos
video_paths=video_paths, # Can contain None for scenes with static images
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
)
video_model = StoryVideoResult(
video_filename=video_result.get("video_filename", ""),
video_url=video_result.get("video_url", ""),
duration=video_result.get("duration", 0.0),
fps=video_result.get("fps", 24),
file_size=video_result.get("file_size", 0),
num_scenes=video_result.get("num_scenes", 0),
error=video_result.get("error"),
)
return StoryVideoGenerationResponse(video=video_model, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate video: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-video-async", response_model=Dict[str, Any])
async def generate_story_video_async(
request: StoryVideoGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Generate a video asynchronously with progress updates via task manager.
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
"""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(
status_code=400,
detail="Number of scenes, image URLs, and audio URLs must match",
)
task_id = task_manager.create_task("story_video_generation")
background_tasks.add_task(
_execute_video_generation_task,
task_id=task_id,
request=request,
user_id=user_id,
)
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to start async video generation: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
"""Background task to generate story video with progress mapped to task manager."""
try:
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_paths: List[str] = []
audio_paths: List[str] = []
valid_scenes: List[Dict[str, Any]] = []
for scene, image_url, audio_url in zip(scenes_data, request.image_urls, request.audio_urls):
image_filename = image_url.split("/")[-1].split("?")[0]
audio_filename = audio_url.split("/")[-1].split("?")[0]
image_path = image_service.output_dir / image_filename
audio_path = audio_service.output_dir / audio_filename
if not image_path.exists():
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
image_paths.append(str(image_path))
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
def progress_callback(sub_progress: float, msg: str):
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
progress_callback=progress_callback,
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Video generation complete!",
result={"video": result, "success": True},
)
except Exception as exc:
logger.error(f"[StoryWriter] Async video generation failed: {exc}", exc_info=True)
task_manager.update_task_status(task_id, "failed", error=str(exc), message=f"Video generation failed: {exc}")
@router.post("/hd-video")
async def generate_hd_video(
request: HDVideoRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_payload(request, user_id)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate HD video: {exc}", exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/hd-video-scene")
async def generate_hd_video_scene(
request: HDVideoSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_scene_payload(request, user_id)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {exc}", exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-complete-video", response_model=Dict[str, Any])
async def generate_complete_story_video(
request: StoryGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Generate a complete story video workflow asynchronously."""
try:
user_id = require_authenticated_user(current_user)
logger.info(f"[StoryWriter] Starting complete video generation for user {user_id}")
task_id = task_manager.create_task("complete_video_generation")
background_tasks.add_task(
execute_complete_video_generation,
task_id=task_id,
request_data=request.dict(),
user_id=user_id,
)
return {
"task_id": task_id,
"status": "pending",
"message": "Complete video generation started",
}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to start complete video generation: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
def execute_complete_video_generation(
task_id: str,
request_data: Dict[str, Any],
user_id: str,
):
"""
Execute complete video generation workflow synchronously.
Runs in a background task and performs blocking operations.
"""
try:
task_manager.update_task_status(task_id, "processing", progress=5.0, message="Starting complete video generation...")
task_manager.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
premise = story_service.generate_premise(
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id,
)
task_manager.update_task_status(task_id, "processing", progress=20.0, message="Generating structured outline with scenes...")
outline_scenes = story_service.generate_outline(
premise=premise,
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id,
use_structured_output=True,
)
if not isinstance(outline_scenes, list):
raise RuntimeError("Failed to generate structured outline")
task_manager.update_task_status(task_id, "processing", progress=30.0, message="Generating images for scenes...")
def image_progress_callback(sub_progress: float, message: str):
overall_progress = 30.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
image_results = image_service.generate_scene_images(
scenes=outline_scenes,
user_id=user_id,
provider=request_data.get("image_provider"),
width=request_data.get("image_width", 1024),
height=request_data.get("image_height", 1024),
model=request_data.get("image_model"),
progress_callback=image_progress_callback,
)
task_manager.update_task_status(task_id, "processing", progress=50.0, message="Generating audio narration for scenes...")
def audio_progress_callback(sub_progress: float, message: str):
overall_progress = 50.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
audio_results = audio_service.generate_scene_audio_list(
scenes=outline_scenes,
user_id=user_id,
provider=request_data.get("audio_provider", "gtts"),
lang=request_data.get("audio_lang", "en"),
slow=request_data.get("audio_slow", False),
rate=request_data.get("audio_rate", 150),
progress_callback=audio_progress_callback,
)
task_manager.update_task_status(task_id, "processing", progress=70.0, message="Preparing video assets...")
image_paths: List[str] = []
audio_paths: List[str] = []
valid_scenes: List[Dict[str, Any]] = []
for scene in outline_scenes:
scene_number = scene.get("scene_number", 0)
image_result = next((img for img in image_results if img.get("scene_number") == scene_number), None)
audio_result = next((aud for aud in audio_results if aud.get("scene_number") == scene_number), None)
if image_result and audio_result and not image_result.get("error") and not audio_result.get("error"):
image_path = image_result.get("image_path")
audio_path = audio_result.get("audio_path")
if image_path and audio_path:
image_paths.append(image_path)
audio_paths.append(audio_path)
valid_scenes.append(scene)
if len(image_paths) == 0 or len(audio_paths) == 0:
raise RuntimeError(
f"No valid images or audio files were generated. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
)
if len(image_paths) != len(audio_paths):
raise RuntimeError(
f"Mismatch between image and audio counts. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
)
task_manager.update_task_status(task_id, "processing", progress=75.0, message="Composing video from scenes...")
def video_progress_callback(sub_progress: float, message: str):
overall_progress = 75.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
video_result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request_data.get("story_setting", "Story")[:50],
fps=request_data.get("video_fps", 24),
transition_duration=request_data.get("video_transition_duration", 0.5),
progress_callback=video_progress_callback,
)
result = {
"premise": premise,
"outline_scenes": outline_scenes,
"images": image_results,
"audio_files": audio_results,
"video": video_result,
"success": True,
}
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Complete video generation finished!",
result=result,
)
logger.info(f"[StoryWriter] Complete video generation task {task_id} completed successfully")
except Exception as exc:
error_msg = str(exc)
logger.error(f"[StoryWriter] Complete video generation task {task_id} failed: {error_msg}", exc_info=True)
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"Complete video generation failed: {error_msg}",
)
@router.get("/videos/{video_filename}")
async def serve_story_video(
video_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve a generated story video file."""
try:
require_authenticated_user(current_user)
video_path = resolve_media_file(video_service.output_dir, video_filename)
return FileResponse(path=str(video_path), media_type="video/mp4", filename=video_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve video: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/videos/ai/{video_filename}")
async def serve_ai_story_video(
video_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve a generated AI scene animation video."""
try:
require_authenticated_user(current_user)
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
video_service_ai = StoryVideoGenerationService(output_dir=str(ai_video_dir))
video_path = resolve_media_file(video_service_ai.output_dir, video_filename)
return FileResponse(
path=str(video_path),
media_type="video/mp4",
filename=video_filename
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve AI video: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,253 @@
"""
Task Management System for Story Writer API
Handles background task execution, status tracking, and progress updates
for story generation operations.
"""
import asyncio
import uuid
from datetime import datetime
from typing import Any, Dict, Optional
from loguru import logger
class TaskManager:
"""Manages background tasks for story generation."""
def __init__(self):
"""Initialize the task manager."""
self.task_storage: Dict[str, Dict[str, Any]] = {}
logger.info("[StoryWriter] TaskManager initialized")
def cleanup_old_tasks(self):
"""Remove tasks older than 1 hour to prevent memory leaks."""
current_time = datetime.now()
tasks_to_remove = []
for task_id, task_data in self.task_storage.items():
created_at = task_data.get("created_at")
if created_at and (current_time - created_at).total_seconds() > 3600: # 1 hour
tasks_to_remove.append(task_id)
for task_id in tasks_to_remove:
del self.task_storage[task_id]
logger.debug(f"[StoryWriter] Cleaned up old task: {task_id}")
def create_task(self, task_type: str = "story_generation") -> str:
"""Create a new task and return its ID."""
task_id = str(uuid.uuid4())
self.task_storage[task_id] = {
"status": "pending",
"created_at": datetime.now(),
"result": None,
"error": None,
"progress_messages": [],
"task_type": task_type,
"progress": 0.0
}
logger.info(f"[StoryWriter] Created task: {task_id} (type: {task_type})")
return task_id
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
"""Get the status of a task."""
self.cleanup_old_tasks()
if task_id not in self.task_storage:
# Log at DEBUG level - task not found is expected when tasks expire or are cleaned up
# This prevents log spam from frontend polling for expired/completed tasks
logger.debug(f"[StoryWriter] Task not found: {task_id} (may have expired or been cleaned up)")
return None
task = self.task_storage[task_id]
response = {
"task_id": task_id,
"status": task["status"],
"progress": task.get("progress", 0.0),
"message": task.get("progress_messages", [])[-1] if task.get("progress_messages") else None,
"created_at": task["created_at"].isoformat() if task.get("created_at") else None,
"updated_at": task.get("updated_at", task.get("created_at")).isoformat() if task.get("updated_at") or task.get("created_at") else None,
}
if task["status"] == "completed" and task.get("result"):
response["result"] = task["result"]
if task["status"] == "failed" and task.get("error"):
response["error"] = task["error"]
return response
def update_task_status(
self,
task_id: str,
status: str,
progress: Optional[float] = None,
message: Optional[str] = None,
result: Optional[Dict[str, Any]] = None,
error: Optional[str] = None
):
"""Update the status of a task."""
if task_id not in self.task_storage:
logger.warning(f"[StoryWriter] Cannot update non-existent task: {task_id}")
return
task = self.task_storage[task_id]
task["status"] = status
task["updated_at"] = datetime.now()
if progress is not None:
task["progress"] = progress
if message:
if "progress_messages" not in task:
task["progress_messages"] = []
task["progress_messages"].append(message)
logger.info(f"[StoryWriter] Task {task_id}: {message} (progress: {progress}%)")
if result is not None:
task["result"] = result
if error is not None:
task["error"] = error
logger.error(f"[StoryWriter] Task {task_id} error: {error}")
async def execute_story_generation_task(
self,
task_id: str,
request_data: Dict[str, Any],
user_id: str
):
"""Execute story generation task asynchronously."""
from services.story_writer.story_service import StoryWriterService
service = StoryWriterService()
try:
self.update_task_status(task_id, "processing", progress=0.0, message="Starting story generation...")
# Step 1: Generate premise
self.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
premise = service.generate_premise(
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
)
# Step 2: Generate outline
self.update_task_status(task_id, "processing", progress=30.0, message="Generating story outline...")
outline = service.generate_outline(
premise=premise,
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
)
# Step 3: Generate story start
self.update_task_status(task_id, "processing", progress=50.0, message="Writing story beginning...")
story_start = service.generate_story_start(
premise=premise,
outline=outline,
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
)
# Step 4: Continue story
self.update_task_status(task_id, "processing", progress=70.0, message="Continuing story generation...")
story_text = story_start
max_iterations = request_data.get("max_iterations", 10)
iteration = 0
while 'IAMDONE' not in story_text and iteration < max_iterations:
iteration += 1
progress = 70.0 + (iteration / max_iterations) * 25.0
self.update_task_status(
task_id,
"processing",
progress=min(progress, 95.0),
message=f"Writing continuation {iteration}/{max_iterations}..."
)
continuation = service.continue_story(
premise=premise,
outline=outline,
story_text=story_text,
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
)
if continuation:
story_text += '\n\n' + continuation
else:
logger.warning(f"[StoryWriter] Empty continuation at iteration {iteration}")
break
# Clean up and finalize
final_story = story_text.replace('IAMDONE', '').strip()
result = {
"premise": premise,
"outline": outline,
"story": final_story,
"is_complete": 'IAMDONE' in story_text or iteration >= max_iterations,
"iterations": iteration
}
self.update_task_status(
task_id,
"completed",
progress=100.0,
message="Story generation completed!",
result=result
)
logger.info(f"[StoryWriter] Task {task_id} completed successfully")
except Exception as e:
error_msg = str(e)
logger.error(f"[StoryWriter] Task {task_id} failed: {error_msg}")
self.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"Story generation failed: {error_msg}"
)
# Global task manager instance
task_manager = TaskManager()

View File

@@ -0,0 +1,8 @@
"""
Utility helpers for Story Writer API routes.
Grouped here to keep the main router lean while reusing common logic
such as authentication guards, media resolution, and HD video helpers.
"""

View File

@@ -0,0 +1,23 @@
from typing import Any, Dict
from fastapi import HTTPException, status
def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
"""
Validates the current user dictionary provided by Clerk middleware and
returns the normalized user_id. Raises HTTP 401 if authentication fails.
"""
if not current_user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
user_id = str(current_user.get("id", "")).strip()
if not user_id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid user ID in authentication token",
)
return user_id

View File

@@ -0,0 +1,146 @@
from __future__ import annotations
from typing import Any, Dict
from fastapi import HTTPException
from loguru import logger
from uuid import uuid4
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""Handles synchronous HD video generation."""
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
video_service = StoryVideoGenerationService()
output_dir = video_service.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
logger.info(f"[StoryWriter] Generating HD video via {getattr(request, 'provider', 'huggingface')} for user {user_id}")
result = ai_video_generate(
prompt=request.prompt,
operation_type="text-to-video",
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
**kwargs,
)
# Extract video bytes from result dict
video_bytes = result["video_bytes"]
filename = f"hd_{uuid4().hex}.mp4"
file_path = output_dir / filename
with open(file_path, "wb") as fh:
fh.write(video_bytes)
logger.info(f"[StoryWriter] HD video saved to {file_path}")
return {
"success": True,
"video_filename": filename,
"video_url": f"/api/story/videos/{filename}",
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""
Handles per-scene HD video generation including prompt enhancement
and subscription validation.
"""
from services.database import get_db as get_db_validation
from services.onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
from services.story_writer.prompt_enhancer_service import enhance_scene_prompt_for_video
from services.llm_providers.main_video_generation import ai_video_generate
from services.story_writer.video_generation_service import StoryVideoGenerationService
scene_number = request.scene_number
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
hf_token = APIKeyManager().get_api_key("hf_token")
if not hf_token:
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
raise HTTPException(
status_code=400,
detail={
"error": "Hugging Face API token is not configured. Please configure your HF token in settings.",
"message": "Hugging Face API token is not configured. Please configure your HF token in settings.",
},
)
db_validation = next(get_db_validation())
try:
pricing_service = PricingService(db_validation)
logger.info(f"[StoryWriter] Pre-flight: Checking video generation limits for user {user_id}...")
validate_video_generation_operations(pricing_service=pricing_service, user_id=user_id)
logger.info("[StoryWriter] Pre-flight: ✅ Video generation limits validated - proceeding")
finally:
db_validation.close()
enhanced_prompt = enhance_scene_prompt_for_video(
current_scene=request.scene_data,
story_context=request.story_context,
all_scenes=request.all_scenes,
user_id=user_id,
)
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
if getattr(request, "num_frames", None):
kwargs["num_frames"] = request.num_frames
if getattr(request, "guidance_scale", None) is not None:
kwargs["guidance_scale"] = request.guidance_scale
if getattr(request, "num_inference_steps", None):
kwargs["num_inference_steps"] = request.num_inference_steps
if getattr(request, "negative_prompt", None):
kwargs["negative_prompt"] = request.negative_prompt
if getattr(request, "seed", None) is not None:
kwargs["seed"] = request.seed
result = ai_video_generate(
prompt=enhanced_prompt,
operation_type="text-to-video",
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
**kwargs,
)
# Extract video bytes from result dict
video_bytes = result["video_bytes"]
video_service = StoryVideoGenerationService()
save_result = video_service.save_scene_video(
video_bytes=video_bytes,
scene_number=scene_number,
user_id=user_id,
)
logger.info(f"[StoryWriter] HD video saved for scene {scene_number}: {save_result.get('video_filename')}")
return {
"success": True,
"scene_number": scene_number,
"video_filename": save_result.get("video_filename"),
"video_url": save_result.get("video_url"),
"prompt_used": enhanced_prompt,
"provider": getattr(request, "provider", None) or "huggingface",
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}

View File

@@ -0,0 +1,148 @@
from __future__ import annotations
from pathlib import Path
from typing import Optional
from urllib.parse import urlparse
from fastapi import HTTPException, status
from loguru import logger
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
STORY_AUDIO_DIR = (BASE_DIR / "story_audio").resolve()
STORY_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
"""
Resolve an authenticated story image URL (e.g., /api/story/images/<file>) to raw bytes.
Returns None if the file cannot be located.
"""
if not image_url:
return None
try:
parsed = urlparse(image_url)
path = parsed.path if parsed.scheme else image_url
prefix = "/api/story/images/"
if prefix not in path:
logger.warning(f"[StoryWriter] Unsupported image URL for video reference: {image_url}")
return None
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
if not filename:
return None
file_path = (STORY_IMAGES_DIR / filename).resolve()
if not str(file_path).startswith(str(STORY_IMAGES_DIR)):
logger.error(f"[StoryWriter] Attempted path traversal when resolving image: {image_url}")
return None
if not file_path.exists():
logger.warning(f"[StoryWriter] Referenced scene image not found on disk: {file_path}")
return None
return file_path.read_bytes()
except Exception as exc:
logger.error(f"[StoryWriter] Failed to load reference image for video gen: {exc}")
return None
def load_story_audio_bytes(audio_url: str) -> Optional[bytes]:
"""
Resolve an authenticated story audio URL (e.g., /api/story/audio/<file>) to raw bytes.
Returns None if the file cannot be located.
"""
if not audio_url:
return None
try:
parsed = urlparse(audio_url)
path = parsed.path if parsed.scheme else audio_url
prefix = "/api/story/audio/"
if prefix not in path:
logger.warning(f"[StoryWriter] Unsupported audio URL for video reference: {audio_url}")
return None
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
if not filename:
return None
file_path = (STORY_AUDIO_DIR / filename).resolve()
if not str(file_path).startswith(str(STORY_AUDIO_DIR)):
logger.error(f"[StoryWriter] Attempted path traversal when resolving audio: {audio_url}")
return None
if not file_path.exists():
logger.warning(f"[StoryWriter] Referenced scene audio not found on disk: {file_path}")
return None
return file_path.read_bytes()
except Exception as exc:
logger.error(f"[StoryWriter] Failed to load reference audio for video gen: {exc}")
return None
def resolve_media_file(base_dir: Path, filename: str) -> Path:
"""
Returns a safe resolved path for a media file stored under base_dir.
Guards against directory traversal and ensures the file exists.
"""
filename = filename.split("?")[0].strip()
resolved = (base_dir / filename).resolve()
try:
resolved.relative_to(base_dir.resolve())
except ValueError:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
if not resolved.exists():
alternate = _find_alternate_media_file(base_dir, filename)
if alternate:
logger.warning(
"[StoryWriter] Requested media file '%s' missing; serving closest match '%s'",
filename,
alternate.name,
)
return alternate
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
return resolved
def _find_alternate_media_file(base_dir: Path, filename: str) -> Optional[Path]:
"""
Attempt to find the most recent media file that matches the original name prefix.
This helps when files are regenerated with new UUID/hash suffixes but the frontend still
references an older filename.
"""
try:
base_dir = base_dir.resolve()
except Exception:
return None
stem = Path(filename).stem
suffix = Path(filename).suffix
if not suffix or "_" not in stem:
return None
prefix = stem.rsplit("_", 1)[0]
pattern = f"{prefix}_*{suffix}"
try:
candidates = sorted(
(p for p in base_dir.glob(pattern) if p.is_file()),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
except Exception as exc:
logger.debug(f"[StoryWriter] Failed to search alternate media files for {filename}: {exc}")
return None
return candidates[0] if candidates else None