Add comprehensive usage-based subscription system with API tracking

Co-authored-by: ajay.calsoft <ajay.calsoft@gmail.com>
This commit is contained in:
Cursor Agent
2025-09-04 17:18:27 +00:00
parent d57f7feb4a
commit e0a6150ed1
13 changed files with 3619 additions and 10 deletions

View File

@@ -18,6 +18,7 @@ from models.enhanced_strategy_models import Base as EnhancedStrategyBase
# Monitoring models now use the same base as enhanced strategy models
from models.monitoring_models import Base as MonitoringBase
from models.persona_models import Base as PersonaBase
from models.subscription_models import Base as SubscriptionBase
# Database configuration
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
@@ -59,7 +60,8 @@ def init_database():
EnhancedStrategyBase.metadata.create_all(bind=engine)
MonitoringBase.metadata.create_all(bind=engine)
PersonaBase.metadata.create_all(bind=engine)
logger.info("Database initialized successfully with all models including personas")
SubscriptionBase.metadata.create_all(bind=engine)
logger.info("Database initialized successfully with all models including subscription system")
except SQLAlchemyError as e:
logger.error(f"Error initializing database: {str(e)}")
raise

View File

@@ -0,0 +1,433 @@
"""
Pricing Service for API Usage Tracking
Manages API pricing, cost calculation, and subscription limits.
"""
from typing import Dict, Any, Optional, List, Tuple
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from loguru import logger
from models.subscription_models import (
APIProviderPricing, SubscriptionPlan, UserSubscription,
UsageSummary, APIUsageLog, APIProvider, SubscriptionTier
)
class PricingService:
"""Service for managing API pricing and cost calculations."""
def __init__(self, db: Session):
self.db = db
self._pricing_cache = {}
self._plans_cache = {}
def initialize_default_pricing(self):
"""Initialize default pricing for all API providers."""
# Gemini API Pricing (as of January 2025)
gemini_pricing = [
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.0-flash-lite",
"cost_per_input_token": 0.000000375, # $0.075 per 1M input tokens (up to 128k context)
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens
"description": "Gemini 2.0 Flash Lite - Fast and efficient model"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash",
"cost_per_input_token": 0.000000625, # $0.125 per 1M input tokens (up to 1M context)
"cost_per_output_token": 0.000000375, # $0.375 per 1M output tokens
"description": "Gemini 2.5 Flash - Balanced performance and cost"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-pro",
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (up to 200k context)
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens
"description": "Gemini 2.5 Pro - Most capable model"
}
]
# OpenAI Pricing (estimated, will be updated)
openai_pricing = [
{
"provider": APIProvider.OPENAI,
"model_name": "gpt-4o",
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens
"description": "GPT-4o - Latest OpenAI model"
},
{
"provider": APIProvider.OPENAI,
"model_name": "gpt-4o-mini",
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens
"description": "GPT-4o Mini - Cost-effective model"
}
]
# Anthropic Pricing (estimated, will be updated)
anthropic_pricing = [
{
"provider": APIProvider.ANTHROPIC,
"model_name": "claude-3.5-sonnet",
"cost_per_input_token": 0.000003, # $3.00 per 1M input tokens
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens
"description": "Claude 3.5 Sonnet - Anthropic's flagship model"
}
]
# Search API Pricing (estimated)
search_pricing = [
{
"provider": APIProvider.TAVILY,
"model_name": "tavily-search",
"cost_per_request": 0.001, # $0.001 per search
"description": "Tavily AI Search API"
},
{
"provider": APIProvider.SERPER,
"model_name": "serper-search",
"cost_per_request": 0.001, # $0.001 per search
"description": "Serper Google Search API"
},
{
"provider": APIProvider.METAPHOR,
"model_name": "metaphor-search",
"cost_per_request": 0.003, # $0.003 per search
"description": "Metaphor/Exa AI Search API"
},
{
"provider": APIProvider.FIRECRAWL,
"model_name": "firecrawl-extract",
"cost_per_page": 0.002, # $0.002 per page crawled
"description": "Firecrawl Web Extraction API"
},
{
"provider": APIProvider.STABILITY,
"model_name": "stable-diffusion",
"cost_per_image": 0.04, # $0.04 per image
"description": "Stability AI Image Generation"
}
]
# Combine all pricing data
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + search_pricing
# Insert pricing data
for pricing_data in all_pricing:
existing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == pricing_data["provider"],
APIProviderPricing.model_name == pricing_data["model_name"]
).first()
if not existing:
pricing = APIProviderPricing(**pricing_data)
self.db.add(pricing)
self.db.commit()
logger.info("Default API pricing initialized")
def initialize_default_plans(self):
"""Initialize default subscription plans."""
plans = [
{
"name": "Free",
"tier": SubscriptionTier.FREE,
"price_monthly": 0.0,
"price_yearly": 0.0,
"gemini_calls_limit": 100,
"openai_calls_limit": 0,
"anthropic_calls_limit": 0,
"mistral_calls_limit": 50,
"tavily_calls_limit": 20,
"serper_calls_limit": 20,
"metaphor_calls_limit": 10,
"firecrawl_calls_limit": 10,
"stability_calls_limit": 5,
"gemini_tokens_limit": 100000,
"monthly_cost_limit": 0.0,
"features": ["basic_content_generation", "limited_research"],
"description": "Perfect for trying out ALwrity"
},
{
"name": "Basic",
"tier": SubscriptionTier.BASIC,
"price_monthly": 29.0,
"price_yearly": 290.0,
"gemini_calls_limit": 1000,
"openai_calls_limit": 500,
"anthropic_calls_limit": 200,
"mistral_calls_limit": 500,
"tavily_calls_limit": 200,
"serper_calls_limit": 200,
"metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100,
"stability_calls_limit": 50,
"gemini_tokens_limit": 1000000,
"openai_tokens_limit": 500000,
"anthropic_tokens_limit": 200000,
"mistral_tokens_limit": 500000,
"monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams"
},
{
"name": "Pro",
"tier": SubscriptionTier.PRO,
"price_monthly": 79.0,
"price_yearly": 790.0,
"gemini_calls_limit": 5000,
"openai_calls_limit": 2500,
"anthropic_calls_limit": 1000,
"mistral_calls_limit": 2500,
"tavily_calls_limit": 1000,
"serper_calls_limit": 1000,
"metaphor_calls_limit": 500,
"firecrawl_calls_limit": 500,
"stability_calls_limit": 200,
"gemini_tokens_limit": 5000000,
"openai_tokens_limit": 2500000,
"anthropic_tokens_limit": 1000000,
"mistral_tokens_limit": 2500000,
"monthly_cost_limit": 150.0,
"features": ["unlimited_content_generation", "premium_research", "advanced_analytics", "priority_support"],
"description": "Perfect for growing businesses"
},
{
"name": "Enterprise",
"tier": SubscriptionTier.ENTERPRISE,
"price_monthly": 199.0,
"price_yearly": 1990.0,
"gemini_calls_limit": 0, # Unlimited
"openai_calls_limit": 0,
"anthropic_calls_limit": 0,
"mistral_calls_limit": 0,
"tavily_calls_limit": 0,
"serper_calls_limit": 0,
"metaphor_calls_limit": 0,
"firecrawl_calls_limit": 0,
"stability_calls_limit": 0,
"gemini_tokens_limit": 0,
"openai_tokens_limit": 0,
"anthropic_tokens_limit": 0,
"mistral_tokens_limit": 0,
"monthly_cost_limit": 500.0,
"features": ["unlimited_everything", "white_label", "dedicated_support", "custom_integrations"],
"description": "For large organizations with high-volume needs"
}
]
for plan_data in plans:
existing = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.name == plan_data["name"]
).first()
if not existing:
plan = SubscriptionPlan(**plan_data)
self.db.add(plan)
self.db.commit()
logger.info("Default subscription plans initialized")
def calculate_api_cost(self, provider: APIProvider, model_name: str,
tokens_input: int = 0, tokens_output: int = 0,
request_count: int = 1, **kwargs) -> Dict[str, float]:
"""Calculate cost for an API call."""
# Get pricing for the provider and model
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == model_name,
APIProviderPricing.is_active == True
).first()
if not pricing:
logger.warning(f"No pricing found for {provider.value}:{model_name}, using default estimates")
# Use default estimates
cost_input = tokens_input * 0.000001 # $1 per 1M tokens default
cost_output = tokens_output * 0.000001
cost_total = (cost_input + cost_output) * request_count
else:
# Calculate based on actual pricing
cost_input = tokens_input * pricing.cost_per_input_token
cost_output = tokens_output * pricing.cost_per_output_token
cost_request = request_count * pricing.cost_per_request
# Handle special cases for non-LLM APIs
cost_search = kwargs.get('search_count', 0) * pricing.cost_per_search
cost_image = kwargs.get('image_count', 0) * pricing.cost_per_image
cost_page = kwargs.get('page_count', 0) * pricing.cost_per_page
cost_total = cost_input + cost_output + cost_request + cost_search + cost_image + cost_page
# Round to 6 decimal places for precision
return {
'cost_input': round(cost_input, 6),
'cost_output': round(cost_output, 6),
'cost_total': round(cost_total, 6)
}
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get usage limits for a user based on their subscription."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if not subscription:
# Return free tier limits
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE
).first()
if free_plan:
return self._plan_to_limits_dict(free_plan)
return None
return self._plan_to_limits_dict(subscription.plan)
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
"""Convert subscription plan to limits dictionary."""
return {
'plan_name': plan.name,
'tier': plan.tier.value,
'limits': {
'gemini_calls': plan.gemini_calls_limit,
'openai_calls': plan.openai_calls_limit,
'anthropic_calls': plan.anthropic_calls_limit,
'mistral_calls': plan.mistral_calls_limit,
'tavily_calls': plan.tavily_calls_limit,
'serper_calls': plan.serper_calls_limit,
'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit,
'stability_calls': plan.stability_calls_limit,
'gemini_tokens': plan.gemini_tokens_limit,
'openai_tokens': plan.openai_tokens_limit,
'anthropic_tokens': plan.anthropic_tokens_limit,
'mistral_tokens': plan.mistral_tokens_limit,
'monthly_cost': plan.monthly_cost_limit
},
'features': plan.features or []
}
def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits."""
# Get user limits
limits = self.get_user_limits(user_id)
if not limits:
return False, "No subscription plan found", {}
# Get current usage for this billing period
current_period = datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
# Check call limits
provider_name = provider.value
current_calls = getattr(usage, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0 and current_calls >= call_limit:
return False, f"API call limit reached for {provider_name}", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0
}
# Check token limits for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0)
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
return False, f"Token limit would be exceeded for {provider_name}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100
}
# Check cost limits
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0 and usage.total_cost >= cost_limit:
return False, "Monthly cost limit reached", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
}
# Calculate usage percentages for warnings
call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
return True, "Within limits", {
'current_calls': current_calls,
'call_limit': call_limit,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
}
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
"""Estimate token count for text based on provider."""
# Get pricing info for token estimation
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.is_active == True
).first()
if pricing and pricing.tokens_per_word:
# Use provider-specific conversion
word_count = len(text.split())
return int(word_count * pricing.tokens_per_word)
else:
# Use default estimation (roughly 1.3 tokens per word for most models)
word_count = len(text.split())
return int(word_count * 1.3)
def get_pricing_info(self, provider: APIProvider, model_name: str = None) -> Optional[Dict[str, Any]]:
"""Get pricing information for a provider/model."""
query = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.is_active == True
)
if model_name:
query = query.filter(APIProviderPricing.model_name == model_name)
pricing = query.first()
if not pricing:
return None
return {
'provider': pricing.provider.value,
'model_name': pricing.model_name,
'cost_per_input_token': pricing.cost_per_input_token,
'cost_per_output_token': pricing.cost_per_output_token,
'cost_per_request': pricing.cost_per_request,
'cost_per_search': pricing.cost_per_search,
'cost_per_image': pricing.cost_per_image,
'cost_per_page': pricing.cost_per_page,
'description': pricing.description
}

View File

@@ -0,0 +1,428 @@
"""
Comprehensive Exception Handling and Logging for Subscription System
Provides robust error handling, logging, and monitoring for the usage-based subscription system.
"""
import traceback
import json
from datetime import datetime
from typing import Dict, Any, Optional, Union, List
from enum import Enum
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from models.subscription_models import APIProvider, UsageAlert
class SubscriptionErrorType(Enum):
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
PRICING_ERROR = "pricing_error"
TRACKING_ERROR = "tracking_error"
DATABASE_ERROR = "database_error"
API_PROVIDER_ERROR = "api_provider_error"
AUTHENTICATION_ERROR = "authentication_error"
BILLING_ERROR = "billing_error"
CONFIGURATION_ERROR = "configuration_error"
class SubscriptionErrorSeverity(Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class SubscriptionException(Exception):
"""Base exception for subscription system errors."""
def __init__(
self,
message: str,
error_type: SubscriptionErrorType,
severity: SubscriptionErrorSeverity = SubscriptionErrorSeverity.MEDIUM,
user_id: str = None,
provider: APIProvider = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
self.message = message
self.error_type = error_type
self.severity = severity
self.user_id = user_id
self.provider = provider
self.context = context or {}
self.original_error = original_error
self.timestamp = datetime.utcnow()
super().__init__(message)
def to_dict(self) -> Dict[str, Any]:
"""Convert exception to dictionary for logging/storage."""
return {
"message": self.message,
"error_type": self.error_type.value,
"severity": self.severity.value,
"user_id": self.user_id,
"provider": self.provider.value if self.provider else None,
"context": self.context,
"timestamp": self.timestamp.isoformat(),
"original_error": str(self.original_error) if self.original_error else None,
"traceback": traceback.format_exc() if self.original_error else None
}
class UsageLimitExceededException(SubscriptionException):
"""Exception raised when usage limits are exceeded."""
def __init__(
self,
message: str,
user_id: str,
provider: APIProvider,
limit_type: str,
current_usage: Union[int, float],
limit_value: Union[int, float],
context: Dict[str, Any] = None
):
context = context or {}
context.update({
"limit_type": limit_type,
"current_usage": current_usage,
"limit_value": limit_value,
"usage_percentage": (current_usage / max(limit_value, 1)) * 100
})
super().__init__(
message=message,
error_type=SubscriptionErrorType.USAGE_LIMIT_EXCEEDED,
severity=SubscriptionErrorSeverity.HIGH,
user_id=user_id,
provider=provider,
context=context
)
class PricingException(SubscriptionException):
"""Exception raised for pricing calculation errors."""
def __init__(
self,
message: str,
provider: APIProvider = None,
model_name: str = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
context = context or {}
if model_name:
context["model_name"] = model_name
super().__init__(
message=message,
error_type=SubscriptionErrorType.PRICING_ERROR,
severity=SubscriptionErrorSeverity.MEDIUM,
provider=provider,
context=context,
original_error=original_error
)
class TrackingException(SubscriptionException):
"""Exception raised for usage tracking errors."""
def __init__(
self,
message: str,
user_id: str = None,
provider: APIProvider = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SubscriptionErrorType.TRACKING_ERROR,
severity=SubscriptionErrorSeverity.MEDIUM,
user_id=user_id,
provider=provider,
context=context,
original_error=original_error
)
class SubscriptionExceptionHandler:
"""Comprehensive exception handler for the subscription system."""
def __init__(self, db: Session = None):
self.db = db
self._setup_logging()
def _setup_logging(self):
"""Setup structured logging for subscription errors."""
# Configure loguru for subscription-specific logging
logger.add(
"logs/subscription_errors.log",
rotation="1 day",
retention="30 days",
level="ERROR",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
filter=lambda record: "subscription" in record["name"].lower()
)
logger.add(
"logs/usage_tracking.log",
rotation="1 day",
retention="90 days",
level="INFO",
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
filter=lambda record: "usage_tracking" in str(record["message"]).lower()
)
def handle_exception(
self,
error: Union[Exception, SubscriptionException],
context: Dict[str, Any] = None,
log_level: str = "error"
) -> Dict[str, Any]:
"""Handle and log subscription system exceptions."""
context = context or {}
# Convert regular exceptions to SubscriptionException
if not isinstance(error, SubscriptionException):
error = SubscriptionException(
message=str(error),
error_type=self._classify_error(error),
severity=self._determine_severity(error),
context=context,
original_error=error
)
# Log the error
error_data = error.to_dict()
error_data.update(context)
log_message = f"Subscription Error: {error.message}"
if log_level == "critical":
logger.critical(log_message, extra={"error_data": error_data})
elif log_level == "error":
logger.error(log_message, extra={"error_data": error_data})
elif log_level == "warning":
logger.warning(log_message, extra={"error_data": error_data})
else:
logger.info(log_message, extra={"error_data": error_data})
# Store critical errors in database for alerting
if error.severity in [SubscriptionErrorSeverity.HIGH, SubscriptionErrorSeverity.CRITICAL]:
self._store_error_alert(error)
# Return formatted error response
return self._format_error_response(error)
def _classify_error(self, error: Exception) -> SubscriptionErrorType:
"""Classify an exception into a subscription error type."""
error_str = str(error).lower()
error_type_name = type(error).__name__.lower()
if "limit" in error_str or "exceeded" in error_str:
return SubscriptionErrorType.USAGE_LIMIT_EXCEEDED
elif "pricing" in error_str or "cost" in error_str:
return SubscriptionErrorType.PRICING_ERROR
elif "tracking" in error_str or "usage" in error_str:
return SubscriptionErrorType.TRACKING_ERROR
elif "database" in error_str or "sql" in error_type_name:
return SubscriptionErrorType.DATABASE_ERROR
elif "api" in error_str or "provider" in error_str:
return SubscriptionErrorType.API_PROVIDER_ERROR
elif "auth" in error_str or "permission" in error_str:
return SubscriptionErrorType.AUTHENTICATION_ERROR
elif "billing" in error_str or "payment" in error_str:
return SubscriptionErrorType.BILLING_ERROR
else:
return SubscriptionErrorType.CONFIGURATION_ERROR
def _determine_severity(self, error: Exception) -> SubscriptionErrorSeverity:
"""Determine the severity of an error."""
error_str = str(error).lower()
error_type = type(error)
# Critical errors
if isinstance(error, (SQLAlchemyError, ConnectionError)):
return SubscriptionErrorSeverity.CRITICAL
# High severity errors
if "limit exceeded" in error_str or "unauthorized" in error_str:
return SubscriptionErrorSeverity.HIGH
# Medium severity errors
if "pricing" in error_str or "tracking" in error_str:
return SubscriptionErrorSeverity.MEDIUM
# Default to low
return SubscriptionErrorSeverity.LOW
def _store_error_alert(self, error: SubscriptionException):
"""Store critical errors as alerts in the database."""
if not self.db or not error.user_id:
return
try:
alert = UsageAlert(
user_id=error.user_id,
alert_type="system_error",
threshold_percentage=0,
provider=error.provider,
title=f"System Error: {error.error_type.value}",
message=error.message,
severity=error.severity.value,
billing_period=datetime.now().strftime("%Y-%m")
)
self.db.add(alert)
self.db.commit()
except Exception as e:
logger.error(f"Failed to store error alert: {e}")
def _format_error_response(self, error: SubscriptionException) -> Dict[str, Any]:
"""Format error for API response."""
response = {
"success": False,
"error": {
"type": error.error_type.value,
"message": error.message,
"severity": error.severity.value,
"timestamp": error.timestamp.isoformat()
}
}
# Add context for debugging (non-sensitive info only)
if error.context:
safe_context = {
k: v for k, v in error.context.items()
if k not in ["password", "token", "key", "secret"]
}
response["error"]["context"] = safe_context
# Add user-friendly message based on error type
user_messages = {
SubscriptionErrorType.USAGE_LIMIT_EXCEEDED:
"You have reached your usage limit. Please upgrade your plan or wait for the next billing cycle.",
SubscriptionErrorType.PRICING_ERROR:
"There was an issue calculating the cost for this request. Please try again.",
SubscriptionErrorType.TRACKING_ERROR:
"Unable to track usage for this request. Please contact support if this persists.",
SubscriptionErrorType.DATABASE_ERROR:
"A database error occurred. Please try again later.",
SubscriptionErrorType.API_PROVIDER_ERROR:
"There was an issue with the API provider. Please try again.",
SubscriptionErrorType.AUTHENTICATION_ERROR:
"Authentication failed. Please check your credentials.",
SubscriptionErrorType.BILLING_ERROR:
"There was a billing-related error. Please contact support.",
SubscriptionErrorType.CONFIGURATION_ERROR:
"System configuration error. Please contact support."
}
response["error"]["user_message"] = user_messages.get(
error.error_type,
"An unexpected error occurred. Please try again or contact support."
)
return response
# Utility functions for common error scenarios
def handle_usage_limit_error(
user_id: str,
provider: APIProvider,
limit_type: str,
current_usage: Union[int, float],
limit_value: Union[int, float],
db: Session = None
) -> Dict[str, Any]:
"""Handle usage limit exceeded errors."""
handler = SubscriptionExceptionHandler(db)
error = UsageLimitExceededException(
message=f"Usage limit exceeded for {limit_type}",
user_id=user_id,
provider=provider,
limit_type=limit_type,
current_usage=current_usage,
limit_value=limit_value
)
return handler.handle_exception(error, log_level="warning")
def handle_pricing_error(
message: str,
provider: APIProvider = None,
model_name: str = None,
original_error: Exception = None,
db: Session = None
) -> Dict[str, Any]:
"""Handle pricing calculation errors."""
handler = SubscriptionExceptionHandler(db)
error = PricingException(
message=message,
provider=provider,
model_name=model_name,
original_error=original_error
)
return handler.handle_exception(error)
def handle_tracking_error(
message: str,
user_id: str = None,
provider: APIProvider = None,
original_error: Exception = None,
db: Session = None
) -> Dict[str, Any]:
"""Handle usage tracking errors."""
handler = SubscriptionExceptionHandler(db)
error = TrackingException(
message=message,
user_id=user_id,
provider=provider,
original_error=original_error
)
return handler.handle_exception(error)
def log_usage_event(
user_id: str,
provider: APIProvider,
action: str,
details: Dict[str, Any] = None
):
"""Log usage events for monitoring and debugging."""
details = details or {}
log_data = {
"user_id": user_id,
"provider": provider.value,
"action": action,
"timestamp": datetime.utcnow().isoformat(),
**details
}
logger.info(f"Usage Tracking: {action}", extra={"usage_data": log_data})
# Decorator for automatic exception handling
def handle_subscription_errors(db: Session = None):
"""Decorator to automatically handle subscription-related exceptions."""
def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except SubscriptionException as e:
handler = SubscriptionExceptionHandler(db)
return handler.handle_exception(e)
except Exception as e:
handler = SubscriptionExceptionHandler(db)
return handler.handle_exception(e)
return wrapper
return decorator

View File

@@ -0,0 +1,460 @@
"""
Usage Tracking Service
Comprehensive tracking of API usage, costs, and subscription limits.
"""
import asyncio
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from loguru import logger
import json
from models.subscription_models import (
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
UserSubscription, UsageStatus
)
from services.pricing_service import PricingService
class UsageTrackingService:
"""Service for tracking API usage and managing subscription limits."""
def __init__(self, db: Session):
self.db = db
self.pricing_service = PricingService(db)
async def track_api_usage(self, user_id: str, provider: APIProvider,
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
"""Track an API usage event and update billing information."""
try:
# Calculate costs
cost_data = self.pricing_service.calculate_api_cost(
provider=provider,
model_name=model_used or f"{provider.value}-default",
tokens_input=tokens_input,
tokens_output=tokens_output,
request_count=1,
**kwargs
)
# Create usage log entry
billing_period = datetime.now().strftime("%Y-%m")
usage_log = APIUsageLog(
user_id=user_id,
provider=provider,
endpoint=endpoint,
method=method,
model_used=model_used,
tokens_input=tokens_input,
tokens_output=tokens_output,
tokens_total=tokens_input + tokens_output,
cost_input=cost_data['cost_input'],
cost_output=cost_data['cost_output'],
cost_total=cost_data['cost_total'],
response_time=response_time,
status_code=status_code,
request_size=request_size,
response_size=response_size,
user_agent=user_agent,
ip_address=ip_address,
error_message=error_message,
retry_count=retry_count,
billing_period=billing_period
)
self.db.add(usage_log)
# Update usage summary
await self._update_usage_summary(
user_id=user_id,
provider=provider,
tokens_used=tokens_input + tokens_output,
cost=cost_data['cost_total'],
billing_period=billing_period,
response_time=response_time,
is_error=status_code >= 400
)
# Check for usage alerts
await self._check_usage_alerts(user_id, provider, billing_period)
self.db.commit()
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
return {
'usage_logged': True,
'cost': cost_data['cost_total'],
'tokens_used': tokens_input + tokens_output,
'billing_period': billing_period
}
except Exception as e:
logger.error(f"Error tracking API usage: {str(e)}")
self.db.rollback()
return {
'usage_logged': False,
'error': str(e)
}
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
tokens_used: int, cost: float, billing_period: str,
response_time: float, is_error: bool):
"""Update the usage summary for a user."""
# Get or create usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=billing_period
)
self.db.add(summary)
# Update provider-specific counters
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
setattr(summary, f"{provider_name}_calls", current_calls + 1)
# Update token usage for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
# Update cost
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
setattr(summary, f"{provider_name}_cost", current_cost + cost)
# Update totals
summary.total_calls += 1
summary.total_tokens += tokens_used
summary.total_cost += cost
# Update performance metrics
if summary.total_calls > 0:
# Update average response time
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
summary.avg_response_time = total_response_time / summary.total_calls
# Update error rate
if is_error:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
summary.error_rate = (error_count / summary.total_calls) * 100
else:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
summary.error_rate = (error_count / summary.total_calls) * 100
# Update usage status based on limits
await self._update_usage_status(summary)
summary.updated_at = datetime.utcnow()
async def _update_usage_status(self, summary: UsageSummary):
"""Update usage status based on subscription limits."""
limits = self.pricing_service.get_user_limits(summary.user_id)
if not limits:
return
# Check various limits and determine status
max_usage_percentage = 0.0
# Check cost limit
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0:
cost_usage_pct = (summary.total_cost / cost_limit) * 100
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
# Check call limits for each provider
for provider in APIProvider:
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
call_usage_pct = (current_calls / call_limit) * 100
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
# Update status based on highest usage percentage
if max_usage_percentage >= 100:
summary.usage_status = UsageStatus.LIMIT_REACHED
elif max_usage_percentage >= 80:
summary.usage_status = UsageStatus.WARNING
else:
summary.usage_status = UsageStatus.ACTIVE
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
# Get current usage
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
return
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
if not limits:
return
# Check for alert thresholds (80%, 90%, 100%)
thresholds = [80, 90, 100]
for threshold in thresholds:
# Check if alert already sent for this threshold
existing_alert = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.threshold_percentage == threshold,
UsageAlert.provider == provider,
UsageAlert.is_sent == True
).first()
if existing_alert:
continue
# Check if threshold is reached
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentage = (current_calls / call_limit) * 100
if usage_percentage >= threshold:
await self._create_usage_alert(
user_id=user_id,
provider=provider,
threshold=threshold,
current_usage=current_calls,
limit=call_limit,
billing_period=billing_period
)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
# Determine alert type and severity
if threshold >= 100:
alert_type = "limit_reached"
severity = "error"
title = f"API Limit Reached - {provider.value.title()}"
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
elif threshold >= 90:
alert_type = "usage_warning"
severity = "warning"
title = f"API Usage Warning - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
else:
alert_type = "usage_warning"
severity = "info"
title = f"API Usage Notice - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
alert = UsageAlert(
user_id=user_id,
alert_type=alert_type,
threshold_percentage=threshold,
provider=provider,
title=title,
message=message,
severity=severity,
billing_period=billing_period
)
self.db.add(alert)
logger.info(f"Created usage alert for {user_id}: {title}")
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
if not billing_period:
billing_period = datetime.now().strftime("%Y-%m")
# Get usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
# Get recent alerts
alerts = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.is_read == False
).order_by(UsageAlert.created_at.desc()).limit(10).all()
if not summary:
# No usage this period
return {
'billing_period': billing_period,
'usage_status': 'active',
'total_calls': 0,
'total_tokens': 0,
'total_cost': 0.0,
'limits': limits,
'provider_breakdown': {},
'alerts': [],
'usage_percentages': {}
}
# Calculate usage percentages
usage_percentages = {}
if limits:
for provider in APIProvider:
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentages[f"{provider_name}_calls"] = (current_calls / call_limit) * 100
else:
usage_percentages[f"{provider_name}_calls"] = 0
# Cost usage percentage
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0:
usage_percentages['cost'] = (summary.total_cost / cost_limit) * 100
else:
usage_percentages['cost'] = 0
# Provider breakdown
provider_breakdown = {}
for provider in APIProvider:
provider_name = provider.value
provider_breakdown[provider_name] = {
'calls': getattr(summary, f"{provider_name}_calls", 0),
'tokens': getattr(summary, f"{provider_name}_tokens", 0),
'cost': getattr(summary, f"{provider_name}_cost", 0.0)
}
return {
'billing_period': billing_period,
'usage_status': summary.usage_status.value,
'total_calls': summary.total_calls,
'total_tokens': summary.total_tokens,
'total_cost': summary.total_cost,
'avg_response_time': summary.avg_response_time,
'error_rate': summary.error_rate,
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [
{
'id': alert.id,
'type': alert.alert_type,
'title': alert.title,
'message': alert.message,
'severity': alert.severity,
'created_at': alert.created_at.isoformat()
}
for alert in alerts
],
'usage_percentages': usage_percentages,
'last_updated': summary.updated_at.isoformat()
}
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time."""
# Calculate billing periods
end_date = datetime.now()
periods = []
for i in range(months):
period_date = end_date - timedelta(days=30 * i)
periods.append(period_date.strftime("%Y-%m"))
periods.reverse() # Oldest first
# Get usage summaries for these periods
summaries = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(periods)
).order_by(UsageSummary.billing_period).all()
# Create trends data
trends = {
'periods': periods,
'total_calls': [],
'total_cost': [],
'total_tokens': [],
'provider_trends': {}
}
summary_dict = {s.billing_period: s for s in summaries}
for period in periods:
summary = summary_dict.get(period)
if summary:
trends['total_calls'].append(summary.total_calls)
trends['total_cost'].append(summary.total_cost)
trends['total_tokens'].append(summary.total_tokens)
# Provider-specific trends
for provider in APIProvider:
provider_name = provider.value
if provider_name not in trends['provider_trends']:
trends['provider_trends'][provider_name] = {
'calls': [],
'cost': [],
'tokens': []
}
trends['provider_trends'][provider_name]['calls'].append(
getattr(summary, f"{provider_name}_calls", 0)
)
trends['provider_trends'][provider_name]['cost'].append(
getattr(summary, f"{provider_name}_cost", 0.0)
)
trends['provider_trends'][provider_name]['tokens'].append(
getattr(summary, f"{provider_name}_tokens", 0)
)
else:
# No data for this period
trends['total_calls'].append(0)
trends['total_cost'].append(0.0)
trends['total_tokens'].append(0)
for provider in APIProvider:
provider_name = provider.value
if provider_name not in trends['provider_trends']:
trends['provider_trends'][provider_name] = {
'calls': [],
'cost': [],
'tokens': []
}
trends['provider_trends'][provider_name]['calls'].append(0)
trends['provider_trends'][provider_name]['cost'].append(0.0)
trends['provider_trends'][provider_name]['tokens'].append(0)
return trends
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
return self.pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=tokens_requested
)