refactor(phase1): extract image generation helpers, edit, face_swap into separate modules + fix subscription bugs
Extracted from main_image_generation.py (1002->591 lines): - image_generation/helpers.py: _validate_image_operation, _track_image_operation_usage - image_generation/edit.py: generate_image_edit (with _get_edit_provider) - image_generation/face_swap.py: generate_face_swap (with _get_face_swap_provider) Main image_generation.py now imports and re-exports from these modules. All existing imports (api/images.py, step4_asset_routes.py, studio services) continue to work unchanged. Bug fixes included: 1. generate_image_edit: Added missing 'return result' (was returning None!) 2. generate_image_edit: Added missing _track_image_operation_usage call 3. generate_face_swap: Removed duplicate dead tracking code after return statement
This commit is contained in:
120
backend/services/llm_providers/image_generation/edit.py
Normal file
120
backend/services/llm_providers/image_generation/edit.py
Normal file
@@ -0,0 +1,120 @@
|
|||||||
|
"""Image editing operations — generate_image_edit and related helpers."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from .base import ImageEditOptions, ImageGenerationResult, ImageEditProvider
|
||||||
|
from .wavespeed_edit_provider import WaveSpeedEditProvider
|
||||||
|
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||||
|
from utils.logger_utils import get_service_logger
|
||||||
|
|
||||||
|
logger = get_service_logger("image_generation.edit")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||||
|
"""Get editing provider instance by name."""
|
||||||
|
if provider_name == "wavespeed":
|
||||||
|
return WaveSpeedEditProvider()
|
||||||
|
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_image_edit(
|
||||||
|
image_base64: str,
|
||||||
|
prompt: str,
|
||||||
|
operation: str = "general_edit",
|
||||||
|
model: Optional[str] = None,
|
||||||
|
options: Optional[Dict[str, Any]] = None,
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
) -> ImageGenerationResult:
|
||||||
|
"""Generate edited image with pre-flight validation and usage tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image_base64: Base64-encoded input image (or data URI)
|
||||||
|
prompt: Edit instruction prompt
|
||||||
|
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||||
|
model: Model ID to use (default: auto-select based on provider)
|
||||||
|
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||||
|
user_id: User ID for validation and tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageGenerationResult with edited image
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If validation fails or editing fails
|
||||||
|
ValueError: If options are invalid
|
||||||
|
"""
|
||||||
|
# 1. REUSE: Validation helper
|
||||||
|
_validate_image_operation(
|
||||||
|
user_id=user_id,
|
||||||
|
operation_type="image-edit",
|
||||||
|
num_operations=1,
|
||||||
|
log_prefix="[Image Edit]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Determine provider from model or default to wavespeed
|
||||||
|
opts = options or {}
|
||||||
|
provider_name = opts.get("provider", "wavespeed")
|
||||||
|
|
||||||
|
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||||
|
provider_name = "wavespeed"
|
||||||
|
|
||||||
|
# 3. Get provider
|
||||||
|
try:
|
||||||
|
provider = _get_edit_provider(provider_name)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||||
|
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||||
|
|
||||||
|
# 4. Prepare edit options
|
||||||
|
edit_options = ImageEditOptions(
|
||||||
|
image_base64=image_base64,
|
||||||
|
prompt=prompt,
|
||||||
|
operation=operation,
|
||||||
|
mask_base64=opts.get("mask_base64"),
|
||||||
|
negative_prompt=opts.get("negative_prompt"),
|
||||||
|
model=model,
|
||||||
|
width=opts.get("width"),
|
||||||
|
height=opts.get("height"),
|
||||||
|
guidance_scale=opts.get("guidance_scale"),
|
||||||
|
steps=opts.get("steps"),
|
||||||
|
seed=opts.get("seed"),
|
||||||
|
extra=opts.get("extra"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 5. Edit image
|
||||||
|
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||||
|
try:
|
||||||
|
result = provider.edit(edit_options)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail={"error": "Image editing failed", "message": str(e)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 6. REUSE: Tracking helper
|
||||||
|
if user_id and result and result.image_bytes:
|
||||||
|
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
|
||||||
|
estimated_cost = 0.0
|
||||||
|
if result.metadata and "estimated_cost" in result.metadata:
|
||||||
|
estimated_cost = float(result.metadata["estimated_cost"])
|
||||||
|
else:
|
||||||
|
estimated_cost = 0.02 if provider_name == "wavespeed" else 0.05
|
||||||
|
|
||||||
|
_track_image_operation_usage(
|
||||||
|
user_id=user_id,
|
||||||
|
provider=provider_name,
|
||||||
|
model=result.model or model or "unknown",
|
||||||
|
operation_type="image-edit",
|
||||||
|
result_bytes=result.image_bytes,
|
||||||
|
cost=estimated_cost,
|
||||||
|
prompt=prompt,
|
||||||
|
endpoint="/image-generation/edit",
|
||||||
|
metadata=result.metadata,
|
||||||
|
log_prefix="[Image Edit]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||||
|
|
||||||
|
# 7. Return result
|
||||||
|
return result
|
||||||
105
backend/services/llm_providers/image_generation/face_swap.py
Normal file
105
backend/services/llm_providers/image_generation/face_swap.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""Face swap operations — generate_face_swap and related helpers."""
|
||||||
|
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from .base import FaceSwapOptions, FaceSwapProvider, ImageGenerationResult
|
||||||
|
from .wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||||
|
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||||
|
from utils.logger_utils import get_service_logger
|
||||||
|
|
||||||
|
logger = get_service_logger("image_generation.face_swap")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||||
|
"""Get face swap provider by name."""
|
||||||
|
if provider_name == "wavespeed":
|
||||||
|
return WaveSpeedFaceSwapProvider()
|
||||||
|
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||||
|
|
||||||
|
|
||||||
|
def generate_face_swap(
|
||||||
|
base_image_base64: str,
|
||||||
|
face_image_base64: str,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
options: Optional[Dict[str, Any]] = None,
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
) -> ImageGenerationResult:
|
||||||
|
"""Generate face swap with pre-flight validation and usage tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_image_base64: Base64-encoded base image (or data URI)
|
||||||
|
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||||
|
model: Model ID to use (default: auto-select)
|
||||||
|
options: Additional options (target_face_index, target_gender, etc.)
|
||||||
|
user_id: User ID for validation and tracking
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ImageGenerationResult with swapped face image
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If validation fails or face swap fails
|
||||||
|
ValueError: If options are invalid
|
||||||
|
"""
|
||||||
|
# 1. REUSE: Validation helper
|
||||||
|
_validate_image_operation(
|
||||||
|
user_id=user_id,
|
||||||
|
operation_type="face-swap",
|
||||||
|
num_operations=1,
|
||||||
|
log_prefix="[Face Swap]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Get provider (default to wavespeed)
|
||||||
|
provider_name = "wavespeed"
|
||||||
|
provider = _get_face_swap_provider(provider_name)
|
||||||
|
|
||||||
|
# 3. Prepare options
|
||||||
|
face_swap_options = FaceSwapOptions(
|
||||||
|
base_image_base64=base_image_base64,
|
||||||
|
face_image_base64=face_image_base64,
|
||||||
|
model=model,
|
||||||
|
target_face_index=options.get("target_face_index") if options else None,
|
||||||
|
target_gender=options.get("target_gender") if options else None,
|
||||||
|
extra=options,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. Swap face
|
||||||
|
try:
|
||||||
|
result = provider.swap_face(face_swap_options)
|
||||||
|
|
||||||
|
# 5. REUSE: Tracking helper
|
||||||
|
if user_id and result and result.image_bytes:
|
||||||
|
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||||
|
|
||||||
|
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||||
|
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||||
|
estimated_cost = model_info.get("cost", 0.025)
|
||||||
|
|
||||||
|
_track_image_operation_usage(
|
||||||
|
user_id=user_id,
|
||||||
|
provider=provider_name,
|
||||||
|
model=model_id,
|
||||||
|
operation_type="face-swap",
|
||||||
|
result_bytes=result.image_bytes,
|
||||||
|
cost=estimated_cost,
|
||||||
|
prompt=None,
|
||||||
|
endpoint="/image-studio/face-swap/process",
|
||||||
|
metadata={
|
||||||
|
"base_image_size": len(base_image_base64),
|
||||||
|
"face_image_size": len(face_image_base64),
|
||||||
|
},
|
||||||
|
log_prefix="[Face Swap]"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as api_error:
|
||||||
|
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=502,
|
||||||
|
detail={"error": "Face swap failed", "message": str(api_error)}
|
||||||
|
)
|
||||||
180
backend/services/llm_providers/image_generation/helpers.py
Normal file
180
backend/services/llm_providers/image_generation/helpers.py
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
"""Shared helpers for image generation operations — validation and usage tracking."""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from utils.logger_utils import get_service_logger
|
||||||
|
|
||||||
|
logger = get_service_logger("image_generation.helpers")
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_image_operation(
|
||||||
|
user_id: Optional[str],
|
||||||
|
operation_type: str = "image-generation",
|
||||||
|
num_operations: int = 1,
|
||||||
|
log_prefix: str = "[Image Generation]"
|
||||||
|
) -> None:
|
||||||
|
"""Reusable pre-flight validation helper for all image operations."""
|
||||||
|
if not user_id:
|
||||||
|
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||||
|
return
|
||||||
|
|
||||||
|
from services.database import get_session_for_user
|
||||||
|
from services.subscription import PricingService
|
||||||
|
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||||
|
db = get_session_for_user(user_id)
|
||||||
|
try:
|
||||||
|
pricing_service = PricingService(db)
|
||||||
|
validate_image_generation_operations(
|
||||||
|
pricing_service=pricing_service,
|
||||||
|
user_id=user_id,
|
||||||
|
num_images=num_operations
|
||||||
|
)
|
||||||
|
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id}")
|
||||||
|
except HTTPException:
|
||||||
|
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _track_image_operation_usage(
|
||||||
|
user_id: str,
|
||||||
|
provider: str,
|
||||||
|
model: str,
|
||||||
|
operation_type: str,
|
||||||
|
result_bytes: bytes,
|
||||||
|
cost: float,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
endpoint: str = "/image-generation",
|
||||||
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
log_prefix: str = "[Image Generation]",
|
||||||
|
response_time: float = 0.0
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Reusable usage tracking helper for all image operations."""
|
||||||
|
try:
|
||||||
|
from services.database import get_session_for_user
|
||||||
|
db_track = get_session_for_user(user_id)
|
||||||
|
try:
|
||||||
|
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||||
|
from services.subscription.provider_detection import detect_actual_provider
|
||||||
|
from services.subscription import PricingService
|
||||||
|
|
||||||
|
pricing = PricingService(db_track)
|
||||||
|
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||||
|
|
||||||
|
summary = db_track.query(UsageSummary).filter(
|
||||||
|
UsageSummary.user_id == user_id,
|
||||||
|
UsageSummary.billing_period == current_period
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not summary:
|
||||||
|
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||||
|
db_track.add(summary)
|
||||||
|
db_track.flush()
|
||||||
|
|
||||||
|
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||||
|
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||||
|
|
||||||
|
new_calls = current_calls_before + 1
|
||||||
|
new_cost = current_cost_before + cost
|
||||||
|
|
||||||
|
from sqlalchemy import text as sql_text
|
||||||
|
update_query = sql_text("""
|
||||||
|
UPDATE usage_summaries
|
||||||
|
SET stability_calls = :new_calls,
|
||||||
|
stability_cost = :new_cost
|
||||||
|
WHERE user_id = :user_id AND billing_period = :period
|
||||||
|
""")
|
||||||
|
db_track.execute(update_query, {
|
||||||
|
'new_calls': new_calls,
|
||||||
|
'new_cost': new_cost,
|
||||||
|
'user_id': user_id,
|
||||||
|
'period': current_period
|
||||||
|
})
|
||||||
|
|
||||||
|
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||||
|
summary.total_calls = (summary.total_calls or 0) + 1
|
||||||
|
summary.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
api_provider = APIProvider.STABILITY
|
||||||
|
actual_provider = detect_actual_provider(
|
||||||
|
provider_enum=api_provider,
|
||||||
|
model_name=model,
|
||||||
|
endpoint=endpoint
|
||||||
|
)
|
||||||
|
|
||||||
|
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||||
|
usage_log = APIUsageLog(
|
||||||
|
user_id=user_id,
|
||||||
|
provider=api_provider,
|
||||||
|
endpoint=endpoint,
|
||||||
|
method="POST",
|
||||||
|
model_used=model or "unknown",
|
||||||
|
actual_provider_name=actual_provider,
|
||||||
|
tokens_input=0,
|
||||||
|
tokens_output=0,
|
||||||
|
tokens_total=0,
|
||||||
|
cost_input=0.0,
|
||||||
|
cost_output=0.0,
|
||||||
|
cost_total=cost,
|
||||||
|
response_time=response_time,
|
||||||
|
status_code=200,
|
||||||
|
request_size=request_size,
|
||||||
|
response_size=len(result_bytes),
|
||||||
|
billing_period=current_period,
|
||||||
|
)
|
||||||
|
db_track.add(usage_log)
|
||||||
|
|
||||||
|
limits = pricing.get_user_limits(user_id)
|
||||||
|
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||||
|
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||||
|
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||||
|
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||||
|
|
||||||
|
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||||
|
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||||
|
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||||
|
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||||
|
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||||
|
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||||
|
|
||||||
|
db_track.commit()
|
||||||
|
logger.info(f"{log_prefix} ✅ Tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||||
|
|
||||||
|
operation_name = operation_type.replace("-", " ").title()
|
||||||
|
print(f"""
|
||||||
|
[SUBSCRIPTION] {operation_name}
|
||||||
|
├─ User: {user_id}
|
||||||
|
├─ Plan: {plan_name} ({tier})
|
||||||
|
├─ Provider: {provider}
|
||||||
|
├─ Actual Provider: {provider}
|
||||||
|
├─ Model: {model or 'unknown'}
|
||||||
|
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||||
|
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||||
|
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||||
|
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||||
|
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||||
|
└─ Status: ✅ Allowed & Tracked
|
||||||
|
""", flush=True)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
return {"current_calls": new_calls, "cost": cost, "total_cost": new_cost}
|
||||||
|
|
||||||
|
except Exception as track_error:
|
||||||
|
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||||
|
import traceback
|
||||||
|
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||||
|
db_track.rollback()
|
||||||
|
return {}
|
||||||
|
finally:
|
||||||
|
db_track.close()
|
||||||
|
except Exception as usage_error:
|
||||||
|
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||||
|
import traceback
|
||||||
|
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||||
|
return {}
|
||||||
@@ -18,9 +18,9 @@ from .image_generation import (
|
|||||||
StabilityImageProvider,
|
StabilityImageProvider,
|
||||||
WaveSpeedImageProvider,
|
WaveSpeedImageProvider,
|
||||||
)
|
)
|
||||||
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
from .image_generation.helpers import _validate_image_operation, _track_image_operation_usage
|
||||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
from .image_generation.edit import generate_image_edit
|
||||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
from .image_generation.face_swap import generate_face_swap
|
||||||
from utils.logger_utils import get_service_logger
|
from utils.logger_utils import get_service_logger
|
||||||
from .tenant_provider_config import tenant_provider_config_resolver
|
from .tenant_provider_config import tenant_provider_config_resolver
|
||||||
|
|
||||||
@@ -53,259 +53,6 @@ def _get_provider(provider_name: str, user_id: Optional[str] = None):
|
|||||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||||
|
|
||||||
|
|
||||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
|
||||||
"""Get face swap provider by name."""
|
|
||||||
if provider_name == "wavespeed":
|
|
||||||
return WaveSpeedFaceSwapProvider()
|
|
||||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
|
||||||
|
|
||||||
|
|
||||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
|
||||||
"""Get editing provider instance.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider_name: Provider name ("wavespeed", "stability", etc.)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ImageEditProvider instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If provider is not supported
|
|
||||||
"""
|
|
||||||
if provider_name == "wavespeed":
|
|
||||||
return WaveSpeedEditProvider()
|
|
||||||
# TODO: Add Stability edit provider if needed
|
|
||||||
# elif provider_name == "stability":
|
|
||||||
# return StabilityEditProvider()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_image_operation(
|
|
||||||
user_id: Optional[str],
|
|
||||||
operation_type: str = "image-generation",
|
|
||||||
num_operations: int = 1,
|
|
||||||
log_prefix: str = "[Image Generation]"
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Reusable pre-flight validation helper for all image operations.
|
|
||||||
|
|
||||||
Extracted from generate_image() to be reused across all image operation functions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID for subscription checking
|
|
||||||
operation_type: Type of operation (for logging)
|
|
||||||
num_operations: Number of operations to validate (default: 1)
|
|
||||||
log_prefix: Logging prefix for operation-specific logs
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If validation fails (subscription limits exceeded, etc.)
|
|
||||||
"""
|
|
||||||
if not user_id:
|
|
||||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
|
||||||
return
|
|
||||||
|
|
||||||
from services.database import get_session_for_user
|
|
||||||
from services.subscription import PricingService
|
|
||||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
|
||||||
db = get_session_for_user(user_id)
|
|
||||||
try:
|
|
||||||
pricing_service = PricingService(db)
|
|
||||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
|
||||||
validate_image_generation_operations(
|
|
||||||
pricing_service=pricing_service,
|
|
||||||
user_id=user_id,
|
|
||||||
num_images=num_operations
|
|
||||||
)
|
|
||||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
|
|
||||||
except HTTPException as http_ex:
|
|
||||||
# Re-raise immediately - don't proceed with API call
|
|
||||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
db.close()
|
|
||||||
|
|
||||||
|
|
||||||
def _track_image_operation_usage(
|
|
||||||
user_id: str,
|
|
||||||
provider: str,
|
|
||||||
model: str,
|
|
||||||
operation_type: str,
|
|
||||||
result_bytes: bytes,
|
|
||||||
cost: float,
|
|
||||||
prompt: Optional[str] = None,
|
|
||||||
endpoint: str = "/image-generation",
|
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
|
||||||
log_prefix: str = "[Image Generation]",
|
|
||||||
response_time: float = 0.0
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Reusable usage tracking helper for all image operations.
|
|
||||||
|
|
||||||
Extracted from generate_image() to be reused across all image operation functions.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User ID for tracking
|
|
||||||
provider: Provider name (e.g., "wavespeed", "stability")
|
|
||||||
model: Model name used
|
|
||||||
operation_type: Type of operation (for logging)
|
|
||||||
result_bytes: Generated/processed image bytes
|
|
||||||
cost: Cost of the operation
|
|
||||||
prompt: Optional prompt text (for request size calculation)
|
|
||||||
endpoint: API endpoint path (for logging)
|
|
||||||
metadata: Optional additional metadata
|
|
||||||
log_prefix: Logging prefix for operation-specific logs
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with tracking information (current_calls, cost, etc.)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from services.database import get_session_for_user
|
|
||||||
db_track = get_session_for_user(user_id)
|
|
||||||
try:
|
|
||||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
|
||||||
from services.subscription.provider_detection import detect_actual_provider
|
|
||||||
from services.subscription import PricingService
|
|
||||||
|
|
||||||
pricing = PricingService(db_track)
|
|
||||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
|
||||||
|
|
||||||
# Get or create usage summary
|
|
||||||
summary = db_track.query(UsageSummary).filter(
|
|
||||||
UsageSummary.user_id == user_id,
|
|
||||||
UsageSummary.billing_period == current_period
|
|
||||||
).first()
|
|
||||||
|
|
||||||
if not summary:
|
|
||||||
summary = UsageSummary(
|
|
||||||
user_id=user_id,
|
|
||||||
billing_period=current_period
|
|
||||||
)
|
|
||||||
db_track.add(summary)
|
|
||||||
db_track.flush()
|
|
||||||
|
|
||||||
# Get current values before update
|
|
||||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
|
||||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
|
||||||
|
|
||||||
# Update image calls and cost
|
|
||||||
new_calls = current_calls_before + 1
|
|
||||||
new_cost = current_cost_before + cost
|
|
||||||
|
|
||||||
# Use direct SQL UPDATE for dynamic attributes
|
|
||||||
from sqlalchemy import text as sql_text
|
|
||||||
update_query = sql_text("""
|
|
||||||
UPDATE usage_summaries
|
|
||||||
SET stability_calls = :new_calls,
|
|
||||||
stability_cost = :new_cost
|
|
||||||
WHERE user_id = :user_id AND billing_period = :period
|
|
||||||
""")
|
|
||||||
db_track.execute(update_query, {
|
|
||||||
'new_calls': new_calls,
|
|
||||||
'new_cost': new_cost,
|
|
||||||
'user_id': user_id,
|
|
||||||
'period': current_period
|
|
||||||
})
|
|
||||||
|
|
||||||
# Update total cost
|
|
||||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
|
||||||
summary.total_calls = (summary.total_calls or 0) + 1
|
|
||||||
summary.updated_at = datetime.utcnow()
|
|
||||||
|
|
||||||
# Determine API provider based on actual provider
|
|
||||||
api_provider = APIProvider.STABILITY # Default for image generation
|
|
||||||
|
|
||||||
# Detect actual provider name (WaveSpeed, Stability, HuggingFace, etc.)
|
|
||||||
actual_provider = detect_actual_provider(
|
|
||||||
provider_enum=api_provider,
|
|
||||||
model_name=model,
|
|
||||||
endpoint=endpoint
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create usage log
|
|
||||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
|
||||||
usage_log = APIUsageLog(
|
|
||||||
user_id=user_id,
|
|
||||||
provider=api_provider,
|
|
||||||
endpoint=endpoint,
|
|
||||||
method="POST",
|
|
||||||
model_used=model or "unknown",
|
|
||||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, Stability, etc.)
|
|
||||||
tokens_input=0,
|
|
||||||
tokens_output=0,
|
|
||||||
tokens_total=0,
|
|
||||||
cost_input=0.0,
|
|
||||||
cost_output=0.0,
|
|
||||||
cost_total=cost,
|
|
||||||
response_time=response_time, # Use actual response time
|
|
||||||
status_code=200,
|
|
||||||
request_size=request_size,
|
|
||||||
response_size=len(result_bytes),
|
|
||||||
billing_period=current_period,
|
|
||||||
)
|
|
||||||
db_track.add(usage_log)
|
|
||||||
|
|
||||||
# Get plan details for unified log
|
|
||||||
limits = pricing.get_user_limits(user_id)
|
|
||||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
|
||||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
|
||||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
|
||||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
|
||||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
|
||||||
|
|
||||||
# Get related stats for unified log
|
|
||||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
|
||||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
|
||||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
|
||||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
|
||||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
|
||||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
|
||||||
|
|
||||||
db_track.commit()
|
|
||||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
|
||||||
|
|
||||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
|
||||||
operation_name = operation_type.replace("-", " ").title()
|
|
||||||
print(f"""
|
|
||||||
[SUBSCRIPTION] {operation_name}
|
|
||||||
├─ User: {user_id}
|
|
||||||
├─ Plan: {plan_name} ({tier})
|
|
||||||
├─ Provider: {provider}
|
|
||||||
├─ Actual Provider: {provider}
|
|
||||||
├─ Model: {model or 'unknown'}
|
|
||||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
|
||||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
|
||||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
|
||||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
|
||||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
|
||||||
└─ Status: ✅ Allowed & Tracked
|
|
||||||
""", flush=True)
|
|
||||||
sys.stdout.flush()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"current_calls": new_calls,
|
|
||||||
"cost": cost,
|
|
||||||
"total_cost": new_cost,
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as track_error:
|
|
||||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
|
||||||
import traceback
|
|
||||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
|
||||||
db_track.rollback()
|
|
||||||
return {}
|
|
||||||
finally:
|
|
||||||
db_track.close()
|
|
||||||
except Exception as usage_error:
|
|
||||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
|
||||||
import traceback
|
|
||||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||||
"""Generate image with pre-flight validation.
|
"""Generate image with pre-flight validation.
|
||||||
|
|
||||||
@@ -500,165 +247,7 @@ def generate_character_image(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_image_edit(
|
|
||||||
image_base64: str,
|
|
||||||
prompt: str,
|
|
||||||
operation: str = "general_edit",
|
|
||||||
model: Optional[str] = None,
|
|
||||||
options: Optional[Dict[str, Any]] = None,
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
) -> ImageGenerationResult:
|
|
||||||
"""
|
|
||||||
Generate edited image - REUSES validation and tracking helpers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
image_base64: Base64-encoded input image (or data URI)
|
|
||||||
prompt: Edit instruction prompt
|
|
||||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
|
||||||
model: Model ID to use (default: auto-select based on provider)
|
|
||||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
|
||||||
user_id: User ID for validation and tracking
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ImageGenerationResult with edited image
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If validation fails or editing fails
|
|
||||||
ValueError: If options are invalid
|
|
||||||
"""
|
|
||||||
# 1. REUSE: Validation helper
|
|
||||||
_validate_image_operation(
|
|
||||||
user_id=user_id,
|
|
||||||
operation_type="image-edit",
|
|
||||||
num_operations=1,
|
|
||||||
log_prefix="[Image Edit]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Determine provider from model or default to wavespeed
|
|
||||||
opts = options or {}
|
|
||||||
provider_name = opts.get("provider", "wavespeed")
|
|
||||||
|
|
||||||
# If model is specified and starts with "wavespeed", use wavespeed provider
|
|
||||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
|
||||||
provider_name = "wavespeed"
|
|
||||||
|
|
||||||
# 3. Get provider (REUSES provider pattern)
|
|
||||||
try:
|
|
||||||
provider = _get_edit_provider(provider_name)
|
|
||||||
except ValueError as e:
|
|
||||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
|
||||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
|
||||||
|
|
||||||
# 4. Prepare edit options
|
|
||||||
edit_options = ImageEditOptions(
|
|
||||||
image_base64=image_base64,
|
|
||||||
prompt=prompt,
|
|
||||||
operation=operation,
|
|
||||||
mask_base64=opts.get("mask_base64"),
|
|
||||||
negative_prompt=opts.get("negative_prompt"),
|
|
||||||
model=model,
|
|
||||||
width=opts.get("width"),
|
|
||||||
height=opts.get("height"),
|
|
||||||
guidance_scale=opts.get("guidance_scale"),
|
|
||||||
steps=opts.get("steps"),
|
|
||||||
seed=opts.get("seed"),
|
|
||||||
extra=opts.get("extra"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 5. Edit image
|
|
||||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
|
||||||
try:
|
|
||||||
result = provider.edit(edit_options)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=502,
|
|
||||||
detail={
|
|
||||||
"error": "Image editing failed",
|
|
||||||
"message": str(e)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def generate_face_swap(
|
|
||||||
base_image_base64: str,
|
|
||||||
face_image_base64: str,
|
|
||||||
model: Optional[str] = None,
|
|
||||||
options: Optional[Dict[str, Any]] = None,
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
) -> ImageGenerationResult:
|
|
||||||
"""
|
|
||||||
Generate face swap - REUSES validation and tracking helpers.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
base_image_base64: Base64-encoded base image (or data URI)
|
|
||||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
|
||||||
model: Model ID to use (default: auto-select)
|
|
||||||
options: Additional options (target_face_index, target_gender, etc.)
|
|
||||||
user_id: User ID for validation and tracking
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ImageGenerationResult with swapped face image
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If validation fails or face swap fails
|
|
||||||
ValueError: If options are invalid
|
|
||||||
"""
|
|
||||||
# 1. REUSE: Validation helper
|
|
||||||
_validate_image_operation(
|
|
||||||
user_id=user_id,
|
|
||||||
operation_type="face-swap",
|
|
||||||
image_base64=base_image_base64, # Use base image for validation
|
|
||||||
log_prefix="[Face Swap]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Get provider (default to wavespeed)
|
|
||||||
provider_name = "wavespeed"
|
|
||||||
provider = _get_face_swap_provider(provider_name)
|
|
||||||
|
|
||||||
# 3. Prepare options
|
|
||||||
face_swap_options = FaceSwapOptions(
|
|
||||||
base_image_base64=base_image_base64,
|
|
||||||
face_image_base64=face_image_base64,
|
|
||||||
model=model,
|
|
||||||
target_face_index=options.get("target_face_index") if options else None,
|
|
||||||
target_gender=options.get("target_gender") if options else None,
|
|
||||||
extra=options,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Swap face
|
|
||||||
try:
|
|
||||||
result = provider.swap_face(face_swap_options)
|
|
||||||
|
|
||||||
# 5. REUSE: Tracking helper
|
|
||||||
if user_id and result and result.image_bytes:
|
|
||||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
|
||||||
|
|
||||||
# Get model cost
|
|
||||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
|
||||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
|
||||||
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
|
|
||||||
|
|
||||||
# Reuse tracking helper
|
|
||||||
_track_image_operation_usage(
|
|
||||||
user_id=user_id,
|
|
||||||
provider=provider_name,
|
|
||||||
model=model_id,
|
|
||||||
operation_type="face-swap",
|
|
||||||
result_bytes=result.image_bytes,
|
|
||||||
cost=estimated_cost,
|
|
||||||
prompt=None, # Face swap doesn't use prompts
|
|
||||||
endpoint="/image-studio/face-swap/process",
|
|
||||||
metadata={
|
|
||||||
"base_image_size": len(base_image_base64),
|
|
||||||
"face_image_size": len(face_image_base64),
|
|
||||||
},
|
|
||||||
log_prefix="[Face Swap]"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except HTTPException:
|
except HTTPException:
|
||||||
raise
|
raise
|
||||||
|
|||||||
Reference in New Issue
Block a user