472 lines
20 KiB
Python
472 lines
20 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from typing import Any, Dict, Optional
|
|
|
|
import requests
|
|
from fastapi import HTTPException
|
|
from requests import exceptions as requests_exceptions
|
|
|
|
from services.onboarding.api_key_manager import APIKeyManager
|
|
from utils.logger_utils import get_service_logger
|
|
|
|
logger = get_service_logger("wavespeed.client")
|
|
|
|
|
|
class WaveSpeedClient:
|
|
"""
|
|
Thin HTTP client for the WaveSpeed AI API.
|
|
Handles authentication, submission, and polling helpers.
|
|
"""
|
|
|
|
BASE_URL = "https://api.wavespeed.ai/api/v3"
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
manager = APIKeyManager()
|
|
self.api_key = api_key or manager.get_api_key("wavespeed")
|
|
if not self.api_key:
|
|
raise RuntimeError("WAVESPEED_API_KEY is not configured. Please add it to your environment.")
|
|
|
|
def _headers(self) -> Dict[str, str]:
|
|
return {
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
}
|
|
|
|
def submit_image_to_video(
|
|
self,
|
|
model_path: str,
|
|
payload: Dict[str, Any],
|
|
timeout: int = 30,
|
|
) -> str:
|
|
"""
|
|
Submit an image-to-video generation request.
|
|
|
|
Returns the prediction ID for polling.
|
|
"""
|
|
url = f"{self.BASE_URL}/{model_path}"
|
|
logger.info(f"[WaveSpeed] Submitting request to {url}")
|
|
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
|
if response.status_code != 200:
|
|
logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"error": "WaveSpeed image-to-video submission failed",
|
|
"status_code": response.status_code,
|
|
"response": response.text,
|
|
},
|
|
)
|
|
|
|
data = response.json().get("data")
|
|
if not data or "id" not in data:
|
|
logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={"error": "WaveSpeed response missing prediction id"},
|
|
)
|
|
|
|
prediction_id = data["id"]
|
|
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
|
|
return prediction_id
|
|
|
|
def get_prediction_result(self, prediction_id: str, timeout: int = 120) -> Dict[str, Any]:
|
|
"""
|
|
Fetch the current status/result for a prediction.
|
|
"""
|
|
url = f"{self.BASE_URL}/predictions/{prediction_id}/result"
|
|
try:
|
|
response = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=timeout)
|
|
except requests_exceptions.Timeout as exc:
|
|
raise HTTPException(
|
|
status_code=504,
|
|
detail={
|
|
"error": "WaveSpeed polling request timed out",
|
|
"prediction_id": prediction_id,
|
|
"resume_available": True,
|
|
"exception": str(exc),
|
|
},
|
|
) from exc
|
|
except requests_exceptions.RequestException as exc:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"error": "WaveSpeed polling request failed",
|
|
"prediction_id": prediction_id,
|
|
"resume_available": True,
|
|
"exception": str(exc),
|
|
},
|
|
) from exc
|
|
if response.status_code != 200:
|
|
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"error": "WaveSpeed prediction polling failed",
|
|
"status_code": response.status_code,
|
|
"response": response.text,
|
|
},
|
|
)
|
|
|
|
result = response.json().get("data")
|
|
if not result:
|
|
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
|
|
return result
|
|
|
|
def poll_until_complete(
|
|
self,
|
|
prediction_id: str,
|
|
timeout_seconds: int = 240,
|
|
interval_seconds: float = 1.0,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Poll WaveSpeed until the job completes, fails, or times out.
|
|
"""
|
|
start_time = time.time()
|
|
while True:
|
|
try:
|
|
result = self.get_prediction_result(prediction_id)
|
|
except HTTPException as exc:
|
|
detail = exc.detail or {}
|
|
if isinstance(detail, dict):
|
|
detail.setdefault("prediction_id", prediction_id)
|
|
detail.setdefault("resume_available", True)
|
|
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
|
|
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
|
|
status = result.get("status")
|
|
if status == "completed":
|
|
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed.")
|
|
return result
|
|
if status == "failed":
|
|
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {result.get('error')}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"error": "WaveSpeed animation failed",
|
|
"prediction_id": prediction_id,
|
|
"details": result.get("error"),
|
|
},
|
|
)
|
|
|
|
elapsed = time.time() - start_time
|
|
if elapsed > timeout_seconds:
|
|
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
|
|
raise HTTPException(
|
|
status_code=504,
|
|
detail={
|
|
"error": "WaveSpeed animation timed out",
|
|
"prediction_id": prediction_id,
|
|
"details": result,
|
|
},
|
|
)
|
|
|
|
logger.debug(f"[WaveSpeed] Prediction {prediction_id} status={status}. Waiting...")
|
|
time.sleep(interval_seconds)
|
|
|
|
def optimize_prompt(
|
|
self,
|
|
text: str,
|
|
mode: str = "image",
|
|
style: str = "default",
|
|
image: Optional[str] = None,
|
|
enable_sync_mode: bool = True,
|
|
timeout: int = 30,
|
|
) -> str:
|
|
"""
|
|
Optimize a prompt using WaveSpeed prompt optimizer.
|
|
|
|
Args:
|
|
text: The prompt text to optimize
|
|
mode: "image" or "video" (default: "image")
|
|
style: "default", "artistic", "photographic", "technical", "anime", "realistic" (default: "default")
|
|
image: Base64-encoded image for context (optional)
|
|
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
|
timeout: Request timeout in seconds (default: 30)
|
|
|
|
Returns:
|
|
Optimized prompt text
|
|
"""
|
|
model_path = "wavespeed-ai/prompt-optimizer"
|
|
url = f"{self.BASE_URL}/{model_path}"
|
|
|
|
payload = {
|
|
"text": text,
|
|
"mode": mode,
|
|
"style": style,
|
|
"enable_sync_mode": enable_sync_mode,
|
|
}
|
|
|
|
if image:
|
|
payload["image"] = image
|
|
|
|
logger.info(f"[WaveSpeed] Optimizing prompt via {url} (mode={mode}, style={style})")
|
|
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"[WaveSpeed] Prompt optimization failed: {response.status_code} {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"error": "WaveSpeed prompt optimization failed",
|
|
"status_code": response.status_code,
|
|
"response": response.text,
|
|
},
|
|
)
|
|
|
|
response_json = response.json()
|
|
data = response_json.get("data") or response_json
|
|
|
|
# Handle sync mode - result should be directly in outputs
|
|
if enable_sync_mode:
|
|
outputs = data.get("outputs") or []
|
|
if not outputs:
|
|
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed prompt optimizer returned no outputs",
|
|
)
|
|
|
|
# Extract optimized prompt from outputs
|
|
# In sync mode, outputs[0] should be the optimized text directly (or a URL to fetch)
|
|
optimized_prompt = None
|
|
if isinstance(outputs, list) and len(outputs) > 0:
|
|
first_output = outputs[0]
|
|
|
|
# If it's a string that looks like a URL, fetch it
|
|
if isinstance(first_output, str):
|
|
if first_output.startswith("http://") or first_output.startswith("https://"):
|
|
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
|
url_response = requests.get(first_output, timeout=timeout)
|
|
if url_response.status_code == 200:
|
|
optimized_prompt = url_response.text.strip()
|
|
else:
|
|
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
|
)
|
|
else:
|
|
# It's already the text
|
|
optimized_prompt = first_output
|
|
elif isinstance(first_output, dict):
|
|
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
|
|
|
if not optimized_prompt:
|
|
logger.error(f"[WaveSpeed] Could not extract optimized prompt from outputs: {outputs}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed prompt optimizer output format not recognized",
|
|
)
|
|
|
|
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
|
return optimized_prompt
|
|
|
|
# Async mode - return prediction ID for polling
|
|
prediction_id = data.get("id")
|
|
if not prediction_id:
|
|
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed response missing prediction id for async mode",
|
|
)
|
|
|
|
# Poll for result
|
|
result = self.poll_until_complete(prediction_id, timeout_seconds=60, interval_seconds=0.5)
|
|
outputs = result.get("outputs") or []
|
|
|
|
if not outputs:
|
|
raise HTTPException(status_code=502, detail="WaveSpeed prompt optimizer returned no outputs")
|
|
|
|
# Extract optimized prompt from outputs
|
|
# In async mode, outputs[0] is typically a URL that needs to be fetched
|
|
optimized_prompt = None
|
|
if isinstance(outputs, list) and len(outputs) > 0:
|
|
first_output = outputs[0]
|
|
|
|
# In async mode, it's usually a URL to fetch
|
|
if isinstance(first_output, str):
|
|
if first_output.startswith("http://") or first_output.startswith("https://"):
|
|
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
|
url_response = requests.get(first_output, timeout=timeout)
|
|
if url_response.status_code == 200:
|
|
optimized_prompt = url_response.text.strip()
|
|
else:
|
|
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
|
)
|
|
else:
|
|
# If it's already text (shouldn't happen in async mode, but handle it)
|
|
optimized_prompt = first_output
|
|
elif isinstance(first_output, dict):
|
|
optimized_prompt = first_output.get("text") or first_output.get("prompt") or first_output.get("output")
|
|
|
|
if not optimized_prompt:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed prompt optimizer output format not recognized",
|
|
)
|
|
|
|
logger.info(f"[WaveSpeed] Prompt optimized successfully (length: {len(optimized_prompt)} chars)")
|
|
return optimized_prompt
|
|
|
|
def generate_speech(
|
|
self,
|
|
text: str,
|
|
voice_id: str,
|
|
speed: float = 1.0,
|
|
volume: float = 1.0,
|
|
pitch: float = 0.0,
|
|
emotion: str = "happy",
|
|
enable_sync_mode: bool = True,
|
|
timeout: int = 60,
|
|
**kwargs
|
|
) -> bytes:
|
|
"""
|
|
Generate speech audio using Minimax Speech 02 HD via WaveSpeed.
|
|
|
|
Args:
|
|
text: Text to convert to speech (max 10000 characters)
|
|
voice_id: Voice ID (e.g., "Wise_Woman", "Friendly_Person", etc.)
|
|
speed: Speech speed (0.5-2.0, default: 1.0)
|
|
volume: Speech volume (0.1-10.0, default: 1.0)
|
|
pitch: Speech pitch (-12 to 12, default: 0.0)
|
|
emotion: Emotion ("happy", "sad", "angry", etc., default: "happy")
|
|
enable_sync_mode: If True, wait for result and return it directly (default: True)
|
|
timeout: Request timeout in seconds (default: 60)
|
|
**kwargs: Additional parameters (sample_rate, bitrate, format, etc.)
|
|
|
|
Returns:
|
|
bytes: Generated audio bytes
|
|
"""
|
|
model_path = "minimax/speech-02-hd"
|
|
url = f"{self.BASE_URL}/{model_path}"
|
|
|
|
payload = {
|
|
"text": text,
|
|
"voice_id": voice_id,
|
|
"speed": speed,
|
|
"volume": volume,
|
|
"pitch": pitch,
|
|
"emotion": emotion,
|
|
"enable_sync_mode": enable_sync_mode,
|
|
}
|
|
|
|
# Add optional parameters
|
|
optional_params = [
|
|
"english_normalization",
|
|
"sample_rate",
|
|
"bitrate",
|
|
"channel",
|
|
"format",
|
|
"language_boost",
|
|
]
|
|
for param in optional_params:
|
|
if param in kwargs:
|
|
payload[param] = kwargs[param]
|
|
|
|
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
|
|
response = requests.post(url, headers=self._headers(), json=payload, timeout=timeout)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"[WaveSpeed] Speech generation failed: {response.status_code} {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"error": "WaveSpeed speech generation failed",
|
|
"status_code": response.status_code,
|
|
"response": response.text,
|
|
},
|
|
)
|
|
|
|
response_json = response.json()
|
|
data = response_json.get("data") or response_json
|
|
|
|
# Handle sync mode - result should be directly in outputs
|
|
if enable_sync_mode:
|
|
outputs = data.get("outputs") or []
|
|
if not outputs:
|
|
logger.error(f"[WaveSpeed] No outputs in sync mode response: {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed speech generator returned no outputs",
|
|
)
|
|
|
|
# Extract audio URL from outputs
|
|
audio_url = None
|
|
if isinstance(outputs, list) and len(outputs) > 0:
|
|
first_output = outputs[0]
|
|
if isinstance(first_output, str):
|
|
audio_url = first_output
|
|
elif isinstance(first_output, dict):
|
|
audio_url = first_output.get("url") or first_output.get("output")
|
|
|
|
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
|
logger.error(f"[WaveSpeed] Invalid audio URL in outputs: {outputs}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed speech generator output format not recognized",
|
|
)
|
|
|
|
# Fetch audio bytes from URL
|
|
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
|
audio_response = requests.get(audio_url, timeout=timeout)
|
|
if audio_response.status_code == 200:
|
|
audio_bytes = audio_response.content
|
|
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
|
return audio_bytes
|
|
else:
|
|
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="Failed to fetch generated audio from WaveSpeed URL",
|
|
)
|
|
|
|
# Async mode - return prediction ID for polling
|
|
prediction_id = data.get("id")
|
|
if not prediction_id:
|
|
logger.error(f"[WaveSpeed] No prediction ID in async response: {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed response missing prediction id for async mode",
|
|
)
|
|
|
|
# Poll for result
|
|
result = self.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=0.5)
|
|
outputs = result.get("outputs") or []
|
|
|
|
if not outputs:
|
|
raise HTTPException(status_code=502, detail="WaveSpeed speech generator returned no outputs")
|
|
|
|
# Extract audio URL and fetch
|
|
audio_url = None
|
|
if isinstance(outputs, list) and len(outputs) > 0:
|
|
first_output = outputs[0]
|
|
if isinstance(first_output, str):
|
|
audio_url = first_output
|
|
elif isinstance(first_output, dict):
|
|
audio_url = first_output.get("url") or first_output.get("output")
|
|
|
|
if not audio_url or not (audio_url.startswith("http://") or audio_url.startswith("https://")):
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed speech generator output format not recognized",
|
|
)
|
|
|
|
# Fetch audio bytes
|
|
logger.info(f"[WaveSpeed] Fetching audio from URL: {audio_url}")
|
|
audio_response = requests.get(audio_url, timeout=timeout)
|
|
if audio_response.status_code == 200:
|
|
audio_bytes = audio_response.content
|
|
logger.info(f"[WaveSpeed] Speech generated successfully (size: {len(audio_bytes)} bytes)")
|
|
return audio_bytes
|
|
else:
|
|
logger.error(f"[WaveSpeed] Failed to fetch audio from URL: {audio_response.status_code}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="Failed to fetch generated audio from WaveSpeed URL",
|
|
)
|
|
|