Added video studio router and endpoints. Added research router and endpoints. Added youtube router and endpoints. Added onboarding utils router and endpoints. Added onboarding utils service. Added onboarding utils models. Added onboarding utils routes. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils. Added onboarding utils utils.
This commit is contained in:
File diff suppressed because it is too large
Load Diff
1
backend/services/wavespeed/generators/__init__.py
Normal file
1
backend/services/wavespeed/generators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""WaveSpeed API generators for different content types."""
|
||||
374
backend/services/wavespeed/generators/image.py
Normal file
374
backend/services/wavespeed/generators/image.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Image generation generator for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import time
|
||||
import requests
|
||||
from typing import Optional
|
||||
from requests import exceptions as requests_exceptions
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.image")
|
||||
|
||||
|
||||
class ImageGenerator:
|
||||
"""Image generation generator."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize image generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 120,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image using WaveSpeed AI models (Ideogram V3 or Qwen Image).
|
||||
|
||||
Args:
|
||||
model: Model to use ("ideogram-v3-turbo" or "qwen-image")
|
||||
prompt: Text prompt for image generation
|
||||
width: Image width (default: 1024)
|
||||
height: Image height (default: 1024)
|
||||
num_inference_steps: Number of inference steps
|
||||
guidance_scale: Guidance scale for generation
|
||||
negative_prompt: Negative prompt (what to avoid)
|
||||
seed: Random seed for reproducibility
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 120)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes
|
||||
"""
|
||||
# Map model names to WaveSpeed API paths
|
||||
model_paths = {
|
||||
"ideogram-v3-turbo": "ideogram-ai/ideogram-v3-turbo",
|
||||
"qwen-image": "wavespeed-ai/qwen-image/text-to-image",
|
||||
}
|
||||
|
||||
model_path = model_paths.get(model)
|
||||
if not model_path:
|
||||
raise ValueError(f"Unsupported image model: {model}. Supported: {list(model_paths.keys())}")
|
||||
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if num_inference_steps is not None:
|
||||
payload["num_inference_steps"] = num_inference_steps
|
||||
if guidance_scale is not None:
|
||||
payload["guidance_scale"] = guidance_scale
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Add any extra parameters
|
||||
for key, value in kwargs.items():
|
||||
if key not in payload:
|
||||
payload[key] = value
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating image via {url} (model={model}, prompt_length={len(prompt)})")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Image generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status - if "created" or "processing", we need to poll even in sync mode
|
||||
status = data.get("status", "").lower()
|
||||
outputs = data.get("outputs") or []
|
||||
prediction_id = data.get("id")
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
# If we have outputs and status is "completed", use them directly
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed] Got immediate results from sync mode (status: {status})")
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout)
|
||||
|
||||
# Sync mode returned "created" or "processing" status - need to poll
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] Sync mode returned status '{status}' but no prediction ID: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed sync mode returned async response without prediction ID",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Sync mode returned status '{status}' with no outputs. "
|
||||
f"Falling back to polling (prediction_id: {prediction_id})"
|
||||
)
|
||||
|
||||
# Async mode OR sync mode that returned "created"/"processing" - poll for result
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id",
|
||||
)
|
||||
|
||||
# Poll for result (use longer timeout for image generation)
|
||||
logger.info(f"[WaveSpeed] Polling for image generation result (prediction_id: {prediction_id}, status: {status})")
|
||||
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed image generator returned no outputs")
|
||||
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=60)
|
||||
|
||||
def generate_character_image(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_image_bytes: bytes,
|
||||
style: str = "Auto",
|
||||
aspect_ratio: str = "16:9",
|
||||
rendering_speed: str = "Default",
|
||||
timeout: Optional[int] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image using Ideogram Character API to maintain character consistency.
|
||||
Creates variations of a reference character image while respecting the base appearance.
|
||||
|
||||
Note: This API is always async and requires polling for results.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt describing the scene/context for the character
|
||||
reference_image_bytes: Reference image bytes (base avatar)
|
||||
style: Character style type ("Auto", "Fiction", or "Realistic")
|
||||
aspect_ratio: Aspect ratio ("1:1", "16:9", "9:16", "4:3", "3:4")
|
||||
rendering_speed: Rendering speed ("Default", "Turbo", "Quality")
|
||||
timeout: Total timeout in seconds for submission + polling (default: 180)
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes with consistent character
|
||||
"""
|
||||
import base64
|
||||
|
||||
# Encode reference image to base64
|
||||
image_base64 = base64.b64encode(reference_image_bytes).decode('utf-8')
|
||||
# Add data URI prefix
|
||||
image_data_uri = f"data:image/png;base64,{image_base64}"
|
||||
|
||||
url = f"{self.base_url}/ideogram-ai/ideogram-character"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"image": image_data_uri,
|
||||
"style": style,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"rendering_speed": rendering_speed,
|
||||
}
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating character image via Ideogram Character (prompt_length={len(prompt)})")
|
||||
|
||||
# Retry on transient connection failures
|
||||
max_retries = 2
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
timeout=(30, 30)
|
||||
)
|
||||
break
|
||||
except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(f"[WaveSpeed] Connection attempt {attempt + 1}/{max_retries + 1} failed, retrying in {retry_delay}s: {e}")
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
else:
|
||||
error_type = "Connection timeout" if isinstance(e, requests_exceptions.ConnectTimeout) else "Connection error"
|
||||
logger.error(f"[WaveSpeed] {error_type} to Ideogram Character API after {max_retries + 1} attempts: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504 if isinstance(e, requests_exceptions.ConnectTimeout) else 502,
|
||||
detail={
|
||||
"error": f"{error_type} to WaveSpeed Ideogram Character API",
|
||||
"message": "Unable to establish connection to the image generation service after multiple attempts. Please check your network connection and try again.",
|
||||
"exception": str(e),
|
||||
"retry_recommended": True,
|
||||
},
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
logger.error(f"[WaveSpeed] Request timeout to Ideogram Character API: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "Request timeout to WaveSpeed Ideogram Character API",
|
||||
"message": "The image generation request took too long. Please try again.",
|
||||
"exception": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Character image generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Ideogram Character generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Extract prediction ID
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character response missing prediction id",
|
||||
)
|
||||
|
||||
# Ideogram Character API is always async - check status and poll if needed
|
||||
outputs = data.get("outputs") or []
|
||||
status = data.get("status", "unknown")
|
||||
|
||||
logger.info(f"[WaveSpeed] Ideogram Character task created: prediction_id={prediction_id}, status={status}")
|
||||
|
||||
# If status is already completed, use outputs directly (unlikely but possible)
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed] Got immediate results from Ideogram Character")
|
||||
else:
|
||||
# Always need to poll for results (API is async)
|
||||
logger.info(f"[WaveSpeed] Polling for Ideogram Character result (status: {status}, prediction_id: {prediction_id})")
|
||||
polling_timeout = timeout if timeout else None
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=polling_timeout,
|
||||
interval_seconds=0.5,
|
||||
)
|
||||
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"[WaveSpeed] Unexpected result type: {type(result)}, value: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character returned unexpected response format",
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
status = result.get("status", "unknown")
|
||||
|
||||
if status != "completed":
|
||||
error_msg = "Unknown error"
|
||||
if isinstance(result, dict):
|
||||
error_msg = result.get("error") or result.get("message") or str(result.get("details", "Unknown error"))
|
||||
else:
|
||||
error_msg = str(result)
|
||||
|
||||
logger.error(f"[WaveSpeed] Ideogram Character task did not complete: status={status}, error={error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Ideogram Character task failed",
|
||||
"status": status,
|
||||
"message": error_msg,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract image URL from outputs
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs after polling: status={status}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character returned no outputs",
|
||||
)
|
||||
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=60)
|
||||
|
||||
def _extract_image_url(self, outputs: list) -> str:
|
||||
"""Extract image URL from outputs."""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
def _download_image(self, image_url: str, timeout: int = 60) -> bytes:
|
||||
"""Download image from URL."""
|
||||
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
if image_response.status_code == 200:
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
|
||||
return image_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated image from WaveSpeed URL",
|
||||
)
|
||||
164
backend/services/wavespeed/generators/prompt.py
Normal file
164
backend/services/wavespeed/generators/prompt.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Prompt optimization generator for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.prompt")
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
"""Prompt optimization generator."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize prompt generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def optimize_prompt(
|
||||
self,
|
||||
text: str,
|
||||
mode: str = "image",
|
||||
style: str = "default",
|
||||
image: Optional[str] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Optimize a prompt using WaveSpeed prompt optimizer.
|
||||
|
||||
Args:
|
||||
text: The prompt text to optimize
|
||||
mode: "image" or "video" (default: "image")
|
||||
style: "default", "artistic", "photographic", "technical", "anime", "realistic" (default: "default")
|
||||
image: Base64-encoded image for context (optional)
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 30)
|
||||
|
||||
Returns:
|
||||
Optimized prompt text
|
||||
"""
|
||||
model_path = "wavespeed-ai/prompt-optimizer"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"mode": mode,
|
||||
"style": style,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
if image:
|
||||
payload["image"] = image
|
||||
|
||||
logger.info(f"[WaveSpeed] Optimizing prompt via {url} (mode={mode}, style={style})")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Prompt optimization failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prompt optimization failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer returned no outputs",
|
||||
)
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
optimized_prompt = self._extract_prompt_from_outputs(outputs, timeout)
|
||||
if not optimized_prompt:
|
||||
logger.error(f"[WaveSpeed] Could not extract optimized prompt from outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=60, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed prompt optimizer returned no outputs")
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
optimized_prompt = self._extract_prompt_from_outputs(outputs, timeout)
|
||||
if not optimized_prompt:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
def _extract_prompt_from_outputs(self, outputs: list, timeout: int) -> Optional[str]:
|
||||
"""Extract optimized prompt from outputs, handling URLs and direct text."""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return None
|
||||
|
||||
first_output = outputs[0]
|
||||
|
||||
# If it's a string that looks like a URL, fetch it
|
||||
if isinstance(first_output, str):
|
||||
if first_output.startswith("http://") or first_output.startswith("https://"):
|
||||
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
||||
url_response = requests.get(first_output, timeout=timeout)
|
||||
if url_response.status_code == 200:
|
||||
return url_response.text.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
)
|
||||
else:
|
||||
# It's already the text
|
||||
return first_output
|
||||
elif isinstance(first_output, dict):
|
||||
return first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
||||
|
||||
return None
|
||||
223
backend/services/wavespeed/generators/speech.py
Normal file
223
backend/services/wavespeed/generators/speech.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Speech generation generator for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import time
|
||||
import requests
|
||||
from typing import Optional
|
||||
from requests import exceptions as requests_exceptions
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.speech")
|
||||
|
||||
|
||||
class SpeechGenerator:
|
||||
"""Speech generation generator."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize speech generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def generate_speech(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str,
|
||||
speed: float = 1.0,
|
||||
volume: float = 1.0,
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 120,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate speech audio using Minimax Speech 02 HD via WaveSpeed.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech (max 10000 characters)
|
||||
voice_id: Voice ID (e.g., "Wise_Woman", "Friendly_Person", etc.)
|
||||
speed: Speech speed (0.5-2.0, default: 1.0)
|
||||
volume: Speech volume (0.1-10.0, default: 1.0)
|
||||
pitch: Speech pitch (-12 to 12, default: 0.0)
|
||||
emotion: Emotion ("happy", "sad", "angry", etc., default: "happy")
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 60)
|
||||
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
|
||||
|
||||
Returns:
|
||||
bytes: Generated audio bytes
|
||||
"""
|
||||
model_path = "minimax/speech-02-hd"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"volume": volume,
|
||||
"pitch": pitch,
|
||||
"emotion": emotion,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
optional_params = [
|
||||
"english_normalization",
|
||||
"sample_rate",
|
||||
"bitrate",
|
||||
"channel",
|
||||
"format",
|
||||
"language_boost",
|
||||
]
|
||||
for param in optional_params:
|
||||
if param in kwargs:
|
||||
payload[param] = kwargs[param]
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
|
||||
|
||||
# Retry on transient connection issues
|
||||
max_retries = 2
|
||||
retry_delay = 2.0
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
timeout=(30, 60), # connect, read
|
||||
)
|
||||
break
|
||||
except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Speech connection attempt {attempt + 1}/{max_retries + 1} failed, "
|
||||
f"retrying in {retry_delay}s: {e}"
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
logger.error(f"[WaveSpeed] Speech connection failed after {max_retries + 1} attempts: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "Connection to WaveSpeed speech API timed out",
|
||||
"message": "Unable to reach the speech service. Please try again.",
|
||||
"exception": str(e),
|
||||
"retry_recommended": True,
|
||||
},
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
logger.error(f"[WaveSpeed] Speech request timeout: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed speech request timed out",
|
||||
"message": "The speech generation request took too long. Please try again.",
|
||||
"exception": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed speech generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator returned no outputs",
|
||||
)
|
||||
|
||||
audio_url = self._extract_audio_url(outputs)
|
||||
return self._download_audio(audio_url, timeout)
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed speech generator returned no outputs")
|
||||
|
||||
audio_url = self._extract_audio_url(outputs)
|
||||
return self._download_audio(audio_url, timeout)
|
||||
|
||||
def _extract_audio_url(self, outputs: list) -> str:
|
||||
"""Extract audio URL from outputs."""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
audio_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
audio_url = first_output.get("url") or first_output.get("output")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
return audio_url
|
||||
|
||||
def _download_audio(self, audio_url: str, timeout: int) -> bytes:
|
||||
"""Download audio from URL."""
|
||||
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
||||
audio_response = requests.get(audio_url, timeout=timeout)
|
||||
if audio_response.status_code == 200:
|
||||
audio_bytes = audio_response.content
|
||||
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
||||
return audio_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated audio from WaveSpeed URL",
|
||||
)
|
||||
1330
backend/services/wavespeed/generators/video.py
Normal file
1330
backend/services/wavespeed/generators/video.py
Normal file
File diff suppressed because it is too large
Load Diff
253
backend/services/wavespeed/hunyuan_avatar.py
Normal file
253
backend/services/wavespeed/hunyuan_avatar.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""
|
||||
Hunyuan Avatar Service
|
||||
|
||||
Service for creating talking avatars using Hunyuan Avatar model.
|
||||
Reference: https://wavespeed.ai/models/wavespeed-ai/hunyuan-avatar
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from .client import WaveSpeedClient
|
||||
|
||||
HUNYUAN_AVATAR_MODEL_PATH = "wavespeed-ai/hunyuan-avatar"
|
||||
HUNYUAN_AVATAR_MODEL_NAME = "wavespeed-ai/hunyuan-avatar"
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MAX_AUDIO_BYTES = 50 * 1024 * 1024 # 50MB safety cap
|
||||
MAX_DURATION_SECONDS = 120 # 2 minutes maximum
|
||||
MIN_DURATION_SECONDS = 5 # Minimum billable duration
|
||||
|
||||
|
||||
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
|
||||
"""Convert bytes to data URI."""
|
||||
encoded = base64.b64encode(content_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
|
||||
def calculate_hunyuan_avatar_cost(resolution: str, duration: float) -> float:
|
||||
"""
|
||||
Calculate cost for Hunyuan Avatar video.
|
||||
|
||||
Pricing:
|
||||
- 480p: $0.15 per 5 seconds
|
||||
- 720p: $0.30 per 5 seconds
|
||||
- Minimum charge: 5 seconds
|
||||
- Maximum billable: 120 seconds
|
||||
|
||||
Args:
|
||||
resolution: Output resolution (480p or 720p)
|
||||
duration: Video duration in seconds
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
# Clamp duration to valid range
|
||||
actual_duration = max(MIN_DURATION_SECONDS, min(duration, MAX_DURATION_SECONDS))
|
||||
|
||||
# Calculate cost per 5 seconds
|
||||
cost_per_5_seconds = 0.15 if resolution == "480p" else 0.30
|
||||
|
||||
# Round up to nearest 5 seconds
|
||||
billable_5_second_blocks = (actual_duration + 4) // 5 # Ceiling division
|
||||
|
||||
return cost_per_5_seconds * billable_5_second_blocks
|
||||
|
||||
|
||||
def create_hunyuan_avatar(
|
||||
*,
|
||||
image_bytes: bytes,
|
||||
audio_bytes: bytes,
|
||||
resolution: str = "480p",
|
||||
prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
user_id: str = "video_studio",
|
||||
image_mime: str = "image/png",
|
||||
audio_mime: str = "audio/mpeg",
|
||||
client: Optional[WaveSpeedClient] = None,
|
||||
progress_callback: Optional[callable] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Create talking avatar video using Hunyuan Avatar.
|
||||
|
||||
Reference: https://wavespeed.ai/docs/docs-api/wavespeed-ai/hunyuan-avatar
|
||||
|
||||
Args:
|
||||
image_bytes: Portrait image as bytes
|
||||
audio_bytes: Audio file as bytes
|
||||
resolution: Output resolution (480p or 720p, default: 480p)
|
||||
prompt: Optional text to guide expression or style
|
||||
seed: Optional random seed (-1 for random)
|
||||
user_id: User ID for tracking
|
||||
image_mime: MIME type of image
|
||||
audio_mime: MIME type of audio
|
||||
client: Optional WaveSpeedClient instance
|
||||
progress_callback: Optional progress callback function
|
||||
|
||||
Returns:
|
||||
Dictionary with video_bytes, prompt, duration, model_name, cost, etc.
|
||||
"""
|
||||
if not image_bytes:
|
||||
raise HTTPException(status_code=400, detail="Image bytes are required for Hunyuan Avatar.")
|
||||
if not audio_bytes:
|
||||
raise HTTPException(status_code=400, detail="Audio bytes are required for Hunyuan Avatar.")
|
||||
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image exceeds {MAX_IMAGE_BYTES / (1024 * 1024):.0f}MB limit required by Hunyuan Avatar.",
|
||||
)
|
||||
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Audio exceeds {MAX_AUDIO_BYTES / (1024 * 1024):.0f}MB limit allowed for Hunyuan Avatar requests.",
|
||||
)
|
||||
|
||||
if resolution not in {"480p", "720p"}:
|
||||
raise HTTPException(status_code=400, detail="Resolution must be '480p' or '720p'.")
|
||||
|
||||
# Build payload
|
||||
payload: Dict[str, Any] = {
|
||||
"image": _as_data_uri(image_bytes, image_mime),
|
||||
"audio": _as_data_uri(audio_bytes, audio_mime),
|
||||
"resolution": resolution,
|
||||
}
|
||||
|
||||
if prompt:
|
||||
payload["prompt"] = prompt.strip()
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
client = client or WaveSpeedClient()
|
||||
|
||||
# Progress callback: submission
|
||||
if progress_callback:
|
||||
progress_callback(10.0, "Submitting Hunyuan Avatar request to WaveSpeed...")
|
||||
|
||||
prediction_id = client.submit_image_to_video(HUNYUAN_AVATAR_MODEL_PATH, payload, timeout=60)
|
||||
|
||||
try:
|
||||
# Poll for completion
|
||||
if progress_callback:
|
||||
progress_callback(20.0, f"Polling for completion (prediction_id: {prediction_id})...")
|
||||
|
||||
result = client.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=600, # 10 minutes max
|
||||
interval_seconds=0.5, # Poll every 0.5 seconds
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Hunyuan Avatar completed but returned no outputs",
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": f"Invalid video URL format: {video_url}",
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Progress callback: downloading video
|
||||
if progress_callback:
|
||||
progress_callback(90.0, "Downloading generated video...")
|
||||
|
||||
# Download video
|
||||
try:
|
||||
video_response = requests.get(video_url, timeout=180)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download Hunyuan Avatar video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": f"Failed to download video: {str(e)}",
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
if len(video_bytes) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Downloaded video is empty",
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
)
|
||||
|
||||
# Estimate duration (we don't get exact duration from API, so estimate from audio or use default)
|
||||
# For now, we'll use a default estimate - in production, you might want to analyze the audio file
|
||||
estimated_duration = 10.0 # Default estimate
|
||||
|
||||
# Calculate cost
|
||||
cost = calculate_hunyuan_avatar_cost(resolution, estimated_duration)
|
||||
|
||||
# Get video dimensions from resolution
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (854, 480))
|
||||
|
||||
# Extract metadata
|
||||
metadata = result.get("metadata", {})
|
||||
metadata.update({
|
||||
"has_nsfw_contents": result.get("has_nsfw_contents", []),
|
||||
"created_at": result.get("created_at"),
|
||||
"resolution": resolution,
|
||||
"max_duration": MAX_DURATION_SECONDS,
|
||||
})
|
||||
|
||||
logger.info(
|
||||
f"[Hunyuan Avatar] ✅ Generated video: {len(video_bytes)} bytes, "
|
||||
f"resolution={resolution}, cost=${cost:.2f}"
|
||||
)
|
||||
|
||||
# Progress callback: completed
|
||||
if progress_callback:
|
||||
progress_callback(100.0, "Avatar generation completed!")
|
||||
|
||||
return {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt or "",
|
||||
"duration": estimated_duration,
|
||||
"model_name": HUNYUAN_AVATAR_MODEL_NAME,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"resolution": resolution,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"metadata": metadata,
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
}
|
||||
203
backend/services/wavespeed/polling.py
Normal file
203
backend/services/wavespeed/polling.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Polling utilities for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from requests import exceptions as requests_exceptions
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.polling")
|
||||
|
||||
|
||||
class WaveSpeedPolling:
|
||||
"""Polling utilities for WaveSpeed API predictions."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
"""Initialize polling utilities.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def get_prediction_result(self, prediction_id: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the current status/result for a prediction.
|
||||
Matches the example pattern: simple GET request, check status_code == 200, return data.
|
||||
"""
|
||||
url = f"{self.base_url}/predictions/{prediction_id}/result"
|
||||
headers = self._get_headers()
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
except requests_exceptions.Timeout as exc:
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
except requests_exceptions.RequestException as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request failed",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
|
||||
# Match example pattern: check status_code == 200, then get data
|
||||
if response.status_code == 200:
|
||||
result = response.json().get("data")
|
||||
if not result:
|
||||
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
|
||||
return result
|
||||
else:
|
||||
# Non-200 status - log and raise error (matching example's break behavior)
|
||||
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prediction polling failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
def poll_until_complete(
|
||||
self,
|
||||
prediction_id: str,
|
||||
timeout_seconds: Optional[int] = None,
|
||||
interval_seconds: float = 1.0,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll WaveSpeed until the job completes or fails.
|
||||
Matches the example pattern: simple polling loop until status is "completed" or "failed".
|
||||
|
||||
Args:
|
||||
prediction_id: The prediction ID to poll for
|
||||
timeout_seconds: Optional timeout in seconds. If None, polls indefinitely until completion/failure.
|
||||
interval_seconds: Seconds to wait between polling attempts (default: 1.0, faster than 2.0)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
Dict containing the completed result
|
||||
|
||||
Raises:
|
||||
HTTPException: If the task fails, polling fails, or times out (if timeout_seconds is set)
|
||||
"""
|
||||
start_time = time.time()
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 6 # safety guard for non-transient errors
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = self.get_prediction_result(prediction_id)
|
||||
consecutive_errors = 0 # Reset error counter on success
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
|
||||
|
||||
# Determine underlying status code (WaveSpeed vs proxy)
|
||||
status_code = detail.get("status_code", exc.status_code)
|
||||
|
||||
# Treat 5xx as transient: keep polling indefinitely with backoff
|
||||
if 500 <= int(status_code) < 600:
|
||||
consecutive_errors += 1
|
||||
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Transient polling error {consecutive_errors} for {prediction_id}: "
|
||||
f"{status_code}. Backing off {backoff:.1f}s"
|
||||
)
|
||||
time.sleep(backoff)
|
||||
continue
|
||||
|
||||
# For non-transient (typically 4xx) errors, apply safety cap
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error(
|
||||
f"[WaveSpeed] Too many polling errors ({consecutive_errors}) for {prediction_id}, "
|
||||
f"status_code={status_code}. Giving up."
|
||||
)
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
|
||||
|
||||
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Polling error {consecutive_errors}/{max_consecutive_errors} for {prediction_id}: "
|
||||
f"{status_code}. Backing off {backoff:.1f}s"
|
||||
)
|
||||
time.sleep(backoff)
|
||||
continue
|
||||
|
||||
# Extract status from result (matching example pattern)
|
||||
status = result.get("status")
|
||||
|
||||
if status == "completed":
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed in {elapsed:.1f}s")
|
||||
return result
|
||||
|
||||
if status == "failed":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed task failed",
|
||||
"prediction_id": prediction_id,
|
||||
"message": error_msg,
|
||||
"details": result,
|
||||
},
|
||||
)
|
||||
|
||||
# Check timeout only if specified
|
||||
if timeout_seconds is not None:
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed task timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"current_status": status,
|
||||
"message": f"Task did not complete within {timeout_seconds} seconds. Status: {status}",
|
||||
},
|
||||
)
|
||||
|
||||
# Log progress periodically (every 30 seconds)
|
||||
elapsed = time.time() - start_time
|
||||
if int(elapsed) % 30 == 0 and elapsed > 0:
|
||||
logger.info(f"[WaveSpeed] Polling {prediction_id}: status={status}, elapsed={elapsed:.0f}s")
|
||||
|
||||
# Call progress callback if provided
|
||||
if progress_callback:
|
||||
# Map elapsed time to progress (20-80% range during polling)
|
||||
# Assume typical completion time is timeout_seconds or 120s default
|
||||
estimated_total = timeout_seconds or 120
|
||||
progress = min(80.0, 20.0 + (elapsed / estimated_total) * 60.0)
|
||||
progress_callback(progress, f"Video generation in progress... ({elapsed:.0f}s)")
|
||||
|
||||
# Poll faster (1.0s instead of 2.0s) to match example's responsiveness
|
||||
time.sleep(interval_seconds)
|
||||
Reference in New Issue
Block a user