AI story writer enhancements, text to video and voice generation, subscription management, and more.

This commit is contained in:
ajaysi
2025-11-19 09:55:32 +05:30
parent bf7493c366
commit e96525347b
64 changed files with 10367 additions and 400 deletions

View File

@@ -134,6 +134,12 @@ def generate(
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
db_track.commit()
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
@@ -148,6 +154,7 @@ def generate(
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
@@ -437,6 +444,12 @@ def edit(
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
db_track.commit()
logger.info(f"[images.edit] ✅ Successfully tracked usage: user {user_id} -> image_edit -> {new_calls} calls")
@@ -451,6 +464,7 @@ def edit(
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Images: {current_image_gen_calls} / {image_gen_limit if image_gen_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:

View File

@@ -5,12 +5,19 @@ Main router for story generation operations including premise, outline,
content generation, and full story creation.
"""
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
from typing import Any, Dict, Union, List, Optional
import mimetypes
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
from loguru import logger
from middleware.auth_middleware import get_current_user
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from models.story_models import (
AnimateSceneRequest,
AnimateSceneVoiceoverRequest,
AnimateSceneResponse,
ResumeSceneAnimationRequest,
StoryGenerationRequest,
StorySetupGenerationRequest,
StorySetupGenerationResponse,
@@ -34,24 +41,66 @@ from models.story_models import (
StoryVideoResult,
TaskStatus,
)
from pydantic import BaseModel, Field
from services.database import get_db
from services.llm_providers.main_video_generation import track_video_usage
from services.story_writer.story_service import StoryWriterService
from .task_manager import task_manager
from .cache_manager import cache_manager
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_scene_animation_operation
from services.wavespeed.kling_animation import animate_scene_image, resume_scene_animation
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
from uuid import uuid4
from pydantic import BaseModel
from pathlib import Path
from utils.logger_utils import get_service_logger
from .cache_manager import cache_manager
from .routes import cache_routes, media_generation, story_content, story_setup, story_tasks, video_generation
from .task_manager import task_manager
from .utils.auth import require_authenticated_user
from .utils.media_utils import resolve_media_file
from .utils.hd_video import (
generate_hd_video_payload,
generate_hd_video_scene_payload,
)
from .utils.hd_video import generate_hd_video_payload, generate_hd_video_scene_payload
from .utils.media_utils import load_story_image_bytes, load_story_audio_bytes, resolve_media_file
from urllib.parse import quote
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
# Include modular routers (order preserved roughly by workflow)
router.include_router(story_setup.router)
router.include_router(story_content.router)
router.include_router(story_tasks.router)
router.include_router(media_generation.router)
router.include_router(video_generation.router)
router.include_router(cache_routes.router)
service = StoryWriterService()
scene_logger = get_service_logger("api.story_writer.scene_animation")
AI_VIDEO_SUBDIR = Path("AI_Videos")
def _build_authenticated_media_url(request: Request, path: str) -> str:
"""Append the caller's auth token to a media URL so <video>/<img> tags can access it."""
if not path:
return path
token: Optional[str] = None
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.replace("Bearer ", "").strip()
elif "token" in request.query_params:
token = request.query_params["token"]
if token:
separator = "&" if "?" in path else "?"
path = f"{path}{separator}token={quote(token)}"
return path
def _guess_mime_from_url(url: str, fallback: str) -> str:
if not url:
return fallback
mime, _ = mimetypes.guess_type(url)
return mime or fallback
@router.get("/health")
@@ -558,6 +607,22 @@ async def get_task_result(
logger.error(f"[StoryWriter] Failed to get task result: {e}")
raise HTTPException(status_code=500, detail=str(e))
class PromptOptimizeRequest(BaseModel):
text: str = Field(..., description="The prompt text to optimize")
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
style: Optional[str] = Field(
default="default",
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
)
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
class PromptOptimizeResponse(BaseModel):
optimized_prompt: str
success: bool
class HDVideoRequest(BaseModel):
prompt: str
provider: str = "huggingface"
@@ -692,6 +757,51 @@ async def generate_scene_images(
raise HTTPException(status_code=500, detail=str(e))
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
async def optimize_prompt(
request: PromptOptimizeRequest,
current_user: Dict[str, Any] = Depends(get_current_user)
) -> PromptOptimizeResponse:
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
try:
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
if not request.text or not request.text.strip():
raise HTTPException(status_code=400, detail="Prompt text is required")
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
from services.wavespeed.client import WaveSpeedClient
client = WaveSpeedClient()
optimized_prompt = client.optimize_prompt(
text=request.text.strip(),
mode=request.mode or "image",
style=request.style or "default",
image=request.image, # Optional base64 image
enable_sync_mode=True,
timeout=30
)
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
return PromptOptimizeResponse(
optimized_prompt=optimized_prompt,
success=True
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to optimize prompt: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.get("/images/{image_filename}")
async def serve_scene_image(
image_filename: str,
@@ -793,32 +903,376 @@ async def generate_scene_audio(
raise HTTPException(status_code=500, detail=str(e))
@router.get("/audio/{audio_filename}")
async def serve_scene_audio(
audio_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve a generated story scene audio file."""
# Audio serving endpoint is handled by routes/media_generation.py
# No duplicate endpoint needed here
# ---------------------------
# Scene Animation Endpoints
# ---------------------------
@router.post("/animate-scene-preview", response_model=AnimateSceneResponse)
async def animate_scene_preview(
request_obj: Request,
request: AnimateSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimateSceneResponse:
"""
Animate a single scene image using WaveSpeed Kling v2.5 Turbo Std.
"""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
duration = request.duration or 5
if duration not in (5, 10):
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds.")
scene_logger.info(
"[AnimateScene] User=%s scene=%s duration=%s image_url=%s",
user_id,
request.scene_number,
duration,
request.image_url,
)
image_bytes = load_story_image_bytes(request.image_url)
if not image_bytes:
scene_logger.warning("[AnimateScene] Missing image bytes for user=%s scene=%s", user_id, request.scene_number)
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
db = next(get_db())
try:
require_authenticated_user(current_user)
pricing_service = PricingService(db)
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from fastapi.responses import FileResponse
animation_result = animate_scene_image(
image_bytes=image_bytes,
scene_data=request.scene_data,
story_context=request.story_context,
user_id=user_id,
duration=duration,
)
audio_service = StoryAudioGenerationService()
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
base_dir = Path(__file__).parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
return FileResponse(
path=str(audio_path),
media_type="audio/mpeg",
filename=audio_filename
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
video_url = _build_authenticated_media_url(
request_obj, f"/api/story/videos/ai/{video_filename}"
)
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateScene] Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to serve audio: {e}")
raise HTTPException(status_code=500, detail=str(e))
scene_logger.info(
"[AnimateScene] ✅ Completed user=%s scene=%s duration=%s cost=$%.2f video=%s",
user_id,
request.scene_number,
animation_result["duration"],
animation_result["cost"],
video_url,
)
return AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
@router.post("/animate-scene-resume", response_model=AnimateSceneResponse)
async def resume_scene_animation_endpoint(
request_obj: Request,
request: ResumeSceneAnimationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimateSceneResponse:
"""Resume downloading a WaveSpeed animation when the initial call timed out."""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
scene_logger.info(
"[AnimateScene] Resume requested user=%s scene=%s prediction=%s",
user_id,
request.scene_number,
request.prediction_id,
)
animation_result = resume_scene_animation(
prediction_id=request.prediction_id,
duration=request.duration or 5,
user_id=user_id,
)
base_dir = Path(__file__).parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
video_url = _build_authenticated_media_url(
request_obj, f"/api/story/videos/ai/{video_filename}"
)
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateScene] (Resume) Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateScene] ✅ Resume completed user=%s scene=%s prediction=%s video=%s",
user_id,
request.scene_number,
request.prediction_id,
video_url,
)
return AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
@router.post("/animate-scene-voiceover", response_model=Dict[str, Any])
async def animate_scene_voiceover_endpoint(
request_obj: Request,
request: AnimateSceneVoiceoverRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Animate a scene using WaveSpeed InfiniteTalk (image + audio) asynchronously.
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
"""
if not current_user:
raise HTTPException(status_code=401, detail="Authentication required")
user_id = str(current_user.get("id", ""))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
scene_logger.info(
"[AnimateSceneVoiceover] User=%s scene=%s resolution=%s (async)",
user_id,
request.scene_number,
request.resolution or "720p",
)
image_bytes = load_story_image_bytes(request.image_url)
if not image_bytes:
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
audio_bytes = load_story_audio_bytes(request.audio_url)
if not audio_bytes:
raise HTTPException(status_code=404, detail="Scene audio not found. Generate audio first.")
db = next(get_db())
try:
pricing_service = PricingService(db)
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
finally:
db.close()
# Extract token for authenticated URL building (if needed)
auth_token = None
auth_header = request_obj.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.replace("Bearer ", "").strip()
# Create async task
task_id = task_manager.create_task("scene_voiceover_animation")
background_tasks.add_task(
_execute_voiceover_animation_task,
task_id=task_id,
request=request,
user_id=user_id,
image_bytes=image_bytes,
audio_bytes=audio_bytes,
auth_token=auth_token,
)
return {
"task_id": task_id,
"status": "pending",
"message": "InfiniteTalk animation started. This may take up to 10 minutes.",
}
def _execute_voiceover_animation_task(
task_id: str,
request: AnimateSceneVoiceoverRequest,
user_id: str,
image_bytes: bytes,
audio_bytes: bytes,
auth_token: Optional[str] = None,
):
"""Background task to generate InfiniteTalk video with progress updates."""
try:
task_manager.update_task_status(
task_id, "processing", progress=5.0, message="Submitting to WaveSpeed InfiniteTalk..."
)
animation_result = animate_scene_with_voiceover(
image_bytes=image_bytes,
audio_bytes=audio_bytes,
scene_data=request.scene_data,
story_context=request.story_context,
user_id=user_id,
resolution=request.resolution or "720p",
prompt_override=request.prompt,
image_mime=_guess_mime_from_url(request.image_url, "image/png"),
audio_mime=_guess_mime_from_url(request.audio_url, "audio/mpeg"),
)
task_manager.update_task_status(
task_id, "processing", progress=80.0, message="Saving video file..."
)
base_dir = Path(__file__).parent.parent.parent
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
ai_video_dir.mkdir(parents=True, exist_ok=True)
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
save_result = video_service.save_scene_video(
video_bytes=animation_result["video_bytes"],
scene_number=request.scene_number,
user_id=user_id,
)
video_filename = save_result["video_filename"]
# Build authenticated URL if token provided, otherwise return plain URL
video_url = f"/api/story/videos/ai/{video_filename}"
if auth_token:
video_url = f"{video_url}?token={quote(auth_token)}"
usage_info = track_video_usage(
user_id=user_id,
provider=animation_result["provider"],
model_name=animation_result["model_name"],
prompt=animation_result["prompt"],
video_bytes=animation_result["video_bytes"],
cost_override=animation_result["cost"],
)
if usage_info:
scene_logger.warning(
"[AnimateSceneVoiceover] Video usage tracked user=%s: %s%s / %s (cost +$%.2f, total=$%.2f)",
user_id,
usage_info.get("previous_calls"),
usage_info.get("current_calls"),
usage_info.get("video_limit_display"),
usage_info.get("cost_per_video", 0.0),
usage_info.get("total_video_cost", 0.0),
)
scene_logger.info(
"[AnimateSceneVoiceover] ✅ Completed user=%s scene=%s cost=$%.2f video=%s",
user_id,
request.scene_number,
animation_result["cost"],
video_url,
)
result = AnimateSceneResponse(
success=True,
scene_number=request.scene_number,
video_filename=video_filename,
video_url=video_url,
duration=animation_result["duration"],
cost=animation_result["cost"],
prompt_used=animation_result["prompt"],
provider=animation_result["provider"],
prediction_id=animation_result.get("prediction_id"),
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="InfiniteTalk animation complete!",
result=result.dict(),
)
except HTTPException as exc:
error_msg = str(exc.detail) if isinstance(exc.detail, str) else exc.detail.get("error", "Animation failed") if isinstance(exc.detail, dict) else "Animation failed"
scene_logger.error(f"[AnimateSceneVoiceover] Failed: {error_msg}")
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"InfiniteTalk animation failed: {error_msg}",
)
except Exception as exc:
error_msg = str(exc)
scene_logger.error(f"[AnimateSceneVoiceover] Error: {error_msg}", exc_info=True)
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"InfiniteTalk animation error: {error_msg}",
)
# ---------------------------
@@ -1260,19 +1714,25 @@ def execute_complete_video_generation(
)
@router.get("/videos/{video_filename}")
async def serve_story_video(
# Regular video serving endpoint is handled by routes/video_generation.py
# Only AI videos need a separate endpoint here
@router.get("/videos/ai/{video_filename}")
async def serve_ai_story_video(
video_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""Serve a generated story video file."""
"""Serve a generated AI scene animation video."""
try:
require_authenticated_user(current_user)
from services.story_writer.video_generation_service import StoryVideoGenerationService
from fastapi.responses import FileResponse
video_service = StoryVideoGenerationService()
base_dir = Path(__file__).parent.parent.parent
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
video_path = resolve_media_file(video_service.output_dir, video_filename)
return FileResponse(
@@ -1284,7 +1744,7 @@ async def serve_story_video(
except HTTPException:
raise
except Exception as e:
logger.error(f"[StoryWriter] Failed to serve video: {e}")
logger.error(f"[StoryWriter] Failed to serve AI video: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -0,0 +1,21 @@
"""
Collection of modular routers for Story Writer endpoints.
Each module focuses on a related set of routes to keep the primary
`router.py` concise and easier to maintain.
"""
from . import story_setup
from . import story_content
from . import story_tasks
from . import media_generation
from . import video_generation
from . import cache_routes
__all__ = [
"story_setup",
"story_content",
"story_tasks",
"media_generation",
"video_generation",
"cache_routes",
]

View File

@@ -0,0 +1,42 @@
from typing import Any, Dict
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from middleware.auth_middleware import get_current_user
from ..cache_manager import cache_manager
from ..utils.auth import require_authenticated_user
router = APIRouter()
@router.get("/cache/stats")
async def get_cache_stats(
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Get cache statistics."""
try:
require_authenticated_user(current_user)
stats = cache_manager.get_cache_stats()
return {"success": True, "stats": stats}
except Exception as exc:
logger.error(f"[StoryWriter] Failed to get cache stats: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/cache/clear")
async def clear_cache(
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Clear the story generation cache."""
try:
require_authenticated_user(current_user)
result = cache_manager.clear_cache()
return {"success": True, **result}
except Exception as exc:
logger.error(f"[StoryWriter] Failed to clear cache: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,289 @@
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from models.story_models import (
StoryImageGenerationRequest,
StoryImageGenerationResponse,
StoryImageResult,
RegenerateImageRequest,
RegenerateImageResponse,
StoryAudioGenerationRequest,
StoryAudioGenerationResponse,
StoryAudioResult,
GenerateAIAudioRequest,
GenerateAIAudioResponse,
StoryScene,
)
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from ..utils.auth import require_authenticated_user
from ..utils.media_utils import resolve_media_file
router = APIRouter()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
@router.post("/generate-images", response_model=StoryImageGenerationResponse)
async def generate_scene_images(
request: StoryImageGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryImageGenerationResponse:
"""Generate images for story scenes."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating images for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_results = image_service.generate_scene_images(
scenes=scenes_data,
user_id=user_id,
provider=request.provider,
width=request.width or 1024,
height=request.height or 1024,
model=request.model,
)
image_models: List[StoryImageResult] = [
StoryImageResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
width=result.get("width", 1024),
height=result.get("height", 1024),
provider=result.get("provider", "unknown"),
model=result.get("model"),
seed=result.get("seed"),
error=result.get("error"),
)
for result in image_results
]
return StoryImageGenerationResponse(images=image_models, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate images: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/regenerate-images", response_model=RegenerateImageResponse)
async def regenerate_scene_image(
request: RegenerateImageRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> RegenerateImageResponse:
"""Regenerate a single scene image using a direct prompt (no AI prompt generation)."""
try:
user_id = require_authenticated_user(current_user)
if not request.prompt or not request.prompt.strip():
raise HTTPException(status_code=400, detail="Prompt is required")
logger.info(
f"[StoryWriter] Regenerating image for scene {request.scene_number} "
f"({request.scene_title}) for user {user_id}"
)
result = image_service.regenerate_scene_image(
scene_number=request.scene_number,
scene_title=request.scene_title,
prompt=request.prompt.strip(),
user_id=user_id,
provider=request.provider,
width=request.width or 1024,
height=request.height or 1024,
model=request.model,
)
return RegenerateImageResponse(
scene_number=result.get("scene_number", request.scene_number),
scene_title=result.get("scene_title", request.scene_title),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
width=result.get("width", request.width or 1024),
height=result.get("height", request.height or 1024),
provider=result.get("provider", "unknown"),
model=result.get("model"),
seed=result.get("seed"),
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to regenerate image: {exc}")
return RegenerateImageResponse(
scene_number=request.scene_number,
scene_title=request.scene_title,
image_filename="",
image_url="",
width=request.width or 1024,
height=request.height or 1024,
provider=request.provider or "unknown",
success=False,
error=str(exc),
)
@router.get("/images/{image_filename}")
async def serve_scene_image(
image_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
):
"""Serve a generated story scene image.
Supports authentication via Authorization header or token query parameter.
Query parameter is useful for HTML elements like <img> that cannot send custom headers.
"""
try:
require_authenticated_user(current_user)
image_path = resolve_media_file(image_service.output_dir, image_filename)
return FileResponse(path=str(image_path), media_type="image/png", filename=image_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve image: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-audio", response_model=StoryAudioGenerationResponse)
async def generate_scene_audio(
request: StoryAudioGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryAudioGenerationResponse:
"""Generate audio narration for story scenes."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
logger.info(f"[StoryWriter] Generating audio for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
audio_results = audio_service.generate_scene_audio_list(
scenes=scenes_data,
user_id=user_id,
provider=request.provider or "gtts",
lang=request.lang or "en",
slow=request.slow or False,
rate=request.rate or 150,
)
audio_models: List[StoryAudioResult] = []
for result in audio_results:
audio_models.append(
StoryAudioResult(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
audio_filename=result.get("audio_filename") or "",
audio_url=result.get("audio_url") or "",
provider=result.get("provider", "unknown"),
file_size=result.get("file_size", 0),
error=result.get("error"),
)
)
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate audio: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-ai-audio", response_model=GenerateAIAudioResponse)
async def generate_ai_audio(
request: GenerateAIAudioRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> GenerateAIAudioResponse:
"""Generate AI audio for a single scene using WaveSpeed Minimax Speech 02 HD."""
try:
user_id = require_authenticated_user(current_user)
if not request.text or not request.text.strip():
raise HTTPException(status_code=400, detail="Text is required")
logger.info(
f"[StoryWriter] Generating AI audio for scene {request.scene_number} "
f"({request.scene_title}) for user {user_id}"
)
result = audio_service.generate_ai_audio(
scene_number=request.scene_number,
scene_title=request.scene_title,
text=request.text.strip(),
user_id=user_id,
voice_id=request.voice_id or "Wise_Woman",
speed=request.speed or 1.0,
volume=request.volume or 1.0,
pitch=request.pitch or 0.0,
emotion=request.emotion or "happy",
)
return GenerateAIAudioResponse(
scene_number=result.get("scene_number", request.scene_number),
scene_title=result.get("scene_title", request.scene_title),
audio_filename=result.get("audio_filename", ""),
audio_url=result.get("audio_url", ""),
provider=result.get("provider", "wavespeed"),
model=result.get("model", "minimax/speech-02-hd"),
voice_id=result.get("voice_id", request.voice_id or "Wise_Woman"),
text_length=result.get("text_length", len(request.text)),
file_size=result.get("file_size", 0),
cost=result.get("cost", 0.0),
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate AI audio: {exc}")
return GenerateAIAudioResponse(
scene_number=request.scene_number,
scene_title=request.scene_title,
audio_filename="",
audio_url="",
provider="wavespeed",
model="minimax/speech-02-hd",
voice_id=request.voice_id or "Wise_Woman",
text_length=len(request.text) if request.text else 0,
file_size=0,
cost=0.0,
success=False,
error=str(exc),
)
@router.get("/audio/{audio_filename}")
async def serve_scene_audio(
audio_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve a generated story scene audio file."""
try:
require_authenticated_user(current_user)
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
return FileResponse(path=str(audio_path), media_type="audio/mpeg", filename=audio_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve audio: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,195 @@
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from middleware.auth_middleware import get_current_user
from models.story_models import (
StoryStartRequest,
StoryContentResponse,
StoryScene,
StoryContinueRequest,
StoryContinueResponse,
)
from services.story_writer.story_service import StoryWriterService
from ..utils.auth import require_authenticated_user
router = APIRouter()
story_service = StoryWriterService()
@router.post("/generate-start", response_model=StoryContentResponse)
async def generate_story_start(
request: StoryStartRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryContentResponse:
"""Generate the starting section of a story."""
try:
user_id = require_authenticated_user(current_user)
if not request.premise or not request.premise.strip():
raise HTTPException(status_code=400, detail="Premise is required")
if not request.outline or (isinstance(request.outline, str) and not request.outline.strip()):
raise HTTPException(status_code=400, detail="Outline is required")
logger.info(f"[StoryWriter] Generating story start for user {user_id}")
outline_data: Any = request.outline
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
outline_data = [scene.dict() for scene in outline_data]
story_length = getattr(request, "story_length", "Medium")
story_start = story_service.generate_story_start(
premise=request.premise,
outline=outline_data,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
story_length=story_length,
user_id=user_id,
)
story_length_lower = story_length.lower()
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
is_complete = False
if is_short_story:
word_count = len(story_start.split()) if story_start else 0
if word_count >= 900:
is_complete = True
logger.info(
f"[StoryWriter] Short story generated with {word_count} words. Marking as complete."
)
else:
logger.warning(
f"[StoryWriter] Short story generated with only {word_count} words. May need continuation."
)
outline_response = outline_data
if isinstance(outline_response, list):
outline_response = "\n".join(
[
f"Scene {scene.get('scene_number', i + 1)}: "
f"{scene.get('title', 'Untitled')}\n {scene.get('description', '')}"
for i, scene in enumerate(outline_response)
]
)
return StoryContentResponse(
story=story_start,
premise=request.premise,
outline=str(outline_response),
is_complete=is_complete,
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate story start: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/continue", response_model=StoryContinueResponse)
async def continue_story(
request: StoryContinueRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryContinueResponse:
"""Continue writing a story."""
try:
user_id = require_authenticated_user(current_user)
if not request.story_text or not request.story_text.strip():
raise HTTPException(status_code=400, detail="Story text is required")
logger.info(f"[StoryWriter] Continuing story for user {user_id}")
outline_data: Any = request.outline
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
outline_data = [scene.dict() for scene in outline_data]
story_length = getattr(request, "story_length", "Medium")
story_length_lower = story_length.lower()
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
if is_short_story:
logger.warning(
"[StoryWriter] Attempted to continue a short story. Short stories should be complete in one call."
)
raise HTTPException(
status_code=400,
detail="Short stories are generated in a single call and should be complete. "
"If the story is incomplete, please regenerate it from the beginning.",
)
current_word_count = len(request.story_text.split()) if request.story_text else 0
if "long" in story_length_lower or "10000" in story_length_lower:
target_total_words = 10000
else:
target_total_words = 4500
buffer_target = int(target_total_words * 1.05)
if current_word_count >= buffer_target or (
current_word_count >= target_total_words
and (current_word_count - target_total_words) < 50
):
logger.info(
f"[StoryWriter] Word count ({current_word_count}) already at or near target ({target_total_words})."
)
return StoryContinueResponse(continuation="IAMDONE", is_complete=True, success=True)
continuation = story_service.continue_story(
premise=request.premise,
outline=outline_data,
story_text=request.story_text,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
story_length=story_length,
user_id=user_id,
)
is_complete = "IAMDONE" in continuation.upper()
if not is_complete and continuation:
new_story_text = request.story_text + "\n\n" + continuation
new_word_count = len(new_story_text.split())
if new_word_count >= buffer_target:
logger.info(
f"[StoryWriter] Word count ({new_word_count}) now exceeds buffer target ({buffer_target})."
)
if "IAMDONE" not in continuation.upper():
continuation = continuation.rstrip() + "\n\nIAMDONE"
is_complete = True
elif new_word_count >= target_total_words and (
new_word_count - target_total_words
) < 100:
logger.info(
f"[StoryWriter] Word count ({new_word_count}) is at or very close to target ({target_total_words})."
)
if "IAMDONE" not in continuation.upper():
continuation = continuation.rstrip() + "\n\nIAMDONE"
is_complete = True
return StoryContinueResponse(continuation=continuation, is_complete=is_complete, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to continue story: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,141 @@
from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from middleware.auth_middleware import get_current_user
from models.story_models import (
StorySetupGenerationRequest,
StorySetupGenerationResponse,
StorySetupOption,
StoryGenerationRequest,
StoryOutlineResponse,
StoryScene,
StoryStartRequest,
StoryPremiseResponse,
)
from services.story_writer.story_service import StoryWriterService
from ..utils.auth import require_authenticated_user
router = APIRouter()
story_service = StoryWriterService()
@router.post("/generate-setup", response_model=StorySetupGenerationResponse)
async def generate_story_setup(
request: StorySetupGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StorySetupGenerationResponse:
"""Generate 3 story setup options from a user's story idea."""
try:
user_id = require_authenticated_user(current_user)
if not request.story_idea or not request.story_idea.strip():
raise HTTPException(status_code=400, detail="Story idea is required")
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
options = story_service.generate_story_setup_options(
story_idea=request.story_idea,
user_id=user_id,
)
setup_options = [StorySetupOption(**option) for option in options]
return StorySetupGenerationResponse(options=setup_options, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate story setup options: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-premise", response_model=StoryPremiseResponse)
async def generate_premise(
request: StoryGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryPremiseResponse:
"""Generate a story premise."""
try:
user_id = require_authenticated_user(current_user)
logger.info(f"[StoryWriter] Generating premise for user {user_id}")
premise = story_service.generate_premise(
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
user_id=user_id,
)
return StoryPremiseResponse(premise=premise, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate premise: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-outline", response_model=StoryOutlineResponse)
async def generate_outline(
request: StoryStartRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
use_structured: bool = True,
) -> StoryOutlineResponse:
"""Generate a story outline from a premise."""
try:
user_id = require_authenticated_user(current_user)
if not request.premise or not request.premise.strip():
raise HTTPException(status_code=400, detail="Premise is required")
logger.info(
f"[StoryWriter] Generating outline for user {user_id} (structured={use_structured})"
)
logger.info(
"[StoryWriter] Outline params: audience_age_group=%s, writing_style=%s, story_tone=%s",
request.audience_age_group,
request.writing_style,
request.story_tone,
)
outline = story_service.generate_outline(
premise=request.premise,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
user_id=user_id,
use_structured_output=use_structured,
)
if isinstance(outline, list):
scenes: List[StoryScene] = [
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline
]
return StoryOutlineResponse(outline=scenes, success=True, is_structured=True)
return StoryOutlineResponse(outline=str(outline), success=True, is_structured=False)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate outline: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,130 @@
from typing import Any, Dict
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from loguru import logger
from middleware.auth_middleware import get_current_user
from models.story_models import (
StoryGenerationRequest,
TaskStatus,
)
from services.story_writer.story_service import StoryWriterService
from ..cache_manager import cache_manager
from ..task_manager import task_manager
from ..utils.auth import require_authenticated_user
router = APIRouter()
story_service = StoryWriterService()
@router.post("/generate-full", response_model=Dict[str, Any])
async def generate_full_story(
request: StoryGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
max_iterations: int = 10,
) -> Dict[str, Any]:
"""Generate a complete story asynchronously."""
try:
user_id = require_authenticated_user(current_user)
cache_key = cache_manager.get_cache_key(request.dict())
cached_result = cache_manager.get_cached_result(cache_key)
if cached_result:
logger.info(f"[StoryWriter] Returning cached result for user {user_id}")
task_id = task_manager.create_task("story_generation")
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
result=cached_result,
message="Returned cached result",
)
return {"task_id": task_id, "cached": True}
task_id = task_manager.create_task("story_generation")
request_data = request.dict()
request_data["max_iterations"] = max_iterations
background_tasks.add_task(
task_manager.execute_story_generation_task,
task_id=task_id,
request_data=request_data,
user_id=user_id,
)
logger.info(f"[StoryWriter] Created task {task_id} for full story generation (user {user_id})")
return {
"task_id": task_id,
"status": "pending",
"message": "Story generation started. Use /task/{task_id}/status to check progress.",
}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to start story generation: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/task/{task_id}/status", response_model=TaskStatus)
async def get_task_status(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> TaskStatus:
"""Get the status of a story generation task."""
try:
require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
return TaskStatus(**task_status)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to get task status: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/task/{task_id}/result")
async def get_task_result(
task_id: str,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Get the result of a completed story generation task."""
try:
require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id)
if not task_status:
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
if task_status["status"] != "completed":
raise HTTPException(
status_code=400,
detail=f"Task {task_id} is not completed. Status: {task_status['status']}",
)
result = task_status.get("result")
if not result:
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
if isinstance(result, dict):
payload = {**result}
payload.setdefault("success", True)
payload["task_id"] = task_id
return payload
return {"result": result, "success": True, "task_id": task_id}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to get task result: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -0,0 +1,511 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import FileResponse
from loguru import logger
from pydantic import BaseModel
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from models.story_models import (
StoryVideoGenerationRequest,
StoryVideoGenerationResponse,
StoryVideoResult,
StoryScene,
StoryGenerationRequest,
)
from services.story_writer.video_generation_service import StoryVideoGenerationService
from services.story_writer.image_generation_service import StoryImageGenerationService
from services.story_writer.audio_generation_service import StoryAudioGenerationService
from services.story_writer.story_service import StoryWriterService
from ..task_manager import task_manager
from ..utils.auth import require_authenticated_user
from ..utils.hd_video import (
generate_hd_video_payload,
generate_hd_video_scene_payload,
)
from ..utils.media_utils import resolve_media_file
router = APIRouter()
video_service = StoryVideoGenerationService()
image_service = StoryImageGenerationService()
audio_service = StoryAudioGenerationService()
story_service = StoryWriterService()
class HDVideoRequest(BaseModel):
prompt: str
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
class HDVideoSceneRequest(BaseModel):
scene_number: int
scene_data: Dict[str, Any]
story_context: Dict[str, Any]
all_scenes: List[Dict[str, Any]]
provider: str = "huggingface"
model: str | None = None
num_frames: int | None = None
guidance_scale: float | None = None
num_inference_steps: int | None = None
negative_prompt: str | None = None
seed: int | None = None
@router.post("/generate-video", response_model=StoryVideoGenerationResponse)
async def generate_story_video(
request: StoryVideoGenerationRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryVideoGenerationResponse:
"""Generate a video from story scenes, images, and audio."""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(
status_code=400,
detail="Number of scenes, image URLs, and audio URLs must match",
)
logger.info(f"[StoryWriter] Generating video for {len(request.scenes)} scenes for user {user_id}")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
video_paths: List[Optional[str]] = [] # Animated videos (preferred)
image_paths: List[Optional[str]] = [] # Static images (fallback)
audio_paths: List[str] = []
valid_scenes: List[Dict[str, Any]] = []
# Resolve video/audio directories
base_dir = Path(__file__).parent.parent.parent.parent
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
video_urls = request.video_urls or [None] * len(request.scenes)
ai_audio_urls = request.ai_audio_urls or [None] * len(request.scenes)
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
# Prefer animated video if available
video_url = video_urls[idx] if idx < len(video_urls) else None
video_path = None
image_path = None
if video_url:
# Extract filename from animated video URL (e.g., /api/story/videos/ai/filename.mp4)
video_filename = video_url.split("/")[-1].split("?")[0]
video_path = ai_video_dir / video_filename
if video_path.exists():
logger.info(f"[StoryWriter] Using animated video for scene {scene.get('scene_number', idx+1)}: {video_filename}")
video_paths.append(str(video_path))
image_paths.append(None)
else:
logger.warning(f"[StoryWriter] Animated video not found: {video_path}, falling back to image")
video_paths.append(None)
video_path = None
# Fall back to image if no animated video
if not video_path:
image_filename = image_url.split("/")[-1].split("?")[0]
image_path = image_service.output_dir / image_filename
if image_path.exists():
video_paths.append(None)
image_paths.append(str(image_path))
else:
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
# Prefer AI audio if available, otherwise use free audio
ai_audio_url = ai_audio_urls[idx] if idx < len(ai_audio_urls) else None
audio_filename = None
audio_path = None
if ai_audio_url:
audio_filename = ai_audio_url.split("/")[-1].split("?")[0]
audio_path = audio_service.output_dir / audio_filename
if audio_path.exists():
logger.info(f"[StoryWriter] Using AI audio for scene {scene.get('scene_number', idx+1)}: {audio_filename}")
else:
logger.warning(f"[StoryWriter] AI audio not found: {audio_path}, falling back to free audio")
audio_path = None
# Fall back to free audio if no AI audio
if not audio_path:
audio_filename = audio_url.split("/")[-1].split("?")[0]
audio_path = audio_service.output_dir / audio_filename
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if len(valid_scenes) == 0 or len(audio_paths) == 0:
raise HTTPException(status_code=400, detail="No valid video/image or audio files were found")
if len(valid_scenes) != len(audio_paths):
raise HTTPException(
status_code=400,
detail="Number of valid scenes and audio files must match",
)
video_result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths, # Can contain None for scenes with animated videos
video_paths=video_paths, # Can contain None for scenes with static images
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
)
video_model = StoryVideoResult(
video_filename=video_result.get("video_filename", ""),
video_url=video_result.get("video_url", ""),
duration=video_result.get("duration", 0.0),
fps=video_result.get("fps", 24),
file_size=video_result.get("file_size", 0),
num_scenes=video_result.get("num_scenes", 0),
error=video_result.get("error"),
)
return StoryVideoGenerationResponse(video=video_model, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate video: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-video-async", response_model=Dict[str, Any])
async def generate_story_video_async(
request: StoryVideoGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""
Generate a video asynchronously with progress updates via task manager.
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
"""
try:
user_id = require_authenticated_user(current_user)
if not request.scenes or len(request.scenes) == 0:
raise HTTPException(status_code=400, detail="At least one scene is required")
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
raise HTTPException(
status_code=400,
detail="Number of scenes, image URLs, and audio URLs must match",
)
task_id = task_manager.create_task("story_video_generation")
background_tasks.add_task(
_execute_video_generation_task,
task_id=task_id,
request=request,
user_id=user_id,
)
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to start async video generation: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
"""Background task to generate story video with progress mapped to task manager."""
try:
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
image_paths: List[str] = []
audio_paths: List[str] = []
valid_scenes: List[Dict[str, Any]] = []
for scene, image_url, audio_url in zip(scenes_data, request.image_urls, request.audio_urls):
image_filename = image_url.split("/")[-1].split("?")[0]
audio_filename = audio_url.split("/")[-1].split("?")[0]
image_path = image_service.output_dir / image_filename
audio_path = audio_service.output_dir / audio_filename
if not image_path.exists():
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
continue
if not audio_path.exists():
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
continue
image_paths.append(str(image_path))
audio_paths.append(str(audio_path))
valid_scenes.append(scene)
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
def progress_callback(sub_progress: float, msg: str):
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request.story_title or "Story",
fps=request.fps or 24,
transition_duration=request.transition_duration or 0.5,
progress_callback=progress_callback,
)
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Video generation complete!",
result={"video": result, "success": True},
)
except Exception as exc:
logger.error(f"[StoryWriter] Async video generation failed: {exc}", exc_info=True)
task_manager.update_task_status(task_id, "failed", error=str(exc), message=f"Video generation failed: {exc}")
@router.post("/hd-video")
async def generate_hd_video(
request: HDVideoRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_payload(request, user_id)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate HD video: {exc}", exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/hd-video-scene")
async def generate_hd_video_scene(
request: HDVideoSceneRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
try:
user_id = require_authenticated_user(current_user)
return generate_hd_video_scene_payload(request, user_id)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {exc}", exc_info=True)
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/generate-complete-video", response_model=Dict[str, Any])
async def generate_complete_story_video(
request: StoryGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Generate a complete story video workflow asynchronously."""
try:
user_id = require_authenticated_user(current_user)
logger.info(f"[StoryWriter] Starting complete video generation for user {user_id}")
task_id = task_manager.create_task("complete_video_generation")
background_tasks.add_task(
execute_complete_video_generation,
task_id=task_id,
request_data=request.dict(),
user_id=user_id,
)
return {
"task_id": task_id,
"status": "pending",
"message": "Complete video generation started",
}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to start complete video generation: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
def execute_complete_video_generation(
task_id: str,
request_data: Dict[str, Any],
user_id: str,
):
"""
Execute complete video generation workflow synchronously.
Runs in a background task and performs blocking operations.
"""
try:
task_manager.update_task_status(task_id, "processing", progress=5.0, message="Starting complete video generation...")
task_manager.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
premise = story_service.generate_premise(
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id,
)
task_manager.update_task_status(task_id, "processing", progress=20.0, message="Generating structured outline with scenes...")
outline_scenes = story_service.generate_outline(
premise=premise,
persona=request_data["persona"],
story_setting=request_data["story_setting"],
character_input=request_data["character_input"],
plot_elements=request_data["plot_elements"],
writing_style=request_data["writing_style"],
story_tone=request_data["story_tone"],
narrative_pov=request_data["narrative_pov"],
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id,
use_structured_output=True,
)
if not isinstance(outline_scenes, list):
raise RuntimeError("Failed to generate structured outline")
task_manager.update_task_status(task_id, "processing", progress=30.0, message="Generating images for scenes...")
def image_progress_callback(sub_progress: float, message: str):
overall_progress = 30.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
image_results = image_service.generate_scene_images(
scenes=outline_scenes,
user_id=user_id,
provider=request_data.get("image_provider"),
width=request_data.get("image_width", 1024),
height=request_data.get("image_height", 1024),
model=request_data.get("image_model"),
progress_callback=image_progress_callback,
)
task_manager.update_task_status(task_id, "processing", progress=50.0, message="Generating audio narration for scenes...")
def audio_progress_callback(sub_progress: float, message: str):
overall_progress = 50.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
audio_results = audio_service.generate_scene_audio_list(
scenes=outline_scenes,
user_id=user_id,
provider=request_data.get("audio_provider", "gtts"),
lang=request_data.get("audio_lang", "en"),
slow=request_data.get("audio_slow", False),
rate=request_data.get("audio_rate", 150),
progress_callback=audio_progress_callback,
)
task_manager.update_task_status(task_id, "processing", progress=70.0, message="Preparing video assets...")
image_paths: List[str] = []
audio_paths: List[str] = []
valid_scenes: List[Dict[str, Any]] = []
for scene in outline_scenes:
scene_number = scene.get("scene_number", 0)
image_result = next((img for img in image_results if img.get("scene_number") == scene_number), None)
audio_result = next((aud for aud in audio_results if aud.get("scene_number") == scene_number), None)
if image_result and audio_result and not image_result.get("error") and not audio_result.get("error"):
image_path = image_result.get("image_path")
audio_path = audio_result.get("audio_path")
if image_path and audio_path:
image_paths.append(image_path)
audio_paths.append(audio_path)
valid_scenes.append(scene)
if len(image_paths) == 0 or len(audio_paths) == 0:
raise RuntimeError(
f"No valid images or audio files were generated. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
)
if len(image_paths) != len(audio_paths):
raise RuntimeError(
f"Mismatch between image and audio counts. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
)
task_manager.update_task_status(task_id, "processing", progress=75.0, message="Composing video from scenes...")
def video_progress_callback(sub_progress: float, message: str):
overall_progress = 75.0 + (sub_progress * 0.2)
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
video_result = video_service.generate_story_video(
scenes=valid_scenes,
image_paths=image_paths,
audio_paths=audio_paths,
user_id=user_id,
story_title=request_data.get("story_setting", "Story")[:50],
fps=request_data.get("video_fps", 24),
transition_duration=request_data.get("video_transition_duration", 0.5),
progress_callback=video_progress_callback,
)
result = {
"premise": premise,
"outline_scenes": outline_scenes,
"images": image_results,
"audio_files": audio_results,
"video": video_result,
"success": True,
}
task_manager.update_task_status(
task_id,
"completed",
progress=100.0,
message="Complete video generation finished!",
result=result,
)
logger.info(f"[StoryWriter] Complete video generation task {task_id} completed successfully")
except Exception as exc:
error_msg = str(exc)
logger.error(f"[StoryWriter] Complete video generation task {task_id} failed: {error_msg}", exc_info=True)
task_manager.update_task_status(
task_id,
"failed",
error=error_msg,
message=f"Complete video generation failed: {error_msg}",
)
@router.get("/videos/{video_filename}")
async def serve_story_video(
video_filename: str,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Serve a generated story video file."""
try:
require_authenticated_user(current_user)
video_path = resolve_media_file(video_service.output_dir, video_filename)
return FileResponse(path=str(video_path), media_type="video/mp4", filename=video_filename)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to serve video: {exc}")
raise HTTPException(status_code=500, detail=str(exc))

View File

@@ -1,13 +1,11 @@
from __future__ import annotations
from typing import Any, Dict, Optional
from typing import Any, Dict
from fastapi import HTTPException
from loguru import logger
from uuid import uuid4
from .media_utils import load_story_image_bytes
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""Handles synchronous HD video generation."""
@@ -57,8 +55,8 @@ def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
"""
Handles per-scene HD video generation including prompt enhancement,
subscription validation, and optional image conditioning.
Handles per-scene HD video generation including prompt enhancement
and subscription validation.
"""
from services.database import get_db as get_db_validation
from services.onboarding.api_key_manager import APIKeyManager
@@ -71,7 +69,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
scene_number = request.scene_number
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
# Step 1: Validate API key
hf_token = APIKeyManager().get_api_key("hf_token")
if not hf_token:
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
@@ -83,7 +80,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
},
)
# Step 2: Subscription limits
db_validation = next(get_db_validation())
try:
pricing_service = PricingService(db_validation)
@@ -93,7 +89,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
finally:
db_validation.close()
# Stage 1: Prompt enhancement
enhanced_prompt = enhance_scene_prompt_for_video(
current_scene=request.scene_data,
story_context=request.story_context,
@@ -102,15 +97,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
)
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
# Stage 2: Optional image reference
scene_image_bytes: Optional[bytes] = None
if getattr(request, "scene_image_url", None):
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
if scene_image_bytes:
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
else:
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
kwargs: Dict[str, Any] = {}
if getattr(request, "model", None):
kwargs["model"] = request.model
@@ -129,7 +115,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
prompt=enhanced_prompt,
provider=getattr(request, "provider", None) or "huggingface",
user_id=user_id,
input_image_bytes=scene_image_bytes,
**kwargs,
)
@@ -151,4 +136,3 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
}

View File

@@ -11,6 +11,8 @@ from loguru import logger
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
STORY_AUDIO_DIR = (BASE_DIR / "story_audio").resolve()
STORY_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
@@ -48,6 +50,41 @@ def load_story_image_bytes(image_url: str) -> Optional[bytes]:
return None
def load_story_audio_bytes(audio_url: str) -> Optional[bytes]:
"""
Resolve an authenticated story audio URL (e.g., /api/story/audio/<file>) to raw bytes.
Returns None if the file cannot be located.
"""
if not audio_url:
return None
try:
parsed = urlparse(audio_url)
path = parsed.path if parsed.scheme else audio_url
prefix = "/api/story/audio/"
if prefix not in path:
logger.warning(f"[StoryWriter] Unsupported audio URL for video reference: {audio_url}")
return None
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
if not filename:
return None
file_path = (STORY_AUDIO_DIR / filename).resolve()
if not str(file_path).startswith(str(STORY_AUDIO_DIR)):
logger.error(f"[StoryWriter] Attempted path traversal when resolving audio: {audio_url}")
return None
if not file_path.exists():
logger.warning(f"[StoryWriter] Referenced scene audio not found on disk: {file_path}")
return None
return file_path.read_bytes()
except Exception as exc:
logger.error(f"[StoryWriter] Failed to load reference audio for video gen: {exc}")
return None
def resolve_media_file(base_dir: Path, filename: str) -> Path:
"""
Returns a safe resolved path for a media file stored under base_dir.
@@ -62,8 +99,50 @@ def resolve_media_file(base_dir: Path, filename: str) -> Path:
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
if not resolved.exists():
alternate = _find_alternate_media_file(base_dir, filename)
if alternate:
logger.warning(
"[StoryWriter] Requested media file '%s' missing; serving closest match '%s'",
filename,
alternate.name,
)
return alternate
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
return resolved
def _find_alternate_media_file(base_dir: Path, filename: str) -> Optional[Path]:
"""
Attempt to find the most recent media file that matches the original name prefix.
This helps when files are regenerated with new UUID/hash suffixes but the frontend still
references an older filename.
"""
try:
base_dir = base_dir.resolve()
except Exception:
return None
stem = Path(filename).stem
suffix = Path(filename).suffix
if not suffix or "_" not in stem:
return None
prefix = stem.rsplit("_", 1)[0]
pattern = f"{prefix}_*{suffix}"
try:
candidates = sorted(
(p for p in base_dir.glob(pattern) if p.is_file()),
key=lambda p: p.stat().st_mtime,
reverse=True,
)
except Exception as exc:
logger.debug(f"[StoryWriter] Failed to search alternate media files for {filename}: {exc}")
return None
return candidates[0] if candidates else None

View File

@@ -4,6 +4,7 @@ Provides endpoints for subscription management and usage monitoring.
"""
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from sqlalchemy import desc, func
from typing import Dict, Any, Optional, List
@@ -116,6 +117,7 @@ async def get_subscription_plans(
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
"gemini_tokens": plan.gemini_tokens_limit,
"openai_tokens": plan.openai_tokens_limit,
"anthropic_tokens": plan.anthropic_tokens_limit,
@@ -134,7 +136,7 @@ async def get_subscription_plans(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
@@ -241,6 +243,7 @@ async def get_user_subscription(
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -340,6 +343,7 @@ async def get_subscription_status(
"stability_calls": free_plan.stability_calls_limit,
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
"monthly_cost": free_plan.monthly_cost_limit
}
}
@@ -405,7 +409,7 @@ async def get_subscription_status(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
# Try to fix schema and retry once
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
try:
@@ -499,6 +503,7 @@ async def get_subscription_status(
"stability_calls": plan.stability_calls_limit,
"video_calls": getattr(plan, 'video_calls_limit', 0),
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
"monthly_cost": plan.monthly_cost_limit
}
}
@@ -988,7 +993,7 @@ async def get_dashboard_data(
except (sqlite3.OperationalError, Exception) as e:
error_str = str(e).lower()
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str or 'audio_calls' in error_str or 'audio_cost' in error_str):
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
try:
import services.subscription.schema_utils as schema_utils
@@ -1271,4 +1276,235 @@ async def get_usage_logs(
raise
except Exception as e:
logger.error(f"Error getting usage logs: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
class PreflightOperationRequest(BaseModel):
"""Request model for pre-flight check operation."""
provider: str
model: Optional[str] = None
tokens_requested: Optional[int] = 0
operation_type: str
actual_provider_name: Optional[str] = None
class PreflightCheckRequest(BaseModel):
"""Request model for pre-flight check."""
operations: List[PreflightOperationRequest]
@router.post("/preflight-check")
async def preflight_check(
request: PreflightCheckRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user)
) -> Dict[str, Any]:
"""
Pre-flight check for operations with cost estimation.
Lightweight endpoint that:
- Validates if operations are allowed based on subscription limits
- Estimates cost for operations
- Returns usage information and remaining quota
Uses caching to minimize DB load (< 100ms with cache hit).
"""
try:
user_id = str(current_user.get('id', ''))
if not user_id:
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
# Ensure schema columns exist
try:
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)
except Exception as schema_err:
logger.warning(f"Schema check failed: {schema_err}")
pricing_service = PricingService(db)
# Convert request operations to internal format
operations_to_validate = []
for op in request.operations:
try:
# Map provider string to APIProvider enum
provider_str = op.provider.lower()
if provider_str == "huggingface":
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
elif provider_str == "video":
provider_enum = APIProvider.VIDEO
elif provider_str == "image_edit":
provider_enum = APIProvider.IMAGE_EDIT
elif provider_str == "stability":
provider_enum = APIProvider.STABILITY
elif provider_str == "audio":
provider_enum = APIProvider.AUDIO
else:
try:
provider_enum = APIProvider(provider_str)
except ValueError:
logger.warning(f"Unknown provider: {provider_str}, skipping")
continue
operations_to_validate.append({
'provider': provider_enum,
'tokens_requested': op.tokens_requested or 0,
'actual_provider_name': op.actual_provider_name or op.provider,
'operation_type': op.operation_type
})
except Exception as e:
logger.warning(f"Error processing operation {op.operation_type}: {e}")
continue
if not operations_to_validate:
raise HTTPException(status_code=400, detail="No valid operations provided")
# Perform pre-flight validation
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate
)
# Get pricing and cost estimation for each operation
operation_results = []
total_cost = 0.0
for i, op in enumerate(operations_to_validate):
op_result = {
'provider': op['actual_provider_name'],
'operation_type': op['operation_type'],
'cost': 0.0,
'allowed': can_proceed,
'limit_info': None,
'message': None
}
# Get pricing for this operation
model_name = request.operations[i].model
if model_name:
pricing_info = pricing_service.get_pricing_for_provider_model(
op['provider'],
model_name
)
if pricing_info:
# Determine cost based on operation type
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
elif op['provider'] == APIProvider.AUDIO:
# Audio pricing is per character (every character is 1 token)
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0)
elif op['tokens_requested'] > 0:
# Token-based cost estimation (rough estimate)
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000)
else:
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
op_result['cost'] = round(cost, 4)
total_cost += cost
else:
# Use default cost if pricing not found
if op['provider'] == APIProvider.VIDEO:
op_result['cost'] = 0.10 # Default video cost
total_cost += 0.10
elif op['provider'] == APIProvider.IMAGE_EDIT:
op_result['cost'] = 0.05 # Default image edit cost
total_cost += 0.05
elif op['provider'] == APIProvider.STABILITY:
op_result['cost'] = 0.04 # Default image generation cost
total_cost += 0.04
elif op['provider'] == APIProvider.AUDIO:
# Default audio cost: $0.05 per 1,000 characters
cost = (op['tokens_requested'] / 1000.0) * 0.05
op_result['cost'] = round(cost, 4)
total_cost += cost
# Get limit information
limit_info = None
if error_details and not can_proceed:
usage_info = error_details.get('usage_info', {})
if usage_info:
op_result['message'] = message
limit_info = {
'current_usage': usage_info.get('current_usage', 0),
'limit': usage_info.get('limit', 0),
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
}
op_result['limit_info'] = limit_info
else:
# Get current usage for this provider
limits = pricing_service.get_user_limits(user_id)
if limits:
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
).first()
if usage_summary:
if op['provider'] == APIProvider.VIDEO:
current = getattr(usage_summary, 'video_calls', 0) or 0
limit = limits['limits'].get('video_calls', 0)
elif op['provider'] == APIProvider.IMAGE_EDIT:
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
limit = limits['limits'].get('image_edit_calls', 0)
elif op['provider'] == APIProvider.STABILITY:
current = getattr(usage_summary, 'stability_calls', 0) or 0
limit = limits['limits'].get('stability_calls', 0)
elif op['provider'] == APIProvider.AUDIO:
current = getattr(usage_summary, 'audio_calls', 0) or 0
limit = limits['limits'].get('audio_calls', 0)
else:
# For LLM providers, use token limits
provider_key = op['provider'].value
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
current = current_tokens
limit_info = {
'current_usage': current,
'limit': limit,
'remaining': max(0, limit - current) if limit > 0 else float('inf')
}
op_result['limit_info'] = limit_info
operation_results.append(op_result)
# Get overall usage summary
limits = pricing_service.get_user_limits(user_id)
usage_summary = None
if limits:
usage_summary = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
).first()
response_data = {
'can_proceed': can_proceed,
'estimated_cost': round(total_cost, 4),
'operations': operation_results,
'total_cost': round(total_cost, 4),
'usage_summary': None,
'cached': False # TODO: Track if result was cached
}
if usage_summary and limits:
# For video generation, show video limits
video_current = getattr(usage_summary, 'video_calls', 0) or 0
video_limit = limits['limits'].get('video_calls', 0)
response_data['usage_summary'] = {
'current_calls': video_current,
'limit': video_limit,
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
}
return {
"success": True,
"data": response_data
}
except HTTPException:
raise
except Exception as e:
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")