AI Researcher and Video Studio implementation complete
This commit is contained in:
@@ -1,4 +1,12 @@
|
||||
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
|
||||
from .base import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
ImageGenerationProvider,
|
||||
ImageEditOptions,
|
||||
ImageEditProvider,
|
||||
FaceSwapOptions,
|
||||
FaceSwapProvider,
|
||||
)
|
||||
from .hf_provider import HuggingFaceImageProvider
|
||||
from .gemini_provider import GeminiImageProvider
|
||||
from .stability_provider import StabilityImageProvider
|
||||
@@ -8,6 +16,10 @@ __all__ = [
|
||||
"ImageGenerationOptions",
|
||||
"ImageGenerationResult",
|
||||
"ImageGenerationProvider",
|
||||
"ImageEditOptions",
|
||||
"ImageEditProvider",
|
||||
"FaceSwapOptions",
|
||||
"FaceSwapProvider",
|
||||
"HuggingFaceImageProvider",
|
||||
"GeminiImageProvider",
|
||||
"StabilityImageProvider",
|
||||
|
||||
@@ -28,6 +28,50 @@ class ImageGenerationResult:
|
||||
metadata: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageEditOptions:
|
||||
"""Options for image editing operations."""
|
||||
image_base64: str
|
||||
prompt: str
|
||||
operation: str # "general_edit", "inpaint", "outpaint", "remove_background", etc.
|
||||
mask_base64: Optional[str] = None
|
||||
negative_prompt: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
width: Optional[int] = None
|
||||
height: Optional[int] = None
|
||||
guidance_scale: Optional[float] = None
|
||||
steps: Optional[int] = None
|
||||
seed: Optional[int] = None
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API calls."""
|
||||
result = {
|
||||
"image_base64": self.image_base64,
|
||||
"prompt": self.prompt,
|
||||
"operation": self.operation,
|
||||
}
|
||||
if self.mask_base64:
|
||||
result["mask_base64"] = self.mask_base64
|
||||
if self.negative_prompt:
|
||||
result["negative_prompt"] = self.negative_prompt
|
||||
if self.model:
|
||||
result["model"] = self.model
|
||||
if self.width:
|
||||
result["width"] = self.width
|
||||
if self.height:
|
||||
result["height"] = self.height
|
||||
if self.guidance_scale is not None:
|
||||
result["guidance_scale"] = self.guidance_scale
|
||||
if self.steps:
|
||||
result["steps"] = self.steps
|
||||
if self.seed is not None:
|
||||
result["seed"] = self.seed
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
class ImageGenerationProvider(Protocol):
|
||||
"""Protocol for image generation providers."""
|
||||
|
||||
@@ -35,3 +79,44 @@ class ImageGenerationProvider(Protocol):
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class FaceSwapOptions:
|
||||
"""Options for face swap operations."""
|
||||
base_image_base64: str # Image to swap face into
|
||||
face_image_base64: str # Face to swap
|
||||
model: Optional[str] = None
|
||||
target_face_index: Optional[int] = None # For multi-face images (0 = largest)
|
||||
target_gender: Optional[str] = None # "all", "female", "male" (for some models)
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for API calls."""
|
||||
result = {
|
||||
"base_image_base64": self.base_image_base64,
|
||||
"face_image_base64": self.face_image_base64,
|
||||
}
|
||||
if self.model:
|
||||
result["model"] = self.model
|
||||
if self.target_face_index is not None:
|
||||
result["target_face_index"] = self.target_face_index
|
||||
if self.target_gender:
|
||||
result["target_gender"] = self.target_gender
|
||||
if self.extra:
|
||||
result.update(self.extra)
|
||||
return result
|
||||
|
||||
|
||||
class ImageEditProvider(Protocol):
|
||||
"""Protocol for image editing providers."""
|
||||
|
||||
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
class FaceSwapProvider(Protocol):
|
||||
"""Protocol for face swap providers."""
|
||||
|
||||
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
|
||||
...
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,691 @@
|
||||
"""WaveSpeed AI image editing provider (14 editing models)."""
|
||||
|
||||
import io
|
||||
import os
|
||||
import requests
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import ImageEditProvider, ImageEditOptions, ImageGenerationResult
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
logger = get_service_logger("wavespeed.edit_provider")
|
||||
|
||||
|
||||
class WaveSpeedEditProvider(ImageEditProvider):
|
||||
"""WaveSpeed AI image editing provider supporting 14 editing models.
|
||||
|
||||
REUSES: WaveSpeedClient, model registry pattern, result format
|
||||
"""
|
||||
|
||||
# Model registry - populated with WaveSpeed editing models
|
||||
SUPPORTED_MODELS = {
|
||||
"qwen-edit": {
|
||||
"model_path": "wavespeed-ai/qwen-image/edit",
|
||||
"name": "Qwen Image Edit",
|
||||
"description": "20B MMDiT image-to-image model offering precise bilingual (Chinese & English) text edits while preserving style. Single-image editing with style preservation.",
|
||||
"cost": 0.02, # Same as Plus version
|
||||
"max_resolution": (1536, 1536), # Based on docs: similar to Plus
|
||||
"capabilities": ["general_edit", "style_transfer", "text_edit"],
|
||||
"tier": "budget",
|
||||
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
||||
"supports_controlnet": False, # Not mentioned in docs
|
||||
"languages": ["en", "zh"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
||||
"default_output_format": "jpeg", # Per API docs: default is "jpeg"
|
||||
"supports_seed": True, # Per API docs: seed parameter supported
|
||||
}
|
||||
},
|
||||
"qwen-edit-plus": {
|
||||
"model_path": "wavespeed-ai/qwen-image/edit-plus",
|
||||
"name": "Qwen Image Edit Plus",
|
||||
"description": "20B MMDiT image editor with multi-image editing, single-image consistency and native ControlNet support. Bilingual (CN/EN) text editing, appearance-level and semantic-level edits.",
|
||||
"cost": 0.02,
|
||||
"max_resolution": (1536, 1536), # Based on docs: 256-1536 per dimension
|
||||
"capabilities": ["general_edit", "style_transfer", "text_edit", "multi_image"],
|
||||
"tier": "budget",
|
||||
"supports_multi_image": True, # Up to 3 reference images
|
||||
"supports_controlnet": True,
|
||||
"languages": ["en", "zh"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": False, # Uses "images" (array)
|
||||
"supports_seed": True, # Seed parameter supported (default for Qwen models)
|
||||
}
|
||||
},
|
||||
"nano-banana-pro-edit-ultra": {
|
||||
"model_path": "google/nano-banana-pro/edit-ultra",
|
||||
"name": "Google Nano Banana Pro Edit Ultra",
|
||||
"description": "High-resolution image editing with 4K/8K native output. Natural language instructions, multilingual text support. Premium quality editing for professional marketing and high-res work.",
|
||||
"cost": 0.15, # 4K - from enhancement proposal
|
||||
"cost_8k": 0.18, # 8K - from enhancement proposal
|
||||
"max_resolution": (8192, 8192), # 8K support
|
||||
"capabilities": ["general_edit", "high_res", "professional", "typography"],
|
||||
"tier": "premium",
|
||||
"supports_multi_image": True, # Up to 14 reference images
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en", "multilingual"],
|
||||
"api_params": {
|
||||
"uses_size": False, # Uses aspect_ratio and resolution instead
|
||||
"uses_aspect_ratio": True, # "1:1", "16:9", etc.
|
||||
"uses_resolution": True, # "4k" or "8k"
|
||||
"max_images": 14,
|
||||
"default_output_format": "png", # Per API docs: default is "png"
|
||||
"supports_seed": False, # Per API docs: no seed parameter
|
||||
}
|
||||
},
|
||||
"seedream-v4.5-edit": {
|
||||
"model_path": "bytedance/seedream-v4.5/edit",
|
||||
"name": "Bytedance Seedream V4.5 Edit",
|
||||
"description": "Preserves facial features, lighting, and color tone from reference images, delivering professional, high-fidelity edits up to 4K with strong prompt adherence. Reference-faithful editing with multi-image support.",
|
||||
"cost": 0.04, # Per generated image
|
||||
"max_resolution": (4096, 4096), # 4K support (1024-4096 per dimension)
|
||||
"capabilities": ["general_edit", "portrait_retouching", "fashion_edit", "product_edit", "multi_image"],
|
||||
"tier": "mid",
|
||||
"supports_multi_image": True, # Up to 10 reference images
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en"],
|
||||
"api_params": {
|
||||
"uses_size": True, # Uses "size" parameter (width*height format, 1024-4096 per dimension)
|
||||
"uses_aspect_ratio": False,
|
||||
"uses_resolution": False,
|
||||
"max_images": 10,
|
||||
"default_output_format": "png",
|
||||
"supports_seed": False, # No seed parameter in API docs (Seedream V4.5)
|
||||
}
|
||||
},
|
||||
"flux-kontext-pro": {
|
||||
"model_path": "wavespeed-ai/flux-kontext-pro",
|
||||
"name": "FLUX Kontext Pro",
|
||||
"description": "FLUX.1 Kontext [pro] offers improved prompt adherence and accurate typography generation for consistent, high-quality edits at speed. Typography-focused editing with improved prompt adherence.",
|
||||
"cost": 0.04, # From enhancement proposal
|
||||
"max_resolution": (2048, 2048), # Estimated, not specified in docs
|
||||
"capabilities": ["general_edit", "typography", "text_edit", "style_transfer"],
|
||||
"tier": "mid",
|
||||
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
||||
"supports_controlnet": False,
|
||||
"languages": ["en"],
|
||||
"api_params": {
|
||||
"uses_size": False, # Uses aspect_ratio instead
|
||||
"uses_aspect_ratio": True, # Aspect ratio as string (e.g., "16:9", "1:1")
|
||||
"uses_resolution": False,
|
||||
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
||||
"supports_guidance_scale": True, # Has guidance_scale parameter (default 3.5, range 1-20)
|
||||
"default_guidance_scale": 3.5, # Per API docs
|
||||
"supports_seed": False, # No seed parameter in API docs
|
||||
}
|
||||
},
|
||||
# TODO: Add remaining 9 models once docs are provided
|
||||
}
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
"""Initialize WaveSpeed edit provider.
|
||||
|
||||
Args:
|
||||
api_key: WaveSpeed API key (falls back to env var if not provided)
|
||||
"""
|
||||
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
|
||||
|
||||
# REUSE: Same client as generation provider
|
||||
self.client = WaveSpeedClient(api_key=self.api_key)
|
||||
logger.info("[WaveSpeed Edit Provider] Initialized with %d models",
|
||||
len(self.SUPPORTED_MODELS))
|
||||
|
||||
def _validate_options(self, options: ImageEditOptions) -> None:
|
||||
"""Validate editing options.
|
||||
|
||||
Args:
|
||||
options: Image editing options
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
model = options.model or list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None
|
||||
|
||||
if not model:
|
||||
raise ValueError("No model specified and no default model available")
|
||||
|
||||
if model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
max_width, max_height = model_info.get("max_resolution", (4096, 4096))
|
||||
|
||||
if options.width and options.width > max_width:
|
||||
raise ValueError(
|
||||
f"Width {options.width} exceeds maximum {max_width} for model {model}"
|
||||
)
|
||||
|
||||
if options.height and options.height > max_height:
|
||||
raise ValueError(
|
||||
f"Height {options.height} exceeds maximum {max_height} for model {model}"
|
||||
)
|
||||
|
||||
if not options.prompt or len(options.prompt.strip()) == 0:
|
||||
raise ValueError("Prompt cannot be empty")
|
||||
|
||||
if not options.image_base64:
|
||||
raise ValueError("Image base64 cannot be empty")
|
||||
|
||||
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
|
||||
"""Edit image using WaveSpeed AI models.
|
||||
|
||||
Args:
|
||||
options: Image editing options
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
ValueError: If options are invalid
|
||||
RuntimeError: If editing fails
|
||||
"""
|
||||
# Validate options
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model = options.model or (list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None)
|
||||
if not model:
|
||||
raise ValueError("No model available for editing")
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model]
|
||||
model_path = model_info["model_path"]
|
||||
|
||||
logger.info("[WaveSpeed Edit] Starting edit: model=%s, operation=%s, prompt=%s",
|
||||
model, options.operation, options.prompt[:100])
|
||||
|
||||
try:
|
||||
# Prepare extra parameters based on model capabilities
|
||||
extra_params = options.extra or {}
|
||||
|
||||
# Add model-specific parameters if needed
|
||||
api_params = model_info.get("api_params", {})
|
||||
if api_params.get("uses_resolution", False):
|
||||
# For Nano Banana: determine resolution from dimensions or use default
|
||||
if options.width and options.height:
|
||||
if options.width >= 4096 or options.height >= 4096:
|
||||
extra_params["resolution"] = "8k"
|
||||
else:
|
||||
extra_params["resolution"] = "4k"
|
||||
elif "resolution" not in extra_params:
|
||||
extra_params["resolution"] = "4k" # Default to 4K
|
||||
|
||||
if api_params.get("uses_aspect_ratio", False) and not extra_params.get("aspect_ratio"):
|
||||
# Calculate aspect ratio if dimensions provided
|
||||
if options.width and options.height:
|
||||
aspect_ratio = self._calculate_aspect_ratio(options.width, options.height)
|
||||
if aspect_ratio:
|
||||
extra_params["aspect_ratio"] = aspect_ratio
|
||||
|
||||
# Call WaveSpeed API for editing
|
||||
result = self._call_wavespeed_edit_api(
|
||||
model_path=model_path,
|
||||
image_base64=options.image_base64,
|
||||
prompt=options.prompt,
|
||||
operation=options.operation,
|
||||
mask_base64=options.mask_base64,
|
||||
negative_prompt=options.negative_prompt,
|
||||
width=options.width,
|
||||
height=options.height,
|
||||
guidance_scale=options.guidance_scale,
|
||||
steps=options.steps,
|
||||
seed=options.seed,
|
||||
extra=extra_params
|
||||
)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
image_bytes = result
|
||||
elif isinstance(result, dict) and "image" in result:
|
||||
image_bytes = result["image"]
|
||||
elif isinstance(result, dict) and "image_bytes" in result:
|
||||
image_bytes = result["image_bytes"]
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
||||
|
||||
# Load image to get dimensions
|
||||
image = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = image.size
|
||||
|
||||
# Calculate estimated cost - handle resolution-based pricing
|
||||
estimated_cost = model_info.get("cost", 0.02)
|
||||
if api_params.get("uses_resolution", False):
|
||||
# Check if 8K was requested
|
||||
resolution = extra_params.get("resolution", "4k")
|
||||
if resolution == "8k" and "cost_8k" in model_info:
|
||||
estimated_cost = model_info["cost_8k"]
|
||||
|
||||
logger.info("[WaveSpeed Edit] ✅ Successfully edited image: %d bytes, %dx%d",
|
||||
len(image_bytes), width, height)
|
||||
|
||||
# REUSE: Same result format as generation
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=model,
|
||||
seed=options.seed,
|
||||
metadata={
|
||||
"provider": "wavespeed",
|
||||
"model": model,
|
||||
"model_name": model_info.get("name", model),
|
||||
"operation": options.operation,
|
||||
"prompt": options.prompt,
|
||||
"negative_prompt": options.negative_prompt,
|
||||
"estimated_cost": estimated_cost,
|
||||
"tier": model_info.get("tier", "mid"),
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("[WaveSpeed Edit] ❌ Error editing image: %s", str(e), exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed edit failed: {str(e)}")
|
||||
|
||||
def _call_wavespeed_edit_api(
|
||||
self,
|
||||
model_path: str,
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str,
|
||||
mask_base64: Optional[str] = None,
|
||||
negative_prompt: Optional[str] = None,
|
||||
width: Optional[int] = None,
|
||||
height: Optional[int] = None,
|
||||
guidance_scale: Optional[float] = None,
|
||||
steps: Optional[int] = None,
|
||||
seed: Optional[int] = None,
|
||||
extra: Optional[dict] = None
|
||||
) -> bytes:
|
||||
"""Call WaveSpeed API for image editing.
|
||||
|
||||
REUSES: Same pattern as ImageGenerator.generate_image()
|
||||
|
||||
Args:
|
||||
model_path: Full model path (e.g., "wavespeed-ai/qwen-image/edit-plus")
|
||||
image_base64: Base64-encoded input image
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of operation
|
||||
mask_base64: Optional mask for inpainting
|
||||
negative_prompt: Optional negative prompt
|
||||
width: Optional target width
|
||||
height: Optional target height
|
||||
guidance_scale: Optional guidance scale (not used by all models)
|
||||
steps: Optional number of steps (not used by all models)
|
||||
seed: Optional seed
|
||||
extra: Optional extra parameters
|
||||
|
||||
Returns:
|
||||
Edited image bytes
|
||||
|
||||
Raises:
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
# Build URL - REUSES same pattern as ImageGenerator
|
||||
url = f"{self.client.BASE_URL}/{model_path}"
|
||||
|
||||
# Prepare images array - WaveSpeed expects array of image strings
|
||||
# Format: base64 strings or data URIs (data:image/png;base64,...)
|
||||
# For Qwen Image Edit Plus: supports up to 3 reference images
|
||||
images = []
|
||||
|
||||
# Add main image - check if it's already a data URI or just base64
|
||||
if image_base64.startswith("data:image"):
|
||||
# Already a data URI
|
||||
images.append(image_base64)
|
||||
else:
|
||||
# Assume it's base64, convert to data URI
|
||||
# Try to detect format from base64 or default to PNG
|
||||
images.append(f"data:image/png;base64,{image_base64}")
|
||||
|
||||
# If mask is provided, add it as second image
|
||||
# Note: Some models may need mask in different format - will adjust per model
|
||||
if mask_base64:
|
||||
if mask_base64.startswith("data:image"):
|
||||
images.append(mask_base64)
|
||||
else:
|
||||
images.append(f"data:image/png;base64,{mask_base64}")
|
||||
|
||||
# Get model info to determine API parameter structure
|
||||
model_info = self.SUPPORTED_MODELS.get(model_path.split("/")[-1] if "/" in model_path else model_path)
|
||||
if not model_info:
|
||||
# Fallback: try to find model by matching path
|
||||
for model_id, info in self.SUPPORTED_MODELS.items():
|
||||
if info["model_path"] == model_path:
|
||||
model_info = info
|
||||
break
|
||||
|
||||
if not model_info:
|
||||
raise ValueError(f"Model info not found for: {model_path}")
|
||||
|
||||
api_params = model_info.get("api_params", {})
|
||||
|
||||
# Build payload - following WaveSpeed API structure
|
||||
# Note: output_format default varies by model (PNG for most, but can be JPEG)
|
||||
default_output_format = api_params.get("default_output_format", "png")
|
||||
|
||||
# Some models use "image" (singular) instead of "images" (array)
|
||||
uses_image_singular = api_params.get("uses_image_singular", False)
|
||||
|
||||
payload = {
|
||||
"prompt": prompt,
|
||||
"enable_sync_mode": True, # Use sync mode for immediate results
|
||||
"enable_base64_output": False, # Get URL, then download
|
||||
"output_format": default_output_format,
|
||||
}
|
||||
|
||||
# Add image(s) based on model API format
|
||||
if uses_image_singular:
|
||||
# Models like Qwen Edit (basic) use "image" (singular)
|
||||
# Use first image only (single image editing)
|
||||
if images:
|
||||
payload["image"] = images[0]
|
||||
else:
|
||||
raise ValueError("At least one image is required")
|
||||
else:
|
||||
# Models like Qwen Edit Plus, Nano Banana use "images" (array)
|
||||
payload["images"] = images
|
||||
|
||||
# Allow override of output_format from extra params
|
||||
if extra and "output_format" in extra:
|
||||
payload["output_format"] = extra["output_format"]
|
||||
|
||||
# Model-specific parameter handling
|
||||
if api_params.get("uses_size", True):
|
||||
# Models like Qwen Edit Plus use "size" parameter (width*height format)
|
||||
if width and height:
|
||||
payload["size"] = f"{width}*{height}"
|
||||
elif width:
|
||||
payload["size"] = f"{width}*{width}" # Square if only width provided
|
||||
elif height:
|
||||
payload["size"] = f"{height}*{height}" # Square if only height provided
|
||||
|
||||
if api_params.get("uses_aspect_ratio", False):
|
||||
# Models like Nano Banana and FLUX Kontext Pro use "aspect_ratio" parameter
|
||||
if width and height:
|
||||
# Calculate aspect ratio from dimensions
|
||||
aspect_ratio = self._calculate_aspect_ratio(width, height)
|
||||
if aspect_ratio:
|
||||
payload["aspect_ratio"] = aspect_ratio
|
||||
elif extra and "aspect_ratio" in extra:
|
||||
payload["aspect_ratio"] = extra["aspect_ratio"]
|
||||
|
||||
if api_params.get("uses_resolution", False):
|
||||
# Models like Nano Banana use "resolution" parameter ("4k" or "8k")
|
||||
if extra and "resolution" in extra:
|
||||
payload["resolution"] = extra["resolution"]
|
||||
else:
|
||||
# Default to 4K, or 8K if dimensions suggest high-res
|
||||
if width and height and (width >= 4096 or height >= 4096):
|
||||
payload["resolution"] = "8k"
|
||||
else:
|
||||
payload["resolution"] = "4k" # Default to 4K per API docs
|
||||
|
||||
# Add optional parameters (model-agnostic)
|
||||
# Guidance scale: Only add if model supports it (e.g., FLUX Kontext Pro)
|
||||
if api_params.get("supports_guidance_scale", False):
|
||||
default_guidance = api_params.get("default_guidance_scale", 3.5)
|
||||
if guidance_scale is not None:
|
||||
# Clamp to valid range (1-20 per FLUX Kontext Pro docs)
|
||||
payload["guidance_scale"] = max(1, min(20, guidance_scale))
|
||||
elif extra and "guidance_scale" in extra:
|
||||
payload["guidance_scale"] = max(1, min(20, extra["guidance_scale"]))
|
||||
else:
|
||||
payload["guidance_scale"] = default_guidance
|
||||
|
||||
# Seed parameter: Only add if model supports it
|
||||
if api_params.get("supports_seed", True): # Default to True for backward compatibility
|
||||
if seed is not None:
|
||||
payload["seed"] = seed
|
||||
else:
|
||||
payload["seed"] = -1 # Random seed (per API docs default)
|
||||
|
||||
# Add any extra parameters
|
||||
if extra:
|
||||
# Filter out parameters we've already handled
|
||||
handled_params = {"aspect_ratio", "resolution", "size", "seed", "guidance_scale"}
|
||||
for key, value in extra.items():
|
||||
if key not in handled_params:
|
||||
payload[key] = value
|
||||
|
||||
logger.info(f"[WaveSpeed Edit] Submitting edit request to {url} (model={model_path}, prompt_length={len(prompt)})")
|
||||
|
||||
# Make API call - REUSES same pattern as ImageGenerator
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.client._headers(),
|
||||
json=payload,
|
||||
timeout=120
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed Edit] API call failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed image editing failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text[:500],
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status
|
||||
status = data.get("status", "").lower()
|
||||
outputs = data.get("outputs") or []
|
||||
prediction_id = data.get("id")
|
||||
|
||||
logger.debug(
|
||||
f"[WaveSpeed Edit] Response: status='{status}', outputs_count={len(outputs)}, "
|
||||
f"prediction_id={prediction_id}"
|
||||
)
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if outputs and status == "completed":
|
||||
logger.info(f"[WaveSpeed Edit] Got immediate results from sync mode")
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=120)
|
||||
|
||||
# Sync mode returned "created" or "processing" - need to poll
|
||||
if not prediction_id:
|
||||
logger.error(f"[WaveSpeed Edit] Sync mode returned status '{status}' but no prediction ID")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed sync mode returned async response without prediction ID",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[WaveSpeed Edit] Sync mode returned status '{status}' with no outputs. "
|
||||
f"Polling for result (prediction_id: {prediction_id})"
|
||||
)
|
||||
|
||||
# Poll for result - REUSES polling utility
|
||||
result = self.client.poll_until_complete(
|
||||
prediction_id,
|
||||
timeout_seconds=180,
|
||||
interval_seconds=2.0,
|
||||
)
|
||||
|
||||
outputs = result.get("outputs") or []
|
||||
if not outputs:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned no outputs after polling"
|
||||
)
|
||||
|
||||
# Extract image URL from outputs - REUSE helper method
|
||||
image_url = self._extract_image_url(outputs)
|
||||
return self._download_image(image_url, timeout=120)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[WaveSpeed Edit] Unexpected error: {str(e)}", exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed edit API call failed: {str(e)}")
|
||||
|
||||
def _extract_image_url(self, outputs: list) -> str:
|
||||
"""Extract image URL from outputs - REUSES same pattern as ImageGenerator.
|
||||
|
||||
Args:
|
||||
outputs: Array of output URLs or objects
|
||||
|
||||
Returns:
|
||||
Image URL string
|
||||
|
||||
Raises:
|
||||
HTTPException: If output format is invalid
|
||||
"""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned no outputs",
|
||||
)
|
||||
|
||||
first_output = outputs[0]
|
||||
if isinstance(first_output, str):
|
||||
image_url = first_output
|
||||
elif isinstance(first_output, dict):
|
||||
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit output format not recognized",
|
||||
)
|
||||
|
||||
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed edit returned invalid image URL",
|
||||
)
|
||||
|
||||
return image_url
|
||||
|
||||
def _download_image(self, image_url: str, timeout: int = 120) -> bytes:
|
||||
"""Download image from URL - REUSES same pattern as ImageGenerator.
|
||||
|
||||
Args:
|
||||
image_url: URL to download from
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Image bytes
|
||||
|
||||
Raises:
|
||||
HTTPException: If download fails
|
||||
"""
|
||||
logger.info(f"[WaveSpeed Edit] Downloading edited image from: {image_url}")
|
||||
image_response = requests.get(image_url, timeout=timeout)
|
||||
|
||||
if image_response.status_code != 200:
|
||||
logger.error(f"[WaveSpeed Edit] Failed to download image: {image_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail=f"Failed to download edited image: {image_response.status_code}"
|
||||
)
|
||||
|
||||
logger.info(f"[WaveSpeed Edit] Successfully downloaded image ({len(image_response.content)} bytes)")
|
||||
return image_response.content
|
||||
|
||||
def _calculate_aspect_ratio(self, width: int, height: int) -> Optional[str]:
|
||||
"""Calculate aspect ratio string from dimensions.
|
||||
|
||||
Args:
|
||||
width: Image width
|
||||
height: Image height
|
||||
|
||||
Returns:
|
||||
Aspect ratio string (e.g., "16:9") or None if not standard
|
||||
"""
|
||||
# Common aspect ratios (includes FLUX Kontext Pro supported ratios)
|
||||
ratios = {
|
||||
(1, 1): "1:1",
|
||||
(3, 2): "3:2",
|
||||
(2, 3): "2:3",
|
||||
(3, 4): "3:4",
|
||||
(4, 3): "4:3",
|
||||
(4, 5): "4:5",
|
||||
(5, 4): "5:4",
|
||||
(9, 16): "9:16",
|
||||
(16, 9): "16:9",
|
||||
(21, 9): "21:9",
|
||||
(9, 21): "9:21", # FLUX Kontext Pro also supports 9:21
|
||||
}
|
||||
|
||||
# Calculate GCD to simplify ratio
|
||||
def gcd(a, b):
|
||||
while b:
|
||||
a, b = b, a % b
|
||||
return a
|
||||
|
||||
divisor = gcd(width, height)
|
||||
simplified = (width // divisor, height // divisor)
|
||||
|
||||
# Check if it matches a standard ratio (with some tolerance)
|
||||
for (w, h), ratio_str in ratios.items():
|
||||
# Allow small tolerance for rounding
|
||||
if abs(simplified[0] / simplified[1] - w / h) < 0.01:
|
||||
return ratio_str
|
||||
|
||||
# If no match, return None (model may not support custom aspect ratios)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available editing models and their information.
|
||||
|
||||
Returns:
|
||||
Dictionary of available models
|
||||
"""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@classmethod
|
||||
def get_models_by_tier(cls, tier: str) -> dict:
|
||||
"""Get models filtered by tier (budget, mid, premium).
|
||||
|
||||
Args:
|
||||
tier: Tier name ("budget", "mid", "premium")
|
||||
|
||||
Returns:
|
||||
Dictionary of models in the specified tier
|
||||
"""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if model_info.get("tier") == tier
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models_by_operation(cls, operation: str) -> dict:
|
||||
"""Get models that support a specific operation.
|
||||
|
||||
Args:
|
||||
operation: Operation type (e.g., "inpaint", "outpaint", "general_edit")
|
||||
|
||||
Returns:
|
||||
Dictionary of models supporting the operation
|
||||
"""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if operation in model_info.get("capabilities", [])
|
||||
}
|
||||
@@ -0,0 +1,367 @@
|
||||
"""WaveSpeed Face Swap Provider for Image Studio."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
from services.llm_providers.image_generation.base import (
|
||||
FaceSwapOptions,
|
||||
FaceSwapProvider,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("llm_providers.wavespeed_face_swap")
|
||||
|
||||
|
||||
class WaveSpeedFaceSwapProvider:
|
||||
"""WaveSpeed provider for face swap operations."""
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"image-face-swap-pro": {
|
||||
"model_path": "wavespeed-ai/image-face-swap-pro",
|
||||
"name": "Image Face Swap Pro",
|
||||
"description": "Instant online AI face swap for photos with no watermark, delivering realistic, shareable results in seconds.",
|
||||
"cost": 0.025,
|
||||
"tier": "mid",
|
||||
"capabilities": ["face_swap", "realistic_blending"],
|
||||
"features": ["Enhanced blending", "Realistic results", "Watermark-free"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"output_format": "jpeg",
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
"image-head-swap": {
|
||||
"model_path": "wavespeed-ai/image-head-swap",
|
||||
"name": "Image Head Swap",
|
||||
"description": "Instant online AI head & face swap for photos with no watermark. Replaces entire head (face + hair + outline) while preserving body, pose and background.",
|
||||
"cost": 0.025,
|
||||
"tier": "mid",
|
||||
"capabilities": ["head_swap", "full_head_replacement", "realistic_blending"],
|
||||
"features": ["Full head replacement", "Hair included", "Pose preservation", "Watermark-free"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"output_format": "jpeg",
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
"akool-face-swap": {
|
||||
"model_path": "akool/image-face-swap",
|
||||
"name": "Akool Image Face Swap",
|
||||
"description": "Powerful AI-powered face swapping with multi-face replacement for group photos. Seamlessly replaces faces with natural lighting and skin tone matching.",
|
||||
"cost": 0.16,
|
||||
"tier": "premium",
|
||||
"capabilities": ["face_swap", "multi_face", "realistic_blending", "face_enhancement"],
|
||||
"features": ["Multi-face swapping (up to 5)", "Face enhancement", "Group photos", "High-quality blending"],
|
||||
"max_faces": 5, # Supports 1-5 faces
|
||||
"api_params": {
|
||||
"uses_source_target_arrays": True, # Uses source_image and target_image arrays
|
||||
"supports_face_enhance": True,
|
||||
"supports_base64": True,
|
||||
"supports_sync": False, # May need polling
|
||||
},
|
||||
},
|
||||
"infinite-you": {
|
||||
"model_path": "wavespeed-ai/infinite-you",
|
||||
"name": "InfiniteYou",
|
||||
"description": "High-quality face swapping powered by ByteDance's zero-shot identity preservation technology. Maintains facial identity characteristics with exceptional realism.",
|
||||
"cost": 0.03,
|
||||
"tier": "mid",
|
||||
"capabilities": ["face_swap", "identity_preservation", "realistic_blending"],
|
||||
"features": ["Zero-shot learning", "Identity preservation", "High-quality results", "Fast processing"],
|
||||
"max_faces": 1,
|
||||
"api_params": {
|
||||
"uses_source_target_names": True, # Uses source_image and target_image (not image/face_image)
|
||||
"target_is_base": True, # target_image is the base image (where face will be swapped)
|
||||
"source_is_face": True, # source_image is the face to swap in
|
||||
"supports_seed": True, # Supports seed parameter
|
||||
"supports_base64": True,
|
||||
"supports_sync": True,
|
||||
},
|
||||
},
|
||||
# Placeholder for additional models (will be added as docs are provided)
|
||||
# "image-face-swap": {...}, # Basic version ($0.01)
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.client = WaveSpeedClient()
|
||||
|
||||
def _validate_options(self, options: FaceSwapOptions) -> None:
|
||||
"""Validate face swap options."""
|
||||
if not options.base_image_base64:
|
||||
raise ValueError("base_image_base64 is required")
|
||||
if not options.face_image_base64:
|
||||
raise ValueError("face_image_base64 is required")
|
||||
|
||||
# Validate model
|
||||
if options.model and options.model not in self.SUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Unsupported model: {options.model}. "
|
||||
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
||||
)
|
||||
|
||||
def _extract_image_url(self, data_url: str) -> str:
|
||||
"""Extract image URL from data URL or return as-is if already a URL."""
|
||||
if data_url.startswith("data:image"):
|
||||
# It's a data URL, we'll need to upload it
|
||||
return data_url
|
||||
return data_url
|
||||
|
||||
def _upload_image_if_needed(self, image_data: str) -> str:
|
||||
"""Upload image if it's a base64 data URL, otherwise return URL."""
|
||||
if image_data.startswith("data:image"):
|
||||
# Extract base64 data
|
||||
header, encoded = image_data.split(",", 1)
|
||||
image_bytes = base64.b64decode(encoded)
|
||||
|
||||
# Upload to temporary storage (or use WaveSpeed upload endpoint if available)
|
||||
# For now, we'll return the data URL and let the API handle it
|
||||
# In production, you might want to upload to S3/CloudFlare first
|
||||
return image_data
|
||||
return image_data
|
||||
|
||||
def _call_wavespeed_face_swap_api(
|
||||
self, options: FaceSwapOptions, model_info: Dict[str, Any]
|
||||
) -> ImageGenerationResult:
|
||||
"""Call WaveSpeed face swap API."""
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
|
||||
model_path = model_info["model_path"]
|
||||
api_params = model_info.get("api_params", {})
|
||||
uses_source_target_arrays = api_params.get("uses_source_target_arrays", False)
|
||||
|
||||
# Prepare images - extract base64 if data URI
|
||||
base_image = options.base_image_base64
|
||||
if base_image.startswith("data:image"):
|
||||
# Keep as data URI - API should accept it
|
||||
pass
|
||||
elif not base_image.startswith("http"):
|
||||
# Assume it's base64, convert to data URI
|
||||
base_image = f"data:image/png;base64,{base_image}"
|
||||
|
||||
face_image = options.face_image_base64
|
||||
if face_image.startswith("data:image"):
|
||||
# Keep as data URI
|
||||
pass
|
||||
elif not face_image.startswith("http"):
|
||||
# Assume it's base64, convert to data URI
|
||||
face_image = f"data:image/png;base64,{face_image}"
|
||||
|
||||
# Build API payload - handle different API formats
|
||||
uses_source_target_names = api_params.get("uses_source_target_names", False)
|
||||
|
||||
if uses_source_target_arrays:
|
||||
# Akool format: uses source_image and target_image as arrays
|
||||
# For single face swap: source_image is the new face, target_image is reference from main image
|
||||
# Since we only have one face_image, we'll use it as source and the base_image as target reference
|
||||
payload = {
|
||||
"image": base_image,
|
||||
"source_image": [face_image], # Array of source faces (1-5) - the new face to swap in
|
||||
"target_image": [base_image], # Array of target faces (1-5) - reference from main image
|
||||
"face_enhance": api_params.get("supports_face_enhance", True), # Default to True for Akool
|
||||
"enable_base64_output": True,
|
||||
}
|
||||
|
||||
# Allow override from extra params
|
||||
if options.extra:
|
||||
if "source_image" in options.extra:
|
||||
payload["source_image"] = options.extra["source_image"]
|
||||
if "target_image" in options.extra:
|
||||
payload["target_image"] = options.extra["target_image"]
|
||||
if "face_enhance" in options.extra:
|
||||
payload["face_enhance"] = options.extra["face_enhance"]
|
||||
elif uses_source_target_names:
|
||||
# InfiniteYou format: uses source_image and target_image (single values, different names)
|
||||
# target_image = base image (where face will be swapped)
|
||||
# source_image = face image (face to swap in)
|
||||
payload = {
|
||||
"target_image": base_image, # Base image where face will be swapped
|
||||
"source_image": face_image, # Face to swap in
|
||||
"enable_base64_output": True,
|
||||
}
|
||||
|
||||
# Add seed if supported
|
||||
if api_params.get("supports_seed", False):
|
||||
seed = options.extra.get("seed") if options.extra else None
|
||||
payload["seed"] = seed if seed is not None else -1 # Default to -1 (random)
|
||||
|
||||
# Allow override from extra params
|
||||
if options.extra:
|
||||
if "source_image" in options.extra:
|
||||
payload["source_image"] = options.extra["source_image"]
|
||||
if "target_image" in options.extra:
|
||||
payload["target_image"] = options.extra["target_image"]
|
||||
if "seed" in options.extra and api_params.get("supports_seed", False):
|
||||
payload["seed"] = options.extra["seed"]
|
||||
else:
|
||||
# Standard format: uses image and face_image (single values)
|
||||
payload = {
|
||||
"image": base_image,
|
||||
"face_image": face_image,
|
||||
"output_format": api_params.get("output_format", "jpeg"),
|
||||
"enable_base64_output": True, # Always get base64 for our use case
|
||||
"enable_sync_mode": True, # Use sync mode for immediate results
|
||||
}
|
||||
|
||||
# Add any extra parameters (filter out already handled ones)
|
||||
if options.extra:
|
||||
handled_keys = {"source_image", "target_image", "face_enhance", "output_format", "enable_sync_mode", "seed"}
|
||||
for key, value in options.extra.items():
|
||||
if key not in handled_keys:
|
||||
payload[key] = value
|
||||
|
||||
url = f"{self.client.BASE_URL}/{model_path}"
|
||||
headers = self.client._headers()
|
||||
|
||||
logger.info(f"[Face Swap] Calling WaveSpeed API: {url}")
|
||||
logger.debug(f"[Face Swap] Payload keys: {list(payload.keys())}")
|
||||
|
||||
try:
|
||||
# Call API
|
||||
response = requests.post(url, headers=headers, json=payload, timeout=120)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"[Face Swap] API call failed: {response.status_code} {response.text}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed face swap failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
# Check status - Akool uses different status values
|
||||
status = data.get("status", "").lower()
|
||||
# Akool uses "output" (singular), others use "outputs" (plural)
|
||||
outputs = data.get("outputs") or data.get("output") or []
|
||||
# Normalize to list if it's a single value
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs] if outputs else []
|
||||
|
||||
prediction_id = data.get("id")
|
||||
|
||||
# Handle completed status - Akool uses "succeeded", others use "completed"
|
||||
is_completed = status in ["completed", "succeeded"]
|
||||
|
||||
# Handle sync mode - result should be directly in outputs
|
||||
if outputs and is_completed:
|
||||
logger.info(f"[Face Swap] Got immediate results (status: {status})")
|
||||
# Extract image URL or base64
|
||||
output = outputs[0]
|
||||
if output.startswith("data:image") or output.startswith("http"):
|
||||
if output.startswith("http"):
|
||||
# Download from URL
|
||||
import requests
|
||||
img_response = requests.get(output, timeout=60)
|
||||
img_response.raise_for_status()
|
||||
image_bytes = img_response.content
|
||||
else:
|
||||
# Extract base64 from data URI
|
||||
image_bytes = base64.b64decode(output.split(",", 1)[1])
|
||||
else:
|
||||
# Assume it's base64 string
|
||||
image_bytes = base64.b64decode(output)
|
||||
elif prediction_id:
|
||||
# Need to poll
|
||||
logger.info(f"[Face Swap] Polling for result (prediction_id: {prediction_id}, status: {status})")
|
||||
result = self.client.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=1.0)
|
||||
# Check both outputs and output fields
|
||||
outputs = result.get("outputs") or result.get("output") or []
|
||||
if not isinstance(outputs, list):
|
||||
outputs = [outputs] if outputs else []
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed face swap returned no outputs")
|
||||
output = outputs[0]
|
||||
if output.startswith("http"):
|
||||
import requests
|
||||
img_response = requests.get(output, timeout=60)
|
||||
img_response.raise_for_status()
|
||||
image_bytes = img_response.content
|
||||
elif output.startswith("data:image"):
|
||||
image_bytes = base64.b64decode(output.split(",", 1)[1])
|
||||
else:
|
||||
image_bytes = base64.b64decode(output)
|
||||
else:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed face swap response missing outputs and prediction ID")
|
||||
|
||||
# Get image dimensions
|
||||
img = Image.open(io.BytesIO(image_bytes))
|
||||
width, height = img.size
|
||||
|
||||
logger.info(f"[Face Swap] ✅ Successfully swapped face: {len(image_bytes)} bytes, {width}x{height}")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=image_bytes,
|
||||
width=width,
|
||||
height=height,
|
||||
provider="wavespeed",
|
||||
model=options.model or model_path,
|
||||
metadata={
|
||||
"model_path": model_path,
|
||||
"status": status,
|
||||
"created_at": data.get("created_at"),
|
||||
},
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap] API call failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Face swap failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
|
||||
"""Swap face in image using WaveSpeed models."""
|
||||
self._validate_options(options)
|
||||
|
||||
# Determine model
|
||||
model_id = options.model
|
||||
if not model_id:
|
||||
# Default to first available model
|
||||
model_id = list(self.SUPPORTED_MODELS.keys())[0]
|
||||
logger.info(f"[Face Swap] No model specified, using default: {model_id}")
|
||||
|
||||
model_info = self.SUPPORTED_MODELS[model_id]
|
||||
|
||||
# Call API
|
||||
return self._call_wavespeed_face_swap_api(options, model_info)
|
||||
|
||||
@classmethod
|
||||
def get_available_models(cls) -> dict:
|
||||
"""Get available face swap models and their information."""
|
||||
return cls.SUPPORTED_MODELS
|
||||
|
||||
@classmethod
|
||||
def get_models_by_tier(cls, tier: str) -> dict:
|
||||
"""Get models filtered by tier (budget, mid, premium)."""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if model_info.get("tier") == tier
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_models_by_capability(cls, capability: str) -> dict:
|
||||
"""Get models that support a specific capability."""
|
||||
return {
|
||||
model_id: model_info
|
||||
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
||||
if capability in model_info.get("capabilities", [])
|
||||
}
|
||||
@@ -8,11 +8,16 @@ from typing import Optional, Dict, Any
|
||||
from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
ImageEditOptions,
|
||||
ImageEditProvider,
|
||||
HuggingFaceImageProvider,
|
||||
GeminiImageProvider,
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
|
||||
@@ -47,6 +52,249 @@ def _get_provider(provider_name: str):
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance.
|
||||
|
||||
Args:
|
||||
provider_name: Provider name ("wavespeed", "stability", etc.)
|
||||
|
||||
Returns:
|
||||
ImageEditProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
# TODO: Add Stability edit provider if needed
|
||||
# elif provider_name == "stability":
|
||||
# return StabilityEditProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""
|
||||
Reusable pre-flight validation helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for subscription checking
|
||||
operation_type: Type of operation (for logging)
|
||||
num_operations: Number of operations to validate (default: 1)
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails (subscription limits exceeded, etc.)
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name (e.g., "wavespeed", "stability")
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated/processed image bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text (for request size calculation)
|
||||
endpoint: API endpoint path (for logging)
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information (current_calls, cost, etc.)
|
||||
"""
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
@@ -55,32 +303,13 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
options: Image generation options (provider, model, width, height, etc.)
|
||||
user_id: User ID for subscription checking (optional, but required for validation)
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image generation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
# PRE-FLIGHT VALIDATION: Reuse extracted helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-generation",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Generation]"
|
||||
)
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
|
||||
@@ -114,151 +343,39 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
provider = _get_provider(provider_name)
|
||||
result = provider.generate(image_options)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
has_image_bytes = bool(result.image_bytes) if result else False
|
||||
image_bytes_len = len(result.image_bytes) if (result and result.image_bytes) else 0
|
||||
logger.info(f"[Image Generation] Checking tracking conditions: user_id={user_id}, has_result={bool(result)}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
|
||||
# TRACK USAGE after successful API call - Reuse extracted helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Generation] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get cost from result metadata or calculate
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
|
||||
# Calculate cost from result metadata or estimate
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
if provider_name == "wavespeed":
|
||||
if result.model and "qwen" in result.model.lower():
|
||||
estimated_cost = 0.05
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
if provider_name == "wavespeed":
|
||||
if result.model and "qwen" in result.model.lower():
|
||||
estimated_cost = 0.05
|
||||
else:
|
||||
estimated_cost = 0.10 # ideogram-v3-turbo default
|
||||
elif provider_name == "stability":
|
||||
estimated_cost = 0.04
|
||||
else:
|
||||
estimated_cost = 0.05 # Default estimate
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Create usage log
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint="/image-generation",
|
||||
method="POST",
|
||||
model_used=result.model or "unknown",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode("utf-8")),
|
||||
response_size=len(result.image_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"[Image Generation] ✅ Successfully tracked usage: user {user_id} -> image -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider_name}
|
||||
├─ Actual Provider: {provider_name}
|
||||
├─ Model: {result.model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
estimated_cost = 0.10 # ideogram-v3-turbo default
|
||||
elif provider_name == "stability":
|
||||
estimated_cost = 0.04
|
||||
else:
|
||||
estimated_cost = 0.05 # Default estimate
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=result.model or "unknown",
|
||||
operation_type="image-generation",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Generation]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
|
||||
|
||||
@@ -290,32 +407,13 @@ def generate_character_image(
|
||||
Returns:
|
||||
bytes: Generated image bytes with consistent character
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image generation before API call
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Character Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1,
|
||||
)
|
||||
logger.info(f"[Character Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with character image generation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Character Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
logger.warning(f"[Character Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
# PRE-FLIGHT VALIDATION: Reuse extracted helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="character-image-generation",
|
||||
num_operations=1,
|
||||
log_prefix="[Character Image Generation]"
|
||||
)
|
||||
|
||||
# Generate character image via WaveSpeed
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
@@ -332,132 +430,26 @@ def generate_character_image(
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
has_image_bytes = bool(image_bytes) if image_bytes else False
|
||||
image_bytes_len = len(image_bytes) if image_bytes else 0
|
||||
logger.info(f"[Character Image Generation] Checking tracking conditions: user_id={user_id}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
|
||||
# TRACK USAGE after successful API call - Reuse extracted helper
|
||||
if user_id and image_bytes:
|
||||
logger.info(f"[Character Image Generation] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
from services.database import get_db as get_db_track
|
||||
db_track = next(get_db_track())
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Character image cost (same as ideogram-v3-turbo)
|
||||
estimated_cost = 0.10
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + estimated_cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.STABILITY, # Image generation uses STABILITY provider
|
||||
endpoint="/image-generation/character",
|
||||
method="POST",
|
||||
model_used="ideogram-character",
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=estimated_cost,
|
||||
response_time=0.0,
|
||||
status_code=200,
|
||||
request_size=len(prompt.encode("utf-8")),
|
||||
response_size=len(image_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Image Generation (Character)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: wavespeed
|
||||
├─ Actual Provider: wavespeed
|
||||
├─ Model: ideogram-character
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
logger.info(f"[Character Image Generation] ✅ Successfully tracked usage: user {user_id} -> {new_calls} calls, ${estimated_cost:.4f}")
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"[Character Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[Character Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
|
||||
|
||||
# Character image cost (same as ideogram-v3-turbo)
|
||||
estimated_cost = 0.10
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model="ideogram-character",
|
||||
operation_type="character-image-generation",
|
||||
result_bytes=image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation/character",
|
||||
metadata=None,
|
||||
log_prefix="[Character Image Generation]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Character Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(image_bytes) if image_bytes else 0} bytes")
|
||||
|
||||
@@ -476,3 +468,210 @@ def generate_character_image(
|
||||
)
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate edited image - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
# If model is specified and starts with "wavespeed", use wavespeed provider
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider (REUSES provider pattern)
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Image editing failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate face swap - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
image_base64=base_image_base64, # Use base image for validation
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
# Get model cost
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None, # Face swap doesn't use prompts
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Face swap failed",
|
||||
"message": str(api_error)
|
||||
}
|
||||
)
|
||||
|
||||
# 6. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
# Get cost from result metadata or estimate
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
# Fallback: estimate based on provider/model
|
||||
if provider_name == "wavespeed":
|
||||
# Default WaveSpeed edit cost
|
||||
estimated_cost = 0.02 # Default for most editing models
|
||||
else:
|
||||
estimated_cost = 0.05 # Default estimate
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=result.model or model or "unknown",
|
||||
operation_type="image-edit",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation/edit",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user