Files
ALwrity/backend/services/subscription/usage_tracking_service.py

365 lines
16 KiB
Python

"""
Usage Tracking Service - Refactored into modular components.
This file now serves as a facade that delegates to specialized modules
in the usage_tracking_modules package.
Modules:
- historical_usage: Functions for aggregating historical usage data
- usage_stats: Functions for getting user usage statistics
- usage_trends: Functions for usage trend analysis
- limit_enforcement: Functions for enforcing usage limits
- alerts: Functions for usage alerts
"""
from typing import Dict, Any, Tuple, Optional
from sqlalchemy.orm import Session
from sqlalchemy import text
from loguru import logger
from datetime import datetime, timedelta
import time
from models.subscription_models import (
APIProvider, UsageStatus, UserSubscription,
UsageSummary, APIUsageLog, UsageAlert
)
from services.subscription.pricing_service import PricingService
from services.subscription.provider_detection import detect_actual_provider
from services.subscription.usage_tracking_helpers import (
build_provider_breakdown,
build_default_usage_percentages,
calculate_final_total_cost,
maybe_persist_reconciled_costs,
build_usage_trends_response,
build_billing_periods,
query_usage_summaries,
self_heal_summaries_from_logs,
reset_usage_summary_counters,
)
# Import clear_dashboard_cache lazily to avoid circular import
def _clear_dashboard_cache_for_user(user_id: str):
from services.subscription.cache import clear_dashboard_cache as _clear
return _clear(user_id)
from .usage_tracking_modules import (
get_all_historical_usage,
get_current_period_usage,
get_usage_for_period,
get_user_usage_stats,
get_usage_trends,
enforce_usage_limits,
check_usage_alerts,
create_usage_alert,
)
class UsageTrackingService:
"""Service for tracking API usage and managing billing information."""
def __init__(self, db: Session):
self.db = db
self.pricing_service = PricingService(db)
# TTL cache (30s) for enforcement results to cut DB chatter
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
"""Return authoritative billing period lookup keys. Always uses subscription period for consistency.
Maintains backward compatibility with existing calendar-month data."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()
# If caller explicitly requested a billing period, use it
if billing_period:
return {
"billing_period": billing_period,
"lookup_periods": [billing_period],
"period_start": subscription.current_period_start if subscription else None,
"period_end": subscription.current_period_end if subscription else None,
}
# Get subscription period if available
subscription_period = None
if subscription and subscription.current_period_start:
subscription_period = subscription.current_period_start.strftime("%Y-%m")
# Get calendar period
calendar_period = datetime.now().strftime("%Y-%m")
# Check which period has usage data
from models.subscription_models import UsageSummary
if subscription_period:
# Check if data exists for subscription period
sub_data = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == subscription_period
).first()
if sub_data:
# Use subscription period (has data)
return {
"billing_period": subscription_period,
"lookup_periods": [subscription_period],
"period_start": subscription.current_period_start,
"period_end": subscription.current_period_end,
}
# No data for subscription period, check calendar period (backward compatibility)
if calendar_period != subscription_period:
cal_data = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == calendar_period
).first()
if cal_data:
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {subscription_period} has no data)")
return {
"billing_period": calendar_period,
"lookup_periods": [calendar_period],
"period_start": None,
"period_end": None,
}
# No data in either period, use subscription period
return {
"billing_period": subscription_period,
"lookup_periods": [subscription_period],
"period_start": subscription.current_period_start,
"period_end": subscription.current_period_end,
}
# No subscription, check for any existing data
latest_summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id
).order_by(UsageSummary.billing_period.desc()).first()
if latest_summary:
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period} for user {user_id}")
return {
"billing_period": latest_summary.billing_period,
"lookup_periods": [latest_summary.billing_period],
"period_start": None,
"period_end": None,
}
# Last fallback to calendar month for free tier / no subscription
return {
"billing_period": calendar_period,
"lookup_periods": [calendar_period],
"period_start": None,
"period_end": None,
}
# Delegate to modular functions
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
return get_user_usage_stats(user_id, billing_period, self.db, self.pricing_service)
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
"""Get ALL historical usage data aggregated across all billing periods."""
return get_all_historical_usage(user_id, self.db, self.pricing_service)
def get_current_period_usage(self, user_id: str) -> Dict[str, Any]:
"""Get current billing period usage with correct per-period limit percentages."""
return get_current_period_usage(user_id, self.db, self.pricing_service)
def get_usage_for_period(self, user_id: str, billing_period: str) -> Dict[str, Any]:
"""Get usage for a specific billing period."""
return get_usage_for_period(user_id, billing_period, self.db, self.pricing_service)
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
return get_usage_trends(user_id, months, self.db)
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
return enforce_usage_limits(user_id, provider, tokens_requested, self.db, self.pricing_service)
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
check_usage_alerts(user_id, provider, billing_period, self.db, self.pricing_service)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
create_usage_alert(user_id, provider, threshold, current_usage, limit, billing_period, self.db)
# Keep the track_api_usage method here as it's the core functionality
async def track_api_usage(self, user_id: str, provider: APIProvider,
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
"""Track an API usage event and update billing information."""
try:
# Calculate costs
# Use specific model names instead of generic defaults
default_models = {
APIProvider.GEMINI: "gemini-2.5-flash", # Use Flash as default (cost-effective)
APIProvider.OPENAI: "gpt-4o-mini", # Use Mini as default (cost-effective)
APIProvider.ANTHROPIC: "claude-3.5-sonnet", # Use Sonnet as default
APIProvider.MISTRAL: "openai/gpt-oss-120b:groq", # HuggingFace default model
APIProvider.WAVESPEED: "openai/gpt-oss-120b" # WaveSpeed default model
}
# For HuggingFace (stored as MISTRAL), use the actual model name or default
if provider == APIProvider.MISTRAL:
# HuggingFace models - try to match the actual model name from model_used
if model_used:
model_name = model_used
else:
model_name = default_models.get(APIProvider.MISTRAL, "openai/gpt-oss-120b:groq")
else:
model_name = model_used or default_models.get(provider, f"{provider.value}-default")
cost_data = self.pricing_service.calculate_api_cost(
provider=provider,
model_name=model_name,
tokens_input=tokens_input,
tokens_output=tokens_output,
request_count=1,
**kwargs
)
# Create usage log entry
period_keys = self._get_authoritative_billing_period_keys(user_id)
billing_period = period_keys["billing_period"]
# Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.)
actual_provider_name = detect_actual_provider(
provider_enum=provider,
model_name=model_used,
endpoint=endpoint
)
usage_log = APIUsageLog(
user_id=user_id,
provider=provider,
endpoint=endpoint,
method=method,
model_used=model_used,
actual_provider_name=actual_provider_name, # Track actual provider
tokens_input=tokens_input,
tokens_output=tokens_output,
tokens_total=(tokens_input or 0) + (tokens_output or 0),
cost_input=cost_data['cost_input'],
cost_output=cost_data['cost_output'],
cost_total=cost_data['cost_total'],
response_time=response_time,
status_code=status_code,
request_size=request_size,
response_size=response_size,
user_agent=user_agent,
ip_address=ip_address,
error_message=error_message,
retry_count=retry_count,
billing_period=billing_period
)
self.db.add(usage_log)
# Update usage summary
await self._update_usage_summary(
user_id=user_id,
provider=provider,
tokens_used=(tokens_input or 0) + (tokens_output or 0),
cost=cost_data['cost_total'],
billing_period=billing_period,
response_time=response_time,
is_error=status_code >= 400
)
# Check for usage alerts
await self._check_usage_alerts(user_id, provider, billing_period)
self.db.commit()
# Invalidate dashboard cache so header stats update immediately
try:
_clear_dashboard_cache_for_user(user_id)
except Exception as cache_err:
logger.warning(f"Failed to clear dashboard cache: {cache_err}")
return {
"success": True,
"cost": cost_data['cost_total'],
"tokens": (tokens_input or 0) + (tokens_output or 0),
"billing_period": billing_period
}
except Exception as e:
logger.error(f"Failed to track API usage: {e}")
self.db.rollback()
return {
"success": False,
"error": str(e)
}
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
tokens_used: int, cost: float,
billing_period: str,
response_time: float = 0.0,
is_error: bool = False):
"""Update or create usage summary for the billing period."""
# Get or create summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=billing_period,
usage_status=UsageStatus.ACTIVE,
total_calls=0,
total_tokens=0,
total_cost=0.0
)
self.db.add(summary)
# Update counts
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_used
summary.total_cost = (summary.total_cost or 0.0) + cost
# Update provider-specific counts
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
setattr(summary, f"{provider_name}_calls", current_calls + 1)
# Update provider-specific tokens
tokens_attr = f"{provider_name}_tokens"
if hasattr(summary, tokens_attr):
current_tokens = getattr(summary, tokens_attr, 0) or 0
setattr(summary, tokens_attr, current_tokens + tokens_used)
# Update provider-specific cost
cost_attr = f"{provider_name}_cost"
if hasattr(summary, cost_attr):
current_cost = getattr(summary, cost_attr, 0.0) or 0.0
setattr(summary, cost_attr, current_cost + cost)
# Update response time (rolling average)
if response_time > 0:
current_avg = summary.avg_response_time or 0.0
current_calls = summary.total_calls or 1
summary.avg_response_time = ((current_avg * (current_calls - 1)) + response_time) / current_calls
# Update error rate
if is_error:
summary.error_count = (summary.error_count or 0) + 1
total_calls = summary.total_calls or 1
summary.error_rate = (summary.error_count / total_calls) * 100
summary.updated_at = datetime.utcnow()