AI story writer enhancements, text to video and voice generation, subscription management, and more.
This commit is contained in:
@@ -134,6 +134,12 @@ def generate(
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.generate] ✅ Successfully tracked usage: user {user_id} -> stability -> {new_calls} calls")
|
||||
|
||||
@@ -148,6 +154,7 @@ def generate(
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
@@ -437,6 +444,12 @@ def edit(
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[images.edit] ✅ Successfully tracked usage: user {user_id} -> image_edit -> {new_calls} calls")
|
||||
|
||||
@@ -451,6 +464,7 @@ def edit(
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Images: {current_image_gen_calls} / {image_gen_limit if image_gen_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
|
||||
@@ -5,12 +5,19 @@ Main router for story generation operations including premise, outline,
|
||||
content generation, and full story creation.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, BackgroundTasks
|
||||
from typing import Any, Dict, Union, List, Optional
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request
|
||||
from loguru import logger
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
|
||||
from models.story_models import (
|
||||
AnimateSceneRequest,
|
||||
AnimateSceneVoiceoverRequest,
|
||||
AnimateSceneResponse,
|
||||
ResumeSceneAnimationRequest,
|
||||
StoryGenerationRequest,
|
||||
StorySetupGenerationRequest,
|
||||
StorySetupGenerationResponse,
|
||||
@@ -34,24 +41,66 @@ from models.story_models import (
|
||||
StoryVideoResult,
|
||||
TaskStatus,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from services.database import get_db
|
||||
from services.llm_providers.main_video_generation import track_video_usage
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
from .task_manager import task_manager
|
||||
from .cache_manager import cache_manager
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_scene_animation_operation
|
||||
from services.wavespeed.kling_animation import animate_scene_image, resume_scene_animation
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
from uuid import uuid4
|
||||
from pydantic import BaseModel
|
||||
from pathlib import Path
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from .cache_manager import cache_manager
|
||||
from .routes import cache_routes, media_generation, story_content, story_setup, story_tasks, video_generation
|
||||
from .task_manager import task_manager
|
||||
from .utils.auth import require_authenticated_user
|
||||
from .utils.media_utils import resolve_media_file
|
||||
from .utils.hd_video import (
|
||||
generate_hd_video_payload,
|
||||
generate_hd_video_scene_payload,
|
||||
)
|
||||
from .utils.hd_video import generate_hd_video_payload, generate_hd_video_scene_payload
|
||||
from .utils.media_utils import load_story_image_bytes, load_story_audio_bytes, resolve_media_file
|
||||
from urllib.parse import quote
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/story", tags=["Story Writer"])
|
||||
|
||||
# Include modular routers (order preserved roughly by workflow)
|
||||
router.include_router(story_setup.router)
|
||||
router.include_router(story_content.router)
|
||||
router.include_router(story_tasks.router)
|
||||
router.include_router(media_generation.router)
|
||||
router.include_router(video_generation.router)
|
||||
router.include_router(cache_routes.router)
|
||||
|
||||
service = StoryWriterService()
|
||||
scene_logger = get_service_logger("api.story_writer.scene_animation")
|
||||
AI_VIDEO_SUBDIR = Path("AI_Videos")
|
||||
|
||||
|
||||
def _build_authenticated_media_url(request: Request, path: str) -> str:
|
||||
"""Append the caller's auth token to a media URL so <video>/<img> tags can access it."""
|
||||
if not path:
|
||||
return path
|
||||
|
||||
token: Optional[str] = None
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.replace("Bearer ", "").strip()
|
||||
elif "token" in request.query_params:
|
||||
token = request.query_params["token"]
|
||||
|
||||
if token:
|
||||
separator = "&" if "?" in path else "?"
|
||||
path = f"{path}{separator}token={quote(token)}"
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def _guess_mime_from_url(url: str, fallback: str) -> str:
|
||||
if not url:
|
||||
return fallback
|
||||
mime, _ = mimetypes.guess_type(url)
|
||||
return mime or fallback
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
@@ -558,6 +607,22 @@ async def get_task_result(
|
||||
logger.error(f"[StoryWriter] Failed to get task result: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
class PromptOptimizeRequest(BaseModel):
|
||||
text: str = Field(..., description="The prompt text to optimize")
|
||||
mode: Optional[str] = Field(default="image", pattern="^(image|video)$", description="Optimization mode: 'image' or 'video'")
|
||||
style: Optional[str] = Field(
|
||||
default="default",
|
||||
pattern="^(default|artistic|photographic|technical|anime|realistic)$",
|
||||
description="Style: 'default', 'artistic', 'photographic', 'technical', 'anime', or 'realistic'"
|
||||
)
|
||||
image: Optional[str] = Field(None, description="Base64-encoded image for context (optional)")
|
||||
|
||||
|
||||
class PromptOptimizeResponse(BaseModel):
|
||||
optimized_prompt: str
|
||||
success: bool
|
||||
|
||||
|
||||
class HDVideoRequest(BaseModel):
|
||||
prompt: str
|
||||
provider: str = "huggingface"
|
||||
@@ -692,6 +757,51 @@ async def generate_scene_images(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/optimize-prompt", response_model=PromptOptimizeResponse)
|
||||
async def optimize_prompt(
|
||||
request: PromptOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> PromptOptimizeResponse:
|
||||
"""Optimize an image prompt using WaveSpeed prompt optimizer."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
raise HTTPException(status_code=400, detail="Prompt text is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Optimizing prompt for user {user_id} (mode={request.mode}, style={request.style})")
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
|
||||
client = WaveSpeedClient()
|
||||
optimized_prompt = client.optimize_prompt(
|
||||
text=request.text.strip(),
|
||||
mode=request.mode or "image",
|
||||
style=request.style or "default",
|
||||
image=request.image, # Optional base64 image
|
||||
enable_sync_mode=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Prompt optimized successfully for user {user_id}")
|
||||
|
||||
return PromptOptimizeResponse(
|
||||
optimized_prompt=optimized_prompt,
|
||||
success=True
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to optimize prompt: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/images/{image_filename}")
|
||||
async def serve_scene_image(
|
||||
image_filename: str,
|
||||
@@ -793,32 +903,376 @@ async def generate_scene_audio(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/audio/{audio_filename}")
|
||||
async def serve_scene_audio(
|
||||
audio_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Serve a generated story scene audio file."""
|
||||
# Audio serving endpoint is handled by routes/media_generation.py
|
||||
# No duplicate endpoint needed here
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Scene Animation Endpoints
|
||||
# ---------------------------
|
||||
|
||||
|
||||
@router.post("/animate-scene-preview", response_model=AnimateSceneResponse)
|
||||
async def animate_scene_preview(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""
|
||||
Animate a single scene image using WaveSpeed Kling v2.5 Turbo Std.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
duration = request.duration or 5
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds.")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] User=%s scene=%s duration=%s image_url=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
duration,
|
||||
request.image_url,
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
scene_logger.warning("[AnimateScene] Missing image bytes for user=%s scene=%s", user_id, request.scene_number)
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
animation_result = animate_scene_image(
|
||||
image_bytes=image_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
audio_service = StoryAudioGenerationService()
|
||||
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
return FileResponse(
|
||||
path=str(audio_path),
|
||||
media_type="audio/mpeg",
|
||||
filename=audio_filename
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to serve audio: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Completed user=%s scene=%s duration=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["duration"],
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-resume", response_model=AnimateSceneResponse)
|
||||
async def resume_scene_animation_endpoint(
|
||||
request_obj: Request,
|
||||
request: ResumeSceneAnimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimateSceneResponse:
|
||||
"""Resume downloading a WaveSpeed animation when the initial call timed out."""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] Resume requested user=%s scene=%s prediction=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
)
|
||||
|
||||
animation_result = resume_scene_animation(
|
||||
prediction_id=request.prediction_id,
|
||||
duration=request.duration or 5,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
video_url = _build_authenticated_media_url(
|
||||
request_obj, f"/api/story/videos/ai/{video_filename}"
|
||||
)
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateScene] (Resume) Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateScene] ✅ Resume completed user=%s scene=%s prediction=%s video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.prediction_id,
|
||||
video_url,
|
||||
)
|
||||
|
||||
return AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/animate-scene-voiceover", response_model=Dict[str, Any])
|
||||
async def animate_scene_voiceover_endpoint(
|
||||
request_obj: Request,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene using WaveSpeed InfiniteTalk (image + audio) asynchronously.
|
||||
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
|
||||
"""
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", ""))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] User=%s scene=%s resolution=%s (async)",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
request.resolution or "720p",
|
||||
)
|
||||
|
||||
image_bytes = load_story_image_bytes(request.image_url)
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene image not found. Generate images first.")
|
||||
|
||||
audio_bytes = load_story_audio_bytes(request.audio_url)
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene audio not found. Generate audio first.")
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Extract token for authenticated URL building (if needed)
|
||||
auth_token = None
|
||||
auth_header = request_obj.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.replace("Bearer ", "").strip()
|
||||
|
||||
# Create async task
|
||||
task_id = task_manager.create_task("scene_voiceover_animation")
|
||||
background_tasks.add_task(
|
||||
_execute_voiceover_animation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
auth_token=auth_token,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "InfiniteTalk animation started. This may take up to 10 minutes.",
|
||||
}
|
||||
|
||||
|
||||
def _execute_voiceover_animation_task(
|
||||
task_id: str,
|
||||
request: AnimateSceneVoiceoverRequest,
|
||||
user_id: str,
|
||||
image_bytes: bytes,
|
||||
audio_bytes: bytes,
|
||||
auth_token: Optional[str] = None,
|
||||
):
|
||||
"""Background task to generate InfiniteTalk video with progress updates."""
|
||||
try:
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=5.0, message="Submitting to WaveSpeed InfiniteTalk..."
|
||||
)
|
||||
|
||||
animation_result = animate_scene_with_voiceover(
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
scene_data=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
user_id=user_id,
|
||||
resolution=request.resolution or "720p",
|
||||
prompt_override=request.prompt,
|
||||
image_mime=_guess_mime_from_url(request.image_url, "image/png"),
|
||||
audio_mime=_guess_mime_from_url(request.audio_url, "audio/mpeg"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id, "processing", progress=80.0, message="Saving video file..."
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
ai_video_dir = base_dir / "story_videos" / AI_VIDEO_SUBDIR
|
||||
ai_video_dir.mkdir(parents=True, exist_ok=True)
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
|
||||
save_result = video_service.save_scene_video(
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
scene_number=request.scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
video_filename = save_result["video_filename"]
|
||||
# Build authenticated URL if token provided, otherwise return plain URL
|
||||
video_url = f"/api/story/videos/ai/{video_filename}"
|
||||
if auth_token:
|
||||
video_url = f"{video_url}?token={quote(auth_token)}"
|
||||
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=animation_result["provider"],
|
||||
model_name=animation_result["model_name"],
|
||||
prompt=animation_result["prompt"],
|
||||
video_bytes=animation_result["video_bytes"],
|
||||
cost_override=animation_result["cost"],
|
||||
)
|
||||
if usage_info:
|
||||
scene_logger.warning(
|
||||
"[AnimateSceneVoiceover] Video usage tracked user=%s: %s → %s / %s (cost +$%.2f, total=$%.2f)",
|
||||
user_id,
|
||||
usage_info.get("previous_calls"),
|
||||
usage_info.get("current_calls"),
|
||||
usage_info.get("video_limit_display"),
|
||||
usage_info.get("cost_per_video", 0.0),
|
||||
usage_info.get("total_video_cost", 0.0),
|
||||
)
|
||||
|
||||
scene_logger.info(
|
||||
"[AnimateSceneVoiceover] ✅ Completed user=%s scene=%s cost=$%.2f video=%s",
|
||||
user_id,
|
||||
request.scene_number,
|
||||
animation_result["cost"],
|
||||
video_url,
|
||||
)
|
||||
|
||||
result = AnimateSceneResponse(
|
||||
success=True,
|
||||
scene_number=request.scene_number,
|
||||
video_filename=video_filename,
|
||||
video_url=video_url,
|
||||
duration=animation_result["duration"],
|
||||
cost=animation_result["cost"],
|
||||
prompt_used=animation_result["prompt"],
|
||||
provider=animation_result["provider"],
|
||||
prediction_id=animation_result.get("prediction_id"),
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="InfiniteTalk animation complete!",
|
||||
result=result.dict(),
|
||||
)
|
||||
except HTTPException as exc:
|
||||
error_msg = str(exc.detail) if isinstance(exc.detail, str) else exc.detail.get("error", "Animation failed") if isinstance(exc.detail, dict) else "Animation failed"
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Failed: {error_msg}")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation failed: {error_msg}",
|
||||
)
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
scene_logger.error(f"[AnimateSceneVoiceover] Error: {error_msg}", exc_info=True)
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"InfiniteTalk animation error: {error_msg}",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------
|
||||
@@ -1260,19 +1714,25 @@ def execute_complete_video_generation(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/videos/{video_filename}")
|
||||
async def serve_story_video(
|
||||
# Regular video serving endpoint is handled by routes/video_generation.py
|
||||
# Only AI videos need a separate endpoint here
|
||||
|
||||
|
||||
@router.get("/videos/ai/{video_filename}")
|
||||
async def serve_ai_story_video(
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""Serve a generated story video file."""
|
||||
"""Serve a generated AI scene animation video."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
video_service = StoryVideoGenerationService()
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
|
||||
video_service = StoryVideoGenerationService(output_dir=str(ai_video_dir))
|
||||
video_path = resolve_media_file(video_service.output_dir, video_filename)
|
||||
|
||||
return FileResponse(
|
||||
@@ -1284,7 +1744,7 @@ async def serve_story_video(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryWriter] Failed to serve video: {e}")
|
||||
logger.error(f"[StoryWriter] Failed to serve AI video: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
||||
21
backend/api/story_writer/routes/__init__.py
Normal file
21
backend/api/story_writer/routes/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Collection of modular routers for Story Writer endpoints.
|
||||
Each module focuses on a related set of routes to keep the primary
|
||||
`router.py` concise and easier to maintain.
|
||||
"""
|
||||
|
||||
from . import story_setup
|
||||
from . import story_content
|
||||
from . import story_tasks
|
||||
from . import media_generation
|
||||
from . import video_generation
|
||||
from . import cache_routes
|
||||
|
||||
__all__ = [
|
||||
"story_setup",
|
||||
"story_content",
|
||||
"story_tasks",
|
||||
"media_generation",
|
||||
"video_generation",
|
||||
"cache_routes",
|
||||
]
|
||||
42
backend/api/story_writer/routes/cache_routes.py
Normal file
42
backend/api/story_writer/routes/cache_routes.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
from ..cache_manager import cache_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/cache/stats")
|
||||
async def get_cache_stats(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get cache statistics."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
stats = cache_manager.get_cache_stats()
|
||||
return {"success": True, "stats": stats}
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to get cache stats: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_cache(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Clear the story generation cache."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
result = cache_manager.clear_cache()
|
||||
return {"success": True, **result}
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to clear cache: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
289
backend/api/story_writer/routes/media_generation.py
Normal file
289
backend/api/story_writer/routes/media_generation.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from models.story_models import (
|
||||
StoryImageGenerationRequest,
|
||||
StoryImageGenerationResponse,
|
||||
StoryImageResult,
|
||||
RegenerateImageRequest,
|
||||
RegenerateImageResponse,
|
||||
StoryAudioGenerationRequest,
|
||||
StoryAudioGenerationResponse,
|
||||
StoryAudioResult,
|
||||
GenerateAIAudioRequest,
|
||||
GenerateAIAudioResponse,
|
||||
StoryScene,
|
||||
)
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.media_utils import resolve_media_file
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
|
||||
|
||||
@router.post("/generate-images", response_model=StoryImageGenerationResponse)
|
||||
async def generate_scene_images(
|
||||
request: StoryImageGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryImageGenerationResponse:
|
||||
"""Generate images for story scenes."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating images for {len(request.scenes)} scenes for user {user_id}")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_results = image_service.generate_scene_images(
|
||||
scenes=scenes_data,
|
||||
user_id=user_id,
|
||||
provider=request.provider,
|
||||
width=request.width or 1024,
|
||||
height=request.height or 1024,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
image_models: List[StoryImageResult] = [
|
||||
StoryImageResult(
|
||||
scene_number=result.get("scene_number", 0),
|
||||
scene_title=result.get("scene_title", "Untitled"),
|
||||
image_filename=result.get("image_filename", ""),
|
||||
image_url=result.get("image_url", ""),
|
||||
width=result.get("width", 1024),
|
||||
height=result.get("height", 1024),
|
||||
provider=result.get("provider", "unknown"),
|
||||
model=result.get("model"),
|
||||
seed=result.get("seed"),
|
||||
error=result.get("error"),
|
||||
)
|
||||
for result in image_results
|
||||
]
|
||||
|
||||
return StoryImageGenerationResponse(images=image_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate images: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/regenerate-images", response_model=RegenerateImageResponse)
|
||||
async def regenerate_scene_image(
|
||||
request: RegenerateImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> RegenerateImageResponse:
|
||||
"""Regenerate a single scene image using a direct prompt (no AI prompt generation)."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.prompt or not request.prompt.strip():
|
||||
raise HTTPException(status_code=400, detail="Prompt is required")
|
||||
|
||||
logger.info(
|
||||
f"[StoryWriter] Regenerating image for scene {request.scene_number} "
|
||||
f"({request.scene_title}) for user {user_id}"
|
||||
)
|
||||
|
||||
result = image_service.regenerate_scene_image(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
prompt=request.prompt.strip(),
|
||||
user_id=user_id,
|
||||
provider=request.provider,
|
||||
width=request.width or 1024,
|
||||
height=request.height or 1024,
|
||||
model=request.model,
|
||||
)
|
||||
|
||||
return RegenerateImageResponse(
|
||||
scene_number=result.get("scene_number", request.scene_number),
|
||||
scene_title=result.get("scene_title", request.scene_title),
|
||||
image_filename=result.get("image_filename", ""),
|
||||
image_url=result.get("image_url", ""),
|
||||
width=result.get("width", request.width or 1024),
|
||||
height=result.get("height", request.height or 1024),
|
||||
provider=result.get("provider", "unknown"),
|
||||
model=result.get("model"),
|
||||
seed=result.get("seed"),
|
||||
success=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to regenerate image: {exc}")
|
||||
return RegenerateImageResponse(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
image_filename="",
|
||||
image_url="",
|
||||
width=request.width or 1024,
|
||||
height=request.height or 1024,
|
||||
provider=request.provider or "unknown",
|
||||
success=False,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/images/{image_filename}")
|
||||
async def serve_scene_image(
|
||||
image_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve a generated story scene image.
|
||||
|
||||
Supports authentication via Authorization header or token query parameter.
|
||||
Query parameter is useful for HTML elements like <img> that cannot send custom headers.
|
||||
"""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
image_path = resolve_media_file(image_service.output_dir, image_filename)
|
||||
return FileResponse(path=str(image_path), media_type="image/png", filename=image_filename)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve image: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-audio", response_model=StoryAudioGenerationResponse)
|
||||
async def generate_scene_audio(
|
||||
request: StoryAudioGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryAudioGenerationResponse:
|
||||
"""Generate audio narration for story scenes."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating audio for {len(request.scenes)} scenes for user {user_id}")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
audio_results = audio_service.generate_scene_audio_list(
|
||||
scenes=scenes_data,
|
||||
user_id=user_id,
|
||||
provider=request.provider or "gtts",
|
||||
lang=request.lang or "en",
|
||||
slow=request.slow or False,
|
||||
rate=request.rate or 150,
|
||||
)
|
||||
|
||||
audio_models: List[StoryAudioResult] = []
|
||||
for result in audio_results:
|
||||
audio_models.append(
|
||||
StoryAudioResult(
|
||||
scene_number=result.get("scene_number", 0),
|
||||
scene_title=result.get("scene_title", "Untitled"),
|
||||
audio_filename=result.get("audio_filename") or "",
|
||||
audio_url=result.get("audio_url") or "",
|
||||
provider=result.get("provider", "unknown"),
|
||||
file_size=result.get("file_size", 0),
|
||||
error=result.get("error"),
|
||||
)
|
||||
)
|
||||
|
||||
return StoryAudioGenerationResponse(audio_files=audio_models, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate audio: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-ai-audio", response_model=GenerateAIAudioResponse)
|
||||
async def generate_ai_audio(
|
||||
request: GenerateAIAudioRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> GenerateAIAudioResponse:
|
||||
"""Generate AI audio for a single scene using WaveSpeed Minimax Speech 02 HD."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.text or not request.text.strip():
|
||||
raise HTTPException(status_code=400, detail="Text is required")
|
||||
|
||||
logger.info(
|
||||
f"[StoryWriter] Generating AI audio for scene {request.scene_number} "
|
||||
f"({request.scene_title}) for user {user_id}"
|
||||
)
|
||||
|
||||
result = audio_service.generate_ai_audio(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
text=request.text.strip(),
|
||||
user_id=user_id,
|
||||
voice_id=request.voice_id or "Wise_Woman",
|
||||
speed=request.speed or 1.0,
|
||||
volume=request.volume or 1.0,
|
||||
pitch=request.pitch or 0.0,
|
||||
emotion=request.emotion or "happy",
|
||||
)
|
||||
|
||||
return GenerateAIAudioResponse(
|
||||
scene_number=result.get("scene_number", request.scene_number),
|
||||
scene_title=result.get("scene_title", request.scene_title),
|
||||
audio_filename=result.get("audio_filename", ""),
|
||||
audio_url=result.get("audio_url", ""),
|
||||
provider=result.get("provider", "wavespeed"),
|
||||
model=result.get("model", "minimax/speech-02-hd"),
|
||||
voice_id=result.get("voice_id", request.voice_id or "Wise_Woman"),
|
||||
text_length=result.get("text_length", len(request.text)),
|
||||
file_size=result.get("file_size", 0),
|
||||
cost=result.get("cost", 0.0),
|
||||
success=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate AI audio: {exc}")
|
||||
return GenerateAIAudioResponse(
|
||||
scene_number=request.scene_number,
|
||||
scene_title=request.scene_title,
|
||||
audio_filename="",
|
||||
audio_url="",
|
||||
provider="wavespeed",
|
||||
model="minimax/speech-02-hd",
|
||||
voice_id=request.voice_id or "Wise_Woman",
|
||||
text_length=len(request.text) if request.text else 0,
|
||||
file_size=0,
|
||||
cost=0.0,
|
||||
success=False,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/audio/{audio_filename}")
|
||||
async def serve_scene_audio(
|
||||
audio_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve a generated story scene audio file."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
audio_path = resolve_media_file(audio_service.output_dir, audio_filename)
|
||||
return FileResponse(path=str(audio_path), media_type="audio/mpeg", filename=audio_filename)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve audio: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
195
backend/api/story_writer/routes/story_content.py
Normal file
195
backend/api/story_writer/routes/story_content.py
Normal file
@@ -0,0 +1,195 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
StoryStartRequest,
|
||||
StoryContentResponse,
|
||||
StoryScene,
|
||||
StoryContinueRequest,
|
||||
StoryContinueResponse,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
|
||||
|
||||
@router.post("/generate-start", response_model=StoryContentResponse)
|
||||
async def generate_story_start(
|
||||
request: StoryStartRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryContentResponse:
|
||||
"""Generate the starting section of a story."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.premise or not request.premise.strip():
|
||||
raise HTTPException(status_code=400, detail="Premise is required")
|
||||
if not request.outline or (isinstance(request.outline, str) and not request.outline.strip()):
|
||||
raise HTTPException(status_code=400, detail="Outline is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating story start for user {user_id}")
|
||||
|
||||
outline_data: Any = request.outline
|
||||
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
|
||||
outline_data = [scene.dict() for scene in outline_data]
|
||||
|
||||
story_length = getattr(request, "story_length", "Medium")
|
||||
story_start = story_service.generate_story_start(
|
||||
premise=request.premise,
|
||||
outline=outline_data,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
story_length=story_length,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
story_length_lower = story_length.lower()
|
||||
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
||||
is_complete = False
|
||||
if is_short_story:
|
||||
word_count = len(story_start.split()) if story_start else 0
|
||||
if word_count >= 900:
|
||||
is_complete = True
|
||||
logger.info(
|
||||
f"[StoryWriter] Short story generated with {word_count} words. Marking as complete."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[StoryWriter] Short story generated with only {word_count} words. May need continuation."
|
||||
)
|
||||
|
||||
outline_response = outline_data
|
||||
if isinstance(outline_response, list):
|
||||
outline_response = "\n".join(
|
||||
[
|
||||
f"Scene {scene.get('scene_number', i + 1)}: "
|
||||
f"{scene.get('title', 'Untitled')}\n {scene.get('description', '')}"
|
||||
for i, scene in enumerate(outline_response)
|
||||
]
|
||||
)
|
||||
|
||||
return StoryContentResponse(
|
||||
story=story_start,
|
||||
premise=request.premise,
|
||||
outline=str(outline_response),
|
||||
is_complete=is_complete,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate story start: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/continue", response_model=StoryContinueResponse)
|
||||
async def continue_story(
|
||||
request: StoryContinueRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryContinueResponse:
|
||||
"""Continue writing a story."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.story_text or not request.story_text.strip():
|
||||
raise HTTPException(status_code=400, detail="Story text is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Continuing story for user {user_id}")
|
||||
|
||||
outline_data: Any = request.outline
|
||||
if isinstance(outline_data, list) and outline_data and isinstance(outline_data[0], StoryScene):
|
||||
outline_data = [scene.dict() for scene in outline_data]
|
||||
|
||||
story_length = getattr(request, "story_length", "Medium")
|
||||
story_length_lower = story_length.lower()
|
||||
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
||||
if is_short_story:
|
||||
logger.warning(
|
||||
"[StoryWriter] Attempted to continue a short story. Short stories should be complete in one call."
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Short stories are generated in a single call and should be complete. "
|
||||
"If the story is incomplete, please regenerate it from the beginning.",
|
||||
)
|
||||
|
||||
current_word_count = len(request.story_text.split()) if request.story_text else 0
|
||||
if "long" in story_length_lower or "10000" in story_length_lower:
|
||||
target_total_words = 10000
|
||||
else:
|
||||
target_total_words = 4500
|
||||
buffer_target = int(target_total_words * 1.05)
|
||||
|
||||
if current_word_count >= buffer_target or (
|
||||
current_word_count >= target_total_words
|
||||
and (current_word_count - target_total_words) < 50
|
||||
):
|
||||
logger.info(
|
||||
f"[StoryWriter] Word count ({current_word_count}) already at or near target ({target_total_words})."
|
||||
)
|
||||
return StoryContinueResponse(continuation="IAMDONE", is_complete=True, success=True)
|
||||
|
||||
continuation = story_service.continue_story(
|
||||
premise=request.premise,
|
||||
outline=outline_data,
|
||||
story_text=request.story_text,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
story_length=story_length,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
is_complete = "IAMDONE" in continuation.upper()
|
||||
if not is_complete and continuation:
|
||||
new_story_text = request.story_text + "\n\n" + continuation
|
||||
new_word_count = len(new_story_text.split())
|
||||
if new_word_count >= buffer_target:
|
||||
logger.info(
|
||||
f"[StoryWriter] Word count ({new_word_count}) now exceeds buffer target ({buffer_target})."
|
||||
)
|
||||
if "IAMDONE" not in continuation.upper():
|
||||
continuation = continuation.rstrip() + "\n\nIAMDONE"
|
||||
is_complete = True
|
||||
elif new_word_count >= target_total_words and (
|
||||
new_word_count - target_total_words
|
||||
) < 100:
|
||||
logger.info(
|
||||
f"[StoryWriter] Word count ({new_word_count}) is at or very close to target ({target_total_words})."
|
||||
)
|
||||
if "IAMDONE" not in continuation.upper():
|
||||
continuation = continuation.rstrip() + "\n\nIAMDONE"
|
||||
is_complete = True
|
||||
|
||||
return StoryContinueResponse(continuation=continuation, is_complete=is_complete, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to continue story: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
141
backend/api/story_writer/routes/story_setup.py
Normal file
141
backend/api/story_writer/routes/story_setup.py
Normal file
@@ -0,0 +1,141 @@
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
StorySetupGenerationRequest,
|
||||
StorySetupGenerationResponse,
|
||||
StorySetupOption,
|
||||
StoryGenerationRequest,
|
||||
StoryOutlineResponse,
|
||||
StoryScene,
|
||||
StoryStartRequest,
|
||||
StoryPremiseResponse,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
|
||||
|
||||
@router.post("/generate-setup", response_model=StorySetupGenerationResponse)
|
||||
async def generate_story_setup(
|
||||
request: StorySetupGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StorySetupGenerationResponse:
|
||||
"""Generate 3 story setup options from a user's story idea."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.story_idea or not request.story_idea.strip():
|
||||
raise HTTPException(status_code=400, detail="Story idea is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
|
||||
|
||||
options = story_service.generate_story_setup_options(
|
||||
story_idea=request.story_idea,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
setup_options = [StorySetupOption(**option) for option in options]
|
||||
return StorySetupGenerationResponse(options=setup_options, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate story setup options: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-premise", response_model=StoryPremiseResponse)
|
||||
async def generate_premise(
|
||||
request: StoryGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryPremiseResponse:
|
||||
"""Generate a story premise."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.info(f"[StoryWriter] Generating premise for user {user_id}")
|
||||
|
||||
premise = story_service.generate_premise(
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return StoryPremiseResponse(premise=premise, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate premise: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-outline", response_model=StoryOutlineResponse)
|
||||
async def generate_outline(
|
||||
request: StoryStartRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
use_structured: bool = True,
|
||||
) -> StoryOutlineResponse:
|
||||
"""Generate a story outline from a premise."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.premise or not request.premise.strip():
|
||||
raise HTTPException(status_code=400, detail="Premise is required")
|
||||
|
||||
logger.info(
|
||||
f"[StoryWriter] Generating outline for user {user_id} (structured={use_structured})"
|
||||
)
|
||||
logger.info(
|
||||
"[StoryWriter] Outline params: audience_age_group=%s, writing_style=%s, story_tone=%s",
|
||||
request.audience_age_group,
|
||||
request.writing_style,
|
||||
request.story_tone,
|
||||
)
|
||||
|
||||
outline = story_service.generate_outline(
|
||||
premise=request.premise,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
user_id=user_id,
|
||||
use_structured_output=use_structured,
|
||||
)
|
||||
|
||||
if isinstance(outline, list):
|
||||
scenes: List[StoryScene] = [
|
||||
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline
|
||||
]
|
||||
return StoryOutlineResponse(outline=scenes, success=True, is_structured=True)
|
||||
|
||||
return StoryOutlineResponse(outline=str(outline), success=True, is_structured=False)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate outline: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
130
backend/api/story_writer/routes/story_tasks.py
Normal file
130
backend/api/story_writer/routes/story_tasks.py
Normal file
@@ -0,0 +1,130 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
StoryGenerationRequest,
|
||||
TaskStatus,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..cache_manager import cache_manager
|
||||
from ..task_manager import task_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
story_service = StoryWriterService()
|
||||
|
||||
|
||||
@router.post("/generate-full", response_model=Dict[str, Any])
|
||||
async def generate_full_story(
|
||||
request: StoryGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
max_iterations: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a complete story asynchronously."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
cache_key = cache_manager.get_cache_key(request.dict())
|
||||
cached_result = cache_manager.get_cached_result(cache_key)
|
||||
if cached_result:
|
||||
logger.info(f"[StoryWriter] Returning cached result for user {user_id}")
|
||||
task_id = task_manager.create_task("story_generation")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
result=cached_result,
|
||||
message="Returned cached result",
|
||||
)
|
||||
return {"task_id": task_id, "cached": True}
|
||||
|
||||
task_id = task_manager.create_task("story_generation")
|
||||
request_data = request.dict()
|
||||
request_data["max_iterations"] = max_iterations
|
||||
|
||||
background_tasks.add_task(
|
||||
task_manager.execute_story_generation_task,
|
||||
task_id=task_id,
|
||||
request_data=request_data,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Created task {task_id} for full story generation (user {user_id})")
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Story generation started. Use /task/{task_id}/status to check progress.",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to start story generation: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/status", response_model=TaskStatus)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> TaskStatus:
|
||||
"""Get the status of a story generation task."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
task_status = task_manager.get_task_status(task_id)
|
||||
if not task_status:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
|
||||
return TaskStatus(**task_status)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to get task status: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/result")
|
||||
async def get_task_result(
|
||||
task_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Get the result of a completed story generation task."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
task_status = task_manager.get_task_status(task_id)
|
||||
if not task_status:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
if task_status["status"] != "completed":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Task {task_id} is not completed. Status: {task_status['status']}",
|
||||
)
|
||||
|
||||
result = task_status.get("result")
|
||||
if not result:
|
||||
raise HTTPException(status_code=404, detail=f"No result found for task {task_id}")
|
||||
|
||||
if isinstance(result, dict):
|
||||
payload = {**result}
|
||||
payload.setdefault("success", True)
|
||||
payload["task_id"] = task_id
|
||||
return payload
|
||||
|
||||
return {"result": result, "success": True, "task_id": task_id}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to get task result: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
511
backend/api/story_writer/routes/video_generation.py
Normal file
511
backend/api/story_writer/routes/video_generation.py
Normal file
@@ -0,0 +1,511 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from models.story_models import (
|
||||
StoryVideoGenerationRequest,
|
||||
StoryVideoGenerationResponse,
|
||||
StoryVideoResult,
|
||||
StoryScene,
|
||||
StoryGenerationRequest,
|
||||
)
|
||||
from services.story_writer.video_generation_service import StoryVideoGenerationService
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
from ..task_manager import task_manager
|
||||
from ..utils.auth import require_authenticated_user
|
||||
from ..utils.hd_video import (
|
||||
generate_hd_video_payload,
|
||||
generate_hd_video_scene_payload,
|
||||
)
|
||||
from ..utils.media_utils import resolve_media_file
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
video_service = StoryVideoGenerationService()
|
||||
image_service = StoryImageGenerationService()
|
||||
audio_service = StoryAudioGenerationService()
|
||||
story_service = StoryWriterService()
|
||||
|
||||
|
||||
class HDVideoRequest(BaseModel):
|
||||
prompt: str
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
class HDVideoSceneRequest(BaseModel):
|
||||
scene_number: int
|
||||
scene_data: Dict[str, Any]
|
||||
story_context: Dict[str, Any]
|
||||
all_scenes: List[Dict[str, Any]]
|
||||
provider: str = "huggingface"
|
||||
model: str | None = None
|
||||
num_frames: int | None = None
|
||||
guidance_scale: float | None = None
|
||||
num_inference_steps: int | None = None
|
||||
negative_prompt: str | None = None
|
||||
seed: int | None = None
|
||||
|
||||
|
||||
@router.post("/generate-video", response_model=StoryVideoGenerationResponse)
|
||||
async def generate_story_video(
|
||||
request: StoryVideoGenerationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryVideoGenerationResponse:
|
||||
"""Generate a video from story scenes, images, and audio."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of scenes, image URLs, and audio URLs must match",
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Generating video for {len(request.scenes)} scenes for user {user_id}")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
video_paths: List[Optional[str]] = [] # Animated videos (preferred)
|
||||
image_paths: List[Optional[str]] = [] # Static images (fallback)
|
||||
audio_paths: List[str] = []
|
||||
valid_scenes: List[Dict[str, Any]] = []
|
||||
|
||||
# Resolve video/audio directories
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
ai_video_dir = (base_dir / "story_videos" / "AI_Videos").resolve()
|
||||
|
||||
video_urls = request.video_urls or [None] * len(request.scenes)
|
||||
ai_audio_urls = request.ai_audio_urls or [None] * len(request.scenes)
|
||||
|
||||
for idx, (scene, image_url, audio_url) in enumerate(zip(scenes_data, request.image_urls, request.audio_urls)):
|
||||
# Prefer animated video if available
|
||||
video_url = video_urls[idx] if idx < len(video_urls) else None
|
||||
video_path = None
|
||||
image_path = None
|
||||
|
||||
if video_url:
|
||||
# Extract filename from animated video URL (e.g., /api/story/videos/ai/filename.mp4)
|
||||
video_filename = video_url.split("/")[-1].split("?")[0]
|
||||
video_path = ai_video_dir / video_filename
|
||||
if video_path.exists():
|
||||
logger.info(f"[StoryWriter] Using animated video for scene {scene.get('scene_number', idx+1)}: {video_filename}")
|
||||
video_paths.append(str(video_path))
|
||||
image_paths.append(None)
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Animated video not found: {video_path}, falling back to image")
|
||||
video_paths.append(None)
|
||||
video_path = None
|
||||
|
||||
# Fall back to image if no animated video
|
||||
if not video_path:
|
||||
image_filename = image_url.split("/")[-1].split("?")[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
if image_path.exists():
|
||||
video_paths.append(None)
|
||||
image_paths.append(str(image_path))
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
|
||||
# Prefer AI audio if available, otherwise use free audio
|
||||
ai_audio_url = ai_audio_urls[idx] if idx < len(ai_audio_urls) else None
|
||||
audio_filename = None
|
||||
audio_path = None
|
||||
|
||||
if ai_audio_url:
|
||||
audio_filename = ai_audio_url.split("/")[-1].split("?")[0]
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if audio_path.exists():
|
||||
logger.info(f"[StoryWriter] Using AI audio for scene {scene.get('scene_number', idx+1)}: {audio_filename}")
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] AI audio not found: {audio_path}, falling back to free audio")
|
||||
audio_path = None
|
||||
|
||||
# Fall back to free audio if no AI audio
|
||||
if not audio_path:
|
||||
audio_filename = audio_url.split("/")[-1].split("?")[0]
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if len(valid_scenes) == 0 or len(audio_paths) == 0:
|
||||
raise HTTPException(status_code=400, detail="No valid video/image or audio files were found")
|
||||
if len(valid_scenes) != len(audio_paths):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of valid scenes and audio files must match",
|
||||
)
|
||||
|
||||
video_result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths, # Can contain None for scenes with animated videos
|
||||
video_paths=video_paths, # Can contain None for scenes with static images
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
)
|
||||
|
||||
video_model = StoryVideoResult(
|
||||
video_filename=video_result.get("video_filename", ""),
|
||||
video_url=video_result.get("video_url", ""),
|
||||
duration=video_result.get("duration", 0.0),
|
||||
fps=video_result.get("fps", 24),
|
||||
file_size=video_result.get("file_size", 0),
|
||||
num_scenes=video_result.get("num_scenes", 0),
|
||||
error=video_result.get("error"),
|
||||
)
|
||||
|
||||
return StoryVideoGenerationResponse(video=video_model, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate video: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-video-async", response_model=Dict[str, Any])
|
||||
async def generate_story_video_async(
|
||||
request: StoryVideoGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a video asynchronously with progress updates via task manager.
|
||||
Frontend can poll /api/story/task/{task_id}/status to show progress messages.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.scenes or len(request.scenes) == 0:
|
||||
raise HTTPException(status_code=400, detail="At least one scene is required")
|
||||
if len(request.scenes) != len(request.image_urls) or len(request.scenes) != len(request.audio_urls):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Number of scenes, image URLs, and audio URLs must match",
|
||||
)
|
||||
|
||||
task_id = task_manager.create_task("story_video_generation")
|
||||
background_tasks.add_task(
|
||||
_execute_video_generation_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
user_id=user_id,
|
||||
)
|
||||
return {"task_id": task_id, "status": "pending", "message": "Video generation started"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to start async video generation: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
def _execute_video_generation_task(task_id: str, request: StoryVideoGenerationRequest, user_id: str):
|
||||
"""Background task to generate story video with progress mapped to task manager."""
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=2.0, message="Initializing video generation...")
|
||||
|
||||
scenes_data = [scene.dict() if isinstance(scene, StoryScene) else scene for scene in request.scenes]
|
||||
image_paths: List[str] = []
|
||||
audio_paths: List[str] = []
|
||||
valid_scenes: List[Dict[str, Any]] = []
|
||||
|
||||
for scene, image_url, audio_url in zip(scenes_data, request.image_urls, request.audio_urls):
|
||||
image_filename = image_url.split("/")[-1].split("?")[0]
|
||||
audio_filename = audio_url.split("/")[-1].split("?")[0]
|
||||
image_path = image_service.output_dir / image_filename
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
if not image_path.exists():
|
||||
logger.warning(f"[StoryWriter] Image not found: {image_path} (from URL: {image_url})")
|
||||
continue
|
||||
if not audio_path.exists():
|
||||
logger.warning(f"[StoryWriter] Audio not found: {audio_path} (from URL: {audio_url})")
|
||||
continue
|
||||
image_paths.append(str(image_path))
|
||||
audio_paths.append(str(audio_path))
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if not image_paths or not audio_paths or len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError("No valid or mismatched image/audio assets for video generation.")
|
||||
|
||||
def progress_callback(sub_progress: float, msg: str):
|
||||
overall = 5.0 + max(0.0, min(100.0, sub_progress)) * 0.9
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall, message=msg)
|
||||
|
||||
result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request.story_title or "Story",
|
||||
fps=request.fps or 24,
|
||||
transition_duration=request.transition_duration or 0.5,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Video generation complete!",
|
||||
result={"video": result, "success": True},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Async video generation failed: {exc}", exc_info=True)
|
||||
task_manager.update_task_status(task_id, "failed", error=str(exc), message=f"Video generation failed: {exc}")
|
||||
|
||||
|
||||
@router.post("/hd-video")
|
||||
async def generate_hd_video(
|
||||
request: HDVideoRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/hd-video-scene")
|
||||
async def generate_hd_video_scene(
|
||||
request: HDVideoSceneRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
return generate_hd_video_scene_payload(request, user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate HD video for scene: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/generate-complete-video", response_model=Dict[str, Any])
|
||||
async def generate_complete_story_video(
|
||||
request: StoryGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a complete story video workflow asynchronously."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.info(f"[StoryWriter] Starting complete video generation for user {user_id}")
|
||||
|
||||
task_id = task_manager.create_task("complete_video_generation")
|
||||
background_tasks.add_task(
|
||||
execute_complete_video_generation,
|
||||
task_id=task_id,
|
||||
request_data=request.dict(),
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return {
|
||||
"task_id": task_id,
|
||||
"status": "pending",
|
||||
"message": "Complete video generation started",
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to start complete video generation: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
def execute_complete_video_generation(
|
||||
task_id: str,
|
||||
request_data: Dict[str, Any],
|
||||
user_id: str,
|
||||
):
|
||||
"""
|
||||
Execute complete video generation workflow synchronously.
|
||||
Runs in a background task and performs blocking operations.
|
||||
"""
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=5.0, message="Starting complete video generation...")
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
|
||||
premise = story_service.generate_premise(
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=20.0, message="Generating structured outline with scenes...")
|
||||
outline_scenes = story_service.generate_outline(
|
||||
premise=premise,
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
character_input=request_data["character_input"],
|
||||
plot_elements=request_data["plot_elements"],
|
||||
writing_style=request_data["writing_style"],
|
||||
story_tone=request_data["story_tone"],
|
||||
narrative_pov=request_data["narrative_pov"],
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id,
|
||||
use_structured_output=True,
|
||||
)
|
||||
|
||||
if not isinstance(outline_scenes, list):
|
||||
raise RuntimeError("Failed to generate structured outline")
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=30.0, message="Generating images for scenes...")
|
||||
|
||||
def image_progress_callback(sub_progress: float, message: str):
|
||||
overall_progress = 30.0 + (sub_progress * 0.2)
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
|
||||
|
||||
image_results = image_service.generate_scene_images(
|
||||
scenes=outline_scenes,
|
||||
user_id=user_id,
|
||||
provider=request_data.get("image_provider"),
|
||||
width=request_data.get("image_width", 1024),
|
||||
height=request_data.get("image_height", 1024),
|
||||
model=request_data.get("image_model"),
|
||||
progress_callback=image_progress_callback,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=50.0, message="Generating audio narration for scenes...")
|
||||
|
||||
def audio_progress_callback(sub_progress: float, message: str):
|
||||
overall_progress = 50.0 + (sub_progress * 0.2)
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
|
||||
|
||||
audio_results = audio_service.generate_scene_audio_list(
|
||||
scenes=outline_scenes,
|
||||
user_id=user_id,
|
||||
provider=request_data.get("audio_provider", "gtts"),
|
||||
lang=request_data.get("audio_lang", "en"),
|
||||
slow=request_data.get("audio_slow", False),
|
||||
rate=request_data.get("audio_rate", 150),
|
||||
progress_callback=audio_progress_callback,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=70.0, message="Preparing video assets...")
|
||||
image_paths: List[str] = []
|
||||
audio_paths: List[str] = []
|
||||
valid_scenes: List[Dict[str, Any]] = []
|
||||
|
||||
for scene in outline_scenes:
|
||||
scene_number = scene.get("scene_number", 0)
|
||||
image_result = next((img for img in image_results if img.get("scene_number") == scene_number), None)
|
||||
audio_result = next((aud for aud in audio_results if aud.get("scene_number") == scene_number), None)
|
||||
|
||||
if image_result and audio_result and not image_result.get("error") and not audio_result.get("error"):
|
||||
image_path = image_result.get("image_path")
|
||||
audio_path = audio_result.get("audio_path")
|
||||
if image_path and audio_path:
|
||||
image_paths.append(image_path)
|
||||
audio_paths.append(audio_path)
|
||||
valid_scenes.append(scene)
|
||||
|
||||
if len(image_paths) == 0 or len(audio_paths) == 0:
|
||||
raise RuntimeError(
|
||||
f"No valid images or audio files were generated. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
|
||||
)
|
||||
if len(image_paths) != len(audio_paths):
|
||||
raise RuntimeError(
|
||||
f"Mismatch between image and audio counts. Images: {len(image_paths)}, Audio: {len(audio_paths)}"
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=75.0, message="Composing video from scenes...")
|
||||
|
||||
def video_progress_callback(sub_progress: float, message: str):
|
||||
overall_progress = 75.0 + (sub_progress * 0.2)
|
||||
task_manager.update_task_status(task_id, "processing", progress=overall_progress, message=message)
|
||||
|
||||
video_result = video_service.generate_story_video(
|
||||
scenes=valid_scenes,
|
||||
image_paths=image_paths,
|
||||
audio_paths=audio_paths,
|
||||
user_id=user_id,
|
||||
story_title=request_data.get("story_setting", "Story")[:50],
|
||||
fps=request_data.get("video_fps", 24),
|
||||
transition_duration=request_data.get("video_transition_duration", 0.5),
|
||||
progress_callback=video_progress_callback,
|
||||
)
|
||||
|
||||
result = {
|
||||
"premise": premise,
|
||||
"outline_scenes": outline_scenes,
|
||||
"images": image_results,
|
||||
"audio_files": audio_results,
|
||||
"video": video_result,
|
||||
"success": True,
|
||||
}
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="Complete video generation finished!",
|
||||
result=result,
|
||||
)
|
||||
|
||||
logger.info(f"[StoryWriter] Complete video generation task {task_id} completed successfully")
|
||||
|
||||
except Exception as exc:
|
||||
error_msg = str(exc)
|
||||
logger.error(f"[StoryWriter] Complete video generation task {task_id} failed: {error_msg}", exc_info=True)
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=error_msg,
|
||||
message=f"Complete video generation failed: {error_msg}",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/videos/{video_filename}")
|
||||
async def serve_story_video(
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve a generated story video file."""
|
||||
try:
|
||||
require_authenticated_user(current_user)
|
||||
video_path = resolve_media_file(video_service.output_dir, video_filename)
|
||||
return FileResponse(path=str(video_path), media_type="video/mp4", filename=video_filename)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to serve video: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from uuid import uuid4
|
||||
|
||||
from .media_utils import load_story_image_bytes
|
||||
|
||||
|
||||
def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""Handles synchronous HD video generation."""
|
||||
@@ -57,8 +55,8 @@ def generate_hd_video_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
|
||||
def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Handles per-scene HD video generation including prompt enhancement,
|
||||
subscription validation, and optional image conditioning.
|
||||
Handles per-scene HD video generation including prompt enhancement
|
||||
and subscription validation.
|
||||
"""
|
||||
from services.database import get_db as get_db_validation
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
@@ -71,7 +69,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
|
||||
scene_number = request.scene_number
|
||||
logger.info(f"[StoryWriter] Generating HD video for scene {scene_number} for user {user_id}")
|
||||
|
||||
# Step 1: Validate API key
|
||||
hf_token = APIKeyManager().get_api_key("hf_token")
|
||||
if not hf_token:
|
||||
logger.error("[StoryWriter] Pre-flight: HF token not configured - blocking video generation")
|
||||
@@ -83,7 +80,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
|
||||
},
|
||||
)
|
||||
|
||||
# Step 2: Subscription limits
|
||||
db_validation = next(get_db_validation())
|
||||
try:
|
||||
pricing_service = PricingService(db_validation)
|
||||
@@ -93,7 +89,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
|
||||
finally:
|
||||
db_validation.close()
|
||||
|
||||
# Stage 1: Prompt enhancement
|
||||
enhanced_prompt = enhance_scene_prompt_for_video(
|
||||
current_scene=request.scene_data,
|
||||
story_context=request.story_context,
|
||||
@@ -102,15 +97,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
|
||||
)
|
||||
logger.info(f"[StoryWriter] Generated enhanced prompt ({len(enhanced_prompt)} chars) for scene {scene_number}")
|
||||
|
||||
# Stage 2: Optional image reference
|
||||
scene_image_bytes: Optional[bytes] = None
|
||||
if getattr(request, "scene_image_url", None):
|
||||
scene_image_bytes = load_story_image_bytes(request.scene_image_url)
|
||||
if scene_image_bytes:
|
||||
logger.info(f"[StoryWriter] Using scene image reference for scene {scene_number}")
|
||||
else:
|
||||
logger.warning(f"[StoryWriter] Scene image could not be loaded for scene {scene_number}, falling back to text-only video")
|
||||
|
||||
kwargs: Dict[str, Any] = {}
|
||||
if getattr(request, "model", None):
|
||||
kwargs["model"] = request.model
|
||||
@@ -129,7 +115,6 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
|
||||
prompt=enhanced_prompt,
|
||||
provider=getattr(request, "provider", None) or "huggingface",
|
||||
user_id=user_id,
|
||||
input_image_bytes=scene_image_bytes,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -151,4 +136,3 @@ def generate_hd_video_scene_payload(request: Any, user_id: str) -> Dict[str, Any
|
||||
"model": getattr(request, "model", None) or "tencent/HunyuanVideo",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ from loguru import logger
|
||||
BASE_DIR = Path(__file__).resolve().parents[3] # backend/
|
||||
STORY_IMAGES_DIR = (BASE_DIR / "story_images").resolve()
|
||||
STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
STORY_AUDIO_DIR = (BASE_DIR / "story_audio").resolve()
|
||||
STORY_AUDIO_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def load_story_image_bytes(image_url: str) -> Optional[bytes]:
|
||||
@@ -48,6 +50,41 @@ def load_story_image_bytes(image_url: str) -> Optional[bytes]:
|
||||
return None
|
||||
|
||||
|
||||
def load_story_audio_bytes(audio_url: str) -> Optional[bytes]:
|
||||
"""
|
||||
Resolve an authenticated story audio URL (e.g., /api/story/audio/<file>) to raw bytes.
|
||||
Returns None if the file cannot be located.
|
||||
"""
|
||||
if not audio_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
parsed = urlparse(audio_url)
|
||||
path = parsed.path if parsed.scheme else audio_url
|
||||
prefix = "/api/story/audio/"
|
||||
if prefix not in path:
|
||||
logger.warning(f"[StoryWriter] Unsupported audio URL for video reference: {audio_url}")
|
||||
return None
|
||||
|
||||
filename = path.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
file_path = (STORY_AUDIO_DIR / filename).resolve()
|
||||
if not str(file_path).startswith(str(STORY_AUDIO_DIR)):
|
||||
logger.error(f"[StoryWriter] Attempted path traversal when resolving audio: {audio_url}")
|
||||
return None
|
||||
|
||||
if not file_path.exists():
|
||||
logger.warning(f"[StoryWriter] Referenced scene audio not found on disk: {file_path}")
|
||||
return None
|
||||
|
||||
return file_path.read_bytes()
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to load reference audio for video gen: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def resolve_media_file(base_dir: Path, filename: str) -> Path:
|
||||
"""
|
||||
Returns a safe resolved path for a media file stored under base_dir.
|
||||
@@ -62,8 +99,50 @@ def resolve_media_file(base_dir: Path, filename: str) -> Path:
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied")
|
||||
|
||||
if not resolved.exists():
|
||||
alternate = _find_alternate_media_file(base_dir, filename)
|
||||
if alternate:
|
||||
logger.warning(
|
||||
"[StoryWriter] Requested media file '%s' missing; serving closest match '%s'",
|
||||
filename,
|
||||
alternate.name,
|
||||
)
|
||||
return alternate
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail=f"File not found: {filename}")
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def _find_alternate_media_file(base_dir: Path, filename: str) -> Optional[Path]:
|
||||
"""
|
||||
Attempt to find the most recent media file that matches the original name prefix.
|
||||
|
||||
This helps when files are regenerated with new UUID/hash suffixes but the frontend still
|
||||
references an older filename.
|
||||
"""
|
||||
try:
|
||||
base_dir = base_dir.resolve()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
stem = Path(filename).stem
|
||||
suffix = Path(filename).suffix
|
||||
|
||||
if not suffix or "_" not in stem:
|
||||
return None
|
||||
|
||||
prefix = stem.rsplit("_", 1)[0]
|
||||
pattern = f"{prefix}_*{suffix}"
|
||||
|
||||
try:
|
||||
candidates = sorted(
|
||||
(p for p in base_dir.glob(pattern) if p.is_file()),
|
||||
key=lambda p: p.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[StoryWriter] Failed to search alternate media files for {filename}: {exc}")
|
||||
return None
|
||||
|
||||
return candidates[0] if candidates else None
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ Provides endpoints for subscription management and usage monitoring.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc, func
|
||||
from typing import Dict, Any, Optional, List
|
||||
@@ -116,6 +117,7 @@ async def get_subscription_plans(
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
@@ -134,7 +136,7 @@ async def get_subscription_plans(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
||||
logger.warning("Missing column detected in subscription plans query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
@@ -241,6 +243,7 @@ async def get_user_subscription(
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -340,6 +343,7 @@ async def get_subscription_status(
|
||||
"stability_calls": free_plan.stability_calls_limit,
|
||||
"video_calls": getattr(free_plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(free_plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(free_plan, 'audio_calls_limit', 0),
|
||||
"monthly_cost": free_plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -405,7 +409,7 @@ async def get_subscription_status(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str):
|
||||
if 'no such column' in error_str and ('exa_calls_limit' in error_str or 'video_calls_limit' in error_str or 'image_edit_calls_limit' in error_str or 'audio_calls_limit' in error_str):
|
||||
# Try to fix schema and retry once
|
||||
logger.warning("Missing column detected in subscription status query, attempting schema fix...")
|
||||
try:
|
||||
@@ -499,6 +503,7 @@ async def get_subscription_status(
|
||||
"stability_calls": plan.stability_calls_limit,
|
||||
"video_calls": getattr(plan, 'video_calls_limit', 0),
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0),
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0),
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
}
|
||||
@@ -988,7 +993,7 @@ async def get_dashboard_data(
|
||||
|
||||
except (sqlite3.OperationalError, Exception) as e:
|
||||
error_str = str(e).lower()
|
||||
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str):
|
||||
if 'no such column' in error_str and ('exa_calls' in error_str or 'exa_cost' in error_str or 'video_calls' in error_str or 'video_cost' in error_str or 'image_edit_calls' in error_str or 'image_edit_cost' in error_str or 'audio_calls' in error_str or 'audio_cost' in error_str):
|
||||
logger.warning("Missing column detected in dashboard query, attempting schema fix...")
|
||||
try:
|
||||
import services.subscription.schema_utils as schema_utils
|
||||
@@ -1271,4 +1276,235 @@ async def get_usage_logs(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting usage logs: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get usage logs: {str(e)}")
|
||||
|
||||
|
||||
class PreflightOperationRequest(BaseModel):
|
||||
"""Request model for pre-flight check operation."""
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
tokens_requested: Optional[int] = 0
|
||||
operation_type: str
|
||||
actual_provider_name: Optional[str] = None
|
||||
|
||||
|
||||
class PreflightCheckRequest(BaseModel):
|
||||
"""Request model for pre-flight check."""
|
||||
operations: List[PreflightOperationRequest]
|
||||
|
||||
|
||||
@router.post("/preflight-check")
|
||||
async def preflight_check(
|
||||
request: PreflightCheckRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Pre-flight check for operations with cost estimation.
|
||||
|
||||
Lightweight endpoint that:
|
||||
- Validates if operations are allowed based on subscription limits
|
||||
- Estimates cost for operations
|
||||
- Returns usage information and remaining quota
|
||||
|
||||
Uses caching to minimize DB load (< 100ms with cache hit).
|
||||
"""
|
||||
try:
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
# Ensure schema columns exist
|
||||
try:
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
except Exception as schema_err:
|
||||
logger.warning(f"Schema check failed: {schema_err}")
|
||||
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Convert request operations to internal format
|
||||
operations_to_validate = []
|
||||
for op in request.operations:
|
||||
try:
|
||||
# Map provider string to APIProvider enum
|
||||
provider_str = op.provider.lower()
|
||||
if provider_str == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL # Maps to HuggingFace
|
||||
elif provider_str == "video":
|
||||
provider_enum = APIProvider.VIDEO
|
||||
elif provider_str == "image_edit":
|
||||
provider_enum = APIProvider.IMAGE_EDIT
|
||||
elif provider_str == "stability":
|
||||
provider_enum = APIProvider.STABILITY
|
||||
elif provider_str == "audio":
|
||||
provider_enum = APIProvider.AUDIO
|
||||
else:
|
||||
try:
|
||||
provider_enum = APIProvider(provider_str)
|
||||
except ValueError:
|
||||
logger.warning(f"Unknown provider: {provider_str}, skipping")
|
||||
continue
|
||||
|
||||
operations_to_validate.append({
|
||||
'provider': provider_enum,
|
||||
'tokens_requested': op.tokens_requested or 0,
|
||||
'actual_provider_name': op.actual_provider_name or op.provider,
|
||||
'operation_type': op.operation_type
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing operation {op.operation_type}: {e}")
|
||||
continue
|
||||
|
||||
if not operations_to_validate:
|
||||
raise HTTPException(status_code=400, detail="No valid operations provided")
|
||||
|
||||
# Perform pre-flight validation
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate
|
||||
)
|
||||
|
||||
# Get pricing and cost estimation for each operation
|
||||
operation_results = []
|
||||
total_cost = 0.0
|
||||
|
||||
for i, op in enumerate(operations_to_validate):
|
||||
op_result = {
|
||||
'provider': op['actual_provider_name'],
|
||||
'operation_type': op['operation_type'],
|
||||
'cost': 0.0,
|
||||
'allowed': can_proceed,
|
||||
'limit_info': None,
|
||||
'message': None
|
||||
}
|
||||
|
||||
# Get pricing for this operation
|
||||
model_name = request.operations[i].model
|
||||
if model_name:
|
||||
pricing_info = pricing_service.get_pricing_for_provider_model(
|
||||
op['provider'],
|
||||
model_name
|
||||
)
|
||||
|
||||
if pricing_info:
|
||||
# Determine cost based on operation type
|
||||
if op['provider'] in [APIProvider.VIDEO, APIProvider.IMAGE_EDIT, APIProvider.STABILITY]:
|
||||
cost = pricing_info.get('cost_per_request', 0.0) or pricing_info.get('cost_per_image', 0.0) or 0.0
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
# Audio pricing is per character (every character is 1 token)
|
||||
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000.0)
|
||||
elif op['tokens_requested'] > 0:
|
||||
# Token-based cost estimation (rough estimate)
|
||||
cost = (pricing_info.get('cost_per_input_token', 0.0) or 0.0) * (op['tokens_requested'] / 1000)
|
||||
else:
|
||||
cost = pricing_info.get('cost_per_request', 0.0) or 0.0
|
||||
|
||||
op_result['cost'] = round(cost, 4)
|
||||
total_cost += cost
|
||||
else:
|
||||
# Use default cost if pricing not found
|
||||
if op['provider'] == APIProvider.VIDEO:
|
||||
op_result['cost'] = 0.10 # Default video cost
|
||||
total_cost += 0.10
|
||||
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
||||
op_result['cost'] = 0.05 # Default image edit cost
|
||||
total_cost += 0.05
|
||||
elif op['provider'] == APIProvider.STABILITY:
|
||||
op_result['cost'] = 0.04 # Default image generation cost
|
||||
total_cost += 0.04
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
# Default audio cost: $0.05 per 1,000 characters
|
||||
cost = (op['tokens_requested'] / 1000.0) * 0.05
|
||||
op_result['cost'] = round(cost, 4)
|
||||
total_cost += cost
|
||||
|
||||
# Get limit information
|
||||
limit_info = None
|
||||
if error_details and not can_proceed:
|
||||
usage_info = error_details.get('usage_info', {})
|
||||
if usage_info:
|
||||
op_result['message'] = message
|
||||
limit_info = {
|
||||
'current_usage': usage_info.get('current_usage', 0),
|
||||
'limit': usage_info.get('limit', 0),
|
||||
'remaining': max(0, usage_info.get('limit', 0) - usage_info.get('current_usage', 0))
|
||||
}
|
||||
op_result['limit_info'] = limit_info
|
||||
else:
|
||||
# Get current usage for this provider
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
if limits:
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
||||
).first()
|
||||
|
||||
if usage_summary:
|
||||
if op['provider'] == APIProvider.VIDEO:
|
||||
current = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
limit = limits['limits'].get('video_calls', 0)
|
||||
elif op['provider'] == APIProvider.IMAGE_EDIT:
|
||||
current = getattr(usage_summary, 'image_edit_calls', 0) or 0
|
||||
limit = limits['limits'].get('image_edit_calls', 0)
|
||||
elif op['provider'] == APIProvider.STABILITY:
|
||||
current = getattr(usage_summary, 'stability_calls', 0) or 0
|
||||
limit = limits['limits'].get('stability_calls', 0)
|
||||
elif op['provider'] == APIProvider.AUDIO:
|
||||
current = getattr(usage_summary, 'audio_calls', 0) or 0
|
||||
limit = limits['limits'].get('audio_calls', 0)
|
||||
else:
|
||||
# For LLM providers, use token limits
|
||||
provider_key = op['provider'].value
|
||||
current_tokens = getattr(usage_summary, f"{provider_key}_tokens", 0) or 0
|
||||
limit = limits['limits'].get(f"{provider_key}_tokens", 0)
|
||||
current = current_tokens
|
||||
|
||||
limit_info = {
|
||||
'current_usage': current,
|
||||
'limit': limit,
|
||||
'remaining': max(0, limit - current) if limit > 0 else float('inf')
|
||||
}
|
||||
op_result['limit_info'] = limit_info
|
||||
|
||||
operation_results.append(op_result)
|
||||
|
||||
# Get overall usage summary
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
usage_summary = None
|
||||
if limits:
|
||||
usage_summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == pricing_service.get_current_billing_period(user_id)
|
||||
).first()
|
||||
|
||||
response_data = {
|
||||
'can_proceed': can_proceed,
|
||||
'estimated_cost': round(total_cost, 4),
|
||||
'operations': operation_results,
|
||||
'total_cost': round(total_cost, 4),
|
||||
'usage_summary': None,
|
||||
'cached': False # TODO: Track if result was cached
|
||||
}
|
||||
|
||||
if usage_summary and limits:
|
||||
# For video generation, show video limits
|
||||
video_current = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
video_limit = limits['limits'].get('video_calls', 0)
|
||||
|
||||
response_data['usage_summary'] = {
|
||||
'current_calls': video_current,
|
||||
'limit': video_limit,
|
||||
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": response_data
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")
|
||||
@@ -97,7 +97,14 @@ def setup_clean_logging():
|
||||
def video_generation_filter(record):
|
||||
msg = record.get("message", "")
|
||||
name = record.get("name", "")
|
||||
return "[StoryVideoGeneration]" in msg or "services.story_writer.video_generation_service" in name
|
||||
service = record.get("extra", {}).get("service")
|
||||
return (
|
||||
"[StoryVideoGeneration]" in msg
|
||||
or "services.story_writer.video_generation_service" in name
|
||||
or "[video_gen]" in msg
|
||||
or service == "video_generation_service"
|
||||
or "services.llm_providers.main_video_generation" in name
|
||||
)
|
||||
logger.add(
|
||||
sys.stdout.write,
|
||||
level="INFO",
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException, Depends, status
|
||||
from fastapi import HTTPException, Depends, status, Request, Query
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from loguru import logger
|
||||
from dotenv import load_dotenv
|
||||
@@ -259,3 +259,63 @@ async def get_optional_user(
|
||||
except Exception as e:
|
||||
logger.warning(f"Optional authentication failed: {e}")
|
||||
return None
|
||||
|
||||
async def get_current_user_with_query_token(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current authenticated user from either Authorization header or query parameter.
|
||||
|
||||
This is useful for media endpoints (audio, video, images) that need to be accessed
|
||||
by HTML elements like <audio> or <img> which cannot send custom headers.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
credentials: HTTP authorization credentials from header
|
||||
|
||||
Returns:
|
||||
User dictionary with authentication info
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
try:
|
||||
# Try to get token from Authorization header first
|
||||
token_to_verify = None
|
||||
if credentials:
|
||||
token_to_verify = credentials.credentials
|
||||
else:
|
||||
# Fall back to query parameter if no header
|
||||
query_token = request.query_params.get("token")
|
||||
if query_token:
|
||||
token_to_verify = query_token
|
||||
|
||||
if not token_to_verify:
|
||||
logger.warning("No credentials provided (neither header nor query parameter)")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user = await clerk_auth.verify_token(token_to_verify)
|
||||
if not user:
|
||||
# Token verification failed (likely expired) - log at debug level to reduce noise
|
||||
logger.debug("Token verification failed (likely expired token)")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication failed",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Authentication error: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication failed",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
@@ -207,6 +207,32 @@ class StoryImageGenerationResponse(BaseModel):
|
||||
task_id: Optional[str] = Field(None, description="Task ID for async operations")
|
||||
|
||||
|
||||
class RegenerateImageRequest(BaseModel):
|
||||
"""Request model for regenerating a single scene image with a direct prompt."""
|
||||
scene_number: int = Field(..., description="Scene number to regenerate image for")
|
||||
scene_title: str = Field(..., description="Scene title")
|
||||
prompt: str = Field(..., description="Direct prompt to use for image generation (no AI prompt generation)")
|
||||
provider: Optional[str] = Field(None, description="Image generation provider (gemini, huggingface, stability)")
|
||||
width: Optional[int] = Field(1024, description="Image width")
|
||||
height: Optional[int] = Field(1024, description="Image height")
|
||||
model: Optional[str] = Field(None, description="Model to use for image generation")
|
||||
|
||||
|
||||
class RegenerateImageResponse(BaseModel):
|
||||
"""Response model for regenerated image."""
|
||||
scene_number: int = Field(..., description="Scene number")
|
||||
scene_title: str = Field(..., description="Scene title")
|
||||
image_filename: str = Field(..., description="Generated image filename")
|
||||
image_url: str = Field(..., description="Image URL")
|
||||
width: int = Field(..., description="Image width")
|
||||
height: int = Field(..., description="Image height")
|
||||
provider: str = Field(..., description="Provider used")
|
||||
model: Optional[str] = Field(None, description="Model used")
|
||||
seed: Optional[int] = Field(None, description="Seed used")
|
||||
success: bool = Field(default=True, description="Whether the generation was successful")
|
||||
error: Optional[str] = Field(None, description="Error message if generation failed")
|
||||
|
||||
|
||||
class StoryAudioGenerationRequest(BaseModel):
|
||||
"""Request model for audio generation."""
|
||||
scenes: List[StoryScene] = Field(..., description="List of scenes to generate audio for")
|
||||
@@ -234,11 +260,41 @@ class StoryAudioGenerationResponse(BaseModel):
|
||||
task_id: Optional[str] = Field(None, description="Task ID for async operations")
|
||||
|
||||
|
||||
class GenerateAIAudioRequest(BaseModel):
|
||||
"""Request model for generating AI audio for a single scene."""
|
||||
scene_number: int = Field(..., description="Scene number to generate audio for")
|
||||
scene_title: str = Field(..., description="Scene title")
|
||||
text: str = Field(..., description="Text to convert to speech")
|
||||
voice_id: Optional[str] = Field("Wise_Woman", description="Voice ID for AI audio generation")
|
||||
speed: Optional[float] = Field(1.0, description="Speech speed (0.5-2.0)")
|
||||
volume: Optional[float] = Field(1.0, description="Speech volume (0.1-10.0)")
|
||||
pitch: Optional[float] = Field(0.0, description="Speech pitch (-12 to 12)")
|
||||
emotion: Optional[str] = Field("happy", description="Emotion for speech")
|
||||
|
||||
|
||||
class GenerateAIAudioResponse(BaseModel):
|
||||
"""Response model for AI audio generation."""
|
||||
scene_number: int = Field(..., description="Scene number")
|
||||
scene_title: str = Field(..., description="Scene title")
|
||||
audio_filename: str = Field(..., description="Generated audio filename")
|
||||
audio_url: str = Field(..., description="Audio URL")
|
||||
provider: str = Field(..., description="Provider used (wavespeed)")
|
||||
model: str = Field(..., description="Model used (minimax/speech-02-hd)")
|
||||
voice_id: str = Field(..., description="Voice ID used")
|
||||
text_length: int = Field(..., description="Number of characters in text")
|
||||
file_size: int = Field(..., description="Audio file size in bytes")
|
||||
cost: float = Field(..., description="Cost of generation")
|
||||
success: bool = Field(default=True, description="Whether the generation was successful")
|
||||
error: Optional[str] = Field(None, description="Error message if generation failed")
|
||||
|
||||
|
||||
class StoryVideoGenerationRequest(BaseModel):
|
||||
"""Request model for video generation."""
|
||||
scenes: List[StoryScene] = Field(..., description="List of scenes to generate video for")
|
||||
image_urls: List[str] = Field(..., description="List of image URLs for each scene")
|
||||
audio_urls: List[str] = Field(..., description="List of audio URLs for each scene")
|
||||
video_urls: Optional[List[Optional[str]]] = Field(None, description="Optional list of animated video URLs (preferred over images)")
|
||||
ai_audio_urls: Optional[List[Optional[str]]] = Field(None, description="Optional list of AI audio URLs (preferred over free audio)")
|
||||
story_title: Optional[str] = Field(default="Story", description="Title of the story")
|
||||
fps: Optional[int] = Field(default=24, description="Frames per second for video")
|
||||
transition_duration: Optional[float] = Field(default=0.5, description="Duration of transitions between scenes")
|
||||
@@ -260,3 +316,39 @@ class StoryVideoGenerationResponse(BaseModel):
|
||||
video: StoryVideoResult = Field(..., description="Generated video")
|
||||
success: bool = Field(default=True, description="Whether the generation was successful")
|
||||
task_id: Optional[str] = Field(None, description="Task ID for async operations")
|
||||
|
||||
|
||||
class AnimateSceneRequest(BaseModel):
|
||||
"""Request model for per-scene animation preview."""
|
||||
scene_number: int = Field(..., description="Scene number to animate")
|
||||
scene_data: Dict[str, Any] = Field(..., description="Scene data payload")
|
||||
story_context: Dict[str, Any] = Field(..., description="Story-wide context used for prompts")
|
||||
image_url: str = Field(..., description="Relative URL to the generated scene image")
|
||||
duration: int = Field(default=5, description="Animation duration (5 or 10 seconds)")
|
||||
|
||||
|
||||
class AnimateSceneVoiceoverRequest(AnimateSceneRequest):
|
||||
"""Request model for WaveSpeed InfiniteTalk animation."""
|
||||
audio_url: str = Field(..., description="Relative URL to the generated scene audio")
|
||||
resolution: Optional[str] = Field("720p", description="Output resolution ('480p' or '720p')")
|
||||
prompt: Optional[str] = Field(None, description="Optional positive prompt override")
|
||||
|
||||
|
||||
class AnimateSceneResponse(BaseModel):
|
||||
"""Response model for scene animation preview."""
|
||||
success: bool = Field(default=True, description="Whether the animation succeeded")
|
||||
scene_number: int = Field(..., description="Scene number animated")
|
||||
video_filename: str = Field(..., description="Stored video filename")
|
||||
video_url: str = Field(..., description="API URL to access the animated video")
|
||||
duration: int = Field(..., description="Duration of the animation")
|
||||
cost: float = Field(..., description="Cost billed for the animation")
|
||||
prompt_used: str = Field(..., description="Animation prompt passed to the model")
|
||||
provider: str = Field(default="wavespeed", description="Underlying provider used")
|
||||
prediction_id: Optional[str] = Field(None, description="WaveSpeed prediction ID for resume operations")
|
||||
|
||||
|
||||
class ResumeSceneAnimationRequest(BaseModel):
|
||||
"""Request model to resume scene animation download."""
|
||||
prediction_id: str = Field(..., description="WaveSpeed prediction ID to resume from")
|
||||
scene_number: int = Field(..., description="Scene number being resumed")
|
||||
duration: int = Field(default=5, description="Animation duration (5 or 10 seconds)")
|
||||
|
||||
@@ -37,6 +37,7 @@ class APIProvider(enum.Enum):
|
||||
EXA = "exa"
|
||||
VIDEO = "video"
|
||||
IMAGE_EDIT = "image_edit"
|
||||
AUDIO = "audio"
|
||||
|
||||
class BillingCycle(enum.Enum):
|
||||
MONTHLY = "monthly"
|
||||
@@ -72,6 +73,7 @@ class SubscriptionPlan(Base):
|
||||
exa_calls_limit = Column(Integer, default=0) # Exa neural search
|
||||
video_calls_limit = Column(Integer, default=0) # AI video generation
|
||||
image_edit_calls_limit = Column(Integer, default=0) # AI image editing
|
||||
audio_calls_limit = Column(Integer, default=0) # AI audio generation (text-to-speech)
|
||||
|
||||
# Token Limits (for LLM providers)
|
||||
gemini_tokens_limit = Column(Integer, default=0)
|
||||
@@ -191,6 +193,7 @@ class UsageSummary(Base):
|
||||
exa_calls = Column(Integer, default=0)
|
||||
video_calls = Column(Integer, default=0) # AI video generation
|
||||
image_edit_calls = Column(Integer, default=0) # AI image editing
|
||||
audio_calls = Column(Integer, default=0) # AI audio generation (text-to-speech)
|
||||
|
||||
# Token Usage
|
||||
gemini_tokens = Column(Integer, default=0)
|
||||
@@ -211,6 +214,7 @@ class UsageSummary(Base):
|
||||
exa_cost = Column(Float, default=0.0)
|
||||
video_cost = Column(Float, default=0.0) # AI video generation
|
||||
image_edit_cost = Column(Float, default=0.0) # AI image editing
|
||||
audio_cost = Column(Float, default=0.0) # AI audio generation (text-to-speech)
|
||||
|
||||
# Totals
|
||||
total_calls = Column(Integer, default=0)
|
||||
|
||||
301
backend/services/llm_providers/main_audio_generation.py
Normal file
301
backend/services/llm_providers/main_audio_generation.py
Normal file
@@ -0,0 +1,301 @@
|
||||
"""
|
||||
Main Audio Generation Service for ALwrity Backend.
|
||||
|
||||
This service provides AI-powered text-to-speech functionality using WaveSpeed Minimax Speech 02 HD.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from typing import Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("audio_generation")
|
||||
|
||||
|
||||
class AudioGenerationResult:
|
||||
"""Result of audio generation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
provider: str,
|
||||
model: str,
|
||||
voice_id: str,
|
||||
text_length: int,
|
||||
file_size: int,
|
||||
):
|
||||
self.audio_bytes = audio_bytes
|
||||
self.provider = provider
|
||||
self.model = model
|
||||
self.voice_id = voice_id
|
||||
self.text_length = text_length
|
||||
self.file_size = file_size
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str,
|
||||
voice_id: str = "Wise_Woman",
|
||||
speed: float = 1.0,
|
||||
volume: float = 1.0,
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> AudioGenerationResult:
|
||||
"""
|
||||
Generate audio using AI text-to-speech with subscription tracking.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech (max 10000 characters)
|
||||
voice_id: Voice ID (default: "Wise_Woman")
|
||||
speed: Speech speed (0.5-2.0, default: 1.0)
|
||||
volume: Speech volume (0.1-10.0, default: 1.0)
|
||||
pitch: Speech pitch (-12 to 12, default: 0.0)
|
||||
emotion: Emotion (default: "happy")
|
||||
user_id: User ID for subscription checking (required)
|
||||
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
|
||||
|
||||
Returns:
|
||||
AudioGenerationResult: Generated audio result
|
||||
|
||||
Raises:
|
||||
RuntimeError: If subscription limits are exceeded or user_id is missing.
|
||||
"""
|
||||
try:
|
||||
logger.info("[audio_gen] Starting audio generation")
|
||||
logger.debug(f"[audio_gen] Text length: {len(text)} characters, voice: {voice_id}")
|
||||
|
||||
# SUBSCRIPTION CHECK - Required and strict enforcement
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
# Calculate cost based on character count (every character is 1 token)
|
||||
# Pricing: $0.05 per 1,000 characters
|
||||
character_count = len(text)
|
||||
cost_per_1000_chars = 0.05
|
||||
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import UsageSummary, APIProvider
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Check limits using sync method from pricing service (strict enforcement)
|
||||
# Use AUDIO provider for audio generation
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=character_count, # Use character count as "tokens" for audio
|
||||
actual_provider_name="wavespeed" # Actual provider is WaveSpeed
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"[audio_gen] Subscription limit exceeded for user {user_id}: {message}")
|
||||
error_detail = {
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': 'wavespeed',
|
||||
'usage_info': usage_info if usage_info else {}
|
||||
}
|
||||
raise HTTPException(status_code=429, detail=error_detail)
|
||||
|
||||
# Get current usage for limit checking
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
logger.error(f"[audio_gen] Subscription check failed for user {user_id}: {sub_error}")
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
# Generate audio using WaveSpeed
|
||||
try:
|
||||
client = WaveSpeedClient()
|
||||
audio_bytes = client.generate_speech(
|
||||
text=text,
|
||||
voice_id=voice_id,
|
||||
speed=speed,
|
||||
volume=volume,
|
||||
pitch=pitch,
|
||||
emotion=emotion,
|
||||
enable_sync_mode=True,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[audio_gen] Audio generation API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Audio generation failed",
|
||||
"message": str(api_error)
|
||||
}
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
if audio_bytes:
|
||||
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
|
||||
# Update audio calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed",
|
||||
method="POST",
|
||||
model_used="minimax/speech-02-hd",
|
||||
tokens_input=character_count,
|
||||
tokens_output=0,
|
||||
tokens_total=character_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(text.encode("utf-8")),
|
||||
response_size=len(audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_image_calls = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[audio_gen] ✅ Successfully tracked usage: user {user_id} -> audio -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Audio Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: wavespeed
|
||||
├─ Actual Provider: wavespeed
|
||||
├─ Model: minimax/speech-02-hd
|
||||
├─ Voice: {voice_id}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {audio_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Characters: {character_count}
|
||||
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[audio_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[audio_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return AudioGenerationResult(
|
||||
audio_bytes=audio_bytes,
|
||||
provider="wavespeed",
|
||||
model="minimax/speech-02-hd",
|
||||
voice_id=voice_id,
|
||||
text_length=character_count,
|
||||
file_size=len(audio_bytes),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[audio_gen] Error generating audio: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Audio generation failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
@@ -515,6 +515,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
|
||||
import sys
|
||||
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
|
||||
@@ -571,6 +577,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
@@ -819,6 +827,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
@@ -838,6 +852,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
|
||||
@@ -10,6 +10,7 @@ from __future__ import annotations
|
||||
import os
|
||||
import base64
|
||||
import io
|
||||
import sys
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -22,11 +23,11 @@ except ImportError:
|
||||
InferenceClient = None
|
||||
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("video_generation_service")
|
||||
|
||||
|
||||
class VideoProviderNotImplemented(Exception):
|
||||
pass
|
||||
|
||||
@@ -48,44 +49,80 @@ def _get_api_key(provider: str) -> Optional[str]:
|
||||
def _coerce_video_bytes(output: Any) -> bytes:
|
||||
"""
|
||||
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
|
||||
Depending on the provider/library version we may get:
|
||||
- raw bytes
|
||||
- an object with `.video` or `.bytes` attributes (plus optional `.save`)
|
||||
- a dict containing a `video` key with bytes/base64 data
|
||||
According to HF docs, text_to_video() should return bytes directly.
|
||||
"""
|
||||
data: Union[bytes, bytearray, memoryview, io.BufferedIOBase, None] = None
|
||||
|
||||
logger.debug(f"[video_gen] _coerce_video_bytes received type: {type(output)}")
|
||||
|
||||
# Most common case: bytes directly
|
||||
if isinstance(output, (bytes, bytearray, memoryview)):
|
||||
logger.debug(f"[video_gen] Output is bytes: {len(output)} bytes")
|
||||
return bytes(output)
|
||||
|
||||
# Handle file-like objects
|
||||
if hasattr(output, "read"):
|
||||
logger.debug("[video_gen] Output has read() method, reading...")
|
||||
data = output.read()
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
raise TypeError(f"File-like object returned non-bytes: {type(data)}")
|
||||
|
||||
# Objects with direct attribute access
|
||||
if hasattr(output, "video"):
|
||||
logger.debug("[video_gen] Output has 'video' attribute")
|
||||
data = getattr(output, "video")
|
||||
elif hasattr(output, "bytes"):
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
|
||||
if hasattr(output, "bytes"):
|
||||
logger.debug("[video_gen] Output has 'bytes' attribute")
|
||||
data = getattr(output, "bytes")
|
||||
elif isinstance(output, dict) and "video" in output:
|
||||
data = output["video"]
|
||||
else:
|
||||
data = output
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
|
||||
# Handle file-like responses
|
||||
if hasattr(data, "read"):
|
||||
data = data.read()
|
||||
# Dict handling - but this shouldn't happen with text_to_video()
|
||||
if isinstance(output, dict):
|
||||
logger.warning(f"[video_gen] Received dict output (unexpected): keys={list(output.keys())}")
|
||||
# Try to get video key safely - use .get() to avoid KeyError
|
||||
data = output.get("video")
|
||||
if data is not None:
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
# Try other common keys
|
||||
for key in ["data", "content", "file", "result", "output"]:
|
||||
data = output.get(key)
|
||||
if data is not None:
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
if hasattr(data, "read"):
|
||||
return bytes(data.read())
|
||||
raise TypeError(f"Dict output has no recognized video key. Keys: {list(output.keys())}")
|
||||
|
||||
if isinstance(data, (bytes, bytearray, memoryview)):
|
||||
return bytes(data)
|
||||
|
||||
if isinstance(data, str):
|
||||
# Expecting data URI or raw base64 string
|
||||
if data.startswith("data:"):
|
||||
_, encoded = data.split(",", 1)
|
||||
# String handling (base64)
|
||||
if isinstance(output, str):
|
||||
logger.debug("[video_gen] Output is string, attempting base64 decode")
|
||||
if output.startswith("data:"):
|
||||
_, encoded = output.split(",", 1)
|
||||
return base64.b64decode(encoded)
|
||||
try:
|
||||
return base64.b64decode(data)
|
||||
return base64.b64decode(output)
|
||||
except Exception as exc:
|
||||
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
|
||||
|
||||
raise TypeError(f"Unsupported video payload type: {type(data)}")
|
||||
# Fallback: try to use output directly
|
||||
logger.warning(f"[video_gen] Unexpected output type: {type(output)}, attempting direct conversion")
|
||||
try:
|
||||
if hasattr(output, "__bytes__"):
|
||||
return bytes(output)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
raise TypeError(f"Unsupported video payload type: {type(output)}. Output: {str(output)[:200]}")
|
||||
|
||||
|
||||
def _generate_with_huggingface(
|
||||
@@ -96,7 +133,6 @@ def _generate_with_huggingface(
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
model: str = "tencent/HunyuanVideo",
|
||||
input_image_bytes: Optional[bytes] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generates video bytes using Hugging Face's InferenceClient.
|
||||
@@ -109,7 +145,6 @@ def _generate_with_huggingface(
|
||||
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
|
||||
|
||||
client = InferenceClient(
|
||||
model=model,
|
||||
provider="fal-ai",
|
||||
token=token,
|
||||
)
|
||||
@@ -126,26 +161,25 @@ def _generate_with_huggingface(
|
||||
params["seed"] = seed
|
||||
|
||||
logger.info(
|
||||
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=%s",
|
||||
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=text-to-video",
|
||||
model,
|
||||
num_frames,
|
||||
num_inference_steps,
|
||||
"image-to-video" if input_image_bytes else "text-to-video",
|
||||
)
|
||||
|
||||
try:
|
||||
call_kwargs = {**params, "model": model}
|
||||
if input_image_bytes:
|
||||
video_output = client.image_to_video(
|
||||
image=input_image_bytes,
|
||||
prompt=prompt,
|
||||
**call_kwargs,
|
||||
)
|
||||
else:
|
||||
video_output = client.text_to_video(
|
||||
prompt,
|
||||
**call_kwargs,
|
||||
)
|
||||
logger.info("[video_gen] Calling client.text_to_video()...")
|
||||
video_output = client.text_to_video(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
**params,
|
||||
)
|
||||
|
||||
logger.info(f"[video_gen] text_to_video() returned type: {type(video_output)}")
|
||||
if isinstance(video_output, dict):
|
||||
logger.info(f"[video_gen] Dict keys: {list(video_output.keys())}")
|
||||
elif hasattr(video_output, "__dict__"):
|
||||
logger.info(f"[video_gen] Object attributes: {dir(video_output)}")
|
||||
|
||||
video_bytes = _coerce_video_bytes(video_output)
|
||||
|
||||
@@ -158,6 +192,15 @@ def _generate_with_huggingface(
|
||||
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
|
||||
except KeyError as e:
|
||||
error_msg = str(e)
|
||||
logger.error(f"[video_gen] HF KeyError: {error_msg}", exc_info=True)
|
||||
logger.error(f"[video_gen] This suggests the API response format is unexpected. Check logs above for response type.")
|
||||
raise HTTPException(status_code=502, detail={
|
||||
"error": f"Hugging Face API returned unexpected response format: {error_msg}",
|
||||
"error_type": "KeyError",
|
||||
"hint": "The API response may have changed. Check server logs for details."
|
||||
})
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
error_type = type(e).__name__
|
||||
@@ -179,7 +222,6 @@ def ai_video_generate(
|
||||
prompt: str,
|
||||
provider: str = "huggingface",
|
||||
user_id: Optional[str] = None,
|
||||
input_image_bytes: Optional[bytes] = None,
|
||||
**kwargs,
|
||||
) -> bytes:
|
||||
"""
|
||||
@@ -187,7 +229,6 @@ def ai_video_generate(
|
||||
|
||||
- provider: 'huggingface' (default), 'gemini' (veo3 stub), 'openai' (sora stub)
|
||||
- kwargs: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
|
||||
- input_image_bytes: optional bytes for image-to-video flows (uses image as motion anchor)
|
||||
|
||||
Returns raw video bytes (mp4/webm depending on provider).
|
||||
"""
|
||||
@@ -200,7 +241,6 @@ def ai_video_generate(
|
||||
# PRE-FLIGHT VALIDATION: Validate video generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
@@ -227,7 +267,6 @@ def ai_video_generate(
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(
|
||||
prompt=prompt,
|
||||
input_image_bytes=input_image_bytes,
|
||||
**kwargs,
|
||||
)
|
||||
elif provider == "gemini":
|
||||
@@ -237,112 +276,14 @@ def ai_video_generate(
|
||||
else:
|
||||
raise RuntimeError(f"Unknown video provider: {provider}")
|
||||
|
||||
# Track usage AFTER successful generation
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
from models.subscription_models import APIProvider, UsageSummary, APIUsageLog
|
||||
from datetime import datetime
|
||||
from services.subscription import PricingService
|
||||
|
||||
# Create pricing service for tracking (uses same DB session)
|
||||
pricing_service_track = PricingService(db_track)
|
||||
|
||||
# Get current billing period
|
||||
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
usage_summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage_summary:
|
||||
usage_summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(usage_summary)
|
||||
db_track.commit()
|
||||
|
||||
# Calculate cost using pricing service
|
||||
cost_info = pricing_service_track.get_pricing_for_provider_model(
|
||||
APIProvider.VIDEO,
|
||||
model_name
|
||||
)
|
||||
cost_per_video = cost_info.get('cost_per_request', 0.10) if cost_info else 0.10
|
||||
|
||||
# Get "before" state for unified log
|
||||
current_video_calls_before = getattr(usage_summary, 'video_calls', 0) or 0
|
||||
current_video_cost = getattr(usage_summary, 'video_cost', 0.0) or 0.0
|
||||
|
||||
# Increment video_calls and track cost
|
||||
new_video_calls = current_video_calls_before + 1
|
||||
usage_summary.video_calls = new_video_calls
|
||||
usage_summary.video_cost = current_video_cost + cost_per_video
|
||||
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
|
||||
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
|
||||
|
||||
# Get plan details for unified log (before commit, in case commit fails)
|
||||
limits = pricing_service_track.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get image and image editing stats for unified log
|
||||
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Create usage log entry for audit trail
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.VIDEO,
|
||||
endpoint=f"/video-generation/{provider}",
|
||||
method="POST",
|
||||
model_used=model_name,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost_per_video,
|
||||
response_time=0.0, # Could track actual time if needed
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode('utf-8')),
|
||||
response_size=len(video_bytes),
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[video_gen] ✅ Successfully tracked usage: user {user_id} -> 1 video call, ${cost_per_video:.4f} cost")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Flush immediately to ensure it's visible in console/logs
|
||||
import sys
|
||||
log_message = f"""
|
||||
[SUBSCRIPTION] Video Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: video
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model_name or 'default'}
|
||||
├─ Calls: {current_video_calls_before} → {new_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
"""
|
||||
print(log_message, flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
# Don't fail video generation if tracking fails - video is already generated
|
||||
finally:
|
||||
db_track.close()
|
||||
|
||||
track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt=prompt,
|
||||
video_bytes=video_bytes,
|
||||
)
|
||||
|
||||
return video_bytes
|
||||
|
||||
except HTTPException:
|
||||
@@ -353,3 +294,139 @@ def ai_video_generate(
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
|
||||
|
||||
def track_video_usage(
|
||||
*,
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model_name: str,
|
||||
prompt: str,
|
||||
video_bytes: bytes,
|
||||
cost_override: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Track subscription usage for any video generation (text-to-video or image-to-video).
|
||||
"""
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
|
||||
from services.database import get_db
|
||||
|
||||
db_track = next(get_db())
|
||||
try:
|
||||
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
|
||||
pricing_service_track = PricingService(db_track)
|
||||
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
logger.debug(f"[video_gen] Billing period: {current_period}")
|
||||
|
||||
usage_summary = (
|
||||
db_track.query(UsageSummary)
|
||||
.filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not usage_summary:
|
||||
logger.debug(f"[video_gen] Creating new UsageSummary for user={user_id}, period={current_period}")
|
||||
usage_summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_summary)
|
||||
db_track.commit()
|
||||
db_track.refresh(usage_summary)
|
||||
else:
|
||||
logger.debug(f"[video_gen] Found existing UsageSummary: video_calls={getattr(usage_summary, 'video_calls', 0)}")
|
||||
|
||||
cost_info = pricing_service_track.get_pricing_for_provider_model(
|
||||
APIProvider.VIDEO,
|
||||
model_name,
|
||||
)
|
||||
default_cost = 0.10
|
||||
if cost_info and cost_info.get("cost_per_request") is not None:
|
||||
default_cost = cost_info["cost_per_request"]
|
||||
cost_per_video = cost_override if cost_override is not None else default_cost
|
||||
logger.debug(f"[video_gen] Cost per video: ${cost_per_video} (override={cost_override}, default={default_cost})")
|
||||
|
||||
current_video_calls_before = getattr(usage_summary, "video_calls", 0) or 0
|
||||
current_video_cost = getattr(usage_summary, "video_cost", 0.0) or 0.0
|
||||
usage_summary.video_calls = current_video_calls_before + 1
|
||||
usage_summary.video_cost = current_video_cost + cost_per_video
|
||||
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
|
||||
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
|
||||
# Ensure the object is in the session
|
||||
db_track.add(usage_summary)
|
||||
logger.debug(f"[video_gen] Updated usage_summary: video_calls={current_video_calls_before} → {usage_summary.video_calls}")
|
||||
|
||||
limits = pricing_service_track.get_user_limits(user_id)
|
||||
plan_name = limits.get("plan_name", "unknown") if limits else "unknown"
|
||||
tier = limits.get("tier", "unknown") if limits else "unknown"
|
||||
video_limit = limits["limits"].get("video_calls", 0) if limits else 0
|
||||
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
|
||||
image_limit = limits["limits"].get("stability_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits["limits"].get("image_edit_calls", 0) if limits else 0
|
||||
current_audio_calls = getattr(usage_summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits["limits"].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.VIDEO,
|
||||
endpoint=f"/video-generation/{provider}",
|
||||
method="POST",
|
||||
model_used=model_name,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost_per_video,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode("utf-8")),
|
||||
response_size=len(video_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
logger.debug(f"[video_gen] Flushing changes before commit...")
|
||||
db_track.flush()
|
||||
logger.debug(f"[video_gen] Committing usage tracking changes...")
|
||||
db_track.commit()
|
||||
db_track.refresh(usage_summary)
|
||||
logger.debug(f"[video_gen] Commit successful. Final video_calls: {usage_summary.video_calls}, video_cost: {usage_summary.video_cost}")
|
||||
|
||||
video_limit_display = video_limit if video_limit > 0 else '∞'
|
||||
|
||||
log_message = f"""
|
||||
[SUBSCRIPTION] Video Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: video
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model_name or 'default'}
|
||||
├─ Calls: {current_video_calls_before} → {usage_summary.video_calls} / {video_limit_display}
|
||||
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
"""
|
||||
logger.info(log_message)
|
||||
return {
|
||||
"previous_calls": current_video_calls_before,
|
||||
"current_calls": usage_summary.video_calls,
|
||||
"video_limit": video_limit,
|
||||
"video_limit_display": video_limit_display,
|
||||
"cost_per_video": cost_per_video,
|
||||
"total_video_cost": usage_summary.video_cost,
|
||||
}
|
||||
except Exception as track_error:
|
||||
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
|
||||
logger.error(f"[video_gen] Exception type: {type(track_error).__name__}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
|
||||
|
||||
|
||||
@@ -414,7 +414,8 @@ class APIKeyManager:
|
||||
'SERPER_API_KEY',
|
||||
'METAPHOR_API_KEY',
|
||||
'FIRECRAWL_API_KEY',
|
||||
'STABILITY_API_KEY'
|
||||
'STABILITY_API_KEY',
|
||||
'WAVESPEED_API_KEY',
|
||||
]
|
||||
|
||||
for provider in providers:
|
||||
|
||||
@@ -288,4 +288,90 @@ class StoryAudioGenerationService:
|
||||
|
||||
logger.info(f"[StoryAudioGeneration] Generated {len(audio_results)} audio files out of {total_scenes} scenes")
|
||||
return audio_results
|
||||
|
||||
def generate_ai_audio(
|
||||
self,
|
||||
scene_number: int,
|
||||
scene_title: str,
|
||||
text: str,
|
||||
user_id: str,
|
||||
voice_id: str = "Wise_Woman",
|
||||
speed: float = 1.0,
|
||||
volume: float = 1.0,
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate AI audio for a single scene using main_audio_generation.
|
||||
|
||||
Parameters:
|
||||
scene_number (int): Scene number.
|
||||
scene_title (str): Scene title.
|
||||
text (str): Text to convert to speech.
|
||||
user_id (str): Clerk user ID for subscription checking.
|
||||
voice_id (str): Voice ID for AI audio generation (default: "Wise_Woman").
|
||||
speed (float): Speech speed (0.5-2.0, default: 1.0).
|
||||
volume (float): Speech volume (0.1-10.0, default: 1.0).
|
||||
pitch (float): Speech pitch (-12 to 12, default: 0.0).
|
||||
emotion (str): Emotion for speech (default: "happy").
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Audio metadata including file path, URL, and scene info.
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
raise ValueError(f"Scene {scene_number} ({scene_title}) requires non-empty text")
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryAudioGeneration] Generating AI audio for scene {scene_number}: {scene_title}")
|
||||
logger.debug(f"[StoryAudioGeneration] Text length: {len(text)} characters, voice: {voice_id}")
|
||||
|
||||
# Import main_audio_generation
|
||||
from services.llm_providers.main_audio_generation import generate_audio
|
||||
|
||||
# Generate audio using main_audio_generation service
|
||||
result = generate_audio(
|
||||
text=text.strip(),
|
||||
voice_id=voice_id,
|
||||
speed=speed,
|
||||
volume=volume,
|
||||
pitch=pitch,
|
||||
emotion=emotion,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Save audio to file
|
||||
audio_filename = self._generate_audio_filename(scene_number, scene_title)
|
||||
audio_path = self.output_dir / audio_filename
|
||||
|
||||
with open(audio_path, "wb") as f:
|
||||
f.write(result.audio_bytes)
|
||||
|
||||
logger.info(f"[StoryAudioGeneration] Saved AI audio to: {audio_path} ({result.file_size} bytes)")
|
||||
|
||||
# Calculate cost (for response)
|
||||
character_count = result.text_length
|
||||
cost_per_1000_chars = 0.05
|
||||
cost = (character_count / 1000.0) * cost_per_1000_chars
|
||||
|
||||
# Return audio metadata
|
||||
return {
|
||||
"scene_number": scene_number,
|
||||
"scene_title": scene_title,
|
||||
"audio_path": str(audio_path),
|
||||
"audio_filename": audio_filename,
|
||||
"audio_url": f"/api/story/audio/{audio_filename}",
|
||||
"provider": result.provider,
|
||||
"model": result.model,
|
||||
"voice_id": result.voice_id,
|
||||
"text_length": result.text_length,
|
||||
"file_size": result.file_size,
|
||||
"cost": cost,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryAudioGeneration] Error generating AI audio for scene {scene_number}: {e}")
|
||||
raise RuntimeError(f"Failed to generate AI audio for scene {scene_number}: {str(e)}") from e
|
||||
|
||||
|
||||
@@ -193,4 +193,82 @@ class StoryImageGenerationService:
|
||||
|
||||
logger.info(f"[StoryImageGeneration] Generated {len(image_results)} images out of {total_scenes} scenes")
|
||||
return image_results
|
||||
|
||||
def regenerate_scene_image(
|
||||
self,
|
||||
scene_number: int,
|
||||
scene_title: str,
|
||||
prompt: str,
|
||||
user_id: str,
|
||||
provider: Optional[str] = None,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
model: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Regenerate an image for a single scene using a direct prompt (no AI prompt generation).
|
||||
|
||||
Parameters:
|
||||
scene_number (int): Scene number.
|
||||
scene_title (str): Scene title.
|
||||
prompt (str): Direct prompt to use for image generation.
|
||||
user_id (str): Clerk user ID for subscription checking.
|
||||
provider (str, optional): Image generation provider (gemini, huggingface, stability).
|
||||
width (int): Image width (default: 1024).
|
||||
height (int): Image height (default: 1024).
|
||||
model (str, optional): Model to use for image generation.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Image metadata including file path, URL, and scene info.
|
||||
"""
|
||||
if not prompt or not prompt.strip():
|
||||
raise ValueError(f"Scene {scene_number} ({scene_title}) requires a non-empty prompt")
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryImageGeneration] Regenerating image for scene {scene_number}: {scene_title}")
|
||||
logger.debug(f"[StoryImageGeneration] Using direct prompt: {prompt[:100]}...")
|
||||
|
||||
# Generate image using main_image_generation service with the direct prompt
|
||||
image_options = {
|
||||
"provider": provider,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
result: ImageGenerationResult = generate_image(
|
||||
prompt=prompt.strip(),
|
||||
options=image_options,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Save image to file
|
||||
image_filename = self._generate_image_filename(scene_number, scene_title)
|
||||
image_path = self.output_dir / image_filename
|
||||
|
||||
with open(image_path, "wb") as f:
|
||||
f.write(result.image_bytes)
|
||||
|
||||
logger.info(f"[StoryImageGeneration] Saved regenerated image to: {image_path}")
|
||||
|
||||
# Return image metadata
|
||||
return {
|
||||
"scene_number": scene_number,
|
||||
"scene_title": scene_title,
|
||||
"image_path": str(image_path),
|
||||
"image_filename": image_filename,
|
||||
"image_url": f"/api/story/images/{image_filename}",
|
||||
"width": result.width,
|
||||
"height": result.height,
|
||||
"provider": result.provider,
|
||||
"model": result.model,
|
||||
"seed": result.seed,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryImageGeneration] Error regenerating image for scene {scene_number}: {e}")
|
||||
raise RuntimeError(f"Failed to regenerate image for scene {scene_number}: {str(e)}") from e
|
||||
|
||||
|
||||
@@ -220,35 +220,41 @@ class StoryVideoGenerationService:
|
||||
def generate_story_video(
|
||||
self,
|
||||
scenes: List[Dict[str, Any]],
|
||||
image_paths: List[str],
|
||||
image_paths: List[Optional[str]],
|
||||
audio_paths: List[str],
|
||||
user_id: str,
|
||||
story_title: str = "Story",
|
||||
fps: int = 24,
|
||||
transition_duration: float = 0.5,
|
||||
progress_callback: Optional[callable] = None
|
||||
progress_callback: Optional[callable] = None,
|
||||
video_paths: Optional[List[Optional[str]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a complete story video from multiple scenes.
|
||||
|
||||
Parameters:
|
||||
scenes (List[Dict[str, Any]]): List of scene data.
|
||||
image_paths (List[str]): List of image file paths for each scene.
|
||||
image_paths (List[Optional[str]]): List of image file paths (None if scene has animated video).
|
||||
audio_paths (List[str]): List of audio file paths for each scene.
|
||||
user_id (str): Clerk user ID for subscription checking.
|
||||
story_title (str): Title of the story (default: "Story").
|
||||
fps (int): Frames per second for video (default: 24).
|
||||
transition_duration (float): Duration of transitions between scenes in seconds (default: 0.5).
|
||||
progress_callback (callable, optional): Callback function for progress updates.
|
||||
video_paths (Optional[List[Optional[str]]]): List of animated video file paths (None if scene has static image).
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Video metadata including file path, URL, and story info.
|
||||
"""
|
||||
if not scenes or not image_paths or not audio_paths:
|
||||
raise ValueError("Scenes, image paths, and audio paths are required")
|
||||
if not scenes or not audio_paths:
|
||||
raise ValueError("Scenes and audio paths are required")
|
||||
|
||||
if len(scenes) != len(image_paths) or len(scenes) != len(audio_paths):
|
||||
raise ValueError("Number of scenes, image paths, and audio paths must match")
|
||||
if len(scenes) != len(audio_paths):
|
||||
raise ValueError("Number of scenes and audio paths must match")
|
||||
|
||||
video_paths = video_paths or [None] * len(scenes)
|
||||
if len(video_paths) != len(scenes):
|
||||
video_paths = video_paths + [None] * (len(scenes) - len(video_paths))
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryVideoGeneration] Generating story video for {len(scenes)} scenes")
|
||||
@@ -293,36 +299,59 @@ class StoryVideoGenerationService:
|
||||
scene_clips = []
|
||||
total_duration = 0.0
|
||||
|
||||
for idx, (scene, image_path, audio_path) in enumerate(zip(scenes, image_paths, audio_paths)):
|
||||
# Import VideoFileClip for animated videos
|
||||
try:
|
||||
from moviepy import VideoFileClip
|
||||
except ImportError:
|
||||
VideoFileClip = None
|
||||
|
||||
for idx, (scene, image_path, audio_path, video_path) in enumerate(zip(scenes, image_paths, audio_paths, video_paths)):
|
||||
try:
|
||||
scene_number = scene.get("scene_number", idx + 1)
|
||||
scene_title = scene.get("title", "Untitled")
|
||||
|
||||
logger.info(f"[StoryVideoGeneration] Processing scene {scene_number}/{len(scenes)}: {scene_title}")
|
||||
|
||||
# Load image and audio
|
||||
image_file = Path(image_path)
|
||||
audio_file = Path(audio_path)
|
||||
|
||||
if not image_file.exists():
|
||||
logger.warning(f"[StoryVideoGeneration] Image not found: {image_path}, skipping scene {scene_number}")
|
||||
continue
|
||||
if not audio_file.exists():
|
||||
logger.warning(f"[StoryVideoGeneration] Audio not found: {audio_path}, skipping scene {scene_number}")
|
||||
continue
|
||||
|
||||
# Load audio to get duration
|
||||
# Load audio
|
||||
audio_clip = AudioFileClip(str(audio_file))
|
||||
audio_duration = audio_clip.duration
|
||||
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
# Prefer animated video if available
|
||||
if video_path and Path(video_path).exists():
|
||||
logger.info(f"[StoryVideoGeneration] Using animated video for scene {scene_number}: {video_path}")
|
||||
# Load animated video
|
||||
if VideoFileClip is None:
|
||||
raise RuntimeError("VideoFileClip not available - MoviePy may not be fully installed")
|
||||
video_clip = VideoFileClip(str(video_path))
|
||||
# Replace audio with the preferred audio (AI or free)
|
||||
video_clip = video_clip.with_audio(audio_clip)
|
||||
# Match duration to audio if needed
|
||||
if video_clip.duration > audio_duration:
|
||||
video_clip = video_clip.subclip(0, audio_duration)
|
||||
elif video_clip.duration < audio_duration:
|
||||
# Loop the video if it's shorter than audio
|
||||
loops_needed = int(audio_duration / video_clip.duration) + 1
|
||||
video_clip = concatenate_videoclips([video_clip] * loops_needed).subclip(0, audio_duration)
|
||||
video_clip = video_clip.with_audio(audio_clip)
|
||||
elif image_path and Path(image_path).exists():
|
||||
# Fall back to static image
|
||||
logger.info(f"[StoryVideoGeneration] Using static image for scene {scene_number}: {image_path}")
|
||||
image_file = Path(image_path)
|
||||
# Create image clip (MoviePy v2: use with_* API)
|
||||
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
|
||||
image_clip = image_clip.with_fps(fps)
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
else:
|
||||
logger.warning(f"[StoryVideoGeneration] No video or image found for scene {scene_number}, skipping")
|
||||
continue
|
||||
|
||||
# Set audio to image clip
|
||||
video_clip = image_clip.with_audio(audio_clip)
|
||||
scene_clips.append(video_clip)
|
||||
|
||||
total_duration += audio_duration
|
||||
|
||||
# Call progress callback if provided
|
||||
|
||||
@@ -19,10 +19,18 @@ import re
|
||||
|
||||
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
|
||||
from models.subscription_models import APIProvider
|
||||
from services.database import get_db
|
||||
from .usage_tracking_service import UsageTrackingService
|
||||
from .pricing_service import PricingService
|
||||
|
||||
|
||||
def _get_db_session():
|
||||
"""
|
||||
Get a database session with lazy import to survive hot reloads.
|
||||
Uvicorn's reloader can sometimes clear module-level imports.
|
||||
"""
|
||||
from services.database import get_db
|
||||
return next(get_db())
|
||||
|
||||
class DatabaseAPIMonitor:
|
||||
"""Database-backed API monitoring with usage tracking and subscription management."""
|
||||
|
||||
@@ -145,8 +153,9 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db = None
|
||||
try:
|
||||
db = next(get_db())
|
||||
db = _get_db_session()
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
|
||||
# Detect if this is an API call that should be rate limited
|
||||
@@ -203,14 +212,15 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
# Don't block requests if usage checking fails
|
||||
return None
|
||||
finally:
|
||||
db.close()
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
async def monitoring_middleware(request: Request, call_next):
|
||||
"""Enhanced FastAPI middleware for monitoring API calls with usage tracking."""
|
||||
start_time = time.time()
|
||||
|
||||
# Get database session
|
||||
db = next(get_db())
|
||||
db = _get_db_session()
|
||||
|
||||
# Extract request details - Enhanced user identification
|
||||
user_id = None
|
||||
@@ -340,8 +350,9 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
|
||||
async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
|
||||
"""Get current monitoring statistics."""
|
||||
db = next(get_db())
|
||||
db = None
|
||||
try:
|
||||
db = _get_db_session()
|
||||
# Placeholder to match old API; heavy stats handled elsewhere
|
||||
return {
|
||||
'timestamp': datetime.utcnow().isoformat(),
|
||||
@@ -354,12 +365,14 @@ async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
|
||||
'system_health': {'status': 'healthy', 'error_rate': 0.0}
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
async def get_lightweight_stats() -> Dict[str, Any]:
|
||||
"""Get lightweight stats for dashboard header."""
|
||||
db = next(get_db())
|
||||
db = None
|
||||
try:
|
||||
db = _get_db_session()
|
||||
# Minimal viable placeholder values
|
||||
now = datetime.utcnow()
|
||||
return {
|
||||
@@ -371,4 +384,5 @@ async def get_lightweight_stats() -> Dict[str, Any]:
|
||||
'timestamp': now.isoformat()
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
if db is not None:
|
||||
db.close()
|
||||
|
||||
@@ -420,3 +420,54 @@ def validate_video_generation_operations(
|
||||
'message': f"Failed to validate video generation: {str(e)}"
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def validate_scene_animation_operation(
|
||||
pricing_service: PricingService,
|
||||
user_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Validate the per-scene animation workflow before API calls.
|
||||
"""
|
||||
try:
|
||||
operations_to_validate = [
|
||||
{
|
||||
'provider': APIProvider.VIDEO,
|
||||
'tokens_requested': 0,
|
||||
'actual_provider_name': 'wavespeed',
|
||||
'operation_type': 'scene_animation',
|
||||
}
|
||||
]
|
||||
|
||||
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
|
||||
user_id=user_id,
|
||||
operations=operations_to_validate,
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.error(f"[Pre-flight Validator] Scene animation blocked for user {user_id}: {message}")
|
||||
usage_info = error_details.get('usage_info', {}) if error_details else {}
|
||||
provider = usage_info.get('provider', 'video') if usage_info else 'video'
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': message,
|
||||
'message': message,
|
||||
'provider': provider,
|
||||
'usage_info': usage_info if usage_info else error_details,
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(f"[Pre-flight Validator] ✅ Scene animation validated for user {user_id}")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Pre-flight Validator] Error validating scene animation: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
'error': f"Failed to validate scene animation: {str(e)}",
|
||||
'message': f"Failed to validate scene animation: {str(e)}",
|
||||
},
|
||||
)
|
||||
@@ -307,6 +307,41 @@ class PricingService:
|
||||
"model_name": "default",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "AI Video Generation default pricing"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "kling-v2.5-turbo-std-5s",
|
||||
"cost_per_request": 0.21,
|
||||
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (5 seconds)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "kling-v2.5-turbo-std-10s",
|
||||
"cost_per_request": 0.42,
|
||||
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (10 seconds)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "wavespeed-ai/infinitetalk",
|
||||
"cost_per_request": 0.30,
|
||||
"description": "WaveSpeed InfiniteTalk (image + audio to talking avatar video)"
|
||||
},
|
||||
# Audio Generation Pricing (Minimax Speech 02 HD via WaveSpeed)
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "minimax/speech-02-hd",
|
||||
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters (every character is 1 token)
|
||||
"cost_per_output_token": 0.0, # No output tokens for audio
|
||||
"cost_per_request": 0.0, # Pricing is per character, not per request
|
||||
"description": "AI Audio Generation (Text-to-Speech) - Minimax Speech 02 HD via WaveSpeed"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "default",
|
||||
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters default
|
||||
"cost_per_output_token": 0.0,
|
||||
"cost_per_request": 0.0,
|
||||
"description": "AI Audio Generation default pricing"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -358,6 +393,7 @@ class PricingService:
|
||||
"exa_calls_limit": 100,
|
||||
"video_calls_limit": 0, # No video generation for free tier
|
||||
"image_edit_calls_limit": 10, # 10 AI image editing calls/month
|
||||
"audio_calls_limit": 20, # 20 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 100000,
|
||||
"monthly_cost_limit": 0.0,
|
||||
"features": ["basic_content_generation", "limited_research"],
|
||||
@@ -381,6 +417,7 @@ class PricingService:
|
||||
"exa_calls_limit": 500,
|
||||
"video_calls_limit": 20, # 20 videos/month for basic plan
|
||||
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
|
||||
"audio_calls_limit": 50, # 50 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
@@ -406,6 +443,7 @@ class PricingService:
|
||||
"exa_calls_limit": 2000,
|
||||
"video_calls_limit": 50, # 50 videos/month for pro plan
|
||||
"image_edit_calls_limit": 100, # 100 AI image editing calls/month
|
||||
"audio_calls_limit": 200, # 200 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 5000000,
|
||||
"openai_tokens_limit": 2500000,
|
||||
"anthropic_tokens_limit": 1000000,
|
||||
@@ -431,6 +469,7 @@ class PricingService:
|
||||
"exa_calls_limit": 0, # Unlimited
|
||||
"video_calls_limit": 0, # Unlimited for enterprise
|
||||
"image_edit_calls_limit": 0, # Unlimited image editing for enterprise
|
||||
"audio_calls_limit": 0, # Unlimited audio generation for enterprise
|
||||
"gemini_tokens_limit": 0,
|
||||
"openai_tokens_limit": 0,
|
||||
"anthropic_tokens_limit": 0,
|
||||
@@ -651,6 +690,7 @@ class PricingService:
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
|
||||
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
|
||||
'audio_calls': getattr(plan, 'audio_calls_limit', 0), # Support missing column
|
||||
# Token limits
|
||||
'gemini_tokens': plan.gemini_tokens_limit,
|
||||
'openai_tokens': plan.openai_tokens_limit,
|
||||
|
||||
@@ -31,6 +31,7 @@ def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
"exa_calls_limit": "INTEGER DEFAULT 0",
|
||||
"video_calls_limit": "INTEGER DEFAULT 0",
|
||||
"image_edit_calls_limit": "INTEGER DEFAULT 0",
|
||||
"audio_calls_limit": "INTEGER DEFAULT 0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
@@ -84,6 +85,8 @@ def ensure_usage_summaries_columns(db: Session) -> None:
|
||||
"video_cost": "REAL DEFAULT 0.0",
|
||||
"image_edit_calls": "INTEGER DEFAULT 0",
|
||||
"image_edit_cost": "REAL DEFAULT 0.0",
|
||||
"audio_calls": "INTEGER DEFAULT 0",
|
||||
"audio_cost": "REAL DEFAULT 0.0",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
|
||||
1
backend/services/wavespeed/__init__.py
Normal file
1
backend/services/wavespeed/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
471
backend/services/wavespeed/client.py
Normal file
471
backend/services/wavespeed/client.py
Normal file
@@ -0,0 +1,471 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from requests import exceptions as requests_exceptions
|
||||
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.client")
|
||||
|
||||
|
||||
class WaveSpeedClient:
|
||||
"""
|
||||
Thin HTTP client for the WaveSpeed AI API.
|
||||
Handles authentication, submission, and polling helpers.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://api.wavespeed.ai/api/v3"
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
manager = APIKeyManager()
|
||||
self.api_key = api_key or manager.get_api_key("wavespeed")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("WAVESPEED_API_KEY is not configured. Please add it to your environment.")
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def submit_image_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Submit an image-to-video generation request.
|
||||
|
||||
Returns the prediction ID for polling.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
logger.info(f"[WaveSpeed] Submitting request to {url}")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json().get("data")
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
|
||||
return prediction_id
|
||||
|
||||
def get_prediction_result(self, prediction_id: str, timeout: int = 120) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the current status/result for a prediction.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/predictions/{prediction_id}/result"
|
||||
try:
|
||||
response = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=timeout)
|
||||
except requests_exceptions.Timeout as exc:
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
except requests_exceptions.RequestException as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request failed",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prediction polling failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
result = response.json().get("data")
|
||||
if not result:
|
||||
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
|
||||
return result
|
||||
|
||||
def poll_until_complete(
|
||||
self,
|
||||
prediction_id: str,
|
||||
timeout_seconds: int = 240,
|
||||
interval_seconds: float = 1.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll WaveSpeed until the job completes, fails, or times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
result = self.get_prediction_result(prediction_id)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
|
||||
status = result.get("status")
|
||||
if status == "completed":
|
||||
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed.")
|
||||
return result
|
||||
if status == "failed":
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {result.get('error')}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed animation failed",
|
||||
"prediction_id": prediction_id,
|
||||
"details": result.get("error"),
|
||||
},
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed animation timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"details": result,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"[WaveSpeed] Prediction {prediction_id} status={status}. Waiting...")
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
def optimize_prompt(
|
||||
self,
|
||||
text: str,
|
||||
mode: str = "image",
|
||||
style: str = "default",
|
||||
image: Optional[str] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Optimize a prompt using WaveSpeed prompt optimizer.
|
||||
|
||||
Args:
|
||||
text: The prompt text to optimize
|
||||
mode: "image" or "video" (default: "image")
|
||||
style: "default", "artistic", "photographic", "technical", "anime", "realistic" (default: "default")
|
||||
image: Base64-encoded image for context (optional)
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 30)
|
||||
|
||||
Returns:
|
||||
Optimized prompt text
|
||||
"""
|
||||
model_path = "wavespeed-ai/prompt-optimizer"
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"mode": mode,
|
||||
"style": style,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
if image:
|
||||
payload["image"] = image
|
||||
|
||||
logger.info(f"[WaveSpeed] Optimizing prompt via {url} (mode={mode}, style={style})")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Prompt optimization failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prompt optimization failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer returned no outputs",
|
||||
)
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
# In sync mode, outputs[0] should be the optimized text directly (or a URL to fetch)
|
||||
optimized_prompt = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
|
||||
# If it's a string that looks like a URL, fetch it
|
||||
if isinstance(first_output, str):
|
||||
if first_output.startswith("http://") or first_output.startswith("https://"):
|
||||
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
||||
url_response = requests.get(first_output, timeout=timeout)
|
||||
if url_response.status_code == 200:
|
||||
optimized_prompt = url_response.text.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
)
|
||||
else:
|
||||
# It's already the text
|
||||
optimized_prompt = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
||||
|
||||
if not optimized_prompt:
|
||||
logger.error(f"[WaveSpeed] Could not extract optimized prompt from outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.poll_until_complete(prediction_id, timeout_seconds=60, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed prompt optimizer returned no outputs")
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
# In async mode, outputs[0] is typically a URL that needs to be fetched
|
||||
optimized_prompt = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
|
||||
# In async mode, it's usually a URL to fetch
|
||||
if isinstance(first_output, str):
|
||||
if first_output.startswith("http://") or first_output.startswith("https://"):
|
||||
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
||||
url_response = requests.get(first_output, timeout=timeout)
|
||||
if url_response.status_code == 200:
|
||||
optimized_prompt = url_response.text.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
)
|
||||
else:
|
||||
# If it's already text (shouldn't happen in async mode, but handle it)
|
||||
optimized_prompt = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
||||
|
||||
if not optimized_prompt:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
def generate_speech(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str,
|
||||
speed: float = 1.0,
|
||||
volume: float = 1.0,
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate speech audio using Minimax Speech 02 HD via WaveSpeed.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech (max 10000 characters)
|
||||
voice_id: Voice ID (e.g., "Wise_Woman", "Friendly_Person", etc.)
|
||||
speed: Speech speed (0.5-2.0, default: 1.0)
|
||||
volume: Speech volume (0.1-10.0, default: 1.0)
|
||||
pitch: Speech pitch (-12 to 12, default: 0.0)
|
||||
emotion: Emotion ("happy", "sad", "angry", etc., default: "happy")
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 60)
|
||||
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
|
||||
|
||||
Returns:
|
||||
bytes: Generated audio bytes
|
||||
"""
|
||||
model_path = "minimax/speech-02-hd"
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"volume": volume,
|
||||
"pitch": pitch,
|
||||
"emotion": emotion,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
optional_params = [
|
||||
"english_normalization",
|
||||
"sample_rate",
|
||||
"bitrate",
|
||||
"channel",
|
||||
"format",
|
||||
"language_boost",
|
||||
]
|
||||
for param in optional_params:
|
||||
if param in kwargs:
|
||||
payload[param] = kwargs[param]
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed speech generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator returned no outputs",
|
||||
)
|
||||
|
||||
# Extract audio URL from outputs
|
||||
audio_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
audio_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
audio_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
||||
logger.error(f"[WaveSpeed] Invalid audio URL in outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch audio bytes from URL
|
||||
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
||||
audio_response = requests.get(audio_url, timeout=timeout)
|
||||
if audio_response.status_code == 200:
|
||||
audio_bytes = audio_response.content
|
||||
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
||||
return audio_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated audio from WaveSpeed URL",
|
||||
)
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed speech generator returned no outputs")
|
||||
|
||||
# Extract audio URL and fetch
|
||||
audio_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
audio_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
audio_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch audio bytes
|
||||
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
||||
audio_response = requests.get(audio_url, timeout=timeout)
|
||||
if audio_response.status_code == 200:
|
||||
audio_bytes = audio_response.content
|
||||
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
||||
return audio_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated audio from WaveSpeed URL",
|
||||
)
|
||||
|
||||
122
backend/services/wavespeed/infinitetalk.py
Normal file
122
backend/services/wavespeed/infinitetalk.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from .client import WaveSpeedClient
|
||||
from .kling_animation import generate_animation_prompt
|
||||
|
||||
INFINITALK_MODEL_PATH = "wavespeed-ai/infinitetalk"
|
||||
INFINITALK_MODEL_NAME = "wavespeed-ai/infinitetalk"
|
||||
INFINITALK_DEFAULT_COST = 0.30 # $0.30 per 5 seconds at 720p tier
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MAX_AUDIO_BYTES = 50 * 1024 * 1024 # 50MB safety cap
|
||||
|
||||
|
||||
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
|
||||
encoded = base64.b64encode(content_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
|
||||
def animate_scene_with_voiceover(
|
||||
*,
|
||||
image_bytes: bytes,
|
||||
audio_bytes: bytes,
|
||||
scene_data: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
prompt_override: Optional[str] = None,
|
||||
image_mime: str = "image/png",
|
||||
audio_mime: str = "audio/mpeg",
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene image with narration audio using WaveSpeed InfiniteTalk.
|
||||
Returns dict with video bytes, prompt used, model name, and cost.
|
||||
"""
|
||||
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene image bytes missing for animation.")
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene audio bytes missing for animation.")
|
||||
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Scene image exceeds 10MB limit required by WaveSpeed InfiniteTalk.",
|
||||
)
|
||||
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Scene audio exceeds 50MB limit allowed for InfiniteTalk requests.",
|
||||
)
|
||||
|
||||
if resolution not in {"480p", "720p"}:
|
||||
raise HTTPException(status_code=400, detail="Resolution must be '480p' or '720p'.")
|
||||
|
||||
animation_prompt = prompt_override or generate_animation_prompt(scene_data, story_context, user_id)
|
||||
|
||||
payload = {
|
||||
"image": _as_data_uri(image_bytes, image_mime),
|
||||
"audio": _as_data_uri(audio_bytes, audio_mime),
|
||||
"resolution": resolution,
|
||||
}
|
||||
if animation_prompt:
|
||||
payload["prompt"] = animation_prompt
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
prediction_id = client.submit_image_to_video(INFINITALK_MODEL_PATH, payload, timeout=60)
|
||||
|
||||
try:
|
||||
result = client.poll_until_complete(prediction_id, timeout_seconds=600, interval_seconds=1.0)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed InfiniteTalk completed but returned no outputs.")
|
||||
|
||||
video_url = outputs[0]
|
||||
video_response = requests.get(video_url, timeout=180)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download InfiniteTalk video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
},
|
||||
)
|
||||
|
||||
metadata = result.get("metadata") or {}
|
||||
duration = metadata.get("duration_seconds") or metadata.get("duration") or 0
|
||||
|
||||
logger.info(
|
||||
"[InfiniteTalk] Generated talking avatar video user=%s scene=%s resolution=%s size=%s bytes",
|
||||
user_id,
|
||||
scene_data.get("scene_number"),
|
||||
resolution,
|
||||
len(video_response.content),
|
||||
)
|
||||
|
||||
return {
|
||||
"video_bytes": video_response.content,
|
||||
"prompt": animation_prompt,
|
||||
"duration": duration or 5,
|
||||
"model_name": INFINITALK_MODEL_NAME,
|
||||
"cost": INFINITALK_DEFAULT_COST,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
|
||||
|
||||
360
backend/services/wavespeed/kling_animation.py
Normal file
360
backend/services/wavespeed/kling_animation.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from .client import WaveSpeedClient
|
||||
|
||||
try:
|
||||
import imghdr
|
||||
except ModuleNotFoundError: # Python 3.13 removed imghdr
|
||||
imghdr = None
|
||||
|
||||
logger = get_service_logger("wavespeed.kling_animation")
|
||||
|
||||
KLING_MODEL_PATH = "kwaivgi/kling-v2.5-turbo-std/image-to-video"
|
||||
KLING_MODEL_5S = "kling-v2.5-turbo-std-5s"
|
||||
KLING_MODEL_10S = "kling-v2.5-turbo-std-10s"
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10 MB limit per docs
|
||||
|
||||
|
||||
def _detect_image_mime(image_bytes: bytes) -> str:
|
||||
if imghdr:
|
||||
detected = imghdr.what(None, h=image_bytes)
|
||||
if detected == "jpeg":
|
||||
return "image/jpeg"
|
||||
if detected == "png":
|
||||
return "image/png"
|
||||
if detected == "gif":
|
||||
return "image/gif"
|
||||
|
||||
header = image_bytes[:8]
|
||||
if header.startswith(b"\x89PNG"):
|
||||
return "image/png"
|
||||
if header[:2] == b"\xff\xd8":
|
||||
return "image/jpeg"
|
||||
if header[:3] in (b"GIF", b"GIF"):
|
||||
return "image/gif"
|
||||
|
||||
return "image/png"
|
||||
|
||||
|
||||
def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str, Any]) -> str:
|
||||
title = (scene_data.get("title") or "Scene").strip()
|
||||
description = (scene_data.get("description") or "").strip()
|
||||
image_prompt = (scene_data.get("image_prompt") or "").strip()
|
||||
tone = (story_context.get("story_tone") or "story").strip()
|
||||
setting = (story_context.get("story_setting") or "the scene").strip()
|
||||
|
||||
parts = [
|
||||
f"{title} cinematic motion shot.",
|
||||
description[:220] if description else "",
|
||||
f"Camera glides with subtle parallax over {setting}.",
|
||||
f"Maintain a {tone} mood with natural lighting accents.",
|
||||
f"Honor the original illustration details: {image_prompt[:200]}." if image_prompt else "",
|
||||
"5-second sequence, gentle push-in, flowing cloth and atmospheric particles.",
|
||||
]
|
||||
fallback_prompt = " ".join(filter(None, parts))
|
||||
return fallback_prompt.strip()
|
||||
|
||||
|
||||
def _load_llm_json_response(response_text: Any) -> Dict[str, Any]:
|
||||
"""Normalize responses from llm_text_gen (dict or JSON string)."""
|
||||
if isinstance(response_text, dict):
|
||||
return response_text
|
||||
if isinstance(response_text, str):
|
||||
return json.loads(response_text)
|
||||
raise ValueError(f"Unexpected response type: {type(response_text)}")
|
||||
|
||||
|
||||
def _generate_text_prompt(
|
||||
*,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
user_id: str,
|
||||
fallback_prompt: str,
|
||||
) -> str:
|
||||
"""Fallback text generation when structured JSON parsing fails."""
|
||||
try:
|
||||
response = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code == 429:
|
||||
raise
|
||||
logger.warning(
|
||||
"[AnimateScene] Text-mode prompt generation failed (%s). Using deterministic fallback.",
|
||||
exc.detail,
|
||||
)
|
||||
return fallback_prompt
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[AnimateScene] Unexpected error generating text prompt: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return fallback_prompt
|
||||
|
||||
if isinstance(response, dict):
|
||||
candidates = [
|
||||
response.get("animation_prompt"),
|
||||
response.get("prompt"),
|
||||
response.get("text"),
|
||||
]
|
||||
for candidate in candidates:
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
# As a last resort, stringify the dict
|
||||
response_text = json.dumps(response, ensure_ascii=False)
|
||||
else:
|
||||
response_text = str(response)
|
||||
|
||||
cleaned = response_text.strip()
|
||||
return cleaned or fallback_prompt
|
||||
|
||||
|
||||
def generate_animation_prompt(
|
||||
scene_data: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
user_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Generate an animation-focused prompt using llm_text_gen, falling back to a deterministic prompt if LLM fails.
|
||||
"""
|
||||
fallback_prompt = _build_fallback_prompt(scene_data, story_context)
|
||||
system_prompt = (
|
||||
"You are an expert cinematic animation director. "
|
||||
"You transform static illustrated scenes into short cinematic motion clips. "
|
||||
"Describe motion, camera behavior, atmosphere, and pacing."
|
||||
)
|
||||
|
||||
description = scene_data.get("description", "")
|
||||
image_prompt = scene_data.get("image_prompt", "")
|
||||
title = scene_data.get("title", "")
|
||||
tone = story_context.get("story_tone") or story_context.get("story_tone", "")
|
||||
setting = story_context.get("story_setting") or story_context.get("story_setting", "")
|
||||
|
||||
prompt = f"""
|
||||
Create a concise animation prompt (2-3 sentences) for a 5-second cinematic clip.
|
||||
|
||||
Scene Title: {title}
|
||||
Description: {description}
|
||||
Existing Image Prompt: {image_prompt}
|
||||
Story Tone: {tone}
|
||||
Setting: {setting}
|
||||
|
||||
Focus on:
|
||||
- Motion of characters/objects
|
||||
- Camera movement (pan, zoom, dolly, orbit)
|
||||
- Atmosphere, lighting, and emotion
|
||||
- Timing cues appropriate for a {tone or "story"} scene
|
||||
|
||||
Respond with JSON: {{"animation_prompt": "<prompt>"}}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
json_struct={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"animation_prompt": {
|
||||
"type": "string",
|
||||
"description": "A cinematic motion prompt for the WaveSpeed image-to-video model.",
|
||||
}
|
||||
},
|
||||
"required": ["animation_prompt"],
|
||||
},
|
||||
)
|
||||
structured = _load_llm_json_response(response)
|
||||
animation_prompt = structured.get("animation_prompt")
|
||||
if not animation_prompt or not isinstance(animation_prompt, str):
|
||||
raise ValueError("Missing animation_prompt in structured response")
|
||||
cleaned_prompt = animation_prompt.strip()
|
||||
if not cleaned_prompt:
|
||||
raise ValueError("animation_prompt is empty after trimming")
|
||||
return cleaned_prompt
|
||||
except HTTPException as exc:
|
||||
if exc.status_code == 429:
|
||||
raise
|
||||
logger.warning(
|
||||
"[AnimateScene] Structured LLM prompt generation failed (%s). Falling back to text parsing.",
|
||||
exc.detail,
|
||||
)
|
||||
return _generate_text_prompt(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
fallback_prompt=fallback_prompt,
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as exc:
|
||||
logger.warning(
|
||||
"[AnimateScene] Failed to parse structured animation prompt (%s). Falling back to text parsing.",
|
||||
exc,
|
||||
)
|
||||
return _generate_text_prompt(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
fallback_prompt=fallback_prompt,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[AnimateScene] Unexpected error generating animation prompt: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return fallback_prompt
|
||||
|
||||
|
||||
def animate_scene_image(
|
||||
*,
|
||||
image_bytes: bytes,
|
||||
scene_data: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
user_id: str,
|
||||
duration: int = 5,
|
||||
guidance_scale: float = 0.5,
|
||||
negative_prompt: Optional[str] = None,
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene image using WaveSpeed Kling v2.5 Turbo Std.
|
||||
Returns dict with video bytes, prompt used, model name, duration, and cost.
|
||||
"""
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds for scene animation.")
|
||||
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Scene image exceeds 10MB limit required by WaveSpeed."
|
||||
)
|
||||
|
||||
guidance_scale = max(0.0, min(1.0, guidance_scale))
|
||||
animation_prompt = generate_animation_prompt(scene_data, story_context, user_id)
|
||||
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
payload = {
|
||||
"duration": duration,
|
||||
"guidance_scale": guidance_scale,
|
||||
"image": image_b64,
|
||||
"prompt": animation_prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt.strip()
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
prediction_id = client.submit_image_to_video(KLING_MODEL_PATH, payload)
|
||||
try:
|
||||
result = client.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
detail.setdefault("message", "WaveSpeed request is still processing. Use resume endpoint to fetch the video once ready.")
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed completed but returned no outputs.")
|
||||
|
||||
video_url = outputs[0]
|
||||
video_response = requests.get(video_url, timeout=60)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download animation video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
},
|
||||
)
|
||||
|
||||
model_name = KLING_MODEL_5S if duration == 5 else KLING_MODEL_10S
|
||||
cost = 0.21 if duration == 5 else 0.42
|
||||
|
||||
return {
|
||||
"video_bytes": video_response.content,
|
||||
"prompt": animation_prompt,
|
||||
"duration": duration,
|
||||
"model_name": model_name,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
|
||||
|
||||
def resume_scene_animation(
|
||||
*,
|
||||
prediction_id: str,
|
||||
duration: int,
|
||||
user_id: str,
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Resume a previously submitted animation by fetching the completed result.
|
||||
"""
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds for scene animation.")
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
result = client.get_prediction_result(prediction_id, timeout=120)
|
||||
status = result.get("status")
|
||||
if status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "WaveSpeed prediction is not completed yet",
|
||||
"prediction_id": prediction_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed completed but returned no outputs.")
|
||||
|
||||
video_url = outputs[0]
|
||||
video_response = requests.get(video_url, timeout=120)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download animation video during resume",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
|
||||
animation_prompt = result.get("prompt") or ""
|
||||
model_name = KLING_MODEL_5S if duration == 5 else KLING_MODEL_10S
|
||||
cost = 0.21 if duration == 5 else 0.42
|
||||
|
||||
logger.info("[AnimateScene] Resumed download for prediction=%s", prediction_id)
|
||||
|
||||
return {
|
||||
"video_bytes": video_response.content,
|
||||
"prompt": animation_prompt,
|
||||
"duration": duration,
|
||||
"model_name": model_name,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
|
||||
Binary file not shown.
BIN
backend/story_audio/scene_2_The_Star_Recipe_Begins_68356250.mp3
Normal file
BIN
backend/story_audio/scene_2_The_Star_Recipe_Begins_68356250.mp3
Normal file
Binary file not shown.
BIN
backend/story_audio/scene_2_The_Star_Recipe_Begins_ed9941a3.mp3
Normal file
BIN
backend/story_audio/scene_2_The_Star_Recipe_Begins_ed9941a3.mp3
Normal file
Binary file not shown.
Binary file not shown.
BIN
backend/story_audio/scene_4_Collecting_Wishes_c38d9001.mp3
Normal file
BIN
backend/story_audio/scene_4_Collecting_Wishes_c38d9001.mp3
Normal file
Binary file not shown.
BIN
backend/story_audio/scene_5_The_Gravity_Mixer_e6255f00.mp3
Normal file
BIN
backend/story_audio/scene_5_The_Gravity_Mixer_e6255f00.mp3
Normal file
Binary file not shown.
BIN
backend/story_audio/scene_6_The_Glowing_Mixture_c0163e9c.mp3
Normal file
BIN
backend/story_audio/scene_6_The_Glowing_Mixture_c0163e9c.mp3
Normal file
Binary file not shown.
BIN
backend/story_audio/scene_7_A_New_Star_Is_Born_c3f3f2c4.mp3
Normal file
BIN
backend/story_audio/scene_7_A_New_Star_Is_Born_c3f3f2c4.mp3
Normal file
Binary file not shown.
Binary file not shown.
|
After Width: | Height: | Size: 1.2 MiB |
Reference in New Issue
Block a user