AI Image Studio, AI podcast Maker, AI product Marketing
This commit is contained in:
@@ -6,6 +6,11 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
TalkingAvatarRequest,
|
||||
)
|
||||
from .templates import PlatformTemplates, TemplateManager
|
||||
|
||||
__all__ = [
|
||||
@@ -20,6 +25,9 @@ __all__ = [
|
||||
"ControlStudioRequest",
|
||||
"SocialOptimizerService",
|
||||
"SocialOptimizerRequest",
|
||||
"TransformStudioService",
|
||||
"TransformImageToVideoRequest",
|
||||
"TalkingAvatarRequest",
|
||||
"PlatformTemplates",
|
||||
"TemplateManager",
|
||||
]
|
||||
|
||||
155
backend/services/image_studio/infinitetalk_adapter.py
Normal file
155
backend/services/image_studio/infinitetalk_adapter.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""InfiniteTalk adapter for Transform Studio."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.infinitetalk")
|
||||
|
||||
|
||||
class InfiniteTalkService:
|
||||
"""Adapter for InfiniteTalk in Transform Studio context."""
|
||||
|
||||
def __init__(self, client: Optional[WaveSpeedClient] = None):
|
||||
"""Initialize InfiniteTalk service adapter."""
|
||||
self.client = client or WaveSpeedClient()
|
||||
logger.info("[InfiniteTalk Adapter] Service initialized")
|
||||
|
||||
def calculate_cost(self, resolution: str, duration: float) -> float:
|
||||
"""Calculate cost for InfiniteTalk video.
|
||||
|
||||
Args:
|
||||
resolution: Output resolution (480p or 720p)
|
||||
duration: Video duration in seconds
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
# InfiniteTalk pricing: $0.03/s (480p) or $0.06/s (720p)
|
||||
# Minimum charge: 5 seconds
|
||||
cost_per_second = 0.03 if resolution == "480p" else 0.06
|
||||
actual_duration = max(5.0, duration) # Minimum 5 seconds
|
||||
return cost_per_second * actual_duration
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
image_base64: str,
|
||||
audio_base64: str,
|
||||
resolution: str = "720p",
|
||||
prompt: Optional[str] = None,
|
||||
mask_image_base64: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
user_id: str = "transform_studio",
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar video using InfiniteTalk.
|
||||
|
||||
Args:
|
||||
image_base64: Person image in base64 or data URI
|
||||
audio_base64: Audio file in base64 or data URI
|
||||
resolution: Output resolution (480p or 720p)
|
||||
prompt: Optional prompt for expression/style
|
||||
mask_image_base64: Optional mask for animatable regions
|
||||
seed: Optional random seed
|
||||
user_id: User ID for tracking
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
# Validate resolution
|
||||
if resolution not in ["480p", "720p"]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Resolution must be '480p' or '720p' for InfiniteTalk"
|
||||
)
|
||||
|
||||
# Decode image
|
||||
import base64
|
||||
try:
|
||||
if image_base64.startswith("data:"):
|
||||
if "," not in image_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = image_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
|
||||
image_mime = mime_parts.strip() or "image/png"
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
image_mime = "image/png"
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode image: {str(e)}"
|
||||
)
|
||||
|
||||
# Decode audio
|
||||
try:
|
||||
if audio_base64.startswith("data:"):
|
||||
if "," not in audio_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = audio_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
|
||||
audio_mime = mime_parts.strip() or "audio/mpeg"
|
||||
audio_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
audio_mime = "audio/mpeg"
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode audio: {str(e)}"
|
||||
)
|
||||
|
||||
# Call existing InfiniteTalk function (run in thread since it's synchronous)
|
||||
# Note: We pass empty dicts for scene_data and story_context since
|
||||
# Transform Studio doesn't have story context
|
||||
try:
|
||||
result = await asyncio.to_thread(
|
||||
animate_scene_with_voiceover,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
scene_data={}, # Empty for Transform Studio
|
||||
story_context={}, # Empty for Transform Studio
|
||||
user_id=user_id,
|
||||
resolution=resolution,
|
||||
prompt_override=prompt,
|
||||
image_mime=image_mime,
|
||||
audio_mime=audio_mime,
|
||||
client=self.client,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[InfiniteTalk Adapter] Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"InfiniteTalk generation failed: {str(e)}"
|
||||
)
|
||||
|
||||
# Calculate actual cost based on duration
|
||||
actual_cost = self.calculate_cost(resolution, result.get("duration", 5.0))
|
||||
|
||||
# Update result with actual cost and additional metadata
|
||||
result["cost"] = actual_cost
|
||||
result["resolution"] = resolution
|
||||
|
||||
# Get video dimensions from resolution
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
result["width"] = width
|
||||
result["height"] = height
|
||||
|
||||
logger.info(
|
||||
f"[InfiniteTalk Adapter] ✅ Generated talking avatar: "
|
||||
f"resolution={resolution}, duration={result.get('duration', 5.0)}s, cost=${actual_cost:.2f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -7,6 +7,11 @@ from .edit_service import EditStudioService, EditStudioRequest
|
||||
from .upscale_service import UpscaleStudioService, UpscaleStudioRequest
|
||||
from .control_service import ControlStudioService, ControlStudioRequest
|
||||
from .social_optimizer_service import SocialOptimizerService, SocialOptimizerRequest
|
||||
from .transform_service import (
|
||||
TransformStudioService,
|
||||
TransformImageToVideoRequest,
|
||||
TalkingAvatarRequest,
|
||||
)
|
||||
from .templates import Platform, TemplateCategory, ImageTemplate
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
@@ -24,6 +29,7 @@ class ImageStudioManager:
|
||||
self.upscale_service = UpscaleStudioService()
|
||||
self.control_service = ControlStudioService()
|
||||
self.social_optimizer_service = SocialOptimizerService()
|
||||
self.transform_service = TransformStudioService()
|
||||
logger.info("[Image Studio Manager] Initialized successfully")
|
||||
|
||||
# ====================
|
||||
@@ -339,4 +345,35 @@ class ImageStudioManager:
|
||||
}
|
||||
|
||||
return specs.get(platform, {})
|
||||
|
||||
# ====================
|
||||
# TRANSFORM STUDIO
|
||||
# ====================
|
||||
|
||||
async def transform_image_to_video(
|
||||
self,
|
||||
request: TransformImageToVideoRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform image to video using WAN 2.5."""
|
||||
logger.info("[Image Studio] Transform image-to-video request from user: %s", user_id)
|
||||
return await self.transform_service.transform_image_to_video(request, user_id=user_id or "anonymous")
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
request: TalkingAvatarRequest,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar using InfiniteTalk."""
|
||||
logger.info("[Image Studio] Talking avatar request from user: %s", user_id)
|
||||
return await self.transform_service.create_talking_avatar(request, user_id=user_id or "anonymous")
|
||||
|
||||
def estimate_transform_cost(
|
||||
self,
|
||||
operation: str,
|
||||
resolution: str,
|
||||
duration: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for transform operation."""
|
||||
return self.transform_service.estimate_cost(operation, resolution, duration)
|
||||
|
||||
|
||||
379
backend/services/image_studio/transform_service.py
Normal file
379
backend/services/image_studio/transform_service.py
Normal file
@@ -0,0 +1,379 @@
|
||||
"""Transform Studio service for image-to-video and talking avatar generation."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
from dataclasses import dataclass
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from .wan25_service import WAN25Service
|
||||
from .infinitetalk_adapter import InfiniteTalkService
|
||||
from services.llm_providers.main_video_generation import track_video_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.file_storage import save_file_safely, sanitize_filename
|
||||
|
||||
logger = get_service_logger("image_studio.transform")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformImageToVideoRequest:
|
||||
"""Request for WAN 2.5 image-to-video."""
|
||||
image_base64: str
|
||||
prompt: str
|
||||
audio_base64: Optional[str] = None
|
||||
resolution: str = "720p" # 480p, 720p, 1080p
|
||||
duration: int = 5 # 5 or 10 seconds
|
||||
negative_prompt: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
enable_prompt_expansion: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class TalkingAvatarRequest:
|
||||
"""Request for InfiniteTalk talking avatar."""
|
||||
image_base64: str
|
||||
audio_base64: str
|
||||
resolution: str = "720p" # 480p or 720p
|
||||
prompt: Optional[str] = None
|
||||
mask_image_base64: Optional[str] = None
|
||||
seed: Optional[int] = None
|
||||
|
||||
|
||||
class TransformStudioService:
|
||||
"""Service for Transform Studio operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Transform Studio service."""
|
||||
self.wan25_service = WAN25Service()
|
||||
self.infinitetalk_service = InfiniteTalkService()
|
||||
|
||||
# Video output directory
|
||||
# __file__ is: backend/services/image_studio/transform_service.py
|
||||
# We need: backend/transform_videos
|
||||
base_dir = Path(__file__).parent.parent.parent.parent
|
||||
self.output_dir = base_dir / "transform_videos"
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Verify directory was created
|
||||
if not self.output_dir.exists():
|
||||
raise RuntimeError(f"Failed to create transform_videos directory: {self.output_dir}")
|
||||
|
||||
logger.info(f"[Transform Studio] Initialized with output directory: {self.output_dir}")
|
||||
|
||||
def _save_video_file(
|
||||
self,
|
||||
video_bytes: bytes,
|
||||
operation_type: str,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Save video file to disk.
|
||||
|
||||
Args:
|
||||
video_bytes: Video content as bytes
|
||||
operation_type: Type of operation (e.g., "image-to-video", "talking-avatar")
|
||||
user_id: User ID for directory organization
|
||||
|
||||
Returns:
|
||||
Dictionary with filename, file_path, and file_url
|
||||
"""
|
||||
# Create user-specific directory
|
||||
user_dir = self.output_dir / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename
|
||||
filename = f"{operation_type}_{uuid.uuid4().hex[:8]}.mp4"
|
||||
filename = sanitize_filename(filename)
|
||||
|
||||
# Save file
|
||||
file_path, error = save_file_safely(
|
||||
content=video_bytes,
|
||||
directory=user_dir,
|
||||
filename=filename,
|
||||
max_file_size=500 * 1024 * 1024 # 500MB max for videos
|
||||
)
|
||||
|
||||
if error:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to save video file: {error}"
|
||||
)
|
||||
|
||||
file_url = f"/api/image-studio/videos/{user_id}/{filename}"
|
||||
|
||||
return {
|
||||
"filename": filename,
|
||||
"file_path": str(file_path),
|
||||
"file_url": file_url,
|
||||
"file_size": len(video_bytes),
|
||||
}
|
||||
|
||||
async def transform_image_to_video(
|
||||
self,
|
||||
request: TransformImageToVideoRequest,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Transform image to video using WAN 2.5.
|
||||
|
||||
Args:
|
||||
request: Transform request
|
||||
user_id: User ID for tracking and file organization
|
||||
|
||||
Returns:
|
||||
Dictionary with video URL, metadata, and cost
|
||||
"""
|
||||
logger.info(
|
||||
f"[Transform Studio] Image-to-video request from user {user_id}: "
|
||||
f"resolution={request.resolution}, duration={request.duration}s"
|
||||
)
|
||||
|
||||
# Generate video using WAN 2.5
|
||||
result = await self.wan25_service.generate_video(
|
||||
image_base64=request.image_base64,
|
||||
prompt=request.prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=request.enable_prompt_expansion,
|
||||
)
|
||||
|
||||
# Save video to disk
|
||||
save_result = self._save_video_file(
|
||||
video_bytes=result["video_bytes"],
|
||||
operation_type="image-to-video",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=result["provider"],
|
||||
model_name=result["model_name"],
|
||||
prompt=result["prompt"],
|
||||
video_bytes=result["video_bytes"],
|
||||
cost_override=result["cost"],
|
||||
)
|
||||
logger.info(
|
||||
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
|
||||
f"{usage_info.get('video_limit_display', '∞')} videos, "
|
||||
f"cost=${result['cost']:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
|
||||
|
||||
# Save to asset library
|
||||
try:
|
||||
from services.database import get_db
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="image_studio",
|
||||
filename=save_result["filename"],
|
||||
file_url=save_result["file_url"],
|
||||
file_path=save_result["file_path"],
|
||||
file_size=save_result["file_size"],
|
||||
mime_type="video/mp4",
|
||||
title=f"Transform: Image-to-Video ({request.resolution})",
|
||||
description=f"Generated video using WAN 2.5: {request.prompt[:100]}",
|
||||
prompt=result["prompt"],
|
||||
tags=["image_studio", "transform", "video", "image-to-video", request.resolution],
|
||||
provider=result["provider"],
|
||||
model=result["model_name"],
|
||||
cost=result["cost"],
|
||||
asset_metadata={
|
||||
"resolution": request.resolution,
|
||||
"duration": result["duration"],
|
||||
"operation": "image-to-video",
|
||||
"width": result["width"],
|
||||
"height": result["height"],
|
||||
}
|
||||
)
|
||||
logger.info(f"[Transform Studio] Video saved to asset library")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": save_result["file_url"],
|
||||
"video_base64": None, # Don't include base64 for large videos
|
||||
"duration": result["duration"],
|
||||
"resolution": result["resolution"],
|
||||
"width": result["width"],
|
||||
"height": result["height"],
|
||||
"file_size": save_result["file_size"],
|
||||
"cost": result["cost"],
|
||||
"provider": result["provider"],
|
||||
"model": result["model_name"],
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
async def create_talking_avatar(
|
||||
self,
|
||||
request: TalkingAvatarRequest,
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create talking avatar using InfiniteTalk.
|
||||
|
||||
Args:
|
||||
request: Talking avatar request
|
||||
user_id: User ID for tracking and file organization
|
||||
|
||||
Returns:
|
||||
Dictionary with video URL, metadata, and cost
|
||||
"""
|
||||
logger.info(
|
||||
f"[Transform Studio] Talking avatar request from user {user_id}: "
|
||||
f"resolution={request.resolution}"
|
||||
)
|
||||
|
||||
# Generate video using InfiniteTalk
|
||||
result = await self.infinitetalk_service.create_talking_avatar(
|
||||
image_base64=request.image_base64,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=request.prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Save video to disk
|
||||
save_result = self._save_video_file(
|
||||
video_bytes=result["video_bytes"],
|
||||
operation_type="talking-avatar",
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
usage_info = track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=result["provider"],
|
||||
model_name=result["model_name"],
|
||||
prompt=result.get("prompt", ""),
|
||||
video_bytes=result["video_bytes"],
|
||||
cost_override=result["cost"],
|
||||
)
|
||||
logger.info(
|
||||
f"[Transform Studio] Usage tracked: {usage_info.get('current_calls', 0)} / "
|
||||
f"{usage_info.get('video_limit_display', '∞')} videos, "
|
||||
f"cost=${result['cost']:.2f}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to track usage: {e}")
|
||||
|
||||
# Save to asset library
|
||||
try:
|
||||
from services.database import get_db
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="video",
|
||||
source_module="image_studio",
|
||||
filename=save_result["filename"],
|
||||
file_url=save_result["file_url"],
|
||||
file_path=save_result["file_path"],
|
||||
file_size=save_result["file_size"],
|
||||
mime_type="video/mp4",
|
||||
title=f"Transform: Talking Avatar ({request.resolution})",
|
||||
description="Generated talking avatar video using InfiniteTalk",
|
||||
prompt=result.get("prompt", ""),
|
||||
tags=["image_studio", "transform", "video", "talking-avatar", request.resolution],
|
||||
provider=result["provider"],
|
||||
model=result["model_name"],
|
||||
cost=result["cost"],
|
||||
asset_metadata={
|
||||
"resolution": request.resolution,
|
||||
"duration": result.get("duration", 5.0),
|
||||
"operation": "talking-avatar",
|
||||
"width": result.get("width", 1280),
|
||||
"height": result.get("height", 720),
|
||||
}
|
||||
)
|
||||
logger.info(f"[Transform Studio] Video saved to asset library")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Transform Studio] Failed to save to asset library: {e}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"video_url": save_result["file_url"],
|
||||
"video_base64": None, # Don't include base64 for large videos
|
||||
"duration": result.get("duration", 5.0),
|
||||
"resolution": result.get("resolution", request.resolution),
|
||||
"width": result.get("width", 1280),
|
||||
"height": result.get("height", 720),
|
||||
"file_size": save_result["file_size"],
|
||||
"cost": result["cost"],
|
||||
"provider": result["provider"],
|
||||
"model": result["model_name"],
|
||||
"metadata": result.get("metadata", {}),
|
||||
}
|
||||
|
||||
def estimate_cost(
|
||||
self,
|
||||
operation: str,
|
||||
resolution: str,
|
||||
duration: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Estimate cost for transform operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type ("image-to-video" or "talking-avatar")
|
||||
resolution: Output resolution
|
||||
duration: Video duration in seconds (for image-to-video)
|
||||
|
||||
Returns:
|
||||
Cost estimation details
|
||||
"""
|
||||
if operation == "image-to-video":
|
||||
if duration is None:
|
||||
duration = 5
|
||||
cost = self.wan25_service.calculate_cost(resolution, duration)
|
||||
return {
|
||||
"estimated_cost": cost,
|
||||
"breakdown": {
|
||||
"base_cost": 0.0,
|
||||
"per_second": self.wan25_service.calculate_cost(resolution, 1),
|
||||
"duration": duration,
|
||||
"total": cost,
|
||||
},
|
||||
"currency": "USD",
|
||||
"provider": "wavespeed",
|
||||
"model": "alibaba/wan-2.5/image-to-video",
|
||||
}
|
||||
elif operation == "talking-avatar":
|
||||
# InfiniteTalk minimum is 5 seconds
|
||||
estimated_duration = duration or 5.0
|
||||
cost = self.infinitetalk_service.calculate_cost(resolution, estimated_duration)
|
||||
return {
|
||||
"estimated_cost": cost,
|
||||
"breakdown": {
|
||||
"base_cost": 0.0,
|
||||
"per_second": self.infinitetalk_service.calculate_cost(resolution, 1.0),
|
||||
"duration": estimated_duration,
|
||||
"total": cost,
|
||||
},
|
||||
"currency": "USD",
|
||||
"provider": "wavespeed",
|
||||
"model": "wavespeed-ai/infinitetalk",
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unknown operation: {operation}")
|
||||
|
||||
295
backend/services/image_studio/wan25_service.py
Normal file
295
backend/services/image_studio/wan25_service.py
Normal file
@@ -0,0 +1,295 @@
|
||||
"""WAN 2.5 service for Alibaba image-to-video generation via WaveSpeed."""
|
||||
|
||||
import base64
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_studio.wan25")
|
||||
|
||||
WAN25_MODEL_PATH = "alibaba/wan-2.5/image-to-video"
|
||||
WAN25_MODEL_NAME = "alibaba/wan-2.5/image-to-video"
|
||||
|
||||
# Pricing per second (from WaveSpeed docs)
|
||||
PRICING = {
|
||||
"480p": 0.05, # $0.05 per second
|
||||
"720p": 0.10, # $0.10 per second
|
||||
"1080p": 0.15, # $0.15 per second
|
||||
}
|
||||
|
||||
MAX_IMAGE_BYTES = 10 * 1024 * 1024 # 10MB (recommended)
|
||||
MAX_AUDIO_BYTES = 15 * 1024 * 1024 # 15MB (API limit)
|
||||
MIN_AUDIO_DURATION = 3 # seconds
|
||||
MAX_AUDIO_DURATION = 30 # seconds
|
||||
|
||||
|
||||
def _as_data_uri(content_bytes: bytes, mime_type: str) -> str:
|
||||
"""Convert bytes to data URI."""
|
||||
encoded = base64.b64encode(content_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{encoded}"
|
||||
|
||||
|
||||
def _decode_base64_image(image_base64: str) -> tuple[bytes, str]:
|
||||
"""Decode base64 image, handling data URIs."""
|
||||
if image_base64.startswith("data:"):
|
||||
# Extract mime type and base64 data
|
||||
if "," not in image_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = image_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "image/png"
|
||||
mime_type = mime_parts.strip()
|
||||
if not mime_type:
|
||||
mime_type = "image/png"
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
# Assume it's raw base64
|
||||
image_bytes = base64.b64decode(image_base64)
|
||||
mime_type = "image/png" # Default
|
||||
|
||||
return image_bytes, mime_type
|
||||
|
||||
|
||||
def _decode_base64_audio(audio_base64: str) -> tuple[bytes, str]:
|
||||
"""Decode base64 audio, handling data URIs."""
|
||||
if audio_base64.startswith("data:"):
|
||||
if "," not in audio_base64:
|
||||
raise ValueError("Invalid data URI format: missing comma separator")
|
||||
header, encoded = audio_base64.split(",", 1)
|
||||
mime_parts = header.split(":")[1].split(";")[0] if ":" in header else "audio/mpeg"
|
||||
mime_type = mime_parts.strip()
|
||||
if not mime_type:
|
||||
mime_type = "audio/mpeg"
|
||||
audio_bytes = base64.b64decode(encoded)
|
||||
else:
|
||||
audio_bytes = base64.b64decode(audio_base64)
|
||||
mime_type = "audio/mpeg" # Default
|
||||
|
||||
return audio_bytes, mime_type
|
||||
|
||||
|
||||
class WAN25Service:
|
||||
"""Service for Alibaba WAN 2.5 image-to-video generation."""
|
||||
|
||||
def __init__(self, client: Optional[WaveSpeedClient] = None):
|
||||
"""Initialize WAN 2.5 service."""
|
||||
self.client = client or WaveSpeedClient()
|
||||
logger.info("[WAN 2.5] Service initialized")
|
||||
|
||||
def calculate_cost(self, resolution: str, duration: int) -> float:
|
||||
"""Calculate cost for video generation.
|
||||
|
||||
Args:
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
|
||||
Returns:
|
||||
Cost in USD
|
||||
"""
|
||||
cost_per_second = PRICING.get(resolution, PRICING["720p"])
|
||||
return cost_per_second * duration
|
||||
|
||||
async def generate_video(
|
||||
self,
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
audio_base64: Optional[str] = None,
|
||||
resolution: str = "720p",
|
||||
duration: int = 5,
|
||||
negative_prompt: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
enable_prompt_expansion: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate video using WAN 2.5.
|
||||
|
||||
Args:
|
||||
image_base64: Image in base64 or data URI format
|
||||
prompt: Text prompt describing the video
|
||||
audio_base64: Optional audio file (wav/mp3, 3-30s, ≤15MB)
|
||||
resolution: Output resolution (480p, 720p, 1080p)
|
||||
duration: Video duration in seconds (5 or 10)
|
||||
negative_prompt: Optional negative prompt
|
||||
seed: Optional random seed for reproducibility
|
||||
enable_prompt_expansion: Enable prompt optimizer
|
||||
|
||||
Returns:
|
||||
Dictionary with video bytes, metadata, and cost
|
||||
"""
|
||||
# Validate resolution
|
||||
if resolution not in PRICING:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid resolution: {resolution}. Must be one of: {list(PRICING.keys())}"
|
||||
)
|
||||
|
||||
# Validate duration
|
||||
if duration not in [5, 10]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid duration: {duration}. Must be 5 or 10 seconds"
|
||||
)
|
||||
|
||||
# Validate prompt
|
||||
if not prompt or not prompt.strip():
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Prompt is required and cannot be empty"
|
||||
)
|
||||
|
||||
# Decode image
|
||||
try:
|
||||
image_bytes, image_mime = _decode_base64_image(image_base64)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode image: {str(e)}"
|
||||
)
|
||||
|
||||
# Validate image size
|
||||
if len(image_bytes) > MAX_IMAGE_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Image exceeds {MAX_IMAGE_BYTES / (1024*1024):.0f}MB limit"
|
||||
)
|
||||
|
||||
# Build payload
|
||||
payload = {
|
||||
"image": _as_data_uri(image_bytes, image_mime),
|
||||
"prompt": prompt,
|
||||
"resolution": resolution,
|
||||
"duration": duration,
|
||||
"enable_prompt_expansion": enable_prompt_expansion,
|
||||
}
|
||||
|
||||
# Add optional audio
|
||||
if audio_base64:
|
||||
try:
|
||||
audio_bytes, audio_mime = _decode_base64_audio(audio_base64)
|
||||
|
||||
# Validate audio size
|
||||
if len(audio_bytes) > MAX_AUDIO_BYTES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Audio exceeds {MAX_AUDIO_BYTES / (1024*1024):.0f}MB limit"
|
||||
)
|
||||
|
||||
# Note: Audio duration validation would require audio analysis
|
||||
# For now, we rely on API to handle it (API keeps first 5s/10s if longer)
|
||||
|
||||
payload["audio"] = _as_data_uri(audio_bytes, audio_mime)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to decode audio: {str(e)}"
|
||||
)
|
||||
|
||||
# Add optional parameters
|
||||
if negative_prompt:
|
||||
payload["negative_prompt"] = negative_prompt
|
||||
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
|
||||
# Submit to WaveSpeed
|
||||
logger.info(
|
||||
f"[WAN 2.5] Submitting video generation request: resolution={resolution}, duration={duration}s"
|
||||
)
|
||||
|
||||
try:
|
||||
prediction_id = self.client.submit_image_to_video(
|
||||
WAN25_MODEL_PATH,
|
||||
payload,
|
||||
timeout=60
|
||||
)
|
||||
except HTTPException as e:
|
||||
logger.error(f"[WAN 2.5] Submission failed: {e.detail}")
|
||||
raise
|
||||
|
||||
# Poll for completion
|
||||
logger.info(f"[WAN 2.5] Polling for completion: prediction_id={prediction_id}")
|
||||
|
||||
try:
|
||||
# WAN 2.5 typically takes 1-2 minutes
|
||||
result = self.client.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=180, # 3 minutes max
|
||||
interval_seconds=2.0
|
||||
)
|
||||
except HTTPException as e:
|
||||
detail = e.detail or {}
|
||||
if isinstance(detail, dict):
|
||||
detail.setdefault("prediction_id", prediction_id)
|
||||
detail.setdefault("resume_available", True)
|
||||
raise HTTPException(status_code=e.status_code, detail=detail)
|
||||
|
||||
# Extract video URL
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WAN 2.5 completed but returned no outputs"
|
||||
)
|
||||
|
||||
video_url = outputs[0]
|
||||
if not isinstance(video_url, str) or not video_url.startswith("http"):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Invalid video URL format: {video_url}"
|
||||
)
|
||||
|
||||
# Download video (run synchronous request in thread)
|
||||
logger.info(f"[WAN 2.5] Downloading video from: {video_url}")
|
||||
video_response = await asyncio.to_thread(
|
||||
requests.get,
|
||||
video_url,
|
||||
timeout=180
|
||||
)
|
||||
|
||||
if video_response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Failed to download WAN 2.5 video",
|
||||
"status_code": video_response.status_code,
|
||||
"response": video_response.text[:200],
|
||||
}
|
||||
)
|
||||
|
||||
video_bytes = video_response.content
|
||||
metadata = result.get("metadata") or {}
|
||||
|
||||
# Calculate cost
|
||||
cost = self.calculate_cost(resolution, duration)
|
||||
|
||||
# Get video dimensions from resolution
|
||||
resolution_dims = {
|
||||
"480p": (854, 480),
|
||||
"720p": (1280, 720),
|
||||
"1080p": (1920, 1080),
|
||||
}
|
||||
width, height = resolution_dims.get(resolution, (1280, 720))
|
||||
|
||||
logger.info(
|
||||
f"[WAN 2.5] ✅ Generated video: {len(video_bytes)} bytes, "
|
||||
f"resolution={resolution}, duration={duration}s, cost=${cost:.2f}"
|
||||
)
|
||||
|
||||
return {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": float(duration),
|
||||
"model_name": WAN25_MODEL_NAME,
|
||||
"cost": cost,
|
||||
"provider": "wavespeed",
|
||||
"source_video_url": video_url,
|
||||
"prediction_id": prediction_id,
|
||||
"resolution": resolution,
|
||||
"width": width,
|
||||
"height": height,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user