Files
ALwrity/backend/services/llm_providers/main_image_editing.py

302 lines
12 KiB
Python

from __future__ import annotations
import os
import io
import base64
import logging
from typing import Optional, Dict, Any
from PIL import Image
from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
)
from .image_generation.base import ImageEditOptions
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from utils.logger_utils import get_service_logger
try:
from huggingface_hub import InferenceClient
HF_HUB_AVAILABLE = True
except ImportError:
HF_HUB_AVAILABLE = False
logger = get_service_logger("image_editing.facade")
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
"WAVESPEED_IMAGE_EDIT_MODEL",
"qwen-edit-plus",
)
def _select_provider(explicit: Optional[str]) -> str:
"""
Select the appropriate image editing provider.
Priority:
1. Explicitly requested provider
2. WaveSpeed (if API key available) - Preferred for quality/speed
3. Hugging Face (fallback)
"""
if explicit:
return explicit.lower()
# Check for WaveSpeed API key first (Preferred provider)
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Default to huggingface if WaveSpeed not available
return "huggingface"
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
"""Get the client for the specified provider."""
if provider_name == "wavespeed":
return WaveSpeedEditProvider(api_key=api_key)
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
if provider_name == "huggingface":
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
# Use fal-ai provider for fast inference via HF Inference API
return InferenceClient(provider="fal-ai", api_key=api_key)
raise ValueError(f"Unknown image editing provider: {provider_name}")
def edit_image(
input_image_bytes: bytes,
prompt: str,
options: Optional[Dict[str, Any]] = None,
user_id: Optional[str] = None,
mask_bytes: Optional[bytes] = None,
) -> ImageGenerationResult:
"""Edit image with pre-flight validation.
Args:
input_image_bytes: Input image as bytes (PNG/JPEG)
prompt: Natural language prompt describing desired edits (e.g., "Turn the cat into a tiger")
options: Image editing options (provider, model, etc.)
user_id: User ID for subscription checking (optional, but required for validation)
mask_bytes: Optional mask image bytes for selective editing (grayscale, white=edit, black=preserve)
Returns:
ImageGenerationResult with edited image bytes and metadata
Best Practices for Prompts:
- Use clear, specific language describing desired changes
- Describe what should change and what should remain
- Examples: "Turn the cat into a tiger", "Change background to forest",
"Make it look like a watercolor painting"
Note: Mask support depends on the specific model. Some models may ignore the mask parameter.
"""
# PRE-FLIGHT VALIDATION: Validate image editing before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
if user_id:
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_editing_operations
from fastapi import HTTPException
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
# Note: get_db() is a generator, so we need to use next() to get the session
# and ensure we close it in the finally block
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_image_editing_operations(
pricing_service=pricing_service,
user_id=user_id
)
logger.info(f"[Image Editing] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image editing")
except HTTPException as http_ex:
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
except Exception as e:
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
finally:
db.close()
else:
logger.warning(f"[Image Editing] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
# Validate input
if not input_image_bytes:
raise ValueError("input_image_bytes is required")
if not prompt or not prompt.strip():
raise ValueError("prompt is required for image editing")
opts = options or {}
provider_name = _select_provider(opts.get("provider"))
model = opts.get("model") or DEFAULT_IMAGE_EDIT_MODEL
logger.info(f"[Image Editing] Editing image via provider={provider_name} model={model}")
# Get provider client
client = _get_provider_client(provider_name, opts.get("api_key"))
if provider_name == "wavespeed":
# Handle WaveSpeed provider
try:
# Convert inputs to base64 for WaveSpeed
image_b64 = base64.b64encode(input_image_bytes).decode('utf-8')
mask_b64 = None
if mask_bytes:
mask_b64 = base64.b64encode(mask_bytes).decode('utf-8')
# Determine operation type based on prompt/mask
operation = "general_edit" # Default
if not prompt and mask_b64:
operation = "remove_bg" # Heuristic: mask but no prompt implies removal/in-painting
elif prompt and not mask_b64:
operation = "style_transfer" # Heuristic: prompt but no mask implies style transfer
elif opts.get("operation"):
operation = opts.get("operation")
edit_options = ImageEditOptions(
image_base64=image_b64,
prompt=prompt.strip(),
operation=operation,
mask_base64=mask_b64,
model=model,
guidance_scale=opts.get("guidance_scale"),
steps=opts.get("steps"),
seed=opts.get("seed"),
extra=opts
)
logger.info(f"[Image Editing] Calling WaveSpeed edit with model={model}")
result = client.edit(edit_options)
# TRACK USAGE after successful WaveSpeed call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (WaveSpeed default: $0.02)
estimated_cost = result.metadata.get("estimated_cost", 0.02) if result.metadata else 0.02
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model=result.model or model,
operation_type="image-editing",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata=result.metadata,
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return result
except Exception as e:
logger.error(f"[Image Editing] ❌ WaveSpeed editing failed: {e}", exc_info=True)
raise RuntimeError(f"WaveSpeed editing failed: {str(e)}")
# Hugging Face (Fallback)
# Prepare parameters for image-to-image
params: Dict[str, Any] = {}
if opts.get("guidance_scale") is not None:
params["guidance_scale"] = opts.get("guidance_scale")
if opts.get("steps") is not None:
params["num_inference_steps"] = opts.get("steps")
if opts.get("seed") is not None:
params["seed"] = opts.get("seed")
try:
# Convert input image bytes to PIL Image for validation
input_image = Image.open(io.BytesIO(input_image_bytes))
width = input_image.width
height = input_image.height
# Convert mask bytes to PIL Image if provided
mask_image = None
if mask_bytes:
try:
mask_image = Image.open(io.BytesIO(mask_bytes)).convert("L") # Convert to grayscale
# Ensure mask dimensions match input image
if mask_image.size != input_image.size:
logger.warning(f"[Image Editing] Mask size {mask_image.size} doesn't match image size {input_image.size}, resizing mask")
mask_image = mask_image.resize(input_image.size, Image.Resampling.LANCZOS)
except Exception as e:
logger.warning(f"[Image Editing] Failed to process mask image: {e}, continuing without mask")
mask_image = None
# Use image_to_image method from Hugging Face InferenceClient
# This follows the pattern from the Hugging Face documentation
# Docs: https://huggingface.co/docs/inference-providers/en/guides/image-editor
# Note: Mask support depends on the model - some models may ignore it
call_params = params.copy()
if mask_image:
call_params["mask_image"] = mask_image
logger.info("[Image Editing] Using mask for selective editing")
edited_image: Image.Image = client.image_to_image(
image=input_image,
prompt=prompt.strip(),
model=model,
**call_params,
)
# Convert edited image back to bytes
with io.BytesIO() as buf:
edited_image.save(buf, format="PNG")
edited_image_bytes = buf.getvalue()
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
# TRACK USAGE after successful HF call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (HF/Fal-ai default: $0.05)
estimated_cost = 0.05
_track_image_operation_usage(
user_id=user_id,
provider="huggingface",
model=model,
operation_type="image-editing",
result_bytes=edited_image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata={"provider": "fal-ai"},
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return ImageGenerationResult(
image_bytes=edited_image_bytes,
width=edited_image.width,
height=edited_image.height,
provider="huggingface",
model=model,
seed=opts.get("seed"),
metadata={
"provider": "fal-ai",
"operation": "image_editing",
"original_width": width,
"original_height": height,
},
)
except Exception as e:
logger.error(f"[Image Editing] ❌ Error editing image: {e}", exc_info=True)
raise RuntimeError(f"Image editing failed: {str(e)}")