AI story writer enhancements, text to video and voice generation, subscription management, and more.
This commit is contained in:
1
backend/services/wavespeed/__init__.py
Normal file
1
backend/services/wavespeed/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
471
backend/services/wavespeed/client.py
Normal file
471
backend/services/wavespeed/client.py
Normal file
@@ -0,0 +1,471 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from requests import exceptions as requests_exceptions
|
||||
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.client")
|
||||
|
||||
|
||||
class WaveSpeedClient:
|
||||
"""
|
||||
Thin HTTP client for the WaveSpeed AI API.
|
||||
Handles authentication, submission, and polling helpers.
|
||||
"""
|
||||
|
||||
BASE_URL = "https://api.wavespeed.ai/api/v3"
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
manager = APIKeyManager()
|
||||
self.api_key = api_key or manager.get_api_key("wavespeed")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("WAVESPEED_API_KEY is not configured. Please add it to your environment.")
|
||||
|
||||
def _headers(self) -> Dict[str, str]:
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def submit_image_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Submit an image-to-video generation request.
|
||||
|
||||
Returns the prediction ID for polling.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
logger.info(f"[WaveSpeed] Submitting request to {url}")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json().get("data")
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
|
||||
return prediction_id
|
||||
|
||||
def get_prediction_result(self, prediction_id: str, timeout: int = 120) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the current status/result for a prediction.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/predictions/{prediction_id}/result"
|
||||
try:
|
||||
response = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=timeout)
|
||||
except requests_exceptions.Timeout as exc:
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
except requests_exceptions.RequestException as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request failed",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prediction polling failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
result = response.json().get("data")
|
||||
if not result:
|
||||
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
|
||||
return result
|
||||
|
||||
def poll_until_complete(
|
||||
self,
|
||||
prediction_id: str,
|
||||
timeout_seconds: int = 240,
|
||||
interval_seconds: float = 1.0,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll WaveSpeed until the job completes, fails, or times out.
|
||||
"""
|
||||
start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
result = self.get_prediction_result(prediction_id)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
|
||||
status = result.get("status")
|
||||
if status == "completed":
|
||||
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed.")
|
||||
return result
|
||||
if status == "failed":
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {result.get('error')}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed animation failed",
|
||||
"prediction_id": prediction_id,
|
||||
"details": result.get("error"),
|
||||
},
|
||||
)
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed animation timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"details": result,
|
||||
},
|
||||
)
|
||||
|
||||
logger.debug(f"[WaveSpeed] Prediction {prediction_id} status={status}. Waiting...")
|
||||
time.sleep(interval_seconds)
|
||||
|
||||
def optimize_prompt(
|
||||
self,
|
||||
text: str,
|
||||
mode: str = "image",
|
||||
style: str = "default",
|
||||
image: Optional[str] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Optimize a prompt using WaveSpeed prompt optimizer.
|
||||
|
||||
Args:
|
||||
text: The prompt text to optimize
|
||||
mode: "image" or "video" (default: "image")
|
||||
style: "default", "artistic", "photographic", "technical", "anime", "realistic" (default: "default")
|
||||
image: Base64-encoded image for context (optional)
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 30)
|
||||
|
||||
Returns:
|
||||
Optimized prompt text
|
||||
"""
|
||||
model_path = "wavespeed-ai/prompt-optimizer"
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"mode": mode,
|
||||
"style": style,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
if image:
|
||||
payload["image"] = image
|
||||
|
||||
logger.info(f"[WaveSpeed] Optimizing prompt via {url} (mode={mode}, style={style})")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Prompt optimization failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prompt optimization failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer returned no outputs",
|
||||
)
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
# In sync mode, outputs[0] should be the optimized text directly (or a URL to fetch)
|
||||
optimized_prompt = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
|
||||
# If it's a string that looks like a URL, fetch it
|
||||
if isinstance(first_output, str):
|
||||
if first_output.startswith("http://") or first_output.startswith("https://"):
|
||||
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
||||
url_response = requests.get(first_output, timeout=timeout)
|
||||
if url_response.status_code == 200:
|
||||
optimized_prompt = url_response.text.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
)
|
||||
else:
|
||||
# It's already the text
|
||||
optimized_prompt = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
||||
|
||||
if not optimized_prompt:
|
||||
logger.error(f"[WaveSpeed] Could not extract optimized prompt from outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.poll_until_complete(prediction_id, timeout_seconds=60, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed prompt optimizer returned no outputs")
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
# In async mode, outputs[0] is typically a URL that needs to be fetched
|
||||
optimized_prompt = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
|
||||
# In async mode, it's usually a URL to fetch
|
||||
if isinstance(first_output, str):
|
||||
if first_output.startswith("http://") or first_output.startswith("https://"):
|
||||
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
||||
url_response = requests.get(first_output, timeout=timeout)
|
||||
if url_response.status_code == 200:
|
||||
optimized_prompt = url_response.text.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
)
|
||||
else:
|
||||
# If it's already text (shouldn't happen in async mode, but handle it)
|
||||
optimized_prompt = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
||||
|
||||
if not optimized_prompt:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
def generate_speech(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str,
|
||||
speed: float = 1.0,
|
||||
volume: float = 1.0,
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 60,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate speech audio using Minimax Speech 02 HD via WaveSpeed.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech (max 10000 characters)
|
||||
voice_id: Voice ID (e.g., "Wise_Woman", "Friendly_Person", etc.)
|
||||
speed: Speech speed (0.5-2.0, default: 1.0)
|
||||
volume: Speech volume (0.1-10.0, default: 1.0)
|
||||
pitch: Speech pitch (-12 to 12, default: 0.0)
|
||||
emotion: Emotion ("happy", "sad", "angry", etc., default: "happy")
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 60)
|
||||
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
|
||||
|
||||
Returns:
|
||||
bytes: Generated audio bytes
|
||||
"""
|
||||
model_path = "minimax/speech-02-hd"
|
||||
url = f"{self.BASE_URL}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"volume": volume,
|
||||
"pitch": pitch,
|
||||
"emotion": emotion,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
optional_params = [
|
||||
"english_normalization",
|
||||
"sample_rate",
|
||||
"bitrate",
|
||||
"channel",
|
||||
"format",
|
||||
"language_boost",
|
||||
]
|
||||
for param in optional_params:
|
||||
if param in kwargs:
|
||||
payload[param] = kwargs[param]
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
|
||||
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed speech generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator returned no outputs",
|
||||
)
|
||||
|
||||
# Extract audio URL from outputs
|
||||
audio_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
audio_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
audio_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
||||
logger.error(f"[WaveSpeed] Invalid audio URL in outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch audio bytes from URL
|
||||
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
||||
audio_response = requests.get(audio_url, timeout=timeout)
|
||||
if audio_response.status_code == 200:
|
||||
audio_bytes = audio_response.content
|
||||
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
||||
return audio_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated audio from WaveSpeed URL",
|
||||
)
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed speech generator returned no outputs")
|
||||
|
||||
# Extract audio URL and fetch
|
||||
audio_url = None
|
||||
if isinstance(outputs, list) and len(outputs) > 0:
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
audio_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
audio_url = first_output.get("url") or first_output.get("output")
|
||||
|
||||
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
# Fetch audio bytes
|
||||
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
||||
audio_response = requests.get(audio_url, timeout=timeout)
|
||||
if audio_response.status_code == 200:
|
||||
audio_bytes = audio_response.content
|
||||
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
||||
return audio_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated audio from WaveSpeed URL",
|
||||
)
|
||||
|
||||
122
backend/services/wavespeed/infinitetalk.py
Normal file
122
backend/services/wavespeed/infinitetalk.py
Normal file
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from .client import WaveSpeedClient
|
||||
from .kling_animation import generate_animation_prompt
|
||||
|
||||
INFINITALK_MODEL_PATH = "wavespeed-ai/infinitetalk"
|
||||
INFINITALK_MODEL_NAME = "wavespeed-ai/infinitetalk"
|
||||
INFINITALK_DEFAULT_COST = 0.30 # $0.30 per 5 seconds at 720p tier
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MAX_AUDIO_BYTES = 50 * 1024 * 1024 # 50MB safety cap
|
||||
|
||||
|
||||
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
|
||||
encoded = base64.b64encode(content_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
|
||||
def animate_scene_with_voiceover(
|
||||
*,
|
||||
image_bytes: bytes,
|
||||
audio_bytes: bytes,
|
||||
scene_data: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
user_id: str,
|
||||
resolution: str = "720p",
|
||||
prompt_override: Optional[str] = None,
|
||||
image_mime: str = "image/png",
|
||||
audio_mime: str = "audio/mpeg",
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene image with narration audio using WaveSpeed InfiniteTalk.
|
||||
Returns dict with video bytes, prompt used, model name, and cost.
|
||||
"""
|
||||
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene image bytes missing for animation.")
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=404, detail="Scene audio bytes missing for animation.")
|
||||
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Scene image exceeds 10MB limit required by WaveSpeed InfiniteTalk.",
|
||||
)
|
||||
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Scene audio exceeds 50MB limit allowed for InfiniteTalk requests.",
|
||||
)
|
||||
|
||||
if resolution not in {"480p", "720p"}:
|
||||
raise HTTPException(status_code=400, detail="Resolution must be '480p' or '720p'.")
|
||||
|
||||
animation_prompt = prompt_override or generate_animation_prompt(scene_data, story_context, user_id)
|
||||
|
||||
payload = {
|
||||
"image": _as_data_uri(image_bytes, image_mime),
|
||||
"audio": _as_data_uri(audio_bytes, audio_mime),
|
||||
"resolution": resolution,
|
||||
}
|
||||
if animation_prompt:
|
||||
payload["prompt"] = animation_prompt
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
prediction_id = client.submit_image_to_video(INFINITALK_MODEL_PATH, payload, timeout=60)
|
||||
|
||||
try:
|
||||
result = client.poll_until_complete(prediction_id, timeout_seconds=600, interval_seconds=1.0)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed InfiniteTalk completed but returned no outputs.")
|
||||
|
||||
video_url = outputs[0]
|
||||
video_response = requests.get(video_url, timeout=180)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download InfiniteTalk video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
},
|
||||
)
|
||||
|
||||
metadata = result.get("metadata") or {}
|
||||
duration = metadata.get("duration_seconds") or metadata.get("duration") or 0
|
||||
|
||||
logger.info(
|
||||
"[InfiniteTalk] Generated talking avatar video user=%s scene=%s resolution=%s size=%s bytes",
|
||||
user_id,
|
||||
scene_data.get("scene_number"),
|
||||
resolution,
|
||||
len(video_response.content),
|
||||
)
|
||||
|
||||
return {
|
||||
"video_bytes": video_response.content,
|
||||
"prompt": animation_prompt,
|
||||
"duration": duration or 5,
|
||||
"model_name": INFINITALK_MODEL_NAME,
|
||||
"cost": INFINITALK_DEFAULT_COST,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
|
||||
|
||||
360
backend/services/wavespeed/kling_animation.py
Normal file
360
backend/services/wavespeed/kling_animation.py
Normal file
@@ -0,0 +1,360 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from .client import WaveSpeedClient
|
||||
|
||||
try:
|
||||
import imghdr
|
||||
except ModuleNotFoundError: # Python 3.13 removed imghdr
|
||||
imghdr = None
|
||||
|
||||
logger = get_service_logger("wavespeed.kling_animation")
|
||||
|
||||
KLING_MODEL_PATH = "kwaivgi/kling-v2.5-turbo-std/image-to-video"
|
||||
KLING_MODEL_5S = "kling-v2.5-turbo-std-5s"
|
||||
KLING_MODEL_10S = "kling-v2.5-turbo-std-10s"
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10 MB limit per docs
|
||||
|
||||
|
||||
def _detect_image_mime(image_bytes: bytes) -> str:
|
||||
if imghdr:
|
||||
detected = imghdr.what(None, h=image_bytes)
|
||||
if detected == "jpeg":
|
||||
return "image/jpeg"
|
||||
if detected == "png":
|
||||
return "image/png"
|
||||
if detected == "gif":
|
||||
return "image/gif"
|
||||
|
||||
header = image_bytes[:8]
|
||||
if header.startswith(b"\x89PNG"):
|
||||
return "image/png"
|
||||
if header[:2] == b"\xff\xd8":
|
||||
return "image/jpeg"
|
||||
if header[:3] in (b"GIF", b"GIF"):
|
||||
return "image/gif"
|
||||
|
||||
return "image/png"
|
||||
|
||||
|
||||
def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str, Any]) -> str:
|
||||
title = (scene_data.get("title") or "Scene").strip()
|
||||
description = (scene_data.get("description") or "").strip()
|
||||
image_prompt = (scene_data.get("image_prompt") or "").strip()
|
||||
tone = (story_context.get("story_tone") or "story").strip()
|
||||
setting = (story_context.get("story_setting") or "the scene").strip()
|
||||
|
||||
parts = [
|
||||
f"{title} cinematic motion shot.",
|
||||
description[:220] if description else "",
|
||||
f"Camera glides with subtle parallax over {setting}.",
|
||||
f"Maintain a {tone} mood with natural lighting accents.",
|
||||
f"Honor the original illustration details: {image_prompt[:200]}." if image_prompt else "",
|
||||
"5-second sequence, gentle push-in, flowing cloth and atmospheric particles.",
|
||||
]
|
||||
fallback_prompt = " ".join(filter(None, parts))
|
||||
return fallback_prompt.strip()
|
||||
|
||||
|
||||
def _load_llm_json_response(response_text: Any) -> Dict[str, Any]:
|
||||
"""Normalize responses from llm_text_gen (dict or JSON string)."""
|
||||
if isinstance(response_text, dict):
|
||||
return response_text
|
||||
if isinstance(response_text, str):
|
||||
return json.loads(response_text)
|
||||
raise ValueError(f"Unexpected response type: {type(response_text)}")
|
||||
|
||||
|
||||
def _generate_text_prompt(
|
||||
*,
|
||||
prompt: str,
|
||||
system_prompt: str,
|
||||
user_id: str,
|
||||
fallback_prompt: str,
|
||||
) -> str:
|
||||
"""Fallback text generation when structured JSON parsing fails."""
|
||||
try:
|
||||
response = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code == 429:
|
||||
raise
|
||||
logger.warning(
|
||||
"[AnimateScene] Text-mode prompt generation failed (%s). Using deterministic fallback.",
|
||||
exc.detail,
|
||||
)
|
||||
return fallback_prompt
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[AnimateScene] Unexpected error generating text prompt: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return fallback_prompt
|
||||
|
||||
if isinstance(response, dict):
|
||||
candidates = [
|
||||
response.get("animation_prompt"),
|
||||
response.get("prompt"),
|
||||
response.get("text"),
|
||||
]
|
||||
for candidate in candidates:
|
||||
if isinstance(candidate, str) and candidate.strip():
|
||||
return candidate.strip()
|
||||
# As a last resort, stringify the dict
|
||||
response_text = json.dumps(response, ensure_ascii=False)
|
||||
else:
|
||||
response_text = str(response)
|
||||
|
||||
cleaned = response_text.strip()
|
||||
return cleaned or fallback_prompt
|
||||
|
||||
|
||||
def generate_animation_prompt(
|
||||
scene_data: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
user_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Generate an animation-focused prompt using llm_text_gen, falling back to a deterministic prompt if LLM fails.
|
||||
"""
|
||||
fallback_prompt = _build_fallback_prompt(scene_data, story_context)
|
||||
system_prompt = (
|
||||
"You are an expert cinematic animation director. "
|
||||
"You transform static illustrated scenes into short cinematic motion clips. "
|
||||
"Describe motion, camera behavior, atmosphere, and pacing."
|
||||
)
|
||||
|
||||
description = scene_data.get("description", "")
|
||||
image_prompt = scene_data.get("image_prompt", "")
|
||||
title = scene_data.get("title", "")
|
||||
tone = story_context.get("story_tone") or story_context.get("story_tone", "")
|
||||
setting = story_context.get("story_setting") or story_context.get("story_setting", "")
|
||||
|
||||
prompt = f"""
|
||||
Create a concise animation prompt (2-3 sentences) for a 5-second cinematic clip.
|
||||
|
||||
Scene Title: {title}
|
||||
Description: {description}
|
||||
Existing Image Prompt: {image_prompt}
|
||||
Story Tone: {tone}
|
||||
Setting: {setting}
|
||||
|
||||
Focus on:
|
||||
- Motion of characters/objects
|
||||
- Camera movement (pan, zoom, dolly, orbit)
|
||||
- Atmosphere, lighting, and emotion
|
||||
- Timing cues appropriate for a {tone or "story"} scene
|
||||
|
||||
Respond with JSON: {{"animation_prompt": "<prompt>"}}
|
||||
"""
|
||||
|
||||
try:
|
||||
response = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
json_struct={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"animation_prompt": {
|
||||
"type": "string",
|
||||
"description": "A cinematic motion prompt for the WaveSpeed image-to-video model.",
|
||||
}
|
||||
},
|
||||
"required": ["animation_prompt"],
|
||||
},
|
||||
)
|
||||
structured = _load_llm_json_response(response)
|
||||
animation_prompt = structured.get("animation_prompt")
|
||||
if not animation_prompt or not isinstance(animation_prompt, str):
|
||||
raise ValueError("Missing animation_prompt in structured response")
|
||||
cleaned_prompt = animation_prompt.strip()
|
||||
if not cleaned_prompt:
|
||||
raise ValueError("animation_prompt is empty after trimming")
|
||||
return cleaned_prompt
|
||||
except HTTPException as exc:
|
||||
if exc.status_code == 429:
|
||||
raise
|
||||
logger.warning(
|
||||
"[AnimateScene] Structured LLM prompt generation failed (%s). Falling back to text parsing.",
|
||||
exc.detail,
|
||||
)
|
||||
return _generate_text_prompt(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
fallback_prompt=fallback_prompt,
|
||||
)
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as exc:
|
||||
logger.warning(
|
||||
"[AnimateScene] Failed to parse structured animation prompt (%s). Falling back to text parsing.",
|
||||
exc,
|
||||
)
|
||||
return _generate_text_prompt(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
fallback_prompt=fallback_prompt,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"[AnimateScene] Unexpected error generating animation prompt: %s",
|
||||
exc,
|
||||
exc_info=True,
|
||||
)
|
||||
return fallback_prompt
|
||||
|
||||
|
||||
def animate_scene_image(
|
||||
*,
|
||||
image_bytes: bytes,
|
||||
scene_data: Dict[str, Any],
|
||||
story_context: Dict[str, Any],
|
||||
user_id: str,
|
||||
duration: int = 5,
|
||||
guidance_scale: float = 0.5,
|
||||
negative_prompt: Optional[str] = None,
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Animate a scene image using WaveSpeed Kling v2.5 Turbo Std.
|
||||
Returns dict with video bytes, prompt used, model name, duration, and cost.
|
||||
"""
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds for scene animation.")
|
||||
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Scene image exceeds 10MB limit required by WaveSpeed."
|
||||
)
|
||||
|
||||
guidance_scale = max(0.0, min(1.0, guidance_scale))
|
||||
animation_prompt = generate_animation_prompt(scene_data, story_context, user_id)
|
||||
image_b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||||
|
||||
payload = {
|
||||
"duration": duration,
|
||||
"guidance_scale": guidance_scale,
|
||||
"image": image_b64,
|
||||
"prompt": animation_prompt,
|
||||
}
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt.strip()
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
prediction_id = client.submit_image_to_video(KLING_MODEL_PATH, payload)
|
||||
try:
|
||||
result = client.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
detail.setdefault("message", "WaveSpeed request is still processing. Use resume endpoint to fetch the video once ready.")
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed completed but returned no outputs.")
|
||||
|
||||
video_url = outputs[0]
|
||||
video_response = requests.get(video_url, timeout=60)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download animation video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
},
|
||||
)
|
||||
|
||||
model_name = KLING_MODEL_5S if duration == 5 else KLING_MODEL_10S
|
||||
cost = 0.21 if duration == 5 else 0.42
|
||||
|
||||
return {
|
||||
"video_bytes": video_response.content,
|
||||
"prompt": animation_prompt,
|
||||
"duration": duration,
|
||||
"model_name": model_name,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
|
||||
|
||||
def resume_scene_animation(
|
||||
*,
|
||||
prediction_id: str,
|
||||
duration: int,
|
||||
user_id: str,
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Resume a previously submitted animation by fetching the completed result.
|
||||
"""
|
||||
if duration not in (5, 10):
|
||||
raise HTTPException(status_code=400, detail="Duration must be 5 or 10 seconds for scene animation.")
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
result = client.get_prediction_result(prediction_id, timeout=120)
|
||||
status = result.get("status")
|
||||
if status != "completed":
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail={
|
||||
"error": "WaveSpeed prediction is not completed yet",
|
||||
"prediction_id": prediction_id,
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed completed but returned no outputs.")
|
||||
|
||||
video_url = outputs[0]
|
||||
video_response = requests.get(video_url, timeout=120)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download animation video during resume",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
|
||||
animation_prompt = result.get("prompt") or ""
|
||||
model_name = KLING_MODEL_5S if duration == 5 else KLING_MODEL_10S
|
||||
cost = 0.21 if duration == 5 else 0.42
|
||||
|
||||
logger.info("[AnimateScene] Resumed download for prediction=%s", prediction_id)
|
||||
|
||||
return {
|
||||
"video_bytes": video_response.content,
|
||||
"prompt": animation_prompt,
|
||||
"duration": duration,
|
||||
"model_name": model_name,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user