Base code
This commit is contained in:
1
backend/services/wavespeed/generators/__init__.py
Normal file
1
backend/services/wavespeed/generators/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""WaveSpeed API generators for different content types."""
|
||||
374
backend/services/wavespeed/generators/image.py
Normal file
374
backend/services/wavespeed/generators/image.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Image generation generator for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import time
|
||||
import requests
|
||||
from typing import Optional
|
||||
from requests import exceptions as requests_exceptions
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.image")
|
||||
|
||||
|
||||
class ImageGenerator:
|
||||
"""Image generation generator."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize image generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
num_inference_steps: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 120,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image using WaveSpeed AI models (Ideogram V3 or Qwen Image).
|
||||
|
||||
Args:
|
||||
model: Model to use ("ideogram-v3-turbo" or "qwen-image")
|
||||
prompt: Text prompt for image generation
|
||||
width: Image width (default: 1024)
|
||||
height: Image height (default: 1024)
|
||||
num_inference_steps: Number of inference steps
|
||||
guidance_scale: Guidance scale for generation
|
||||
negative_prompt: Negative prompt (what to avoid)
|
||||
seed: Random seed for reproducibility
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 120)
|
||||
**kwargs: Additional parameters
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes
|
||||
"""
|
||||
# Map model names to WaveSpeed API paths
|
||||
model_paths = {
|
||||
"ideogram-v3-turbo": "ideogram-ai/ideogram-v3-turbo",
|
||||
"qwen-image": "wavespeed-ai/qwen-image/text-to-image",
|
||||
}
|
||||
|
||||
model_path = model_paths.get(model)
|
||||
if not model_path:
|
||||
raise ValueError(f"Unsupported image model: {model}. Supported: {list(model_paths.keys())}")
|
||||
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if num_inference_steps is not None:
|
||||
payload["num_inference_steps"] = num_inference_steps
|
||||
if guidance_scale is not None:
|
||||
payload["guidance_scale"] = guidance_scale
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Add any extra parameters
|
||||
for key, value in kwargs.items():
|
||||
if key not in payload:
|
||||
payload[key] = value
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating image via {url} (model={model}, prompt_length={len(prompt)})")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Image generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status - if "created" or "processing", we need to poll even in sync mode
|
||||
status = data.get("status", "").lower()
|
||||
outputs = data.get("outputs") or []
|
||||
prediction_id = data.get("id")
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
# If we have outputs and status is "completed", use them directly
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed] Got immediate results from sync mode (status: {status})")
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout)
|
||||
|
||||
# Sync mode returned "created" or "processing" status - need to poll
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] Sync mode returned status '{status}' but no prediction ID: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed sync mode returned async response without prediction ID",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Sync mode returned status '{status}' with no outputs. "
|
||||
f"Falling back to polling (prediction_id: {prediction_id})"
|
||||
)
|
||||
|
||||
# Async mode OR sync mode that returned "created"/"processing" - poll for result
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id",
|
||||
)
|
||||
|
||||
# Poll for result (use longer timeout for image generation)
|
||||
logger.info(f"[WaveSpeed] Polling for image generation result (prediction_id: {prediction_id}, status: {status})")
|
||||
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=240, interval_seconds=1.0)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed image generator returned no outputs")
|
||||
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=60)
|
||||
|
||||
def generate_character_image(
|
||||
self,
|
||||
prompt: str,
|
||||
reference_image_bytes: bytes,
|
||||
style: str = "Auto",
|
||||
aspect_ratio: str = "16:9",
|
||||
rendering_speed: str = "Default",
|
||||
timeout: Optional[int] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image using Ideogram Character API to maintain character consistency.
|
||||
Creates variations of a reference character image while respecting the base appearance.
|
||||
|
||||
Note: This API is always async and requires polling for results.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt describing the scene/context for the character
|
||||
reference_image_bytes: Reference image bytes (base avatar)
|
||||
style: Character style type ("Auto", "Fiction", or "Realistic")
|
||||
aspect_ratio: Aspect ratio ("1:1", "16:9", "9:16", "4:3", "3:4")
|
||||
rendering_speed: Rendering speed ("Default", "Turbo", "Quality")
|
||||
timeout: Total timeout in seconds for submission + polling (default: 180)
|
||||
|
||||
Returns:
|
||||
bytes: Generated image bytes with consistent character
|
||||
"""
|
||||
import base64
|
||||
|
||||
# Encode reference image to base64
|
||||
image_base64 = base64.b64encode(reference_image_bytes).decode('utf-8')
|
||||
# Add data URI prefix
|
||||
image_data_uri = f"data:image/png;base64,{image_base64}"
|
||||
|
||||
url = f"{self.base_url}/ideogram-ai/ideogram-character"
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"image": image_data_uri,
|
||||
"style": style,
|
||||
"aspect_ratio": aspect_ratio,
|
||||
"rendering_speed": rendering_speed,
|
||||
}
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating character image via Ideogram Character (prompt_length={len(prompt)})")
|
||||
|
||||
# Retry on transient connection failures
|
||||
max_retries = 2
|
||||
retry_delay = 2.0
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
timeout=(30, 30)
|
||||
)
|
||||
break
|
||||
except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(f"[WaveSpeed] Connection attempt {attempt + 1}/{max_retries + 1} failed, retrying in {retry_delay}s: {e}")
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
else:
|
||||
error_type = "Connection timeout" if isinstance(e, requests_exceptions.ConnectTimeout) else "Connection error"
|
||||
logger.error(f"[WaveSpeed] {error_type} to Ideogram Character API after {max_retries + 1} attempts: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504 if isinstance(e, requests_exceptions.ConnectTimeout) else 502,
|
||||
detail={
|
||||
"error": f"{error_type} to WaveSpeed Ideogram Character API",
|
||||
"message": "Unable to establish connection to the image generation service after multiple attempts. Please check your network connection and try again.",
|
||||
"exception": str(e),
|
||||
"retry_recommended": True,
|
||||
},
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
logger.error(f"[WaveSpeed] Request timeout to Ideogram Character API: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "Request timeout to WaveSpeed Ideogram Character API",
|
||||
"message": "The image generation request took too long. Please try again.",
|
||||
"exception": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Character image generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Ideogram Character generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Extract prediction ID
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character response missing prediction id",
|
||||
)
|
||||
|
||||
# Ideogram Character API is always async - check status and poll if needed
|
||||
outputs = data.get("outputs") or []
|
||||
status = data.get("status", "unknown")
|
||||
|
||||
logger.info(f"[WaveSpeed] Ideogram Character task created: prediction_id={prediction_id}, status={status}")
|
||||
|
||||
# If status is already completed, use outputs directly (unlikely but possible)
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed] Got immediate results from Ideogram Character")
|
||||
else:
|
||||
# Always need to poll for results (API is async)
|
||||
logger.info(f"[WaveSpeed] Polling for Ideogram Character result (status: {status}, prediction_id: {prediction_id})")
|
||||
polling_timeout = timeout if timeout else None
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=polling_timeout,
|
||||
interval_seconds=0.5,
|
||||
)
|
||||
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"[WaveSpeed] Unexpected result type: {type(result)}, value: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character returned unexpected response format",
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
status = result.get("status", "unknown")
|
||||
|
||||
if status != "completed":
|
||||
error_msg = "Unknown error"
|
||||
if isinstance(result, dict):
|
||||
error_msg = result.get("error") or result.get("message") or str(result.get("details", "Unknown error"))
|
||||
else:
|
||||
error_msg = str(result)
|
||||
|
||||
logger.error(f"[WaveSpeed] Ideogram Character task did not complete: status={status}, error={error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Ideogram Character task failed",
|
||||
"status": status,
|
||||
"message": error_msg,
|
||||
}
|
||||
)
|
||||
|
||||
# Extract image URL from outputs
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs after polling: status={status}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Ideogram Character returned no outputs",
|
||||
)
|
||||
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=60)
|
||||
|
||||
def _extract_image_url(self, outputs: list) -> str:
|
||||
"""Extract image URL from outputs."""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed image generator output format not recognized",
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
def _download_image(self, image_url: str, timeout: int = 60) -> bytes:
|
||||
"""Download image from URL."""
|
||||
logger.info(f"[WaveSpeed] Fetching image from URL: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
if image_response.status_code == 200:
|
||||
image_bytes = image_response.content
|
||||
logger.info(f"[WaveSpeed] Image generated successfully (size: {len(image_bytes)} bytes)")
|
||||
return image_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch image from URL: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated image from WaveSpeed URL",
|
||||
)
|
||||
164
backend/services/wavespeed/generators/prompt.py
Normal file
164
backend/services/wavespeed/generators/prompt.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Prompt optimization generator for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.prompt")
|
||||
|
||||
|
||||
class PromptGenerator:
|
||||
"""Prompt optimization generator."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize prompt generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def optimize_prompt(
|
||||
self,
|
||||
text: str,
|
||||
mode: str = "image",
|
||||
style: str = "default",
|
||||
image: Optional[str] = None,
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Optimize a prompt using WaveSpeed prompt optimizer.
|
||||
|
||||
Args:
|
||||
text: The prompt text to optimize
|
||||
mode: "image" or "video" (default: "image")
|
||||
style: "default", "artistic", "photographic", "technical", "anime", "realistic" (default: "default")
|
||||
image: Base64-encoded image for context (optional)
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 30)
|
||||
|
||||
Returns:
|
||||
Optimized prompt text
|
||||
"""
|
||||
model_path = "wavespeed-ai/prompt-optimizer"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"mode": mode,
|
||||
"style": style,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
if image:
|
||||
payload["image"] = image
|
||||
|
||||
logger.info(f"[WaveSpeed] Optimizing prompt via {url} (mode={mode}, style={style})")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Prompt optimization failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prompt optimization failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer returned no outputs",
|
||||
)
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
optimized_prompt = self._extract_prompt_from_outputs(outputs, timeout)
|
||||
if not optimized_prompt:
|
||||
logger.error(f"[WaveSpeed] Could not extract optimized prompt from outputs: {outputs}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=60, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed prompt optimizer returned no outputs")
|
||||
|
||||
# Extract optimized prompt from outputs
|
||||
optimized_prompt = self._extract_prompt_from_outputs(outputs, timeout)
|
||||
if not optimized_prompt:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer output format not recognized",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
||||
return optimized_prompt
|
||||
|
||||
def _extract_prompt_from_outputs(self, outputs: list, timeout: int) -> Optional[str]:
|
||||
"""Extract optimized prompt from outputs, handling URLs and direct text."""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
return None
|
||||
|
||||
first_output = outputs[0]
|
||||
|
||||
# If it's a string that looks like a URL, fetch it
|
||||
if isinstance(first_output, str):
|
||||
if first_output.startswith("http://") or first_output.startswith("https://"):
|
||||
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
||||
url_response = requests.get(first_output, timeout=timeout)
|
||||
if url_response.status_code == 200:
|
||||
return url_response.text.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
)
|
||||
else:
|
||||
# It's already the text
|
||||
return first_output
|
||||
elif isinstance(first_output, dict):
|
||||
return first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
||||
|
||||
return None
|
||||
223
backend/services/wavespeed/generators/speech.py
Normal file
223
backend/services/wavespeed/generators/speech.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
Speech generation generator for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import time
|
||||
import requests
|
||||
from typing import Optional
|
||||
from requests import exceptions as requests_exceptions
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.speech")
|
||||
|
||||
|
||||
class SpeechGenerator:
|
||||
"""Speech generation generator."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize speech generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def generate_speech(
|
||||
self,
|
||||
text: str,
|
||||
voice_id: str,
|
||||
speed: float = 1.0,
|
||||
volume: float = 1.0,
|
||||
pitch: float = 0.0,
|
||||
emotion: str = "happy",
|
||||
enable_sync_mode: bool = True,
|
||||
timeout: int = 120,
|
||||
**kwargs
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate speech audio using Minimax Speech 02 HD via WaveSpeed.
|
||||
|
||||
Args:
|
||||
text: Text to convert to speech (max 10000 characters)
|
||||
voice_id: Voice ID (e.g., "Wise_Woman", "Friendly_Person", etc.)
|
||||
speed: Speech speed (0.5-2.0, default: 1.0)
|
||||
volume: Speech volume (0.1-10.0, default: 1.0)
|
||||
pitch: Speech pitch (-12 to 12, default: 0.0)
|
||||
emotion: Emotion ("happy", "sad", "angry", etc., default: "happy")
|
||||
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
||||
timeout: Request timeout in seconds (default: 60)
|
||||
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
|
||||
|
||||
Returns:
|
||||
bytes: Generated audio bytes
|
||||
"""
|
||||
model_path = "minimax/speech-02-hd"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"volume": volume,
|
||||
"pitch": pitch,
|
||||
"emotion": emotion,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
optional_params = [
|
||||
"english_normalization",
|
||||
"sample_rate",
|
||||
"bitrate",
|
||||
"channel",
|
||||
"format",
|
||||
"language_boost",
|
||||
]
|
||||
for param in optional_params:
|
||||
if param in kwargs:
|
||||
payload[param] = kwargs[param]
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
|
||||
|
||||
# Retry on transient connection issues
|
||||
max_retries = 2
|
||||
retry_delay = 2.0
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
timeout=(30, 60), # connect, read
|
||||
)
|
||||
break
|
||||
except (requests_exceptions.ConnectTimeout, requests_exceptions.ConnectionError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Speech connection attempt {attempt + 1}/{max_retries + 1} failed, "
|
||||
f"retrying in {retry_delay}s: {e}"
|
||||
)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
logger.error(f"[WaveSpeed] Speech connection failed after {max_retries + 1} attempts: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "Connection to WaveSpeed speech API timed out",
|
||||
"message": "Unable to reach the speech service. Please try again.",
|
||||
"exception": str(e),
|
||||
"retry_recommended": True,
|
||||
},
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
logger.error(f"[WaveSpeed] Speech request timeout: {e}")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed speech request timed out",
|
||||
"message": "The speech generation request took too long. Please try again.",
|
||||
"exception": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed speech generation failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if enable_sync_mode:
|
||||
outputs = data.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator returned no outputs",
|
||||
)
|
||||
|
||||
audio_url = self._extract_audio_url(outputs)
|
||||
return self._download_audio(audio_url, timeout)
|
||||
|
||||
# Async mode - return prediction ID for polling
|
||||
prediction_id = data.get("id")
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed response missing prediction id for async mode",
|
||||
)
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=0.5)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed speech generator returned no outputs")
|
||||
|
||||
audio_url = self._extract_audio_url(outputs)
|
||||
return self._download_audio(audio_url, timeout)
|
||||
|
||||
def _extract_audio_url(self, outputs: list) -> str:
|
||||
"""Extract audio URL from outputs."""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
audio_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
audio_url = first_output.get("url") or first_output.get("output")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed speech generator output format not recognized",
|
||||
)
|
||||
|
||||
return audio_url
|
||||
|
||||
def _download_audio(self, audio_url: str, timeout: int) -> bytes:
|
||||
"""Download audio from URL."""
|
||||
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
||||
audio_response = requests.get(audio_url, timeout=timeout)
|
||||
if audio_response.status_code == 200:
|
||||
audio_bytes = audio_response.content
|
||||
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
||||
return audio_bytes
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch generated audio from WaveSpeed URL",
|
||||
)
|
||||
1330
backend/services/wavespeed/generators/video.py
Normal file
1330
backend/services/wavespeed/generators/video.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user