AI Researcher and Video Studio implementation complete
This commit is contained in:
@@ -1,4 +1,12 @@
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from .base import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
ImageGenerationProvider,
|
||||
ImageEditOptions,
|
||||
ImageEditProvider,
|
||||
FaceSwapOptions,
|
||||
FaceSwapProvider,
|
||||
)
|
||||
from .hf_provider import HuggingFaceImageProvider
|
||||
from .gemini_provider import GeminiImageProvider
|
||||
from .stability_provider import StabilityImageProvider
|
||||
@@ -8,6 +16,10 @@ __all__ = [
|
||||
"ImageGenerationOptions",
|
||||
"ImageGenerationResult",
|
||||
"ImageGenerationProvider",
|
||||
"ImageEditOptions",
|
||||
"ImageEditProvider",
|
||||
"FaceSwapOptions",
|
||||
"FaceSwapProvider",
|
||||
"HuggingFaceImageProvider",
|
||||
"GeminiImageProvider",
|
||||
"StabilityImageProvider",
|
||||
|
||||
@@ -28,6 +28,50 @@ class ImageGenerationResult:
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageEditOptions:
|
||||
"""Options for image editing operations."""
|
||||
image_base64: str
|
||||
prompt: str
|
||||
operation: str # "general_edit", "inpaint", "outpaint", "remove_background", etc.
|
||||
mask_base64: Optional[str] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API calls."""
|
||||
result = {
|
||||
"image_base64": self.image_base64,
|
||||
"prompt": self.prompt,
|
||||
"operation": self.operation,
|
||||
}
|
||||
if self.mask_base64:
|
||||
result["mask_base64"] = self.mask_base64
|
||||
if self.negative_prompt:
|
||||
result["negative_prompt"] = self.negative_prompt
|
||||
if self.model:
|
||||
result["model"] = self.model
|
||||
if self.width:
|
||||
result["width"] = self.width
|
||||
if self.height:
|
||||
result["height"] = self.height
|
||||
if self.guidance_scale is not None:
|
||||
result["guidance_scale"] = self.guidance_scale
|
||||
if self.steps:
|
||||
result["steps"] = self.steps
|
||||
if self.seed is not None:
|
||||
result["seed"] = self.seed
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
class ImageGenerationProvider(Protocol):
|
||||
"""Protocol for image generation providers."""
|
||||
|
||||
@@ -35,3 +79,44 @@ class ImageGenerationProvider(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceSwapOptions:
|
||||
"""Options for face swap operations."""
|
||||
base_image_base64: str # Image to swap face into
|
||||
face_image_base64: str # Face to swap
|
||||
model: Optional[str] = None
|
||||
target_face_index: Optional[int] = None # For multi-face images (0 = largest)
|
||||
target_gender: Optional[str] = None # "all", "female", "male" (for some models)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API calls."""
|
||||
result = {
|
||||
"base_image_base64": self.base_image_base64,
|
||||
"face_image_base64": self.face_image_base64,
|
||||
}
|
||||
if self.model:
|
||||
result["model"] = self.model
|
||||
if self.target_face_index is not None:
|
||||
result["target_face_index"] = self.target_face_index
|
||||
if self.target_gender:
|
||||
result["target_gender"] = self.target_gender
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
class ImageEditProvider(Protocol):
|
||||
"""Protocol for image editing providers."""
|
||||
|
||||
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
class FaceSwapProvider(Protocol):
|
||||
"""Protocol for face swap providers."""
|
||||
|
||||
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,691 @@
|
||||
"""WaveSpeed AI image editing provider (14 editing models)."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import requests
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import ImageEditProvider, ImageEditOptions, ImageGenerationResult
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("wavespeed.edit_provider")
|
||||
|
||||
|
||||
class WaveSpeedEditProvider(ImageEditProvider):
|
||||
"""WaveSpeed AI image editing provider supporting 14 editing models.
|
||||
|
||||
REUSES: WaveSpeedClient, model registry pattern, result format
|
||||
"""
|
||||
|
||||
# Model registry - populated with WaveSpeed editing models
|
||||
SUPPORTED_MODELS = {
|
||||
"qwen-edit": {
|
||||
"model_path": "wavespeed-ai/qwen-image/edit",
|
||||
"name": "Qwen Image Edit",
|
||||
"description": "20B MMDiT image-to-image model offering precise bilingual (Chinese & English) text edits while preserving style. Single-image editing with style preservation.",
|
||||
"cost": 0.02, # Same as Plus version
|
||||
"max_resolution": (1536, 1536), # Based on docs: similar to Plus
|
||||
"capabilities": ["general_edit", "style_transfer", "text_edit"],
|
||||
"tier": "budget",
|
||||
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
||||
"supports_controlnet": False, # Not mentioned in docs
|
||||
"languages": ["en", "zh"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
||||
"default_output_format": "jpeg", # Per API docs: default is "jpeg"
|
||||
"supports_seed": True, # Per API docs: seed parameter supported
|
||||
}
|
||||
},
|
||||
"qwen-edit-plus": {
|
||||
"model_path": "wavespeed-ai/qwen-image/edit-plus",
|
||||
"name": "Qwen Image Edit Plus",
|
||||
"description": "20B MMDiT image editor with multi-image editing, single-image consistency and native ControlNet support. Bilingual (CN/EN) text editing, appearance-level and semantic-level edits.",
|
||||
"cost": 0.02,
|
||||
"max_resolution": (1536, 1536), # Based on docs: 256-1536 per dimension
|
||||
"capabilities": ["general_edit", "style_transfer", "text_edit", "multi_image"],
|
||||
"tier": "budget",
|
||||
"supports_multi_image": True, # Up to 3 reference images
|
||||
"supports_controlnet": True,
|
||||
"languages": ["en", "zh"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": False, # Uses "images" (array)
|
||||
"supports_seed": True, # Seed parameter supported (default for Qwen models)
|
||||
}
|
||||
},
|
||||
"nano-banana-pro-edit-ultra": {
|
||||
"model_path": "google/nano-banana-pro/edit-ultra",
|
||||
"name": "Google Nano Banana Pro Edit Ultra",
|
||||
"description": "High-resolution image editing with 4K/8K native output. Natural language instructions, multilingual text support. Premium quality editing for professional marketing and high-res work.",
|
||||
"cost": 0.15, # 4K - from enhancement proposal
|
||||
"cost_8k": 0.18, # 8K - from enhancement proposal
|
||||
"max_resolution": (8192, 8192), # 8K support
|
||||
"capabilities": ["general_edit", "high_res", "professional", "typography"],
|
||||
"tier": "premium",
|
||||
"supports_multi_image": True, # Up to 14 reference images
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en", "multilingual"],
|
||||
"api_params": {
|
||||
"uses_size": False, # Uses aspect_ratio and resolution instead
|
||||
"uses_aspect_ratio": True, # "1:1", "16:9", etc.
|
||||
"uses_resolution": True, # "4k" or "8k"
|
||||
"max_images": 14,
|
||||
"default_output_format": "png", # Per API docs: default is "png"
|
||||
"supports_seed": False, # Per API docs: no seed parameter
|
||||
}
|
||||
},
|
||||
"seedream-v4.5-edit": {
|
||||
"model_path": "bytedance/seedream-v4.5/edit",
|
||||
"name": "Bytedance Seedream V4.5 Edit",
|
||||
"description": "Preserves facial features, lighting, and color tone from reference images, delivering professional, high-fidelity edits up to 4K with strong prompt adherence. Reference-faithful editing with multi-image support.",
|
||||
"cost": 0.04, # Per generated image
|
||||
"max_resolution": (4096, 4096), # 4K support (1024-4096 per dimension)
|
||||
"capabilities": ["general_edit", "portrait_retouching", "fashion_edit", "product_edit", "multi_image"],
|
||||
"tier": "mid",
|
||||
"supports_multi_image": True, # Up to 10 reference images
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height format, 1024-4096 per dimension)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"max_images": 10,
|
||||
"default_output_format": "png",
|
||||
"supports_seed": False, # No seed parameter in API docs (Seedream V4.5)
|
||||
}
|
||||
},
|
||||
"flux-kontext-pro": {
|
||||
"model_path": "wavespeed-ai/flux-kontext-pro",
|
||||
"name": "FLUX Kontext Pro",
|
||||
"description": "FLUX.1 Kontext [pro] offers improved prompt adherence and accurate typography generation for consistent, high-quality edits at speed. Typography-focused editing with improved prompt adherence.",
|
||||
"cost": 0.04, # From enhancement proposal
|
||||
"max_resolution": (2048, 2048), # Estimated, not specified in docs
|
||||
"capabilities": ["general_edit", "typography", "text_edit", "style_transfer"],
|
||||
"tier": "mid",
|
||||
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en"],
|
||||
"api_params": {
|
||||
"uses_size": False, # Uses aspect_ratio instead
|
||||
"uses_aspect_ratio": True, # Aspect ratio as string (e.g., "16:9", "1:1")
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
||||
"supports_guidance_scale": True, # Has guidance_scale parameter (default 3.5, range 1-20)
|
||||
"default_guidance_scale": 3.5, # Per API docs
|
||||
"supports_seed": False, # No seed parameter in API docs
|
||||
}
|
||||
},
|
||||
# TODO: Add remaining 9 models once docs are provided
|
||||
}
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize WaveSpeed edit provider.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key (falls back to env var if not provided)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
|
||||
|
||||
# REUSE: Same client as generation provider
|
||||
self.client = WaveSpeedClient(api_key=self.api_key)
|
||||
logger.info("[WaveSpeed Edit Provider] Initialized with %d models",
|
||||
len(self.SUPPORTED_MODELS))
|
||||
|
||||
def _validate_options(self, options: ImageEditOptions) -> None:
|
||||
"""Validate editing options.
|
||||
|
||||
Args:
|
||||
options: Image editing options
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
model = options.model or list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None
|
||||
|
||||
if not model:
|
||||
raise ValueError("No model specified and no default model available")
|
||||
|
||||
if model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
max_width, max_height = model_info.get("max_resolution", (4096, 4096))
|
||||
|
||||
if options.width and options.width > max_width:
|
||||
raise ValueError(
|
||||
f"Width {options.width} exceeds maximum {max_width} for model {model}"
|
||||
)
|
||||
|
||||
if options.height and options.height > max_height:
|
||||
raise ValueError(
|
||||
f"Height {options.height} exceeds maximum {max_height} for model {model}"
|
||||
)
|
||||
|
||||
if not options.prompt or len(options.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
if not options.image_base64:
|
||||
raise ValueError("Image base64 cannot be empty")
|
||||
|
||||
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
|
||||
"""Edit image using WaveSpeed AI models.
|
||||
|
||||
Args:
|
||||
options: Image editing options
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
RuntimeError: If editing fails
|
||||
"""
|
||||
# Validate options
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model = options.model or (list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None)
|
||||
if not model:
|
||||
raise ValueError("No model available for editing")
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
model_path = model_info["model_path"]
|
||||
|
||||
logger.info("[WaveSpeed Edit] Starting edit: model=%s, operation=%s, prompt=%s",
|
||||
model, options.operation, options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare extra parameters based on model capabilities
|
||||
extra_params = options.extra or {}
|
||||
|
||||
# Add model-specific parameters if needed
|
||||
api_params = model_info.get("api_params", {})
|
||||
if api_params.get("uses_resolution", False):
|
||||
# For Nano Banana: determine resolution from dimensions or use default
|
||||
if options.width and options.height:
|
||||
if options.width >= 4096 or options.height >= 4096:
|
||||
extra_params["resolution"] = "8k"
|
||||
else:
|
||||
extra_params["resolution"] = "4k"
|
||||
elif "resolution" not in extra_params:
|
||||
extra_params["resolution"] = "4k" # Default to 4K
|
||||
|
||||
if api_params.get("uses_aspect_ratio", False) and not extra_params.get("aspect_ratio"):
|
||||
# Calculate aspect ratio if dimensions provided
|
||||
if options.width and options.height:
|
||||
aspect_ratio = self._calculate_aspect_ratio(options.width, options.height)
|
||||
if aspect_ratio:
|
||||
extra_params["aspect_ratio"] = aspect_ratio
|
||||
|
||||
# Call WaveSpeed API for editing
|
||||
result = self._call_wavespeed_edit_api(
|
||||
model_path=model_path,
|
||||
image_base64=options.image_base64,
|
||||
prompt=options.prompt,
|
||||
operation=options.operation,
|
||||
mask_base64=options.mask_base64,
|
||||
negative_prompt=options.negative_prompt,
|
||||
width=options.width,
|
||||
height=options.height,
|
||||
guidance_scale=options.guidance_scale,
|
||||
steps=options.steps,
|
||||
seed=options.seed,
|
||||
extra=extra_params
|
||||
)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
elif isinstance(result, dict) and "image_bytes" in result:
|
||||
image_bytes = result["image_bytes"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
# Load image to get dimensions
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
# Calculate estimated cost - handle resolution-based pricing
|
||||
estimated_cost = model_info.get("cost", 0.02)
|
||||
if api_params.get("uses_resolution", False):
|
||||
# Check if 8K was requested
|
||||
resolution = extra_params.get("resolution", "4k")
|
||||
if resolution == "8k" and "cost_8k" in model_info:
|
||||
estimated_cost = model_info["cost_8k"]
|
||||
|
||||
logger.info("[WaveSpeed Edit] ✅ Successfully edited image: %d bytes, %dx%d",
|
||||
len(image_bytes), width, height)
|
||||
|
||||
# REUSE: Same result format as generation
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={
|
||||
"provider": "wavespeed",
|
||||
"model": model,
|
||||
"model_name": model_info.get("name", model),
|
||||
"operation": options.operation,
|
||||
"prompt": options.prompt,
|
||||
"negative_prompt": options.negative_prompt,
|
||||
"estimated_cost": estimated_cost,
|
||||
"tier": model_info.get("tier", "mid"),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[WaveSpeed Edit] ❌ Error editing image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed edit failed: {str(e)}")
|
||||
|
||||
def _call_wavespeed_edit_api(
|
||||
self,
|
||||
model_path: str,
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str,
|
||||
mask_base64: Optional[str] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
steps: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
extra: Optional[dict] = None
|
||||
) -> bytes:
|
||||
"""Call WaveSpeed API for image editing.
|
||||
|
||||
REUSES: Same pattern as ImageGenerator.generate_image()
|
||||
|
||||
Args:
|
||||
model_path: Full model path (e.g., "wavespeed-ai/qwen-image/edit-plus")
|
||||
image_base64: Base64-encoded input image
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of operation
|
||||
mask_base64: Optional mask for inpainting
|
||||
negative_prompt: Optional negative prompt
|
||||
width: Optional target width
|
||||
height: Optional target height
|
||||
guidance_scale: Optional guidance scale (not used by all models)
|
||||
steps: Optional number of steps (not used by all models)
|
||||
seed: Optional seed
|
||||
extra: Optional extra parameters
|
||||
|
||||
Returns:
|
||||
Edited image bytes
|
||||
|
||||
Raises:
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
# Build URL - REUSES same pattern as ImageGenerator
|
||||
url = f"{self.client.BASE_URL}/{model_path}"
|
||||
|
||||
# Prepare images array - WaveSpeed expects array of image strings
|
||||
# Format: base64 strings or data URIs (data:image/png;base64,...)
|
||||
# For Qwen Image Edit Plus: supports up to 3 reference images
|
||||
images = []
|
||||
|
||||
# Add main image - check if it's already a data URI or just base64
|
||||
if image_base64.startswith("data:image"):
|
||||
# Already a data URI
|
||||
images.append(image_base64)
|
||||
else:
|
||||
# Assume it's base64, convert to data URI
|
||||
# Try to detect format from base64 or default to PNG
|
||||
images.append(f"data:image/png;base64,{image_base64}")
|
||||
|
||||
# If mask is provided, add it as second image
|
||||
# Note: Some models may need mask in different format - will adjust per model
|
||||
if mask_base64:
|
||||
if mask_base64.startswith("data:image"):
|
||||
images.append(mask_base64)
|
||||
else:
|
||||
images.append(f"data:image/png;base64,{mask_base64}")
|
||||
|
||||
# Get model info to determine API parameter structure
|
||||
model_info = self.SUPPORTED_MODELS.get(model_path.split("/")[-1] if "/" in model_path else model_path)
|
||||
if not model_info:
|
||||
# Fallback: try to find model by matching path
|
||||
for model_id, info in self.SUPPORTED_MODELS.items():
|
||||
if info["model_path"] == model_path:
|
||||
model_info = info
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
raise ValueError(f"Model info not found for: {model_path}")
|
||||
|
||||
api_params = model_info.get("api_params", {})
|
||||
|
||||
# Build payload - following WaveSpeed API structure
|
||||
# Note: output_format default varies by model (PNG for most, but can be JPEG)
|
||||
default_output_format = api_params.get("default_output_format", "png")
|
||||
|
||||
# Some models use "image" (singular) instead of "images" (array)
|
||||
uses_image_singular = api_params.get("uses_image_singular", False)
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"enable_sync_mode": True, # Use sync mode for immediate results
|
||||
"enable_base64_output": False, # Get URL, then download
|
||||
"output_format": default_output_format,
|
||||
}
|
||||
|
||||
# Add image(s) based on model API format
|
||||
if uses_image_singular:
|
||||
# Models like Qwen Edit (basic) use "image" (singular)
|
||||
# Use first image only (single image editing)
|
||||
if images:
|
||||
payload["image"] = images[0]
|
||||
else:
|
||||
raise ValueError("At least one image is required")
|
||||
else:
|
||||
# Models like Qwen Edit Plus, Nano Banana use "images" (array)
|
||||
payload["images"] = images
|
||||
|
||||
# Allow override of output_format from extra params
|
||||
if extra and "output_format" in extra:
|
||||
payload["output_format"] = extra["output_format"]
|
||||
|
||||
# Model-specific parameter handling
|
||||
if api_params.get("uses_size", True):
|
||||
# Models like Qwen Edit Plus use "size" parameter (width*height format)
|
||||
if width and height:
|
||||
payload["size"] = f"{width}*{height}"
|
||||
elif width:
|
||||
payload["size"] = f"{width}*{width}" # Square if only width provided
|
||||
elif height:
|
||||
payload["size"] = f"{height}*{height}" # Square if only height provided
|
||||
|
||||
if api_params.get("uses_aspect_ratio", False):
|
||||
# Models like Nano Banana and FLUX Kontext Pro use "aspect_ratio" parameter
|
||||
if width and height:
|
||||
# Calculate aspect ratio from dimensions
|
||||
aspect_ratio = self._calculate_aspect_ratio(width, height)
|
||||
if aspect_ratio:
|
||||
payload["aspect_ratio"] = aspect_ratio
|
||||
elif extra and "aspect_ratio" in extra:
|
||||
payload["aspect_ratio"] = extra["aspect_ratio"]
|
||||
|
||||
if api_params.get("uses_resolution", False):
|
||||
# Models like Nano Banana use "resolution" parameter ("4k" or "8k")
|
||||
if extra and "resolution" in extra:
|
||||
payload["resolution"] = extra["resolution"]
|
||||
else:
|
||||
# Default to 4K, or 8K if dimensions suggest high-res
|
||||
if width and height and (width >= 4096 or height >= 4096):
|
||||
payload["resolution"] = "8k"
|
||||
else:
|
||||
payload["resolution"] = "4k" # Default to 4K per API docs
|
||||
|
||||
# Add optional parameters (model-agnostic)
|
||||
# Guidance scale: Only add if model supports it (e.g., FLUX Kontext Pro)
|
||||
if api_params.get("supports_guidance_scale", False):
|
||||
default_guidance = api_params.get("default_guidance_scale", 3.5)
|
||||
if guidance_scale is not None:
|
||||
# Clamp to valid range (1-20 per FLUX Kontext Pro docs)
|
||||
payload["guidance_scale"] = max(1, min(20, guidance_scale))
|
||||
elif extra and "guidance_scale" in extra:
|
||||
payload["guidance_scale"] = max(1, min(20, extra["guidance_scale"]))
|
||||
else:
|
||||
payload["guidance_scale"] = default_guidance
|
||||
|
||||
# Seed parameter: Only add if model supports it
|
||||
if api_params.get("supports_seed", True): # Default to True for backward compatibility
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
else:
|
||||
payload["seed"] = -1 # Random seed (per API docs default)
|
||||
|
||||
# Add any extra parameters
|
||||
if extra:
|
||||
# Filter out parameters we've already handled
|
||||
handled_params = {"aspect_ratio", "resolution", "size", "seed", "guidance_scale"}
|
||||
for key, value in extra.items():
|
||||
if key not in handled_params:
|
||||
payload[key] = value
|
||||
|
||||
logger.info(f"[WaveSpeed Edit] Submitting edit request to {url} (model={model_path}, prompt_length={len(prompt)})")
|
||||
|
||||
# Make API call - REUSES same pattern as ImageGenerator
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.client._headers(),
|
||||
json=payload,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed Edit] API call failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image editing failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status
|
||||
status = data.get("status", "").lower()
|
||||
outputs = data.get("outputs") or []
|
||||
prediction_id = data.get("id")
|
||||
|
||||
logger.debug(
|
||||
f"[WaveSpeed Edit] Response: status='{status}', outputs_count={len(outputs)}, "
|
||||
f"prediction_id={prediction_id}"
|
||||
)
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed Edit] Got immediate results from sync mode")
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=120)
|
||||
|
||||
# Sync mode returned "created" or "processing" - need to poll
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed Edit] Sync mode returned status '{status}' but no prediction ID")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed sync mode returned async response without prediction ID",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed Edit] Sync mode returned status '{status}' with no outputs. "
|
||||
f"Polling for result (prediction_id: {prediction_id})"
|
||||
)
|
||||
|
||||
# Poll for result - REUSES polling utility
|
||||
result = self.client.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=180,
|
||||
interval_seconds=2.0,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned no outputs after polling"
|
||||
)
|
||||
|
||||
# Extract image URL from outputs - REUSE helper method
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=120)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[WaveSpeed Edit] Unexpected error: {str(e)}", exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed edit API call failed: {str(e)}")
|
||||
|
||||
def _extract_image_url(self, outputs: list) -> str:
|
||||
"""Extract image URL from outputs - REUSES same pattern as ImageGenerator.
|
||||
|
||||
Args:
|
||||
outputs: Array of output URLs or objects
|
||||
|
||||
Returns:
|
||||
Image URL string
|
||||
|
||||
Raises:
|
||||
HTTPException: If output format is invalid
|
||||
"""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned no outputs",
|
||||
)
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit output format not recognized",
|
||||
)
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned invalid image URL",
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
def _download_image(self, image_url: str, timeout: int = 120) -> bytes:
|
||||
"""Download image from URL - REUSES same pattern as ImageGenerator.
|
||||
|
||||
Args:
|
||||
image_url: URL to download from
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If download fails
|
||||
"""
|
||||
logger.info(f"[WaveSpeed Edit] Downloading edited image from: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
|
||||
if image_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed Edit] Failed to download image: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to download edited image: {image_response.status_code}"
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed Edit] Successfully downloaded image ({len(image_response.content)} bytes)")
|
||||
return image_response.content
|
||||
|
||||
def _calculate_aspect_ratio(self, width: int, height: int) -> Optional[str]:
|
||||
"""Calculate aspect ratio string from dimensions.
|
||||
|
||||
Args:
|
||||
width: Image width
|
||||
height: Image height
|
||||
|
||||
Returns:
|
||||
Aspect ratio string (e.g., "16:9") or None if not standard
|
||||
"""
|
||||
# Common aspect ratios (includes FLUX Kontext Pro supported ratios)
|
||||
ratios = {
|
||||
(1, 1): "1:1",
|
||||
(3, 2): "3:2",
|
||||
(2, 3): "2:3",
|
||||
(3, 4): "3:4",
|
||||
(4, 3): "4:3",
|
||||
(4, 5): "4:5",
|
||||
(5, 4): "5:4",
|
||||
(9, 16): "9:16",
|
||||
(16, 9): "16:9",
|
||||
(21, 9): "21:9",
|
||||
(9, 21): "9:21", # FLUX Kontext Pro also supports 9:21
|
||||
}
|
||||
|
||||
# Calculate GCD to simplify ratio
|
||||
def gcd(a, b):
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
divisor = gcd(width, height)
|
||||
simplified = (width // divisor, height // divisor)
|
||||
|
||||
# Check if it matches a standard ratio (with some tolerance)
|
||||
for (w, h), ratio_str in ratios.items():
|
||||
# Allow small tolerance for rounding
|
||||
if abs(simplified[0] / simplified[1] - w / h) < 0.01:
|
||||
return ratio_str
|
||||
|
||||
# If no match, return None (model may not support custom aspect ratios)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available editing models and their information.
|
||||
|
||||
Returns:
|
||||
Dictionary of available models
|
||||
"""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@classmethod
|
||||
def get_models_by_tier(cls, tier: str) -> dict:
|
||||
"""Get models filtered by tier (budget, mid, premium).
|
||||
|
||||
Args:
|
||||
tier: Tier name ("budget", "mid", "premium")
|
||||
|
||||
Returns:
|
||||
Dictionary of models in the specified tier
|
||||
"""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if model_info.get("tier") == tier
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models_by_operation(cls, operation: str) -> dict:
|
||||
"""Get models that support a specific operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type (e.g., "inpaint", "outpaint", "general_edit")
|
||||
|
||||
Returns:
|
||||
Dictionary of models supporting the operation
|
||||
"""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if operation in model_info.get("capabilities", [])
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
"""WaveSpeed Face Swap Provider for Image Studio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.image_generation.base import (
|
||||
FaceSwapOptions,
|
||||
FaceSwapProvider,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("llm_providers.wavespeed_face_swap")
|
||||
|
||||
|
||||
class WaveSpeedFaceSwapProvider:
|
||||
"""WaveSpeed provider for face swap operations."""
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"image-face-swap-pro": {
|
||||
"model_path": "wavespeed-ai/image-face-swap-pro",
|
||||
"name": "Image Face Swap Pro",
|
||||
"description": "Instant online AI face swap for photos with no watermark, delivering realistic, shareable results in seconds.",
|
||||
"cost": 0.025,
|
||||
"tier": "mid",
|
||||
"capabilities": ["face_swap", "realistic_blending"],
|
||||
"features": ["Enhanced blending", "Realistic results", "Watermark-free"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"output_format": "jpeg",
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
"image-head-swap": {
|
||||
"model_path": "wavespeed-ai/image-head-swap",
|
||||
"name": "Image Head Swap",
|
||||
"description": "Instant online AI head & face swap for photos with no watermark. Replaces entire head (face + hair + outline) while preserving body, pose and background.",
|
||||
"cost": 0.025,
|
||||
"tier": "mid",
|
||||
"capabilities": ["head_swap", "full_head_replacement", "realistic_blending"],
|
||||
"features": ["Full head replacement", "Hair included", "Pose preservation", "Watermark-free"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"output_format": "jpeg",
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
"akool-face-swap": {
|
||||
"model_path": "akool/image-face-swap",
|
||||
"name": "Akool Image Face Swap",
|
||||
"description": "Powerful AI-powered face swapping with multi-face replacement for group photos. Seamlessly replaces faces with natural lighting and skin tone matching.",
|
||||
"cost": 0.16,
|
||||
"tier": "premium",
|
||||
"capabilities": ["face_swap", "multi_face", "realistic_blending", "face_enhancement"],
|
||||
"features": ["Multi-face swapping (up to 5)", "Face enhancement", "Group photos", "High-quality blending"],
|
||||
"max_faces": 5, # Supports 1-5 faces
|
||||
"api_params": {
|
||||
"uses_source_target_arrays": True, # Uses source_image and target_image arrays
|
||||
"supports_face_enhance": True,
|
||||
"supports_base64": True,
|
||||
"supports_sync": False, # May need polling
|
||||
},
|
||||
},
|
||||
"infinite-you": {
|
||||
"model_path": "wavespeed-ai/infinite-you",
|
||||
"name": "InfiniteYou",
|
||||
"description": "High-quality face swapping powered by ByteDance's zero-shot identity preservation technology. Maintains facial identity characteristics with exceptional realism.",
|
||||
"cost": 0.03,
|
||||
"tier": "mid",
|
||||
"capabilities": ["face_swap", "identity_preservation", "realistic_blending"],
|
||||
"features": ["Zero-shot learning", "Identity preservation", "High-quality results", "Fast processing"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"uses_source_target_names": True, # Uses source_image and target_image (not image/face_image)
|
||||
"target_is_base": True, # target_image is the base image (where face will be swapped)
|
||||
"source_is_face": True, # source_image is the face to swap in
|
||||
"supports_seed": True, # Supports seed parameter
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
# Placeholder for additional models (will be added as docs are provided)
|
||||
# "image-face-swap": {...}, # Basic version ($0.01)
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.client = WaveSpeedClient()
|
||||
|
||||
def _validate_options(self, options: FaceSwapOptions) -> None:
|
||||
"""Validate face swap options."""
|
||||
if not options.base_image_base64:
|
||||
raise ValueError("base_image_base64 is required")
|
||||
if not options.face_image_base64:
|
||||
raise ValueError("face_image_base64 is required")
|
||||
|
||||
# Validate model
|
||||
if options.model and options.model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {options.model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
def _extract_image_url(self, data_url: str) -> str:
|
||||
"""Extract image URL from data URL or return as-is if already a URL."""
|
||||
if data_url.startswith("data:image"):
|
||||
# It's a data URL, we'll need to upload it
|
||||
return data_url
|
||||
return data_url
|
||||
|
||||
def _upload_image_if_needed(self, image_data: str) -> str:
|
||||
"""Upload image if it's a base64 data URL, otherwise return URL."""
|
||||
if image_data.startswith("data:image"):
|
||||
# Extract base64 data
|
||||
header, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
|
||||
# Upload to temporary storage (or use WaveSpeed upload endpoint if available)
|
||||
# For now, we'll return the data URL and let the API handle it
|
||||
# In production, you might want to upload to S3/CloudFlare first
|
||||
return image_data
|
||||
return image_data
|
||||
|
||||
def _call_wavespeed_face_swap_api(
|
||||
self, options: FaceSwapOptions, model_info: Dict[str, Any]
|
||||
) -> ImageGenerationResult:
|
||||
"""Call WaveSpeed face swap API."""
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
model_path = model_info["model_path"]
|
||||
api_params = model_info.get("api_params", {})
|
||||
uses_source_target_arrays = api_params.get("uses_source_target_arrays", False)
|
||||
|
||||
# Prepare images - extract base64 if data URI
|
||||
base_image = options.base_image_base64
|
||||
if base_image.startswith("data:image"):
|
||||
# Keep as data URI - API should accept it
|
||||
pass
|
||||
elif not base_image.startswith("http"):
|
||||
# Assume it's base64, convert to data URI
|
||||
base_image = f"data:image/png;base64,{base_image}"
|
||||
|
||||
face_image = options.face_image_base64
|
||||
if face_image.startswith("data:image"):
|
||||
# Keep as data URI
|
||||
pass
|
||||
elif not face_image.startswith("http"):
|
||||
# Assume it's base64, convert to data URI
|
||||
face_image = f"data:image/png;base64,{face_image}"
|
||||
|
||||
# Build API payload - handle different API formats
|
||||
uses_source_target_names = api_params.get("uses_source_target_names", False)
|
||||
|
||||
if uses_source_target_arrays:
|
||||
# Akool format: uses source_image and target_image as arrays
|
||||
# For single face swap: source_image is the new face, target_image is reference from main image
|
||||
# Since we only have one face_image, we'll use it as source and the base_image as target reference
|
||||
payload = {
|
||||
"image": base_image,
|
||||
"source_image": [face_image], # Array of source faces (1-5) - the new face to swap in
|
||||
"target_image": [base_image], # Array of target faces (1-5) - reference from main image
|
||||
"face_enhance": api_params.get("supports_face_enhance", True), # Default to True for Akool
|
||||
"enable_base64_output": True,
|
||||
}
|
||||
|
||||
# Allow override from extra params
|
||||
if options.extra:
|
||||
if "source_image" in options.extra:
|
||||
payload["source_image"] = options.extra["source_image"]
|
||||
if "target_image" in options.extra:
|
||||
payload["target_image"] = options.extra["target_image"]
|
||||
if "face_enhance" in options.extra:
|
||||
payload["face_enhance"] = options.extra["face_enhance"]
|
||||
elif uses_source_target_names:
|
||||
# InfiniteYou format: uses source_image and target_image (single values, different names)
|
||||
# target_image = base image (where face will be swapped)
|
||||
# source_image = face image (face to swap in)
|
||||
payload = {
|
||||
"target_image": base_image, # Base image where face will be swapped
|
||||
"source_image": face_image, # Face to swap in
|
||||
"enable_base64_output": True,
|
||||
}
|
||||
|
||||
# Add seed if supported
|
||||
if api_params.get("supports_seed", False):
|
||||
seed = options.extra.get("seed") if options.extra else None
|
||||
payload["seed"] = seed if seed is not None else -1 # Default to -1 (random)
|
||||
|
||||
# Allow override from extra params
|
||||
if options.extra:
|
||||
if "source_image" in options.extra:
|
||||
payload["source_image"] = options.extra["source_image"]
|
||||
if "target_image" in options.extra:
|
||||
payload["target_image"] = options.extra["target_image"]
|
||||
if "seed" in options.extra and api_params.get("supports_seed", False):
|
||||
payload["seed"] = options.extra["seed"]
|
||||
else:
|
||||
# Standard format: uses image and face_image (single values)
|
||||
payload = {
|
||||
"image": base_image,
|
||||
"face_image": face_image,
|
||||
"output_format": api_params.get("output_format", "jpeg"),
|
||||
"enable_base64_output": True, # Always get base64 for our use case
|
||||
"enable_sync_mode": True, # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
# Add any extra parameters (filter out already handled ones)
|
||||
if options.extra:
|
||||
handled_keys = {"source_image", "target_image", "face_enhance", "output_format", "enable_sync_mode", "seed"}
|
||||
for key, value in options.extra.items():
|
||||
if key not in handled_keys:
|
||||
payload[key] = value
|
||||
|
||||
url = f"{self.client.BASE_URL}/{model_path}"
|
||||
headers = self.client._headers()
|
||||
|
||||
logger.info(f"[Face Swap] Calling WaveSpeed API: {url}")
|
||||
logger.debug(f"[Face Swap] Payload keys: {list(payload.keys())}")
|
||||
|
||||
try:
|
||||
# Call API
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[Face Swap] API call failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed face swap failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status - Akool uses different status values
|
||||
status = data.get("status", "").lower()
|
||||
# Akool uses "output" (singular), others use "outputs" (plural)
|
||||
outputs = data.get("outputs") or data.get("output") or []
|
||||
# Normalize to list if it's a single value
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs] if outputs else []
|
||||
|
||||
prediction_id = data.get("id")
|
||||
|
||||
# Handle completed status - Akool uses "succeeded", others use "completed"
|
||||
is_completed = status in ["completed", "succeeded"]
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if outputs and is_completed:
|
||||
logger.info(f"[Face Swap] Got immediate results (status: {status})")
|
||||
# Extract image URL or base64
|
||||
output = outputs[0]
|
||||
if output.startswith("data:image") or output.startswith("http"):
|
||||
if output.startswith("http"):
|
||||
# Download from URL
|
||||
import requests
|
||||
img_response = requests.get(output, timeout=60)
|
||||
img_response.raise_for_status()
|
||||
image_bytes = img_response.content
|
||||
else:
|
||||
# Extract base64 from data URI
|
||||
image_bytes = base64.b64decode(output.split(",", 1)[1])
|
||||
else:
|
||||
# Assume it's base64 string
|
||||
image_bytes = base64.b64decode(output)
|
||||
elif prediction_id:
|
||||
# Need to poll
|
||||
logger.info(f"[Face Swap] Polling for result (prediction_id: {prediction_id}, status: {status})")
|
||||
result = self.client.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=1.0)
|
||||
# Check both outputs and output fields
|
||||
outputs = result.get("outputs") or result.get("output") or []
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs] if outputs else []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed face swap returned no outputs")
|
||||
output = outputs[0]
|
||||
if output.startswith("http"):
|
||||
import requests
|
||||
img_response = requests.get(output, timeout=60)
|
||||
img_response.raise_for_status()
|
||||
image_bytes = img_response.content
|
||||
elif output.startswith("data:image"):
|
||||
image_bytes = base64.b64decode(output.split(",", 1)[1])
|
||||
else:
|
||||
image_bytes = base64.b64decode(output)
|
||||
else:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed face swap response missing outputs and prediction ID")
|
||||
|
||||
# Get image dimensions
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = img.size
|
||||
|
||||
logger.info(f"[Face Swap] ✅ Successfully swapped face: {len(image_bytes)} bytes, {width}x{height}")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=options.model or model_path,
|
||||
metadata={
|
||||
"model_path": model_path,
|
||||
"status": status,
|
||||
"created_at": data.get("created_at"),
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap] API call failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Face swap failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
|
||||
"""Swap face in image using WaveSpeed models."""
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model_id = options.model
|
||||
if not model_id:
|
||||
# Default to first available model
|
||||
model_id = list(self.SUPPORTED_MODELS.keys())[0]
|
||||
logger.info(f"[Face Swap] No model specified, using default: {model_id}")
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model_id]
|
||||
|
||||
# Call API
|
||||
return self._call_wavespeed_face_swap_api(options, model_info)
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available face swap models and their information."""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@classmethod
|
||||
def get_models_by_tier(cls, tier: str) -> dict:
|
||||
"""Get models filtered by tier (budget, mid, premium)."""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if model_info.get("tier") == tier
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models_by_capability(cls, capability: str) -> dict:
|
||||
"""Get models that support a specific capability."""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if capability in model_info.get("capabilities", [])
|
||||
}
|
||||
Reference in New Issue
Block a user