From ca725b77e75afbb8745b5d01e100a5623a07e623 Mon Sep 17 00:00:00 2001 From: ajaysi Date: Sat, 9 May 2026 08:51:06 +0530 Subject: [PATCH] refactor(phase2): add provider-aware tracking and fill missing subscription usage tracking MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Changes: 1. helpers.py (_track_image_operation_usage): Map provider name to DB columns dynamically (stability→stability_calls, wavespeed→wavespeed_calls, etc.) instead of hardcoding stability_calls/stability_cost. 2. upscale_service.py: Added _track_image_operation_usage() call after successful Stability upscale completion. 3. control_service.py: Added _track_image_operation_usage() call after successful Stability control operation completion. 4. edit_service.py: Added _track_image_operation_usage() call after successful Stability edit operation (remove_background, inpaint, outpaint, search_replace, search_recolor, relight). Previously only Create Studio and Face Swap tracked usage. Now all five studios correctly decrement subscription limits. --- .../services/image_studio/control_service.py | 15 ++++++++ backend/services/image_studio/edit_service.py | 13 +++++++ .../services/image_studio/upscale_service.py | 14 +++++++ .../llm_providers/image_generation/helpers.py | 38 ++++++++++++++----- 4 files changed, 71 insertions(+), 9 deletions(-) diff --git a/backend/services/image_studio/control_service.py b/backend/services/image_studio/control_service.py index 75604adf..243d5057 100644 --- a/backend/services/image_studio/control_service.py +++ b/backend/services/image_studio/control_service.py @@ -237,6 +237,21 @@ class ControlStudioService: image_bytes = self._extract_image_bytes(result) metadata = self._image_bytes_to_metadata(image_bytes) + + # Track usage + if user_id: + from services.llm_providers.main_image_generation import _track_image_operation_usage + _track_image_operation_usage( + user_id=user_id, + provider="stability", + model=f"control-{operation}", + operation_type="image-control", + result_bytes=image_bytes, + cost=0.04, + endpoint="/image-studio/control/process", + log_prefix="[Control Studio]" + ) + metadata.update( { "operation": operation, diff --git a/backend/services/image_studio/edit_service.py b/backend/services/image_studio/edit_service.py index 339bde28..3858f841 100644 --- a/backend/services/image_studio/edit_service.py +++ b/backend/services/image_studio/edit_service.py @@ -514,6 +514,19 @@ class EditStudioService: background_bytes=background_bytes, lighting_bytes=lighting_bytes, ) + # Track usage for Stability operations + if user_id: + from services.llm_providers.main_image_generation import _track_image_operation_usage + _track_image_operation_usage( + user_id=user_id, + provider="stability", + model=f"edit-{operation}", + operation_type="image-edit", + result_bytes=image_bytes, + cost=0.04, + endpoint="/image-studio/edit/process", + log_prefix="[Edit Studio]" + ) else: image_bytes = await self._handle_general_edit( request=request, diff --git a/backend/services/image_studio/upscale_service.py b/backend/services/image_studio/upscale_service.py index 4b2707a5..82fd2c53 100644 --- a/backend/services/image_studio/upscale_service.py +++ b/backend/services/image_studio/upscale_service.py @@ -88,6 +88,20 @@ class UpscaleStudioService: image_bytes = self._extract_image_bytes(result) metadata = self._image_metadata(image_bytes) + # Track usage + if user_id: + from services.llm_providers.main_image_generation import _track_image_operation_usage + _track_image_operation_usage( + user_id=user_id, + provider="stability", + model=f"upscale-{mode}", + operation_type="image-upscale", + result_bytes=image_bytes, + cost=0.04, + endpoint="/image-studio/upscale", + log_prefix="[Upscale Studio]" + ) + return { "success": True, "mode": mode, diff --git a/backend/services/llm_providers/image_generation/helpers.py b/backend/services/llm_providers/image_generation/helpers.py index 3afe40d4..1c6bc3d7 100644 --- a/backend/services/llm_providers/image_generation/helpers.py +++ b/backend/services/llm_providers/image_generation/helpers.py @@ -77,17 +77,27 @@ def _track_image_operation_usage( db_track.add(summary) db_track.flush() - current_calls_before = getattr(summary, "stability_calls", 0) or 0 - current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0 + # Map provider to DB column names + provider_column_map = { + "stability": ("stability_calls", "stability_cost"), + "wavespeed": ("wavespeed_calls", "wavespeed_cost"), + "gemini": ("gemini_calls", "gemini_cost"), + "openai": ("openai_calls", "openai_cost"), + "huggingface": ("total_calls", "total_cost"), # no dedicated columns + } + calls_col, cost_col = provider_column_map.get(provider, ("total_calls", "total_cost")) + + current_calls_before = getattr(summary, calls_col, 0) or 0 + current_cost_before = getattr(summary, cost_col, 0.0) or 0.0 new_calls = current_calls_before + 1 new_cost = current_cost_before + cost from sqlalchemy import text as sql_text - update_query = sql_text(""" + update_query = sql_text(f""" UPDATE usage_summaries - SET stability_calls = :new_calls, - stability_cost = :new_cost + SET {calls_col} = :new_calls, + {cost_col} = :new_cost WHERE user_id = :user_id AND billing_period = :period """) db_track.execute(update_query, { @@ -101,7 +111,17 @@ def _track_image_operation_usage( summary.total_calls = (summary.total_calls or 0) + 1 summary.updated_at = datetime.utcnow() - api_provider = APIProvider.STABILITY + # Map provider to APIProvider enum + provider_api_map = { + "stability": APIProvider.STABILITY, + "wavespeed": APIProvider.WAVESPEED, + "gemini": APIProvider.GEMINI, + "openai": APIProvider.OPENAI, + "image_edit": APIProvider.IMAGE_EDIT, + "video": APIProvider.VIDEO, + "audio": APIProvider.AUDIO, + } + api_provider = provider_api_map.get(provider, APIProvider.STABILITY) actual_provider = detect_actual_provider( provider_enum=api_provider, model_name=model, @@ -133,8 +153,8 @@ def _track_image_operation_usage( limits = pricing.get_user_limits(user_id) plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown' tier = limits.get('tier', 'unknown') if limits else 'unknown' - image_limit = limits['limits'].get("stability_calls", 0) if limits else 0 - image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞' + provider_limit = limits['limits'].get(calls_col, 0) if limits else 0 + provider_limit_display = provider_limit if (provider_limit > 0 or tier != 'enterprise') else '∞' current_audio_calls = getattr(summary, "audio_calls", 0) or 0 audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0 @@ -154,7 +174,7 @@ def _track_image_operation_usage( ├─ Provider: {provider} ├─ Actual Provider: {provider} ├─ Model: {model or 'unknown'} -├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display} +├─ Calls: {current_calls_before} → {new_calls} / {provider_limit_display} ├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f} ├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'} ├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}