692 lines
29 KiB
Python
692 lines
29 KiB
Python
"""WaveSpeed AI image editing provider (14 editing models)."""
|
|
|
|
import io
|
|
import os
|
|
import requests
|
|
from typing import Optional
|
|
from PIL import Image
|
|
from fastapi import HTTPException
|
|
|
|
from .base import ImageEditProvider, ImageEditOptions, ImageGenerationResult
|
|
from services.wavespeed.client import WaveSpeedClient
|
|
from utils.logger_utils import get_service_logger
|
|
|
|
|
|
logger = get_service_logger("wavespeed.edit_provider")
|
|
|
|
|
|
class WaveSpeedEditProvider(ImageEditProvider):
|
|
"""WaveSpeed AI image editing provider supporting 14 editing models.
|
|
|
|
REUSES: WaveSpeedClient, model registry pattern, result format
|
|
"""
|
|
|
|
# Model registry - populated with WaveSpeed editing models
|
|
SUPPORTED_MODELS = {
|
|
"qwen-edit": {
|
|
"model_path": "wavespeed-ai/qwen-image/edit",
|
|
"name": "Qwen Image Edit",
|
|
"description": "20B MMDiT image-to-image model offering precise bilingual (Chinese & English) text edits while preserving style. Single-image editing with style preservation.",
|
|
"cost": 0.02, # Same as Plus version
|
|
"max_resolution": (1536, 1536), # Based on docs: similar to Plus
|
|
"capabilities": ["general_edit", "style_transfer", "text_edit"],
|
|
"tier": "budget",
|
|
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
|
"supports_controlnet": False, # Not mentioned in docs
|
|
"languages": ["en", "zh"],
|
|
"api_params": {
|
|
"uses_size": True, # Uses "size" parameter (width*height)
|
|
"uses_aspect_ratio": False,
|
|
"uses_resolution": False,
|
|
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
|
"default_output_format": "jpeg", # Per API docs: default is "jpeg"
|
|
"supports_seed": True, # Per API docs: seed parameter supported
|
|
}
|
|
},
|
|
"qwen-edit-plus": {
|
|
"model_path": "wavespeed-ai/qwen-image/edit-plus",
|
|
"name": "Qwen Image Edit Plus",
|
|
"description": "20B MMDiT image editor with multi-image editing, single-image consistency and native ControlNet support. Bilingual (CN/EN) text editing, appearance-level and semantic-level edits.",
|
|
"cost": 0.02,
|
|
"max_resolution": (1536, 1536), # Based on docs: 256-1536 per dimension
|
|
"capabilities": ["general_edit", "style_transfer", "text_edit", "multi_image"],
|
|
"tier": "budget",
|
|
"supports_multi_image": True, # Up to 3 reference images
|
|
"supports_controlnet": True,
|
|
"languages": ["en", "zh"],
|
|
"api_params": {
|
|
"uses_size": True, # Uses "size" parameter (width*height)
|
|
"uses_aspect_ratio": False,
|
|
"uses_resolution": False,
|
|
"uses_image_singular": False, # Uses "images" (array)
|
|
"supports_seed": True, # Seed parameter supported (default for Qwen models)
|
|
}
|
|
},
|
|
"nano-banana-pro-edit-ultra": {
|
|
"model_path": "google/nano-banana-pro/edit-ultra",
|
|
"name": "Google Nano Banana Pro Edit Ultra",
|
|
"description": "High-resolution image editing with 4K/8K native output. Natural language instructions, multilingual text support. Premium quality editing for professional marketing and high-res work.",
|
|
"cost": 0.15, # 4K - from enhancement proposal
|
|
"cost_8k": 0.18, # 8K - from enhancement proposal
|
|
"max_resolution": (8192, 8192), # 8K support
|
|
"capabilities": ["general_edit", "high_res", "professional", "typography"],
|
|
"tier": "premium",
|
|
"supports_multi_image": True, # Up to 14 reference images
|
|
"supports_controlnet": False,
|
|
"languages": ["en", "multilingual"],
|
|
"api_params": {
|
|
"uses_size": False, # Uses aspect_ratio and resolution instead
|
|
"uses_aspect_ratio": True, # "1:1", "16:9", etc.
|
|
"uses_resolution": True, # "4k" or "8k"
|
|
"max_images": 14,
|
|
"default_output_format": "png", # Per API docs: default is "png"
|
|
"supports_seed": False, # Per API docs: no seed parameter
|
|
}
|
|
},
|
|
"seedream-v4.5-edit": {
|
|
"model_path": "bytedance/seedream-v4.5/edit",
|
|
"name": "Bytedance Seedream V4.5 Edit",
|
|
"description": "Preserves facial features, lighting, and color tone from reference images, delivering professional, high-fidelity edits up to 4K with strong prompt adherence. Reference-faithful editing with multi-image support.",
|
|
"cost": 0.04, # Per generated image
|
|
"max_resolution": (4096, 4096), # 4K support (1024-4096 per dimension)
|
|
"capabilities": ["general_edit", "portrait_retouching", "fashion_edit", "product_edit", "multi_image"],
|
|
"tier": "mid",
|
|
"supports_multi_image": True, # Up to 10 reference images
|
|
"supports_controlnet": False,
|
|
"languages": ["en"],
|
|
"api_params": {
|
|
"uses_size": True, # Uses "size" parameter (width*height format, 1024-4096 per dimension)
|
|
"uses_aspect_ratio": False,
|
|
"uses_resolution": False,
|
|
"max_images": 10,
|
|
"default_output_format": "png",
|
|
"supports_seed": False, # No seed parameter in API docs (Seedream V4.5)
|
|
}
|
|
},
|
|
"flux-kontext-pro": {
|
|
"model_path": "wavespeed-ai/flux-kontext-pro",
|
|
"name": "FLUX Kontext Pro",
|
|
"description": "FLUX.1 Kontext [pro] offers improved prompt adherence and accurate typography generation for consistent, high-quality edits at speed. Typography-focused editing with improved prompt adherence.",
|
|
"cost": 0.04, # From enhancement proposal
|
|
"max_resolution": (2048, 2048), # Estimated, not specified in docs
|
|
"capabilities": ["general_edit", "typography", "text_edit", "style_transfer"],
|
|
"tier": "mid",
|
|
"supports_multi_image": False, # Single image only (uses "image" not "images")
|
|
"supports_controlnet": False,
|
|
"languages": ["en"],
|
|
"api_params": {
|
|
"uses_size": False, # Uses aspect_ratio instead
|
|
"uses_aspect_ratio": True, # Aspect ratio as string (e.g., "16:9", "1:1")
|
|
"uses_resolution": False,
|
|
"uses_image_singular": True, # Uses "image" (singular) not "images" (array)
|
|
"supports_guidance_scale": True, # Has guidance_scale parameter (default 3.5, range 1-20)
|
|
"default_guidance_scale": 3.5, # Per API docs
|
|
"supports_seed": False, # No seed parameter in API docs
|
|
}
|
|
},
|
|
# TODO: Add remaining 9 models once docs are provided
|
|
}
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
"""Initialize WaveSpeed edit provider.
|
|
|
|
Args:
|
|
api_key: WaveSpeed API key (falls back to env var if not provided)
|
|
"""
|
|
self.api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
|
if not self.api_key:
|
|
raise ValueError("WaveSpeed API key not found. Set WAVESPEED_API_KEY environment variable.")
|
|
|
|
# REUSE: Same client as generation provider
|
|
self.client = WaveSpeedClient(api_key=self.api_key)
|
|
logger.info("[WaveSpeed Edit Provider] Initialized with %d models",
|
|
len(self.SUPPORTED_MODELS))
|
|
|
|
def _validate_options(self, options: ImageEditOptions) -> None:
|
|
"""Validate editing options.
|
|
|
|
Args:
|
|
options: Image editing options
|
|
|
|
Raises:
|
|
ValueError: If options are invalid
|
|
"""
|
|
model = options.model or list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None
|
|
|
|
if not model:
|
|
raise ValueError("No model specified and no default model available")
|
|
|
|
if model not in self.SUPPORTED_MODELS:
|
|
raise ValueError(
|
|
f"Unsupported model: {model}. "
|
|
f"Supported models: {list(self.SUPPORTED_MODELS.keys())}"
|
|
)
|
|
|
|
model_info = self.SUPPORTED_MODELS[model]
|
|
max_width, max_height = model_info.get("max_resolution", (4096, 4096))
|
|
|
|
if options.width and options.width > max_width:
|
|
raise ValueError(
|
|
f"Width {options.width} exceeds maximum {max_width} for model {model}"
|
|
)
|
|
|
|
if options.height and options.height > max_height:
|
|
raise ValueError(
|
|
f"Height {options.height} exceeds maximum {max_height} for model {model}"
|
|
)
|
|
|
|
if not options.prompt or len(options.prompt.strip()) == 0:
|
|
raise ValueError("Prompt cannot be empty")
|
|
|
|
if not options.image_base64:
|
|
raise ValueError("Image base64 cannot be empty")
|
|
|
|
def edit(self, options: ImageEditOptions) -> ImageGenerationResult:
|
|
"""Edit image using WaveSpeed AI models.
|
|
|
|
Args:
|
|
options: Image editing options
|
|
|
|
Returns:
|
|
ImageGenerationResult with edited image
|
|
|
|
Raises:
|
|
ValueError: If options are invalid
|
|
RuntimeError: If editing fails
|
|
"""
|
|
# Validate options
|
|
self._validate_options(options)
|
|
|
|
# Determine model
|
|
model = options.model or (list(self.SUPPORTED_MODELS.keys())[0] if self.SUPPORTED_MODELS else None)
|
|
if not model:
|
|
raise ValueError("No model available for editing")
|
|
|
|
model_info = self.SUPPORTED_MODELS[model]
|
|
model_path = model_info["model_path"]
|
|
|
|
logger.info("[WaveSpeed Edit] Starting edit: model=%s, operation=%s, prompt=%s",
|
|
model, options.operation, options.prompt[:100])
|
|
|
|
try:
|
|
# Prepare extra parameters based on model capabilities
|
|
extra_params = options.extra or {}
|
|
|
|
# Add model-specific parameters if needed
|
|
api_params = model_info.get("api_params", {})
|
|
if api_params.get("uses_resolution", False):
|
|
# For Nano Banana: determine resolution from dimensions or use default
|
|
if options.width and options.height:
|
|
if options.width >= 4096 or options.height >= 4096:
|
|
extra_params["resolution"] = "8k"
|
|
else:
|
|
extra_params["resolution"] = "4k"
|
|
elif "resolution" not in extra_params:
|
|
extra_params["resolution"] = "4k" # Default to 4K
|
|
|
|
if api_params.get("uses_aspect_ratio", False) and not extra_params.get("aspect_ratio"):
|
|
# Calculate aspect ratio if dimensions provided
|
|
if options.width and options.height:
|
|
aspect_ratio = self._calculate_aspect_ratio(options.width, options.height)
|
|
if aspect_ratio:
|
|
extra_params["aspect_ratio"] = aspect_ratio
|
|
|
|
# Call WaveSpeed API for editing
|
|
result = self._call_wavespeed_edit_api(
|
|
model_path=model_path,
|
|
image_base64=options.image_base64,
|
|
prompt=options.prompt,
|
|
operation=options.operation,
|
|
mask_base64=options.mask_base64,
|
|
negative_prompt=options.negative_prompt,
|
|
width=options.width,
|
|
height=options.height,
|
|
guidance_scale=options.guidance_scale,
|
|
steps=options.steps,
|
|
seed=options.seed,
|
|
extra=extra_params
|
|
)
|
|
|
|
# Extract image bytes from result
|
|
if isinstance(result, bytes):
|
|
image_bytes = result
|
|
elif isinstance(result, dict) and "image" in result:
|
|
image_bytes = result["image"]
|
|
elif isinstance(result, dict) and "image_bytes" in result:
|
|
image_bytes = result["image_bytes"]
|
|
else:
|
|
raise ValueError(f"Unexpected response format from WaveSpeed API: {type(result)}")
|
|
|
|
# Load image to get dimensions
|
|
image = Image.open(io.BytesIO(image_bytes))
|
|
width, height = image.size
|
|
|
|
# Calculate estimated cost - handle resolution-based pricing
|
|
estimated_cost = model_info.get("cost", 0.02)
|
|
if api_params.get("uses_resolution", False):
|
|
# Check if 8K was requested
|
|
resolution = extra_params.get("resolution", "4k")
|
|
if resolution == "8k" and "cost_8k" in model_info:
|
|
estimated_cost = model_info["cost_8k"]
|
|
|
|
logger.info("[WaveSpeed Edit] ✅ Successfully edited image: %d bytes, %dx%d",
|
|
len(image_bytes), width, height)
|
|
|
|
# REUSE: Same result format as generation
|
|
return ImageGenerationResult(
|
|
image_bytes=image_bytes,
|
|
width=width,
|
|
height=height,
|
|
provider="wavespeed",
|
|
model=model,
|
|
seed=options.seed,
|
|
metadata={
|
|
"provider": "wavespeed",
|
|
"model": model,
|
|
"model_name": model_info.get("name", model),
|
|
"operation": options.operation,
|
|
"prompt": options.prompt,
|
|
"negative_prompt": options.negative_prompt,
|
|
"estimated_cost": estimated_cost,
|
|
"tier": model_info.get("tier", "mid"),
|
|
}
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("[WaveSpeed Edit] ❌ Error editing image: %s", str(e), exc_info=True)
|
|
raise RuntimeError(f"WaveSpeed edit failed: {str(e)}")
|
|
|
|
def _call_wavespeed_edit_api(
|
|
self,
|
|
model_path: str,
|
|
image_base64: str,
|
|
prompt: str,
|
|
operation: str,
|
|
mask_base64: Optional[str] = None,
|
|
negative_prompt: Optional[str] = None,
|
|
width: Optional[int] = None,
|
|
height: Optional[int] = None,
|
|
guidance_scale: Optional[float] = None,
|
|
steps: Optional[int] = None,
|
|
seed: Optional[int] = None,
|
|
extra: Optional[dict] = None
|
|
) -> bytes:
|
|
"""Call WaveSpeed API for image editing.
|
|
|
|
REUSES: Same pattern as ImageGenerator.generate_image()
|
|
|
|
Args:
|
|
model_path: Full model path (e.g., "wavespeed-ai/qwen-image/edit-plus")
|
|
image_base64: Base64-encoded input image
|
|
prompt: Edit instruction prompt
|
|
operation: Type of operation
|
|
mask_base64: Optional mask for inpainting
|
|
negative_prompt: Optional negative prompt
|
|
width: Optional target width
|
|
height: Optional target height
|
|
guidance_scale: Optional guidance scale (not used by all models)
|
|
steps: Optional number of steps (not used by all models)
|
|
seed: Optional seed
|
|
extra: Optional extra parameters
|
|
|
|
Returns:
|
|
Edited image bytes
|
|
|
|
Raises:
|
|
RuntimeError: If API call fails
|
|
"""
|
|
import requests
|
|
from fastapi import HTTPException
|
|
|
|
# Build URL - REUSES same pattern as ImageGenerator
|
|
url = f"{self.client.BASE_URL}/{model_path}"
|
|
|
|
# Prepare images array - WaveSpeed expects array of image strings
|
|
# Format: base64 strings or data URIs (data:image/png;base64,...)
|
|
# For Qwen Image Edit Plus: supports up to 3 reference images
|
|
images = []
|
|
|
|
# Add main image - check if it's already a data URI or just base64
|
|
if image_base64.startswith("data:image"):
|
|
# Already a data URI
|
|
images.append(image_base64)
|
|
else:
|
|
# Assume it's base64, convert to data URI
|
|
# Try to detect format from base64 or default to PNG
|
|
images.append(f"data:image/png;base64,{image_base64}")
|
|
|
|
# If mask is provided, add it as second image
|
|
# Note: Some models may need mask in different format - will adjust per model
|
|
if mask_base64:
|
|
if mask_base64.startswith("data:image"):
|
|
images.append(mask_base64)
|
|
else:
|
|
images.append(f"data:image/png;base64,{mask_base64}")
|
|
|
|
# Get model info to determine API parameter structure
|
|
model_info = self.SUPPORTED_MODELS.get(model_path.split("/")[-1] if "/" in model_path else model_path)
|
|
if not model_info:
|
|
# Fallback: try to find model by matching path
|
|
for model_id, info in self.SUPPORTED_MODELS.items():
|
|
if info["model_path"] == model_path:
|
|
model_info = info
|
|
break
|
|
|
|
if not model_info:
|
|
raise ValueError(f"Model info not found for: {model_path}")
|
|
|
|
api_params = model_info.get("api_params", {})
|
|
|
|
# Build payload - following WaveSpeed API structure
|
|
# Note: output_format default varies by model (PNG for most, but can be JPEG)
|
|
default_output_format = api_params.get("default_output_format", "png")
|
|
|
|
# Some models use "image" (singular) instead of "images" (array)
|
|
uses_image_singular = api_params.get("uses_image_singular", False)
|
|
|
|
payload = {
|
|
"prompt": prompt,
|
|
"enable_sync_mode": True, # Use sync mode for immediate results
|
|
"enable_base64_output": False, # Get URL, then download
|
|
"output_format": default_output_format,
|
|
}
|
|
|
|
# Add image(s) based on model API format
|
|
if uses_image_singular:
|
|
# Models like Qwen Edit (basic) use "image" (singular)
|
|
# Use first image only (single image editing)
|
|
if images:
|
|
payload["image"] = images[0]
|
|
else:
|
|
raise ValueError("At least one image is required")
|
|
else:
|
|
# Models like Qwen Edit Plus, Nano Banana use "images" (array)
|
|
payload["images"] = images
|
|
|
|
# Allow override of output_format from extra params
|
|
if extra and "output_format" in extra:
|
|
payload["output_format"] = extra["output_format"]
|
|
|
|
# Model-specific parameter handling
|
|
if api_params.get("uses_size", True):
|
|
# Models like Qwen Edit Plus use "size" parameter (width*height format)
|
|
if width and height:
|
|
payload["size"] = f"{width}*{height}"
|
|
elif width:
|
|
payload["size"] = f"{width}*{width}" # Square if only width provided
|
|
elif height:
|
|
payload["size"] = f"{height}*{height}" # Square if only height provided
|
|
|
|
if api_params.get("uses_aspect_ratio", False):
|
|
# Models like Nano Banana and FLUX Kontext Pro use "aspect_ratio" parameter
|
|
if width and height:
|
|
# Calculate aspect ratio from dimensions
|
|
aspect_ratio = self._calculate_aspect_ratio(width, height)
|
|
if aspect_ratio:
|
|
payload["aspect_ratio"] = aspect_ratio
|
|
elif extra and "aspect_ratio" in extra:
|
|
payload["aspect_ratio"] = extra["aspect_ratio"]
|
|
|
|
if api_params.get("uses_resolution", False):
|
|
# Models like Nano Banana use "resolution" parameter ("4k" or "8k")
|
|
if extra and "resolution" in extra:
|
|
payload["resolution"] = extra["resolution"]
|
|
else:
|
|
# Default to 4K, or 8K if dimensions suggest high-res
|
|
if width and height and (width >= 4096 or height >= 4096):
|
|
payload["resolution"] = "8k"
|
|
else:
|
|
payload["resolution"] = "4k" # Default to 4K per API docs
|
|
|
|
# Add optional parameters (model-agnostic)
|
|
# Guidance scale: Only add if model supports it (e.g., FLUX Kontext Pro)
|
|
if api_params.get("supports_guidance_scale", False):
|
|
default_guidance = api_params.get("default_guidance_scale", 3.5)
|
|
if guidance_scale is not None:
|
|
# Clamp to valid range (1-20 per FLUX Kontext Pro docs)
|
|
payload["guidance_scale"] = max(1, min(20, guidance_scale))
|
|
elif extra and "guidance_scale" in extra:
|
|
payload["guidance_scale"] = max(1, min(20, extra["guidance_scale"]))
|
|
else:
|
|
payload["guidance_scale"] = default_guidance
|
|
|
|
# Seed parameter: Only add if model supports it
|
|
if api_params.get("supports_seed", True): # Default to True for backward compatibility
|
|
if seed is not None:
|
|
payload["seed"] = seed
|
|
else:
|
|
payload["seed"] = -1 # Random seed (per API docs default)
|
|
|
|
# Add any extra parameters
|
|
if extra:
|
|
# Filter out parameters we've already handled
|
|
handled_params = {"aspect_ratio", "resolution", "size", "seed", "guidance_scale"}
|
|
for key, value in extra.items():
|
|
if key not in handled_params:
|
|
payload[key] = value
|
|
|
|
logger.info(f"[WaveSpeed Edit] Submitting edit request to {url} (model={model_path}, prompt_length={len(prompt)})")
|
|
|
|
# Make API call - REUSES same pattern as ImageGenerator
|
|
try:
|
|
response = requests.post(
|
|
url,
|
|
headers=self.client._headers(),
|
|
json=payload,
|
|
timeout=120
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
logger.error(f"[WaveSpeed Edit] API call failed: {response.status_code} {response.text}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail={
|
|
"error": "WaveSpeed image editing failed",
|
|
"status_code": response.status_code,
|
|
"response": response.text[:500],
|
|
},
|
|
)
|
|
|
|
response_json = response.json()
|
|
data = response_json.get("data") or response_json
|
|
|
|
# Check status
|
|
status = data.get("status", "").lower()
|
|
outputs = data.get("outputs") or []
|
|
prediction_id = data.get("id")
|
|
|
|
logger.debug(
|
|
f"[WaveSpeed Edit] Response: status='{status}', outputs_count={len(outputs)}, "
|
|
f"prediction_id={prediction_id}"
|
|
)
|
|
|
|
# Handle sync mode - result should be directly in outputs
|
|
if outputs and status == "completed":
|
|
logger.info(f"[WaveSpeed Edit] Got immediate results from sync mode")
|
|
image_url = self._extract_image_url(outputs)
|
|
return self._download_image(image_url, timeout=120)
|
|
|
|
# Sync mode returned "created" or "processing" - need to poll
|
|
if not prediction_id:
|
|
logger.error(f"[WaveSpeed Edit] Sync mode returned status '{status}' but no prediction ID")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed sync mode returned async response without prediction ID",
|
|
)
|
|
|
|
logger.info(
|
|
f"[WaveSpeed Edit] Sync mode returned status '{status}' with no outputs. "
|
|
f"Polling for result (prediction_id: {prediction_id})"
|
|
)
|
|
|
|
# Poll for result - REUSES polling utility
|
|
result = self.client.poll_until_complete(
|
|
prediction_id,
|
|
timeout_seconds=180,
|
|
interval_seconds=2.0,
|
|
)
|
|
|
|
outputs = result.get("outputs") or []
|
|
if not outputs:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed edit returned no outputs after polling"
|
|
)
|
|
|
|
# Extract image URL from outputs - REUSE helper method
|
|
image_url = self._extract_image_url(outputs)
|
|
return self._download_image(image_url, timeout=120)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
logger.error(f"[WaveSpeed Edit] Unexpected error: {str(e)}", exc_info=True)
|
|
raise RuntimeError(f"WaveSpeed edit API call failed: {str(e)}")
|
|
|
|
def _extract_image_url(self, outputs: list) -> str:
|
|
"""Extract image URL from outputs - REUSES same pattern as ImageGenerator.
|
|
|
|
Args:
|
|
outputs: Array of output URLs or objects
|
|
|
|
Returns:
|
|
Image URL string
|
|
|
|
Raises:
|
|
HTTPException: If output format is invalid
|
|
"""
|
|
if not isinstance(outputs, list) or len(outputs) == 0:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed edit returned no outputs",
|
|
)
|
|
|
|
first_output = outputs[0]
|
|
if isinstance(first_output, str):
|
|
image_url = first_output
|
|
elif isinstance(first_output, dict):
|
|
image_url = first_output.get("url") or first_output.get("image_url") or first_output.get("output")
|
|
else:
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed edit output format not recognized",
|
|
)
|
|
|
|
if not image_url or not (image_url.startswith("http://") or image_url.startswith("https://")):
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail="WaveSpeed edit returned invalid image URL",
|
|
)
|
|
|
|
return image_url
|
|
|
|
def _download_image(self, image_url: str, timeout: int = 120) -> bytes:
|
|
"""Download image from URL - REUSES same pattern as ImageGenerator.
|
|
|
|
Args:
|
|
image_url: URL to download from
|
|
timeout: Request timeout in seconds
|
|
|
|
Returns:
|
|
Image bytes
|
|
|
|
Raises:
|
|
HTTPException: If download fails
|
|
"""
|
|
logger.info(f"[WaveSpeed Edit] Downloading edited image from: {image_url}")
|
|
image_response = requests.get(image_url, timeout=timeout)
|
|
|
|
if image_response.status_code != 200:
|
|
logger.error(f"[WaveSpeed Edit] Failed to download image: {image_response.status_code}")
|
|
raise HTTPException(
|
|
status_code=502,
|
|
detail=f"Failed to download edited image: {image_response.status_code}"
|
|
)
|
|
|
|
logger.info(f"[WaveSpeed Edit] Successfully downloaded image ({len(image_response.content)} bytes)")
|
|
return image_response.content
|
|
|
|
def _calculate_aspect_ratio(self, width: int, height: int) -> Optional[str]:
|
|
"""Calculate aspect ratio string from dimensions.
|
|
|
|
Args:
|
|
width: Image width
|
|
height: Image height
|
|
|
|
Returns:
|
|
Aspect ratio string (e.g., "16:9") or None if not standard
|
|
"""
|
|
# Common aspect ratios (includes FLUX Kontext Pro supported ratios)
|
|
ratios = {
|
|
(1, 1): "1:1",
|
|
(3, 2): "3:2",
|
|
(2, 3): "2:3",
|
|
(3, 4): "3:4",
|
|
(4, 3): "4:3",
|
|
(4, 5): "4:5",
|
|
(5, 4): "5:4",
|
|
(9, 16): "9:16",
|
|
(16, 9): "16:9",
|
|
(21, 9): "21:9",
|
|
(9, 21): "9:21", # FLUX Kontext Pro also supports 9:21
|
|
}
|
|
|
|
# Calculate GCD to simplify ratio
|
|
def gcd(a, b):
|
|
while b:
|
|
a, b = b, a % b
|
|
return a
|
|
|
|
divisor = gcd(width, height)
|
|
simplified = (width // divisor, height // divisor)
|
|
|
|
# Check if it matches a standard ratio (with some tolerance)
|
|
for (w, h), ratio_str in ratios.items():
|
|
# Allow small tolerance for rounding
|
|
if abs(simplified[0] / simplified[1] - w / h) < 0.01:
|
|
return ratio_str
|
|
|
|
# If no match, return None (model may not support custom aspect ratios)
|
|
return None
|
|
|
|
@classmethod
|
|
def get_available_models(cls) -> dict:
|
|
"""Get available editing models and their information.
|
|
|
|
Returns:
|
|
Dictionary of available models
|
|
"""
|
|
return cls.SUPPORTED_MODELS
|
|
|
|
@classmethod
|
|
def get_models_by_tier(cls, tier: str) -> dict:
|
|
"""Get models filtered by tier (budget, mid, premium).
|
|
|
|
Args:
|
|
tier: Tier name ("budget", "mid", "premium")
|
|
|
|
Returns:
|
|
Dictionary of models in the specified tier
|
|
"""
|
|
return {
|
|
model_id: model_info
|
|
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
|
if model_info.get("tier") == tier
|
|
}
|
|
|
|
@classmethod
|
|
def get_models_by_operation(cls, operation: str) -> dict:
|
|
"""Get models that support a specific operation.
|
|
|
|
Args:
|
|
operation: Operation type (e.g., "inpaint", "outpaint", "general_edit")
|
|
|
|
Returns:
|
|
Dictionary of models supporting the operation
|
|
"""
|
|
return {
|
|
model_id: model_info
|
|
for model_id, model_info in cls.SUPPORTED_MODELS.items()
|
|
if operation in model_info.get("capabilities", [])
|
|
}
|