Files
Kunthawat Greethong c35fa52117 Base code
2026-01-08 22:39:53 +07:00

204 lines
8.7 KiB
Python

"""
Polling utilities for WaveSpeed API.
"""
import time
from typing import Any, Dict, Optional, Callable
import requests
from fastapi import HTTPException
from requests import exceptions as requests_exceptions
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.polling")
class WaveSpeedPolling:
"""Polling utilities for WaveSpeed API predictions."""
def __init__(self, api_key: str, base_url: str):
"""Initialize polling utilities.
Args:
api_key: WaveSpeed API key
base_url: WaveSpeed API base URL
"""
self.api_key = api_key
self.base_url = base_url
def _get_headers(self) -> Dict[str, str]:
"""Get HTTP headers for API requests."""
return {"Authorization": f"Bearer {self.api_key}"}
def get_prediction_result(self, prediction_id: str, timeout: int = 30) -> Dict[str, Any]:
"""
Fetch the current status/result for a prediction.
Matches the example pattern: simple GET request, check status_code == 200, return data.
"""
url = f"{self.base_url}/predictions/{prediction_id}/result"
headers = self._get_headers()
try:
response = requests.get(url, headers=headers, timeout=timeout)
except requests_exceptions.Timeout as exc:
raise HTTPException(
status_code=504,
detail={
"error": "WaveSpeed polling request timed out",
"prediction_id": prediction_id,
"resume_available": True,
"exception": str(exc),
},
) from exc
except requests_exceptions.RequestException as exc:
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed polling request failed",
"prediction_id": prediction_id,
"resume_available": True,
"exception": str(exc),
},
) from exc
# Match example pattern: check status_code == 200, then get data
if response.status_code == 200:
result = response.json().get("data")
if not result:
raise HTTPException(status_code=502, detail={"error": "WaveSpeed polling response missing data"})
return result
else:
# Non-200 status - log and raise error (matching example's break behavior)
logger.error(f"[WaveSpeed] Polling failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed prediction polling failed",
"status_code": response.status_code,
"response": response.text,
},
)
def poll_until_complete(
self,
prediction_id: str,
timeout_seconds: Optional[int] = None,
interval_seconds: float = 1.0,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> Dict[str, Any]:
"""
Poll WaveSpeed until the job completes or fails.
Matches the example pattern: simple polling loop until status is "completed" or "failed".
Args:
prediction_id: The prediction ID to poll for
timeout_seconds: Optional timeout in seconds. If None, polls indefinitely until completion/failure.
interval_seconds: Seconds to wait between polling attempts (default: 1.0, faster than 2.0)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
Dict containing the completed result
Raises:
HTTPException: If the task fails, polling fails, or times out (if timeout_seconds is set)
"""
start_time = time.time()
consecutive_errors = 0
max_consecutive_errors = 6 # safety guard for non-transient errors
while True:
try:
result = self.get_prediction_result(prediction_id)
consecutive_errors = 0 # Reset error counter on success
except HTTPException as exc:
detail = exc.detail or {}
if isinstance(detail, dict):
detail.setdefault("prediction_id", prediction_id)
detail.setdefault("resume_available", True)
detail.setdefault("error", detail.get("error", "WaveSpeed polling failed"))
# Determine underlying status code (WaveSpeed vs proxy)
status_code = detail.get("status_code", exc.status_code)
# Treat 5xx as transient: keep polling indefinitely with backoff
if 500 <= int(status_code) < 600:
consecutive_errors += 1
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
logger.warning(
f"[WaveSpeed] Transient polling error {consecutive_errors} for {prediction_id}: "
f"{status_code}. Backing off {backoff:.1f}s"
)
time.sleep(backoff)
continue
# For non-transient (typically 4xx) errors, apply safety cap
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
logger.error(
f"[WaveSpeed] Too many polling errors ({consecutive_errors}) for {prediction_id}, "
f"status_code={status_code}. Giving up."
)
raise HTTPException(status_code=exc.status_code, detail=detail) from exc
backoff = min(30.0, interval_seconds * (2 ** (consecutive_errors - 1)))
logger.warning(
f"[WaveSpeed] Polling error {consecutive_errors}/{max_consecutive_errors} for {prediction_id}: "
f"{status_code}. Backing off {backoff:.1f}s"
)
time.sleep(backoff)
continue
# Extract status from result (matching example pattern)
status = result.get("status")
if status == "completed":
elapsed = time.time() - start_time
logger.info(f"[WaveSpeed] Prediction {prediction_id} completed in {elapsed:.1f}s")
return result
if status == "failed":
error_msg = result.get("error", "Unknown error")
logger.error(f"[WaveSpeed] Prediction {prediction_id} failed: {error_msg}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed task failed",
"prediction_id": prediction_id,
"message": error_msg,
"details": result,
},
)
# Check timeout only if specified
if timeout_seconds is not None:
elapsed = time.time() - start_time
if elapsed > timeout_seconds:
logger.error(f"[WaveSpeed] Prediction {prediction_id} timed out after {timeout_seconds}s")
raise HTTPException(
status_code=504,
detail={
"error": "WaveSpeed task timed out",
"prediction_id": prediction_id,
"timeout_seconds": timeout_seconds,
"current_status": status,
"message": f"Task did not complete within {timeout_seconds} seconds. Status: {status}",
},
)
# Log progress periodically (every 30 seconds)
elapsed = time.time() - start_time
if int(elapsed) % 30 == 0 and elapsed > 0:
logger.info(f"[WaveSpeed] Polling {prediction_id}: status={status}, elapsed={elapsed:.0f}s")
# Call progress callback if provided
if progress_callback:
# Map elapsed time to progress (20-80% range during polling)
# Assume typical completion time is timeout_seconds or 120s default
estimated_total = timeout_seconds or 120
progress = min(80.0, 20.0 + (elapsed / estimated_total) * 60.0)
progress_callback(progress, f"Video generation in progress... ({elapsed:.0f}s)")
# Poll faster (1.0s instead of 2.0s) to match example's responsiveness
time.sleep(interval_seconds)