Files
ALwrity/backend/services/image_studio/face_swap_service.py

267 lines
10 KiB
Python

"""Face Swap Studio service for AI-powered face swapping."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from typing import Any, Dict, Optional
from PIL import Image
from services.llm_providers.main_image_generation import generate_face_swap
from utils.logger_utils import get_service_logger
logger = get_service_logger("image_studio.face_swap")
@dataclass
class FaceSwapStudioRequest:
"""Request model for face swap operations."""
base_image_base64: str
face_image_base64: str
model: Optional[str] = None
target_face_index: Optional[int] = None
target_gender: Optional[str] = None
options: Optional[Dict[str, Any]] = None
class FaceSwapService:
"""Service for face swap operations."""
def __init__(self):
pass
def get_available_models(
self,
tier: Optional[str] = None,
) -> Dict[str, Any]:
"""Get available WaveSpeed face swap models.
Args:
tier: Filter by tier ("budget", "mid", "premium")
Returns:
Dictionary with models and metadata
"""
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
provider = WaveSpeedFaceSwapProvider()
all_models = provider.get_available_models()
# Filter by tier if specified
if tier:
filtered = provider.get_models_by_tier(tier)
all_models = {k: v for k, v in all_models.items() if k in filtered}
# Format for API response
models_list = []
for model_id, model_info in all_models.items():
models_list.append({
"id": model_id,
"name": model_info.get("name", model_id),
"description": model_info.get("description", ""),
"cost": model_info.get("cost", 0.025),
"tier": model_info.get("tier", "mid"),
"capabilities": model_info.get("capabilities", []),
"use_cases": self._get_use_cases_for_model(model_id, model_info),
"features": model_info.get("features", []),
"max_faces": model_info.get("max_faces", 1),
})
return {
"models": models_list,
"total": len(models_list),
}
def recommend_model(
self,
base_image_resolution: Optional[Dict[str, int]] = None,
face_image_resolution: Optional[Dict[str, int]] = None,
user_tier: Optional[str] = None,
preferences: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Recommend best model for face swap.
Args:
base_image_resolution: Dict with "width" and "height" of base image
face_image_resolution: Dict with "width" and "height" of face image
user_tier: User subscription tier ("free", "pro", "enterprise")
preferences: Dict with "prioritize_cost" or "prioritize_quality"
Returns:
Dictionary with recommended model and alternatives
"""
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
provider = WaveSpeedFaceSwapProvider()
available_models = provider.get_available_models()
if not available_models:
raise ValueError("No models available")
# Apply preferences
prioritize_cost = preferences and preferences.get("prioritize_cost", False)
prioritize_quality = preferences and preferences.get("prioritize_quality", False)
# Score models
scored_models = []
for model_id, model_info in available_models.items():
score = 0
cost = model_info.get("cost", 0.025)
tier = model_info.get("tier", "mid")
# Cost scoring (lower is better)
if prioritize_cost:
score += (1.0 / cost) * 100
else:
score += (1.0 / cost) * 50
# Quality scoring (higher cost = better quality for face swap)
if prioritize_quality:
score += cost * 20
else:
score += cost * 10
# Tier preference based on user tier
if user_tier == "free":
if tier == "budget":
score += 50
elif tier == "mid":
score += 20
elif user_tier in ["pro", "enterprise"]:
if tier == "premium":
score += 50
elif tier == "mid":
score += 30
scored_models.append((model_id, model_info, score))
# Sort by score (highest first)
scored_models.sort(key=lambda x: x[2], reverse=True)
# Get recommended model
recommended_id, recommended_info, recommended_score = scored_models[0]
# Build reason
reasons = []
if prioritize_cost:
reasons.append("Lowest cost option")
if prioritize_quality:
reasons.append("Best quality")
if user_tier == "free" and recommended_info.get("tier") == "budget":
reasons.append("Budget-friendly for free tier")
reason = ", ".join(reasons) if reasons else "Best match for your requirements"
# Get alternatives (top 2-3)
alternatives = []
for model_id, model_info, score in scored_models[1:4]:
alt_reason = f"Alternative: {model_info.get('tier', 'mid').title()} tier"
if model_info.get("cost", 0) < recommended_info.get("cost", 0):
alt_reason += ", lower cost"
elif model_info.get("cost", 0) > recommended_info.get("cost", 0):
alt_reason += ", higher quality"
alternatives.append({
"model_id": model_id,
"name": model_info.get("name", model_id),
"cost": model_info.get("cost", 0.025),
"reason": alt_reason,
})
return {
"recommended_model": recommended_id,
"reason": reason,
"alternatives": alternatives,
}
def _get_use_cases_for_model(self, model_id: str, model_info: Dict[str, Any]) -> list:
"""Get use cases for a model based on its capabilities."""
use_cases_map = {
"face_swap": ["Portrait editing", "Fun swaps", "Social media"],
"head_swap": ["Casting and concept design", "Privacy and anonymization", "Photo exploration"],
"full_head_replacement": ["Full head replacement", "Hair included", "Casting mockups"],
"realistic_blending": ["Professional work", "Marketing", "Entertainment"],
"multi_face": ["Group photos", "Family photos", "Team photos", "Creative projects", "Content creation"],
"face_enhancement": ["High-quality results", "Professional work", "Marketing campaigns"],
"identity_preservation": ["Character consistency", "Brand identity"],
}
capabilities = model_info.get("capabilities", [])
use_cases = []
for cap in capabilities:
if cap in use_cases_map:
use_cases.extend(use_cases_map[cap])
return list(set(use_cases)) if use_cases else ["General face swapping"]
async def process_face_swap(
self,
request: FaceSwapStudioRequest,
user_id: Optional[str] = None,
) -> Dict[str, Any]:
"""Process face swap request.
Args:
request: Face swap request
user_id: User ID for tracking
Returns:
Dictionary with result image and metadata
"""
# Auto-detect model if not specified
selected_model = request.model
if not selected_model:
try:
# Get image dimensions for recommendation
base_img = Image.open(io.BytesIO(base64.b64decode(request.base_image_base64.split(",", 1)[1] if "," in request.base_image_base64 else request.base_image_base64)))
face_img = Image.open(io.BytesIO(base64.b64decode(request.face_image_base64.split(",", 1)[1] if "," in request.face_image_base64 else request.face_image_base64)))
base_resolution = {"width": base_img.width, "height": base_img.height}
face_resolution = {"width": face_img.width, "height": face_img.height}
recommendation = self.recommend_model(
base_image_resolution=base_resolution,
face_image_resolution=face_resolution,
preferences={"prioritize_cost": True},
)
selected_model = recommendation.get("recommended_model")
logger.info(f"[Face Swap] Auto-selected model: {selected_model} (reason: {recommendation.get('reason')})")
except Exception as e:
logger.warning(f"[Face Swap] Auto-detection failed: {e}, using default model")
# Use first available model as fallback
from services.llm_providers.image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
provider = WaveSpeedFaceSwapProvider()
all_models = provider.get_available_models()
if all_models:
selected_model = list(all_models.keys())[0]
# Prepare options
options = request.options or {}
if request.target_face_index is not None:
options["target_face_index"] = request.target_face_index
if request.target_gender:
options["target_gender"] = request.target_gender
# Call unified entry point
result = generate_face_swap(
base_image_base64=request.base_image_base64,
face_image_base64=request.face_image_base64,
model=selected_model,
options=options,
user_id=user_id,
)
# Convert result to base64
result_base64 = base64.b64encode(result.image_bytes).decode("utf-8")
result_data_url = f"data:image/png;base64,{result_base64}"
return {
"success": True,
"image_base64": result_data_url,
"width": result.width,
"height": result.height,
"provider": result.provider,
"model": result.model,
"metadata": result.metadata or {},
}