Files
ALwrity/backend/services/image_studio/edit_service.py
ajaysi ca725b77e7 refactor(phase2): add provider-aware tracking and fill missing subscription usage tracking
Changes:
1. helpers.py (_track_image_operation_usage): Map provider name to DB columns
   dynamically (stability→stability_calls, wavespeed→wavespeed_calls, etc.)
   instead of hardcoding stability_calls/stability_cost.

2. upscale_service.py: Added _track_image_operation_usage() call after
   successful Stability upscale completion.

3. control_service.py: Added _track_image_operation_usage() call after
   successful Stability control operation completion.

4. edit_service.py: Added _track_image_operation_usage() call after
   successful Stability edit operation (remove_background, inpaint,
   outpaint, search_replace, search_recolor, relight).

Previously only Create Studio and Face Swap tracked usage. Now all five
studios correctly decrement subscription limits.
2026-05-14 09:11:51 +05:30

802 lines
32 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Edit Studio service for AI-powered image editing and transformations."""
from __future__ import annotations
import asyncio
import base64
import io
from dataclasses import dataclass, field
from typing import Any, Dict, Literal, Optional
from PIL import Image
from services.llm_providers.main_image_editing import edit_image as huggingface_edit_image
from services.llm_providers.main_image_generation import generate_image_edit
from services.stability_service import StabilityAIService
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.edit")
EditOperationType = Literal[
"remove_background",
"inpaint",
"outpaint",
"search_replace",
"search_recolor",
"relight",
"general_edit",
]
@dataclass
class EditStudioRequest:
"""Normalized request payload for Edit Studio operations."""
image_base64: str
operation: EditOperationType
prompt: Optional[str] = None
negative_prompt: Optional[str] = None
mask_base64: Optional[str] = None
search_prompt: Optional[str] = None
select_prompt: Optional[str] = None
background_image_base64: Optional[str] = None
lighting_image_base64: Optional[str] = None
expand_left: Optional[int] = None
expand_right: Optional[int] = None
expand_up: Optional[int] = None
expand_down: Optional[int] = None
provider: Optional[str] = None
model: Optional[str] = None
style_preset: Optional[str] = None
guidance_scale: Optional[float] = None
steps: Optional[int] = None
seed: Optional[int] = None
output_format: str = "png"
options: Dict[str, Any] = field(default_factory=dict)
class EditStudioService:
"""Service layer orchestrating Edit Studio operations."""
SUPPORTED_OPERATIONS: Dict[EditOperationType, Dict[str, Any]] = {
"remove_background": {
"label": "Remove Background",
"description": "Isolate the main subject and remove the background.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"inpaint": {
"label": "Inpaint & Fix",
"description": "Edit specific regions using prompts and optional masks.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"outpaint": {
"label": "Outpaint",
"description": "Extend the canvas in any direction with smart fill.",
"provider": "stability",
"async": False,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": True,
},
},
"search_replace": {
"label": "Search & Replace",
"description": "Locate objects via search prompt and replace them. Optional mask for precise control.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": True,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
"search_recolor": {
"label": "Search & Recolor",
"description": "Select elements via prompt and recolor them. Optional mask for exact region selection.",
"provider": "stability",
"async": False,
"fields": {
"prompt": True,
"mask": True, # Optional mask for precise region selection
"negative_prompt": False,
"search_prompt": False,
"select_prompt": True,
"background": False,
"lighting": False,
"expansion": False,
},
},
"relight": {
"label": "Replace Background & Relight",
"description": "Swap backgrounds and relight using reference images.",
"provider": "stability",
"async": True,
"fields": {
"prompt": False,
"mask": False,
"negative_prompt": False,
"search_prompt": False,
"select_prompt": False,
"background": True,
"lighting": True,
"expansion": False,
},
},
"general_edit": {
"label": "Prompt-based Edit",
"description": "Free-form editing powered by Hugging Face image-to-image models. Optional mask for selective editing.",
"provider": "huggingface",
"async": False,
"fields": {
"prompt": True,
"mask": True, # Optional mask for selective region editing
"negative_prompt": True,
"search_prompt": False,
"select_prompt": False,
"background": False,
"lighting": False,
"expansion": False,
},
},
}
def __init__(self):
logger.info("[Edit Studio] Initialized edit service")
@staticmethod
def _decode_base64_image(value: Optional[str]) -> Optional[bytes]:
"""Decode a base64 (or data URL) string to bytes."""
if not value:
return None
try:
# Handle data URLs (data:image/png;base64,...)
if value.startswith("data:"):
_, b64data = value.split(",", 1)
else:
b64data = value
return base64.b64decode(b64data)
except Exception as exc:
logger.error(f"[Edit Studio] Failed to decode base64 image: {exc}")
raise ValueError("Invalid base64 image payload") from exc
@staticmethod
def _image_bytes_to_metadata(image_bytes: bytes) -> Dict[str, Any]:
"""Extract width/height metadata from image bytes."""
with Image.open(io.BytesIO(image_bytes)) as img:
return {
"width": img.width,
"height": img.height,
}
@staticmethod
def _bytes_to_base64(image_bytes: bytes, output_format: str = "png") -> str:
"""Convert raw bytes to base64 data URL."""
b64 = base64.b64encode(image_bytes).decode("utf-8")
return f"data:image/{output_format};base64,{b64}"
def list_operations(self) -> Dict[str, Dict[str, Any]]:
"""Expose supported operations for UI rendering."""
return self.SUPPORTED_OPERATIONS
def get_available_models(
self,
operation: Optional[str] = None,
tier: Optional[str] = None,
) -> Dict[str, Any]:
"""Get available WaveSpeed editing models.
Args:
operation: Filter by operation type (e.g., "general_edit")
tier: Filter by tier ("budget", "mid", "premium")
Returns:
Dictionary with models and metadata
"""
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
provider = WaveSpeedEditProvider()
all_models = provider.get_available_models()
# Filter by operation if specified
if operation:
filtered = provider.get_models_by_operation(operation)
all_models = {k: v for k, v in all_models.items() if k in filtered}
# Filter by tier if specified
if tier:
filtered = provider.get_models_by_tier(tier)
all_models = {k: v for k, v in all_models.items() if k in filtered}
# Format for API response
models_list = []
for model_id, model_info in all_models.items():
models_list.append({
"id": model_id,
"name": model_info.get("name", model_id),
"description": model_info.get("description", ""),
"cost": model_info.get("cost", 0.02),
"cost_8k": model_info.get("cost_8k"), # Optional
"tier": model_info.get("tier", "mid"),
"max_resolution": model_info.get("max_resolution", [2048, 2048]),
"capabilities": model_info.get("capabilities", []),
"use_cases": self._get_use_cases_for_model(model_id, model_info),
"features": self._get_features_for_model(model_info),
"supports_multi_image": model_info.get("supports_multi_image", False),
"supports_controlnet": model_info.get("supports_controlnet", False),
"languages": model_info.get("languages", ["en"]),
})
return {
"models": models_list,
"total": len(models_list),
}
def recommend_model(
self,
operation: str,
image_resolution: Optional[Dict[str, int]] = None,
user_tier: Optional[str] = None,
preferences: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Recommend best model for given operation and context.
Args:
operation: Operation type (e.g., "general_edit")
image_resolution: Dict with "width" and "height"
user_tier: User subscription tier ("free", "pro", "enterprise")
preferences: Dict with "prioritize_cost" or "prioritize_quality"
Returns:
Dictionary with recommended model and alternatives
"""
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
provider = WaveSpeedEditProvider()
available_models = provider.get_models_by_operation(operation)
if not available_models:
# Fallback to all models if operation doesn't match
available_models = provider.get_available_models()
# Filter by resolution if provided
if image_resolution:
width = image_resolution.get("width", 0)
height = image_resolution.get("height", 0)
max_dimension = max(width, height)
# Filter models that support this resolution
filtered = {}
for model_id, model_info in available_models.items():
max_res = model_info.get("max_resolution", (2048, 2048))
max_supported = max(max_res[0], max_res[1])
if max_dimension <= max_supported:
filtered[model_id] = model_info
available_models = filtered
if not available_models:
# No models match, return first available
all_models = provider.get_available_models()
if all_models:
first_model_id = list(all_models.keys())[0]
return {
"recommended_model": first_model_id,
"reason": "No specific match found, using default model",
"alternatives": [],
}
else:
raise ValueError("No models available")
# Apply preferences
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
# Score models
scored_models = []
for model_id, model_info in available_models.items():
score = 0
cost = model_info.get("cost", 0.02)
tier = model_info.get("tier", "mid")
max_res = model_info.get("max_resolution", (2048, 2048))
max_resolution = max(max_res[0], max_res[1])
# Cost scoring (lower is better)
if prioritize_cost:
score += (1.0 / cost) * 100 # Invert cost for scoring
else:
score += (1.0 / cost) * 50 # Less weight if not prioritizing
# Quality scoring (higher resolution = better)
if prioritize_quality:
score += max_resolution / 10 # Higher weight for quality
else:
score += max_resolution / 20 # Lower weight
# Tier preference based on user tier
if user_tier == "free":
if tier == "budget":
score += 50
elif tier == "mid":
score += 20
elif user_tier in ["pro", "enterprise"]:
if tier == "premium":
score += 50
elif tier == "mid":
score += 30
scored_models.append((model_id, model_info, score))
# Sort by score (highest first)
scored_models.sort(key=lambda x: x[2], reverse=True)
# Get recommended model
recommended_id, recommended_info, recommended_score = scored_models[0]
# Build reason
reasons = []
if prioritize_cost:
reasons.append("Lowest cost option")
if prioritize_quality:
reasons.append("Best quality")
if image_resolution:
reasons.append(f"Supports {image_resolution.get('width')}×{image_resolution.get('height')} resolution")
if user_tier == "free" and recommended_info.get("tier") == "budget":
reasons.append("Budget-friendly for free tier")
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
# Get alternatives (top 2-3)
alternatives = []
for model_id, model_info, score in scored_models[1:4]:
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
alt_reason += ", lower cost"
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
alt_reason += ", higher quality"
alternatives.append({
"model_id": model_id,
"name": model_info.get("name", model_id),
"cost": model_info.get("cost", 0.02),
"reason": alt_reason,
})
return {
"recommended_model": recommended_id,
"reason": reason,
"alternatives": alternatives,
}
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
"""Get use cases for a model based on its capabilities."""
use_cases_map = {
"general_edit": ["Quick edits", "Style changes", "Background replacement"],
"style_transfer": ["Apply artistic styles", "Style transformations"],
"text_edit": ["Add text to images", "Edit text in images"],
"multi_image": ["Batch editing", "Consistent character work"],
"high_res": ["Professional work", "Print materials", "4K/8K editing"],
"professional": ["Marketing campaigns", "Brand assets"],
"typography": ["Text-heavy edits", "Typography generation"],
"portrait_retouching": ["Portrait edits", "Beauty retouching"],
"fashion_edit": ["Fashion photography", "Outfit changes"],
"product_edit": ["E-commerce", "Product photography"],
}
capabilities = model_info.get("capabilities", [])
use_cases = []
for cap in capabilities:
if cap in use_cases_map:
use_cases.extend(use_cases_map[cap])
# Remove duplicates
return list(set(use_cases)) if use_cases else ["General image editing"]
def _get_features_for_model(self, model_info: Dict[str, Any]) -> list:
"""Get feature list for a model."""
features = []
if model_info.get("supports_multi_image"):
max_images = model_info.get("api_params", {}).get("max_images", 0)
if max_images:
features.append(f"Multi-image ({max_images} images)")
else:
features.append("Multi-image support")
if model_info.get("supports_controlnet"):
features.append("ControlNet support")
languages = model_info.get("languages", [])
if len(languages) > 1:
features.append(f"Multilingual ({', '.join(languages)})")
elif "multilingual" in languages:
features.append("Multilingual support")
max_res = model_info.get("max_resolution", (2048, 2048))
if max(max_res) >= 4096:
features.append("4K/8K support")
elif max(max_res) >= 2048:
features.append("2K support")
api_params = model_info.get("api_params", {})
if api_params.get("supports_guidance_scale"):
features.append("Guidance scale control")
return features if features else ["Standard editing"]
async def process_edit(
self,
request: EditStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process edit request and return normalized response."""
# Pre-flight validation: Use specific validator for editing operations
# Note: Editing uses validate_image_editing_operations (different from generation)
# This is intentional as editing may have different subscription limits
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_editing_operations
from fastapi import HTTPException
db = next(get_db())
try:
pricing_service = PricingService(db)
logger.info(f"[Edit Studio] 🛂 Running pre-flight validation for user {user_id}")
validate_image_editing_operations(
pricing_service=pricing_service,
user_id=user_id,
)
logger.info("[Edit Studio] ✅ Pre-flight validation passed")
except HTTPException:
logger.error("[Edit Studio] ❌ Pre-flight validation failed")
raise
finally:
db.close()
else:
logger.warning("[Edit Studio] ⚠️ No user_id provided - skipping pre-flight validation")
image_bytes = self._decode_base64_image(request.image_base64)
if not image_bytes:
raise ValueError("Primary image payload is required")
mask_bytes = self._decode_base64_image(request.mask_base64)
background_bytes = self._decode_base64_image(request.background_image_base64)
lighting_bytes = self._decode_base64_image(request.lighting_image_base64)
operation = request.operation
logger.info("[Edit Studio] Processing operation='%s' for user=%s", operation, user_id)
if operation not in self.SUPPORTED_OPERATIONS:
raise ValueError(f"Unsupported edit operation: {operation}")
if operation in {"remove_background", "inpaint", "outpaint", "search_replace", "search_recolor", "relight"}:
image_bytes = await self._handle_stability_edit(
operation=operation,
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
background_bytes=background_bytes,
lighting_bytes=lighting_bytes,
)
# Track usage for Stability operations
if user_id:
from services.llm_providers.main_image_generation import _track_image_operation_usage
_track_image_operation_usage(
user_id=user_id,
provider="stability",
model=f"edit-{operation}",
operation_type="image-edit",
result_bytes=image_bytes,
cost=0.04,
endpoint="/image-studio/edit/process",
log_prefix="[Edit Studio]"
)
else:
image_bytes = await self._handle_general_edit(
request=request,
image_bytes=image_bytes,
mask_bytes=mask_bytes,
user_id=user_id,
)
metadata = self._image_bytes_to_metadata(image_bytes)
metadata.update(
{
"operation": operation,
"style_preset": request.style_preset,
"provider": self.SUPPORTED_OPERATIONS[operation]["provider"],
}
)
response = {
"success": True,
"operation": operation,
"provider": metadata["provider"],
"image_base64": self._bytes_to_base64(image_bytes, request.output_format),
"width": metadata["width"],
"height": metadata["height"],
"metadata": metadata,
}
logger.info("[Edit Studio] ✅ Operation '%s' completed", operation)
return response
async def _handle_stability_edit(
self,
operation: EditOperationType,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
background_bytes: Optional[bytes],
lighting_bytes: Optional[bytes],
) -> bytes:
"""Execute Stability AI edit workflows."""
stability_service = StabilityAIService()
async with stability_service:
if operation == "remove_background":
result = await stability_service.remove_background(
image=image_bytes,
output_format=request.output_format,
)
elif operation == "inpaint":
if not request.prompt:
raise ValueError("Prompt is required for inpainting")
result = await stability_service.inpaint(
image=image_bytes,
prompt=request.prompt,
mask=mask_bytes,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
style_preset=request.style_preset,
grow_mask=request.options.get("grow_mask", 5),
)
elif operation == "outpaint":
result = await stability_service.outpaint(
image=image_bytes,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
output_format=request.output_format,
left=request.expand_left or 0,
right=request.expand_right or 0,
up=request.expand_up or 0,
down=request.expand_down or 0,
style_preset=request.style_preset,
)
elif operation == "search_replace":
if not (request.prompt and request.search_prompt):
raise ValueError("Both prompt and search_prompt are required for search & replace")
result = await stability_service.search_and_replace(
image=image_bytes,
prompt=request.prompt,
search_prompt=request.search_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "search_recolor":
if not (request.prompt and request.select_prompt):
raise ValueError("Both prompt and select_prompt are required for search & recolor")
result = await stability_service.search_and_recolor(
image=image_bytes,
prompt=request.prompt,
select_prompt=request.select_prompt,
mask=mask_bytes, # Optional mask for precise region selection
output_format=request.output_format,
)
elif operation == "relight":
if not background_bytes and not lighting_bytes:
raise ValueError("At least one reference (background or lighting) is required for relight")
result = await stability_service.replace_background_and_relight(
subject_image=image_bytes,
background_reference=background_bytes,
light_reference=lighting_bytes,
output_format=request.output_format,
)
if isinstance(result, dict) and result.get("id"):
result = await self._poll_stability_result(
stability_service,
generation_id=result["id"],
output_format=request.output_format,
)
else:
raise ValueError(f"Unsupported Stability operation: {operation}")
return self._extract_image_bytes(result)
async def _handle_general_edit(
self,
request: EditStudioRequest,
image_bytes: bytes,
mask_bytes: Optional[bytes],
user_id: Optional[str],
) -> bytes:
"""Execute general editing - routes to WaveSpeed (unified entry) or HuggingFace (legacy).
If model is a WaveSpeed model (qwen-edit-plus, nano-banana-pro-edit-ultra, seedream-v4.5-edit),
uses unified entry point. Otherwise falls back to HuggingFace for backward compatibility.
"""
if not request.prompt:
raise ValueError("Prompt is required for general edits")
# Check if model is a WaveSpeed editing model
from services.llm_providers.image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
provider = WaveSpeedEditProvider()
wavespeed_models = set(provider.get_available_models().keys())
# Also check if provider is explicitly set to "wavespeed"
is_wavespeed = (
request.provider == "wavespeed" or
(request.model and request.model in wavespeed_models)
)
# Auto-detect: If no model specified and operation is general_edit, recommend one
if not request.model and not is_wavespeed and request.operation == "general_edit":
# Auto-select recommended model
try:
# Get image dimensions for recommendation
with Image.open(io.BytesIO(image_bytes)) as img:
image_resolution = {"width": img.width, "height": img.height}
recommendation = self.recommend_model(
operation=request.operation,
image_resolution=image_resolution,
preferences={"prioritize_cost": True}, # Default to cost-optimized
)
recommended_model = recommendation.get("recommended_model")
if recommended_model and recommended_model in wavespeed_models:
logger.info(f"[Edit Studio] Auto-selected model: {recommended_model} (reason: {recommendation.get('reason')})")
request.model = recommended_model
is_wavespeed = True
except Exception as e:
logger.warning(f"[Edit Studio] Auto-detection failed: {e}, falling back to HuggingFace")
if is_wavespeed:
# Use unified entry point for WaveSpeed models
logger.info(f"[Edit Studio] Using WaveSpeed unified entry for model={request.model}")
# Convert image bytes to base64
import base64
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
# Prepare options for unified entry point
edit_options = {
"mask_base64": None,
"negative_prompt": request.negative_prompt,
"width": None, # Will be determined from image if needed
"height": None,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
# Add mask if provided
if mask_bytes:
edit_options["mask_base64"] = base64.b64encode(mask_bytes).decode("utf-8")
# Extract dimensions from image if needed
with Image.open(io.BytesIO(image_bytes)) as img:
edit_options["width"] = img.width
edit_options["height"] = img.height
# Call unified entry point (synchronous, so run in thread)
result = await asyncio.to_thread(
generate_image_edit,
image_base64=image_base64,
prompt=request.prompt,
operation=request.operation or "general_edit",
model=request.model, # Will auto-select if None
options=edit_options,
user_id=user_id,
)
return result.image_bytes
else:
# Fall back to HuggingFace for backward compatibility
logger.info("[Edit Studio] Using HuggingFace (legacy) for general edit")
options = {
"provider": request.provider or "huggingface",
"model": request.model,
"guidance_scale": request.guidance_scale,
"steps": request.steps,
"seed": request.seed,
}
# huggingface edit is synchronous - run in thread
result = await asyncio.to_thread(
huggingface_edit_image,
image_bytes,
request.prompt,
options,
user_id,
mask_bytes, # Optional mask for selective editing
)
return result.image_bytes
@staticmethod
def _extract_image_bytes(result: Any) -> bytes:
"""Normalize Stability responses into raw image bytes."""
if isinstance(result, bytes):
return result
if isinstance(result, dict):
artifacts = result.get("artifacts") or result.get("data") or result.get("images") or []
for artifact in artifacts:
if isinstance(artifact, dict):
if artifact.get("base64"):
return base64.b64decode(artifact["base64"])
if artifact.get("b64_json"):
return base64.b64decode(artifact["b64_json"])
raise RuntimeError("Unable to extract image bytes from provider response")
async def _poll_stability_result(
self,
stability_service: StabilityAIService,
generation_id: str,
output_format: str,
timeout_seconds: int = 240,
interval_seconds: float = 2.0,
) -> bytes:
"""Poll Stability async endpoint until result is ready."""
elapsed = 0.0
while elapsed < timeout_seconds:
result = await stability_service.get_generation_result(
generation_id=generation_id,
accept_type="*/*",
)
if isinstance(result, bytes):
return result
if isinstance(result, dict):
state = (result.get("state") or result.get("status") or "").lower()
if state in {"succeeded", "success", "ready", "completed"}:
return self._extract_image_bytes(result)
if state in {"failed", "error"}:
raise RuntimeError(f"Stability generation failed: {result}")
await asyncio.sleep(interval_seconds)
elapsed += interval_seconds
raise RuntimeError("Timed out waiting for Stability generation result")