AI Researcher and Video Studio implementation complete
This commit is contained in:
9
backend/services/wavespeed/generators/video/__init__.py
Normal file
9
backend/services/wavespeed/generators/video/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Video generation generator for WaveSpeed API.
|
||||
|
||||
Modular implementation with separate modules for different video operations.
|
||||
"""
|
||||
|
||||
from .generator import VideoGenerator
|
||||
|
||||
__all__ = ["VideoGenerator"]
|
||||
244
backend/services/wavespeed/generators/video/audio.py
Normal file
244
backend/services/wavespeed/generators/video/audio.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Video audio generation operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.audio")
|
||||
|
||||
|
||||
class VideoAudio(VideoBase):
|
||||
"""Video audio generation operations."""
|
||||
|
||||
def hunyuan_video_foley(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: Optional[str] = None, # Optional text prompt describing desired sounds
|
||||
seed: int = -1, # Random seed (-1 for random)
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate realistic Foley and ambient audio from video using Hunyuan Video Foley.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
prompt: Optional text prompt describing desired sounds (e.g., "ocean waves, seagulls")
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Video with generated audio
|
||||
|
||||
Raises:
|
||||
HTTPException: If the audio generation fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/hunyuan-video-foley"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
if prompt:
|
||||
payload["prompt"] = prompt
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Hunyuan Video Foley request via {url} "
|
||||
f"(has_prompt={prompt is not None}, seed={seed})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Hunyuan Video Foley submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Hunyuan Video Foley submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in Hunyuan Video Foley response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Hunyuan Video Foley response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Hunyuan Video Foley task submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Hunyuan Video Foley returned no outputs")
|
||||
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Hunyuan Video Foley output format not recognized")
|
||||
|
||||
logger.info(f"[WaveSpeed] Downloading video with audio from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download video with audio: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download video with audio from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Hunyuan Video Foley completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for Hunyuan Video Foley",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
|
||||
def think_sound(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: Optional[str] = None, # Optional text prompt describing desired sounds
|
||||
seed: int = -1, # Random seed (-1 for random)
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate realistic sound effects and audio tracks from video using Think Sound.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
prompt: Optional text prompt describing desired sounds (e.g., "engine roaring, footsteps on gravel")
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Video with generated audio
|
||||
|
||||
Raises:
|
||||
HTTPException: If the audio generation fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/think-sound"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"seed": seed,
|
||||
}
|
||||
|
||||
if prompt:
|
||||
payload["prompt"] = prompt
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Think Sound request via {url} "
|
||||
f"(has_prompt={prompt is not None}, seed={seed})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Think Sound submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed Think Sound submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in Think Sound response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed Think Sound response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Think Sound task submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Think Sound returned no outputs")
|
||||
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed Think Sound output format not recognized")
|
||||
|
||||
logger.info(f"[WaveSpeed] Downloading video with audio from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download video with audio: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download video with audio from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Think Sound completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for Think Sound",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
127
backend/services/wavespeed/generators/video/background.py
Normal file
127
backend/services/wavespeed/generators/video/background.py
Normal file
@@ -0,0 +1,127 @@
|
||||
"""
|
||||
Video background removal operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.background")
|
||||
|
||||
|
||||
class VideoBackground(VideoBase):
|
||||
"""Video background removal operations."""
|
||||
|
||||
def remove_background(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
background_image: Optional[str] = None, # Base64-encoded image or URL (optional)
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Remove or replace video background using Video Background Remover.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
background_image: Optional base64-encoded image data URI or public URL (replacement background)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Video with background removed/replaced
|
||||
|
||||
Raises:
|
||||
HTTPException: If the background removal fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/video-background-remover"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
}
|
||||
|
||||
if background_image:
|
||||
payload["background_image"] = background_image
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Video background removal request via {url} "
|
||||
f"(has_background={background_image is not None})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video background removal submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video background removal submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in video background removal response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed video background removal response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Video background removal task submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video background removal returned no outputs")
|
||||
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video background removal output format not recognized")
|
||||
|
||||
logger.info(f"[WaveSpeed] Downloading processed video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download processed video: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download processed video from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video background removal completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for video background removal",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
84
backend/services/wavespeed/generators/video/base.py
Normal file
84
backend/services/wavespeed/generators/video/base.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""
|
||||
Base functionality for video operations.
|
||||
|
||||
Shared utilities for HTTP requests, video download, and common operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.base")
|
||||
|
||||
|
||||
class VideoBase:
|
||||
"""Base class for video operations with shared functionality."""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize video base.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.polling = polling
|
||||
|
||||
def _get_headers(self) -> dict:
|
||||
"""Get HTTP headers for API requests."""
|
||||
return {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
}
|
||||
|
||||
def _download_video(self, video_url: str, timeout: int = 180) -> bytes:
|
||||
"""Download video from URL.
|
||||
|
||||
Args:
|
||||
video_url: URL to download video from
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
bytes: Video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If download fails
|
||||
"""
|
||||
logger.info(f"[WaveSpeed] Downloading video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
}
|
||||
)
|
||||
|
||||
return video_response.content
|
||||
|
||||
def _extract_video_url(self, outputs: list) -> Optional[str]:
|
||||
"""Extract video URL from outputs array.
|
||||
|
||||
Args:
|
||||
outputs: Array of outputs (can be strings or dicts)
|
||||
|
||||
Returns:
|
||||
Optional[str]: Video URL if found, None otherwise
|
||||
"""
|
||||
if not outputs:
|
||||
return None
|
||||
|
||||
output = outputs[0]
|
||||
if isinstance(output, str):
|
||||
return output if output.startswith("http") else None
|
||||
elif isinstance(output, dict):
|
||||
return output.get("url") or output.get("video_url")
|
||||
|
||||
return None
|
||||
109
backend/services/wavespeed/generators/video/enhancement.py
Normal file
109
backend/services/wavespeed/generators/video/enhancement.py
Normal file
@@ -0,0 +1,109 @@
|
||||
"""
|
||||
Video enhancement operations (upscaling).
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.enhancement")
|
||||
|
||||
|
||||
class VideoEnhancement(VideoBase):
|
||||
"""Video enhancement operations."""
|
||||
|
||||
def upscale_video(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
target_resolution: str = "1080p",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Upscale video using FlashVSR.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL
|
||||
target_resolution: Target resolution ("720p", "1080p", "2k", "4k")
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300 for long videos)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Upscaled video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the upscaling fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/flashvsr"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"video": video,
|
||||
"target_resolution": target_resolution,
|
||||
}
|
||||
|
||||
logger.info(f"[WaveSpeed] Upscaling video via {url} (target={target_resolution})")
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] FlashVSR submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed FlashVSR submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in FlashVSR response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed FlashVSR response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] FlashVSR task submitted: {prediction_id}")
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0, # Longer interval for upscaling (slower process)
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed FlashVSR returned no outputs")
|
||||
|
||||
video_url = outputs[0] if isinstance(outputs[0], str) else outputs[0].get("url")
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed FlashVSR output format not recognized")
|
||||
|
||||
# Download the upscaled video
|
||||
logger.info(f"[WaveSpeed] Downloading upscaled video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download upscaled video: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download upscaled video from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video upscaling completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
161
backend/services/wavespeed/generators/video/extension.py
Normal file
161
backend/services/wavespeed/generators/video/extension.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""
|
||||
Video extension operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.extension")
|
||||
|
||||
|
||||
class VideoExtension(VideoBase):
|
||||
"""Video extension operations."""
|
||||
|
||||
def extend_video(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: str,
|
||||
model: str = "wan-2.5", # "wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro"
|
||||
audio: Optional[str] = None, # Optional audio URL (WAN 2.5 only)
|
||||
negative_prompt: Optional[str] = None, # WAN 2.5 only
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
enable_prompt_expansion: bool = False, # WAN 2.5 only
|
||||
generate_audio: bool = True, # Seedance 1.5 Pro only
|
||||
camera_fixed: bool = False, # Seedance 1.5 Pro only
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Extend video duration using WAN 2.5, WAN 2.2 Spicy, or Seedance 1.5 Pro video-extend.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL
|
||||
prompt: Text prompt describing how to extend the video
|
||||
model: Model to use ("wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro")
|
||||
audio: Optional audio URL to guide generation (WAN 2.5 only)
|
||||
negative_prompt: Optional negative prompt (WAN 2.5 only)
|
||||
resolution: Output resolution (varies by model)
|
||||
duration: Duration of extended video in seconds (varies by model)
|
||||
enable_prompt_expansion: Enable prompt optimizer (WAN 2.5 only)
|
||||
generate_audio: Generate audio for extended video (Seedance 1.5 Pro only)
|
||||
camera_fixed: Fix camera position (Seedance 1.5 Pro only)
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Extended video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the extension fails
|
||||
"""
|
||||
# Determine model path
|
||||
if model in ("wan-2.2-spicy", "wavespeed-ai/wan-2.2-spicy/video-extend"):
|
||||
model_path = "wavespeed-ai/wan-2.2-spicy/video-extend"
|
||||
elif model in ("seedance-1.5-pro", "bytedance/seedance-v1.5-pro/video-extend"):
|
||||
model_path = "bytedance/seedance-v1.5-pro/video-extend"
|
||||
else:
|
||||
# Default to WAN 2.5
|
||||
model_path = "alibaba/wan-2.5/video-extend"
|
||||
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Base payload (common to all models)
|
||||
payload = {
|
||||
"video": video,
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"duration": duration,
|
||||
}
|
||||
|
||||
# Model-specific parameters
|
||||
if model_path == "alibaba/wan-2.5/video-extend":
|
||||
# WAN 2.5 specific
|
||||
payload["enable_prompt_expansion"] = enable_prompt_expansion
|
||||
if audio:
|
||||
payload["audio"] = audio
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
elif model_path == "bytedance/seedance-v1.5-pro/video-extend":
|
||||
# Seedance 1.5 Pro specific
|
||||
payload["generate_audio"] = generate_audio
|
||||
payload["camera_fixed"] = camera_fixed
|
||||
|
||||
# Seed (all models support it)
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
logger.info(f"[WaveSpeed] Extending video via {url} (duration={duration}s, resolution={resolution})")
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video extend submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video extend submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed] No prediction ID in video extend response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed video extend response missing prediction id",
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed] Video extend task submitted: {prediction_id}")
|
||||
|
||||
# Poll for result
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video extend returned no outputs")
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed video extend output format not recognized")
|
||||
|
||||
# Download the extended video
|
||||
logger.info(f"[WaveSpeed] Downloading extended video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Failed to download extended video: {video_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to download extended video from WaveSpeed",
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video extension completed successfully (size: {len(video_bytes)} bytes)")
|
||||
|
||||
return video_bytes
|
||||
283
backend/services/wavespeed/generators/video/face_swap.py
Normal file
283
backend/services/wavespeed/generators/video/face_swap.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Face swap operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.face_swap")
|
||||
|
||||
|
||||
class VideoFaceSwap(VideoBase):
|
||||
"""Face swap operations."""
|
||||
|
||||
def face_swap(
|
||||
self,
|
||||
image: str, # Base64-encoded image or URL
|
||||
video: str, # Base64-encoded video or URL
|
||||
prompt: Optional[str] = None,
|
||||
resolution: str = "480p",
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Perform face/character swap using MoCha (wavespeed-ai/wan-2.1/mocha).
|
||||
|
||||
Args:
|
||||
image: Base64-encoded image data URI or public URL (reference character)
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
prompt: Optional prompt to guide the swap
|
||||
resolution: Output resolution ("480p" or "720p")
|
||||
seed: Random seed for reproducibility (-1 for random)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Face-swapped video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the face swap fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/wan-2.1/mocha"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"image": image,
|
||||
"video": video,
|
||||
}
|
||||
|
||||
if prompt:
|
||||
payload["prompt"] = prompt
|
||||
|
||||
if resolution in ("480p", "720p"):
|
||||
payload["resolution"] = resolution
|
||||
else:
|
||||
payload["resolution"] = "480p" # Default
|
||||
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
else:
|
||||
payload["seed"] = -1 # Random seed
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Face swap request via {url} "
|
||||
f"(resolution={payload['resolution']}, seed={payload['seed']})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Face swap submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed face swap submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected face swap response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Face swap submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
# Poll until complete
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extract video URL from result
|
||||
outputs = result.get("outputs", [])
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap completed but no output video found"},
|
||||
)
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap output format not recognized"},
|
||||
)
|
||||
|
||||
# Download video
|
||||
logger.info(f"[WaveSpeed] Downloading face-swapped video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": f"Failed to download face-swapped video: {video_response.status_code}"},
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Face swap completed: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
else:
|
||||
# Return prediction ID for async polling
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for face swap",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
|
||||
def video_face_swap(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
face_image: str, # Base64-encoded image or URL
|
||||
target_gender: str = "all",
|
||||
target_index: int = 0,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Perform face swap using Video Face Swap (wavespeed-ai/video-face-swap).
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
face_image: Base64-encoded image data URI or public URL (reference face)
|
||||
target_gender: Filter which faces to swap ("all", "female", "male")
|
||||
target_index: Select which face to swap (0 = largest, 1 = second largest, etc.)
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 300)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Face-swapped video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the face swap fails
|
||||
"""
|
||||
model_path = "wavespeed-ai/video-face-swap"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"face_image": face_image,
|
||||
}
|
||||
|
||||
if target_gender in ("all", "female", "male"):
|
||||
payload["target_gender"] = target_gender
|
||||
else:
|
||||
payload["target_gender"] = "all" # Default
|
||||
|
||||
if 0 <= target_index <= 10:
|
||||
payload["target_index"] = target_index
|
||||
else:
|
||||
payload["target_index"] = 0 # Default
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Video face swap request via {url} "
|
||||
f"(target_gender={payload['target_gender']}, target_index={payload['target_index']})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video face swap submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video face swap submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected video face swap response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Video face swap submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
# Poll until complete
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extract video URL from result
|
||||
outputs = result.get("outputs", [])
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video face swap completed but no output video found"},
|
||||
)
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video face swap output format not recognized"},
|
||||
)
|
||||
|
||||
# Download video
|
||||
logger.info(f"[WaveSpeed] Downloading face-swapped video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": f"Failed to download face-swapped video: {video_response.status_code}"},
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video face swap completed: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
else:
|
||||
# Return prediction ID for async polling
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for video face swap",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
333
backend/services/wavespeed/generators/video/generation.py
Normal file
333
backend/services/wavespeed/generators/video/generation.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Video generation operations (text-to-video and image-to-video).
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Any, Dict, Optional
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.generation")
|
||||
|
||||
|
||||
class VideoGeneration(VideoBase):
|
||||
"""Video generation operations."""
|
||||
|
||||
def submit_image_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""
|
||||
Submit an image-to-video generation request.
|
||||
|
||||
Returns the prediction ID for polling.
|
||||
"""
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
logger.info(f"[WaveSpeed] Submitting request to {url}")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json().get("data")
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected submission response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Submitted request: {prediction_id}")
|
||||
return prediction_id
|
||||
|
||||
def submit_text_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 60,
|
||||
) -> str:
|
||||
"""
|
||||
Submit a text-to-video generation request to WaveSpeed.
|
||||
|
||||
Args:
|
||||
model_path: Model path (e.g., "alibaba/wan-2.5/text-to-video")
|
||||
payload: Request payload with prompt, resolution, duration, optional audio
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Prediction ID for polling
|
||||
"""
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
logger.info(f"[WaveSpeed] Submitting text-to-video request to {url}")
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed text-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
data = response.json().get("data")
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected text-to-video response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Submitted text-to-video request: {prediction_id}")
|
||||
return prediction_id
|
||||
|
||||
def generate_text_video(
|
||||
self,
|
||||
prompt: str,
|
||||
resolution: str = "720p", # 480p, 720p, 1080p
|
||||
duration: int = 5, # 5 or 10 seconds
|
||||
audio_base64: Optional[str] = None, # Optional audio for lip-sync
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 180,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate video from text prompt using WAN 2.5 text-to-video.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt describing the video
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB) for lip-sync
|
||||
negative_prompt: Optional negative prompt
|
||||
seed: Optional random seed for reproducibility
|
||||
enable_prompt_expansion: Enable prompt optimizer
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
model_path = "alibaba/wan-2.5/text-to-video"
|
||||
|
||||
# Validate resolution
|
||||
valid_resolutions = ["480p", "720p", "1080p"]
|
||||
if resolution not in valid_resolutions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid resolution: {resolution}. Must be one of: {valid_resolutions}"
|
||||
)
|
||||
|
||||
# Validate duration
|
||||
if duration not in [5, 10]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Duration must be 5 or 10 seconds"
|
||||
)
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"duration": duration,
|
||||
"enable_prompt_expansion": enable_prompt_expansion,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
}
|
||||
|
||||
# Add optional audio
|
||||
if audio_base64:
|
||||
payload["audio"] = audio_base64
|
||||
|
||||
# Add optional parameters
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Submit request
|
||||
logger.info(
|
||||
f"[WaveSpeed] Generating text-to-video: resolution={resolution}, "
|
||||
f"duration={duration}s, prompt_length={len(prompt)}, sync_mode={enable_sync_mode}"
|
||||
)
|
||||
|
||||
# For sync mode, submit and get result directly
|
||||
if enable_sync_mode:
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Text-to-video submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed text-to-video submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status - if "created" or "processing", we need to poll even in sync mode
|
||||
status = data.get("status", "").lower()
|
||||
outputs = data.get("outputs") or []
|
||||
prediction_id = data.get("id")
|
||||
|
||||
logger.debug(
|
||||
f"[WaveSpeed] Sync mode response: status='{status}', outputs_count={len(outputs)}, "
|
||||
f"prediction_id={prediction_id}"
|
||||
)
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if status == "completed" and outputs:
|
||||
# Sync mode returned completed result - use it directly
|
||||
logger.info(f"[WaveSpeed] Got immediate video results from sync mode (status: {status})")
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
logger.error(f"[WaveSpeed] Invalid video URL format in sync mode: {video_url}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}",
|
||||
)
|
||||
|
||||
video_bytes = self._download_video(video_url)
|
||||
metadata = data.get("metadata") or {}
|
||||
else:
|
||||
# Sync mode returned "created", "processing", or incomplete status - need to poll
|
||||
if not prediction_id:
|
||||
logger.error(
|
||||
f"[WaveSpeed] Sync mode returned status '{status}' but no prediction ID. "
|
||||
f"Response: {response.text[:500]}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed text-to-video sync mode returned async response without prediction ID",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Sync mode returned status '{status}' with {len(outputs)} output(s). "
|
||||
f"Falling back to polling (prediction_id: {prediction_id})"
|
||||
)
|
||||
|
||||
# Poll for completion
|
||||
try:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
)
|
||||
except HTTPException as e:
|
||||
detail = e.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise HTTPException(status_code=e.status_code, detail=detail)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
logger.error(f"[WaveSpeed] Polling completed but no outputs: {result}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed text-to-video completed but returned no outputs",
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
logger.error(f"[WaveSpeed] Invalid video URL format after polling: {video_url}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}",
|
||||
)
|
||||
|
||||
video_bytes = self._download_video(video_url)
|
||||
metadata = result.get("metadata") or {}
|
||||
else:
|
||||
# Async mode - submit and poll
|
||||
prediction_id = self.submit_text_to_video(model_path, payload, timeout=timeout)
|
||||
|
||||
# Poll for completion
|
||||
try:
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0
|
||||
)
|
||||
except HTTPException as e:
|
||||
detail = e.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise HTTPException(status_code=e.status_code, detail=detail)
|
||||
|
||||
# Extract video URL
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WAN 2.5 text-to-video completed but returned no outputs"
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}"
|
||||
)
|
||||
|
||||
video_bytes = self._download_video(video_url)
|
||||
metadata = result.get("metadata") or {}
|
||||
# prediction_id is already set from earlier in the function
|
||||
|
||||
# Calculate cost (same pricing as image-to-video)
|
||||
pricing = {
|
||||
"480p": 0.05,
|
||||
"720p": 0.10,
|
||||
"1080p": 0.15,
|
||||
}
|
||||
cost = pricing.get(resolution, 0.10) * duration
|
||||
|
||||
# Get video dimensions
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
"1080p": (1920, 1080),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] ✅ Generated text-to-video: {len(video_bytes)} bytes, "
|
||||
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": float(duration),
|
||||
"model_name": "alibaba/wan-2.5/text-to-video",
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
"resolution": resolution,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"metadata": metadata,
|
||||
}
|
||||
263
backend/services/wavespeed/generators/video/generator.py
Normal file
263
backend/services/wavespeed/generators/video/generator.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""
|
||||
Main VideoGenerator class that composes all video operation modules.
|
||||
|
||||
This class maintains backward compatibility with the original monolithic VideoGenerator
|
||||
by delegating to specialized modules for different video operations.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional, Callable
|
||||
|
||||
from .base import VideoBase
|
||||
from .generation import VideoGeneration
|
||||
from .enhancement import VideoEnhancement
|
||||
from .extension import VideoExtension
|
||||
from .face_swap import VideoFaceSwap
|
||||
from .translation import VideoTranslation
|
||||
from .background import VideoBackground
|
||||
from .audio import VideoAudio
|
||||
|
||||
|
||||
class VideoGenerator(VideoBase):
|
||||
"""
|
||||
Video generation generator for WaveSpeed API.
|
||||
|
||||
This class composes multiple specialized modules to provide all video operations
|
||||
while maintaining a single unified interface for backward compatibility.
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: str, polling):
|
||||
"""Initialize video generator.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key
|
||||
base_url: WaveSpeed API base URL
|
||||
polling: WaveSpeedPolling instance for async operations
|
||||
"""
|
||||
super().__init__(api_key, base_url, polling)
|
||||
|
||||
# Initialize specialized modules
|
||||
self._generation = VideoGeneration(api_key, base_url, polling)
|
||||
self._enhancement = VideoEnhancement(api_key, base_url, polling)
|
||||
self._extension = VideoExtension(api_key, base_url, polling)
|
||||
self._face_swap = VideoFaceSwap(api_key, base_url, polling)
|
||||
self._translation = VideoTranslation(api_key, base_url, polling)
|
||||
self._background = VideoBackground(api_key, base_url, polling)
|
||||
self._audio = VideoAudio(api_key, base_url, polling)
|
||||
|
||||
# Generation methods (delegated to VideoGeneration)
|
||||
def submit_image_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 30,
|
||||
) -> str:
|
||||
"""Submit an image-to-video generation request."""
|
||||
return self._generation.submit_image_to_video(model_path, payload, timeout)
|
||||
|
||||
def submit_text_to_video(
|
||||
self,
|
||||
model_path: str,
|
||||
payload: Dict[str, Any],
|
||||
timeout: int = 60,
|
||||
) -> str:
|
||||
"""Submit a text-to-video generation request to WaveSpeed."""
|
||||
return self._generation.submit_text_to_video(model_path, payload, timeout)
|
||||
|
||||
def generate_text_video(
|
||||
self,
|
||||
prompt: str,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
audio_base64: Optional[str] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 180,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate video from text prompt using WAN 2.5 text-to-video."""
|
||||
return self._generation.generate_text_video(
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
audio_base64=audio_base64,
|
||||
negative_prompt=negative_prompt,
|
||||
seed=seed,
|
||||
enable_prompt_expansion=enable_prompt_expansion,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# Enhancement methods (delegated to VideoEnhancement)
|
||||
def upscale_video(
|
||||
self,
|
||||
video: str,
|
||||
target_resolution: str = "1080p",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Upscale video using FlashVSR."""
|
||||
return self._enhancement.upscale_video(
|
||||
video=video,
|
||||
target_resolution=target_resolution,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extension methods (delegated to VideoExtension)
|
||||
def extend_video(
|
||||
self,
|
||||
video: str,
|
||||
prompt: str,
|
||||
model: str = "wan-2.5",
|
||||
audio: Optional[str] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
enable_prompt_expansion: bool = False,
|
||||
generate_audio: bool = True,
|
||||
camera_fixed: bool = False,
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Extend video duration using WAN 2.5, WAN 2.2 Spicy, or Seedance 1.5 Pro video-extend."""
|
||||
return self._extension.extend_video(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
audio=audio,
|
||||
negative_prompt=negative_prompt,
|
||||
resolution=resolution,
|
||||
duration=duration,
|
||||
enable_prompt_expansion=enable_prompt_expansion,
|
||||
generate_audio=generate_audio,
|
||||
camera_fixed=camera_fixed,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Face swap methods (delegated to VideoFaceSwap)
|
||||
def face_swap(
|
||||
self,
|
||||
image: str,
|
||||
video: str,
|
||||
prompt: Optional[str] = None,
|
||||
resolution: str = "480p",
|
||||
seed: Optional[int] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Perform face/character swap using MoCha (wavespeed-ai/wan-2.1/mocha)."""
|
||||
return self._face_swap.face_swap(
|
||||
image=image,
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
resolution=resolution,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
def video_face_swap(
|
||||
self,
|
||||
video: str,
|
||||
face_image: str,
|
||||
target_gender: str = "all",
|
||||
target_index: int = 0,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Perform face swap using Video Face Swap (wavespeed-ai/video-face-swap)."""
|
||||
return self._face_swap.video_face_swap(
|
||||
video=video,
|
||||
face_image=face_image,
|
||||
target_gender=target_gender,
|
||||
target_index=target_index,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Translation methods (delegated to VideoTranslation)
|
||||
def video_translate(
|
||||
self,
|
||||
video: str,
|
||||
output_language: str = "English",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 600,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Translate video to target language using HeyGen Video Translate."""
|
||||
return self._translation.video_translate(
|
||||
video=video,
|
||||
output_language=output_language,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Background methods (delegated to VideoBackground)
|
||||
def remove_background(
|
||||
self,
|
||||
video: str,
|
||||
background_image: Optional[str] = None,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Remove or replace video background using Video Background Remover."""
|
||||
return self._background.remove_background(
|
||||
video=video,
|
||||
background_image=background_image,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Audio methods (delegated to VideoAudio)
|
||||
def hunyuan_video_foley(
|
||||
self,
|
||||
video: str,
|
||||
prompt: Optional[str] = None,
|
||||
seed: int = -1,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Generate realistic Foley and ambient audio from video using Hunyuan Video Foley."""
|
||||
return self._audio.hunyuan_video_foley(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
def think_sound(
|
||||
self,
|
||||
video: str,
|
||||
prompt: Optional[str] = None,
|
||||
seed: int = -1,
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 300,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""Generate realistic sound effects and audio tracks from video using Think Sound."""
|
||||
return self._audio.think_sound(
|
||||
video=video,
|
||||
prompt=prompt,
|
||||
seed=seed,
|
||||
enable_sync_mode=enable_sync_mode,
|
||||
timeout=timeout,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
133
backend/services/wavespeed/generators/video/translation.py
Normal file
133
backend/services/wavespeed/generators/video/translation.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Video translation operations.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Optional, Callable
|
||||
from fastapi import HTTPException
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .base import VideoBase
|
||||
|
||||
logger = get_service_logger("wavespeed.generators.video.translation")
|
||||
|
||||
|
||||
class VideoTranslation(VideoBase):
|
||||
"""Video translation operations."""
|
||||
|
||||
def video_translate(
|
||||
self,
|
||||
video: str, # Base64-encoded video or URL
|
||||
output_language: str = "English",
|
||||
enable_sync_mode: bool = False,
|
||||
timeout: int = 600,
|
||||
progress_callback: Optional[Callable[[float, str], None]] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Translate video to target language using HeyGen Video Translate.
|
||||
|
||||
Args:
|
||||
video: Base64-encoded video data URI or public URL (source video)
|
||||
output_language: Target language for translation (default: "English")
|
||||
enable_sync_mode: If True, wait for result and return it directly
|
||||
timeout: Request timeout in seconds (default: 600)
|
||||
progress_callback: Optional callback function(progress: float, message: str) for progress updates
|
||||
|
||||
Returns:
|
||||
bytes: Translated video bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If the video translation fails
|
||||
"""
|
||||
model_path = "heygen/video-translate"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"video": video,
|
||||
"output_language": output_language,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed] Video translate request via {url} "
|
||||
f"(output_language={output_language})"
|
||||
)
|
||||
|
||||
# Submit the task
|
||||
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed] Video translate submission failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed video translate submission failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
if not data or "id" not in data:
|
||||
logger.error(f"[WaveSpeed] Unexpected video translate response: {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "WaveSpeed response missing prediction id"},
|
||||
)
|
||||
|
||||
prediction_id = data["id"]
|
||||
logger.info(f"[WaveSpeed] Video translate submitted: {prediction_id}")
|
||||
|
||||
if enable_sync_mode:
|
||||
# Poll until complete
|
||||
result = self.polling.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=timeout,
|
||||
interval_seconds=2.0,
|
||||
progress_callback=progress_callback,
|
||||
)
|
||||
|
||||
# Extract video URL from result
|
||||
outputs = result.get("outputs", [])
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video translate completed but no output video found"},
|
||||
)
|
||||
|
||||
# Handle outputs - can be array of strings or array of objects
|
||||
video_url = None
|
||||
if isinstance(outputs[0], str):
|
||||
video_url = outputs[0]
|
||||
elif isinstance(outputs[0], dict):
|
||||
video_url = outputs[0].get("url") or outputs[0].get("video_url")
|
||||
|
||||
if not video_url:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Video translate output format not recognized"},
|
||||
)
|
||||
|
||||
# Download video
|
||||
logger.info(f"[WaveSpeed] Downloading translated video from: {video_url}")
|
||||
video_response = requests.get(video_url, timeout=timeout)
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": f"Failed to download translated video: {video_response.status_code}"},
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
logger.info(f"[WaveSpeed] Video translate completed: {len(video_bytes)} bytes")
|
||||
return video_bytes
|
||||
else:
|
||||
# Return prediction ID for async polling
|
||||
raise HTTPException(
|
||||
status_code=501,
|
||||
detail={
|
||||
"error": "Async mode not yet implemented for video translate",
|
||||
"prediction_id": prediction_id,
|
||||
},
|
||||
)
|
||||
Reference in New Issue
Block a user