Changes: 1. helpers.py (_track_image_operation_usage): Map provider name to DB columns dynamically (stability→stability_calls, wavespeed→wavespeed_calls, etc.) instead of hardcoding stability_calls/stability_cost. 2. upscale_service.py: Added _track_image_operation_usage() call after successful Stability upscale completion. 3. control_service.py: Added _track_image_operation_usage() call after successful Stability control operation completion. 4. edit_service.py: Added _track_image_operation_usage() call after successful Stability edit operation (remove_background, inpaint, outpaint, search_replace, search_recolor, relight). Previously only Create Studio and Face Swap tracked usage. Now all five studios correctly decrement subscription limits.
802 lines
32 KiB
Python
802 lines
32 KiB
Python
"""Edit Studio service for AI-powered image editing and transformations."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import io
|
||
from dataclasses import dataclass, field
|
||
from typing import Any, Dict, Literal, Optional
|
||
|
||
from PIL import Image
|
||
|
||
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
|
||
from services.llm_providers.main_image_generation import generate_image_edit
|
||
from services.stability_service import StabilityAIService
|
||
from utils.logger_utils import get_service_logger
|
||
|
||
|
||
logger = get_service_logger("image_studio.edit")
|
||
|
||
|
||
EditOperationType = Literal[
|
||
"remove_background",
|
||
"inpaint",
|
||
"outpaint",
|
||
"search_replace",
|
||
"search_recolor",
|
||
"relight",
|
||
"general_edit",
|
||
]
|
||
|
||
|
||
@dataclass
|
||
class EditStudioRequest:
|
||
"""Normalized request payload for Edit Studio operations."""
|
||
|
||
image_base64: str
|
||
operation: EditOperationType
|
||
prompt: Optional[str] = None
|
||
negative_prompt: Optional[str] = None
|
||
mask_base64: Optional[str] = None
|
||
search_prompt: Optional[str] = None
|
||
select_prompt: Optional[str] = None
|
||
background_image_base64: Optional[str] = None
|
||
lighting_image_base64: Optional[str] = None
|
||
expand_left: Optional[int] = None
|
||
expand_right: Optional[int] = None
|
||
expand_up: Optional[int] = None
|
||
expand_down: Optional[int] = None
|
||
provider: Optional[str] = None
|
||
model: Optional[str] = None
|
||
style_preset: Optional[str] = None
|
||
guidance_scale: Optional[float] = None
|
||
steps: Optional[int] = None
|
||
seed: Optional[int] = None
|
||
output_format: str = "png"
|
||
options: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class EditStudioService:
|
||
"""Service layer orchestrating Edit Studio operations."""
|
||
|
||
SUPPORTED_OPERATIONS: Dict[EditOperationType, Dict[str, Any]] = {
|
||
"remove_background": {
|
||
"label": "Remove Background",
|
||
"description": "Isolate the main subject and remove the background.",
|
||
"provider": "stability",
|
||
"async": False,
|
||
"fields": {
|
||
"prompt": False,
|
||
"mask": False,
|
||
"negative_prompt": False,
|
||
"search_prompt": False,
|
||
"select_prompt": False,
|
||
"background": False,
|
||
"lighting": False,
|
||
"expansion": False,
|
||
},
|
||
},
|
||
"inpaint": {
|
||
"label": "Inpaint & Fix",
|
||
"description": "Edit specific regions using prompts and optional masks.",
|
||
"provider": "stability",
|
||
"async": False,
|
||
"fields": {
|
||
"prompt": True,
|
||
"mask": True,
|
||
"negative_prompt": True,
|
||
"search_prompt": False,
|
||
"select_prompt": False,
|
||
"background": False,
|
||
"lighting": False,
|
||
"expansion": False,
|
||
},
|
||
},
|
||
"outpaint": {
|
||
"label": "Outpaint",
|
||
"description": "Extend the canvas in any direction with smart fill.",
|
||
"provider": "stability",
|
||
"async": False,
|
||
"fields": {
|
||
"prompt": False,
|
||
"mask": False,
|
||
"negative_prompt": True,
|
||
"search_prompt": False,
|
||
"select_prompt": False,
|
||
"background": False,
|
||
"lighting": False,
|
||
"expansion": True,
|
||
},
|
||
},
|
||
"search_replace": {
|
||
"label": "Search & Replace",
|
||
"description": "Locate objects via search prompt and replace them. Optional mask for precise control.",
|
||
"provider": "stability",
|
||
"async": False,
|
||
"fields": {
|
||
"prompt": True,
|
||
"mask": True, # Optional mask for precise region selection
|
||
"negative_prompt": False,
|
||
"search_prompt": True,
|
||
"select_prompt": False,
|
||
"background": False,
|
||
"lighting": False,
|
||
"expansion": False,
|
||
},
|
||
},
|
||
"search_recolor": {
|
||
"label": "Search & Recolor",
|
||
"description": "Select elements via prompt and recolor them. Optional mask for exact region selection.",
|
||
"provider": "stability",
|
||
"async": False,
|
||
"fields": {
|
||
"prompt": True,
|
||
"mask": True, # Optional mask for precise region selection
|
||
"negative_prompt": False,
|
||
"search_prompt": False,
|
||
"select_prompt": True,
|
||
"background": False,
|
||
"lighting": False,
|
||
"expansion": False,
|
||
},
|
||
},
|
||
"relight": {
|
||
"label": "Replace Background & Relight",
|
||
"description": "Swap backgrounds and relight using reference images.",
|
||
"provider": "stability",
|
||
"async": True,
|
||
"fields": {
|
||
"prompt": False,
|
||
"mask": False,
|
||
"negative_prompt": False,
|
||
"search_prompt": False,
|
||
"select_prompt": False,
|
||
"background": True,
|
||
"lighting": True,
|
||
"expansion": False,
|
||
},
|
||
},
|
||
"general_edit": {
|
||
"label": "Prompt-based Edit",
|
||
"description": "Free-form editing powered by Hugging Face image-to-image models. Optional mask for selective editing.",
|
||
"provider": "huggingface",
|
||
"async": False,
|
||
"fields": {
|
||
"prompt": True,
|
||
"mask": True, # Optional mask for selective region editing
|
||
"negative_prompt": True,
|
||
"search_prompt": False,
|
||
"select_prompt": False,
|
||
"background": False,
|
||
"lighting": False,
|
||
"expansion": False,
|
||
},
|
||
},
|
||
}
|
||
|
||
def __init__(self):
|
||
logger.info("[Edit Studio] Initialized edit service")
|
||
|
||
@staticmethod
|
||
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
|
||
"""Decode a base64 (or data URL) string to bytes."""
|
||
if not value:
|
||
return None
|
||
|
||
try:
|
||
# Handle data URLs (data:image/png;base64,...)
|
||
if value.startswith("data:"):
|
||
_, b64data = value.split(",", 1)
|
||
else:
|
||
b64data = value
|
||
|
||
return base64.b64decode(b64data)
|
||
except Exception as exc:
|
||
logger.error(f"[Edit Studio] Failed to decode base64 image: {exc}")
|
||
raise ValueError("Invalid base64 image payload") from exc
|
||
|
||
@staticmethod
|
||
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
|
||
"""Extract width/height metadata from image bytes."""
|
||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||
return {
|
||
"width": img.width,
|
||
"height": img.height,
|
||
}
|
||
|
||
@staticmethod
|
||
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
|
||
"""Convert raw bytes to base64 data URL."""
|
||
b64 = base64.b64encode(image_bytes).decode("utf-8")
|
||
return f"data:image/{output_format};base64,{b64}"
|
||
|
||
def list_operations(self) -> Dict[str, Dict[str, Any]]:
|
||
"""Expose supported operations for UI rendering."""
|
||
return self.SUPPORTED_OPERATIONS
|
||
|
||
def get_available_models(
|
||
self,
|
||
operation: Optional[str] = None,
|
||
tier: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Get available WaveSpeed editing models.
|
||
|
||
Args:
|
||
operation: Filter by operation type (e.g., "general_edit")
|
||
tier: Filter by tier ("budget", "mid", "premium")
|
||
|
||
Returns:
|
||
Dictionary with models and metadata
|
||
"""
|
||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||
|
||
provider = WaveSpeedEditProvider()
|
||
all_models = provider.get_available_models()
|
||
|
||
# Filter by operation if specified
|
||
if operation:
|
||
filtered = provider.get_models_by_operation(operation)
|
||
all_models = {k: v for k, v in all_models.items() if k in filtered}
|
||
|
||
# Filter by tier if specified
|
||
if tier:
|
||
filtered = provider.get_models_by_tier(tier)
|
||
all_models = {k: v for k, v in all_models.items() if k in filtered}
|
||
|
||
# Format for API response
|
||
models_list = []
|
||
for model_id, model_info in all_models.items():
|
||
models_list.append({
|
||
"id": model_id,
|
||
"name": model_info.get("name", model_id),
|
||
"description": model_info.get("description", ""),
|
||
"cost": model_info.get("cost", 0.02),
|
||
"cost_8k": model_info.get("cost_8k"), # Optional
|
||
"tier": model_info.get("tier", "mid"),
|
||
"max_resolution": model_info.get("max_resolution", [2048, 2048]),
|
||
"capabilities": model_info.get("capabilities", []),
|
||
"use_cases": self._get_use_cases_for_model(model_id, model_info),
|
||
"features": self._get_features_for_model(model_info),
|
||
"supports_multi_image": model_info.get("supports_multi_image", False),
|
||
"supports_controlnet": model_info.get("supports_controlnet", False),
|
||
"languages": model_info.get("languages", ["en"]),
|
||
})
|
||
|
||
return {
|
||
"models": models_list,
|
||
"total": len(models_list),
|
||
}
|
||
|
||
def recommend_model(
|
||
self,
|
||
operation: str,
|
||
image_resolution: Optional[Dict[str, int]] = None,
|
||
user_tier: Optional[str] = None,
|
||
preferences: Optional[Dict[str, Any]] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Recommend best model for given operation and context.
|
||
|
||
Args:
|
||
operation: Operation type (e.g., "general_edit")
|
||
image_resolution: Dict with "width" and "height"
|
||
user_tier: User subscription tier ("free", "pro", "enterprise")
|
||
preferences: Dict with "prioritize_cost" or "prioritize_quality"
|
||
|
||
Returns:
|
||
Dictionary with recommended model and alternatives
|
||
"""
|
||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||
|
||
provider = WaveSpeedEditProvider()
|
||
available_models = provider.get_models_by_operation(operation)
|
||
|
||
if not available_models:
|
||
# Fallback to all models if operation doesn't match
|
||
available_models = provider.get_available_models()
|
||
|
||
# Filter by resolution if provided
|
||
if image_resolution:
|
||
width = image_resolution.get("width", 0)
|
||
height = image_resolution.get("height", 0)
|
||
max_dimension = max(width, height)
|
||
|
||
# Filter models that support this resolution
|
||
filtered = {}
|
||
for model_id, model_info in available_models.items():
|
||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||
max_supported = max(max_res[0], max_res[1])
|
||
if max_dimension <= max_supported:
|
||
filtered[model_id] = model_info
|
||
available_models = filtered
|
||
|
||
if not available_models:
|
||
# No models match, return first available
|
||
all_models = provider.get_available_models()
|
||
if all_models:
|
||
first_model_id = list(all_models.keys())[0]
|
||
return {
|
||
"recommended_model": first_model_id,
|
||
"reason": "No specific match found, using default model",
|
||
"alternatives": [],
|
||
}
|
||
else:
|
||
raise ValueError("No models available")
|
||
|
||
# Apply preferences
|
||
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
|
||
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
|
||
|
||
# Score models
|
||
scored_models = []
|
||
for model_id, model_info in available_models.items():
|
||
score = 0
|
||
cost = model_info.get("cost", 0.02)
|
||
tier = model_info.get("tier", "mid")
|
||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||
max_resolution = max(max_res[0], max_res[1])
|
||
|
||
# Cost scoring (lower is better)
|
||
if prioritize_cost:
|
||
score += (1.0 / cost) * 100 # Invert cost for scoring
|
||
else:
|
||
score += (1.0 / cost) * 50 # Less weight if not prioritizing
|
||
|
||
# Quality scoring (higher resolution = better)
|
||
if prioritize_quality:
|
||
score += max_resolution / 10 # Higher weight for quality
|
||
else:
|
||
score += max_resolution / 20 # Lower weight
|
||
|
||
# Tier preference based on user tier
|
||
if user_tier == "free":
|
||
if tier == "budget":
|
||
score += 50
|
||
elif tier == "mid":
|
||
score += 20
|
||
elif user_tier in ["pro", "enterprise"]:
|
||
if tier == "premium":
|
||
score += 50
|
||
elif tier == "mid":
|
||
score += 30
|
||
|
||
scored_models.append((model_id, model_info, score))
|
||
|
||
# Sort by score (highest first)
|
||
scored_models.sort(key=lambda x: x[2], reverse=True)
|
||
|
||
# Get recommended model
|
||
recommended_id, recommended_info, recommended_score = scored_models[0]
|
||
|
||
# Build reason
|
||
reasons = []
|
||
if prioritize_cost:
|
||
reasons.append("Lowest cost option")
|
||
if prioritize_quality:
|
||
reasons.append("Best quality")
|
||
if image_resolution:
|
||
reasons.append(f"Supports {image_resolution.get('width')}×{image_resolution.get('height')} resolution")
|
||
if user_tier == "free" and recommended_info.get("tier") == "budget":
|
||
reasons.append("Budget-friendly for free tier")
|
||
|
||
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
|
||
|
||
# Get alternatives (top 2-3)
|
||
alternatives = []
|
||
for model_id, model_info, score in scored_models[1:4]:
|
||
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
|
||
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
|
||
alt_reason += ", lower cost"
|
||
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
|
||
alt_reason += ", higher quality"
|
||
alternatives.append({
|
||
"model_id": model_id,
|
||
"name": model_info.get("name", model_id),
|
||
"cost": model_info.get("cost", 0.02),
|
||
"reason": alt_reason,
|
||
})
|
||
|
||
return {
|
||
"recommended_model": recommended_id,
|
||
"reason": reason,
|
||
"alternatives": alternatives,
|
||
}
|
||
|
||
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
|
||
"""Get use cases for a model based on its capabilities."""
|
||
use_cases_map = {
|
||
"general_edit": ["Quick edits", "Style changes", "Background replacement"],
|
||
"style_transfer": ["Apply artistic styles", "Style transformations"],
|
||
"text_edit": ["Add text to images", "Edit text in images"],
|
||
"multi_image": ["Batch editing", "Consistent character work"],
|
||
"high_res": ["Professional work", "Print materials", "4K/8K editing"],
|
||
"professional": ["Marketing campaigns", "Brand assets"],
|
||
"typography": ["Text-heavy edits", "Typography generation"],
|
||
"portrait_retouching": ["Portrait edits", "Beauty retouching"],
|
||
"fashion_edit": ["Fashion photography", "Outfit changes"],
|
||
"product_edit": ["E-commerce", "Product photography"],
|
||
}
|
||
|
||
capabilities = model_info.get("capabilities", [])
|
||
use_cases = []
|
||
for cap in capabilities:
|
||
if cap in use_cases_map:
|
||
use_cases.extend(use_cases_map[cap])
|
||
|
||
# Remove duplicates
|
||
return list(set(use_cases)) if use_cases else ["General image editing"]
|
||
|
||
def _get_features_for_model(self, model_info: Dict[str, Any]) -> list:
|
||
"""Get feature list for a model."""
|
||
features = []
|
||
|
||
if model_info.get("supports_multi_image"):
|
||
max_images = model_info.get("api_params", {}).get("max_images", 0)
|
||
if max_images:
|
||
features.append(f"Multi-image ({max_images} images)")
|
||
else:
|
||
features.append("Multi-image support")
|
||
|
||
if model_info.get("supports_controlnet"):
|
||
features.append("ControlNet support")
|
||
|
||
languages = model_info.get("languages", [])
|
||
if len(languages) > 1:
|
||
features.append(f"Multilingual ({', '.join(languages)})")
|
||
elif "multilingual" in languages:
|
||
features.append("Multilingual support")
|
||
|
||
max_res = model_info.get("max_resolution", (2048, 2048))
|
||
if max(max_res) >= 4096:
|
||
features.append("4K/8K support")
|
||
elif max(max_res) >= 2048:
|
||
features.append("2K support")
|
||
|
||
api_params = model_info.get("api_params", {})
|
||
if api_params.get("supports_guidance_scale"):
|
||
features.append("Guidance scale control")
|
||
|
||
return features if features else ["Standard editing"]
|
||
|
||
async def process_edit(
|
||
self,
|
||
request: EditStudioRequest,
|
||
user_id: Optional[str] = None,
|
||
) -> Dict[str, Any]:
|
||
"""Process edit request and return normalized response."""
|
||
|
||
# Pre-flight validation: Use specific validator for editing operations
|
||
# Note: Editing uses validate_image_editing_operations (different from generation)
|
||
# This is intentional as editing may have different subscription limits
|
||
if user_id:
|
||
from services.database import get_db
|
||
from services.subscription import PricingService
|
||
from services.subscription.preflight_validator import validate_image_editing_operations
|
||
from fastapi import HTTPException
|
||
|
||
db = next(get_db())
|
||
try:
|
||
pricing_service = PricingService(db)
|
||
logger.info(f"[Edit Studio] 🛂 Running pre-flight validation for user {user_id}")
|
||
validate_image_editing_operations(
|
||
pricing_service=pricing_service,
|
||
user_id=user_id,
|
||
)
|
||
logger.info("[Edit Studio] ✅ Pre-flight validation passed")
|
||
except HTTPException:
|
||
logger.error("[Edit Studio] ❌ Pre-flight validation failed")
|
||
raise
|
||
finally:
|
||
db.close()
|
||
else:
|
||
logger.warning("[Edit Studio] ⚠️ No user_id provided - skipping pre-flight validation")
|
||
|
||
image_bytes = self._decode_base64_image(request.image_base64)
|
||
if not image_bytes:
|
||
raise ValueError("Primary image payload is required")
|
||
|
||
mask_bytes = self._decode_base64_image(request.mask_base64)
|
||
background_bytes = self._decode_base64_image(request.background_image_base64)
|
||
lighting_bytes = self._decode_base64_image(request.lighting_image_base64)
|
||
|
||
operation = request.operation
|
||
logger.info("[Edit Studio] Processing operation='%s' for user=%s", operation, user_id)
|
||
|
||
if operation not in self.SUPPORTED_OPERATIONS:
|
||
raise ValueError(f"Unsupported edit operation: {operation}")
|
||
|
||
if operation in {"remove_background", "inpaint", "outpaint", "search_replace", "search_recolor", "relight"}:
|
||
image_bytes = await self._handle_stability_edit(
|
||
operation=operation,
|
||
request=request,
|
||
image_bytes=image_bytes,
|
||
mask_bytes=mask_bytes,
|
||
background_bytes=background_bytes,
|
||
lighting_bytes=lighting_bytes,
|
||
)
|
||
# Track usage for Stability operations
|
||
if user_id:
|
||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||
_track_image_operation_usage(
|
||
user_id=user_id,
|
||
provider="stability",
|
||
model=f"edit-{operation}",
|
||
operation_type="image-edit",
|
||
result_bytes=image_bytes,
|
||
cost=0.04,
|
||
endpoint="/image-studio/edit/process",
|
||
log_prefix="[Edit Studio]"
|
||
)
|
||
else:
|
||
image_bytes = await self._handle_general_edit(
|
||
request=request,
|
||
image_bytes=image_bytes,
|
||
mask_bytes=mask_bytes,
|
||
user_id=user_id,
|
||
)
|
||
|
||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||
metadata.update(
|
||
{
|
||
"operation": operation,
|
||
"style_preset": request.style_preset,
|
||
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
|
||
}
|
||
)
|
||
|
||
response = {
|
||
"success": True,
|
||
"operation": operation,
|
||
"provider": metadata["provider"],
|
||
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
|
||
"width": metadata["width"],
|
||
"height": metadata["height"],
|
||
"metadata": metadata,
|
||
}
|
||
|
||
logger.info("[Edit Studio] ✅ Operation '%s' completed", operation)
|
||
return response
|
||
|
||
async def _handle_stability_edit(
|
||
self,
|
||
operation: EditOperationType,
|
||
request: EditStudioRequest,
|
||
image_bytes: bytes,
|
||
mask_bytes: Optional[bytes],
|
||
background_bytes: Optional[bytes],
|
||
lighting_bytes: Optional[bytes],
|
||
) -> bytes:
|
||
"""Execute Stability AI edit workflows."""
|
||
stability_service = StabilityAIService()
|
||
|
||
async with stability_service:
|
||
if operation == "remove_background":
|
||
result = await stability_service.remove_background(
|
||
image=image_bytes,
|
||
output_format=request.output_format,
|
||
)
|
||
elif operation == "inpaint":
|
||
if not request.prompt:
|
||
raise ValueError("Prompt is required for inpainting")
|
||
result = await stability_service.inpaint(
|
||
image=image_bytes,
|
||
prompt=request.prompt,
|
||
mask=mask_bytes,
|
||
negative_prompt=request.negative_prompt,
|
||
output_format=request.output_format,
|
||
style_preset=request.style_preset,
|
||
grow_mask=request.options.get("grow_mask", 5),
|
||
)
|
||
elif operation == "outpaint":
|
||
result = await stability_service.outpaint(
|
||
image=image_bytes,
|
||
prompt=request.prompt,
|
||
negative_prompt=request.negative_prompt,
|
||
output_format=request.output_format,
|
||
left=request.expand_left or 0,
|
||
right=request.expand_right or 0,
|
||
up=request.expand_up or 0,
|
||
down=request.expand_down or 0,
|
||
style_preset=request.style_preset,
|
||
)
|
||
elif operation == "search_replace":
|
||
if not (request.prompt and request.search_prompt):
|
||
raise ValueError("Both prompt and search_prompt are required for search & replace")
|
||
result = await stability_service.search_and_replace(
|
||
image=image_bytes,
|
||
prompt=request.prompt,
|
||
search_prompt=request.search_prompt,
|
||
mask=mask_bytes, # Optional mask for precise region selection
|
||
output_format=request.output_format,
|
||
)
|
||
elif operation == "search_recolor":
|
||
if not (request.prompt and request.select_prompt):
|
||
raise ValueError("Both prompt and select_prompt are required for search & recolor")
|
||
result = await stability_service.search_and_recolor(
|
||
image=image_bytes,
|
||
prompt=request.prompt,
|
||
select_prompt=request.select_prompt,
|
||
mask=mask_bytes, # Optional mask for precise region selection
|
||
output_format=request.output_format,
|
||
)
|
||
elif operation == "relight":
|
||
if not background_bytes and not lighting_bytes:
|
||
raise ValueError("At least one reference (background or lighting) is required for relight")
|
||
result = await stability_service.replace_background_and_relight(
|
||
subject_image=image_bytes,
|
||
background_reference=background_bytes,
|
||
light_reference=lighting_bytes,
|
||
output_format=request.output_format,
|
||
)
|
||
if isinstance(result, dict) and result.get("id"):
|
||
result = await self._poll_stability_result(
|
||
stability_service,
|
||
generation_id=result["id"],
|
||
output_format=request.output_format,
|
||
)
|
||
else:
|
||
raise ValueError(f"Unsupported Stability operation: {operation}")
|
||
|
||
return self._extract_image_bytes(result)
|
||
|
||
async def _handle_general_edit(
|
||
self,
|
||
request: EditStudioRequest,
|
||
image_bytes: bytes,
|
||
mask_bytes: Optional[bytes],
|
||
user_id: Optional[str],
|
||
) -> bytes:
|
||
"""Execute general editing - routes to WaveSpeed (unified entry) or HuggingFace (legacy).
|
||
|
||
If model is a WaveSpeed model (qwen-edit-plus, nano-banana-pro-edit-ultra, seedream-v4.5-edit),
|
||
uses unified entry point. Otherwise falls back to HuggingFace for backward compatibility.
|
||
"""
|
||
if not request.prompt:
|
||
raise ValueError("Prompt is required for general edits")
|
||
|
||
# Check if model is a WaveSpeed editing model
|
||
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||
provider = WaveSpeedEditProvider()
|
||
wavespeed_models = set(provider.get_available_models().keys())
|
||
|
||
# Also check if provider is explicitly set to "wavespeed"
|
||
is_wavespeed = (
|
||
request.provider == "wavespeed" or
|
||
(request.model and request.model in wavespeed_models)
|
||
)
|
||
|
||
# Auto-detect: If no model specified and operation is general_edit, recommend one
|
||
if not request.model and not is_wavespeed and request.operation == "general_edit":
|
||
# Auto-select recommended model
|
||
try:
|
||
# Get image dimensions for recommendation
|
||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||
image_resolution = {"width": img.width, "height": img.height}
|
||
|
||
recommendation = self.recommend_model(
|
||
operation=request.operation,
|
||
image_resolution=image_resolution,
|
||
preferences={"prioritize_cost": True}, # Default to cost-optimized
|
||
)
|
||
recommended_model = recommendation.get("recommended_model")
|
||
if recommended_model and recommended_model in wavespeed_models:
|
||
logger.info(f"[Edit Studio] Auto-selected model: {recommended_model} (reason: {recommendation.get('reason')})")
|
||
request.model = recommended_model
|
||
is_wavespeed = True
|
||
except Exception as e:
|
||
logger.warning(f"[Edit Studio] Auto-detection failed: {e}, falling back to HuggingFace")
|
||
|
||
if is_wavespeed:
|
||
# Use unified entry point for WaveSpeed models
|
||
logger.info(f"[Edit Studio] Using WaveSpeed unified entry for model={request.model}")
|
||
|
||
# Convert image bytes to base64
|
||
import base64
|
||
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
|
||
|
||
# Prepare options for unified entry point
|
||
edit_options = {
|
||
"mask_base64": None,
|
||
"negative_prompt": request.negative_prompt,
|
||
"width": None, # Will be determined from image if needed
|
||
"height": None,
|
||
"guidance_scale": request.guidance_scale,
|
||
"steps": request.steps,
|
||
"seed": request.seed,
|
||
}
|
||
|
||
# Add mask if provided
|
||
if mask_bytes:
|
||
edit_options["mask_base64"] = base64.b64encode(mask_bytes).decode("utf-8")
|
||
|
||
# Extract dimensions from image if needed
|
||
with Image.open(io.BytesIO(image_bytes)) as img:
|
||
edit_options["width"] = img.width
|
||
edit_options["height"] = img.height
|
||
|
||
# Call unified entry point (synchronous, so run in thread)
|
||
result = await asyncio.to_thread(
|
||
generate_image_edit,
|
||
image_base64=image_base64,
|
||
prompt=request.prompt,
|
||
operation=request.operation or "general_edit",
|
||
model=request.model, # Will auto-select if None
|
||
options=edit_options,
|
||
user_id=user_id,
|
||
)
|
||
|
||
return result.image_bytes
|
||
else:
|
||
# Fall back to HuggingFace for backward compatibility
|
||
logger.info("[Edit Studio] Using HuggingFace (legacy) for general edit")
|
||
|
||
options = {
|
||
"provider": request.provider or "huggingface",
|
||
"model": request.model,
|
||
"guidance_scale": request.guidance_scale,
|
||
"steps": request.steps,
|
||
"seed": request.seed,
|
||
}
|
||
|
||
# huggingface edit is synchronous - run in thread
|
||
result = await asyncio.to_thread(
|
||
huggingface_edit_image,
|
||
image_bytes,
|
||
request.prompt,
|
||
options,
|
||
user_id,
|
||
mask_bytes, # Optional mask for selective editing
|
||
)
|
||
|
||
return result.image_bytes
|
||
|
||
@staticmethod
|
||
def _extract_image_bytes(result: Any) -> bytes:
|
||
"""Normalize Stability responses into raw image bytes."""
|
||
if isinstance(result, bytes):
|
||
return result
|
||
|
||
if isinstance(result, dict):
|
||
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
|
||
for artifact in artifacts:
|
||
if isinstance(artifact, dict):
|
||
if artifact.get("base64"):
|
||
return base64.b64decode(artifact["base64"])
|
||
if artifact.get("b64_json"):
|
||
return base64.b64decode(artifact["b64_json"])
|
||
|
||
raise RuntimeError("Unable to extract image bytes from provider response")
|
||
|
||
async def _poll_stability_result(
|
||
self,
|
||
stability_service: StabilityAIService,
|
||
generation_id: str,
|
||
output_format: str,
|
||
timeout_seconds: int = 240,
|
||
interval_seconds: float = 2.0,
|
||
) -> bytes:
|
||
"""Poll Stability async endpoint until result is ready."""
|
||
elapsed = 0.0
|
||
while elapsed < timeout_seconds:
|
||
result = await stability_service.get_generation_result(
|
||
generation_id=generation_id,
|
||
accept_type="*/*",
|
||
)
|
||
|
||
if isinstance(result, bytes):
|
||
return result
|
||
|
||
if isinstance(result, dict):
|
||
state = (result.get("state") or result.get("status") or "").lower()
|
||
if state in {"succeeded", "success", "ready", "completed"}:
|
||
return self._extract_image_bytes(result)
|
||
if state in {"failed", "error"}:
|
||
raise RuntimeError(f"Stability generation failed: {result}")
|
||
|
||
await asyncio.sleep(interval_seconds)
|
||
elapsed += interval_seconds
|
||
|
||
raise RuntimeError("Timed out waiting for Stability generation result")
|
||
|
||
|