Base code
This commit is contained in:
203
backend/services/wavespeed/polling.py
Normal file
203
backend/services/wavespeed/polling.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
Polling utilities for WaveSpeed API.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from requests import exceptions as requests_exceptions
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.polling")
|
||||
|
||||
|
||||
class WaveSpeedPolling:
|
||||
"""Polling utilities for WaveSpeed API predictions."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
"""Initialize polling utilities.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
|
||||
def _get_headers(self) -> Dict[str, str]:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
def get_prediction_result(self, prediction_id: str, timeout: int = 30) -> Dict[str, Any]:
|
||||
"""
|
||||
Fetch the current status/result for a prediction.
|
||||
Matches the example pattern: simple GET request, check status_code == 200, return data.
|
||||
"""
|
||||
url = f"{self.base_url}/predictions/{prediction_id}/result"
|
||||
headers = self._get_headers()
|
||||
|
||||
try:
|
||||
response = requests.get(url, headers=headers, timeout=timeout)
|
||||
except requests_exceptions.Timeout as exc:
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
except requests_exceptions.RequestException as exc:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed polling request failed",
|
||||
"prediction_id": prediction_id,
|
||||
"resume_available": True,
|
||||
"exception": str(exc),
|
||||
},
|
||||
) from exc
|
||||
|
||||
# Match example pattern: check status_code == 200, then get data
|
||||
if response.status_code == 200:
|
||||
result = response.json().get("data")
|
||||
if not result:
|
||||
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
|
||||
return result
|
||||
else:
|
||||
# Non-200 status - log and raise error (matching example's break behavior)
|
||||
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed prediction polling failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
def poll_until_complete(
|
||||
self,
|
||||
prediction_id: str,
|
||||
timeout_seconds: Optional[int] = None,
|
||||
interval_seconds: float = 1.0,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Poll WaveSpeed until the job completes or fails.
|
||||
Matches the example pattern: simple polling loop until status is "completed" or "failed".
|
||||
|
||||
Args:
|
||||
prediction_id: The prediction ID to poll for
|
||||
timeout_seconds: Optional timeout in seconds. If None, polls indefinitely until completion/failure.
|
||||
interval_seconds: Seconds to wait between polling attempts (default: 1.0, faster than 2.0)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
Dict containing the completed result
|
||||
|
||||
Raises:
|
||||
HTTPException: If the task fails, polling fails, or times out (if timeout_seconds is set)
|
||||
"""
|
||||
start_time = time.time()
|
||||
consecutive_errors = 0
|
||||
max_consecutive_errors = 6 # safety guard for non-transient errors
|
||||
|
||||
while True:
|
||||
try:
|
||||
result = self.get_prediction_result(prediction_id)
|
||||
consecutive_errors = 0 # Reset error counter on success
|
||||
except HTTPException as exc:
|
||||
detail = exc.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
|
||||
|
||||
# Determine underlying status code (WaveSpeed vs proxy)
|
||||
status_code = detail.get("status_code", exc.status_code)
|
||||
|
||||
# Treat 5xx as transient: keep polling indefinitely with backoff
|
||||
if 500 <= int(status_code) < 600:
|
||||
consecutive_errors += 1
|
||||
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Transient polling error {consecutive_errors} for {prediction_id}: "
|
||||
f"{status_code}. Backing off {backoff:.1f}s"
|
||||
)
|
||||
time.sleep(backoff)
|
||||
continue
|
||||
|
||||
# For non-transient (typically 4xx) errors, apply safety cap
|
||||
consecutive_errors += 1
|
||||
if consecutive_errors >= max_consecutive_errors:
|
||||
logger.error(
|
||||
f"[WaveSpeed] Too many polling errors ({consecutive_errors}) for {prediction_id}, "
|
||||
f"status_code={status_code}. Giving up."
|
||||
)
|
||||
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
|
||||
|
||||
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
|
||||
logger.warning(
|
||||
f"[WaveSpeed] Polling error {consecutive_errors}/{max_consecutive_errors} for {prediction_id}: "
|
||||
f"{status_code}. Backing off {backoff:.1f}s"
|
||||
)
|
||||
time.sleep(backoff)
|
||||
continue
|
||||
|
||||
# Extract status from result (matching example pattern)
|
||||
status = result.get("status")
|
||||
|
||||
if status == "completed":
|
||||
elapsed = time.time() - start_time
|
||||
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed in {elapsed:.1f}s")
|
||||
return result
|
||||
|
||||
if status == "failed":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {error_msg}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed task failed",
|
||||
"prediction_id": prediction_id,
|
||||
"message": error_msg,
|
||||
"details": result,
|
||||
},
|
||||
)
|
||||
|
||||
# Check timeout only if specified
|
||||
if timeout_seconds is not None:
|
||||
elapsed = time.time() - start_time
|
||||
if elapsed > timeout_seconds:
|
||||
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
|
||||
raise HTTPException(
|
||||
status_code=504,
|
||||
detail={
|
||||
"error": "WaveSpeed task timed out",
|
||||
"prediction_id": prediction_id,
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"current_status": status,
|
||||
"message": f"Task did not complete within {timeout_seconds} seconds. Status: {status}",
|
||||
},
|
||||
)
|
||||
|
||||
# Log progress periodically (every 30 seconds)
|
||||
elapsed = time.time() - start_time
|
||||
if int(elapsed) % 30 == 0 and elapsed > 0:
|
||||
logger.info(f"[WaveSpeed] Polling {prediction_id}: status={status}, elapsed={elapsed:.0f}s")
|
||||
|
||||
# Call progress callback if provided
|
||||
if progress_callback:
|
||||
# Map elapsed time to progress (20-80% range during polling)
|
||||
# Assume typical completion time is timeout_seconds or 120s default
|
||||
estimated_total = timeout_seconds or 120
|
||||
progress = min(80.0, 20.0 + (elapsed / estimated_total) * 60.0)
|
||||
progress_callback(progress, f"Video generation in progress... ({elapsed:.0f}s)")
|
||||
|
||||
# Poll faster (1.0s instead of 2.0s) to match example's responsiveness
|
||||
time.sleep(interval_seconds)
|
||||
Reference in New Issue
Block a user