AI Researcher and Video Studio implementation complete

This commit is contained in:
ajaysi
2026-01-05 15:49:51 +05:30
parent b134e9dc7e
commit 0b63ae7fc1
200 changed files with 39535 additions and 1375 deletions

View File

@@ -1,4 +1,12 @@
from .base import ImageGenerationOptions, ImageGenerationResult, ImageGenerationProvider
from .base import (
ImageGenerationOptions,
ImageGenerationResult,
ImageGenerationProvider,
ImageEditOptions,
ImageEditProvider,
FaceSwapOptions,
FaceSwapProvider,
)
from .hf_provider import HuggingFaceImageProvider
from .gemini_provider import GeminiImageProvider
from .stability_provider import StabilityImageProvider
@@ -8,6 +16,10 @@ __all__ = [
"ImageGenerationOptions",
"ImageGenerationResult",
"ImageGenerationProvider",
"ImageEditOptions",
"ImageEditProvider",
"FaceSwapOptions",
"FaceSwapProvider",
"HuggingFaceImageProvider",
"GeminiImageProvider",
"StabilityImageProvider",

View File

@@ -28,6 +28,50 @@ class ImageGenerationResult:
metadata: Optional[Dict[str, Any]] = None
@dataclass
class ImageEditOptions:
"""Options for image editing operations."""
image_base64: str
prompt: str
operation: str # "general_edit", "inpaint", "outpaint", "remove_background", etc.
mask_base64: Optional[str] = None
negative_prompt: Optional[str] = None
model: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
extra: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API calls."""
result = {
"image_base64": self.image_base64,
"prompt": self.prompt,
"operation": self.operation,
}
if self.mask_base64:
result["mask_base64"] = self.mask_base64
if self.negative_prompt:
result["negative_prompt"] = self.negative_prompt
if self.model:
result["model"] = self.model
if self.width:
result["width"] = self.width
if self.height:
result["height"] = self.height
if self.guidance_scale is not None:
result["guidance_scale"] = self.guidance_scale
if self.steps:
result["steps"] = self.steps
if self.seed is not None:
result["seed"] = self.seed
if self.extra:
result.update(self.extra)
return result
class ImageGenerationProvider(Protocol):
"""Protocol for image generation providers."""
@@ -35,3 +79,44 @@ class ImageGenerationProvider(Protocol):
...
@dataclass
class FaceSwapOptions:
"""Options for face swap operations."""
base_image_base64: str # Image to swap face into
face_image_base64: str # Face to swap
model: Optional[str] = None
target_face_index: Optional[int] = None # For multi-face images (0 = largest)
target_gender: Optional[str] = None # "all", "female", "male" (for some models)
extra: Optional[Dict[str, Any]] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for API calls."""
result = {
"base_image_base64": self.base_image_base64,
"face_image_base64": self.face_image_base64,
}
if self.model:
result["model"] = self.model
if self.target_face_index is not None:
result["target_face_index"] = self.target_face_index
if self.target_gender:
result["target_gender"] = self.target_gender
if self.extra:
result.update(self.extra)
return result
class ImageEditProvider(Protocol):
"""Protocol for image editing providers."""
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
...
class FaceSwapProvider(Protocol):
"""Protocol for face swap providers."""
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
...

View File

@@ -0,0 +1,691 @@
"""WaveSpeed AI image editing provider (14 editing models)."""
import io
import os
import requests
from typing import Optional
from PIL import Image
from fastapi import HTTPException
from .base import ImageEditProvider, ImageEditOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("wavespeed.edit_provider")
class WaveSpeedEditProvider(ImageEditProvider):
"""WaveSpeed AI image editing provider supporting 14 editing models.
REUSES: WaveSpeedClient, model registry pattern, result format
"""
# Model registry - populated with WaveSpeed editing models
SUPPORTED_MODELS = {
"qwen-edit": {
"model_path": "wavespeed-ai/qwen-image/edit",
"name": "Qwen Image Edit",
"description": "20B MMDiT image-to-image model offering precise bilingual (Chinese & English) text edits while preserving style. Single-image editing with style preservation.",
"cost": 0.02, # Same as Plus version
"max_resolution": (1536, 1536), # Based on docs: similar to Plus
"capabilities": ["general_edit", "style_transfer", "text_edit"],
"tier": "budget",
"supports_multi_image": False, # Single image only (uses "image" not "images")
"supports_controlnet": False, # Not mentioned in docs
"languages": ["en", "zh"],
"api_params": {
"uses_size": True, # Uses "size" parameter (width*height)
"uses_aspect_ratio": False,
"uses_resolution": False,
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
"default_output_format": "jpeg", # Per API docs: default is "jpeg"
"supports_seed": True, # Per API docs: seed parameter supported
}
},
"qwen-edit-plus": {
"model_path": "wavespeed-ai/qwen-image/edit-plus",
"name": "Qwen Image Edit Plus",
"description": "20B MMDiT image editor with multi-image editing, single-image consistency and native ControlNet support. Bilingual (CN/EN) text editing, appearance-level and semantic-level edits.",
"cost": 0.02,
"max_resolution": (1536, 1536), # Based on docs: 256-1536 per dimension
"capabilities": ["general_edit", "style_transfer", "text_edit", "multi_image"],
"tier": "budget",
"supports_multi_image": True, # Up to 3 reference images
"supports_controlnet": True,
"languages": ["en", "zh"],
"api_params": {
"uses_size": True, # Uses "size" parameter (width*height)
"uses_aspect_ratio": False,
"uses_resolution": False,
"uses_image_singular": False, # Uses "images" (array)
"supports_seed": True, # Seed parameter supported (default for Qwen models)
}
},
"nano-banana-pro-edit-ultra": {
"model_path": "google/nano-banana-pro/edit-ultra",
"name": "Google Nano Banana Pro Edit Ultra",
"description": "High-resolution image editing with 4K/8K native output. Natural language instructions, multilingual text support. Premium quality editing for professional marketing and high-res work.",
"cost": 0.15, # 4K - from enhancement proposal
"cost_8k": 0.18, # 8K - from enhancement proposal
"max_resolution": (8192, 8192), # 8K support
"capabilities": ["general_edit", "high_res", "professional", "typography"],
"tier": "premium",
"supports_multi_image": True, # Up to 14 reference images
"supports_controlnet": False,
"languages": ["en", "multilingual"],
"api_params": {
"uses_size": False, # Uses aspect_ratio and resolution instead
"uses_aspect_ratio": True, # "1:1", "16:9", etc.
"uses_resolution": True, # "4k" or "8k"
"max_images": 14,
"default_output_format": "png", # Per API docs: default is "png"
"supports_seed": False, # Per API docs: no seed parameter
}
},
"seedream-v4.5-edit": {
"model_path": "bytedance/seedream-v4.5/edit",
"name": "Bytedance Seedream V4.5 Edit",
"description": "Preserves facial features, lighting, and color tone from reference images, delivering professional, high-fidelity edits up to 4K with strong prompt adherence. Reference-faithful editing with multi-image support.",
"cost": 0.04, # Per generated image
"max_resolution": (4096, 4096), # 4K support (1024-4096 per dimension)
"capabilities": ["general_edit", "portrait_retouching", "fashion_edit", "product_edit", "multi_image"],
"tier": "mid",
"supports_multi_image": True, # Up to 10 reference images
"supports_controlnet": False,
"languages": ["en"],
"api_params": {
"uses_size": True, # Uses "size" parameter (width*height format, 1024-4096 per dimension)
"uses_aspect_ratio": False,
"uses_resolution": False,
"max_images": 10,
"default_output_format": "png",
"supports_seed": False, # No seed parameter in API docs (Seedream V4.5)
}
},
"flux-kontext-pro": {
"model_path": "wavespeed-ai/flux-kontext-pro",
"name": "FLUX Kontext Pro",
"description": "FLUX.1 Kontext [pro] offers improved prompt adherence and accurate typography generation for consistent, high-quality edits at speed. Typography-focused editing with improved prompt adherence.",
"cost": 0.04, # From enhancement proposal
"max_resolution": (2048, 2048), # Estimated, not specified in docs
"capabilities": ["general_edit", "typography", "text_edit", "style_transfer"],
"tier": "mid",
"supports_multi_image": False, # Single image only (uses "image" not "images")
"supports_controlnet": False,
"languages": ["en"],
"api_params": {
"uses_size": False, # Uses aspect_ratio instead
"uses_aspect_ratio": True, # Aspect ratio as string (e.g., "16:9", "1:1")
"uses_resolution": False,
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
"supports_guidance_scale": True, # Has guidance_scale parameter (default 3.5, range 1-20)
"default_guidance_scale": 3.5, # Per API docs
"supports_seed": False, # No seed parameter in API docs
}
},
# TODO: Add remaining 9 models once docs are provided
}
def __init__(self, api_key: Optional[str] = None):
"""Initialize WaveSpeed edit provider.
Args:
api_key: WaveSpeed API key (falls back to env var if not provided)
"""
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
if not self.api_key:
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
# REUSE: Same client as generation provider
self.client = WaveSpeedClient(api_key=self.api_key)
logger.info("[WaveSpeed Edit Provider] Initialized with %d models",
len(self.SUPPORTED_MODELS))
def _validate_options(self, options: ImageEditOptions) -> None:
"""Validate editing options.
Args:
options: Image editing options
Raises:
ValueError: If options are invalid
"""
model = options.model or list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None
if not model:
raise ValueError("No model specified and no default model available")
if model not in self.SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model: {model}. "
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
)
model_info = self.SUPPORTED_MODELS[model]
max_width, max_height = model_info.get("max_resolution", (4096, 4096))
if options.width and options.width > max_width:
raise ValueError(
f"Width {options.width} exceeds maximum {max_width} for model {model}"
)
if options.height and options.height > max_height:
raise ValueError(
f"Height {options.height} exceeds maximum {max_height} for model {model}"
)
if not options.prompt or len(options.prompt.strip()) == 0:
raise ValueError("Prompt cannot be empty")
if not options.image_base64:
raise ValueError("Image base64 cannot be empty")
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
"""Edit image using WaveSpeed AI models.
Args:
options: Image editing options
Returns:
ImageGenerationResult with edited image
Raises:
ValueError: If options are invalid
RuntimeError: If editing fails
"""
# Validate options
self._validate_options(options)
# Determine model
model = options.model or (list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None)
if not model:
raise ValueError("No model available for editing")
model_info = self.SUPPORTED_MODELS[model]
model_path = model_info["model_path"]
logger.info("[WaveSpeed Edit] Starting edit: model=%s, operation=%s, prompt=%s",
model, options.operation, options.prompt[:100])
try:
# Prepare extra parameters based on model capabilities
extra_params = options.extra or {}
# Add model-specific parameters if needed
api_params = model_info.get("api_params", {})
if api_params.get("uses_resolution", False):
# For Nano Banana: determine resolution from dimensions or use default
if options.width and options.height:
if options.width >= 4096 or options.height >= 4096:
extra_params["resolution"] = "8k"
else:
extra_params["resolution"] = "4k"
elif "resolution" not in extra_params:
extra_params["resolution"] = "4k" # Default to 4K
if api_params.get("uses_aspect_ratio", False) and not extra_params.get("aspect_ratio"):
# Calculate aspect ratio if dimensions provided
if options.width and options.height:
aspect_ratio = self._calculate_aspect_ratio(options.width, options.height)
if aspect_ratio:
extra_params["aspect_ratio"] = aspect_ratio
# Call WaveSpeed API for editing
result = self._call_wavespeed_edit_api(
model_path=model_path,
image_base64=options.image_base64,
prompt=options.prompt,
operation=options.operation,
mask_base64=options.mask_base64,
negative_prompt=options.negative_prompt,
width=options.width,
height=options.height,
guidance_scale=options.guidance_scale,
steps=options.steps,
seed=options.seed,
extra=extra_params
)
# Extract image bytes from result
if isinstance(result, bytes):
image_bytes = result
elif isinstance(result, dict) and "image" in result:
image_bytes = result["image"]
elif isinstance(result, dict) and "image_bytes" in result:
image_bytes = result["image_bytes"]
else:
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
# Load image to get dimensions
image = Image.open(io.BytesIO(image_bytes))
width, height = image.size
# Calculate estimated cost - handle resolution-based pricing
estimated_cost = model_info.get("cost", 0.02)
if api_params.get("uses_resolution", False):
# Check if 8K was requested
resolution = extra_params.get("resolution", "4k")
if resolution == "8k" and "cost_8k" in model_info:
estimated_cost = model_info["cost_8k"]
logger.info("[WaveSpeed Edit] ✅ Successfully edited image: %d bytes, %dx%d",
len(image_bytes), width, height)
# REUSE: Same result format as generation
return ImageGenerationResult(
image_bytes=image_bytes,
width=width,
height=height,
provider="wavespeed",
model=model,
seed=options.seed,
metadata={
"provider": "wavespeed",
"model": model,
"model_name": model_info.get("name", model),
"operation": options.operation,
"prompt": options.prompt,
"negative_prompt": options.negative_prompt,
"estimated_cost": estimated_cost,
"tier": model_info.get("tier", "mid"),
}
)
except Exception as e:
logger.error("[WaveSpeed Edit] ❌ Error editing image: %s", str(e), exc_info=True)
raise RuntimeError(f"WaveSpeed edit failed: {str(e)}")
def _call_wavespeed_edit_api(
self,
model_path: str,
image_base64: str,
prompt: str,
operation: str,
mask_base64: Optional[str] = None,
negative_prompt: Optional[str] = None,
width: Optional[int] = None,
height: Optional[int] = None,
guidance_scale: Optional[float] = None,
steps: Optional[int] = None,
seed: Optional[int] = None,
extra: Optional[dict] = None
) -> bytes:
"""Call WaveSpeed API for image editing.
REUSES: Same pattern as ImageGenerator.generate_image()
Args:
model_path: Full model path (e.g., "wavespeed-ai/qwen-image/edit-plus")
image_base64: Base64-encoded input image
prompt: Edit instruction prompt
operation: Type of operation
mask_base64: Optional mask for inpainting
negative_prompt: Optional negative prompt
width: Optional target width
height: Optional target height
guidance_scale: Optional guidance scale (not used by all models)
steps: Optional number of steps (not used by all models)
seed: Optional seed
extra: Optional extra parameters
Returns:
Edited image bytes
Raises:
RuntimeError: If API call fails
"""
import requests
from fastapi import HTTPException
# Build URL - REUSES same pattern as ImageGenerator
url = f"{self.client.BASE_URL}/{model_path}"
# Prepare images array - WaveSpeed expects array of image strings
# Format: base64 strings or data URIs (data:image/png;base64,...)
# For Qwen Image Edit Plus: supports up to 3 reference images
images = []
# Add main image - check if it's already a data URI or just base64
if image_base64.startswith("data:image"):
# Already a data URI
images.append(image_base64)
else:
# Assume it's base64, convert to data URI
# Try to detect format from base64 or default to PNG
images.append(f"data:image/png;base64,{image_base64}")
# If mask is provided, add it as second image
# Note: Some models may need mask in different format - will adjust per model
if mask_base64:
if mask_base64.startswith("data:image"):
images.append(mask_base64)
else:
images.append(f"data:image/png;base64,{mask_base64}")
# Get model info to determine API parameter structure
model_info = self.SUPPORTED_MODELS.get(model_path.split("/")[-1] if "/" in model_path else model_path)
if not model_info:
# Fallback: try to find model by matching path
for model_id, info in self.SUPPORTED_MODELS.items():
if info["model_path"] == model_path:
model_info = info
break
if not model_info:
raise ValueError(f"Model info not found for: {model_path}")
api_params = model_info.get("api_params", {})
# Build payload - following WaveSpeed API structure
# Note: output_format default varies by model (PNG for most, but can be JPEG)
default_output_format = api_params.get("default_output_format", "png")
# Some models use "image" (singular) instead of "images" (array)
uses_image_singular = api_params.get("uses_image_singular", False)
payload = {
"prompt": prompt,
"enable_sync_mode": True, # Use sync mode for immediate results
"enable_base64_output": False, # Get URL, then download
"output_format": default_output_format,
}
# Add image(s) based on model API format
if uses_image_singular:
# Models like Qwen Edit (basic) use "image" (singular)
# Use first image only (single image editing)
if images:
payload["image"] = images[0]
else:
raise ValueError("At least one image is required")
else:
# Models like Qwen Edit Plus, Nano Banana use "images" (array)
payload["images"] = images
# Allow override of output_format from extra params
if extra and "output_format" in extra:
payload["output_format"] = extra["output_format"]
# Model-specific parameter handling
if api_params.get("uses_size", True):
# Models like Qwen Edit Plus use "size" parameter (width*height format)
if width and height:
payload["size"] = f"{width}*{height}"
elif width:
payload["size"] = f"{width}*{width}" # Square if only width provided
elif height:
payload["size"] = f"{height}*{height}" # Square if only height provided
if api_params.get("uses_aspect_ratio", False):
# Models like Nano Banana and FLUX Kontext Pro use "aspect_ratio" parameter
if width and height:
# Calculate aspect ratio from dimensions
aspect_ratio = self._calculate_aspect_ratio(width, height)
if aspect_ratio:
payload["aspect_ratio"] = aspect_ratio
elif extra and "aspect_ratio" in extra:
payload["aspect_ratio"] = extra["aspect_ratio"]
if api_params.get("uses_resolution", False):
# Models like Nano Banana use "resolution" parameter ("4k" or "8k")
if extra and "resolution" in extra:
payload["resolution"] = extra["resolution"]
else:
# Default to 4K, or 8K if dimensions suggest high-res
if width and height and (width >= 4096 or height >= 4096):
payload["resolution"] = "8k"
else:
payload["resolution"] = "4k" # Default to 4K per API docs
# Add optional parameters (model-agnostic)
# Guidance scale: Only add if model supports it (e.g., FLUX Kontext Pro)
if api_params.get("supports_guidance_scale", False):
default_guidance = api_params.get("default_guidance_scale", 3.5)
if guidance_scale is not None:
# Clamp to valid range (1-20 per FLUX Kontext Pro docs)
payload["guidance_scale"] = max(1, min(20, guidance_scale))
elif extra and "guidance_scale" in extra:
payload["guidance_scale"] = max(1, min(20, extra["guidance_scale"]))
else:
payload["guidance_scale"] = default_guidance
# Seed parameter: Only add if model supports it
if api_params.get("supports_seed", True): # Default to True for backward compatibility
if seed is not None:
payload["seed"] = seed
else:
payload["seed"] = -1 # Random seed (per API docs default)
# Add any extra parameters
if extra:
# Filter out parameters we've already handled
handled_params = {"aspect_ratio", "resolution", "size", "seed", "guidance_scale"}
for key, value in extra.items():
if key not in handled_params:
payload[key] = value
logger.info(f"[WaveSpeed Edit] Submitting edit request to {url} (model={model_path}, prompt_length={len(prompt)})")
# Make API call - REUSES same pattern as ImageGenerator
try:
response = requests.post(
url,
headers=self.client._headers(),
json=payload,
timeout=120
)
if response.status_code != 200:
logger.error(f"[WaveSpeed Edit] API call failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed image editing failed",
"status_code": response.status_code,
"response": response.text[:500],
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Check status
status = data.get("status", "").lower()
outputs = data.get("outputs") or []
prediction_id = data.get("id")
logger.debug(
f"[WaveSpeed Edit] Response: status='{status}', outputs_count={len(outputs)}, "
f"prediction_id={prediction_id}"
)
# Handle sync mode - result should be directly in outputs
if outputs and status == "completed":
logger.info(f"[WaveSpeed Edit] Got immediate results from sync mode")
image_url = self._extract_image_url(outputs)
return self._download_image(image_url, timeout=120)
# Sync mode returned "created" or "processing" - need to poll
if not prediction_id:
logger.error(f"[WaveSpeed Edit] Sync mode returned status '{status}' but no prediction ID")
raise HTTPException(
status_code=502,
detail="WaveSpeed sync mode returned async response without prediction ID",
)
logger.info(
f"[WaveSpeed Edit] Sync mode returned status '{status}' with no outputs. "
f"Polling for result (prediction_id: {prediction_id})"
)
# Poll for result - REUSES polling utility
result = self.client.poll_until_complete(
prediction_id,
timeout_seconds=180,
interval_seconds=2.0,
)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(
status_code=502,
detail="WaveSpeed edit returned no outputs after polling"
)
# Extract image URL from outputs - REUSE helper method
image_url = self._extract_image_url(outputs)
return self._download_image(image_url, timeout=120)
except HTTPException:
raise
except Exception as e:
logger.error(f"[WaveSpeed Edit] Unexpected error: {str(e)}", exc_info=True)
raise RuntimeError(f"WaveSpeed edit API call failed: {str(e)}")
def _extract_image_url(self, outputs: list) -> str:
"""Extract image URL from outputs - REUSES same pattern as ImageGenerator.
Args:
outputs: Array of output URLs or objects
Returns:
Image URL string
Raises:
HTTPException: If output format is invalid
"""
if not isinstance(outputs, list) or len(outputs) == 0:
raise HTTPException(
status_code=502,
detail="WaveSpeed edit returned no outputs",
)
first_output = outputs[0]
if isinstance(first_output, str):
image_url = first_output
elif isinstance(first_output, dict):
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
else:
raise HTTPException(
status_code=502,
detail="WaveSpeed edit output format not recognized",
)
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
raise HTTPException(
status_code=502,
detail="WaveSpeed edit returned invalid image URL",
)
return image_url
def _download_image(self, image_url: str, timeout: int = 120) -> bytes:
"""Download image from URL - REUSES same pattern as ImageGenerator.
Args:
image_url: URL to download from
timeout: Request timeout in seconds
Returns:
Image bytes
Raises:
HTTPException: If download fails
"""
logger.info(f"[WaveSpeed Edit] Downloading edited image from: {image_url}")
image_response = requests.get(image_url, timeout=timeout)
if image_response.status_code != 200:
logger.error(f"[WaveSpeed Edit] Failed to download image: {image_response.status_code}")
raise HTTPException(
status_code=502,
detail=f"Failed to download edited image: {image_response.status_code}"
)
logger.info(f"[WaveSpeed Edit] Successfully downloaded image ({len(image_response.content)} bytes)")
return image_response.content
def _calculate_aspect_ratio(self, width: int, height: int) -> Optional[str]:
"""Calculate aspect ratio string from dimensions.
Args:
width: Image width
height: Image height
Returns:
Aspect ratio string (e.g., "16:9") or None if not standard
"""
# Common aspect ratios (includes FLUX Kontext Pro supported ratios)
ratios = {
(1, 1): "1:1",
(3, 2): "3:2",
(2, 3): "2:3",
(3, 4): "3:4",
(4, 3): "4:3",
(4, 5): "4:5",
(5, 4): "5:4",
(9, 16): "9:16",
(16, 9): "16:9",
(21, 9): "21:9",
(9, 21): "9:21", # FLUX Kontext Pro also supports 9:21
}
# Calculate GCD to simplify ratio
def gcd(a, b):
while b:
a, b = b, a % b
return a
divisor = gcd(width, height)
simplified = (width // divisor, height // divisor)
# Check if it matches a standard ratio (with some tolerance)
for (w, h), ratio_str in ratios.items():
# Allow small tolerance for rounding
if abs(simplified[0] / simplified[1] - w / h) < 0.01:
return ratio_str
# If no match, return None (model may not support custom aspect ratios)
return None
@classmethod
def get_available_models(cls) -> dict:
"""Get available editing models and their information.
Returns:
Dictionary of available models
"""
return cls.SUPPORTED_MODELS
@classmethod
def get_models_by_tier(cls, tier: str) -> dict:
"""Get models filtered by tier (budget, mid, premium).
Args:
tier: Tier name ("budget", "mid", "premium")
Returns:
Dictionary of models in the specified tier
"""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if model_info.get("tier") == tier
}
@classmethod
def get_models_by_operation(cls, operation: str) -> dict:
"""Get models that support a specific operation.
Args:
operation: Operation type (e.g., "inpaint", "outpaint", "general_edit")
Returns:
Dictionary of models supporting the operation
"""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if operation in model_info.get("capabilities", [])
}

