""" Usage Tracking Service Comprehensive tracking of API usage, costs, and subscription limits. """ # Ensure Optional is available in global scope for dynamic imports from typing import Optional import asyncio from typing import Dict, Any, List, Tuple from datetime import datetime, timedelta from sqlalchemy.orm import Session from sqlalchemy import desc from loguru import logger import json from api.subscription.cache import clear_dashboard_cache from models.subscription_models import ( APIUsageLog, UsageSummary, APIProvider, UsageAlert, UserSubscription, UsageStatus ) from .pricing_service import PricingService from .provider_detection import detect_actual_provider from .usage_tracking_helpers import ( build_billing_periods, build_default_usage_percentages, build_empty_usage_response, build_provider_breakdown, build_usage_trends_response, calculate_final_total_cost, maybe_persist_reconciled_costs, query_usage_summaries, reset_usage_summary_counters, self_heal_summaries_from_logs, ) class UsageTrackingService: """Service for tracking API usage and managing subscription limits.""" def __init__(self, db: Session): self.db = db self.pricing_service = PricingService(db) # TTL cache (30s) for enforcement results to cut DB chatter # key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime } self._enforce_cache: Dict[str, Dict[str, Any]] = {} def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]: """Return authoritative billing period lookup keys. Always uses calendar month for consistency.""" subscription = self.db.query(UserSubscription).filter( UserSubscription.user_id == user_id ).first() # If caller explicitly requested a billing period, use it if billing_period: return { "billing_period": billing_period, "lookup_periods": [billing_period], "period_start": subscription.current_period_start if subscription else None, "period_end": subscription.current_period_end if subscription else None, } # ALWAYS use current calendar month for billing period to ensure consistency # This prevents data loss when subscription spans month boundaries current_period = datetime.now().strftime("%Y-%m") return { "billing_period": current_period, "lookup_periods": [current_period], "period_start": subscription.current_period_start if subscription else None, "period_end": subscription.current_period_end if subscription else None, } async def track_api_usage(self, user_id: str, provider: APIProvider, endpoint: str, method: str, model_used: str = None, tokens_input: int = 0, tokens_output: int = 0, response_time: float = 0.0, status_code: int = 200, request_size: int = None, response_size: int = None, user_agent: str = None, ip_address: str = None, error_message: str = None, retry_count: int = 0, **kwargs) -> Dict[str, Any]: """Track an API usage event and update billing information.""" try: # Calculate costs # Use specific model names instead of generic defaults default_models = { APIProvider.GEMINI: "gemini-2.5-flash", # Use Flash as default (cost-effective) APIProvider.OPENAI: "gpt-4o-mini", # Use Mini as default (cost-effective) APIProvider.ANTHROPIC: "claude-3.5-sonnet", # Use Sonnet as default APIProvider.MISTRAL: "openai/gpt-oss-120b:groq", # HuggingFace default model APIProvider.WAVESPEED: "openai/gpt-oss-120b" # WaveSpeed default model } # For HuggingFace (stored as MISTRAL), use the actual model name or default if provider == APIProvider.MISTRAL: # HuggingFace models - try to match the actual model name from model_used if model_used: model_name = model_used else: model_name = default_models.get(APIProvider.MISTRAL, "openai/gpt-oss-120b:groq") else: model_name = model_used or default_models.get(provider, f"{provider.value}-default") cost_data = self.pricing_service.calculate_api_cost( provider=provider, model_name=model_name, tokens_input=tokens_input, tokens_output=tokens_output, request_count=1, **kwargs ) # Create usage log entry period_keys = self._get_authoritative_billing_period_keys(user_id) billing_period = period_keys["billing_period"] # Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.) actual_provider_name = detect_actual_provider( provider_enum=provider, model_name=model_used, endpoint=endpoint ) usage_log = APIUsageLog( user_id=user_id, provider=provider, endpoint=endpoint, method=method, model_used=model_used, actual_provider_name=actual_provider_name, # Track actual provider tokens_input=tokens_input, tokens_output=tokens_output, tokens_total=(tokens_input or 0) + (tokens_output or 0), cost_input=cost_data['cost_input'], cost_output=cost_data['cost_output'], cost_total=cost_data['cost_total'], response_time=response_time, status_code=status_code, request_size=request_size, response_size=response_size, user_agent=user_agent, ip_address=ip_address, error_message=error_message, retry_count=retry_count, billing_period=billing_period ) self.db.add(usage_log) # Update usage summary await self._update_usage_summary( user_id=user_id, provider=provider, tokens_used=(tokens_input or 0) + (tokens_output or 0), cost=cost_data['cost_total'], billing_period=billing_period, response_time=response_time, is_error=status_code >= 400 ) # Check for usage alerts await self._check_usage_alerts(user_id, provider, billing_period) self.db.commit() # Invalidate dashboard cache so header stats update immediately try: clear_dashboard_cache(user_id) except Exception as cache_err: logger.debug(f"Could not clear dashboard cache: {cache_err}") logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}") return { 'usage_logged': True, 'cost': cost_data['cost_total'], 'tokens_used': (tokens_input or 0) + (tokens_output or 0), 'billing_period': billing_period } except Exception as e: logger.error(f"Error tracking API usage: {str(e)}") self.db.rollback() return { 'usage_logged': False, 'error': str(e) } async def _update_usage_summary(self, user_id: str, provider: APIProvider, tokens_used: int, cost: float, billing_period: str, response_time: float, is_error: bool): """Update the usage summary for a user.""" # Get or create usage summary period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period) summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() if not summary: logger.info(f"[UsageTracking] Creating new UsageSummary for user={user_id}, period={period_keys['billing_period']}") summary = UsageSummary( user_id=user_id, billing_period=period_keys["billing_period"] ) self.db.add(summary) else: logger.debug(f"[UsageTracking] Found existing UsageSummary for user={user_id}, period={summary.billing_period}, calls={summary.total_calls}") # Update provider-specific counters provider_name = provider.value current_calls = getattr(summary, f"{provider_name}_calls", 0) setattr(summary, f"{provider_name}_calls", current_calls + 1) # Update token usage for LLM providers if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL, APIProvider.WAVESPEED]: current_tokens = getattr(summary, f"{provider_name}_tokens", 0) setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used) # Update cost current_cost = getattr(summary, f"{provider_name}_cost", 0.0) setattr(summary, f"{provider_name}_cost", current_cost + cost) # Update totals summary.total_calls += 1 summary.total_tokens += tokens_used summary.total_cost += cost # Update performance metrics if summary.total_calls > 0: # Update average response time total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time summary.avg_response_time = total_response_time / summary.total_calls # Update error rate if is_error: error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1 summary.error_rate = (error_count / summary.total_calls) * 100 else: error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) summary.error_rate = (error_count / summary.total_calls) * 100 # Update usage status based on limits await self._update_usage_status(summary) summary.updated_at = datetime.utcnow() async def _update_usage_status(self, summary: UsageSummary): """Update usage status based on subscription limits.""" limits = self.pricing_service.get_user_limits(summary.user_id) if not limits: return # Check various limits and determine status max_usage_percentage = 0.0 # Check cost limit cost_limit = limits['limits'].get('monthly_cost', 0) if cost_limit > 0: cost_usage_pct = (summary.total_cost / cost_limit) * 100 max_usage_percentage = max(max_usage_percentage, cost_usage_pct) # Check call limits for each provider for provider in APIProvider: provider_name = provider.value current_calls = getattr(summary, f"{provider_name}_calls", 0) call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if call_limit > 0: call_usage_pct = (current_calls / call_limit) * 100 max_usage_percentage = max(max_usage_percentage, call_usage_pct) # Update status based on highest usage percentage if max_usage_percentage >= 100: summary.usage_status = UsageStatus.LIMIT_REACHED elif max_usage_percentage >= 80: summary.usage_status = UsageStatus.WARNING else: summary.usage_status = UsageStatus.ACTIVE async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str): """Check if usage alerts should be sent.""" # Get current usage period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period) summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() if not summary: return # Get user limits limits = self.pricing_service.get_user_limits(user_id) if not limits: return # Check for alert thresholds (80%, 90%, 100%) thresholds = [80, 90, 100] for threshold in thresholds: # Check if alert already sent for this threshold existing_alert = self.db.query(UsageAlert).filter( UsageAlert.user_id == user_id, UsageAlert.billing_period == billing_period, UsageAlert.threshold_percentage == threshold, UsageAlert.provider == provider, UsageAlert.is_sent == True ).first() if existing_alert: continue # Check if threshold is reached provider_name = provider.value current_calls = getattr(summary, f"{provider_name}_calls", 0) call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if call_limit > 0: usage_percentage = (current_calls / call_limit) * 100 if usage_percentage >= threshold: await self._create_usage_alert( user_id=user_id, provider=provider, threshold=threshold, current_usage=current_calls, limit=call_limit, billing_period=billing_period ) async def _create_usage_alert(self, user_id: str, provider: APIProvider, threshold: int, current_usage: int, limit: int, billing_period: str): """Create a usage alert.""" # Determine alert type and severity if threshold >= 100: alert_type = "limit_reached" severity = "error" title = f"API Limit Reached - {provider.value.title()}" message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period." elif threshold >= 90: alert_type = "usage_warning" severity = "warning" title = f"API Usage Warning - {provider.value.title()}" message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)." else: alert_type = "usage_warning" severity = "info" title = f"API Usage Notice - {provider.value.title()}" message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)." alert = UsageAlert( user_id=user_id, alert_type=alert_type, threshold_percentage=threshold, provider=provider, title=title, message=message, severity=severity, billing_period=billing_period ) self.db.add(alert) logger.info(f"Created usage alert for {user_id}: {title}") def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]: """Get comprehensive usage statistics for a user.""" if not user_id: logger.error("get_user_usage_stats called without user_id") raise ValueError("user_id is required") requested_billing_period = billing_period period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period) billing_period = period_keys["billing_period"] logger.debug(f"[get_user_usage_stats] user={user_id}, billing_period={billing_period}, lookup_periods={period_keys['lookup_periods']}") # Get usage summary summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() if summary: logger.debug(f"[get_user_usage_stats] Found summary: period={summary.billing_period}, calls={summary.total_calls}, cost={summary.total_cost}") else: logger.debug(f"[get_user_usage_stats] No summary found for user={user_id}, period={billing_period}") # Get user limits limits = self.pricing_service.get_user_limits(user_id) # Get recent alerts alerts = self.db.query(UsageAlert).filter( UsageAlert.user_id == user_id, UsageAlert.billing_period == billing_period, UsageAlert.is_read == False ).order_by(UsageAlert.created_at.desc()).limit(10).all() if not summary: # If no summary exists for current period, we should initialize it # This handles the "start of month" case where a user logs in but hasn't made calls yet if not requested_billing_period: logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}") summary = UsageSummary( user_id=user_id, billing_period=billing_period, usage_status=UsageStatus.ACTIVE, total_calls=0, total_tokens=0, total_cost=0.0 ) try: self.db.add(summary) self.db.commit() self.db.refresh(summary) except Exception as e: logger.error(f"Failed to initialize summary: {e}") self.db.rollback() # Fallback to zero-struct return if DB write fails pass if not summary: # Still no summary after attempt return build_empty_usage_response( billing_period=billing_period, limits=limits, providers=APIProvider, ) # Provider breakdown - calculate costs first, then use for percentages # Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum) provider_breakdown, resolved_costs, core_counts = build_provider_breakdown( db=self.db, user_id=user_id, billing_period=billing_period, summary=summary, ) summary_total_cost = summary.total_cost or 0.0 calculated_total_cost, final_total_cost = calculate_final_total_cost( summary_total_cost=summary_total_cost, resolved_costs=resolved_costs, ) maybe_persist_reconciled_costs( db=self.db, summary=summary, summary_total_cost=summary_total_cost, calculated_total_cost=calculated_total_cost, final_total_cost=final_total_cost, resolved_costs=resolved_costs, ) # Calculate usage percentages - only for Gemini and HuggingFace # Use the calculated costs for accurate percentages usage_percentages = build_default_usage_percentages(APIProvider) if limits: # Gemini gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0 if gemini_call_limit > 0: usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100 # HuggingFace (stored as mistral in database) mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0 if mistral_call_limit > 0: usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100 # Cost usage percentage - use final_total_cost (calculated from logs if needed) cost_limit = limits['limits'].get('monthly_cost', 0) or 0 if cost_limit > 0: usage_percentages['cost'] = (final_total_cost / cost_limit) * 100 return { 'billing_period': billing_period, 'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status), 'total_calls': summary.total_calls or 0, 'total_tokens': summary.total_tokens or 0, 'total_cost': final_total_cost, 'avg_response_time': summary.avg_response_time or 0.0, 'error_rate': summary.error_rate or 0.0, 'limits': limits, 'provider_breakdown': provider_breakdown, 'alerts': [ { 'id': alert.id, 'type': alert.alert_type, 'title': alert.title, 'message': alert.message, 'severity': alert.severity, 'created_at': alert.created_at.isoformat() } for alert in alerts ], 'usage_percentages': usage_percentages, 'last_updated': summary.updated_at.isoformat() } def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]: """Get usage trends over time with self-healing from logs.""" periods = build_billing_periods(months) summary_dict = query_usage_summaries(self.db, user_id, periods) self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict) return build_usage_trends_response(periods, summary_dict) async def enforce_usage_limits(self, user_id: str, provider: APIProvider, tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]: """Enforce usage limits before making an API call.""" # Check short-lived cache first (30s) cache_key = f"{user_id}:{provider.value}" now = datetime.utcnow() cached = self._enforce_cache.get(cache_key) if cached and cached.get('expires_at') and cached['expires_at'] > now: return tuple(cached['result']) # type: ignore result = self.pricing_service.check_usage_limits( user_id=user_id, provider=provider, tokens_requested=tokens_requested ) self._enforce_cache[cache_key] = { 'result': result, 'expires_at': now + timedelta(seconds=30) } return result async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]: """Reset usage status and counters for the current billing period (after plan renewal/change).""" period_keys = self._get_authoritative_billing_period_keys(user_id) billing_period = period_keys["billing_period"] summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() if not summary: return {"reset": False, "reason": "no_summary"} try: reset_usage_summary_counters(summary) self.db.commit() # Invalidate dashboard cache so header stats update after reset try: clear_dashboard_cache(user_id) except Exception as cache_err: logger.debug(f"Could not clear dashboard cache: {cache_err}") logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal") return {"reset": True, "counters_reset": True} except Exception as e: self.db.rollback() logger.error(f"Error resetting usage status: {e}") return {"reset": False, "error": str(e)}