from __future__ import annotations import json import time from typing import Any, Dict, Optional import requests from fastapi import HTTPException from requests import exceptions as requests_exceptions from services.onboarding.api_key_manager import APIKeyManager from utils.logger_utils import get_service_logger logger = get_service_logger("wavespeed.client") class WaveSpeedClient: """ Thin HTTP client for the WaveSpeed AI API. Handles authentication, submission, and polling helpers. """ BASE_URL = "https://api.wavespeed.ai/api/v3" def __init__(self, api_key: Optional[str] = None): manager = APIKeyManager() self.api_key = api_key or manager.get_api_key("wavespeed") if not self.api_key: raise RuntimeError("WAVESPEED_API_KEY is not configured. Please add it to your environment.") def _headers(self) -> Dict[str, str]: return { "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}", } def submit_image_to_video( self, model_path: str, payload: Dict[str, Any], timeout: int = 30, ) -> str: """ Submit an image-to-video generation request. Returns the prediction ID for polling. """ url = f"{self.BASE_URL}/{model_path}" logger.info(f"[WaveSpeed] Submitting request to {url}") response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout) if response.status_code != 200: logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed image-to-video submission failed", "status_code": response.status_code, "response": response.text, }, ) data = response.json().get("data") if not data or "id" not in data: logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}") raise HTTPException( status_code=502, detail={"error": "WaveSpeed response missing prediction id"}, ) prediction_id = data["id"] logger.info(f"[WaveSpeed] Submitted request: {prediction_id}") return prediction_id def get_prediction_result(self, prediction_id: str, timeout: int = 30) -> Dict[str, Any]: """ Fetch the current status/result for a prediction. Matches the example pattern: simple GET request, check status_code == 200, return data. """ url = f"{self.BASE_URL}/predictions/{prediction_id}/result" headers = {"Authorization": f"Bearer {self.api_key}"} try: response = requests.get(url, headers=headers, timeout=timeout) except requests_exceptions.Timeout as exc: raise HTTPException( status_code=504, detail={ "error": "WaveSpeed polling request timed out", "prediction_id": prediction_id, "resume_available": True, "exception": str(exc), }, ) from exc except requests_exceptions.RequestException as exc: raise HTTPException( status_code=502, detail={ "error": "WaveSpeed polling request failed", "prediction_id": prediction_id, "resume_available": True, "exception": str(exc), }, ) from exc # Match example pattern: check status_code == 200, then get data if response.status_code == 200: result = response.json().get("data") if not result: raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"}) return result else: # Non-200 status - log and raise error (matching example's break behavior) logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed prediction polling failed", "status_code": response.status_code, "response": response.text, }, ) def poll_until_complete( self, prediction_id: str, timeout_seconds: Optional[int] = None, interval_seconds: float = 1.0, ) -> Dict[str, Any]: """ Poll WaveSpeed until the job completes or fails. Matches the example pattern: simple polling loop until status is "completed" or "failed". Args: prediction_id: The prediction ID to poll for timeout_seconds: Optional timeout in seconds. If None, polls indefinitely until completion/failure. interval_seconds: Seconds to wait between polling attempts (default: 1.0, faster than 2.0) Returns: Dict containing the completed result Raises: HTTPException: If the task fails, polling fails, or times out (if timeout_seconds is set) """ start_time = time.time() consecutive_errors = 0 max_consecutive_errors = 6 # safety guard for non-transient errors while True: try: result = self.get_prediction_result(prediction_id) consecutive_errors = 0 # Reset error counter on success except HTTPException as exc: detail = exc.detail or {} if isinstance(detail, dict): detail.setdefault("prediction_id", prediction_id) detail.setdefault("resume_available", True) detail.setdefault("error", detail.get("error", "WaveSpeed polling failed")) # Determine underlying status code (WaveSpeed vs proxy) status_code = detail.get("status_code", exc.status_code) # Treat 5xx as transient: keep polling indefinitely with backoff if 500 <= int(status_code) < 600: consecutive_errors += 1 backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1))) logger.warning( f"[WaveSpeed] Transient polling error {consecutive_errors} for {prediction_id}: " f"{status_code}. Backing off {backoff:.1f}s" ) time.sleep(backoff) continue # For non-transient (typically 4xx) errors, apply safety cap consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: logger.error( f"[WaveSpeed] Too many polling errors ({consecutive_errors}) for {prediction_id}, " f"status_code={status_code}. Giving up." ) raise HTTPException(status_code=exc.status_code, detail=detail) from exc backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1))) logger.warning( f"[WaveSpeed] Polling error {consecutive_errors}/{max_consecutive_errors} for {prediction_id}: " f"{status_code}. Backing off {backoff:.1f}s" ) time.sleep(backoff) continue # Extract status from result (matching example pattern) status = result.get("status") if status == "completed": elapsed = time.time() - start_time logger.info(f"[WaveSpeed] Prediction {prediction_id} completed in {elapsed:.1f}s") return result if status == "failed": error_msg = result.get("error", "Unknown error") logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {error_msg}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed task failed", "prediction_id": prediction_id, "message": error_msg, "details": result, }, ) # Check timeout only if specified if timeout_seconds is not None: elapsed = time.time() - start_time if elapsed > timeout_seconds: logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s") raise HTTPException( status_code=504, detail={ "error": "WaveSpeed task timed out", "prediction_id": prediction_id, "timeout_seconds": timeout_seconds, "current_status": status, "message": f"Task did not complete within {timeout_seconds} seconds. Status: {status}", }, ) # Log progress periodically (every 30 seconds) elapsed = time.time() - start_time if int(elapsed) % 30 == 0 and elapsed > 0: logger.info(f"[WaveSpeed] Polling {prediction_id}: status={status}, elapsed={elapsed:.0f}s") # Poll faster (1.0s instead of 2.0s) to match example's responsiveness time.sleep(interval_seconds) def optimize_prompt( self, text: str, mode: str = "image", style: str = "default", image: Optional[str] = None, enable_sync_mode: bool = True, timeout: int = 30, ) -> str: """ Optimize a prompt using WaveSpeed prompt optimizer. Args: text: The prompt text to optimize mode: "image" or "video" (default: "image") style: "default", "artistic", "photographic", "technical", "anime", "realistic" (default: "default") image: Base64-encoded image for context (optional) enable_sync_mode: If True, wait for result and return it directly (default: True) timeout: Request timeout in seconds (default: 30) Returns: Optimized prompt text """ model_path = "wavespeed-ai/prompt-optimizer" url = f"{self.BASE_URL}/{model_path}" payload = { "text": text, "mode": mode, "style": style, "enable_sync_mode": enable_sync_mode, } if image: payload["image"] = image logger.info(f"[WaveSpeed] Optimizing prompt via {url} (mode={mode}, style={style})") response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout) if response.status_code != 200: logger.error(f"[WaveSpeed] Prompt optimization failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed prompt optimization failed", "status_code": response.status_code, "response": response.text, }, ) response_json = response.json() data = response_json.get("data") or response_json # Handle sync mode - result should be directly in outputs if enable_sync_mode: outputs = data.get("outputs") or [] if not outputs: logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}") raise HTTPException( status_code=502, detail="WaveSpeed prompt optimizer returned no outputs", ) # Extract optimized prompt from outputs # In sync mode, outputs[0] should be the optimized text directly (or a URL to fetch) optimized_prompt = None if isinstance(outputs, list) and len(outputs) > 0: first_output = outputs[0] # If it's a string that looks like a URL, fetch it if isinstance(first_output, str): if first_output.startswith("http://") or first_output.startswith("https://"): logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}") url_response = requests.get(first_output, timeout=timeout) if url_response.status_code == 200: optimized_prompt = url_response.text.strip() else: logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}") raise HTTPException( status_code=502, detail="Failed to fetch optimized prompt from WaveSpeed URL", ) else: # It's already the text optimized_prompt = first_output elif isinstance(first_output, dict): optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output") if not optimized_prompt: logger.error(f"[WaveSpeed] Could not extract optimized prompt from outputs: {outputs}") raise HTTPException( status_code=502, detail="WaveSpeed prompt optimizer output format not recognized", ) logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)") return optimized_prompt # Async mode - return prediction ID for polling prediction_id = data.get("id") if not prediction_id: logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}") raise HTTPException( status_code=502, detail="WaveSpeed response missing prediction id for async mode", ) # Poll for result result = self.poll_until_complete(prediction_id, timeout_seconds=60, interval_seconds=0.5) outputs = result.get("outputs") or [] if not outputs: raise HTTPException(status_code=502, detail="WaveSpeed prompt optimizer returned no outputs") # Extract optimized prompt from outputs # In async mode, outputs[0] is typically a URL that needs to be fetched optimized_prompt = None if isinstance(outputs, list) and len(outputs) > 0: first_output = outputs[0] # In async mode, it's usually a URL to fetch if isinstance(first_output, str): if first_output.startswith("http://") or first_output.startswith("https://"): logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}") url_response = requests.get(first_output, timeout=timeout) if url_response.status_code == 200: optimized_prompt = url_response.text.strip() else: logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}") raise HTTPException( status_code=502, detail="Failed to fetch optimized prompt from WaveSpeed URL", ) else: # If it's already text (shouldn't happen in async mode, but handle it) optimized_prompt = first_output elif isinstance(first_output, dict): optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output") if not optimized_prompt: raise HTTPException( status_code=502, detail="WaveSpeed prompt optimizer output format not recognized", ) logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)") return optimized_prompt def generate_image( self, model: str, prompt: str, width: int = 1024, height: int = 1024, num_inference_steps: Optional[int] = None, guidance_scale: Optional[float] = None, negative_prompt: Optional[str] = None, seed: Optional[int] = None, enable_sync_mode: bool = True, timeout: int = 120, **kwargs ) -> bytes: """ Generate image using WaveSpeed AI models (Ideogram V3 or Qwen Image). Args: model: Model to use ("ideogram-v3-turbo" or "qwen-image") prompt: Text prompt for image generation width: Image width (default: 1024) height: Image height (default: 1024) num_inference_steps: Number of inference steps guidance_scale: Guidance scale for generation negative_prompt: Negative prompt (what to avoid) seed: Random seed for reproducibility enable_sync_mode: If True, wait for result and return it directly (default: True) timeout: Request timeout in seconds (default: 120) **kwargs: Additional parameters Returns: bytes: Generated image bytes """ # Map model names to WaveSpeed API paths model_paths = { "ideogram-v3-turbo": "ideogram-ai/ideogram-v3-turbo", "qwen-image": "wavespeed-ai/qwen-image/text-to-image", } model_path = model_paths.get(model) if not model_path: raise ValueError(f"Unsupported image model: {model}. Supported: {list(model_paths.keys())}") url = f"{self.BASE_URL}/{model_path}" payload = { "prompt": prompt, "width": width, "height": height, "enable_sync_mode": enable_sync_mode, } # Add optional parameters if num_inference_steps is not None: payload["num_inference_steps"] = num_inference_steps if guidance_scale is not None: payload["guidance_scale"] = guidance_scale if negative_prompt: payload["negative_prompt"] = negative_prompt if seed is not None: payload["seed"] = seed # Add any extra parameters for key, value in kwargs.items(): if key not in payload: payload[key] = value logger.info(f"[WaveSpeed] Generating image via {url} (model={model}, prompt_length={len(prompt)})") response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout) if response.status_code != 200: logger.error(f"[WaveSpeed] Image generation failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed image generation failed", "status_code": response.status_code, "response": response.text, }, ) response_json = response.json() data = response_json.get("data") or response_json # Handle sync mode - result should be directly in outputs if enable_sync_mode: outputs = data.get("outputs") or [] if not outputs: logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}") raise HTTPException( status_code=502, detail="WaveSpeed image generator returned no outputs", ) # Extract image URL from outputs image_url = None if isinstance(outputs, list) and len(outputs) > 0: first_output = outputs[0] if isinstance(first_output, str): image_url = first_output elif isinstance(first_output, dict): image_url = first_output.get("url") or first_output.get("output") if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")): logger.error(f"[WaveSpeed] Invalid image URL in outputs: {outputs}") raise HTTPException( status_code=502, detail="WaveSpeed image generator output format not recognized", ) # Fetch image bytes from URL logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}") image_response = requests.get(image_url, timeout=timeout) if image_response.status_code == 200: image_bytes = image_response.content logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)") return image_bytes else: logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}") raise HTTPException( status_code=502, detail="Failed to fetch generated image from WaveSpeed URL", ) # Async mode - poll for result prediction_id = data.get("id") if not prediction_id: logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}") raise HTTPException( status_code=502, detail="WaveSpeed response missing prediction id for async mode", ) # Poll for result result = self.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0) outputs = result.get("outputs") or [] if not outputs: raise HTTPException(status_code=502, detail="WaveSpeed image generator returned no outputs") # Extract image URL and fetch image_url = None if isinstance(outputs, list) and len(outputs) > 0: first_output = outputs[0] if isinstance(first_output, str): image_url = first_output elif isinstance(first_output, dict): image_url = first_output.get("url") or first_output.get("output") if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")): raise HTTPException( status_code=502, detail="WaveSpeed image generator output format not recognized", ) # Fetch image bytes logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}") # Use reasonable timeout for downloading the final image (60s should be enough) # The timeout parameter is for polling, not for downloading image_response = requests.get(image_url, timeout=60) if image_response.status_code == 200: image_bytes = image_response.content logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)") return image_bytes else: logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}") raise HTTPException( status_code=502, detail="Failed to fetch generated image from WaveSpeed URL", ) def generate_character_image( self, prompt: str, reference_image_bytes: bytes, style: str = "Auto", aspect_ratio: str = "16:9", rendering_speed: str = "Default", timeout: Optional[int] = None, ) -> bytes: """ Generate image using Ideogram Character API to maintain character consistency. Creates variations of a reference character image while respecting the base appearance. Note: This API is always async and requires polling for results. Args: prompt: Text prompt describing the scene/context for the character reference_image_bytes: Reference image bytes (base avatar) style: Character style type ("Auto", "Fiction", or "Realistic") aspect_ratio: Aspect ratio ("1:1", "16:9", "9:16", "4:3", "3:4") rendering_speed: Rendering speed ("Default", "Turbo", "Quality") timeout: Total timeout in seconds for submission + polling (default: 180) Returns: bytes: Generated image bytes with consistent character """ import base64 # Encode reference image to base64 image_base64 = base64.b64encode(reference_image_bytes).decode('utf-8') # Add data URI prefix image_data_uri = f"data:image/png;base64,{image_base64}" url = f"{self.BASE_URL}/ideogram-ai/ideogram-character" # Note: enable_sync_mode is not a valid parameter for Ideogram Character API # The API is always async and requires polling payload = { "prompt": prompt, "image": image_data_uri, "style": style, "aspect_ratio": aspect_ratio, "rendering_speed": rendering_speed, } logger.info(f"[WaveSpeed] Generating character image via Ideogram Character (prompt_length={len(prompt)})") # POST request should return quickly with just the task ID # Use reasonable timeouts for the initial submission # Connection timeout: 30s (increased for reliability - network may be slow) # Read timeout: 30s (should be enough to get task ID response) # Retry logic for transient connection failures max_retries = 2 retry_delay = 2.0 # seconds for attempt in range(max_retries + 1): try: response = requests.post( url, headers=self._headers(), json=payload, timeout=(30, 30) # (connect_timeout, read_timeout) - increased for network reliability ) break # Success, exit retry loop except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e: if attempt < max_retries: logger.warning(f"[WaveSpeed] Connection attempt {attempt + 1}/{max_retries + 1} failed, retrying in {retry_delay}s: {e}") time.sleep(retry_delay) retry_delay *= 2 # Exponential backoff continue else: # Final attempt failed error_type = "Connection timeout" if isinstance(e, requests_exceptions.ConnectTimeout) else "Connection error" logger.error(f"[WaveSpeed] {error_type} to Ideogram Character API after {max_retries + 1} attempts: {e}") raise HTTPException( status_code=504 if isinstance(e, requests_exceptions.ConnectTimeout) else 502, detail={ "error": f"{error_type} to WaveSpeed Ideogram Character API", "message": "Unable to establish connection to the image generation service after multiple attempts. Please check your network connection and try again.", "exception": str(e), "retry_recommended": True, }, ) except requests_exceptions.Timeout as e: logger.error(f"[WaveSpeed] Request timeout to Ideogram Character API: {e}") raise HTTPException( status_code=504, detail={ "error": "Request timeout to WaveSpeed Ideogram Character API", "message": "The image generation request took too long. Please try again.", "exception": str(e), }, ) if response.status_code != 200: logger.error(f"[WaveSpeed] Character image generation failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed Ideogram Character generation failed", "status_code": response.status_code, "response": response.text, }, ) response_json = response.json() data = response_json.get("data") or response_json # Extract prediction ID prediction_id = data.get("id") if not prediction_id: logger.error(f"[WaveSpeed] No prediction ID in response: {response.text}") raise HTTPException( status_code=502, detail="WaveSpeed Ideogram Character response missing prediction id", ) # Ideogram Character API is always async - check status and poll if needed outputs = data.get("outputs") or [] status = data.get("status", "unknown") logger.info(f"[WaveSpeed] Ideogram Character task created: prediction_id={prediction_id}, status={status}") # If status is already completed, use outputs directly (unlikely but possible) if outputs and status == "completed": logger.info(f"[WaveSpeed] Got immediate results from Ideogram Character") else: # Always need to poll for results (API is async) logger.info(f"[WaveSpeed] Polling for Ideogram Character result (status: {status}, prediction_id: {prediction_id})") # Poll until complete - use timeout if provided, otherwise poll indefinitely # Match example pattern exactly: simple while True loop, check status, break on completed/failed polling_timeout = timeout if timeout else None # None means poll indefinitely result = self.poll_until_complete( prediction_id, timeout_seconds=polling_timeout, interval_seconds=0.5, # Poll every 0.5s (closer to example's 0.1s) ) # Safely extract outputs and status if not isinstance(result, dict): logger.error(f"[WaveSpeed] Unexpected result type: {type(result)}, value: {result}") raise HTTPException( status_code=502, detail="WaveSpeed Ideogram Character returned unexpected response format", ) outputs = result.get("outputs") or [] status = result.get("status", "unknown") if status != "completed": # Safely extract error message error_msg = "Unknown error" if isinstance(result, dict): error_msg = result.get("error") or result.get("message") or str(result.get("details", "Unknown error")) else: error_msg = str(result) logger.error(f"[WaveSpeed] Ideogram Character task did not complete: status={status}, error={error_msg}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed Ideogram Character task failed", "status": status, "message": error_msg, } ) # Extract image URL from outputs if not outputs: logger.error(f"[WaveSpeed] No outputs after polling: status={status}") raise HTTPException( status_code=502, detail="WaveSpeed Ideogram Character returned no outputs", ) image_url = None if isinstance(outputs, list) and len(outputs) > 0: first_output = outputs[0] if isinstance(first_output, str): image_url = first_output elif isinstance(first_output, dict): image_url = first_output.get("url") or first_output.get("image_url") if not image_url: logger.error(f"[WaveSpeed] No image URL in outputs: {outputs}") raise HTTPException( status_code=502, detail="WaveSpeed Ideogram Character response missing image URL", ) # Download image logger.info(f"[WaveSpeed] Downloading character image from: {image_url}") image_response = requests.get(image_url, timeout=60) if image_response.status_code != 200: logger.error(f"[WaveSpeed] Failed to download image: {image_response.status_code}") raise HTTPException( status_code=502, detail="Failed to download generated character image", ) image_bytes = image_response.content logger.info(f"[WaveSpeed] ✅ Successfully generated character image: {len(image_bytes)} bytes") return image_bytes def generate_speech( self, text: str, voice_id: str, speed: float = 1.0, volume: float = 1.0, pitch: float = 0.0, emotion: str = "happy", enable_sync_mode: bool = True, timeout: int = 120, **kwargs ) -> bytes: """ Generate speech audio using Minimax Speech 02 HD via WaveSpeed. Args: text: Text to convert to speech (max 10000 characters) voice_id: Voice ID (e.g., "Wise_Woman", "Friendly_Person", etc.) speed: Speech speed (0.5-2.0, default: 1.0) volume: Speech volume (0.1-10.0, default: 1.0) pitch: Speech pitch (-12 to 12, default: 0.0) emotion: Emotion ("happy", "sad", "angry", etc., default: "happy") enable_sync_mode: If True, wait for result and return it directly (default: True) timeout: Request timeout in seconds (default: 60) **kwargs: Additional parameters (sample_rate, bitrate, format, etc.) Returns: bytes: Generated audio bytes """ model_path = "minimax/speech-02-hd" url = f"{self.BASE_URL}/{model_path}" payload = { "text": text, "voice_id": voice_id, "speed": speed, "volume": volume, "pitch": pitch, "emotion": emotion, "enable_sync_mode": enable_sync_mode, } # Add optional parameters optional_params = [ "english_normalization", "sample_rate", "bitrate", "channel", "format", "language_boost", ] for param in optional_params: if param in kwargs: payload[param] = kwargs[param] logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})") # Retry on transient connection issues max_retries = 2 retry_delay = 2.0 last_error = None for attempt in range(max_retries + 1): try: response = requests.post( url, headers=self._headers(), json=payload, timeout=(30, 60), # connect, read ) break except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e: last_error = e if attempt < max_retries: logger.warning( f"[WaveSpeed] Speech connection attempt {attempt + 1}/{max_retries + 1} failed, " f"retrying in {retry_delay}s: {e}" ) time.sleep(retry_delay) retry_delay *= 2 continue logger.error(f"[WaveSpeed] Speech connection failed after {max_retries + 1} attempts: {e}") raise HTTPException( status_code=504, detail={ "error": "Connection to WaveSpeed speech API timed out", "message": "Unable to reach the speech service. Please try again.", "exception": str(e), "retry_recommended": True, }, ) except requests_exceptions.Timeout as e: last_error = e logger.error(f"[WaveSpeed] Speech request timeout: {e}") raise HTTPException( status_code=504, detail={ "error": "WaveSpeed speech request timed out", "message": "The speech generation request took too long. Please try again.", "exception": str(e), }, ) if response.status_code != 200: logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed speech generation failed", "status_code": response.status_code, "response": response.text, }, ) response_json = response.json() data = response_json.get("data") or response_json # Handle sync mode - result should be directly in outputs if enable_sync_mode: outputs = data.get("outputs") or [] if not outputs: logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}") raise HTTPException( status_code=502, detail="WaveSpeed speech generator returned no outputs", ) # Extract audio URL from outputs audio_url = None if isinstance(outputs, list) and len(outputs) > 0: first_output = outputs[0] if isinstance(first_output, str): audio_url = first_output elif isinstance(first_output, dict): audio_url = first_output.get("url") or first_output.get("output") if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")): logger.error(f"[WaveSpeed] Invalid audio URL in outputs: {outputs}") raise HTTPException( status_code=502, detail="WaveSpeed speech generator output format not recognized", ) # Fetch audio bytes from URL logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}") audio_response = requests.get(audio_url, timeout=timeout) if audio_response.status_code == 200: audio_bytes = audio_response.content logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)") return audio_bytes else: logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}") raise HTTPException( status_code=502, detail="Failed to fetch generated audio from WaveSpeed URL", ) # Async mode - return prediction ID for polling prediction_id = data.get("id") if not prediction_id: logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}") raise HTTPException( status_code=502, detail="WaveSpeed response missing prediction id for async mode", ) # Poll for result result = self.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=0.5) outputs = result.get("outputs") or [] if not outputs: raise HTTPException(status_code=502, detail="WaveSpeed speech generator returned no outputs") # Extract audio URL and fetch audio_url = None if isinstance(outputs, list) and len(outputs) > 0: first_output = outputs[0] if isinstance(first_output, str): audio_url = first_output elif isinstance(first_output, dict): audio_url = first_output.get("url") or first_output.get("output") if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")): raise HTTPException( status_code=502, detail="WaveSpeed speech generator output format not recognized", ) # Fetch audio bytes logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}") audio_response = requests.get(audio_url, timeout=timeout) if audio_response.status_code == 200: audio_bytes = audio_response.content logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)") return audio_bytes else: logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}") raise HTTPException( status_code=502, detail="Failed to fetch generated audio from WaveSpeed URL", ) def submit_text_to_video( self, model_path: str, payload: Dict[str, Any], timeout: int = 60, ) -> str: """ Submit a text-to-video generation request to WaveSpeed. Args: model_path: Model path (e.g., "alibaba/wan-2.5/text-to-video") payload: Request payload with prompt, resolution, duration, optional audio timeout: Request timeout in seconds Returns: Prediction ID for polling """ url = f"{self.BASE_URL}/{model_path}" logger.info(f"[WaveSpeed] Submitting text-to-video request to {url}") response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout) if response.status_code != 200: logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed text-to-video submission failed", "status_code": response.status_code, "response": response.text, }, ) data = response.json().get("data") if not data or "id" not in data: logger.error(f"[WaveSpeed] Unexpected text-to-video response: {response.text}") raise HTTPException( status_code=502, detail={"error": "WaveSpeed response missing prediction id"}, ) prediction_id = data["id"] logger.info(f"[WaveSpeed] Submitted text-to-video request: {prediction_id}") return prediction_id def generate_text_video( self, prompt: str, resolution: str = "720p", # 480p, 720p, 1080p duration: int = 5, # 5 or 10 seconds audio_base64: Optional[str] = None, # Optional audio for lip-sync negative_prompt: Optional[str] = None, seed: Optional[int] = None, enable_prompt_expansion: bool = True, enable_sync_mode: bool = False, timeout: int = 180, ) -> Dict[str, Any]: """ Generate video from text prompt using WAN 2.5 text-to-video. Args: prompt: Text prompt describing the video resolution: Output resolution (480p, 720p, 1080p) duration: Video duration in seconds (5 or 10) audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB) for lip-sync negative_prompt: Optional negative prompt seed: Optional random seed for reproducibility enable_prompt_expansion: Enable prompt optimizer enable_sync_mode: If True, wait for result and return it directly timeout: Request timeout in seconds Returns: Dictionary with video bytes, metadata, and cost """ model_path = "alibaba/wan-2.5/text-to-video" # Validate resolution valid_resolutions = ["480p", "720p", "1080p"] if resolution not in valid_resolutions: raise HTTPException( status_code=400, detail=f"Invalid resolution: {resolution}. Must be one of: {valid_resolutions}" ) # Validate duration if duration not in [5, 10]: raise HTTPException( status_code=400, detail="Duration must be 5 or 10 seconds" ) # Build payload payload = { "prompt": prompt, "resolution": resolution, "duration": duration, "enable_prompt_expansion": enable_prompt_expansion, "enable_sync_mode": enable_sync_mode, # Add sync mode to payload } # Add optional audio if audio_base64: payload["audio"] = audio_base64 # Add optional parameters if negative_prompt: payload["negative_prompt"] = negative_prompt if seed is not None: payload["seed"] = seed # Submit request logger.info( f"[WaveSpeed] Generating text-to-video: resolution={resolution}, " f"duration={duration}s, prompt_length={len(prompt)}, sync_mode={enable_sync_mode}" ) # For sync mode, submit and get result directly if enable_sync_mode: url = f"{self.BASE_URL}/{model_path}" response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout) if response.status_code != 200: logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}") raise HTTPException( status_code=502, detail={ "error": "WaveSpeed text-to-video submission failed", "status_code": response.status_code, "response": response.text[:500], }, ) response_json = response.json() data = response_json.get("data") or response_json # In sync mode, result should be directly in outputs outputs = data.get("outputs") or [] if not outputs: logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text[:500]}") raise HTTPException( status_code=502, detail="WaveSpeed text-to-video returned no outputs in sync mode", ) # Extract video URL from outputs video_url = outputs[0] if not isinstance(video_url, str) or not video_url.startswith("http"): logger.error(f"[WaveSpeed] Invalid video URL format in sync mode: {video_url}") raise HTTPException( status_code=502, detail=f"Invalid video URL format: {video_url}", ) # Download video logger.info(f"[WaveSpeed] Downloading video from sync mode URL: {video_url}") video_response = requests.get(video_url, timeout=180) if video_response.status_code != 200: raise HTTPException( status_code=502, detail={ "error": "Failed to download WAN 2.5 video from sync mode", "status_code": video_response.status_code, "response": video_response.text[:200], } ) video_bytes = video_response.content prediction_id = data.get("id", "sync_mode") metadata = data.get("metadata") or {} # video_url is already set above for sync mode else: # Async mode - submit and poll prediction_id = self.submit_text_to_video(model_path, payload, timeout=timeout) # Poll for completion try: result = self.poll_until_complete( prediction_id, timeout_seconds=timeout, interval_seconds=2.0 ) except HTTPException as e: detail = e.detail or {} if isinstance(detail, dict): detail.setdefault("prediction_id", prediction_id) detail.setdefault("resume_available", True) raise HTTPException(status_code=e.status_code, detail=detail) # Extract video URL outputs = result.get("outputs") or [] if not outputs: raise HTTPException( status_code=502, detail="WAN 2.5 text-to-video completed but returned no outputs" ) video_url = outputs[0] if not isinstance(video_url, str) or not video_url.startswith("http"): raise HTTPException( status_code=502, detail=f"Invalid video URL format: {video_url}" ) # Download video logger.info(f"[WaveSpeed] Downloading video from: {video_url}") video_response = requests.get(video_url, timeout=180) if video_response.status_code != 200: raise HTTPException( status_code=502, detail={ "error": "Failed to download WAN 2.5 video", "status_code": video_response.status_code, "response": video_response.text[:200], } ) video_bytes = video_response.content metadata = result.get("metadata") or {} # Calculate cost (same pricing as image-to-video) pricing = { "480p": 0.05, "720p": 0.10, "1080p": 0.15, } cost = pricing.get(resolution, 0.10) * duration # Get video dimensions resolution_dims = { "480p": (854, 480), "720p": (1280, 720), "1080p": (1920, 1080), } width, height = resolution_dims.get(resolution, (1280, 720)) logger.info( f"[WaveSpeed] ✅ Generated text-to-video: {len(video_bytes)} bytes, " f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}" ) return { "video_bytes": video_bytes, "prompt": prompt, "duration": float(duration), "model_name": "alibaba/wan-2.5/text-to-video", "cost": cost, "provider": "wavespeed", "source_video_url": video_url, "prediction_id": prediction_id, "resolution": resolution, "width": width, "height": height, "metadata": metadata, }