Save local changes (GSC/Bing integrations) before merging PR #354
This commit is contained in:
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
@@ -9,6 +11,9 @@ from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from .image_generation.base import ImageEditOptions
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
try:
|
||||
@@ -22,21 +27,36 @@ logger = get_service_logger("image_editing.facade")
|
||||
|
||||
|
||||
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
|
||||
"HF_IMAGE_EDIT_MODEL",
|
||||
"Qwen/Qwen-Image-Edit",
|
||||
"WAVESPEED_IMAGE_EDIT_MODEL",
|
||||
"qwen-edit-plus",
|
||||
)
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
|
||||
"""
|
||||
Select the appropriate image editing provider.
|
||||
|
||||
Priority:
|
||||
1. Explicitly requested provider
|
||||
2. WaveSpeed (if API key available) - Preferred for quality/speed
|
||||
3. Hugging Face (fallback)
|
||||
"""
|
||||
if explicit:
|
||||
return explicit
|
||||
# Default to huggingface for image editing (best support for image-to-image)
|
||||
return explicit.lower()
|
||||
|
||||
# Check for WaveSpeed API key first (Preferred provider)
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
# Default to huggingface if WaveSpeed not available
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get InferenceClient for the specified provider."""
|
||||
"""Get the client for the specified provider."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider(api_key=api_key)
|
||||
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
@@ -44,7 +64,7 @@ def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
|
||||
# Use fal-ai provider for fast inference
|
||||
# Use fal-ai provider for fast inference via HF Inference API
|
||||
return InferenceClient(provider="fal-ai", api_key=api_key)
|
||||
|
||||
raise ValueError(f"Unknown image editing provider: {provider_name}")
|
||||
@@ -86,6 +106,8 @@ def edit_image(
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
# Note: get_db() is a generator, so we need to use next() to get the session
|
||||
# and ensure we close it in the finally block
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
@@ -99,6 +121,9 @@ def edit_image(
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
@@ -119,6 +144,69 @@ def edit_image(
|
||||
# Get provider client
|
||||
client = _get_provider_client(provider_name, opts.get("api_key"))
|
||||
|
||||
if provider_name == "wavespeed":
|
||||
# Handle WaveSpeed provider
|
||||
try:
|
||||
# Convert inputs to base64 for WaveSpeed
|
||||
image_b64 = base64.b64encode(input_image_bytes).decode('utf-8')
|
||||
mask_b64 = None
|
||||
if mask_bytes:
|
||||
mask_b64 = base64.b64encode(mask_bytes).decode('utf-8')
|
||||
|
||||
# Determine operation type based on prompt/mask
|
||||
operation = "general_edit" # Default
|
||||
if not prompt and mask_b64:
|
||||
operation = "remove_bg" # Heuristic: mask but no prompt implies removal/in-painting
|
||||
elif prompt and not mask_b64:
|
||||
operation = "style_transfer" # Heuristic: prompt but no mask implies style transfer
|
||||
elif opts.get("operation"):
|
||||
operation = opts.get("operation")
|
||||
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_b64,
|
||||
prompt=prompt.strip(),
|
||||
operation=operation,
|
||||
mask_base64=mask_b64,
|
||||
model=model,
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts
|
||||
)
|
||||
|
||||
logger.info(f"[Image Editing] Calling WaveSpeed edit with model={model}")
|
||||
result = client.edit(edit_options)
|
||||
|
||||
# TRACK USAGE after successful WaveSpeed call
|
||||
if user_id:
|
||||
try:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
|
||||
# Estimate cost (WaveSpeed default: $0.02)
|
||||
estimated_cost = result.metadata.get("estimated_cost", 0.02) if result.metadata else 0.02
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model=result.model or model,
|
||||
operation_type="image-editing",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-editing",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Editing]"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ WaveSpeed editing failed: {e}", exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed editing failed: {str(e)}")
|
||||
|
||||
# Hugging Face (Fallback)
|
||||
# Prepare parameters for image-to-image
|
||||
params: Dict[str, Any] = {}
|
||||
if opts.get("guidance_scale") is not None:
|
||||
@@ -170,6 +258,29 @@ def edit_image(
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
|
||||
|
||||
# TRACK USAGE after successful HF call
|
||||
if user_id:
|
||||
try:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
|
||||
# Estimate cost (HF/Fal-ai default: $0.05)
|
||||
estimated_cost = 0.05
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
operation_type="image-editing",
|
||||
result_bytes=edited_image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-editing",
|
||||
metadata={"provider": "fal-ai"},
|
||||
log_prefix="[Image Editing]"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=edited_image_bytes,
|
||||
width=edited_image.width,
|
||||
|
||||
Reference in New Issue
Block a user