refactor(phase1): extract image generation helpers, edit, face_swap into separate modules + fix subscription bugs
Extracted from main_image_generation.py (1002->591 lines): - image_generation/helpers.py: _validate_image_operation, _track_image_operation_usage - image_generation/edit.py: generate_image_edit (with _get_edit_provider) - image_generation/face_swap.py: generate_face_swap (with _get_face_swap_provider) Main image_generation.py now imports and re-exports from these modules. All existing imports (api/images.py, step4_asset_routes.py, studio services) continue to work unchanged. Bug fixes included: 1. generate_image_edit: Added missing 'return result' (was returning None!) 2. generate_image_edit: Added missing _track_image_operation_usage call 3. generate_face_swap: Removed duplicate dead tracking code after return statement
This commit is contained in:
105
backend/services/llm_providers/image_generation/face_swap.py
Normal file
105
backend/services/llm_providers/image_generation/face_swap.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Face swap operations — generate_face_swap and related helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import FaceSwapOptions, FaceSwapProvider, ImageGenerationResult
|
||||
from .wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.face_swap")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Generate face swap with pre-flight validation and usage tracking.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
num_operations=1,
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025)
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None,
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap failed", "message": str(api_error)}
|
||||
)
|
||||
Reference in New Issue
Block a user