diff --git a/backend/services/subscription/usage_tracking_helpers/__init__.py b/backend/services/subscription/usage_tracking_helpers/__init__.py new file mode 100644 index 00000000..a61c757b --- /dev/null +++ b/backend/services/subscription/usage_tracking_helpers/__init__.py @@ -0,0 +1,29 @@ +"""Helper utilities for usage tracking service.""" + +from .usage_reset_helpers import reset_usage_summary_counters +from .usage_stats_helpers import ( + build_default_usage_percentages, + build_empty_usage_response, + build_provider_breakdown, + calculate_final_total_cost, + maybe_persist_reconciled_costs, +) +from .usage_trends_helpers import ( + build_billing_periods, + build_usage_trends_response, + query_usage_summaries, + self_heal_summaries_from_logs, +) + +__all__ = [ + "build_default_usage_percentages", + "build_empty_usage_response", + "build_provider_breakdown", + "calculate_final_total_cost", + "maybe_persist_reconciled_costs", + "build_billing_periods", + "query_usage_summaries", + "self_heal_summaries_from_logs", + "build_usage_trends_response", + "reset_usage_summary_counters", +] diff --git a/backend/services/subscription/usage_tracking_helpers/usage_reset_helpers.py b/backend/services/subscription/usage_tracking_helpers/usage_reset_helpers.py new file mode 100644 index 00000000..132ce390 --- /dev/null +++ b/backend/services/subscription/usage_tracking_helpers/usage_reset_helpers.py @@ -0,0 +1,73 @@ +"""Helpers extracted from UsageTrackingService.reset_current_billing_period.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from models.subscription_models import UsageStatus + + +_CALL_FIELDS = [ + "gemini_calls", + "openai_calls", + "anthropic_calls", + "mistral_calls", + "wavespeed_calls", + "tavily_calls", + "serper_calls", + "metaphor_calls", + "firecrawl_calls", + "stability_calls", + "exa_calls", + "video_calls", + "audio_calls", + "image_edit_calls", +] + +_TOKEN_FIELDS = [ + "gemini_tokens", + "openai_tokens", + "anthropic_tokens", + "mistral_tokens", + "wavespeed_tokens", +] + +_COST_FIELDS = [ + "gemini_cost", + "openai_cost", + "anthropic_cost", + "mistral_cost", + "wavespeed_cost", + "tavily_cost", + "serper_cost", + "metaphor_cost", + "firecrawl_cost", + "stability_cost", + "exa_cost", + "video_cost", + "image_edit_cost", + "audio_cost", +] + + +def reset_usage_summary_counters(summary: Any) -> None: + """Reset all known usage counters to baseline values.""" + summary.usage_status = UsageStatus.ACTIVE + + for field in _CALL_FIELDS: + if hasattr(summary, field): + setattr(summary, field, 0) + + for field in _TOKEN_FIELDS: + if hasattr(summary, field): + setattr(summary, field, 0) + + for field in _COST_FIELDS: + if hasattr(summary, field): + setattr(summary, field, 0.0) + + summary.total_calls = 0 + summary.total_tokens = 0 + summary.total_cost = 0.0 + summary.updated_at = datetime.utcnow() diff --git a/backend/services/subscription/usage_tracking_helpers/usage_stats_helpers.py b/backend/services/subscription/usage_tracking_helpers/usage_stats_helpers.py new file mode 100644 index 00000000..5de62213 --- /dev/null +++ b/backend/services/subscription/usage_tracking_helpers/usage_stats_helpers.py @@ -0,0 +1,250 @@ +"""Helper utilities extracted from UsageTrackingService.get_user_usage_stats.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any, Dict, Iterable, Tuple + +from loguru import logger + +from models.subscription_models import APIProvider, APIUsageLog + + +def _safe_getattr(obj: Any, attr: str, default: Any) -> Any: + """Return object attribute using a defensive default fallback.""" + value = getattr(obj, attr, default) + return default if value is None else value + + +def build_default_usage_percentages(providers: Iterable[APIProvider]) -> Dict[str, float]: + """Create zeroed usage percentages for all provider call limits plus cost.""" + usage_percentages = {f"{provider.value}_calls": 0 for provider in providers} + usage_percentages["cost"] = 0 + return usage_percentages + + +def build_empty_usage_response( + *, + billing_period: str, + limits: Dict[str, Any] | None, + providers: Iterable[APIProvider], +) -> Dict[str, Any]: + """Build a no-usage response payload with complete shape.""" + provider_breakdown = { + provider.value: {"calls": 0, "tokens": 0, "cost": 0.0} for provider in providers + } + + return { + "billing_period": billing_period, + "usage_status": "active", + "total_calls": 0, + "total_tokens": 0, + "total_cost": 0.0, + "avg_response_time": 0.0, + "error_rate": 0.0, + "last_updated": datetime.now().isoformat(), + "limits": limits, + "provider_breakdown": provider_breakdown, + "alerts": [], + "usage_percentages": build_default_usage_percentages(providers), + } + + +def _resolve_cost_from_logs( + *, + db: Any, + user_id: str, + billing_period: str, + provider: APIProvider, + current_calls: int, + current_cost: float, + debug_label: str, +) -> float: + """Backfill provider cost from usage logs only when needed.""" + if current_calls <= 0 or current_cost != 0.0: + return current_cost + + logs = ( + db.query(APIUsageLog) + .filter( + APIUsageLog.user_id == user_id, + APIUsageLog.provider == provider, + APIUsageLog.billing_period == billing_period, + ) + .all() + ) + if not logs: + return current_cost + + calculated_cost = sum(float(log.cost_total or 0.0) for log in logs) + logger.info( + f"[UsageStats] Calculated {debug_label} cost from {len(logs)} logs: ${calculated_cost:.6f}" + ) + return calculated_cost + + +def build_provider_breakdown( + *, + db: Any, + user_id: str, + billing_period: str, + summary: Any, +) -> Tuple[Dict[str, Dict[str, float | int]], Dict[str, float], Dict[str, int]]: + """Build provider breakdown while preserving existing backfill behavior.""" + provider_breakdown: Dict[str, Dict[str, float | int]] = {} + + gemini_calls = int(_safe_getattr(summary, "gemini_calls", 0) or 0) + gemini_tokens = int(_safe_getattr(summary, "gemini_tokens", 0) or 0) + gemini_cost = float(_safe_getattr(summary, "gemini_cost", 0.0) or 0.0) + gemini_cost = _resolve_cost_from_logs( + db=db, + user_id=user_id, + billing_period=billing_period, + provider=APIProvider.GEMINI, + current_calls=gemini_calls, + current_cost=gemini_cost, + debug_label="gemini", + ) + provider_breakdown["gemini"] = {"calls": gemini_calls, "tokens": gemini_tokens, "cost": gemini_cost} + + mistral_calls = int(_safe_getattr(summary, "mistral_calls", 0) or 0) + mistral_tokens = int(_safe_getattr(summary, "mistral_tokens", 0) or 0) + mistral_cost = float(_safe_getattr(summary, "mistral_cost", 0.0) or 0.0) + mistral_cost = _resolve_cost_from_logs( + db=db, + user_id=user_id, + billing_period=billing_period, + provider=APIProvider.MISTRAL, + current_calls=mistral_calls, + current_cost=mistral_cost, + debug_label="mistral (HuggingFace)", + ) + provider_breakdown["huggingface"] = { + "calls": mistral_calls, + "tokens": mistral_tokens, + "cost": mistral_cost, + } + + mapped_providers = { + "video": ("video_calls", "video_cost", APIProvider.VIDEO), + "audio": ("audio_calls", "audio_cost", APIProvider.AUDIO), + "image": ("stability_calls", "stability_cost", APIProvider.STABILITY), + "image_edit": ("image_edit_calls", "image_edit_cost", APIProvider.IMAGE_EDIT), + } + resolved_costs = { + "gemini_cost": gemini_cost, + "mistral_cost": mistral_cost, + } + + for key, (calls_attr, cost_attr, provider_enum) in mapped_providers.items(): + calls = int(_safe_getattr(summary, calls_attr, 0) or 0) + cost = float(_safe_getattr(summary, cost_attr, 0.0) or 0.0) + cost = _resolve_cost_from_logs( + db=db, + user_id=user_id, + billing_period=billing_period, + provider=provider_enum, + current_calls=calls, + current_cost=cost, + debug_label=key, + ) + provider_breakdown[key] = {"calls": calls, "tokens": 0, "cost": cost} + resolved_costs[cost_attr] = cost + + wavespeed_logs = ( + db.query(APIUsageLog) + .filter( + APIUsageLog.user_id == user_id, + APIUsageLog.billing_period == billing_period, + APIUsageLog.actual_provider_name == "wavespeed", + ) + .all() + ) + if wavespeed_logs: + wavespeed_calls = len(wavespeed_logs) + wavespeed_tokens = sum((log.tokens_total or 0) for log in wavespeed_logs) + wavespeed_cost = sum(float(log.cost_total or 0.0) for log in wavespeed_logs) + provider_breakdown["wavespeed"] = { + "calls": wavespeed_calls, + "tokens": wavespeed_tokens, + "cost": wavespeed_cost, + } + logger.info( + f"[UsageStats] Calculated WaveSpeed usage: {wavespeed_calls} calls, ${wavespeed_cost:.6f}" + ) + else: + provider_breakdown["wavespeed"] = {"calls": 0, "tokens": 0, "cost": 0.0} + + for search_provider in ("tavily", "serper", "exa"): + calls = int(_safe_getattr(summary, f"{search_provider}_calls", 0) or 0) + cost = float(_safe_getattr(summary, f"{search_provider}_cost", 0.0) or 0.0) + provider_breakdown[search_provider] = {"calls": calls, "tokens": 0, "cost": cost} + resolved_costs[f"{search_provider}_cost"] = cost + + core_counts = { + "gemini_calls": gemini_calls, + "mistral_calls": mistral_calls, + } + + return provider_breakdown, resolved_costs, core_counts + + +def calculate_final_total_cost(summary_total_cost: float, resolved_costs: Dict[str, float]) -> Tuple[float, float]: + """Return calculated and chosen total cost values.""" + calculated_total_cost = float( + resolved_costs.get("gemini_cost", 0.0) + + resolved_costs.get("mistral_cost", 0.0) + + resolved_costs.get("video_cost", 0.0) + + resolved_costs.get("audio_cost", 0.0) + + resolved_costs.get("stability_cost", 0.0) + + resolved_costs.get("image_edit_cost", 0.0) + + resolved_costs.get("tavily_cost", 0.0) + + resolved_costs.get("serper_cost", 0.0) + + resolved_costs.get("exa_cost", 0.0) + ) + final_total_cost = calculated_total_cost if calculated_total_cost > (summary_total_cost or 0.0) else (summary_total_cost or 0.0) + return calculated_total_cost, final_total_cost + + +def maybe_persist_reconciled_costs( + *, + db: Any, + summary: Any, + summary_total_cost: float, + calculated_total_cost: float, + final_total_cost: float, + resolved_costs: Dict[str, float], +) -> None: + """Persist summary cost reconciliation when calculated values are more complete.""" + if not ( + calculated_total_cost > 0 + and (summary_total_cost == 0.0 or calculated_total_cost > summary_total_cost) + ): + return + + logger.info( + "[UsageStats] Updating summary costs (was {}): total_cost={:.6f}, gemini_cost={:.6f}, " + "mistral_cost={:.6f}, video_cost={:.6f}, audio_cost={:.6f}, image_cost={:.6f}".format( + summary_total_cost, + final_total_cost, + resolved_costs.get("gemini_cost", 0.0), + resolved_costs.get("mistral_cost", 0.0), + resolved_costs.get("video_cost", 0.0), + resolved_costs.get("audio_cost", 0.0), + resolved_costs.get("stability_cost", 0.0), + ) + ) + + summary.total_cost = final_total_cost + summary.gemini_cost = resolved_costs.get("gemini_cost", 0.0) + summary.mistral_cost = resolved_costs.get("mistral_cost", 0.0) + + for summary_attr in ("video_cost", "audio_cost", "stability_cost", "image_edit_cost"): + if hasattr(summary, summary_attr): + setattr(summary, summary_attr, resolved_costs.get(summary_attr, 0.0)) + + try: + db.commit() + except Exception as e: + logger.error(f"[UsageStats] Error updating summary costs: {e}") + db.rollback() diff --git a/backend/services/subscription/usage_tracking_helpers/usage_trends_helpers.py b/backend/services/subscription/usage_tracking_helpers/usage_trends_helpers.py new file mode 100644 index 00000000..8f2b0c7a --- /dev/null +++ b/backend/services/subscription/usage_tracking_helpers/usage_trends_helpers.py @@ -0,0 +1,171 @@ +"""Helpers extracted from UsageTrackingService.get_usage_trends.""" + +from __future__ import annotations + +from datetime import datetime, timedelta +from typing import Any, Dict, List + +from loguru import logger +from sqlalchemy import func + +from models.subscription_models import APIProvider, APIUsageLog, UsageStatus, UsageSummary + + +def build_billing_periods(months: int) -> List[str]: + """Build billing period keys (YYYY-MM) from oldest to newest.""" + end_date = datetime.now() + periods: List[str] = [] + for i in range(months): + period_date = end_date - timedelta(days=30 * i) + periods.append(period_date.strftime("%Y-%m")) + periods.reverse() + return periods + + +def query_usage_summaries(db: Any, user_id: str, periods: List[str]) -> Dict[str, Any]: + """Load usage summaries for requested periods keyed by billing period.""" + summaries = ( + db.query(UsageSummary) + .filter(UsageSummary.user_id == user_id, UsageSummary.billing_period.in_(periods)) + .all() + ) + return {summary.billing_period: summary for summary in summaries} + + +def self_heal_summaries_from_logs(db: Any, user_id: str, periods: List[str], summary_dict: Dict[str, Any]) -> None: + """Backfill/create usage summaries from aggregated API usage logs.""" + try: + log_stats = ( + db.query( + APIUsageLog.billing_period, + APIUsageLog.provider, + func.count(APIUsageLog.id).label("calls"), + func.sum(APIUsageLog.cost_total).label("cost"), + func.sum(APIUsageLog.tokens_total).label("tokens"), + ) + .filter(APIUsageLog.user_id == user_id, APIUsageLog.billing_period.in_(periods)) + .group_by(APIUsageLog.billing_period, APIUsageLog.provider) + .all() + ) + + log_data_by_period: Dict[str, Dict[str, Dict[str, float | int]]] = {} + for period, provider_enum, calls, cost, tokens in log_stats: + if period not in log_data_by_period: + log_data_by_period[period] = {} + + provider_name = provider_enum.value if hasattr(provider_enum, "value") else str(provider_enum).lower() + if "." in provider_name: + provider_name = provider_name.split(".")[-1].lower() + + if provider_name not in log_data_by_period[period]: + log_data_by_period[period][provider_name] = {"calls": 0, "cost": 0.0, "tokens": 0} + + log_data_by_period[period][provider_name]["calls"] += calls or 0 + log_data_by_period[period][provider_name]["cost"] += float(cost or 0.0) + log_data_by_period[period][provider_name]["tokens"] += tokens or 0 + + for period in periods: + period_logs = log_data_by_period.get(period, {}) + summary = summary_dict.get(period) + + if not summary and period_logs: + logger.info(f"[UsageStats] Self-healing: Creating missing summary for {period}") + summary = UsageSummary( + user_id=user_id, + billing_period=period, + usage_status=UsageStatus.ACTIVE, + total_calls=0, + total_cost=0.0, + total_tokens=0, + ) + db.add(summary) + summary_dict[period] = summary + + if summary and period_logs: + total_calls_calc = 0 + total_cost_calc = 0.0 + total_tokens_calc = 0 + + for provider_name, data in period_logs.items(): + total_calls_calc += int(data["calls"]) + total_cost_calc += float(data["cost"]) + total_tokens_calc += int(data["tokens"]) + + calls_attr = f"{provider_name}_calls" + cost_attr = f"{provider_name}_cost" + tokens_attr = f"{provider_name}_tokens" + + if hasattr(summary, calls_attr): + current_val = getattr(summary, calls_attr, 0) + if current_val < data["calls"]: + setattr(summary, calls_attr, data["calls"]) + + if hasattr(summary, cost_attr): + current_val = getattr(summary, cost_attr, 0.0) + if (float(data["cost"]) - current_val) > 0.000001: + setattr(summary, cost_attr, data["cost"]) + + if hasattr(summary, tokens_attr): + current_val = getattr(summary, tokens_attr, 0) + if current_val < data["tokens"]: + setattr(summary, tokens_attr, data["tokens"]) + + if (summary.total_cost or 0.0) < total_cost_calc: + logger.info( + f"[UsageStats] Self-healing cost for {period}: {summary.total_cost} -> {total_cost_calc}" + ) + summary.total_cost = total_cost_calc + if (summary.total_calls or 0) < total_calls_calc: + summary.total_calls = total_calls_calc + if (summary.total_tokens or 0) < total_tokens_calc: + summary.total_tokens = total_tokens_calc + + db.commit() + except Exception as e: + logger.error(f"Failed to self-heal usage trends: {e}") + db.rollback() + + +def build_usage_trends_response(periods: List[str], summary_dict: Dict[str, Any]) -> Dict[str, Any]: + """Build trends response payload from summaries.""" + trends = { + "periods": periods, + "total_calls": [], + "total_cost": [], + "total_tokens": [], + "provider_trends": {}, + } + + for provider in APIProvider: + provider_name = provider.value + trends["provider_trends"][provider_name] = {"calls": [], "cost": [], "tokens": []} + + for period in periods: + summary = summary_dict.get(period) + if summary: + trends["total_calls"].append(summary.total_calls or 0) + trends["total_cost"].append(summary.total_cost or 0.0) + trends["total_tokens"].append(summary.total_tokens or 0) + + for provider in APIProvider: + provider_name = provider.value + trends["provider_trends"][provider_name]["calls"].append( + getattr(summary, f"{provider_name}_calls", 0) or 0 + ) + trends["provider_trends"][provider_name]["cost"].append( + getattr(summary, f"{provider_name}_cost", 0.0) or 0.0 + ) + trends["provider_trends"][provider_name]["tokens"].append( + getattr(summary, f"{provider_name}_tokens", 0) or 0 + ) + else: + trends["total_calls"].append(0) + trends["total_cost"].append(0.0) + trends["total_tokens"].append(0) + for provider in APIProvider: + provider_name = provider.value + trends["provider_trends"][provider_name]["calls"].append(0) + trends["provider_trends"][provider_name]["cost"].append(0.0) + trends["provider_trends"][provider_name]["tokens"].append(0) + + return trends diff --git a/backend/services/subscription/usage_tracking_service.py b/backend/services/subscription/usage_tracking_service.py index 8cc476a8..5c9cf7e2 100644 --- a/backend/services/subscription/usage_tracking_service.py +++ b/backend/services/subscription/usage_tracking_service.py @@ -10,7 +10,7 @@ import asyncio from typing import Dict, Any, List, Tuple from datetime import datetime, timedelta from sqlalchemy.orm import Session -from sqlalchemy import func, desc +from sqlalchemy import desc from loguru import logger import json @@ -20,6 +20,18 @@ from models.subscription_models import ( ) from .pricing_service import PricingService from .provider_detection import detect_actual_provider +from .usage_tracking_helpers import ( + build_billing_periods, + build_default_usage_percentages, + build_empty_usage_response, + build_provider_breakdown, + build_usage_trends_response, + calculate_final_total_cost, + maybe_persist_reconciled_costs, + query_usage_summaries, + reset_usage_summary_counters, + self_heal_summaries_from_logs, +) class UsageTrackingService: """Service for tracking API usage and managing subscription limits.""" @@ -356,6 +368,10 @@ class UsageTrackingService: def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]: """Get comprehensive usage statistics for a user.""" + + if not user_id: + logger.error("get_user_usage_stats called without user_id") + raise ValueError("user_id is required") requested_billing_period = billing_period period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period) @@ -401,269 +417,54 @@ class UsageTrackingService: pass if not summary: # Still no summary after attempt - # No usage this period - return complete structure with zeros - provider_breakdown = {} - usage_percentages = {} - - # Initialize provider breakdown with zeros - for provider in APIProvider: - provider_name = provider.value - provider_breakdown[provider_name] = { - 'calls': 0, - 'tokens': 0, - 'cost': 0.0 - } - usage_percentages[f"{provider_name}_calls"] = 0 - - usage_percentages['cost'] = 0 - - return { - 'billing_period': billing_period, - 'usage_status': 'active', - 'total_calls': 0, - 'total_tokens': 0, - 'total_cost': 0.0, - 'avg_response_time': 0.0, - 'error_rate': 0.0, - 'last_updated': datetime.now().isoformat(), - 'limits': limits, - 'provider_breakdown': provider_breakdown, - 'alerts': [], - 'usage_percentages': usage_percentages - } + return build_empty_usage_response( + billing_period=billing_period, + limits=limits, + providers=APIProvider, + ) # Provider breakdown - calculate costs first, then use for percentages # Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum) - provider_breakdown = {} - - # Gemini - gemini_calls = getattr(summary, "gemini_calls", 0) or 0 - gemini_tokens = getattr(summary, "gemini_tokens", 0) or 0 - gemini_cost = getattr(summary, "gemini_cost", 0.0) or 0.0 - - # If gemini cost is 0 but there are calls, calculate from usage logs - if gemini_calls > 0 and gemini_cost == 0.0: - gemini_logs = self.db.query(APIUsageLog).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.provider == APIProvider.GEMINI, - APIUsageLog.billing_period == billing_period - ).all() - if gemini_logs: - gemini_cost = sum(float(log.cost_total or 0.0) for log in gemini_logs) - logger.info(f"[UsageStats] Calculated gemini cost from {len(gemini_logs)} logs: ${gemini_cost:.6f}") - - provider_breakdown['gemini'] = { - 'calls': gemini_calls, - 'tokens': gemini_tokens, - 'cost': gemini_cost - } - - # HuggingFace (stored as MISTRAL in database) - mistral_calls = getattr(summary, "mistral_calls", 0) or 0 - mistral_tokens = getattr(summary, "mistral_tokens", 0) or 0 - mistral_cost = getattr(summary, "mistral_cost", 0.0) or 0.0 - - # If mistral (HuggingFace) cost is 0 but there are calls, calculate from usage logs - if mistral_calls > 0 and mistral_cost == 0.0: - mistral_logs = self.db.query(APIUsageLog).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.provider == APIProvider.MISTRAL, - APIUsageLog.billing_period == billing_period - ).all() - if mistral_logs: - mistral_cost = sum(float(log.cost_total or 0.0) for log in mistral_logs) - logger.info(f"[UsageStats] Calculated mistral (HuggingFace) cost from {len(mistral_logs)} logs: ${mistral_cost:.6f}") - - provider_breakdown['huggingface'] = { - 'calls': mistral_calls, - 'tokens': mistral_tokens, - 'cost': mistral_cost - } - - # Add other providers (Video, Audio, Image, Image Edit) for comprehensive breakdown - # Video (WaveSpeed, HuggingFace, etc.) - video_calls = getattr(summary, "video_calls", 0) or 0 - video_cost = getattr(summary, "video_cost", 0.0) or 0.0 - if video_calls > 0 and video_cost == 0.0: - video_logs = self.db.query(APIUsageLog).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.provider == APIProvider.VIDEO, - APIUsageLog.billing_period == billing_period - ).all() - if video_logs: - video_cost = sum(float(log.cost_total or 0.0) for log in video_logs) - - provider_breakdown['video'] = { - 'calls': video_calls, - 'tokens': 0, - 'cost': video_cost - } - - # Audio (WaveSpeed, etc.) - audio_calls = getattr(summary, "audio_calls", 0) or 0 - audio_cost = getattr(summary, "audio_cost", 0.0) or 0.0 - if audio_calls > 0 and audio_cost == 0.0: - audio_logs = self.db.query(APIUsageLog).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.provider == APIProvider.AUDIO, - APIUsageLog.billing_period == billing_period - ).all() - if audio_logs: - audio_cost = sum(float(log.cost_total or 0.0) for log in audio_logs) - - provider_breakdown['audio'] = { - 'calls': audio_calls, - 'tokens': 0, - 'cost': audio_cost - } - - # Image Generation (Stability/WaveSpeed) - stability_calls = getattr(summary, "stability_calls", 0) or 0 - stability_cost = getattr(summary, "stability_cost", 0.0) or 0.0 - if stability_calls > 0 and stability_cost == 0.0: - stability_logs = self.db.query(APIUsageLog).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.provider == APIProvider.STABILITY, - APIUsageLog.billing_period == billing_period - ).all() - if stability_logs: - stability_cost = sum(float(log.cost_total or 0.0) for log in stability_logs) - - provider_breakdown['image'] = { - 'calls': stability_calls, - 'tokens': 0, - 'cost': stability_cost - } - - # Image Editing (WaveSpeed) - image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0 - image_edit_cost = getattr(summary, "image_edit_cost", 0.0) or 0.0 - if image_edit_calls > 0 and image_edit_cost == 0.0: - image_edit_logs = self.db.query(APIUsageLog).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.provider == APIProvider.IMAGE_EDIT, - APIUsageLog.billing_period == billing_period - ).all() - if image_edit_logs: - image_edit_cost = sum(float(log.cost_total or 0.0) for log in image_edit_logs) - - provider_breakdown['image_edit'] = { - 'calls': image_edit_calls, - 'tokens': 0, - 'cost': image_edit_cost - } - - # WaveSpeed (aggregated across Video, Audio, Image, Image Edit) - # Query APIUsageLog directly to get accurate WaveSpeed-specific usage - wavespeed_logs = self.db.query(APIUsageLog).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.billing_period == billing_period, - APIUsageLog.actual_provider_name == "wavespeed" - ).all() - - if wavespeed_logs: - wavespeed_calls = len(wavespeed_logs) - wavespeed_tokens = sum((log.tokens_total or 0) for log in wavespeed_logs) - wavespeed_cost = sum(float(log.cost_total or 0.0) for log in wavespeed_logs) - - provider_breakdown['wavespeed'] = { - 'calls': wavespeed_calls, - 'tokens': wavespeed_tokens, - 'cost': wavespeed_cost - } - logger.info(f"[UsageStats] Calculated WaveSpeed usage: {wavespeed_calls} calls, ${wavespeed_cost:.6f}") - else: - provider_breakdown['wavespeed'] = { - 'calls': 0, - 'tokens': 0, - 'cost': 0.0 - } - - # Search APIs - tavily_calls = getattr(summary, "tavily_calls", 0) or 0 - tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0 - provider_breakdown['tavily'] = { - 'calls': tavily_calls, - 'tokens': 0, - 'cost': tavily_cost - } - - serper_calls = getattr(summary, "serper_calls", 0) or 0 - serper_cost = getattr(summary, "serper_cost", 0.0) or 0.0 - provider_breakdown['serper'] = { - 'calls': serper_calls, - 'tokens': 0, - 'cost': serper_cost - } - - exa_calls = getattr(summary, "exa_calls", 0) or 0 - exa_cost = getattr(summary, "exa_cost", 0.0) or 0.0 - provider_breakdown['exa'] = { - 'calls': exa_calls, - 'tokens': 0, - 'cost': exa_cost - } - - # Calculate total cost from provider breakdown if summary total_cost is 0 - calculated_total_cost = ( - gemini_cost + mistral_cost + video_cost + audio_cost + - stability_cost + image_edit_cost + tavily_cost + serper_cost + exa_cost + provider_breakdown, resolved_costs, core_counts = build_provider_breakdown( + db=self.db, + user_id=user_id, + billing_period=billing_period, + summary=summary, ) + summary_total_cost = summary.total_cost or 0.0 - - # Determine the best cost value to use - # If summary cost is 0 but we have calculated cost, use calculated cost - # If summary cost exists but is less than calculated cost (out of sync), use calculated cost - if calculated_total_cost > summary_total_cost: - final_total_cost = calculated_total_cost - else: - final_total_cost = summary_total_cost - - # If we found a discrepancy (summary cost is 0 or less than calculated), update the DB - if calculated_total_cost > 0 and (summary_total_cost == 0.0 or calculated_total_cost > summary_total_cost): - logger.info(f"[UsageStats] Updating summary costs (was {summary_total_cost}): total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}, video_cost={video_cost:.6f}, audio_cost={audio_cost:.6f}, image_cost={stability_cost:.6f}") - summary.total_cost = final_total_cost - summary.gemini_cost = gemini_cost - summary.mistral_cost = mistral_cost - # Update other provider costs if they exist - if hasattr(summary, 'video_cost'): - summary.video_cost = video_cost - if hasattr(summary, 'audio_cost'): - summary.audio_cost = audio_cost - if hasattr(summary, 'stability_cost'): - summary.stability_cost = stability_cost - if hasattr(summary, 'image_edit_cost'): - summary.image_edit_cost = image_edit_cost - try: - self.db.commit() - except Exception as e: - logger.error(f"[UsageStats] Error updating summary costs: {e}") - self.db.rollback() + calculated_total_cost, final_total_cost = calculate_final_total_cost( + summary_total_cost=summary_total_cost, + resolved_costs=resolved_costs, + ) + + maybe_persist_reconciled_costs( + db=self.db, + summary=summary, + summary_total_cost=summary_total_cost, + calculated_total_cost=calculated_total_cost, + final_total_cost=final_total_cost, + resolved_costs=resolved_costs, + ) # Calculate usage percentages - only for Gemini and HuggingFace # Use the calculated costs for accurate percentages - usage_percentages = {} + usage_percentages = build_default_usage_percentages(APIProvider) if limits: # Gemini gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0 if gemini_call_limit > 0: - usage_percentages['gemini_calls'] = (gemini_calls / gemini_call_limit) * 100 - else: - usage_percentages['gemini_calls'] = 0 + usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100 # HuggingFace (stored as mistral in database) mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0 if mistral_call_limit > 0: - usage_percentages['mistral_calls'] = (mistral_calls / mistral_call_limit) * 100 - else: - usage_percentages['mistral_calls'] = 0 + usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100 # Cost usage percentage - use final_total_cost (calculated from logs if needed) cost_limit = limits['limits'].get('monthly_cost', 0) or 0 if cost_limit > 0: usage_percentages['cost'] = (final_total_cost / cost_limit) * 100 - else: - usage_percentages['cost'] = 0 return { 'billing_period': billing_period, @@ -692,171 +493,10 @@ class UsageTrackingService: def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]: """Get usage trends over time with self-healing from logs.""" - - # Calculate billing periods - end_date = datetime.now() - periods = [] - for i in range(months): - period_date = end_date - timedelta(days=30 * i) - periods.append(period_date.strftime("%Y-%m")) - - periods.reverse() # Oldest first - - # 1. Fetch existing summaries - summaries = self.db.query(UsageSummary).filter( - UsageSummary.user_id == user_id, - UsageSummary.billing_period.in_(periods) - ).all() - summary_dict = {s.billing_period: s for s in summaries} - - # 2. Fetch aggregated logs for self-healing - # Group by (billing_period, provider) to fix provider breakdowns too - try: - log_stats = self.db.query( - APIUsageLog.billing_period, - APIUsageLog.provider, - func.count(APIUsageLog.id).label('calls'), - func.sum(APIUsageLog.cost_total).label('cost'), - func.sum(APIUsageLog.tokens_total).label('tokens') - ).filter( - APIUsageLog.user_id == user_id, - APIUsageLog.billing_period.in_(periods) - ).group_by(APIUsageLog.billing_period, APIUsageLog.provider).all() - - # Organize log stats by period -> provider - log_data_by_period = {} - for period, provider_enum, calls, cost, tokens in log_stats: - if period not in log_data_by_period: - log_data_by_period[period] = {} - - # Handle provider enum or string - provider_name = provider_enum.value if hasattr(provider_enum, 'value') else str(provider_enum).lower() - # Normalize provider names (e.g. 'GEMINI' -> 'gemini') - if '.' in provider_name: - provider_name = provider_name.split('.')[-1].lower() - - if provider_name not in log_data_by_period[period]: - log_data_by_period[period][provider_name] = {'calls': 0, 'cost': 0.0, 'tokens': 0} - - log_data_by_period[period][provider_name]['calls'] += (calls or 0) - log_data_by_period[period][provider_name]['cost'] += float(cost or 0.0) - log_data_by_period[period][provider_name]['tokens'] += (tokens or 0) - - # 3. Update/Create Summaries based on logs - for period in periods: - period_logs = log_data_by_period.get(period, {}) - summary = summary_dict.get(period) - - # If no summary exists but logs do, create one - if not summary and period_logs: - logger.info(f"[UsageStats] Self-healing: Creating missing summary for {period}") - summary = UsageSummary( - user_id=user_id, - billing_period=period, - usage_status=UsageStatus.ACTIVE, - total_calls=0, - total_cost=0.0, - total_tokens=0 - ) - self.db.add(summary) - summary_dict[period] = summary - - if summary and period_logs: - total_calls_calc = 0 - total_cost_calc = 0.0 - total_tokens_calc = 0 - - for prov, data in period_logs.items(): - total_calls_calc += data['calls'] - total_cost_calc += data['cost'] - total_tokens_calc += data['tokens'] - - # Update provider specific fields if logs > summary - calls_attr = f"{prov}_calls" - cost_attr = f"{prov}_cost" - tokens_attr = f"{prov}_tokens" - - if hasattr(summary, calls_attr): - current_val = getattr(summary, calls_attr, 0) - if current_val < data['calls']: - setattr(summary, calls_attr, data['calls']) - - if hasattr(summary, cost_attr): - current_val = getattr(summary, cost_attr, 0.0) - # Use significant difference to avoid float noise - if (data['cost'] - current_val) > 0.000001: - setattr(summary, cost_attr, data['cost']) - - if hasattr(summary, tokens_attr): - current_val = getattr(summary, tokens_attr, 0) - if current_val < data['tokens']: - setattr(summary, tokens_attr, data['tokens']) - - # Update totals if under-reported - if (summary.total_cost or 0.0) < total_cost_calc: - logger.info(f"[UsageStats] Self-healing cost for {period}: {summary.total_cost} -> {total_cost_calc}") - summary.total_cost = total_cost_calc - if (summary.total_calls or 0) < total_calls_calc: - summary.total_calls = total_calls_calc - if (summary.total_tokens or 0) < total_tokens_calc: - summary.total_tokens = total_tokens_calc - - self.db.commit() - except Exception as e: - logger.error(f"Failed to self-heal usage trends: {e}") - self.db.rollback() - - # 4. Construct Return Data - trends = { - 'periods': periods, - 'total_calls': [], - 'total_cost': [], - 'total_tokens': [], - 'provider_trends': {} - } - - # Initialize provider trends structure - for provider in APIProvider: - provider_name = provider.value - trends['provider_trends'][provider_name] = { - 'calls': [], - 'cost': [], - 'tokens': [] - } - - for period in periods: - summary = summary_dict.get(period) - - if summary: - trends['total_calls'].append(summary.total_calls or 0) - trends['total_cost'].append(summary.total_cost or 0.0) - trends['total_tokens'].append(summary.total_tokens or 0) - - # Provider-specific trends - for provider in APIProvider: - provider_name = provider.value - trends['provider_trends'][provider_name]['calls'].append( - getattr(summary, f"{provider_name}_calls", 0) or 0 - ) - trends['provider_trends'][provider_name]['cost'].append( - getattr(summary, f"{provider_name}_cost", 0.0) or 0.0 - ) - trends['provider_trends'][provider_name]['tokens'].append( - getattr(summary, f"{provider_name}_tokens", 0) or 0 - ) - else: - # No data for this period - trends['total_calls'].append(0) - trends['total_cost'].append(0.0) - trends['total_tokens'].append(0) - - for provider in APIProvider: - provider_name = provider.value - trends['provider_trends'][provider_name]['calls'].append(0) - trends['provider_trends'][provider_name]['cost'].append(0.0) - trends['provider_trends'][provider_name]['tokens'].append(0) - - return trends + periods = build_billing_periods(months) + summary_dict = query_usage_summaries(self.db, user_id, periods) + self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict) + return build_usage_trends_response(periods, summary_dict) async def enforce_usage_limits(self, user_id: str, provider: APIProvider, tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]: @@ -890,70 +530,11 @@ class UsageTrackingService: ).first() if not summary: - # Nothing to reset return {"reset": False, "reason": "no_summary"} - # CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan - # Clear LIMIT_REACHED status - summary.usage_status = UsageStatus.ACTIVE - - # Reset all LLM provider call counters - summary.gemini_calls = 0 - summary.openai_calls = 0 - summary.anthropic_calls = 0 - summary.mistral_calls = 0 - summary.wavespeed_calls = 0 - - # Reset all LLM provider token counters - summary.gemini_tokens = 0 - summary.openai_tokens = 0 - summary.anthropic_tokens = 0 - summary.mistral_tokens = 0 - summary.wavespeed_tokens = 0 - - # Reset search/research provider counters - summary.tavily_calls = 0 - summary.serper_calls = 0 - summary.metaphor_calls = 0 - summary.firecrawl_calls = 0 - - # Reset image generation counters - summary.stability_calls = 0 - summary.exa_calls = 0 - - # Reset video generation counters - summary.video_calls = 0 - - # Reset audio generation counters - summary.audio_calls = 0 - - # Reset image editing counters - summary.image_edit_calls = 0 - - # Reset cost counters - summary.gemini_cost = 0.0 - summary.openai_cost = 0.0 - summary.anthropic_cost = 0.0 - summary.mistral_cost = 0.0 - summary.wavespeed_cost = 0.0 - summary.tavily_cost = 0.0 - summary.serper_cost = 0.0 - summary.metaphor_cost = 0.0 - summary.firecrawl_cost = 0.0 - summary.stability_cost = 0.0 - summary.exa_cost = 0.0 - summary.video_cost = 0.0 - summary.image_edit_cost = 0.0 - summary.audio_cost = 0.0 - - # Reset totals - summary.total_calls = 0 - summary.total_tokens = 0 - summary.total_cost = 0.0 - - summary.updated_at = datetime.utcnow() + reset_usage_summary_counters(summary) self.db.commit() - + logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal") return {"reset": True, "counters_reset": True} except Exception as e: