- Fix text selection menu not showing: wire contentRef via inputRef on multiline TextField - Fix blog title not truncating: add min-w-0 for flex item overflow - Fix outline generation 500: escape curly braces in f-string prompt template - Fix content generation 'NoneType not callable': replace SessionLocal() with get_session_for_user(), add db param to MediumBlogGenerator, fix signature mismatch in database_task_manager - Fix writing assistant suggest 500: add auth + user_id to API endpoint and service, replace sync requests with httpx.AsyncClient - Fix hallucination detector 404: explicitly include router in main.py and app.py - Fix missing error_data in task failure responses - Hide CopilotKit web inspector button - Remove hardcoded fallback suggestions from SmartTypingAssist - Fix stale closure refs in SmartTypingAssist handleTypingChange - Add two-column editor layout, stats bar, section hover menu - Various subscription, billing, and research module improvements
365 lines
16 KiB
Python
365 lines
16 KiB
Python
"""
|
|
Usage Tracking Service - Refactored into modular components.
|
|
|
|
This file now serves as a facade that delegates to specialized modules
|
|
in the usage_tracking_modules package.
|
|
|
|
Modules:
|
|
- historical_usage: Functions for aggregating historical usage data
|
|
- usage_stats: Functions for getting user usage statistics
|
|
- usage_trends: Functions for usage trend analysis
|
|
- limit_enforcement: Functions for enforcing usage limits
|
|
- alerts: Functions for usage alerts
|
|
"""
|
|
|
|
from typing import Dict, Any, Tuple, Optional
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy import text
|
|
from loguru import logger
|
|
from datetime import datetime, timedelta
|
|
import time
|
|
|
|
from models.subscription_models import (
|
|
APIProvider, UsageStatus, UserSubscription,
|
|
UsageSummary, APIUsageLog, UsageAlert
|
|
)
|
|
from services.subscription.pricing_service import PricingService
|
|
from services.subscription.provider_detection import detect_actual_provider
|
|
from services.subscription.usage_tracking_helpers import (
|
|
build_provider_breakdown,
|
|
build_default_usage_percentages,
|
|
calculate_final_total_cost,
|
|
maybe_persist_reconciled_costs,
|
|
build_usage_trends_response,
|
|
build_billing_periods,
|
|
query_usage_summaries,
|
|
self_heal_summaries_from_logs,
|
|
reset_usage_summary_counters,
|
|
)
|
|
# Import clear_dashboard_cache lazily to avoid circular import
|
|
def _clear_dashboard_cache_for_user(user_id: str):
|
|
from api.subscription.cache import clear_dashboard_cache as _clear
|
|
return _clear(user_id)
|
|
|
|
from .usage_tracking_modules import (
|
|
get_all_historical_usage,
|
|
get_current_period_usage,
|
|
get_usage_for_period,
|
|
get_user_usage_stats,
|
|
get_usage_trends,
|
|
enforce_usage_limits,
|
|
check_usage_alerts,
|
|
create_usage_alert,
|
|
)
|
|
|
|
|
|
class UsageTrackingService:
|
|
"""Service for tracking API usage and managing billing information."""
|
|
|
|
def __init__(self, db: Session):
|
|
self.db = db
|
|
self.pricing_service = PricingService(db)
|
|
# TTL cache (30s) for enforcement results to cut DB chatter
|
|
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
|
|
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
|
|
"""Return authoritative billing period lookup keys. Always uses subscription period for consistency.
|
|
Maintains backward compatibility with existing calendar-month data."""
|
|
subscription = self.db.query(UserSubscription).filter(
|
|
UserSubscription.user_id == user_id
|
|
).first()
|
|
|
|
# If caller explicitly requested a billing period, use it
|
|
if billing_period:
|
|
return {
|
|
"billing_period": billing_period,
|
|
"lookup_periods": [billing_period],
|
|
"period_start": subscription.current_period_start if subscription else None,
|
|
"period_end": subscription.current_period_end if subscription else None,
|
|
}
|
|
|
|
# Get subscription period if available
|
|
subscription_period = None
|
|
if subscription and subscription.current_period_start:
|
|
subscription_period = subscription.current_period_start.strftime("%Y-%m")
|
|
|
|
# Get calendar period
|
|
calendar_period = datetime.now().strftime("%Y-%m")
|
|
|
|
# Check which period has usage data
|
|
from models.subscription_models import UsageSummary
|
|
|
|
if subscription_period:
|
|
# Check if data exists for subscription period
|
|
sub_data = self.db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == subscription_period
|
|
).first()
|
|
|
|
if sub_data:
|
|
# Use subscription period (has data)
|
|
return {
|
|
"billing_period": subscription_period,
|
|
"lookup_periods": [subscription_period],
|
|
"period_start": subscription.current_period_start,
|
|
"period_end": subscription.current_period_end,
|
|
}
|
|
|
|
# No data for subscription period, check calendar period (backward compatibility)
|
|
if calendar_period != subscription_period:
|
|
cal_data = self.db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == calendar_period
|
|
).first()
|
|
|
|
if cal_data:
|
|
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {subscription_period} has no data)")
|
|
return {
|
|
"billing_period": calendar_period,
|
|
"lookup_periods": [calendar_period],
|
|
"period_start": None,
|
|
"period_end": None,
|
|
}
|
|
|
|
# No data in either period, use subscription period
|
|
return {
|
|
"billing_period": subscription_period,
|
|
"lookup_periods": [subscription_period],
|
|
"period_start": subscription.current_period_start,
|
|
"period_end": subscription.current_period_end,
|
|
}
|
|
|
|
# No subscription, check for any existing data
|
|
latest_summary = self.db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id
|
|
).order_by(UsageSummary.billing_period.desc()).first()
|
|
|
|
if latest_summary:
|
|
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period} for user {user_id}")
|
|
return {
|
|
"billing_period": latest_summary.billing_period,
|
|
"lookup_periods": [latest_summary.billing_period],
|
|
"period_start": None,
|
|
"period_end": None,
|
|
}
|
|
|
|
# Last fallback to calendar month for free tier / no subscription
|
|
return {
|
|
"billing_period": calendar_period,
|
|
"lookup_periods": [calendar_period],
|
|
"period_start": None,
|
|
"period_end": None,
|
|
}
|
|
|
|
# Delegate to modular functions
|
|
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
|
"""Get comprehensive usage statistics for a user."""
|
|
return get_user_usage_stats(user_id, billing_period, self.db, self.pricing_service)
|
|
|
|
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
|
"""Get ALL historical usage data aggregated across all billing periods."""
|
|
return get_all_historical_usage(user_id, self.db, self.pricing_service)
|
|
|
|
def get_current_period_usage(self, user_id: str) -> Dict[str, Any]:
|
|
"""Get current billing period usage with correct per-period limit percentages."""
|
|
return get_current_period_usage(user_id, self.db, self.pricing_service)
|
|
|
|
def get_usage_for_period(self, user_id: str, billing_period: str) -> Dict[str, Any]:
|
|
"""Get usage for a specific billing period."""
|
|
return get_usage_for_period(user_id, billing_period, self.db, self.pricing_service)
|
|
|
|
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
|
"""Get usage trends over time with self-healing from logs."""
|
|
return get_usage_trends(user_id, months, self.db)
|
|
|
|
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
|
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
|
"""Enforce usage limits before making an API call."""
|
|
return enforce_usage_limits(user_id, provider, tokens_requested, self.db, self.pricing_service)
|
|
|
|
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
|
"""Check if usage alerts should be sent."""
|
|
check_usage_alerts(user_id, provider, billing_period, self.db, self.pricing_service)
|
|
|
|
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
|
threshold: int, current_usage: int, limit: int,
|
|
billing_period: str):
|
|
"""Create a usage alert."""
|
|
create_usage_alert(user_id, provider, threshold, current_usage, limit, billing_period, self.db)
|
|
|
|
# Keep the track_api_usage method here as it's the core functionality
|
|
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
|
endpoint: str, method: str, model_used: str = None,
|
|
tokens_input: int = 0, tokens_output: int = 0,
|
|
response_time: float = 0.0, status_code: int = 200,
|
|
request_size: int = None, response_size: int = None,
|
|
user_agent: str = None, ip_address: str = None,
|
|
error_message: str = None, retry_count: int = 0,
|
|
**kwargs) -> Dict[str, Any]:
|
|
"""Track an API usage event and update billing information."""
|
|
|
|
try:
|
|
# Calculate costs
|
|
# Use specific model names instead of generic defaults
|
|
default_models = {
|
|
APIProvider.GEMINI: "gemini-2.5-flash", # Use Flash as default (cost-effective)
|
|
APIProvider.OPENAI: "gpt-4o-mini", # Use Mini as default (cost-effective)
|
|
APIProvider.ANTHROPIC: "claude-3.5-sonnet", # Use Sonnet as default
|
|
APIProvider.MISTRAL: "openai/gpt-oss-120b:groq", # HuggingFace default model
|
|
APIProvider.WAVESPEED: "openai/gpt-oss-120b" # WaveSpeed default model
|
|
}
|
|
|
|
# For HuggingFace (stored as MISTRAL), use the actual model name or default
|
|
if provider == APIProvider.MISTRAL:
|
|
# HuggingFace models - try to match the actual model name from model_used
|
|
if model_used:
|
|
model_name = model_used
|
|
else:
|
|
model_name = default_models.get(APIProvider.MISTRAL, "openai/gpt-oss-120b:groq")
|
|
else:
|
|
model_name = model_used or default_models.get(provider, f"{provider.value}-default")
|
|
|
|
cost_data = self.pricing_service.calculate_api_cost(
|
|
provider=provider,
|
|
model_name=model_name,
|
|
tokens_input=tokens_input,
|
|
tokens_output=tokens_output,
|
|
request_count=1,
|
|
**kwargs
|
|
)
|
|
|
|
# Create usage log entry
|
|
period_keys = self._get_authoritative_billing_period_keys(user_id)
|
|
billing_period = period_keys["billing_period"]
|
|
|
|
# Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.)
|
|
actual_provider_name = detect_actual_provider(
|
|
provider_enum=provider,
|
|
model_name=model_used,
|
|
endpoint=endpoint
|
|
)
|
|
|
|
usage_log = APIUsageLog(
|
|
user_id=user_id,
|
|
provider=provider,
|
|
endpoint=endpoint,
|
|
method=method,
|
|
model_used=model_used,
|
|
actual_provider_name=actual_provider_name, # Track actual provider
|
|
tokens_input=tokens_input,
|
|
tokens_output=tokens_output,
|
|
tokens_total=(tokens_input or 0) + (tokens_output or 0),
|
|
cost_input=cost_data['cost_input'],
|
|
cost_output=cost_data['cost_output'],
|
|
cost_total=cost_data['cost_total'],
|
|
response_time=response_time,
|
|
status_code=status_code,
|
|
request_size=request_size,
|
|
response_size=response_size,
|
|
user_agent=user_agent,
|
|
ip_address=ip_address,
|
|
error_message=error_message,
|
|
retry_count=retry_count,
|
|
billing_period=billing_period
|
|
)
|
|
|
|
self.db.add(usage_log)
|
|
|
|
# Update usage summary
|
|
await self._update_usage_summary(
|
|
user_id=user_id,
|
|
provider=provider,
|
|
tokens_used=(tokens_input or 0) + (tokens_output or 0),
|
|
cost=cost_data['cost_total'],
|
|
billing_period=billing_period,
|
|
response_time=response_time,
|
|
is_error=status_code >= 400
|
|
)
|
|
|
|
# Check for usage alerts
|
|
await self._check_usage_alerts(user_id, provider, billing_period)
|
|
|
|
self.db.commit()
|
|
|
|
# Invalidate dashboard cache so header stats update immediately
|
|
try:
|
|
_clear_dashboard_cache_for_user(user_id)
|
|
except Exception as cache_err:
|
|
logger.warning(f"Failed to clear dashboard cache: {cache_err}")
|
|
|
|
return {
|
|
"success": True,
|
|
"cost": cost_data['cost_total'],
|
|
"tokens": (tokens_input or 0) + (tokens_output or 0),
|
|
"billing_period": billing_period
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to track API usage: {e}")
|
|
self.db.rollback()
|
|
return {
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
|
|
tokens_used: int, cost: float,
|
|
billing_period: str,
|
|
response_time: float = 0.0,
|
|
is_error: bool = False):
|
|
"""Update or create usage summary for the billing period."""
|
|
|
|
# Get or create summary
|
|
summary = self.db.query(UsageSummary).filter(
|
|
UsageSummary.user_id == user_id,
|
|
UsageSummary.billing_period == billing_period
|
|
).first()
|
|
|
|
if not summary:
|
|
summary = UsageSummary(
|
|
user_id=user_id,
|
|
billing_period=billing_period,
|
|
usage_status=UsageStatus.ACTIVE,
|
|
total_calls=0,
|
|
total_tokens=0,
|
|
total_cost=0.0
|
|
)
|
|
self.db.add(summary)
|
|
|
|
# Update counts
|
|
summary.total_calls = (summary.total_calls or 0) + 1
|
|
summary.total_tokens = (summary.total_tokens or 0) + tokens_used
|
|
summary.total_cost = (summary.total_cost or 0.0) + cost
|
|
|
|
# Update provider-specific counts
|
|
provider_name = provider.value
|
|
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
|
|
setattr(summary, f"{provider_name}_calls", current_calls + 1)
|
|
|
|
# Update provider-specific tokens
|
|
tokens_attr = f"{provider_name}_tokens"
|
|
if hasattr(summary, tokens_attr):
|
|
current_tokens = getattr(summary, tokens_attr, 0) or 0
|
|
setattr(summary, tokens_attr, current_tokens + tokens_used)
|
|
|
|
# Update provider-specific cost
|
|
cost_attr = f"{provider_name}_cost"
|
|
if hasattr(summary, cost_attr):
|
|
current_cost = getattr(summary, cost_attr, 0.0) or 0.0
|
|
setattr(summary, cost_attr, current_cost + cost)
|
|
|
|
# Update response time (rolling average)
|
|
if response_time > 0:
|
|
current_avg = summary.avg_response_time or 0.0
|
|
current_calls = summary.total_calls or 1
|
|
summary.avg_response_time = ((current_avg * (current_calls - 1)) + response_time) / current_calls
|
|
|
|
# Update error rate
|
|
if is_error:
|
|
summary.error_count = (summary.error_count or 0) + 1
|
|
total_calls = summary.total_calls or 1
|
|
summary.error_rate = (summary.error_count / total_calls) * 100
|
|
|
|
summary.updated_at = datetime.utcnow()
|