AI story writer enhancements, text to video and voice generation, subscription management, and more.

This commit is contained in:
ajaysi
2025-11-19 09:55:32 +05:30
parent bf7493c366
commit e96525347b
64 changed files with 10367 additions and 400 deletions

View File

@@ -0,0 +1,301 @@
"""
Main Audio Generation Service for ALwrity Backend.
This service provides AI-powered text-to-speech functionality using WaveSpeed Minimax Speech 02 HD.
"""
from __future__ import annotations
import sys
from typing import Optional, Dict, Any
from datetime import datetime
from loguru import logger
from fastapi import HTTPException
from services.wavespeed.client import WaveSpeedClient
from services.onboarding.api_key_manager import APIKeyManager
from utils.logger_utils import get_service_logger
logger = get_service_logger("audio_generation")
class AudioGenerationResult:
"""Result of audio generation."""
def __init__(
self,
audio_bytes: bytes,
provider: str,
model: str,
voice_id: str,
text_length: int,
file_size: int,
):
self.audio_bytes = audio_bytes
self.provider = provider
self.model = model
self.voice_id = voice_id
self.text_length = text_length
self.file_size = file_size
def generate_audio(
text: str,
voice_id: str = "Wise_Woman",
speed: float = 1.0,
volume: float = 1.0,
pitch: float = 0.0,
emotion: str = "happy",
user_id: Optional[str] = None,
**kwargs
) -> AudioGenerationResult:
"""
Generate audio using AI text-to-speech with subscription tracking.
Args:
text: Text to convert to speech (max 10000 characters)
voice_id: Voice ID (default: "Wise_Woman")
speed: Speech speed (0.5-2.0, default: 1.0)
volume: Speech volume (0.1-10.0, default: 1.0)
pitch: Speech pitch (-12 to 12, default: 0.0)
emotion: Emotion (default: "happy")
user_id: User ID for subscription checking (required)
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
Returns:
AudioGenerationResult: Generated audio result
Raises:
RuntimeError: If subscription limits are exceeded or user_id is missing.
"""
try:
logger.info("[audio_gen] Starting audio generation")
logger.debug(f"[audio_gen] Text length: {len(text)} characters, voice: {voice_id}")
# SUBSCRIPTION CHECK - Required and strict enforcement
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
# Calculate cost based on character count (every character is 1 token)
# Pricing: $0.05 per 1,000 characters
character_count = len(text)
cost_per_1000_chars = 0.05
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
try:
from services.database import get_db
from services.subscription import PricingService
from models.subscription_models import UsageSummary, APIProvider
db = next(get_db())
try:
pricing_service = PricingService(db)
# Check limits using sync method from pricing service (strict enforcement)
# Use AUDIO provider for audio generation
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=character_count, # Use character count as "tokens" for audio
actual_provider_name="wavespeed" # Actual provider is WaveSpeed
)
if not can_proceed:
logger.warning(f"[audio_gen] Subscription limit exceeded for user {user_id}: {message}")
error_detail = {
'error': message,
'message': message,
'provider': 'wavespeed',
'usage_info': usage_info if usage_info else {}
}
raise HTTPException(status_code=429, detail=error_detail)
# Get current usage for limit checking
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
finally:
db.close()
except HTTPException:
raise
except RuntimeError:
raise
except Exception as sub_error:
logger.error(f"[audio_gen] Subscription check failed for user {user_id}: {sub_error}")
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
# Generate audio using WaveSpeed
try:
client = WaveSpeedClient()
audio_bytes = client.generate_speech(
text=text,
voice_id=voice_id,
speed=speed,
volume=volume,
pitch=pitch,
emotion=emotion,
enable_sync_mode=True,
**kwargs
)
logger.info(f"[audio_gen] ✅ API call successful, generated {len(audio_bytes)} bytes")
except HTTPException:
raise
except Exception as api_error:
logger.error(f"[audio_gen] Audio generation API failed: {api_error}")
raise HTTPException(
status_code=502,
detail={
"error": "Audio generation failed",
"message": str(api_error)
}
)
# TRACK USAGE after successful API call
if audio_bytes:
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
# Update audio calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text
update_query = text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed",
method="POST",
model_used="minimax/speech-02-hd",
tokens_input=character_count,
tokens_output=0,
tokens_total=character_count,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(text.encode("utf-8")),
response_size=len(audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_image_calls = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[audio_gen] ✅ Successfully tracked usage: user {user_id} -> audio -> {new_calls} calls, ${estimated_cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Audio Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: wavespeed
├─ Actual Provider: wavespeed
├─ Model: minimax/speech-02-hd
├─ Voice: {voice_id}
├─ Calls: {current_calls_before}{new_calls} / {audio_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Characters: {character_count}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[audio_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[audio_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return AudioGenerationResult(
audio_bytes=audio_bytes,
provider="wavespeed",
model="minimax/speech-02-hd",
voice_id=voice_id,
text_length=character_count,
file_size=len(audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[audio_gen] Error generating audio: {e}")
raise HTTPException(
status_code=500,
detail={
"error": "Audio generation failed",
"message": str(e)
}
)

View File

@@ -515,6 +515,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
@@ -571,6 +577,8 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
@@ -819,6 +827,12 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
@@ -838,6 +852,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:

View File

@@ -10,6 +10,7 @@ from __future__ import annotations
import os
import base64
import io
import sys
from typing import Any, Dict, Optional, Union
from fastapi import HTTPException
@@ -22,11 +23,11 @@ except ImportError:
InferenceClient = None
from ..onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService
from utils.logger_utils import get_service_logger
logger = get_service_logger("video_generation_service")
class VideoProviderNotImplemented(Exception):
pass
@@ -48,44 +49,80 @@ def _get_api_key(provider: str) -> Optional[str]:
def _coerce_video_bytes(output: Any) -> bytes:
"""
Normalizes the different return shapes that huggingface_hub may emit for video tasks.
Depending on the provider/library version we may get:
- raw bytes
- an object with `.video` or `.bytes` attributes (plus optional `.save`)
- a dict containing a `video` key with bytes/base64 data
According to HF docs, text_to_video() should return bytes directly.
"""
data: Union[bytes, bytearray, memoryview, io.BufferedIOBase, None] = None
logger.debug(f"[video_gen] _coerce_video_bytes received type: {type(output)}")
# Most common case: bytes directly
if isinstance(output, (bytes, bytearray, memoryview)):
logger.debug(f"[video_gen] Output is bytes: {len(output)} bytes")
return bytes(output)
# Handle file-like objects
if hasattr(output, "read"):
logger.debug("[video_gen] Output has read() method, reading...")
data = output.read()
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
raise TypeError(f"File-like object returned non-bytes: {type(data)}")
# Objects with direct attribute access
if hasattr(output, "video"):
logger.debug("[video_gen] Output has 'video' attribute")
data = getattr(output, "video")
elif hasattr(output, "bytes"):
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
if hasattr(output, "bytes"):
logger.debug("[video_gen] Output has 'bytes' attribute")
data = getattr(output, "bytes")
elif isinstance(output, dict) and "video" in output:
data = output["video"]
else:
data = output
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
# Handle file-like responses
if hasattr(data, "read"):
data = data.read()
# Dict handling - but this shouldn't happen with text_to_video()
if isinstance(output, dict):
logger.warning(f"[video_gen] Received dict output (unexpected): keys={list(output.keys())}")
# Try to get video key safely - use .get() to avoid KeyError
data = output.get("video")
if data is not None:
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
# Try other common keys
for key in ["data", "content", "file", "result", "output"]:
data = output.get(key)
if data is not None:
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if hasattr(data, "read"):
return bytes(data.read())
raise TypeError(f"Dict output has no recognized video key. Keys: {list(output.keys())}")
if isinstance(data, (bytes, bytearray, memoryview)):
return bytes(data)
if isinstance(data, str):
# Expecting data URI or raw base64 string
if data.startswith("data:"):
_, encoded = data.split(",", 1)
# String handling (base64)
if isinstance(output, str):
logger.debug("[video_gen] Output is string, attempting base64 decode")
if output.startswith("data:"):
_, encoded = output.split(",", 1)
return base64.b64decode(encoded)
try:
return base64.b64decode(data)
return base64.b64decode(output)
except Exception as exc:
raise TypeError(f"Unable to decode string video payload: {exc}") from exc
raise TypeError(f"Unsupported video payload type: {type(data)}")
# Fallback: try to use output directly
logger.warning(f"[video_gen] Unexpected output type: {type(output)}, attempting direct conversion")
try:
if hasattr(output, "__bytes__"):
return bytes(output)
except Exception:
pass
raise TypeError(f"Unsupported video payload type: {type(output)}. Output: {str(output)[:200]}")
def _generate_with_huggingface(
@@ -96,7 +133,6 @@ def _generate_with_huggingface(
negative_prompt: Optional[str] = None,
seed: Optional[int] = None,
model: str = "tencent/HunyuanVideo",
input_image_bytes: Optional[bytes] = None,
) -> bytes:
"""
Generates video bytes using Hugging Face's InferenceClient.
@@ -109,7 +145,6 @@ def _generate_with_huggingface(
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
client = InferenceClient(
model=model,
provider="fal-ai",
token=token,
)
@@ -126,26 +161,25 @@ def _generate_with_huggingface(
params["seed"] = seed
logger.info(
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=%s",
"[video_gen] HuggingFace request model=%s frames=%s steps=%s mode=text-to-video",
model,
num_frames,
num_inference_steps,
"image-to-video" if input_image_bytes else "text-to-video",
)
try:
call_kwargs = {**params, "model": model}
if input_image_bytes:
video_output = client.image_to_video(
image=input_image_bytes,
prompt=prompt,
**call_kwargs,
)
else:
video_output = client.text_to_video(
prompt,
**call_kwargs,
)
logger.info("[video_gen] Calling client.text_to_video()...")
video_output = client.text_to_video(
prompt=prompt,
model=model,
**params,
)
logger.info(f"[video_gen] text_to_video() returned type: {type(video_output)}")
if isinstance(video_output, dict):
logger.info(f"[video_gen] Dict keys: {list(video_output.keys())}")
elif hasattr(video_output, "__dict__"):
logger.info(f"[video_gen] Object attributes: {dir(video_output)}")
video_bytes = _coerce_video_bytes(video_output)
@@ -158,6 +192,15 @@ def _generate_with_huggingface(
logger.info(f"[video_gen] Successfully generated video: {len(video_bytes)} bytes")
return video_bytes
except KeyError as e:
error_msg = str(e)
logger.error(f"[video_gen] HF KeyError: {error_msg}", exc_info=True)
logger.error(f"[video_gen] This suggests the API response format is unexpected. Check logs above for response type.")
raise HTTPException(status_code=502, detail={
"error": f"Hugging Face API returned unexpected response format: {error_msg}",
"error_type": "KeyError",
"hint": "The API response may have changed. Check server logs for details."
})
except Exception as e:
error_msg = str(e)
error_type = type(e).__name__
@@ -179,7 +222,6 @@ def ai_video_generate(
prompt: str,
provider: str = "huggingface",
user_id: Optional[str] = None,
input_image_bytes: Optional[bytes] = None,
**kwargs,
) -> bytes:
"""
@@ -187,7 +229,6 @@ def ai_video_generate(
- provider: 'huggingface' (default), 'gemini' (veo3 stub), 'openai' (sora stub)
- kwargs: num_frames, guidance_scale, num_inference_steps, negative_prompt, seed, model
- input_image_bytes: optional bytes for image-to-video flows (uses image as motion anchor)
Returns raw video bytes (mp4/webm depending on provider).
"""
@@ -200,7 +241,6 @@ def ai_video_generate(
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
from fastapi import HTTPException
@@ -227,7 +267,6 @@ def ai_video_generate(
if provider == "huggingface":
video_bytes = _generate_with_huggingface(
prompt=prompt,
input_image_bytes=input_image_bytes,
**kwargs,
)
elif provider == "gemini":
@@ -237,112 +276,14 @@ def ai_video_generate(
else:
raise RuntimeError(f"Unknown video provider: {provider}")
# Track usage AFTER successful generation
db_track = next(get_db())
try:
from models.subscription_models import APIProvider, UsageSummary, APIUsageLog
from datetime import datetime
from services.subscription import PricingService
# Create pricing service for tracking (uses same DB session)
pricing_service_track = PricingService(db_track)
# Get current billing period
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
usage_summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage_summary:
usage_summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(usage_summary)
db_track.commit()
# Calculate cost using pricing service
cost_info = pricing_service_track.get_pricing_for_provider_model(
APIProvider.VIDEO,
model_name
)
cost_per_video = cost_info.get('cost_per_request', 0.10) if cost_info else 0.10
# Get "before" state for unified log
current_video_calls_before = getattr(usage_summary, 'video_calls', 0) or 0
current_video_cost = getattr(usage_summary, 'video_cost', 0.0) or 0.0
# Increment video_calls and track cost
new_video_calls = current_video_calls_before + 1
usage_summary.video_calls = new_video_calls
usage_summary.video_cost = current_video_cost + cost_per_video
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
# Get plan details for unified log (before commit, in case commit fails)
limits = pricing_service_track.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get image and image editing stats for unified log
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Create usage log entry for audit trail
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.VIDEO,
endpoint=f"/video-generation/{provider}",
method="POST",
model_used=model_name,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost_per_video,
response_time=0.0, # Could track actual time if needed
status_code=200,
request_size=len(prompt.encode('utf-8')),
response_size=len(video_bytes),
billing_period=current_period
)
db_track.add(usage_log)
db_track.commit()
logger.info(f"[video_gen] ✅ Successfully tracked usage: user {user_id} -> 1 video call, ${cost_per_video:.4f} cost")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Flush immediately to ensure it's visible in console/logs
import sys
log_message = f"""
[SUBSCRIPTION] Video Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: video
├─ Actual Provider: {provider}
├─ Model: {model_name or 'default'}
├─ Calls: {current_video_calls_before}{new_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
"""
print(log_message, flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
db_track.rollback()
# Don't fail video generation if tracking fails - video is already generated
finally:
db_track.close()
track_video_usage(
user_id=user_id,
provider=provider,
model_name=model_name,
prompt=prompt,
video_bytes=video_bytes,
)
return video_bytes
except HTTPException:
@@ -353,3 +294,139 @@ def ai_video_generate(
raise HTTPException(status_code=500, detail={"error": str(e)})
def track_video_usage(
*,
user_id: str,
provider: str,
model_name: str,
prompt: str,
video_bytes: bytes,
cost_override: Optional[float] = None,
) -> Dict[str, Any]:
"""
Track subscription usage for any video generation (text-to-video or image-to-video).
"""
from datetime import datetime
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
from services.database import get_db
db_track = next(get_db())
try:
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
pricing_service_track = PricingService(db_track)
current_period = pricing_service_track.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[video_gen] Billing period: {current_period}")
usage_summary = (
db_track.query(UsageSummary)
.filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period,
)
.first()
)
if not usage_summary:
logger.debug(f"[video_gen] Creating new UsageSummary for user={user_id}, period={current_period}")
usage_summary = UsageSummary(
user_id=user_id,
billing_period=current_period,
)
db_track.add(usage_summary)
db_track.commit()
db_track.refresh(usage_summary)
else:
logger.debug(f"[video_gen] Found existing UsageSummary: video_calls={getattr(usage_summary, 'video_calls', 0)}")
cost_info = pricing_service_track.get_pricing_for_provider_model(
APIProvider.VIDEO,
model_name,
)
default_cost = 0.10
if cost_info and cost_info.get("cost_per_request") is not None:
default_cost = cost_info["cost_per_request"]
cost_per_video = cost_override if cost_override is not None else default_cost
logger.debug(f"[video_gen] Cost per video: ${cost_per_video} (override={cost_override}, default={default_cost})")
current_video_calls_before = getattr(usage_summary, "video_calls", 0) or 0
current_video_cost = getattr(usage_summary, "video_cost", 0.0) or 0.0
usage_summary.video_calls = current_video_calls_before + 1
usage_summary.video_cost = current_video_cost + cost_per_video
usage_summary.total_calls = (usage_summary.total_calls or 0) + 1
usage_summary.total_cost = (usage_summary.total_cost or 0.0) + cost_per_video
# Ensure the object is in the session
db_track.add(usage_summary)
logger.debug(f"[video_gen] Updated usage_summary: video_calls={current_video_calls_before}{usage_summary.video_calls}")
limits = pricing_service_track.get_user_limits(user_id)
plan_name = limits.get("plan_name", "unknown") if limits else "unknown"
tier = limits.get("tier", "unknown") if limits else "unknown"
video_limit = limits["limits"].get("video_calls", 0) if limits else 0
current_image_calls = getattr(usage_summary, "stability_calls", 0) or 0
image_limit = limits["limits"].get("stability_calls", 0) if limits else 0
current_image_edit_calls = getattr(usage_summary, "image_edit_calls", 0) or 0
image_edit_limit = limits["limits"].get("image_edit_calls", 0) if limits else 0
current_audio_calls = getattr(usage_summary, "audio_calls", 0) or 0
audio_limit = limits["limits"].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.VIDEO,
endpoint=f"/video-generation/{provider}",
method="POST",
model_used=model_name,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost_per_video,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(video_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
logger.debug(f"[video_gen] Flushing changes before commit...")
db_track.flush()
logger.debug(f"[video_gen] Committing usage tracking changes...")
db_track.commit()
db_track.refresh(usage_summary)
logger.debug(f"[video_gen] Commit successful. Final video_calls: {usage_summary.video_calls}, video_cost: {usage_summary.video_cost}")
video_limit_display = video_limit if video_limit > 0 else ''
log_message = f"""
[SUBSCRIPTION] Video Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: video
├─ Actual Provider: {provider}
├─ Model: {model_name or 'default'}
├─ Calls: {current_video_calls_before}{usage_summary.video_calls} / {video_limit_display}
├─ Images: {current_image_calls} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
"""
logger.info(log_message)
return {
"previous_calls": current_video_calls_before,
"current_calls": usage_summary.video_calls,
"video_limit": video_limit,
"video_limit_display": video_limit_display,
"cost_per_video": cost_per_video,
"total_video_cost": usage_summary.video_cost,
}
except Exception as track_error:
logger.error(f"[video_gen] Error tracking usage: {track_error}", exc_info=True)
logger.error(f"[video_gen] Exception type: {type(track_error).__name__}", exc_info=True)
db_track.rollback()
finally:
db_track.close()

View File

@@ -414,7 +414,8 @@ class APIKeyManager:
'SERPER_API_KEY',
'METAPHOR_API_KEY',
'FIRECRAWL_API_KEY',
'STABILITY_API_KEY'
'STABILITY_API_KEY',
'WAVESPEED_API_KEY',
]
for provider in providers:

View File

@@ -288,4 +288,90 @@ class StoryAudioGenerationService:
logger.info(f"[StoryAudioGeneration] Generated {len(audio_results)} audio files out of {total_scenes} scenes")
return audio_results
def generate_ai_audio(
self,
scene_number: int,
scene_title: str,
text: str,
user_id: str,
voice_id: str = "Wise_Woman",
speed: float = 1.0,
volume: float = 1.0,
pitch: float = 0.0,
emotion: str = "happy",
) -> Dict[str, Any]:
"""
Generate AI audio for a single scene using main_audio_generation.
Parameters:
scene_number (int): Scene number.
scene_title (str): Scene title.
text (str): Text to convert to speech.
user_id (str): Clerk user ID for subscription checking.
voice_id (str): Voice ID for AI audio generation (default: "Wise_Woman").
speed (float): Speech speed (0.5-2.0, default: 1.0).
volume (float): Speech volume (0.1-10.0, default: 1.0).
pitch (float): Speech pitch (-12 to 12, default: 0.0).
emotion (str): Emotion for speech (default: "happy").
Returns:
Dict[str, Any]: Audio metadata including file path, URL, and scene info.
"""
if not text or not text.strip():
raise ValueError(f"Scene {scene_number} ({scene_title}) requires non-empty text")
try:
logger.info(f"[StoryAudioGeneration] Generating AI audio for scene {scene_number}: {scene_title}")
logger.debug(f"[StoryAudioGeneration] Text length: {len(text)} characters, voice: {voice_id}")
# Import main_audio_generation
from services.llm_providers.main_audio_generation import generate_audio
# Generate audio using main_audio_generation service
result = generate_audio(
text=text.strip(),
voice_id=voice_id,
speed=speed,
volume=volume,
pitch=pitch,
emotion=emotion,
user_id=user_id,
)
# Save audio to file
audio_filename = self._generate_audio_filename(scene_number, scene_title)
audio_path = self.output_dir / audio_filename
with open(audio_path, "wb") as f:
f.write(result.audio_bytes)
logger.info(f"[StoryAudioGeneration] Saved AI audio to: {audio_path} ({result.file_size} bytes)")
# Calculate cost (for response)
character_count = result.text_length
cost_per_1000_chars = 0.05
cost = (character_count / 1000.0) * cost_per_1000_chars
# Return audio metadata
return {
"scene_number": scene_number,
"scene_title": scene_title,
"audio_path": str(audio_path),
"audio_filename": audio_filename,
"audio_url": f"/api/story/audio/{audio_filename}",
"provider": result.provider,
"model": result.model,
"voice_id": result.voice_id,
"text_length": result.text_length,
"file_size": result.file_size,
"cost": cost,
}
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit)
raise
except Exception as e:
logger.error(f"[StoryAudioGeneration] Error generating AI audio for scene {scene_number}: {e}")
raise RuntimeError(f"Failed to generate AI audio for scene {scene_number}: {str(e)}") from e

View File

@@ -193,4 +193,82 @@ class StoryImageGenerationService:
logger.info(f"[StoryImageGeneration] Generated {len(image_results)} images out of {total_scenes} scenes")
return image_results
def regenerate_scene_image(
self,
scene_number: int,
scene_title: str,
prompt: str,
user_id: str,
provider: Optional[str] = None,
width: int = 1024,
height: int = 1024,
model: Optional[str] = None
) -> Dict[str, Any]:
"""
Regenerate an image for a single scene using a direct prompt (no AI prompt generation).
Parameters:
scene_number (int): Scene number.
scene_title (str): Scene title.
prompt (str): Direct prompt to use for image generation.
user_id (str): Clerk user ID for subscription checking.
provider (str, optional): Image generation provider (gemini, huggingface, stability).
width (int): Image width (default: 1024).
height (int): Image height (default: 1024).
model (str, optional): Model to use for image generation.
Returns:
Dict[str, Any]: Image metadata including file path, URL, and scene info.
"""
if not prompt or not prompt.strip():
raise ValueError(f"Scene {scene_number} ({scene_title}) requires a non-empty prompt")
try:
logger.info(f"[StoryImageGeneration] Regenerating image for scene {scene_number}: {scene_title}")
logger.debug(f"[StoryImageGeneration] Using direct prompt: {prompt[:100]}...")
# Generate image using main_image_generation service with the direct prompt
image_options = {
"provider": provider,
"width": width,
"height": height,
"model": model,
}
result: ImageGenerationResult = generate_image(
prompt=prompt.strip(),
options=image_options,
user_id=user_id
)
# Save image to file
image_filename = self._generate_image_filename(scene_number, scene_title)
image_path = self.output_dir / image_filename
with open(image_path, "wb") as f:
f.write(result.image_bytes)
logger.info(f"[StoryImageGeneration] Saved regenerated image to: {image_path}")
# Return image metadata
return {
"scene_number": scene_number,
"scene_title": scene_title,
"image_path": str(image_path),
"image_filename": image_filename,
"image_url": f"/api/story/images/{image_filename}",
"width": result.width,
"height": result.height,
"provider": result.provider,
"model": result.model,
"seed": result.seed,
}
except HTTPException:
# Re-raise HTTPExceptions (e.g., 429 subscription limit)
raise
except Exception as e:
logger.error(f"[StoryImageGeneration] Error regenerating image for scene {scene_number}: {e}")
raise RuntimeError(f"Failed to regenerate image for scene {scene_number}: {str(e)}") from e

View File

@@ -220,35 +220,41 @@ class StoryVideoGenerationService:
def generate_story_video(
self,
scenes: List[Dict[str, Any]],
image_paths: List[str],
image_paths: List[Optional[str]],
audio_paths: List[str],
user_id: str,
story_title: str = "Story",
fps: int = 24,
transition_duration: float = 0.5,
progress_callback: Optional[callable] = None
progress_callback: Optional[callable] = None,
video_paths: Optional[List[Optional[str]]] = None
) -> Dict[str, Any]:
"""
Generate a complete story video from multiple scenes.
Parameters:
scenes (List[Dict[str, Any]]): List of scene data.
image_paths (List[str]): List of image file paths for each scene.
image_paths (List[Optional[str]]): List of image file paths (None if scene has animated video).
audio_paths (List[str]): List of audio file paths for each scene.
user_id (str): Clerk user ID for subscription checking.
story_title (str): Title of the story (default: "Story").
fps (int): Frames per second for video (default: 24).
transition_duration (float): Duration of transitions between scenes in seconds (default: 0.5).
progress_callback (callable, optional): Callback function for progress updates.
video_paths (Optional[List[Optional[str]]]): List of animated video file paths (None if scene has static image).
Returns:
Dict[str, Any]: Video metadata including file path, URL, and story info.
"""
if not scenes or not image_paths or not audio_paths:
raise ValueError("Scenes, image paths, and audio paths are required")
if not scenes or not audio_paths:
raise ValueError("Scenes and audio paths are required")
if len(scenes) != len(image_paths) or len(scenes) != len(audio_paths):
raise ValueError("Number of scenes, image paths, and audio paths must match")
if len(scenes) != len(audio_paths):
raise ValueError("Number of scenes and audio paths must match")
video_paths = video_paths or [None] * len(scenes)
if len(video_paths) != len(scenes):
video_paths = video_paths + [None] * (len(scenes) - len(video_paths))
try:
logger.info(f"[StoryVideoGeneration] Generating story video for {len(scenes)} scenes")
@@ -293,36 +299,59 @@ class StoryVideoGenerationService:
scene_clips = []
total_duration = 0.0
for idx, (scene, image_path, audio_path) in enumerate(zip(scenes, image_paths, audio_paths)):
# Import VideoFileClip for animated videos
try:
from moviepy import VideoFileClip
except ImportError:
VideoFileClip = None
for idx, (scene, image_path, audio_path, video_path) in enumerate(zip(scenes, image_paths, audio_paths, video_paths)):
try:
scene_number = scene.get("scene_number", idx + 1)
scene_title = scene.get("title", "Untitled")
logger.info(f"[StoryVideoGeneration] Processing scene {scene_number}/{len(scenes)}: {scene_title}")
# Load image and audio
image_file = Path(image_path)
audio_file = Path(audio_path)
if not image_file.exists():
logger.warning(f"[StoryVideoGeneration] Image not found: {image_path}, skipping scene {scene_number}")
continue
if not audio_file.exists():
logger.warning(f"[StoryVideoGeneration] Audio not found: {audio_path}, skipping scene {scene_number}")
continue
# Load audio to get duration
# Load audio
audio_clip = AudioFileClip(str(audio_file))
audio_duration = audio_clip.duration
# Create image clip (MoviePy v2: use with_* API)
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
image_clip = image_clip.with_fps(fps)
# Prefer animated video if available
if video_path and Path(video_path).exists():
logger.info(f"[StoryVideoGeneration] Using animated video for scene {scene_number}: {video_path}")
# Load animated video
if VideoFileClip is None:
raise RuntimeError("VideoFileClip not available - MoviePy may not be fully installed")
video_clip = VideoFileClip(str(video_path))
# Replace audio with the preferred audio (AI or free)
video_clip = video_clip.with_audio(audio_clip)
# Match duration to audio if needed
if video_clip.duration > audio_duration:
video_clip = video_clip.subclip(0, audio_duration)
elif video_clip.duration < audio_duration:
# Loop the video if it's shorter than audio
loops_needed = int(audio_duration / video_clip.duration) + 1
video_clip = concatenate_videoclips([video_clip] * loops_needed).subclip(0, audio_duration)
video_clip = video_clip.with_audio(audio_clip)
elif image_path and Path(image_path).exists():
# Fall back to static image
logger.info(f"[StoryVideoGeneration] Using static image for scene {scene_number}: {image_path}")
image_file = Path(image_path)
# Create image clip (MoviePy v2: use with_* API)
image_clip = ImageClip(str(image_file)).with_duration(audio_duration)
image_clip = image_clip.with_fps(fps)
# Set audio to image clip
video_clip = image_clip.with_audio(audio_clip)
else:
logger.warning(f"[StoryVideoGeneration] No video or image found for scene {scene_number}, skipping")
continue
# Set audio to image clip
video_clip = image_clip.with_audio(audio_clip)
scene_clips.append(video_clip)
total_duration += audio_duration
# Call progress callback if provided

View File

@@ -19,10 +19,18 @@ import re
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
from models.subscription_models import APIProvider
from services.database import get_db
from .usage_tracking_service import UsageTrackingService
from .pricing_service import PricingService
def _get_db_session():
"""
Get a database session with lazy import to survive hot reloads.
Uvicorn's reloader can sometimes clear module-level imports.
"""
from services.database import get_db
return next(get_db())
class DatabaseAPIMonitor:
"""Database-backed API monitoring with usage tracking and subscription management."""
@@ -145,8 +153,9 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
except Exception:
pass
db = None
try:
db = next(get_db())
db = _get_db_session()
api_monitor = DatabaseAPIMonitor()
# Detect if this is an API call that should be rate limited
@@ -203,14 +212,15 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
# Don't block requests if usage checking fails
return None
finally:
db.close()
if db is not None:
db.close()
async def monitoring_middleware(request: Request, call_next):
"""Enhanced FastAPI middleware for monitoring API calls with usage tracking."""
start_time = time.time()
# Get database session
db = next(get_db())
db = _get_db_session()
# Extract request details - Enhanced user identification
user_id = None
@@ -340,8 +350,9 @@ async def monitoring_middleware(request: Request, call_next):
async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
"""Get current monitoring statistics."""
db = next(get_db())
db = None
try:
db = _get_db_session()
# Placeholder to match old API; heavy stats handled elsewhere
return {
'timestamp': datetime.utcnow().isoformat(),
@@ -354,12 +365,14 @@ async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
'system_health': {'status': 'healthy', 'error_rate': 0.0}
}
finally:
db.close()
if db is not None:
db.close()
async def get_lightweight_stats() -> Dict[str, Any]:
"""Get lightweight stats for dashboard header."""
db = next(get_db())
db = None
try:
db = _get_db_session()
# Minimal viable placeholder values
now = datetime.utcnow()
return {
@@ -371,4 +384,5 @@ async def get_lightweight_stats() -> Dict[str, Any]:
'timestamp': now.isoformat()
}
finally:
db.close()
if db is not None:
db.close()

View File

@@ -420,3 +420,54 @@ def validate_video_generation_operations(
'message': f"Failed to validate video generation: {str(e)}"
}
)
def validate_scene_animation_operation(
pricing_service: PricingService,
user_id: str,
) -> None:
"""
Validate the per-scene animation workflow before API calls.
"""
try:
operations_to_validate = [
{
'provider': APIProvider.VIDEO,
'tokens_requested': 0,
'actual_provider_name': 'wavespeed',
'operation_type': 'scene_animation',
}
]
can_proceed, message, error_details = pricing_service.check_comprehensive_limits(
user_id=user_id,
operations=operations_to_validate,
)
if not can_proceed:
logger.error(f"[Pre-flight Validator] Scene animation blocked for user {user_id}: {message}")
usage_info = error_details.get('usage_info', {}) if error_details else {}
provider = usage_info.get('provider', 'video') if usage_info else 'video'
raise HTTPException(
status_code=429,
detail={
'error': message,
'message': message,
'provider': provider,
'usage_info': usage_info if usage_info else error_details,
}
)
logger.info(f"[Pre-flight Validator] ✅ Scene animation validated for user {user_id}")
except HTTPException:
raise
except Exception as e:
logger.error(f"[Pre-flight Validator] Error validating scene animation: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
'error': f"Failed to validate scene animation: {str(e)}",
'message': f"Failed to validate scene animation: {str(e)}",
},
)

View File

@@ -307,6 +307,41 @@ class PricingService:
"model_name": "default",
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
"description": "AI Video Generation default pricing"
},
{
"provider": APIProvider.VIDEO,
"model_name": "kling-v2.5-turbo-std-5s",
"cost_per_request": 0.21,
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (5 seconds)"
},
{
"provider": APIProvider.VIDEO,
"model_name": "kling-v2.5-turbo-std-10s",
"cost_per_request": 0.42,
"description": "WaveSpeed Kling v2.5 Turbo Std Image-to-Video (10 seconds)"
},
{
"provider": APIProvider.VIDEO,
"model_name": "wavespeed-ai/infinitetalk",
"cost_per_request": 0.30,
"description": "WaveSpeed InfiniteTalk (image + audio to talking avatar video)"
},
# Audio Generation Pricing (Minimax Speech 02 HD via WaveSpeed)
{
"provider": APIProvider.AUDIO,
"model_name": "minimax/speech-02-hd",
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters (every character is 1 token)
"cost_per_output_token": 0.0, # No output tokens for audio
"cost_per_request": 0.0, # Pricing is per character, not per request
"description": "AI Audio Generation (Text-to-Speech) - Minimax Speech 02 HD via WaveSpeed"
},
{
"provider": APIProvider.AUDIO,
"model_name": "default",
"cost_per_input_token": 0.00005, # $0.05 per 1,000 characters default
"cost_per_output_token": 0.0,
"cost_per_request": 0.0,
"description": "AI Audio Generation default pricing"
}
]
@@ -358,6 +393,7 @@ class PricingService:
"exa_calls_limit": 100,
"video_calls_limit": 0, # No video generation for free tier
"image_edit_calls_limit": 10, # 10 AI image editing calls/month
"audio_calls_limit": 20, # 20 AI audio generation calls/month
"gemini_tokens_limit": 100000,
"monthly_cost_limit": 0.0,
"features": ["basic_content_generation", "limited_research"],
@@ -381,6 +417,7 @@ class PricingService:
"exa_calls_limit": 500,
"video_calls_limit": 20, # 20 videos/month for basic plan
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
"audio_calls_limit": 50, # 50 AI audio generation calls/month
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
@@ -406,6 +443,7 @@ class PricingService:
"exa_calls_limit": 2000,
"video_calls_limit": 50, # 50 videos/month for pro plan
"image_edit_calls_limit": 100, # 100 AI image editing calls/month
"audio_calls_limit": 200, # 200 AI audio generation calls/month
"gemini_tokens_limit": 5000000,
"openai_tokens_limit": 2500000,
"anthropic_tokens_limit": 1000000,
@@ -431,6 +469,7 @@ class PricingService:
"exa_calls_limit": 0, # Unlimited
"video_calls_limit": 0, # Unlimited for enterprise
"image_edit_calls_limit": 0, # Unlimited image editing for enterprise
"audio_calls_limit": 0, # Unlimited audio generation for enterprise
"gemini_tokens_limit": 0,
"openai_tokens_limit": 0,
"anthropic_tokens_limit": 0,
@@ -651,6 +690,7 @@ class PricingService:
'stability_calls': plan.stability_calls_limit,
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
'audio_calls': getattr(plan, 'audio_calls_limit', 0), # Support missing column
# Token limits
'gemini_tokens': plan.gemini_tokens_limit,
'openai_tokens': plan.openai_tokens_limit,

View File

@@ -31,6 +31,7 @@ def ensure_subscription_plan_columns(db: Session) -> None:
"exa_calls_limit": "INTEGER DEFAULT 0",
"video_calls_limit": "INTEGER DEFAULT 0",
"image_edit_calls_limit": "INTEGER DEFAULT 0",
"audio_calls_limit": "INTEGER DEFAULT 0",
}
for col_name, ddl in required_columns.items():
@@ -84,6 +85,8 @@ def ensure_usage_summaries_columns(db: Session) -> None:
"video_cost": "REAL DEFAULT 0.0",
"image_edit_calls": "INTEGER DEFAULT 0",
"image_edit_cost": "REAL DEFAULT 0.0",
"audio_calls": "INTEGER DEFAULT 0",
"audio_cost": "REAL DEFAULT 0.0",
}
for col_name, ddl in required_columns.items():

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,471 @@
from __future__ import annotations
import json
import time
from typing import Any, Dict, Optional
import requests
from fastapi import HTTPException
from requests import exceptions as requests_exceptions
from services.onboarding.api_key_manager import APIKeyManager
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.client")
class WaveSpeedClient:
"""
Thin HTTP client for the WaveSpeed AI API.
Handles authentication, submission, and polling helpers.
"""
BASE_URL = "https://api.wavespeed.ai/api/v3"
def __init__(self, api_key: Optional[str] = None):
manager = APIKeyManager()
self.api_key = api_key or manager.get_api_key("wavespeed")
if not self.api_key:
raise RuntimeError("WAVESPEED_API_KEY is not configured. Please add it to your environment.")
def _headers(self) -> Dict[str, str]:
return {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}",
}
def submit_image_to_video(
self,
model_path: str,
payload: Dict[str, Any],
timeout: int = 30,
) -> str:
"""
Submit an image-to-video generation request.
Returns the prediction ID for polling.
"""
url = f"{self.BASE_URL}/{model_path}"
logger.info(f"[WaveSpeed] Submitting request to {url}")
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed image-to-video submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
data = response.json().get("data")
if not data or "id" not in data:
logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}")
raise HTTPException(
status_code=502,
detail={"error": "WaveSpeed response missing prediction id"},
)
prediction_id = data["id"]
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
return prediction_id
def get_prediction_result(self, prediction_id: str, timeout: int = 120) -> Dict[str, Any]:
"""
Fetch the current status/result for a prediction.
"""
url = f"{self.BASE_URL}/predictions/{prediction_id}/result"
try:
response = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=timeout)
except requests_exceptions.Timeout as exc:
raise HTTPException(
status_code=504,
detail={
"error": "WaveSpeed polling request timed out",
"prediction_id": prediction_id,
"resume_available": True,
"exception": str(exc),
},
) from exc
except requests_exceptions.RequestException as exc:
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed polling request failed",
"prediction_id": prediction_id,
"resume_available": True,
"exception": str(exc),
},
) from exc
if response.status_code != 200:
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed prediction polling failed",
"status_code": response.status_code,
"response": response.text,
},
)
result = response.json().get("data")
if not result:
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
return result
def poll_until_complete(
self,
prediction_id: str,
timeout_seconds: int = 240,
interval_seconds: float = 1.0,
) -> Dict[str, Any]:
"""
Poll WaveSpeed until the job completes, fails, or times out.
"""
start_time = time.time()
while True:
try:
result = self.get_prediction_result(prediction_id)
except HTTPException as exc:
detail = exc.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
status = result.get("status")
if status == "completed":
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed.")
return result
if status == "failed":
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {result.get('error')}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed animation failed",
"prediction_id": prediction_id,
"details": result.get("error"),
},
)
elapsed = time.time() - start_time
if elapsed > timeout_seconds:
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
raise HTTPException(
status_code=504,
detail={
"error": "WaveSpeed animation timed out",
"prediction_id": prediction_id,
"details": result,
},
)
logger.debug(f"[WaveSpeed] Prediction {prediction_id} status={status}. Waiting...")
time.sleep(interval_seconds)
def optimize_prompt(
self,
text: str,
mode: str = "image",
style: str = "default",
image: Optional[str] = None,
enable_sync_mode: bool = True,
timeout: int = 30,
) -> str:
"""
Optimize a prompt using WaveSpeed prompt optimizer.
Args:
text: The prompt text to optimize
mode: "image" or "video" (default: "image")
style: "default", "artistic", "photographic", "technical", "anime", "realistic" (default: "default")
image: Base64-encoded image for context (optional)
enable_sync_mode: If True, wait for result and return it directly (default: True)
timeout: Request timeout in seconds (default: 30)
Returns:
Optimized prompt text
"""
model_path = "wavespeed-ai/prompt-optimizer"
url = f"{self.BASE_URL}/{model_path}"
payload = {
"text": text,
"mode": mode,
"style": style,
"enable_sync_mode": enable_sync_mode,
}
if image:
payload["image"] = image
logger.info(f"[WaveSpeed] Optimizing prompt via {url} (mode={mode}, style={style})")
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Prompt optimization failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed prompt optimization failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Handle sync mode - result should be directly in outputs
if enable_sync_mode:
outputs = data.get("outputs") or []
if not outputs:
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed prompt optimizer returned no outputs",
)
# Extract optimized prompt from outputs
# In sync mode, outputs[0] should be the optimized text directly (or a URL to fetch)
optimized_prompt = None
if isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
# If it's a string that looks like a URL, fetch it
if isinstance(first_output, str):
if first_output.startswith("http://") or first_output.startswith("https://"):
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
url_response = requests.get(first_output, timeout=timeout)
if url_response.status_code == 200:
optimized_prompt = url_response.text.strip()
else:
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch optimized prompt from WaveSpeed URL",
)
else:
# It's already the text
optimized_prompt = first_output
elif isinstance(first_output, dict):
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
if not optimized_prompt:
logger.error(f"[WaveSpeed] Could not extract optimized prompt from outputs: {outputs}")
raise HTTPException(
status_code=502,
detail="WaveSpeed prompt optimizer output format not recognized",
)
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
return optimized_prompt
# Async mode - return prediction ID for polling
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed response missing prediction id for async mode",
)
# Poll for result
result = self.poll_until_complete(prediction_id, timeout_seconds=60, interval_seconds=0.5)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed prompt optimizer returned no outputs")
# Extract optimized prompt from outputs
# In async mode, outputs[0] is typically a URL that needs to be fetched
optimized_prompt = None
if isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
# In async mode, it's usually a URL to fetch
if isinstance(first_output, str):
if first_output.startswith("http://") or first_output.startswith("https://"):
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
url_response = requests.get(first_output, timeout=timeout)
if url_response.status_code == 200:
optimized_prompt = url_response.text.strip()
else:
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch optimized prompt from WaveSpeed URL",
)
else:
# If it's already text (shouldn't happen in async mode, but handle it)
optimized_prompt = first_output
elif isinstance(first_output, dict):
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
if not optimized_prompt:
raise HTTPException(
status_code=502,
detail="WaveSpeed prompt optimizer output format not recognized",
)
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
return optimized_prompt
def generate_speech(
self,
text: str,
voice_id: str,
speed: float = 1.0,
volume: float = 1.0,
pitch: float = 0.0,
emotion: str = "happy",
enable_sync_mode: bool = True,
timeout: int = 60,
**kwargs
) -> bytes:
"""
Generate speech audio using Minimax Speech 02 HD via WaveSpeed.
Args:
text: Text to convert to speech (max 10000 characters)
voice_id: Voice ID (e.g., "Wise_Woman", "Friendly_Person", etc.)
speed: Speech speed (0.5-2.0, default: 1.0)
volume: Speech volume (0.1-10.0, default: 1.0)
pitch: Speech pitch (-12 to 12, default: 0.0)
emotion: Emotion ("happy", "sad", "angry", etc., default: "happy")
enable_sync_mode: If True, wait for result and return it directly (default: True)
timeout: Request timeout in seconds (default: 60)
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
Returns:
bytes: Generated audio bytes
"""
model_path = "minimax/speech-02-hd"
url = f"{self.BASE_URL}/{model_path}"
payload = {
"text": text,
"voice_id": voice_id,
"speed": speed,
"volume": volume,
"pitch": pitch,
"emotion": emotion,
"enable_sync_mode": enable_sync_mode,
}
# Add optional parameters
optional_params = [
"english_normalization",
"sample_rate",
"bitrate",
"channel",
"format",
"language_boost",
]
for param in optional_params:
if param in kwargs:
payload[param] = kwargs[param]
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed speech generation failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Handle sync mode - result should be directly in outputs
if enable_sync_mode:
outputs = data.get("outputs") or []
if not outputs:
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed speech generator returned no outputs",
)
# Extract audio URL from outputs
audio_url = None
if isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
if isinstance(first_output, str):
audio_url = first_output
elif isinstance(first_output, dict):
audio_url = first_output.get("url") or first_output.get("output")
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
logger.error(f"[WaveSpeed] Invalid audio URL in outputs: {outputs}")
raise HTTPException(
status_code=502,
detail="WaveSpeed speech generator output format not recognized",
)
# Fetch audio bytes from URL
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
audio_response = requests.get(audio_url, timeout=timeout)
if audio_response.status_code == 200:
audio_bytes = audio_response.content
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
return audio_bytes
else:
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch generated audio from WaveSpeed URL",
)
# Async mode - return prediction ID for polling
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed response missing prediction id for async mode",
)
# Poll for result
result = self.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=0.5)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed speech generator returned no outputs")
# Extract audio URL and fetch
audio_url = None
if isinstance(outputs, list) and len(outputs) > 0:
first_output = outputs[0]
if isinstance(first_output, str):
audio_url = first_output
elif isinstance(first_output, dict):
audio_url = first_output.get("url") or first_output.get("output")
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
raise HTTPException(
status_code=502,
detail="WaveSpeed speech generator output format not recognized",
)
# Fetch audio bytes
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
audio_response = requests.get(audio_url, timeout=timeout)
if audio_response.status_code == 200:
audio_bytes = audio_response.content
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
return audio_bytes
else:
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch generated audio from WaveSpeed URL",
)

View File

@@ -0,0 +1,122 @@
from __future__ import annotations
import base64
from typing import Any, Dict, Optional
import requests
from fastapi import HTTPException
from loguru import logger
from .client import WaveSpeedClient
from .kling_animation import generate_animation_prompt
INFINITALK_MODEL_PATH = "wavespeed-ai/infinitetalk"
INFINITALK_MODEL_NAME = "wavespeed-ai/infinitetalk"
INFINITALK_DEFAULT_COST = 0.30 # $0.30 per 5 seconds at 720p tier
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
MAX_AUDIO_BYTES = 50 * 1024 * 1024 # 50MB safety cap
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
encoded = base64.b64encode(content_bytes).decode("utf-8")
return f"data:{mime_type};base64,{encoded}"
def animate_scene_with_voiceover(
*,
image_bytes: bytes,
audio_bytes: bytes,
scene_data: Dict[str, Any],
story_context: Dict[str, Any],
user_id: str,
resolution: str = "720p",
prompt_override: Optional[str] = None,
image_mime: str = "image/png",
audio_mime: str = "audio/mpeg",
client: Optional[WaveSpeedClient] = None,
) -> Dict[str, Any]:
"""
Animate a scene image with narration audio using WaveSpeed InfiniteTalk.
Returns dict with video bytes, prompt used, model name, and cost.
"""
if not image_bytes:
raise HTTPException(status_code=404, detail="Scene image bytes missing for animation.")
if not audio_bytes:
raise HTTPException(status_code=404, detail="Scene audio bytes missing for animation.")
if len(image_bytes) > MAX_IMAGE_BYTES:
raise HTTPException(
status_code=400,
detail="Scene image exceeds 10MB limit required by WaveSpeed InfiniteTalk.",
)
if len(audio_bytes) > MAX_AUDIO_BYTES:
raise HTTPException(
status_code=400,
detail="Scene audio exceeds 50MB limit allowed for InfiniteTalk requests.",
)
if resolution not in {"480p", "720p"}:
raise HTTPException(status_code=400, detail="Resolution must be '480p' or '720p'.")
animation_prompt = prompt_override or generate_animation_prompt(scene_data, story_context, user_id)
payload = {
"image": _as_data_uri(image_bytes, image_mime),
"audio": _as_data_uri(audio_bytes, audio_mime),
"resolution": resolution,
}
if animation_prompt:
payload["prompt"] = animation_prompt
client = client or WaveSpeedClient()
prediction_id = client.submit_image_to_video(INFINITALK_MODEL_PATH, payload, timeout=60)
try:
result = client.poll_until_complete(prediction_id, timeout_seconds=600, interval_seconds=1.0)
except HTTPException as exc:
detail = exc.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
raise
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed InfiniteTalk completed but returned no outputs.")
video_url = outputs[0]
video_response = requests.get(video_url, timeout=180)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "Failed to download InfiniteTalk video",
"status_code": video_response.status_code,
"response": video_response.text[:200],
},
)
metadata = result.get("metadata") or {}
duration = metadata.get("duration_seconds") or metadata.get("duration") or 0
logger.info(
"[InfiniteTalk] Generated talking avatar video user=%s scene=%s resolution=%s size=%s bytes",
user_id,
scene_data.get("scene_number"),
resolution,
len(video_response.content),
)
return {
"video_bytes": video_response.content,
"prompt": animation_prompt,
"duration": duration or 5,
"model_name": INFINITALK_MODEL_NAME,
"cost": INFINITALK_DEFAULT_COST,
"provider": "wavespeed",
"source_video_url": video_url,
"prediction_id": prediction_id,
}

View File

@@ -0,0 +1,360 @@
from __future__ import annotations
import base64
import json
from typing import Any, Dict, Optional
import requests
from fastapi import HTTPException
from services.llm_providers.main_text_generation import llm_text_gen
from utils.logger_utils import get_service_logger
from .client import WaveSpeedClient
try:
import imghdr
except ModuleNotFoundError: # Python 3.13 removed imghdr
imghdr = None
logger = get_service_logger("wavespeed.kling_animation")
KLING_MODEL_PATH = "kwaivgi/kling-v2.5-turbo-std/image-to-video"
KLING_MODEL_5S = "kling-v2.5-turbo-std-5s"
KLING_MODEL_10S = "kling-v2.5-turbo-std-10s"
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10 MB limit per docs
def _detect_image_mime(image_bytes: bytes) -> str:
if imghdr:
detected = imghdr.what(None, h=image_bytes)
if detected == "jpeg":
return "image/jpeg"
if detected == "png":
return "image/png"
if detected == "gif":
return "image/gif"
header = image_bytes[:8]
if header.startswith(b"\x89PNG"):
return "image/png"
if header[:2] == b"\xff\xd8":
return "image/jpeg"
if header[:3] in (b"GIF", b"GIF"):
return "image/gif"
return "image/png"
def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str, Any]) -> str:
title = (scene_data.get("title") or "Scene").strip()
description = (scene_data.get("description") or "").strip()
image_prompt = (scene_data.get("image_prompt") or "").strip()
tone = (story_context.get("story_tone") or "story").strip()
setting = (story_context.get("story_setting") or "the scene").strip()
parts = [
f"{title} cinematic motion shot.",
description[:220] if description else "",
f"Camera glides with subtle parallax over {setting}.",
f"Maintain a {tone} mood with natural lighting accents.",
f"Honor the original illustration details: {image_prompt[:200]}." if image_prompt else "",
"5-second sequence, gentle push-in, flowing cloth and atmospheric particles.",
]
fallback_prompt = " ".join(filter(None, parts))
return fallback_prompt.strip()
def _load_llm_json_response(response_text: Any) -> Dict[str, Any]:
"""Normalize responses from llm_text_gen (dict or JSON string)."""
if isinstance(response_text, dict):
return response_text
if isinstance(response_text, str):
return json.loads(response_text)
raise ValueError(f"Unexpected response type: {type(response_text)}")
def _generate_text_prompt(
*,
prompt: str,
system_prompt: str,
user_id: str,
fallback_prompt: str,
) -> str:
"""Fallback text generation when structured JSON parsing fails."""
try:
response = llm_text_gen(
prompt=prompt.strip(),
system_prompt=system_prompt,
user_id=user_id,
)
except HTTPException as exc:
if exc.status_code == 429:
raise
logger.warning(
"[AnimateScene] Text-mode prompt generation failed (%s). Using deterministic fallback.",
exc.detail,
)
return fallback_prompt
except Exception as exc:
logger.error(
"[AnimateScene] Unexpected error generating text prompt: %s",
exc,
exc_info=True,
)
return fallback_prompt
if isinstance(response, dict):
candidates = [
response.get("animation_prompt"),
response.get("prompt"),
response.get("text"),
]
for candidate in candidates:
if isinstance(candidate, str) and candidate.strip():
return candidate.strip()
# As a last resort, stringify the dict
response_text = json.dumps(response, ensure_ascii=False)
else:
response_text = str(response)
cleaned = response_text.strip()
return cleaned or fallback_prompt
def generate_animation_prompt(
scene_data: Dict[str, Any],
story_context: Dict[str, Any],
user_id: str,
) -> str:
"""
Generate an animation-focused prompt using llm_text_gen, falling back to a deterministic prompt if LLM fails.
"""
fallback_prompt = _build_fallback_prompt(scene_data, story_context)
system_prompt = (
"You are an expert cinematic animation director. "
"You transform static illustrated scenes into short cinematic motion clips. "
"Describe motion, camera behavior, atmosphere, and pacing."
)
description = scene_data.get("description", "")
image_prompt = scene_data.get("image_prompt", "")
title = scene_data.get("title", "")
tone = story_context.get("story_tone") or story_context.get("story_tone", "")
setting = story_context.get("story_setting") or story_context.get("story_setting", "")
prompt = f"""
Create a concise animation prompt (2-3 sentences) for a 5-second cinematic clip.
Scene Title: {title}
Description: {description}
Existing Image Prompt: {image_prompt}
Story Tone: {tone}
Setting: {setting}
Focus on:
- Motion of characters/objects
- Camera movement (pan, zoom, dolly, orbit)
- Atmosphere, lighting, and emotion
- Timing cues appropriate for a {tone or "story"} scene
Respond with JSON: {{"animation_prompt": "<prompt>"}}
"""
try:
response = llm_text_gen(
prompt=prompt.strip(),
system_prompt=system_prompt,
user_id=user_id,
json_struct={
"type": "object",
"properties": {
"animation_prompt": {
"type": "string",
"description": "A cinematic motion prompt for the WaveSpeed image-to-video model.",
}
},
"required": ["animation_prompt"],
},
)
structured = _load_llm_json_response(response)
animation_prompt = structured.get("animation_prompt")
if not animation_prompt or not isinstance(animation_prompt, str):
raise ValueError("Missing animation_prompt in structured response")
cleaned_prompt = animation_prompt.strip()
if not cleaned_prompt:
raise ValueError("animation_prompt is empty after trimming")
return cleaned_prompt
except HTTPException as exc:
if exc.status_code == 429:
raise
logger.warning(
"[AnimateScene] Structured LLM prompt generation failed (%s). Falling back to text parsing.",
exc.detail,
)
return _generate_text_prompt(
prompt=prompt,
system_prompt=system_prompt,
user_id=user_id,
fallback_prompt=fallback_prompt,
)
except (json.JSONDecodeError, ValueError, KeyError) as exc:
logger.warning(
"[AnimateScene] Failed to parse structured animation prompt (%s). Falling back to text parsing.",
exc,
)
return _generate_text_prompt(
prompt=prompt,
system_prompt=system_prompt,
user_id=user_id,
fallback_prompt=fallback_prompt,
)
except Exception as exc:
logger.error(
"[AnimateScene] Unexpected error generating animation prompt: %s",
exc,
exc_info=True,
)
return fallback_prompt
def animate_scene_image(
*,
image_bytes: bytes,
scene_data: Dict[str, Any],
story_context: Dict[str, Any],
user_id: str,
duration: int = 5,
guidance_scale: float = 0.5,
negative_prompt: Optional[str] = None,
client: Optional[WaveSpeedClient] = None,
) -> Dict[str, Any]:
"""
Animate a scene image using WaveSpeed Kling v2.5 Turbo Std.
Returns dict with video bytes, prompt used, model name, duration, and cost.
"""
if duration not in (5, 10):
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds for scene animation.")
if len(image_bytes) > MAX_IMAGE_BYTES:
raise HTTPException(
status_code=400,
detail="Scene image exceeds 10MB limit required by WaveSpeed."
)
guidance_scale = max(0.0, min(1.0, guidance_scale))
animation_prompt = generate_animation_prompt(scene_data, story_context, user_id)
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
payload = {
"duration": duration,
"guidance_scale": guidance_scale,
"image": image_b64,
"prompt": animation_prompt,
}
if negative_prompt:
payload["negative_prompt"] = negative_prompt.strip()
client = client or WaveSpeedClient()
prediction_id = client.submit_image_to_video(KLING_MODEL_PATH, payload)
try:
result = client.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
except HTTPException as exc:
detail = exc.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
detail.setdefault("message", "WaveSpeed request is still processing. Use resume endpoint to fetch the video once ready.")
raise HTTPException(status_code=exc.status_code, detail=detail)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed completed but returned no outputs.")
video_url = outputs[0]
video_response = requests.get(video_url, timeout=60)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "Failed to download animation video",
"status_code": video_response.status_code,
"response": video_response.text[:200],
},
)
model_name = KLING_MODEL_5S if duration == 5 else KLING_MODEL_10S
cost = 0.21 if duration == 5 else 0.42
return {
"video_bytes": video_response.content,
"prompt": animation_prompt,
"duration": duration,
"model_name": model_name,
"cost": cost,
"provider": "wavespeed",
"source_video_url": video_url,
"prediction_id": prediction_id,
}
def resume_scene_animation(
*,
prediction_id: str,
duration: int,
user_id: str,
client: Optional[WaveSpeedClient] = None,
) -> Dict[str, Any]:
"""
Resume a previously submitted animation by fetching the completed result.
"""
if duration not in (5, 10):
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds for scene animation.")
client = client or WaveSpeedClient()
result = client.get_prediction_result(prediction_id, timeout=120)
status = result.get("status")
if status != "completed":
raise HTTPException(
status_code=409,
detail={
"error": "WaveSpeed prediction is not completed yet",
"prediction_id": prediction_id,
"status": status,
},
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed completed but returned no outputs.")
video_url = outputs[0]
video_response = requests.get(video_url, timeout=120)
if video_response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "Failed to download animation video during resume",
"status_code": video_response.status_code,
"response": video_response.text[:200],
"prediction_id": prediction_id,
},
)
animation_prompt = result.get("prompt") or ""
model_name = KLING_MODEL_5S if duration == 5 else KLING_MODEL_10S
cost = 0.21 if duration == 5 else 0.42
logger.info("[AnimateScene] Resumed download for prediction=%s", prediction_id)
return {
"video_bytes": video_response.content,
"prompt": animation_prompt,
"duration": duration,
"model_name": model_name,
"cost": cost,
"provider": "wavespeed",
"source_video_url": video_url,
"prediction_id": prediction_id,
}