View File

@@ -0,0 +1,367 @@
"""WaveSpeed Face Swap Provider for Image Studio."""
from __future__ import annotations
import base64
import io
from typing import Optional, Dict, Any
from PIL import Image
from services.llm_providers.image_generation.base import (
FaceSwapOptions,
FaceSwapProvider,
ImageGenerationResult,
)
from services.wavespeed.client import WaveSpeedClient
from utils.logger_utils import get_service_logger
logger = get_service_logger("llm_providers.wavespeed_face_swap")
class WaveSpeedFaceSwapProvider:
"""WaveSpeed provider for face swap operations."""
SUPPORTED_MODELS = {
"image-face-swap-pro": {
"model_path": "wavespeed-ai/image-face-swap-pro",
"name": "Image Face Swap Pro",
"description": "Instant online AI face swap for photos with no watermark, delivering realistic, shareable results in seconds.",
"cost": 0.025,
"tier": "mid",
"capabilities": ["face_swap", "realistic_blending"],
"features": ["Enhanced blending", "Realistic results", "Watermark-free"],
"max_faces": 1,
"api_params": {
"output_format": "jpeg",
"supports_base64": True,
"supports_sync": True,
},
},
"image-head-swap": {
"model_path": "wavespeed-ai/image-head-swap",
"name": "Image Head Swap",
"description": "Instant online AI head & face swap for photos with no watermark. Replaces entire head (face + hair + outline) while preserving body, pose and background.",
"cost": 0.025,
"tier": "mid",
"capabilities": ["head_swap", "full_head_replacement", "realistic_blending"],
"features": ["Full head replacement", "Hair included", "Pose preservation", "Watermark-free"],
"max_faces": 1,
"api_params": {
"output_format": "jpeg",
"supports_base64": True,
"supports_sync": True,
},
},
"akool-face-swap": {
"model_path": "akool/image-face-swap",
"name": "Akool Image Face Swap",
"description": "Powerful AI-powered face swapping with multi-face replacement for group photos. Seamlessly replaces faces with natural lighting and skin tone matching.",
"cost": 0.16,
"tier": "premium",
"capabilities": ["face_swap", "multi_face", "realistic_blending", "face_enhancement"],
"features": ["Multi-face swapping (up to 5)", "Face enhancement", "Group photos", "High-quality blending"],
"max_faces": 5, # Supports 1-5 faces
"api_params": {
"uses_source_target_arrays": True, # Uses source_image and target_image arrays
"supports_face_enhance": True,
"supports_base64": True,
"supports_sync": False, # May need polling
},
},
"infinite-you": {
"model_path": "wavespeed-ai/infinite-you",
"name": "InfiniteYou",
"description": "High-quality face swapping powered by ByteDance's zero-shot identity preservation technology. Maintains facial identity characteristics with exceptional realism.",
"cost": 0.03,
"tier": "mid",
"capabilities": ["face_swap", "identity_preservation", "realistic_blending"],
"features": ["Zero-shot learning", "Identity preservation", "High-quality results", "Fast processing"],
"max_faces": 1,
"api_params": {
"uses_source_target_names": True, # Uses source_image and target_image (not image/face_image)
"target_is_base": True, # target_image is the base image (where face will be swapped)
"source_is_face": True, # source_image is the face to swap in
"supports_seed": True, # Supports seed parameter
"supports_base64": True,
"supports_sync": True,
},
},
# Placeholder for additional models (will be added as docs are provided)
# "image-face-swap": {...}, # Basic version ($0.01)
}
def __init__(self):
self.client = WaveSpeedClient()
def _validate_options(self, options: FaceSwapOptions) -> None:
"""Validate face swap options."""
if not options.base_image_base64:
raise ValueError("base_image_base64 is required")
if not options.face_image_base64:
raise ValueError("face_image_base64 is required")
# Validate model
if options.model and options.model not in self.SUPPORTED_MODELS:
raise ValueError(
f"Unsupported model: {options.model}. "
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
)
def _extract_image_url(self, data_url: str) -> str:
"""Extract image URL from data URL or return as-is if already a URL."""
if data_url.startswith("data:image"):
# It's a data URL, we'll need to upload it
return data_url
return data_url
def _upload_image_if_needed(self, image_data: str) -> str:
"""Upload image if it's a base64 data URL, otherwise return URL."""
if image_data.startswith("data:image"):
# Extract base64 data
header, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
# Upload to temporary storage (or use WaveSpeed upload endpoint if available)
# For now, we'll return the data URL and let the API handle it
# In production, you might want to upload to S3/CloudFlare first
return image_data
return image_data
def _call_wavespeed_face_swap_api(
self, options: FaceSwapOptions, model_info: Dict[str, Any]
) -> ImageGenerationResult:
"""Call WaveSpeed face swap API."""
import requests
from fastapi import HTTPException
model_path = model_info["model_path"]
api_params = model_info.get("api_params", {})
uses_source_target_arrays = api_params.get("uses_source_target_arrays", False)
# Prepare images - extract base64 if data URI
base_image = options.base_image_base64
if base_image.startswith("data:image"):
# Keep as data URI - API should accept it
pass
elif not base_image.startswith("http"):
# Assume it's base64, convert to data URI
base_image = f"data:image/png;base64,{base_image}"
face_image = options.face_image_base64
if face_image.startswith("data:image"):
# Keep as data URI
pass
elif not face_image.startswith("http"):
# Assume it's base64, convert to data URI
face_image = f"data:image/png;base64,{face_image}"
# Build API payload - handle different API formats
uses_source_target_names = api_params.get("uses_source_target_names", False)
if uses_source_target_arrays:
# Akool format: uses source_image and target_image as arrays
# For single face swap: source_image is the new face, target_image is reference from main image
# Since we only have one face_image, we'll use it as source and the base_image as target reference
payload = {
"image": base_image,
"source_image": [face_image], # Array of source faces (1-5) - the new face to swap in
"target_image": [base_image], # Array of target faces (1-5) - reference from main image
"face_enhance": api_params.get("supports_face_enhance", True), # Default to True for Akool
"enable_base64_output": True,
}
# Allow override from extra params
if options.extra:
if "source_image" in options.extra:
payload["source_image"] = options.extra["source_image"]
if "target_image" in options.extra:
payload["target_image"] = options.extra["target_image"]
if "face_enhance" in options.extra:
payload["face_enhance"] = options.extra["face_enhance"]
elif uses_source_target_names:
# InfiniteYou format: uses source_image and target_image (single values, different names)
# target_image = base image (where face will be swapped)
# source_image = face image (face to swap in)
payload = {
"target_image": base_image, # Base image where face will be swapped
"source_image": face_image, # Face to swap in
"enable_base64_output": True,
}
# Add seed if supported
if api_params.get("supports_seed", False):
seed = options.extra.get("seed") if options.extra else None
payload["seed"] = seed if seed is not None else -1 # Default to -1 (random)
# Allow override from extra params
if options.extra:
if "source_image" in options.extra:
payload["source_image"] = options.extra["source_image"]
if "target_image" in options.extra:
payload["target_image"] = options.extra["target_image"]
if "seed" in options.extra and api_params.get("supports_seed", False):
payload["seed"] = options.extra["seed"]
else:
# Standard format: uses image and face_image (single values)
payload = {
"image": base_image,
"face_image": face_image,
"output_format": api_params.get("output_format", "jpeg"),
"enable_base64_output": True, # Always get base64 for our use case
"enable_sync_mode": True, # Use sync mode for immediate results
}
# Add any extra parameters (filter out already handled ones)
if options.extra:
handled_keys = {"source_image", "target_image", "face_enhance", "output_format", "enable_sync_mode", "seed"}
for key, value in options.extra.items():
if key not in handled_keys:
payload[key] = value
url = f"{self.client.BASE_URL}/{model_path}"
headers = self.client._headers()
logger.info(f"[Face Swap] Calling WaveSpeed API: {url}")
logger.debug(f"[Face Swap] Payload keys: {list(payload.keys())}")
try:
# Call API
response = requests.post(url, headers=headers, json=payload, timeout=120)
if response.status_code != 200:
logger.error(f"[Face Swap] API call failed: {response.status_code} {response.text}")
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed face swap failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
# Check status - Akool uses different status values
status = data.get("status", "").lower()
# Akool uses "output" (singular), others use "outputs" (plural)
outputs = data.get("outputs") or data.get("output") or []
# Normalize to list if it's a single value
if not isinstance(outputs, list):
outputs = [outputs] if outputs else []
prediction_id = data.get("id")
# Handle completed status - Akool uses "succeeded", others use "completed"
is_completed = status in ["completed", "succeeded"]
# Handle sync mode - result should be directly in outputs
if outputs and is_completed:
logger.info(f"[Face Swap] Got immediate results (status: {status})")
# Extract image URL or base64
output = outputs[0]
if output.startswith("data:image") or output.startswith("http"):
if output.startswith("http"):
# Download from URL
import requests
img_response = requests.get(output, timeout=60)
img_response.raise_for_status()
image_bytes = img_response.content
else:
# Extract base64 from data URI
image_bytes = base64.b64decode(output.split(",", 1)[1])
else:
# Assume it's base64 string
image_bytes = base64.b64decode(output)
elif prediction_id:
# Need to poll
logger.info(f"[Face Swap] Polling for result (prediction_id: {prediction_id}, status: {status})")
result = self.client.poll_until_complete(prediction_id, timeout_seconds=120, interval_seconds=1.0)
# Check both outputs and output fields
outputs = result.get("outputs") or result.get("output") or []
if not isinstance(outputs, list):
outputs = [outputs] if outputs else []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed face swap returned no outputs")
output = outputs[0]
if output.startswith("http"):
import requests
img_response = requests.get(output, timeout=60)
img_response.raise_for_status()
image_bytes = img_response.content
elif output.startswith("data:image"):
image_bytes = base64.b64decode(output.split(",", 1)[1])
else:
image_bytes = base64.b64decode(output)
else:
raise HTTPException(status_code=502, detail="WaveSpeed face swap response missing outputs and prediction ID")
# Get image dimensions
img = Image.open(io.BytesIO(image_bytes))
width, height = img.size
logger.info(f"[Face Swap] ✅ Successfully swapped face: {len(image_bytes)} bytes, {width}x{height}")
return ImageGenerationResult(
image_bytes=image_bytes,
width=width,
height=height,
provider="wavespeed",
model=options.model or model_path,
metadata={
"model_path": model_path,
"status": status,
"created_at": data.get("created_at"),
},
)
except HTTPException:
raise
except Exception as e:
logger.error(f"[Face Swap] API call failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=502,
detail={
"error": "Face swap failed",
"message": str(e)
}
)
def swap_face(self, options: FaceSwapOptions) -> ImageGenerationResult:
"""Swap face in image using WaveSpeed models."""
self._validate_options(options)
# Determine model
model_id = options.model
if not model_id:
# Default to first available model
model_id = list(self.SUPPORTED_MODELS.keys())[0]
logger.info(f"[Face Swap] No model specified, using default: {model_id}")
model_info = self.SUPPORTED_MODELS[model_id]
# Call API
return self._call_wavespeed_face_swap_api(options, model_info)
@classmethod
def get_available_models(cls) -> dict:
"""Get available face swap models and their information."""
return cls.SUPPORTED_MODELS
@classmethod
def get_models_by_tier(cls, tier: str) -> dict:
"""Get models filtered by tier (budget, mid, premium)."""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if model_info.get("tier") == tier
}
@classmethod
def get_models_by_capability(cls, capability: str) -> dict:
"""Get models that support a specific capability."""
return {
model_id: model_info
for model_id, model_info in cls.SUPPORTED_MODELS.items()
if capability in model_info.get("capabilities", [])
}

