Add comprehensive usage-based subscription system with API tracking
Co-authored-by: ajay.calsoft <ajay.calsoft@gmail.com>
This commit is contained in:
@@ -18,6 +18,7 @@ from models.enhanced_strategy_models import Base as EnhancedStrategyBase
|
||||
# Monitoring models now use the same base as enhanced strategy models
|
||||
from models.monitoring_models import Base as MonitoringBase
|
||||
from models.persona_models import Base as PersonaBase
|
||||
from models.subscription_models import Base as SubscriptionBase
|
||||
|
||||
# Database configuration
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./alwrity.db')
|
||||
@@ -59,7 +60,8 @@ def init_database():
|
||||
EnhancedStrategyBase.metadata.create_all(bind=engine)
|
||||
MonitoringBase.metadata.create_all(bind=engine)
|
||||
PersonaBase.metadata.create_all(bind=engine)
|
||||
logger.info("Database initialized successfully with all models including personas")
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
logger.info("Database initialized successfully with all models including subscription system")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error initializing database: {str(e)}")
|
||||
raise
|
||||
|
||||
433
backend/services/pricing_service.py
Normal file
433
backend/services/pricing_service.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Pricing Service for API Usage Tracking
|
||||
Manages API pricing, cost calculation, and subscription limits.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from decimal import Decimal, ROUND_HALF_UP
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import (
|
||||
APIProviderPricing, SubscriptionPlan, UserSubscription,
|
||||
UsageSummary, APIUsageLog, APIProvider, SubscriptionTier
|
||||
)
|
||||
|
||||
class PricingService:
|
||||
"""Service for managing API pricing and cost calculations."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self._pricing_cache = {}
|
||||
self._plans_cache = {}
|
||||
|
||||
def initialize_default_pricing(self):
|
||||
"""Initialize default pricing for all API providers."""
|
||||
|
||||
# Gemini API Pricing (as of January 2025)
|
||||
gemini_pricing = [
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.0-flash-lite",
|
||||
"cost_per_input_token": 0.000000375, # $0.075 per 1M input tokens (up to 128k context)
|
||||
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens
|
||||
"description": "Gemini 2.0 Flash Lite - Fast and efficient model"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-flash",
|
||||
"cost_per_input_token": 0.000000625, # $0.125 per 1M input tokens (up to 1M context)
|
||||
"cost_per_output_token": 0.000000375, # $0.375 per 1M output tokens
|
||||
"description": "Gemini 2.5 Flash - Balanced performance and cost"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.GEMINI,
|
||||
"model_name": "gemini-2.5-pro",
|
||||
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (up to 200k context)
|
||||
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens
|
||||
"description": "Gemini 2.5 Pro - Most capable model"
|
||||
}
|
||||
]
|
||||
|
||||
# OpenAI Pricing (estimated, will be updated)
|
||||
openai_pricing = [
|
||||
{
|
||||
"provider": APIProvider.OPENAI,
|
||||
"model_name": "gpt-4o",
|
||||
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens
|
||||
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens
|
||||
"description": "GPT-4o - Latest OpenAI model"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.OPENAI,
|
||||
"model_name": "gpt-4o-mini",
|
||||
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
|
||||
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens
|
||||
"description": "GPT-4o Mini - Cost-effective model"
|
||||
}
|
||||
]
|
||||
|
||||
# Anthropic Pricing (estimated, will be updated)
|
||||
anthropic_pricing = [
|
||||
{
|
||||
"provider": APIProvider.ANTHROPIC,
|
||||
"model_name": "claude-3.5-sonnet",
|
||||
"cost_per_input_token": 0.000003, # $3.00 per 1M input tokens
|
||||
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens
|
||||
"description": "Claude 3.5 Sonnet - Anthropic's flagship model"
|
||||
}
|
||||
]
|
||||
|
||||
# Search API Pricing (estimated)
|
||||
search_pricing = [
|
||||
{
|
||||
"provider": APIProvider.TAVILY,
|
||||
"model_name": "tavily-search",
|
||||
"cost_per_request": 0.001, # $0.001 per search
|
||||
"description": "Tavily AI Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.SERPER,
|
||||
"model_name": "serper-search",
|
||||
"cost_per_request": 0.001, # $0.001 per search
|
||||
"description": "Serper Google Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.METAPHOR,
|
||||
"model_name": "metaphor-search",
|
||||
"cost_per_request": 0.003, # $0.003 per search
|
||||
"description": "Metaphor/Exa AI Search API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.FIRECRAWL,
|
||||
"model_name": "firecrawl-extract",
|
||||
"cost_per_page": 0.002, # $0.002 per page crawled
|
||||
"description": "Firecrawl Web Extraction API"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.STABILITY,
|
||||
"model_name": "stable-diffusion",
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"description": "Stability AI Image Generation"
|
||||
}
|
||||
]
|
||||
|
||||
# Combine all pricing data
|
||||
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + search_pricing
|
||||
|
||||
# Insert pricing data
|
||||
for pricing_data in all_pricing:
|
||||
existing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == pricing_data["provider"],
|
||||
APIProviderPricing.model_name == pricing_data["model_name"]
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
pricing = APIProviderPricing(**pricing_data)
|
||||
self.db.add(pricing)
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Default API pricing initialized")
|
||||
|
||||
def initialize_default_plans(self):
|
||||
"""Initialize default subscription plans."""
|
||||
|
||||
plans = [
|
||||
{
|
||||
"name": "Free",
|
||||
"tier": SubscriptionTier.FREE,
|
||||
"price_monthly": 0.0,
|
||||
"price_yearly": 0.0,
|
||||
"gemini_calls_limit": 100,
|
||||
"openai_calls_limit": 0,
|
||||
"anthropic_calls_limit": 0,
|
||||
"mistral_calls_limit": 50,
|
||||
"tavily_calls_limit": 20,
|
||||
"serper_calls_limit": 20,
|
||||
"metaphor_calls_limit": 10,
|
||||
"firecrawl_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"gemini_tokens_limit": 100000,
|
||||
"monthly_cost_limit": 0.0,
|
||||
"features": ["basic_content_generation", "limited_research"],
|
||||
"description": "Perfect for trying out ALwrity"
|
||||
},
|
||||
{
|
||||
"name": "Basic",
|
||||
"tier": SubscriptionTier.BASIC,
|
||||
"price_monthly": 29.0,
|
||||
"price_yearly": 290.0,
|
||||
"gemini_calls_limit": 1000,
|
||||
"openai_calls_limit": 500,
|
||||
"anthropic_calls_limit": 200,
|
||||
"mistral_calls_limit": 500,
|
||||
"tavily_calls_limit": 200,
|
||||
"serper_calls_limit": 200,
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 50,
|
||||
"gemini_tokens_limit": 1000000,
|
||||
"openai_tokens_limit": 500000,
|
||||
"anthropic_tokens_limit": 200000,
|
||||
"mistral_tokens_limit": 500000,
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
},
|
||||
{
|
||||
"name": "Pro",
|
||||
"tier": SubscriptionTier.PRO,
|
||||
"price_monthly": 79.0,
|
||||
"price_yearly": 790.0,
|
||||
"gemini_calls_limit": 5000,
|
||||
"openai_calls_limit": 2500,
|
||||
"anthropic_calls_limit": 1000,
|
||||
"mistral_calls_limit": 2500,
|
||||
"tavily_calls_limit": 1000,
|
||||
"serper_calls_limit": 1000,
|
||||
"metaphor_calls_limit": 500,
|
||||
"firecrawl_calls_limit": 500,
|
||||
"stability_calls_limit": 200,
|
||||
"gemini_tokens_limit": 5000000,
|
||||
"openai_tokens_limit": 2500000,
|
||||
"anthropic_tokens_limit": 1000000,
|
||||
"mistral_tokens_limit": 2500000,
|
||||
"monthly_cost_limit": 150.0,
|
||||
"features": ["unlimited_content_generation", "premium_research", "advanced_analytics", "priority_support"],
|
||||
"description": "Perfect for growing businesses"
|
||||
},
|
||||
{
|
||||
"name": "Enterprise",
|
||||
"tier": SubscriptionTier.ENTERPRISE,
|
||||
"price_monthly": 199.0,
|
||||
"price_yearly": 1990.0,
|
||||
"gemini_calls_limit": 0, # Unlimited
|
||||
"openai_calls_limit": 0,
|
||||
"anthropic_calls_limit": 0,
|
||||
"mistral_calls_limit": 0,
|
||||
"tavily_calls_limit": 0,
|
||||
"serper_calls_limit": 0,
|
||||
"metaphor_calls_limit": 0,
|
||||
"firecrawl_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"gemini_tokens_limit": 0,
|
||||
"openai_tokens_limit": 0,
|
||||
"anthropic_tokens_limit": 0,
|
||||
"mistral_tokens_limit": 0,
|
||||
"monthly_cost_limit": 500.0,
|
||||
"features": ["unlimited_everything", "white_label", "dedicated_support", "custom_integrations"],
|
||||
"description": "For large organizations with high-volume needs"
|
||||
}
|
||||
]
|
||||
|
||||
for plan_data in plans:
|
||||
existing = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == plan_data["name"]
|
||||
).first()
|
||||
|
||||
if not existing:
|
||||
plan = SubscriptionPlan(**plan_data)
|
||||
self.db.add(plan)
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Default subscription plans initialized")
|
||||
|
||||
def calculate_api_cost(self, provider: APIProvider, model_name: str,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
request_count: int = 1, **kwargs) -> Dict[str, float]:
|
||||
"""Calculate cost for an API call."""
|
||||
|
||||
# Get pricing for the provider and model
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == model_name,
|
||||
APIProviderPricing.is_active == True
|
||||
).first()
|
||||
|
||||
if not pricing:
|
||||
logger.warning(f"No pricing found for {provider.value}:{model_name}, using default estimates")
|
||||
# Use default estimates
|
||||
cost_input = tokens_input * 0.000001 # $1 per 1M tokens default
|
||||
cost_output = tokens_output * 0.000001
|
||||
cost_total = (cost_input + cost_output) * request_count
|
||||
else:
|
||||
# Calculate based on actual pricing
|
||||
cost_input = tokens_input * pricing.cost_per_input_token
|
||||
cost_output = tokens_output * pricing.cost_per_output_token
|
||||
cost_request = request_count * pricing.cost_per_request
|
||||
|
||||
# Handle special cases for non-LLM APIs
|
||||
cost_search = kwargs.get('search_count', 0) * pricing.cost_per_search
|
||||
cost_image = kwargs.get('image_count', 0) * pricing.cost_per_image
|
||||
cost_page = kwargs.get('page_count', 0) * pricing.cost_per_page
|
||||
|
||||
cost_total = cost_input + cost_output + cost_request + cost_search + cost_image + cost_page
|
||||
|
||||
# Round to 6 decimal places for precision
|
||||
return {
|
||||
'cost_input': round(cost_input, 6),
|
||||
'cost_output': round(cost_output, 6),
|
||||
'cost_total': round(cost_total, 6)
|
||||
}
|
||||
|
||||
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get usage limits for a user based on their subscription."""
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if not subscription:
|
||||
# Return free tier limits
|
||||
free_plan = self.db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE
|
||||
).first()
|
||||
if free_plan:
|
||||
return self._plan_to_limits_dict(free_plan)
|
||||
return None
|
||||
|
||||
return self._plan_to_limits_dict(subscription.plan)
|
||||
|
||||
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""Convert subscription plan to limits dictionary."""
|
||||
return {
|
||||
'plan_name': plan.name,
|
||||
'tier': plan.tier.value,
|
||||
'limits': {
|
||||
'gemini_calls': plan.gemini_calls_limit,
|
||||
'openai_calls': plan.openai_calls_limit,
|
||||
'anthropic_calls': plan.anthropic_calls_limit,
|
||||
'mistral_calls': plan.mistral_calls_limit,
|
||||
'tavily_calls': plan.tavily_calls_limit,
|
||||
'serper_calls': plan.serper_calls_limit,
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'gemini_tokens': plan.gemini_tokens_limit,
|
||||
'openai_tokens': plan.openai_tokens_limit,
|
||||
'anthropic_tokens': plan.anthropic_tokens_limit,
|
||||
'mistral_tokens': plan.mistral_tokens_limit,
|
||||
'monthly_cost': plan.monthly_cost_limit
|
||||
},
|
||||
'features': plan.features or []
|
||||
}
|
||||
|
||||
def check_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Check if user can make an API call within their limits."""
|
||||
|
||||
# Get user limits
|
||||
limits = self.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return False, "No subscription plan found", {}
|
||||
|
||||
# Get current usage for this billing period
|
||||
current_period = datetime.now().strftime("%Y-%m")
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
usage = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
self.db.add(usage)
|
||||
self.db.commit()
|
||||
|
||||
# Check call limits
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(usage, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0 and current_calls >= call_limit:
|
||||
return False, f"API call limit reached for {provider_name}", {
|
||||
'current_calls': current_calls,
|
||||
'limit': call_limit,
|
||||
'usage_percentage': 100.0
|
||||
}
|
||||
|
||||
# Check token limits for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(usage, f"{provider_name}_tokens", 0)
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0)
|
||||
|
||||
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
|
||||
return False, f"Token limit would be exceeded for {provider_name}", {
|
||||
'current_tokens': current_tokens,
|
||||
'requested_tokens': tokens_requested,
|
||||
'limit': token_limit,
|
||||
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100
|
||||
}
|
||||
|
||||
# Check cost limits
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0 and usage.total_cost >= cost_limit:
|
||||
return False, "Monthly cost limit reached", {
|
||||
'current_cost': usage.total_cost,
|
||||
'limit': cost_limit,
|
||||
'usage_percentage': 100.0
|
||||
}
|
||||
|
||||
# Calculate usage percentages for warnings
|
||||
call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0
|
||||
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
|
||||
|
||||
return True, "Within limits", {
|
||||
'current_calls': current_calls,
|
||||
'call_limit': call_limit,
|
||||
'call_usage_percentage': call_usage_pct,
|
||||
'current_cost': usage.total_cost,
|
||||
'cost_limit': cost_limit,
|
||||
'cost_usage_percentage': cost_usage_pct
|
||||
}
|
||||
|
||||
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
|
||||
"""Estimate token count for text based on provider."""
|
||||
|
||||
# Get pricing info for token estimation
|
||||
pricing = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.is_active == True
|
||||
).first()
|
||||
|
||||
if pricing and pricing.tokens_per_word:
|
||||
# Use provider-specific conversion
|
||||
word_count = len(text.split())
|
||||
return int(word_count * pricing.tokens_per_word)
|
||||
else:
|
||||
# Use default estimation (roughly 1.3 tokens per word for most models)
|
||||
word_count = len(text.split())
|
||||
return int(word_count * 1.3)
|
||||
|
||||
def get_pricing_info(self, provider: APIProvider, model_name: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get pricing information for a provider/model."""
|
||||
|
||||
query = self.db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.is_active == True
|
||||
)
|
||||
|
||||
if model_name:
|
||||
query = query.filter(APIProviderPricing.model_name == model_name)
|
||||
|
||||
pricing = query.first()
|
||||
|
||||
if not pricing:
|
||||
return None
|
||||
|
||||
return {
|
||||
'provider': pricing.provider.value,
|
||||
'model_name': pricing.model_name,
|
||||
'cost_per_input_token': pricing.cost_per_input_token,
|
||||
'cost_per_output_token': pricing.cost_per_output_token,
|
||||
'cost_per_request': pricing.cost_per_request,
|
||||
'cost_per_search': pricing.cost_per_search,
|
||||
'cost_per_image': pricing.cost_per_image,
|
||||
'cost_per_page': pricing.cost_per_page,
|
||||
'description': pricing.description
|
||||
}
|
||||
428
backend/services/subscription_exception_handler.py
Normal file
428
backend/services/subscription_exception_handler.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""
|
||||
Comprehensive Exception Handling and Logging for Subscription System
|
||||
Provides robust error handling, logging, and monitoring for the usage-based subscription system.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional, Union, List
|
||||
from enum import Enum
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.subscription_models import APIProvider, UsageAlert
|
||||
|
||||
class SubscriptionErrorType(Enum):
|
||||
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
|
||||
PRICING_ERROR = "pricing_error"
|
||||
TRACKING_ERROR = "tracking_error"
|
||||
DATABASE_ERROR = "database_error"
|
||||
API_PROVIDER_ERROR = "api_provider_error"
|
||||
AUTHENTICATION_ERROR = "authentication_error"
|
||||
BILLING_ERROR = "billing_error"
|
||||
CONFIGURATION_ERROR = "configuration_error"
|
||||
|
||||
class SubscriptionErrorSeverity(Enum):
|
||||
LOW = "low"
|
||||
MEDIUM = "medium"
|
||||
HIGH = "high"
|
||||
CRITICAL = "critical"
|
||||
|
||||
class SubscriptionException(Exception):
|
||||
"""Base exception for subscription system errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
error_type: SubscriptionErrorType,
|
||||
severity: SubscriptionErrorSeverity = SubscriptionErrorSeverity.MEDIUM,
|
||||
user_id: str = None,
|
||||
provider: APIProvider = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
self.message = message
|
||||
self.error_type = error_type
|
||||
self.severity = severity
|
||||
self.user_id = user_id
|
||||
self.provider = provider
|
||||
self.context = context or {}
|
||||
self.original_error = original_error
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
super().__init__(message)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert exception to dictionary for logging/storage."""
|
||||
return {
|
||||
"message": self.message,
|
||||
"error_type": self.error_type.value,
|
||||
"severity": self.severity.value,
|
||||
"user_id": self.user_id,
|
||||
"provider": self.provider.value if self.provider else None,
|
||||
"context": self.context,
|
||||
"timestamp": self.timestamp.isoformat(),
|
||||
"original_error": str(self.original_error) if self.original_error else None,
|
||||
"traceback": traceback.format_exc() if self.original_error else None
|
||||
}
|
||||
|
||||
class UsageLimitExceededException(SubscriptionException):
|
||||
"""Exception raised when usage limits are exceeded."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: str,
|
||||
provider: APIProvider,
|
||||
limit_type: str,
|
||||
current_usage: Union[int, float],
|
||||
limit_value: Union[int, float],
|
||||
context: Dict[str, Any] = None
|
||||
):
|
||||
context = context or {}
|
||||
context.update({
|
||||
"limit_type": limit_type,
|
||||
"current_usage": current_usage,
|
||||
"limit_value": limit_value,
|
||||
"usage_percentage": (current_usage / max(limit_value, 1)) * 100
|
||||
})
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SubscriptionErrorType.USAGE_LIMIT_EXCEEDED,
|
||||
severity=SubscriptionErrorSeverity.HIGH,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
context=context
|
||||
)
|
||||
|
||||
class PricingException(SubscriptionException):
|
||||
"""Exception raised for pricing calculation errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
provider: APIProvider = None,
|
||||
model_name: str = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
context = context or {}
|
||||
if model_name:
|
||||
context["model_name"] = model_name
|
||||
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SubscriptionErrorType.PRICING_ERROR,
|
||||
severity=SubscriptionErrorSeverity.MEDIUM,
|
||||
provider=provider,
|
||||
context=context,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
class TrackingException(SubscriptionException):
|
||||
"""Exception raised for usage tracking errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
user_id: str = None,
|
||||
provider: APIProvider = None,
|
||||
context: Dict[str, Any] = None,
|
||||
original_error: Exception = None
|
||||
):
|
||||
super().__init__(
|
||||
message=message,
|
||||
error_type=SubscriptionErrorType.TRACKING_ERROR,
|
||||
severity=SubscriptionErrorSeverity.MEDIUM,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
context=context,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
class SubscriptionExceptionHandler:
|
||||
"""Comprehensive exception handler for the subscription system."""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
self.db = db
|
||||
self._setup_logging()
|
||||
|
||||
def _setup_logging(self):
|
||||
"""Setup structured logging for subscription errors."""
|
||||
# Configure loguru for subscription-specific logging
|
||||
logger.add(
|
||||
"logs/subscription_errors.log",
|
||||
rotation="1 day",
|
||||
retention="30 days",
|
||||
level="ERROR",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
|
||||
filter=lambda record: "subscription" in record["name"].lower()
|
||||
)
|
||||
|
||||
logger.add(
|
||||
"logs/usage_tracking.log",
|
||||
rotation="1 day",
|
||||
retention="90 days",
|
||||
level="INFO",
|
||||
format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {message}",
|
||||
filter=lambda record: "usage_tracking" in str(record["message"]).lower()
|
||||
)
|
||||
|
||||
def handle_exception(
|
||||
self,
|
||||
error: Union[Exception, SubscriptionException],
|
||||
context: Dict[str, Any] = None,
|
||||
log_level: str = "error"
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle and log subscription system exceptions."""
|
||||
|
||||
context = context or {}
|
||||
|
||||
# Convert regular exceptions to SubscriptionException
|
||||
if not isinstance(error, SubscriptionException):
|
||||
error = SubscriptionException(
|
||||
message=str(error),
|
||||
error_type=self._classify_error(error),
|
||||
severity=self._determine_severity(error),
|
||||
context=context,
|
||||
original_error=error
|
||||
)
|
||||
|
||||
# Log the error
|
||||
error_data = error.to_dict()
|
||||
error_data.update(context)
|
||||
|
||||
log_message = f"Subscription Error: {error.message}"
|
||||
|
||||
if log_level == "critical":
|
||||
logger.critical(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "error":
|
||||
logger.error(log_message, extra={"error_data": error_data})
|
||||
elif log_level == "warning":
|
||||
logger.warning(log_message, extra={"error_data": error_data})
|
||||
else:
|
||||
logger.info(log_message, extra={"error_data": error_data})
|
||||
|
||||
# Store critical errors in database for alerting
|
||||
if error.severity in [SubscriptionErrorSeverity.HIGH, SubscriptionErrorSeverity.CRITICAL]:
|
||||
self._store_error_alert(error)
|
||||
|
||||
# Return formatted error response
|
||||
return self._format_error_response(error)
|
||||
|
||||
def _classify_error(self, error: Exception) -> SubscriptionErrorType:
|
||||
"""Classify an exception into a subscription error type."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type_name = type(error).__name__.lower()
|
||||
|
||||
if "limit" in error_str or "exceeded" in error_str:
|
||||
return SubscriptionErrorType.USAGE_LIMIT_EXCEEDED
|
||||
elif "pricing" in error_str or "cost" in error_str:
|
||||
return SubscriptionErrorType.PRICING_ERROR
|
||||
elif "tracking" in error_str or "usage" in error_str:
|
||||
return SubscriptionErrorType.TRACKING_ERROR
|
||||
elif "database" in error_str or "sql" in error_type_name:
|
||||
return SubscriptionErrorType.DATABASE_ERROR
|
||||
elif "api" in error_str or "provider" in error_str:
|
||||
return SubscriptionErrorType.API_PROVIDER_ERROR
|
||||
elif "auth" in error_str or "permission" in error_str:
|
||||
return SubscriptionErrorType.AUTHENTICATION_ERROR
|
||||
elif "billing" in error_str or "payment" in error_str:
|
||||
return SubscriptionErrorType.BILLING_ERROR
|
||||
else:
|
||||
return SubscriptionErrorType.CONFIGURATION_ERROR
|
||||
|
||||
def _determine_severity(self, error: Exception) -> SubscriptionErrorSeverity:
|
||||
"""Determine the severity of an error."""
|
||||
|
||||
error_str = str(error).lower()
|
||||
error_type = type(error)
|
||||
|
||||
# Critical errors
|
||||
if isinstance(error, (SQLAlchemyError, ConnectionError)):
|
||||
return SubscriptionErrorSeverity.CRITICAL
|
||||
|
||||
# High severity errors
|
||||
if "limit exceeded" in error_str or "unauthorized" in error_str:
|
||||
return SubscriptionErrorSeverity.HIGH
|
||||
|
||||
# Medium severity errors
|
||||
if "pricing" in error_str or "tracking" in error_str:
|
||||
return SubscriptionErrorSeverity.MEDIUM
|
||||
|
||||
# Default to low
|
||||
return SubscriptionErrorSeverity.LOW
|
||||
|
||||
def _store_error_alert(self, error: SubscriptionException):
|
||||
"""Store critical errors as alerts in the database."""
|
||||
|
||||
if not self.db or not error.user_id:
|
||||
return
|
||||
|
||||
try:
|
||||
alert = UsageAlert(
|
||||
user_id=error.user_id,
|
||||
alert_type="system_error",
|
||||
threshold_percentage=0,
|
||||
provider=error.provider,
|
||||
title=f"System Error: {error.error_type.value}",
|
||||
message=error.message,
|
||||
severity=error.severity.value,
|
||||
billing_period=datetime.now().strftime("%Y-%m")
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
self.db.commit()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store error alert: {e}")
|
||||
|
||||
def _format_error_response(self, error: SubscriptionException) -> Dict[str, Any]:
|
||||
"""Format error for API response."""
|
||||
|
||||
response = {
|
||||
"success": False,
|
||||
"error": {
|
||||
"type": error.error_type.value,
|
||||
"message": error.message,
|
||||
"severity": error.severity.value,
|
||||
"timestamp": error.timestamp.isoformat()
|
||||
}
|
||||
}
|
||||
|
||||
# Add context for debugging (non-sensitive info only)
|
||||
if error.context:
|
||||
safe_context = {
|
||||
k: v for k, v in error.context.items()
|
||||
if k not in ["password", "token", "key", "secret"]
|
||||
}
|
||||
response["error"]["context"] = safe_context
|
||||
|
||||
# Add user-friendly message based on error type
|
||||
user_messages = {
|
||||
SubscriptionErrorType.USAGE_LIMIT_EXCEEDED:
|
||||
"You have reached your usage limit. Please upgrade your plan or wait for the next billing cycle.",
|
||||
SubscriptionErrorType.PRICING_ERROR:
|
||||
"There was an issue calculating the cost for this request. Please try again.",
|
||||
SubscriptionErrorType.TRACKING_ERROR:
|
||||
"Unable to track usage for this request. Please contact support if this persists.",
|
||||
SubscriptionErrorType.DATABASE_ERROR:
|
||||
"A database error occurred. Please try again later.",
|
||||
SubscriptionErrorType.API_PROVIDER_ERROR:
|
||||
"There was an issue with the API provider. Please try again.",
|
||||
SubscriptionErrorType.AUTHENTICATION_ERROR:
|
||||
"Authentication failed. Please check your credentials.",
|
||||
SubscriptionErrorType.BILLING_ERROR:
|
||||
"There was a billing-related error. Please contact support.",
|
||||
SubscriptionErrorType.CONFIGURATION_ERROR:
|
||||
"System configuration error. Please contact support."
|
||||
}
|
||||
|
||||
response["error"]["user_message"] = user_messages.get(
|
||||
error.error_type,
|
||||
"An unexpected error occurred. Please try again or contact support."
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
# Utility functions for common error scenarios
|
||||
def handle_usage_limit_error(
|
||||
user_id: str,
|
||||
provider: APIProvider,
|
||||
limit_type: str,
|
||||
current_usage: Union[int, float],
|
||||
limit_value: Union[int, float],
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle usage limit exceeded errors."""
|
||||
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
error = UsageLimitExceededException(
|
||||
message=f"Usage limit exceeded for {limit_type}",
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
limit_type=limit_type,
|
||||
current_usage=current_usage,
|
||||
limit_value=limit_value
|
||||
)
|
||||
|
||||
return handler.handle_exception(error, log_level="warning")
|
||||
|
||||
def handle_pricing_error(
|
||||
message: str,
|
||||
provider: APIProvider = None,
|
||||
model_name: str = None,
|
||||
original_error: Exception = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle pricing calculation errors."""
|
||||
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
error = PricingException(
|
||||
message=message,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
return handler.handle_exception(error)
|
||||
|
||||
def handle_tracking_error(
|
||||
message: str,
|
||||
user_id: str = None,
|
||||
provider: APIProvider = None,
|
||||
original_error: Exception = None,
|
||||
db: Session = None
|
||||
) -> Dict[str, Any]:
|
||||
"""Handle usage tracking errors."""
|
||||
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
error = TrackingException(
|
||||
message=message,
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
original_error=original_error
|
||||
)
|
||||
|
||||
return handler.handle_exception(error)
|
||||
|
||||
def log_usage_event(
|
||||
user_id: str,
|
||||
provider: APIProvider,
|
||||
action: str,
|
||||
details: Dict[str, Any] = None
|
||||
):
|
||||
"""Log usage events for monitoring and debugging."""
|
||||
|
||||
details = details or {}
|
||||
log_data = {
|
||||
"user_id": user_id,
|
||||
"provider": provider.value,
|
||||
"action": action,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
**details
|
||||
}
|
||||
|
||||
logger.info(f"Usage Tracking: {action}", extra={"usage_data": log_data})
|
||||
|
||||
# Decorator for automatic exception handling
|
||||
def handle_subscription_errors(db: Session = None):
|
||||
"""Decorator to automatically handle subscription-related exceptions."""
|
||||
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except SubscriptionException as e:
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
return handler.handle_exception(e)
|
||||
except Exception as e:
|
||||
handler = SubscriptionExceptionHandler(db)
|
||||
return handler.handle_exception(e)
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
460
backend/services/usage_tracking_service.py
Normal file
460
backend/services/usage_tracking_service.py
Normal file
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
Usage Tracking Service
|
||||
Comprehensive tracking of API usage, costs, and subscription limits.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, Optional, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
from models.subscription_models import (
|
||||
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
|
||||
UserSubscription, UsageStatus
|
||||
)
|
||||
from services.pricing_service import PricingService
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking API usage and managing subscription limits."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.pricing_service = PricingService(db)
|
||||
|
||||
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Track an API usage event and update billing information."""
|
||||
|
||||
try:
|
||||
# Calculate costs
|
||||
cost_data = self.pricing_service.calculate_api_cost(
|
||||
provider=provider,
|
||||
model_name=model_used or f"{provider.value}-default",
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
request_count=1,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
# Create usage log entry
|
||||
billing_period = datetime.now().strftime("%Y-%m")
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
endpoint=endpoint,
|
||||
method=method,
|
||||
model_used=model_used,
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
tokens_total=tokens_input + tokens_output,
|
||||
cost_input=cost_data['cost_input'],
|
||||
cost_output=cost_data['cost_output'],
|
||||
cost_total=cost_data['cost_total'],
|
||||
response_time=response_time,
|
||||
status_code=status_code,
|
||||
request_size=request_size,
|
||||
response_size=response_size,
|
||||
user_agent=user_agent,
|
||||
ip_address=ip_address,
|
||||
error_message=error_message,
|
||||
retry_count=retry_count,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(usage_log)
|
||||
|
||||
# Update usage summary
|
||||
await self._update_usage_summary(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_used=tokens_input + tokens_output,
|
||||
cost=cost_data['cost_total'],
|
||||
billing_period=billing_period,
|
||||
response_time=response_time,
|
||||
is_error=status_code >= 400
|
||||
)
|
||||
|
||||
# Check for usage alerts
|
||||
await self._check_usage_alerts(user_id, provider, billing_period)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
|
||||
|
||||
return {
|
||||
'usage_logged': True,
|
||||
'cost': cost_data['cost_total'],
|
||||
'tokens_used': tokens_input + tokens_output,
|
||||
'billing_period': billing_period
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error tracking API usage: {str(e)}")
|
||||
self.db.rollback()
|
||||
return {
|
||||
'usage_logged': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
|
||||
tokens_used: int, cost: float, billing_period: str,
|
||||
response_time: float, is_error: bool):
|
||||
"""Update the usage summary for a user."""
|
||||
|
||||
# Get or create usage summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=billing_period
|
||||
)
|
||||
self.db.add(summary)
|
||||
|
||||
# Update provider-specific counters
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
setattr(summary, f"{provider_name}_calls", current_calls + 1)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
|
||||
|
||||
# Update cost
|
||||
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
|
||||
setattr(summary, f"{provider_name}_cost", current_cost + cost)
|
||||
|
||||
# Update totals
|
||||
summary.total_calls += 1
|
||||
summary.total_tokens += tokens_used
|
||||
summary.total_cost += cost
|
||||
|
||||
# Update performance metrics
|
||||
if summary.total_calls > 0:
|
||||
# Update average response time
|
||||
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
|
||||
summary.avg_response_time = total_response_time / summary.total_calls
|
||||
|
||||
# Update error rate
|
||||
if is_error:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
else:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
|
||||
# Update usage status based on limits
|
||||
await self._update_usage_status(summary)
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
async def _update_usage_status(self, summary: UsageSummary):
|
||||
"""Update usage status based on subscription limits."""
|
||||
|
||||
limits = self.pricing_service.get_user_limits(summary.user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check various limits and determine status
|
||||
max_usage_percentage = 0.0
|
||||
|
||||
# Check cost limit
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0:
|
||||
cost_usage_pct = (summary.total_cost / cost_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
|
||||
|
||||
# Check call limits for each provider
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
call_usage_pct = (current_calls / call_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
|
||||
|
||||
# Update status based on highest usage percentage
|
||||
if max_usage_percentage >= 100:
|
||||
summary.usage_status = UsageStatus.LIMIT_REACHED
|
||||
elif max_usage_percentage >= 80:
|
||||
summary.usage_status = UsageStatus.WARNING
|
||||
else:
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
|
||||
# Get current usage
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
await self._create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
|
||||
if not billing_period:
|
||||
billing_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get usage summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Get recent alerts
|
||||
alerts = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(10).all()
|
||||
|
||||
if not summary:
|
||||
# No usage this period
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': {},
|
||||
'alerts': [],
|
||||
'usage_percentages': {}
|
||||
}
|
||||
|
||||
# Calculate usage percentages
|
||||
usage_percentages = {}
|
||||
if limits:
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentages[f"{provider_name}_calls"] = (current_calls / call_limit) * 100
|
||||
else:
|
||||
usage_percentages[f"{provider_name}_calls"] = 0
|
||||
|
||||
# Cost usage percentage
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (summary.total_cost / cost_limit) * 100
|
||||
else:
|
||||
usage_percentages['cost'] = 0
|
||||
|
||||
# Provider breakdown
|
||||
provider_breakdown = {}
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
provider_breakdown[provider_name] = {
|
||||
'calls': getattr(summary, f"{provider_name}_calls", 0),
|
||||
'tokens': getattr(summary, f"{provider_name}_tokens", 0),
|
||||
'cost': getattr(summary, f"{provider_name}_cost", 0.0)
|
||||
}
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': summary.usage_status.value,
|
||||
'total_calls': summary.total_calls,
|
||||
'total_tokens': summary.total_tokens,
|
||||
'total_cost': summary.total_cost,
|
||||
'avg_response_time': summary.avg_response_time,
|
||||
'error_rate': summary.error_rate,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'alerts': [
|
||||
{
|
||||
'id': alert.id,
|
||||
'type': alert.alert_type,
|
||||
'title': alert.title,
|
||||
'message': alert.message,
|
||||
'severity': alert.severity,
|
||||
'created_at': alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
],
|
||||
'usage_percentages': usage_percentages,
|
||||
'last_updated': summary.updated_at.isoformat()
|
||||
}
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time."""
|
||||
|
||||
# Calculate billing periods
|
||||
end_date = datetime.now()
|
||||
periods = []
|
||||
for i in range(months):
|
||||
period_date = end_date - timedelta(days=30 * i)
|
||||
periods.append(period_date.strftime("%Y-%m"))
|
||||
|
||||
periods.reverse() # Oldest first
|
||||
|
||||
# Get usage summaries for these periods
|
||||
summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(periods)
|
||||
).order_by(UsageSummary.billing_period).all()
|
||||
|
||||
# Create trends data
|
||||
trends = {
|
||||
'periods': periods,
|
||||
'total_calls': [],
|
||||
'total_cost': [],
|
||||
'total_tokens': [],
|
||||
'provider_trends': {}
|
||||
}
|
||||
|
||||
summary_dict = {s.billing_period: s for s in summaries}
|
||||
|
||||
for period in periods:
|
||||
summary = summary_dict.get(period)
|
||||
|
||||
if summary:
|
||||
trends['total_calls'].append(summary.total_calls)
|
||||
trends['total_cost'].append(summary.total_cost)
|
||||
trends['total_tokens'].append(summary.total_tokens)
|
||||
|
||||
# Provider-specific trends
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
if provider_name not in trends['provider_trends']:
|
||||
trends['provider_trends'][provider_name] = {
|
||||
'calls': [],
|
||||
'cost': [],
|
||||
'tokens': []
|
||||
}
|
||||
|
||||
trends['provider_trends'][provider_name]['calls'].append(
|
||||
getattr(summary, f"{provider_name}_calls", 0)
|
||||
)
|
||||
trends['provider_trends'][provider_name]['cost'].append(
|
||||
getattr(summary, f"{provider_name}_cost", 0.0)
|
||||
)
|
||||
trends['provider_trends'][provider_name]['tokens'].append(
|
||||
getattr(summary, f"{provider_name}_tokens", 0)
|
||||
)
|
||||
else:
|
||||
# No data for this period
|
||||
trends['total_calls'].append(0)
|
||||
trends['total_cost'].append(0.0)
|
||||
trends['total_tokens'].append(0)
|
||||
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
if provider_name not in trends['provider_trends']:
|
||||
trends['provider_trends'][provider_name] = {
|
||||
'calls': [],
|
||||
'cost': [],
|
||||
'tokens': []
|
||||
}
|
||||
|
||||
trends['provider_trends'][provider_name]['calls'].append(0)
|
||||
trends['provider_trends'][provider_name]['cost'].append(0.0)
|
||||
trends['provider_trends'][provider_name]['tokens'].append(0)
|
||||
|
||||
return trends
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
|
||||
return self.pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
Reference in New Issue
Block a user