Files

162 lines
6.5 KiB
Python

"""
Video extension operations.
"""
import requests
from typing import Optional, Callable
from fastapi import HTTPException
from utils.logger_utils import get_service_logger
from .base import VideoBase
logger = get_service_logger("wavespeed.generators.video.extension")
class VideoExtension(VideoBase):
"""Video extension operations."""
def extend_video(
self,
video: str, # Base64-encoded video or URL
prompt: str,
model: str = "wan-2.5", # "wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro"
audio: Optional[str] = None, # Optional audio URL (WAN 2.5 only)
negative_prompt: Optional[str] = None, # WAN 2.5 only
resolution: str = "720p",
duration: int = 5,
enable_prompt_expansion: bool = False, # WAN 2.5 only
generate_audio: bool = True, # Seedance 1.5 Pro only
camera_fixed: bool = False, # Seedance 1.5 Pro only
seed: Optional[int] = None,
enable_sync_mode: bool = False,
timeout: int = 300,
progress_callback: Optional[Callable[[float, str], None]] = None,
) -> bytes:
"""
Extend video duration using WAN 2.5, WAN 2.2 Spicy, or Seedance 1.5 Pro video-extend.
Args:
video: Base64-encoded video data URI or public URL
prompt: Text prompt describing how to extend the video
model: Model to use ("wan-2.5", "wan-2.2-spicy", or "seedance-1.5-pro")
audio: Optional audio URL to guide generation (WAN 2.5 only)
negative_prompt: Optional negative prompt (WAN 2.5 only)
resolution: Output resolution (varies by model)
duration: Duration of extended video in seconds (varies by model)
enable_prompt_expansion: Enable prompt optimizer (WAN 2.5 only)
generate_audio: Generate audio for extended video (Seedance 1.5 Pro only)
camera_fixed: Fix camera position (Seedance 1.5 Pro only)
seed: Random seed for reproducibility (-1 for random)
enable_sync_mode: If True, wait for result and return it directly
timeout: Request timeout in seconds (default: 300)
progress_callback: Optional callback function(progress: float, message: str) for progress updates
Returns:
bytes: Extended video bytes
Raises:
HTTPException: If the extension fails
"""
# Determine model path
if model in ("wan-2.2-spicy", "wavespeed-ai/wan-2.2-spicy/video-extend"):
model_path = "wavespeed-ai/wan-2.2-spicy/video-extend"
elif model in ("seedance-1.5-pro", "bytedance/seedance-v1.5-pro/video-extend"):
model_path = "bytedance/seedance-v1.5-pro/video-extend"
else:
# Default to WAN 2.5
model_path = "alibaba/wan-2.5/video-extend"
url = f"{self.base_url}/{model_path}"
# Base payload (common to all models)
payload = {
"video": video,
"prompt": prompt,
"resolution": resolution,
"duration": duration,
}
# Model-specific parameters
if model_path == "alibaba/wan-2.5/video-extend":
# WAN 2.5 specific
payload["enable_prompt_expansion"] = enable_prompt_expansion
if audio:
payload["audio"] = audio
if negative_prompt:
payload["negative_prompt"] = negative_prompt
elif model_path == "bytedance/seedance-v1.5-pro/video-extend":
# Seedance 1.5 Pro specific
payload["generate_audio"] = generate_audio
payload["camera_fixed"] = camera_fixed
# Seed (all models support it)
if seed is not None:
payload["seed"] = seed
logger.info(f"[WaveSpeed] Extending video via {url} (duration={duration}s, resolution={resolution})")
# Submit the task
response = requests.post(url, headers=self._get_headers(), json=payload, timeout=timeout)
if response.status_code != 200:
logger.error(f"[WaveSpeed] Video extend submission failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed video extend submission failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
prediction_id = data.get("id")
if not prediction_id:
logger.error(f"[WaveSpeed] No prediction ID in video extend response: {response.text}")
raise HTTPException(
status_code=502,
detail="WaveSpeed video extend response missing prediction id",
)
logger.info(f"[WaveSpeed] Video extend task submitted: {prediction_id}")
# Poll for result
result = self.polling.poll_until_complete(
prediction_id,
timeout_seconds=timeout,
interval_seconds=2.0,
progress_callback=progress_callback,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed video extend returned no outputs")
# Handle outputs - can be array of strings or array of objects
video_url = None
if isinstance(outputs[0], str):
video_url = outputs[0]
elif isinstance(outputs[0], dict):
video_url = outputs[0].get("url") or outputs[0].get("video_url")
if not video_url:
raise HTTPException(status_code=502, detail="WaveSpeed video extend output format not recognized")
# Download the extended video
logger.info(f"[WaveSpeed] Downloading extended video from: {video_url}")
video_response = requests.get(video_url, timeout=timeout)
if video_response.status_code != 200:
logger.error(f"[WaveSpeed] Failed to download extended video: {video_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to download extended video from WaveSpeed",
)
video_bytes = video_response.content
logger.info(f"[WaveSpeed] Video extension completed successfully (size: {len(video_bytes)} bytes)")
return video_bytes