View File

@@ -8,11 +8,16 @@ from typing import Optional, Dict, Any
from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
ImageEditOptions,
ImageEditProvider,
HuggingFaceImageProvider,
GeminiImageProvider,
StabilityImageProvider,
WaveSpeedImageProvider,
)
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
from utils.logger_utils import get_service_logger
@@ -47,6 +52,249 @@ def _get_provider(provider_name: str):
raise ValueError(f"Unknown image provider: {provider_name}")
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
"""Get face swap provider by name."""
if provider_name == "wavespeed":
return WaveSpeedFaceSwapProvider()
raise ValueError(f"Unknown face swap provider: {provider_name}")
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
"""Get editing provider instance.
Args:
provider_name: Provider name ("wavespeed", "stability", etc.)
Returns:
ImageEditProvider instance
Raises:
ValueError: If provider is not supported
"""
if provider_name == "wavespeed":
return WaveSpeedEditProvider()
# TODO: Add Stability edit provider if needed
# elif provider_name == "stability":
# return StabilityEditProvider()
else:
raise ValueError(f"Unknown edit provider: {provider_name}")
def _validate_image_operation(
user_id: Optional[str],
operation_type: str = "image-generation",
num_operations: int = 1,
log_prefix: str = "[Image Generation]"
) -> None:
"""
Reusable pre-flight validation helper for all image operations.
Extracted from generate_image() to be reused across all image operation functions.
Args:
user_id: User ID for subscription checking
operation_type: Type of operation (for logging)
num_operations: Number of operations to validate (default: 1)
log_prefix: Logging prefix for operation-specific logs
Raises:
HTTPException: If validation fails (subscription limits exceeded, etc.)
"""
if not user_id:
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
return
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=num_operations
)
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
def _track_image_operation_usage(
user_id: str,
provider: str,
model: str,
operation_type: str,
result_bytes: bytes,
cost: float,
prompt: Optional[str] = None,
endpoint: str = "/image-generation",
metadata: Optional[Dict[str, Any]] = None,
log_prefix: str = "[Image Generation]"
) -> Dict[str, Any]:
"""
Reusable usage tracking helper for all image operations.
Extracted from generate_image() to be reused across all image operation functions.
Args:
user_id: User ID for tracking
provider: Provider name (e.g., "wavespeed", "stability")
model: Model name used
operation_type: Type of operation (for logging)
result_bytes: Generated/processed image bytes
cost: Cost of the operation
prompt: Optional prompt text (for request size calculation)
endpoint: API endpoint path (for logging)
metadata: Optional additional metadata
log_prefix: Logging prefix for operation-specific logs
Returns:
Dictionary with tracking information (current_calls, cost, etc.)
"""
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
# Update image calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Determine API provider based on actual provider
api_provider = APIProvider.STABILITY # Default for image generation
# Create usage log
request_size = len(prompt.encode("utf-8")) if prompt else 0
usage_log = APIUsageLog(
user_id=user_id,
provider=api_provider,
endpoint=endpoint,
method="POST",
model_used=model or "unknown",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost,
response_time=0.0,
status_code=200,
request_size=request_size,
response_size=len(result_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
operation_name = operation_type.replace("-", " ").title()
print(f"""
[SUBSCRIPTION] {operation_name}
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider}
├─ Actual Provider: {provider}
├─ Model: {model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
return {
"current_calls": new_calls,
"cost": cost,
"total_cost": new_cost,
}
except Exception as track_error:
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
db_track.rollback()
return {}
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
return {}
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
"""Generate image with pre-flight validation.
@@ -55,32 +303,13 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
options: Image generation options (provider, model, width, height, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"[Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
logger.info(f"[Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image generation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
else:
logger.warning(f"[Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
# PRE-FLIGHT VALIDATION: Reuse extracted helper
_validate_image_operation(
user_id=user_id,
operation_type="image-generation",
num_operations=1,
log_prefix="[Image Generation]"
)
opts = options or {}
provider_name = _select_provider(opts.get("provider"))
@@ -114,151 +343,39 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
provider = _get_provider(provider_name)
result = provider.generate(image_options)
# TRACK USAGE after successful API call
has_image_bytes = bool(result.image_bytes) if result else False
image_bytes_len = len(result.image_bytes) if (result and result.image_bytes) else 0
logger.info(f"[Image Generation] Checking tracking conditions: user_id={user_id}, has_result={bool(result)}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
# TRACK USAGE after successful API call - Reuse extracted helper
if user_id and result and result.image_bytes:
logger.info(f"[Image Generation] ✅ API call successful, tracking usage for user {user_id}")
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get cost from result metadata or calculate
estimated_cost = 0.0
if result.metadata and "estimated_cost" in result.metadata:
estimated_cost = float(result.metadata["estimated_cost"])
# Calculate cost from result metadata or estimate
estimated_cost = 0.0
if result.metadata and "estimated_cost" in result.metadata:
estimated_cost = float(result.metadata["estimated_cost"])
else:
# Fallback: estimate based on provider/model
if provider_name == "wavespeed":
if result.model and "qwen" in result.model.lower():
estimated_cost = 0.05
else:
# Fallback: estimate based on provider/model
if provider_name == "wavespeed":
if result.model and "qwen" in result.model.lower():
estimated_cost = 0.05
else:
estimated_cost = 0.10 # ideogram-v3-turbo default
elif provider_name == "stability":
estimated_cost = 0.04
else:
estimated_cost = 0.05 # Default estimate
# Get current values before update
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
# Update image calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Determine API provider based on actual provider
api_provider = APIProvider.STABILITY # Default for image generation
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=api_provider,
endpoint="/image-generation",
method="POST",
model_used=result.model or "unknown",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(result.image_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
logger.info(f"[Image Generation] ✅ Successfully tracked usage: user {user_id} -> image -> {new_calls} calls, ${estimated_cost:.4f}")
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
print(f"""
[SUBSCRIPTION] Image Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider_name}
├─ Actual Provider: {provider_name}
├─ Model: {result.model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"[Image Generation] Full traceback: {traceback.format_exc()}")
estimated_cost = 0.10 # ideogram-v3-turbo default
elif provider_name == "stability":
estimated_cost = 0.04
else:
estimated_cost = 0.05 # Default estimate
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider=provider_name,
model=result.model or "unknown",
operation_type="image-generation",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-generation",
metadata=result.metadata,
log_prefix="[Image Generation]"
)
else:
logger.warning(f"[Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
@@ -290,32 +407,13 @@ def generate_character_image(
Returns:
bytes: Generated image bytes with consistent character
"""
# PRE-FLIGHT VALIDATION: Validate image generation before API call
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
from fastapi import HTTPException
logger.info(f"[Character Image Generation] 🔍 Starting pre-flight validation for user_id={user_id}")
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1,
)
logger.info(f"[Character Image Generation] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with character image generation")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Character Image Generation] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
finally:
db.close()
else:
logger.warning(f"[Character Image Generation] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
# PRE-FLIGHT VALIDATION: Reuse extracted helper
_validate_image_operation(
user_id=user_id,
operation_type="character-image-generation",
num_operations=1,
log_prefix="[Character Image Generation]"
)
# Generate character image via WaveSpeed
from services.wavespeed.client import WaveSpeedClient
@@ -332,132 +430,26 @@ def generate_character_image(
timeout=timeout,
)
# TRACK USAGE after successful API call
has_image_bytes = bool(image_bytes) if image_bytes else False
image_bytes_len = len(image_bytes) if image_bytes else 0
logger.info(f"[Character Image Generation] Checking tracking conditions: user_id={user_id}, has_image_bytes={has_image_bytes}, image_bytes_len={image_bytes_len}")
# TRACK USAGE after successful API call - Reuse extracted helper
if user_id and image_bytes:
logger.info(f"[Character Image Generation] ✅ API call successful, tracking usage for user {user_id}")
try:
from services.database import get_db as get_db_track
db_track = next(get_db_track())
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Character image cost (same as ideogram-v3-turbo)
estimated_cost = 0.10
current_calls_before = getattr(summary, "stability_calls", 0) or 0
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + estimated_cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET stability_calls = :new_calls,
stability_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + estimated_cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.STABILITY, # Image generation uses STABILITY provider
endpoint="/image-generation/character",
method="POST",
model_used="ideogram-character",
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=estimated_cost,
response_time=0.0,
status_code=200,
request_size=len(prompt.encode("utf-8")),
response_size=len(image_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else ''
# Get related stats
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
db_track.commit()
# UNIFIED SUBSCRIPTION LOG
print(f"""
[SUBSCRIPTION] Image Generation (Character)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: wavespeed
├─ Actual Provider: wavespeed
├─ Model: ideogram-character
├─ Calls: {current_calls_before}{new_calls} / {image_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
logger.info(f"[Character Image Generation] ✅ Successfully tracked usage: user {user_id} -> {new_calls} calls, ${estimated_cost:.4f}")
except Exception as track_error:
logger.error(f"[Character Image Generation] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[Character Image Generation] ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"[Character Image Generation] Full traceback: {traceback.format_exc()}")
# Character image cost (same as ideogram-v3-turbo)
estimated_cost = 0.10
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model="ideogram-character",
operation_type="character-image-generation",
result_bytes=image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-generation/character",
metadata=None,
log_prefix="[Character Image Generation]"
)
else:
logger.warning(f"[Character Image Generation] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(image_bytes) if image_bytes else 0} bytes")
@@ -476,3 +468,210 @@ def generate_character_image(
)
def generate_image_edit(
image_base64: str,
prompt: str,
operation: str = "general_edit",
model: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None
) -> ImageGenerationResult:
"""
Generate edited image - REUSES validation and tracking helpers.
Args:
image_base64: Base64-encoded input image (or data URI)
prompt: Edit instruction prompt
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
model: Model ID to use (default: auto-select based on provider)
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
user_id: User ID for validation and tracking
Returns:
ImageGenerationResult with edited image
Raises:
HTTPException: If validation fails or editing fails
ValueError: If options are invalid
"""
# 1. REUSE: Validation helper
_validate_image_operation(
user_id=user_id,
operation_type="image-edit",
num_operations=1,
log_prefix="[Image Edit]"
)
# 2. Determine provider from model or default to wavespeed
opts = options or {}
provider_name = opts.get("provider", "wavespeed")
# If model is specified and starts with "wavespeed", use wavespeed provider
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
provider_name = "wavespeed"
# 3. Get provider (REUSES provider pattern)
try:
provider = _get_edit_provider(provider_name)
except ValueError as e:
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
raise ValueError(f"Unsupported edit provider: {provider_name}")
# 4. Prepare edit options
edit_options = ImageEditOptions(
image_base64=image_base64,
prompt=prompt,
operation=operation,
mask_base64=opts.get("mask_base64"),
negative_prompt=opts.get("negative_prompt"),
model=model,
width=opts.get("width"),
height=opts.get("height"),
guidance_scale=opts.get("guidance_scale"),
steps=opts.get("steps"),
seed=opts.get("seed"),
extra=opts.get("extra"),
)
# 5. Edit image
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
try:
result = provider.edit(edit_options)
except Exception as e:
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=502,
detail={
"error": "Image editing failed",
"message": str(e)
}
)
def generate_face_swap(
base_image_base64: str,
face_image_base64: str,
model: Optional[str] = None,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None
) -> ImageGenerationResult:
"""
Generate face swap - REUSES validation and tracking helpers.
Args:
base_image_base64: Base64-encoded base image (or data URI)
face_image_base64: Base64-encoded face image to swap (or data URI)
model: Model ID to use (default: auto-select)
options: Additional options (target_face_index, target_gender, etc.)
user_id: User ID for validation and tracking
Returns:
ImageGenerationResult with swapped face image
Raises:
HTTPException: If validation fails or face swap fails
ValueError: If options are invalid
"""
# 1. REUSE: Validation helper
_validate_image_operation(
user_id=user_id,
operation_type="face-swap",
image_base64=base_image_base64, # Use base image for validation
log_prefix="[Face Swap]"
)
# 2. Get provider (default to wavespeed)
provider_name = "wavespeed"
provider = _get_face_swap_provider(provider_name)
# 3. Prepare options
face_swap_options = FaceSwapOptions(
base_image_base64=base_image_base64,
face_image_base64=face_image_base64,
model=model,
target_face_index=options.get("target_face_index") if options else None,
target_gender=options.get("target_gender") if options else None,
extra=options,
)
# 4. Swap face
try:
result = provider.swap_face(face_swap_options)
# 5. REUSE: Tracking helper
if user_id and result and result.image_bytes:
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
# Get model cost
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider=provider_name,
model=model_id,
operation_type="face-swap",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=None, # Face swap doesn't use prompts
endpoint="/image-studio/face-swap/process",
metadata={
"base_image_size": len(base_image_base64),
"face_image_size": len(face_image_base64),
},
log_prefix="[Face Swap]"
)
else:
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
return result
except HTTPException:
raise
except Exception as api_error:
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
raise HTTPException(
status_code=502,
detail={
"error": "Face swap failed",
"message": str(api_error)
}
)
# 6. REUSE: Tracking helper
if user_id and result and result.image_bytes:
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
# Get cost from result metadata or estimate
estimated_cost = 0.0
if result.metadata and "estimated_cost" in result.metadata:
estimated_cost = float(result.metadata["estimated_cost"])
else:
# Fallback: estimate based on provider/model
if provider_name == "wavespeed":
# Default WaveSpeed edit cost
estimated_cost = 0.02 # Default for most editing models
else:
estimated_cost = 0.05 # Default estimate
# Reuse tracking helper
_track_image_operation_usage(
user_id=user_id,
provider=provider_name,
model=result.model or model or "unknown",
operation_type="image-edit",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-generation/edit",
metadata=result.metadata,
log_prefix="[Image Edit]"
)
else:
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result.image_bytes else 0} bytes")
return result