SEO Dashboard Fixes and content planning refactoring

This commit is contained in:
ajaysi
2025-10-29 17:10:48 +05:30
parent 5866f49325
commit 4431cd9848
92 changed files with 7046 additions and 1940 deletions

View File

@@ -0,0 +1,185 @@
# Subscription Services Package
## Overview
This package consolidates all subscription, billing, and usage tracking related services and middleware into a single, well-organized module. This follows the same architectural pattern as the onboarding package for consistency and maintainability.
## Package Structure
```
backend/services/subscription/
├── __init__.py # Package exports
├── pricing_service.py # API pricing and cost calculations
├── usage_tracking_service.py # Usage tracking and limits
├── exception_handler.py # Exception handling
├── monitoring_middleware.py # API monitoring with usage tracking
└── README.md # This documentation
```
## Services
### PricingService
- **File**: `pricing_service.py`
- **Purpose**: Manages API pricing, cost calculation, and subscription limits
- **Key Features**:
- Dynamic pricing based on API provider and model
- Cost calculation for input/output tokens
- Subscription limit enforcement
- Billing period management
### UsageTrackingService
- **File**: `usage_tracking_service.py`
- **Purpose**: Comprehensive tracking of API usage, costs, and subscription limits
- **Key Features**:
- Real-time usage tracking
- Cost calculation and billing
- Usage limit enforcement with TTL caching
- Usage alerts and notifications
### SubscriptionExceptionHandler
- **File**: `exception_handler.py`
- **Purpose**: Centralized exception handling for subscription-related errors
- **Key Features**:
- Custom exception types
- Error handling decorators
- Consistent error responses
### Monitoring Middleware
- **File**: `monitoring_middleware.py`
- **Purpose**: FastAPI middleware for API monitoring and usage tracking
- **Key Features**:
- Request/response monitoring
- Usage tracking integration
- Performance metrics
- Database API monitoring
## Usage
### Import Pattern
Always use the consolidated package for subscription-related imports:
```python
# ✅ Correct - Use consolidated package
from services.subscription import PricingService, UsageTrackingService
from services.subscription import SubscriptionExceptionHandler
from services.subscription import check_usage_limits_middleware
# ❌ Incorrect - Old scattered imports
from services.pricing_service import PricingService
from services.usage_tracking_service import UsageTrackingService
from middleware.monitoring_middleware import check_usage_limits_middleware
```
### Service Initialization
```python
from services.subscription import PricingService, UsageTrackingService
from services.database import get_db
# Get database session
db = next(get_db())
# Initialize services
pricing_service = PricingService(db)
usage_service = UsageTrackingService(db)
```
### Middleware Registration
```python
from services.subscription import monitoring_middleware
# Register middleware in FastAPI app
app.middleware("http")(monitoring_middleware)
```
## Database Models
The subscription services use the following database models (defined in `backend/models/subscription_models.py`):
- `APIProvider` - API provider enumeration
- `SubscriptionPlan` - Subscription plan definitions
- `UserSubscription` - User subscription records
- `UsageSummary` - Usage summary by billing period
- `APIUsageLog` - Individual API usage logs
- `APIProviderPricing` - Pricing configuration
- `UsageAlert` - Usage limit alerts
- `SubscriptionTier` - Subscription tier definitions
- `BillingCycle` - Billing cycle enumeration
- `UsageStatus` - Usage status enumeration
## Key Features
### 1. Database-Only Persistence
- All data stored in database tables
- No file-based storage
- User-isolated data access
### 2. TTL Caching
- In-memory caching for performance
- 30-second TTL for usage limit checks
- 10-minute TTL for dashboard data
### 3. Real-time Monitoring
- Live API usage tracking
- Performance metrics collection
- Error rate monitoring
### 4. Flexible Pricing
- Per-provider pricing configuration
- Model-specific pricing
- Dynamic cost calculation
## Error Handling
The package provides comprehensive error handling:
```python
from services.subscription import (
SubscriptionException,
UsageLimitExceededException,
PricingException,
TrackingException
)
try:
# Subscription operation
pass
except UsageLimitExceededException as e:
# Handle usage limit exceeded
pass
except PricingException as e:
# Handle pricing error
pass
```
## Configuration
The services use environment variables for configuration:
- `SUBSCRIPTION_DASHBOARD_NOCACHE` - Bypass dashboard cache
- `ENABLE_ALPHA` - Enable alpha features (default: false)
## Migration from Old Structure
This package consolidates the following previously scattered files:
- `backend/services/pricing_service.py``subscription/pricing_service.py`
- `backend/services/usage_tracking_service.py``subscription/usage_tracking_service.py`
- `backend/services/subscription_exception_handler.py``subscription/exception_handler.py`
- `backend/middleware/monitoring_middleware.py``subscription/monitoring_middleware.py`
## Benefits
1. **Single Package**: All subscription logic in one location
2. **Clear Ownership**: Easy to find subscription-related code
3. **Better Organization**: Follows same pattern as onboarding
4. **Easier Maintenance**: Single source of truth for billing logic
5. **Consistent Architecture**: Matches onboarding consolidation
## Related Packages
- `services.onboarding` - Onboarding and user setup
- `models.subscription_models` - Database models
- `api.subscription_api` - API endpoints

View File

@@ -0,0 +1,40 @@
# Subscription Services Package
# Consolidated subscription-related services and middleware
from .pricing_service import PricingService
from .usage_tracking_service import UsageTrackingService
from .exception_handler import (
SubscriptionException,
SubscriptionExceptionHandler,
UsageLimitExceededException,
PricingException,
TrackingException,
handle_usage_limit_error,
handle_pricing_error,
handle_tracking_error,
)
from .monitoring_middleware import (
DatabaseAPIMonitor,
check_usage_limits_middleware,
monitoring_middleware,
get_monitoring_stats,
get_lightweight_stats,
)
__all__ = [
"PricingService",
"UsageTrackingService",
"SubscriptionException",
"SubscriptionExceptionHandler",
"UsageLimitExceededException",
"PricingException",
"TrackingException",
"handle_usage_limit_error",
"handle_pricing_error",
"handle_tracking_error",
"DatabaseAPIMonitor",
"check_usage_limits_middleware",
"monitoring_middleware",
"get_monitoring_stats",
"get_lightweight_stats",
]

View File

@@ -0,0 +1,412 @@
"""
Comprehensive Exception Handling and Logging for Subscription System
Provides robust error handling, logging, and monitoring for the usage-based subscription system.
"""
import traceback
import json
from datetime import datetime
from typing import Dict, Any, Optional, Union, List
from enum import Enum
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from models.subscription_models import APIProvider, UsageAlert
class SubscriptionErrorType(Enum):
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
PRICING_ERROR = "pricing_error"
TRACKING_ERROR = "tracking_error"
DATABASE_ERROR = "database_error"
API_PROVIDER_ERROR = "api_provider_error"
AUTHENTICATION_ERROR = "authentication_error"
BILLING_ERROR = "billing_error"
CONFIGURATION_ERROR = "configuration_error"
class SubscriptionErrorSeverity(Enum):
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class SubscriptionException(Exception):
"""Base exception for subscription system errors."""
def __init__(
self,
message: str,
error_type: SubscriptionErrorType,
severity: SubscriptionErrorSeverity = SubscriptionErrorSeverity.MEDIUM,
user_id: str = None,
provider: APIProvider = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
self.message = message
self.error_type = error_type
self.severity = severity
self.user_id = user_id
self.provider = provider
self.context = context or {}
self.original_error = original_error
self.timestamp = datetime.utcnow()
super().__init__(message)
def to_dict(self) -> Dict[str, Any]:
"""Convert exception to dictionary for logging/storage."""
return {
"message": self.message,
"error_type": self.error_type.value,
"severity": self.severity.value,
"user_id": self.user_id,
"provider": self.provider.value if self.provider else None,
"context": self.context,
"timestamp": self.timestamp.isoformat(),
"original_error": str(self.original_error) if self.original_error else None,
"traceback": traceback.format_exc() if self.original_error else None
}
class UsageLimitExceededException(SubscriptionException):
"""Exception raised when usage limits are exceeded."""
def __init__(
self,
message: str,
user_id: str,
provider: APIProvider,
limit_type: str,
current_usage: Union[int, float],
limit_value: Union[int, float],
context: Dict[str, Any] = None
):
context = context or {}
context.update({
"limit_type": limit_type,
"current_usage": current_usage,
"limit_value": limit_value,
"usage_percentage": (current_usage / max(limit_value, 1)) * 100
})
super().__init__(
message=message,
error_type=SubscriptionErrorType.USAGE_LIMIT_EXCEEDED,
severity=SubscriptionErrorSeverity.HIGH,
user_id=user_id,
provider=provider,
context=context
)
class PricingException(SubscriptionException):
"""Exception raised for pricing calculation errors."""
def __init__(
self,
message: str,
provider: APIProvider = None,
model_name: str = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
context = context or {}
if model_name:
context["model_name"] = model_name
super().__init__(
message=message,
error_type=SubscriptionErrorType.PRICING_ERROR,
severity=SubscriptionErrorSeverity.MEDIUM,
provider=provider,
context=context,
original_error=original_error
)
class TrackingException(SubscriptionException):
"""Exception raised for usage tracking errors."""
def __init__(
self,
message: str,
user_id: str = None,
provider: APIProvider = None,
context: Dict[str, Any] = None,
original_error: Exception = None
):
super().__init__(
message=message,
error_type=SubscriptionErrorType.TRACKING_ERROR,
severity=SubscriptionErrorSeverity.MEDIUM,
user_id=user_id,
provider=provider,
context=context,
original_error=original_error
)
class SubscriptionExceptionHandler:
"""Comprehensive exception handler for the subscription system."""
def __init__(self, db: Session = None):
self.db = db
self._setup_logging()
def _setup_logging(self):
"""Setup structured logging for subscription errors."""
from utils.logger_utils import get_service_logger
return get_service_logger("subscription_exception_handler")
def handle_exception(
self,
error: Union[Exception, SubscriptionException],
context: Dict[str, Any] = None,
log_level: str = "error"
) -> Dict[str, Any]:
"""Handle and log subscription system exceptions."""
context = context or {}
# Convert regular exceptions to SubscriptionException
if not isinstance(error, SubscriptionException):
error = SubscriptionException(
message=str(error),
error_type=self._classify_error(error),
severity=self._determine_severity(error),
context=context,
original_error=error
)
# Log the error
error_data = error.to_dict()
error_data.update(context)
log_message = f"Subscription Error: {error.message}"
if log_level == "critical":
logger.critical(log_message, extra={"error_data": error_data})
elif log_level == "error":
logger.error(log_message, extra={"error_data": error_data})
elif log_level == "warning":
logger.warning(log_message, extra={"error_data": error_data})
else:
logger.info(log_message, extra={"error_data": error_data})
# Store critical errors in database for alerting
if error.severity in [SubscriptionErrorSeverity.HIGH, SubscriptionErrorSeverity.CRITICAL]:
self._store_error_alert(error)
# Return formatted error response
return self._format_error_response(error)
def _classify_error(self, error: Exception) -> SubscriptionErrorType:
"""Classify an exception into a subscription error type."""
error_str = str(error).lower()
error_type_name = type(error).__name__.lower()
if "limit" in error_str or "exceeded" in error_str:
return SubscriptionErrorType.USAGE_LIMIT_EXCEEDED
elif "pricing" in error_str or "cost" in error_str:
return SubscriptionErrorType.PRICING_ERROR
elif "tracking" in error_str or "usage" in error_str:
return SubscriptionErrorType.TRACKING_ERROR
elif "database" in error_str or "sql" in error_type_name:
return SubscriptionErrorType.DATABASE_ERROR
elif "api" in error_str or "provider" in error_str:
return SubscriptionErrorType.API_PROVIDER_ERROR
elif "auth" in error_str or "permission" in error_str:
return SubscriptionErrorType.AUTHENTICATION_ERROR
elif "billing" in error_str or "payment" in error_str:
return SubscriptionErrorType.BILLING_ERROR
else:
return SubscriptionErrorType.CONFIGURATION_ERROR
def _determine_severity(self, error: Exception) -> SubscriptionErrorSeverity:
"""Determine the severity of an error."""
error_str = str(error).lower()
error_type = type(error)
# Critical errors
if isinstance(error, (SQLAlchemyError, ConnectionError)):
return SubscriptionErrorSeverity.CRITICAL
# High severity errors
if "limit exceeded" in error_str or "unauthorized" in error_str:
return SubscriptionErrorSeverity.HIGH
# Medium severity errors
if "pricing" in error_str or "tracking" in error_str:
return SubscriptionErrorSeverity.MEDIUM
# Default to low
return SubscriptionErrorSeverity.LOW
def _store_error_alert(self, error: SubscriptionException):
"""Store critical errors as alerts in the database."""
if not self.db or not error.user_id:
return
try:
alert = UsageAlert(
user_id=error.user_id,
alert_type="system_error",
threshold_percentage=0,
provider=error.provider,
title=f"System Error: {error.error_type.value}",
message=error.message,
severity=error.severity.value,
billing_period=datetime.now().strftime("%Y-%m")
)
self.db.add(alert)
self.db.commit()
except Exception as e:
logger.error(f"Failed to store error alert: {e}")
def _format_error_response(self, error: SubscriptionException) -> Dict[str, Any]:
"""Format error for API response."""
response = {
"success": False,
"error": {
"type": error.error_type.value,
"message": error.message,
"severity": error.severity.value,
"timestamp": error.timestamp.isoformat()
}
}
# Add context for debugging (non-sensitive info only)
if error.context:
safe_context = {
k: v for k, v in error.context.items()
if k not in ["password", "token", "key", "secret"]
}
response["error"]["context"] = safe_context
# Add user-friendly message based on error type
user_messages = {
SubscriptionErrorType.USAGE_LIMIT_EXCEEDED:
"You have reached your usage limit. Please upgrade your plan or wait for the next billing cycle.",
SubscriptionErrorType.PRICING_ERROR:
"There was an issue calculating the cost for this request. Please try again.",
SubscriptionErrorType.TRACKING_ERROR:
"Unable to track usage for this request. Please contact support if this persists.",
SubscriptionErrorType.DATABASE_ERROR:
"A database error occurred. Please try again later.",
SubscriptionErrorType.API_PROVIDER_ERROR:
"There was an issue with the API provider. Please try again.",
SubscriptionErrorType.AUTHENTICATION_ERROR:
"Authentication failed. Please check your credentials.",
SubscriptionErrorType.BILLING_ERROR:
"There was a billing-related error. Please contact support.",
SubscriptionErrorType.CONFIGURATION_ERROR:
"System configuration error. Please contact support."
}
response["error"]["user_message"] = user_messages.get(
error.error_type,
"An unexpected error occurred. Please try again or contact support."
)
return response
# Utility functions for common error scenarios
def handle_usage_limit_error(
user_id: str,
provider: APIProvider,
limit_type: str,
current_usage: Union[int, float],
limit_value: Union[int, float],
db: Session = None
) -> Dict[str, Any]:
"""Handle usage limit exceeded errors."""
handler = SubscriptionExceptionHandler(db)
error = UsageLimitExceededException(
message=f"Usage limit exceeded for {limit_type}",
user_id=user_id,
provider=provider,
limit_type=limit_type,
current_usage=current_usage,
limit_value=limit_value
)
return handler.handle_exception(error, log_level="warning")
def handle_pricing_error(
message: str,
provider: APIProvider = None,
model_name: str = None,
original_error: Exception = None,
db: Session = None
) -> Dict[str, Any]:
"""Handle pricing calculation errors."""
handler = SubscriptionExceptionHandler(db)
error = PricingException(
message=message,
provider=provider,
model_name=model_name,
original_error=original_error
)
return handler.handle_exception(error)
def handle_tracking_error(
message: str,
user_id: str = None,
provider: APIProvider = None,
original_error: Exception = None,
db: Session = None
) -> Dict[str, Any]:
"""Handle usage tracking errors."""
handler = SubscriptionExceptionHandler(db)
error = TrackingException(
message=message,
user_id=user_id,
provider=provider,
original_error=original_error
)
return handler.handle_exception(error)
def log_usage_event(
user_id: str,
provider: APIProvider,
action: str,
details: Dict[str, Any] = None
):
"""Log usage events for monitoring and debugging."""
details = details or {}
log_data = {
"user_id": user_id,
"provider": provider.value,
"action": action,
"timestamp": datetime.utcnow().isoformat(),
**details
}
logger.info(f"Usage Tracking: {action}", extra={"usage_data": log_data})
# Decorator for automatic exception handling
def handle_subscription_errors(db: Session = None):
"""Decorator to automatically handle subscription-related exceptions."""
def decorator(func):
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except SubscriptionException as e:
handler = SubscriptionExceptionHandler(db)
return handler.handle_exception(e)
except Exception as e:
handler = SubscriptionExceptionHandler(db)
return handler.handle_exception(e)
return wrapper
return decorator

View File

@@ -0,0 +1,373 @@
"""
Enhanced FastAPI Monitoring Middleware
Database-backed monitoring for API calls, errors, performance metrics, and usage tracking.
Includes comprehensive subscription-based usage monitoring and cost tracking.
"""
from fastapi import Request, Response
from fastapi.responses import JSONResponse
import time
import json
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional
from collections import defaultdict, deque
import asyncio
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
import re
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
from models.subscription_models import APIProvider
from services.database import get_db
from .usage_tracking_service import UsageTrackingService
from .pricing_service import PricingService
class DatabaseAPIMonitor:
"""Database-backed API monitoring with usage tracking and subscription management."""
def __init__(self):
self.cache_stats = {
'hits': 0,
'misses': 0,
'hit_rate': 0.0
}
# API provider detection patterns - Updated to match actual endpoints
self.provider_patterns = {
APIProvider.GEMINI: [
r'gemini', r'google.*ai'
],
APIProvider.OPENAI: [r'openai', r'gpt', r'chatgpt'],
APIProvider.ANTHROPIC: [r'anthropic', r'claude'],
APIProvider.MISTRAL: [r'mistral'],
APIProvider.TAVILY: [r'tavily'],
APIProvider.SERPER: [r'serper'],
APIProvider.METAPHOR: [r'metaphor', r'/exa'],
APIProvider.FIRECRAWL: [r'firecrawl']
}
def detect_api_provider(self, path: str, user_agent: str = None) -> Optional[APIProvider]:
"""Detect which API provider is being used based on request details."""
path_lower = path.lower()
user_agent_lower = (user_agent or '').lower()
# Permanently ignore internal route families that must not accrue or check provider usage
if path_lower.startswith('/api/onboarding/') or path_lower.startswith('/api/subscription/'):
return None
for provider, patterns in self.provider_patterns.items():
for pattern in patterns:
if re.search(pattern, path_lower) or re.search(pattern, user_agent_lower):
return provider
return None
def extract_usage_metrics(self, request_body: str = None, response_body: str = None) -> Dict[str, Any]:
"""Extract usage metrics from request/response bodies."""
metrics = {
'tokens_input': 0,
'tokens_output': 0,
'model_used': None,
'search_count': 0,
'image_count': 0,
'page_count': 0
}
try:
# Try to parse request body for input tokens/content
if request_body:
request_data = json.loads(request_body) if isinstance(request_body, str) else request_body
# Extract model information
if 'model' in request_data:
metrics['model_used'] = request_data['model']
# Estimate input tokens from prompt/content
if 'prompt' in request_data:
metrics['tokens_input'] = self._estimate_tokens(request_data['prompt'])
elif 'messages' in request_data:
total_content = ' '.join([msg.get('content', '') for msg in request_data['messages']])
metrics['tokens_input'] = self._estimate_tokens(total_content)
elif 'input' in request_data:
metrics['tokens_input'] = self._estimate_tokens(str(request_data['input']))
# Count specific request types
if 'query' in request_data or 'search' in request_data:
metrics['search_count'] = 1
if 'image' in request_data or 'generate_image' in request_data:
metrics['image_count'] = 1
if 'url' in request_data or 'crawl' in request_data:
metrics['page_count'] = 1
# Try to parse response body for output tokens
if response_body:
response_data = json.loads(response_body) if isinstance(response_body, str) else response_body
# Extract output content and estimate tokens
if 'text' in response_data:
metrics['tokens_output'] = self._estimate_tokens(response_data['text'])
elif 'content' in response_data:
metrics['tokens_output'] = self._estimate_tokens(str(response_data['content']))
elif 'choices' in response_data and response_data['choices']:
choice = response_data['choices'][0]
if 'message' in choice and 'content' in choice['message']:
metrics['tokens_output'] = self._estimate_tokens(choice['message']['content'])
# Extract actual token usage if provided by API
if 'usage' in response_data:
usage = response_data['usage']
if 'prompt_tokens' in usage:
metrics['tokens_input'] = usage['prompt_tokens']
if 'completion_tokens' in usage:
metrics['tokens_output'] = usage['completion_tokens']
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.debug(f"Could not extract usage metrics: {e}")
return metrics
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count for text (rough approximation)."""
if not text:
return 0
# Rough estimation: 1.3 tokens per word on average
word_count = len(str(text).split())
return int(word_count * 1.3)
async def check_usage_limits_middleware(request: Request, user_id: str, request_body: str = None) -> Optional[JSONResponse]:
"""Check usage limits before processing request."""
if not user_id:
return None
# No special whitelist; onboarding/subscription are ignored by provider detection
try:
path = request.url.path
except Exception:
pass
try:
db = next(get_db())
api_monitor = DatabaseAPIMonitor()
# Detect if this is an API call that should be rate limited
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
if not api_provider:
return None
# Use provided request body or read it if not provided
if request_body is None:
try:
if hasattr(request, '_body'):
request_body = request._body
else:
# Try to read body (this might not work in all cases)
body = await request.body()
request_body = body.decode('utf-8') if body else None
except:
pass
# Estimate tokens needed
tokens_requested = 0
if request_body:
usage_metrics = api_monitor.extract_usage_metrics(request_body)
tokens_requested = usage_metrics.get('tokens_input', 0)
# Check limits
usage_service = UsageTrackingService(db)
can_proceed, message, usage_info = await usage_service.enforce_usage_limits(
user_id=user_id,
provider=api_provider,
tokens_requested=tokens_requested
)
if not can_proceed:
logger.warning(f"Usage limit exceeded for {user_id}: {message}")
return JSONResponse(
status_code=429,
content={
"error": "Usage limit exceeded",
"message": message,
"usage_info": usage_info,
"provider": api_provider.value
}
)
# Warn if approaching limits
if usage_info.get('call_usage_percentage', 0) >= 80 or usage_info.get('cost_usage_percentage', 0) >= 80:
logger.warning(f"User {user_id} approaching usage limits: {usage_info}")
return None
except Exception as e:
logger.error(f"Error checking usage limits: {e}")
# Don't block requests if usage checking fails
return None
finally:
db.close()
async def monitoring_middleware(request: Request, call_next):
"""Enhanced FastAPI middleware for monitoring API calls with usage tracking."""
start_time = time.time()
# Get database session
db = next(get_db())
# Extract request details - Enhanced user identification
user_id = None
try:
# PRIORITY 1: Check request.state.user_id (set by API key injection middleware)
if hasattr(request.state, 'user_id') and request.state.user_id:
user_id = request.state.user_id
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
# PRIORITY 2: Check query parameters
elif hasattr(request, 'query_params') and 'user_id' in request.query_params:
user_id = request.query_params['user_id']
elif hasattr(request, 'path_params') and 'user_id' in request.path_params:
user_id = request.path_params['user_id']
# PRIORITY 3: Check headers for user identification
elif 'x-user-id' in request.headers:
user_id = request.headers['x-user-id']
elif 'x-user-email' in request.headers:
user_id = request.headers['x-user-email'] # Use email as user identifier
elif 'x-session-id' in request.headers:
user_id = request.headers['x-session-id'] # Use session as fallback
# Check for authorization header with user info
elif 'authorization' in request.headers:
# Auth middleware should have set request.state.user_id
# If not, this indicates an authentication failure that should be logged
user_id = None
logger.warning("Monitoring: Auth header present but no user_id in state - authentication may have failed")
# Final fallback: None (skip usage limits for truly anonymous/unauthenticated)
else:
user_id = None
except Exception as e:
logger.debug(f"Error extracting user ID: {e}")
user_id = None # On error, skip usage limits
# Capture request body for usage tracking (read once, safely)
request_body = None
try:
# Only read body for POST/PUT/PATCH requests to avoid issues
if request.method in ['POST', 'PUT', 'PATCH']:
if hasattr(request, '_body') and request._body:
request_body = request._body.decode('utf-8')
else:
# Read body only if it hasn't been read yet
try:
body = await request.body()
request_body = body.decode('utf-8') if body else None
except Exception as body_error:
logger.debug(f"Could not read request body: {body_error}")
request_body = None
except Exception as e:
logger.debug(f"Error capturing request body: {e}")
request_body = None
# Check usage limits before processing
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
if limit_response:
return limit_response
try:
response = await call_next(request)
status_code = response.status_code
duration = time.time() - start_time
# Capture response body for usage tracking
response_body = None
try:
if hasattr(response, 'body'):
response_body = response.body.decode('utf-8') if response.body else None
elif hasattr(response, '_content'):
response_body = response._content.decode('utf-8') if response._content else None
except:
pass
# Track API usage if this is an API call to external providers
api_monitor = DatabaseAPIMonitor()
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
if api_provider and user_id:
logger.info(f"Detected API call: {request.url.path} -> {api_provider.value} for user: {user_id}")
try:
# Extract usage metrics
usage_metrics = api_monitor.extract_usage_metrics(request_body, response_body)
# Track usage with the usage tracking service
usage_service = UsageTrackingService(db)
await usage_service.track_api_usage(
user_id=user_id,
provider=api_provider,
endpoint=request.url.path,
method=request.method,
model_used=usage_metrics.get('model_used'),
tokens_input=usage_metrics.get('tokens_input', 0),
tokens_output=usage_metrics.get('tokens_output', 0),
response_time=duration,
status_code=status_code,
request_size=len(request_body) if request_body else None,
response_size=len(response_body) if response_body else None,
user_agent=request.headers.get('user-agent'),
ip_address=request.client.host if request.client else None,
search_count=usage_metrics.get('search_count', 0),
image_count=usage_metrics.get('image_count', 0),
page_count=usage_metrics.get('page_count', 0)
)
except Exception as usage_error:
logger.error(f"Error tracking API usage: {usage_error}")
# Don't fail the main request if usage tracking fails
return response
except Exception as e:
duration = time.time() - start_time
status_code = 500
# Store minimal error info
logger.error(f"API Error: {request.method} {request.url.path} - {str(e)}")
return JSONResponse(
status_code=500,
content={"error": "Internal server error"}
)
finally:
db.close()
async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
"""Get current monitoring statistics."""
db = next(get_db())
try:
# Placeholder to match old API; heavy stats handled elsewhere
return {
'timestamp': datetime.utcnow().isoformat(),
'overview': {
'recent_requests': 0,
'recent_errors': 0,
},
'cache_performance': {'hits': 0, 'misses': 0, 'hit_rate': 0.0},
'recent_errors': [],
'system_health': {'status': 'healthy', 'error_rate': 0.0}
}
finally:
db.close()
async def get_lightweight_stats() -> Dict[str, Any]:
"""Get lightweight stats for dashboard header."""
db = next(get_db())
try:
# Minimal viable placeholder values
now = datetime.utcnow()
return {
'status': 'healthy',
'icon': '🟢',
'recent_requests': 0,
'recent_errors': 0,
'error_rate': 0.0,
'timestamp': now.isoformat()
}
finally:
db.close()

View File

@@ -0,0 +1,594 @@
"""
Pricing Service for API Usage Tracking
Manages API pricing, cost calculation, and subscription limits.
"""
from typing import Dict, Any, Optional, List, Tuple
from decimal import Decimal, ROUND_HALF_UP
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from loguru import logger
from models.subscription_models import (
APIProviderPricing, SubscriptionPlan, UserSubscription,
UsageSummary, APIUsageLog, APIProvider, SubscriptionTier
)
class PricingService:
"""Service for managing API pricing and cost calculations."""
def __init__(self, db: Session):
self.db = db
self._pricing_cache = {}
self._plans_cache = {}
# Lightweight in-process cache for limit checks
# key: f"{user_id}:{provider}", value: { 'result': (bool, str, dict), 'expires_at': datetime }
self._limits_cache: Dict[str, Dict[str, Any]] = {}
# ------------------- Billing period helpers -------------------
def _compute_next_period_end(self, start: datetime, cycle: str) -> datetime:
"""Compute the next period end given a start and billing cycle."""
try:
cycle_value = cycle.value if hasattr(cycle, 'value') else str(cycle)
except Exception:
cycle_value = str(cycle)
if cycle_value == 'yearly':
return start + timedelta(days=365)
return start + timedelta(days=30)
def _ensure_subscription_current(self, subscription) -> bool:
"""Auto-advance subscription period if expired and auto_renew is enabled."""
if not subscription:
return False
now = datetime.utcnow()
try:
if subscription.current_period_end and subscription.current_period_end < now:
if getattr(subscription, 'auto_renew', False):
subscription.current_period_start = now
subscription.current_period_end = self._compute_next_period_end(now, subscription.billing_cycle)
# Keep status active if model enum else string
try:
subscription.status = subscription.status.ACTIVE # type: ignore[attr-defined]
except Exception:
setattr(subscription, 'status', 'active')
self.db.commit()
else:
return False
except Exception:
self.db.rollback()
return True
def get_current_billing_period(self, user_id: str) -> Optional[str]:
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
# Ensure subscription is current (advance if auto_renew)
self._ensure_subscription_current(subscription)
# Continue to use YYYY-MM for summaries
return datetime.now().strftime("%Y-%m")
def initialize_default_pricing(self):
"""Initialize default pricing for all API providers."""
# Gemini API Pricing (Updated as of September 2025 - Official Google AI Pricing)
# Source: https://ai.google.dev/gemini-api/docs/pricing
gemini_pricing = [
# Gemini 2.5 Pro - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-pro",
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (prompts <= 200k tokens)
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens (prompts <= 200k tokens)
"description": "Gemini 2.5 Pro - State-of-the-art multipurpose model for coding and complex reasoning"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-pro-large",
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens (prompts > 200k tokens)
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens (prompts > 200k tokens)
"description": "Gemini 2.5 Pro - Large context model for prompts > 200k tokens"
},
# Gemini 2.5 Flash - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash",
"cost_per_input_token": 0.0000003, # $0.30 per 1M input tokens (text/image/video)
"cost_per_output_token": 0.0000025, # $2.50 per 1M output tokens
"description": "Gemini 2.5 Flash - Hybrid reasoning model with 1M token context window"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash-audio",
"cost_per_input_token": 0.000001, # $1.00 per 1M input tokens (audio)
"cost_per_output_token": 0.0000025, # $2.50 per 1M output tokens
"description": "Gemini 2.5 Flash - Audio input model"
},
# Gemini 2.5 Flash-Lite - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash-lite",
"cost_per_input_token": 0.0000001, # $0.10 per 1M input tokens (text/image/video)
"cost_per_output_token": 0.0000004, # $0.40 per 1M output tokens
"description": "Gemini 2.5 Flash-Lite - Smallest and most cost-effective model for at-scale usage"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-2.5-flash-lite-audio",
"cost_per_input_token": 0.0000003, # $0.30 per 1M input tokens (audio)
"cost_per_output_token": 0.0000004, # $0.40 per 1M output tokens
"description": "Gemini 2.5 Flash-Lite - Audio input model"
},
# Gemini 1.5 Flash - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash",
"cost_per_input_token": 0.000000075, # $0.075 per 1M input tokens (prompts <= 128k tokens)
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens (prompts <= 128k tokens)
"description": "Gemini 1.5 Flash - Fast multimodal model with 1M token context window"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash-large",
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens (prompts > 128k tokens)
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens (prompts > 128k tokens)
"description": "Gemini 1.5 Flash - Large context model for prompts > 128k tokens"
},
# Gemini 1.5 Flash-8B - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash-8b",
"cost_per_input_token": 0.0000000375, # $0.0375 per 1M input tokens (prompts <= 128k tokens)
"cost_per_output_token": 0.00000015, # $0.15 per 1M output tokens (prompts <= 128k tokens)
"description": "Gemini 1.5 Flash-8B - Smallest model for lower intelligence use cases"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-flash-8b-large",
"cost_per_input_token": 0.000000075, # $0.075 per 1M input tokens (prompts > 128k tokens)
"cost_per_output_token": 0.0000003, # $0.30 per 1M output tokens (prompts > 128k tokens)
"description": "Gemini 1.5 Flash-8B - Large context model for prompts > 128k tokens"
},
# Gemini 1.5 Pro - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-pro",
"cost_per_input_token": 0.00000125, # $1.25 per 1M input tokens (prompts <= 128k tokens)
"cost_per_output_token": 0.000005, # $5.00 per 1M output tokens (prompts <= 128k tokens)
"description": "Gemini 1.5 Pro - Highest intelligence model with 2M token context window"
},
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-1.5-pro-large",
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens (prompts > 128k tokens)
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens (prompts > 128k tokens)
"description": "Gemini 1.5 Pro - Large context model for prompts > 128k tokens"
},
# Gemini Embedding - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-embedding",
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
"cost_per_output_token": 0.0, # No output tokens for embeddings
"description": "Gemini Embedding - Newest embeddings model with higher rate limits"
},
# Grounding with Google Search - Standard Tier
{
"provider": APIProvider.GEMINI,
"model_name": "gemini-grounding-search",
"cost_per_request": 0.035, # $35 per 1,000 requests (after free tier)
"cost_per_input_token": 0.0, # No additional token cost for grounding
"cost_per_output_token": 0.0, # No additional token cost for grounding
"description": "Grounding with Google Search - 1,500 RPD free, then $35/1K requests"
}
]
# OpenAI Pricing (estimated, will be updated)
openai_pricing = [
{
"provider": APIProvider.OPENAI,
"model_name": "gpt-4o",
"cost_per_input_token": 0.0000025, # $2.50 per 1M input tokens
"cost_per_output_token": 0.00001, # $10.00 per 1M output tokens
"description": "GPT-4o - Latest OpenAI model"
},
{
"provider": APIProvider.OPENAI,
"model_name": "gpt-4o-mini",
"cost_per_input_token": 0.00000015, # $0.15 per 1M input tokens
"cost_per_output_token": 0.0000006, # $0.60 per 1M output tokens
"description": "GPT-4o Mini - Cost-effective model"
}
]
# Anthropic Pricing (estimated, will be updated)
anthropic_pricing = [
{
"provider": APIProvider.ANTHROPIC,
"model_name": "claude-3.5-sonnet",
"cost_per_input_token": 0.000003, # $3.00 per 1M input tokens
"cost_per_output_token": 0.000015, # $15.00 per 1M output tokens
"description": "Claude 3.5 Sonnet - Anthropic's flagship model"
}
]
# Search API Pricing (estimated)
search_pricing = [
{
"provider": APIProvider.TAVILY,
"model_name": "tavily-search",
"cost_per_request": 0.001, # $0.001 per search
"description": "Tavily AI Search API"
},
{
"provider": APIProvider.SERPER,
"model_name": "serper-search",
"cost_per_request": 0.001, # $0.001 per search
"description": "Serper Google Search API"
},
{
"provider": APIProvider.METAPHOR,
"model_name": "metaphor-search",
"cost_per_request": 0.003, # $0.003 per search
"description": "Metaphor/Exa AI Search API"
},
{
"provider": APIProvider.FIRECRAWL,
"model_name": "firecrawl-extract",
"cost_per_page": 0.002, # $0.002 per page crawled
"description": "Firecrawl Web Extraction API"
},
{
"provider": APIProvider.STABILITY,
"model_name": "stable-diffusion",
"cost_per_image": 0.04, # $0.04 per image
"description": "Stability AI Image Generation"
}
]
# Combine all pricing data
all_pricing = gemini_pricing + openai_pricing + anthropic_pricing + search_pricing
# Insert pricing data
for pricing_data in all_pricing:
existing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == pricing_data["provider"],
APIProviderPricing.model_name == pricing_data["model_name"]
).first()
if not existing:
pricing = APIProviderPricing(**pricing_data)
self.db.add(pricing)
self.db.commit()
logger.debug("Default API pricing initialized")
def initialize_default_plans(self):
"""Initialize default subscription plans."""
plans = [
{
"name": "Free",
"tier": SubscriptionTier.FREE,
"price_monthly": 0.0,
"price_yearly": 0.0,
"gemini_calls_limit": 100,
"openai_calls_limit": 0,
"anthropic_calls_limit": 0,
"mistral_calls_limit": 50,
"tavily_calls_limit": 20,
"serper_calls_limit": 20,
"metaphor_calls_limit": 10,
"firecrawl_calls_limit": 10,
"stability_calls_limit": 5,
"gemini_tokens_limit": 100000,
"monthly_cost_limit": 0.0,
"features": ["basic_content_generation", "limited_research"],
"description": "Perfect for trying out ALwrity"
},
{
"name": "Basic",
"tier": SubscriptionTier.BASIC,
"price_monthly": 29.0,
"price_yearly": 290.0,
"gemini_calls_limit": 1000,
"openai_calls_limit": 500,
"anthropic_calls_limit": 200,
"mistral_calls_limit": 500,
"tavily_calls_limit": 200,
"serper_calls_limit": 200,
"metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100,
"stability_calls_limit": 50,
"gemini_tokens_limit": 1000000,
"openai_tokens_limit": 500000,
"anthropic_tokens_limit": 200000,
"mistral_tokens_limit": 500000,
"monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams"
},
{
"name": "Pro",
"tier": SubscriptionTier.PRO,
"price_monthly": 79.0,
"price_yearly": 790.0,
"gemini_calls_limit": 5000,
"openai_calls_limit": 2500,
"anthropic_calls_limit": 1000,
"mistral_calls_limit": 2500,
"tavily_calls_limit": 1000,
"serper_calls_limit": 1000,
"metaphor_calls_limit": 500,
"firecrawl_calls_limit": 500,
"stability_calls_limit": 200,
"gemini_tokens_limit": 5000000,
"openai_tokens_limit": 2500000,
"anthropic_tokens_limit": 1000000,
"mistral_tokens_limit": 2500000,
"monthly_cost_limit": 150.0,
"features": ["unlimited_content_generation", "premium_research", "advanced_analytics", "priority_support"],
"description": "Perfect for growing businesses"
},
{
"name": "Enterprise",
"tier": SubscriptionTier.ENTERPRISE,
"price_monthly": 199.0,
"price_yearly": 1990.0,
"gemini_calls_limit": 0, # Unlimited
"openai_calls_limit": 0,
"anthropic_calls_limit": 0,
"mistral_calls_limit": 0,
"tavily_calls_limit": 0,
"serper_calls_limit": 0,
"metaphor_calls_limit": 0,
"firecrawl_calls_limit": 0,
"stability_calls_limit": 0,
"gemini_tokens_limit": 0,
"openai_tokens_limit": 0,
"anthropic_tokens_limit": 0,
"mistral_tokens_limit": 0,
"monthly_cost_limit": 500.0,
"features": ["unlimited_everything", "white_label", "dedicated_support", "custom_integrations"],
"description": "For large organizations with high-volume needs"
}
]
for plan_data in plans:
existing = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.name == plan_data["name"]
).first()
if not existing:
plan = SubscriptionPlan(**plan_data)
self.db.add(plan)
self.db.commit()
logger.debug("Default subscription plans initialized")
def calculate_api_cost(self, provider: APIProvider, model_name: str,
tokens_input: int = 0, tokens_output: int = 0,
request_count: int = 1, **kwargs) -> Dict[str, float]:
"""Calculate cost for an API call."""
# Get pricing for the provider and model
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.model_name == model_name,
APIProviderPricing.is_active == True
).first()
if not pricing:
logger.warning(f"No pricing found for {provider.value}:{model_name}, using default estimates")
# Use default estimates
cost_input = tokens_input * 0.000001 # $1 per 1M tokens default
cost_output = tokens_output * 0.000001
cost_total = (cost_input + cost_output) * request_count
else:
# Calculate based on actual pricing
cost_input = tokens_input * pricing.cost_per_input_token
cost_output = tokens_output * pricing.cost_per_output_token
cost_request = request_count * pricing.cost_per_request
# Handle special cases for non-LLM APIs
cost_search = kwargs.get('search_count', 0) * pricing.cost_per_search
cost_image = kwargs.get('image_count', 0) * pricing.cost_per_image
cost_page = kwargs.get('page_count', 0) * pricing.cost_per_page
cost_total = cost_input + cost_output + cost_request + cost_search + cost_image + cost_page
# Round to 6 decimal places for precision
return {
'cost_input': round(cost_input, 6),
'cost_output': round(cost_output, 6),
'cost_total': round(cost_total, 6)
}
def get_user_limits(self, user_id: str) -> Optional[Dict[str, Any]]:
"""Get usage limits for a user based on their subscription."""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id,
UserSubscription.is_active == True
).first()
if not subscription:
# Return free tier limits
free_plan = self.db.query(SubscriptionPlan).filter(
SubscriptionPlan.tier == SubscriptionTier.FREE
).first()
if free_plan:
return self._plan_to_limits_dict(free_plan)
return None
# Ensure current period before returning limits
self._ensure_subscription_current(subscription)
return self._plan_to_limits_dict(subscription.plan)
def _plan_to_limits_dict(self, plan: SubscriptionPlan) -> Dict[str, Any]:
"""Convert subscription plan to limits dictionary."""
return {
'plan_name': plan.name,
'tier': plan.tier.value,
'limits': {
'gemini_calls': plan.gemini_calls_limit,
'openai_calls': plan.openai_calls_limit,
'anthropic_calls': plan.anthropic_calls_limit,
'mistral_calls': plan.mistral_calls_limit,
'tavily_calls': plan.tavily_calls_limit,
'serper_calls': plan.serper_calls_limit,
'metaphor_calls': plan.metaphor_calls_limit,
'firecrawl_calls': plan.firecrawl_calls_limit,
'stability_calls': plan.stability_calls_limit,
'gemini_tokens': plan.gemini_tokens_limit,
'openai_tokens': plan.openai_tokens_limit,
'anthropic_tokens': plan.anthropic_tokens_limit,
'mistral_tokens': plan.mistral_tokens_limit,
'monthly_cost': plan.monthly_cost_limit
},
'features': plan.features or []
}
def check_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Check if user can make an API call within their limits."""
# Short TTL cache to reduce DB reads under sustained traffic
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._limits_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
return tuple(cached['result']) # type: ignore
# Get user limits
limits = self.get_user_limits(user_id)
if not limits:
return False, "No subscription plan found", {}
# Get current usage for this billing period
current_period = self.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not usage:
# First usage this period, create summary
usage = UsageSummary(
user_id=user_id,
billing_period=current_period
)
self.db.add(usage)
self.db.commit()
# Check call limits
provider_name = provider.value
current_calls = getattr(usage, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0 and current_calls >= call_limit:
result = (False, f"API call limit reached for {provider_name}", {
'current_calls': current_calls,
'limit': call_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Check token limits for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(usage, f"{provider_name}_tokens", 0)
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0)
if token_limit > 0 and (current_tokens + tokens_requested) > token_limit:
result = (False, f"Token limit would be exceeded for {provider_name}", {
'current_tokens': current_tokens,
'requested_tokens': tokens_requested,
'limit': token_limit,
'usage_percentage': ((current_tokens + tokens_requested) / token_limit) * 100
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Check cost limits
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0 and usage.total_cost >= cost_limit:
result = (False, "Monthly cost limit reached", {
'current_cost': usage.total_cost,
'limit': cost_limit,
'usage_percentage': 100.0
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
# Calculate usage percentages for warnings
call_usage_pct = (current_calls / max(call_limit, 1)) * 100 if call_limit > 0 else 0
cost_usage_pct = (usage.total_cost / max(cost_limit, 1)) * 100 if cost_limit > 0 else 0
result = (True, "Within limits", {
'current_calls': current_calls,
'call_limit': call_limit,
'call_usage_percentage': call_usage_pct,
'current_cost': usage.total_cost,
'cost_limit': cost_limit,
'cost_usage_percentage': cost_usage_pct
})
self._limits_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
def estimate_tokens(self, text: str, provider: APIProvider) -> int:
"""Estimate token count for text based on provider."""
# Get pricing info for token estimation
pricing = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.is_active == True
).first()
if pricing and pricing.tokens_per_word:
# Use provider-specific conversion
word_count = len(text.split())
return int(word_count * pricing.tokens_per_word)
else:
# Use default estimation (roughly 1.3 tokens per word for most models)
word_count = len(text.split())
return int(word_count * 1.3)
def get_pricing_info(self, provider: APIProvider, model_name: str = None) -> Optional[Dict[str, Any]]:
"""Get pricing information for a provider/model."""
query = self.db.query(APIProviderPricing).filter(
APIProviderPricing.provider == provider,
APIProviderPricing.is_active == True
)
if model_name:
query = query.filter(APIProviderPricing.model_name == model_name)
pricing = query.first()
if not pricing:
return None
return {
'provider': pricing.provider.value,
'model_name': pricing.model_name,
'cost_per_input_token': pricing.cost_per_input_token,
'cost_per_output_token': pricing.cost_per_output_token,
'cost_per_request': pricing.cost_per_request,
'cost_per_search': pricing.cost_per_search,
'cost_per_image': pricing.cost_per_image,
'cost_per_page': pricing.cost_per_page,
'description': pricing.description
}

View File

@@ -0,0 +1,525 @@
"""
Usage Tracking Service
Comprehensive tracking of API usage, costs, and subscription limits.
"""
import asyncio
from typing import Dict, Any, Optional, List, Tuple
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from loguru import logger
import json
from models.subscription_models import (
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
UserSubscription, UsageStatus
)
from .pricing_service import PricingService
class UsageTrackingService:
"""Service for tracking API usage and managing subscription limits."""
def __init__(self, db: Session):
self.db = db
self.pricing_service = PricingService(db)
# TTL cache (30s) for enforcement results to cut DB chatter
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
async def track_api_usage(self, user_id: str, provider: APIProvider,
endpoint: str, method: str, model_used: str = None,
tokens_input: int = 0, tokens_output: int = 0,
response_time: float = 0.0, status_code: int = 200,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None,
error_message: str = None, retry_count: int = 0,
**kwargs) -> Dict[str, Any]:
"""Track an API usage event and update billing information."""
try:
# Calculate costs
# Use specific model names instead of generic defaults
default_models = {
"gemini": "gemini-2.5-flash", # Use Flash as default (cost-effective)
"openai": "gpt-4o-mini", # Use Mini as default (cost-effective)
"anthropic": "claude-3.5-sonnet" # Use Sonnet as default
}
model_name = model_used or default_models.get(provider.value, f"{provider.value}-default")
cost_data = self.pricing_service.calculate_api_cost(
provider=provider,
model_name=model_name,
tokens_input=tokens_input,
tokens_output=tokens_output,
request_count=1,
**kwargs
)
# Create usage log entry
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage_log = APIUsageLog(
user_id=user_id,
provider=provider,
endpoint=endpoint,
method=method,
model_used=model_used,
tokens_input=tokens_input,
tokens_output=tokens_output,
tokens_total=(tokens_input or 0) + (tokens_output or 0),
cost_input=cost_data['cost_input'],
cost_output=cost_data['cost_output'],
cost_total=cost_data['cost_total'],
response_time=response_time,
status_code=status_code,
request_size=request_size,
response_size=response_size,
user_agent=user_agent,
ip_address=ip_address,
error_message=error_message,
retry_count=retry_count,
billing_period=billing_period
)
self.db.add(usage_log)
# Update usage summary
await self._update_usage_summary(
user_id=user_id,
provider=provider,
tokens_used=(tokens_input or 0) + (tokens_output or 0),
cost=cost_data['cost_total'],
billing_period=billing_period,
response_time=response_time,
is_error=status_code >= 400
)
# Check for usage alerts
await self._check_usage_alerts(user_id, provider, billing_period)
self.db.commit()
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
return {
'usage_logged': True,
'cost': cost_data['cost_total'],
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
'billing_period': billing_period
}
except Exception as e:
logger.error(f"Error tracking API usage: {str(e)}")
self.db.rollback()
return {
'usage_logged': False,
'error': str(e)
}
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
tokens_used: int, cost: float, billing_period: str,
response_time: float, is_error: bool):
"""Update the usage summary for a user."""
# Get or create usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=billing_period
)
self.db.add(summary)
# Update provider-specific counters
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
setattr(summary, f"{provider_name}_calls", current_calls + 1)
# Update token usage for LLM providers
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
# Update cost
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
setattr(summary, f"{provider_name}_cost", current_cost + cost)
# Update totals
summary.total_calls += 1
summary.total_tokens += tokens_used
summary.total_cost += cost
# Update performance metrics
if summary.total_calls > 0:
# Update average response time
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
summary.avg_response_time = total_response_time / summary.total_calls
# Update error rate
if is_error:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
summary.error_rate = (error_count / summary.total_calls) * 100
else:
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
summary.error_rate = (error_count / summary.total_calls) * 100
# Update usage status based on limits
await self._update_usage_status(summary)
summary.updated_at = datetime.utcnow()
async def _update_usage_status(self, summary: UsageSummary):
"""Update usage status based on subscription limits."""
limits = self.pricing_service.get_user_limits(summary.user_id)
if not limits:
return
# Check various limits and determine status
max_usage_percentage = 0.0
# Check cost limit
cost_limit = limits['limits'].get('monthly_cost', 0)
if cost_limit > 0:
cost_usage_pct = (summary.total_cost / cost_limit) * 100
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
# Check call limits for each provider
for provider in APIProvider:
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
call_usage_pct = (current_calls / call_limit) * 100
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
# Update status based on highest usage percentage
if max_usage_percentage >= 100:
summary.usage_status = UsageStatus.LIMIT_REACHED
elif max_usage_percentage >= 80:
summary.usage_status = UsageStatus.WARNING
else:
summary.usage_status = UsageStatus.ACTIVE
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
"""Check if usage alerts should be sent."""
# Get current usage
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
return
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
if not limits:
return
# Check for alert thresholds (80%, 90%, 100%)
thresholds = [80, 90, 100]
for threshold in thresholds:
# Check if alert already sent for this threshold
existing_alert = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.threshold_percentage == threshold,
UsageAlert.provider == provider,
UsageAlert.is_sent == True
).first()
if existing_alert:
continue
# Check if threshold is reached
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0)
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
if call_limit > 0:
usage_percentage = (current_calls / call_limit) * 100
if usage_percentage >= threshold:
await self._create_usage_alert(
user_id=user_id,
provider=provider,
threshold=threshold,
current_usage=current_calls,
limit=call_limit,
billing_period=billing_period
)
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
threshold: int, current_usage: int, limit: int,
billing_period: str):
"""Create a usage alert."""
# Determine alert type and severity
if threshold >= 100:
alert_type = "limit_reached"
severity = "error"
title = f"API Limit Reached - {provider.value.title()}"
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
elif threshold >= 90:
alert_type = "usage_warning"
severity = "warning"
title = f"API Usage Warning - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
else:
alert_type = "usage_warning"
severity = "info"
title = f"API Usage Notice - {provider.value.title()}"
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
alert = UsageAlert(
user_id=user_id,
alert_type=alert_type,
threshold_percentage=threshold,
provider=provider,
title=title,
message=message,
severity=severity,
billing_period=billing_period
)
self.db.add(alert)
logger.info(f"Created usage alert for {user_id}: {title}")
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
if not billing_period:
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get usage summary
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
# Get user limits
limits = self.pricing_service.get_user_limits(user_id)
# Get recent alerts
alerts = self.db.query(UsageAlert).filter(
UsageAlert.user_id == user_id,
UsageAlert.billing_period == billing_period,
UsageAlert.is_read == False
).order_by(UsageAlert.created_at.desc()).limit(10).all()
if not summary:
# No usage this period - return complete structure with zeros
provider_breakdown = {}
usage_percentages = {}
# Initialize provider breakdown with zeros
for provider in APIProvider:
provider_name = provider.value
provider_breakdown[provider_name] = {
'calls': 0,
'tokens': 0,
'cost': 0.0
}
usage_percentages[f"{provider_name}_calls"] = 0
usage_percentages['cost'] = 0
return {
'billing_period': billing_period,
'usage_status': 'active',
'total_calls': 0,
'total_tokens': 0,
'total_cost': 0.0,
'avg_response_time': 0.0,
'error_rate': 0.0,
'last_updated': datetime.now().isoformat(),
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [],
'usage_percentages': usage_percentages
}
# Calculate usage percentages
usage_percentages = {}
if limits:
for provider in APIProvider:
provider_name = provider.value
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) or 0
if call_limit > 0:
usage_percentages[f"{provider_name}_calls"] = (current_calls / call_limit) * 100
else:
usage_percentages[f"{provider_name}_calls"] = 0
# Cost usage percentage
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
total_cost = summary.total_cost or 0
if cost_limit > 0:
usage_percentages['cost'] = (total_cost / cost_limit) * 100
else:
usage_percentages['cost'] = 0
# Provider breakdown
provider_breakdown = {}
for provider in APIProvider:
provider_name = provider.value
provider_breakdown[provider_name] = {
'calls': getattr(summary, f"{provider_name}_calls", 0) or 0,
'tokens': getattr(summary, f"{provider_name}_tokens", 0) or 0,
'cost': getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
}
return {
'billing_period': billing_period,
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
'total_calls': summary.total_calls or 0,
'total_tokens': summary.total_tokens or 0,
'total_cost': summary.total_cost or 0.0,
'avg_response_time': summary.avg_response_time or 0.0,
'error_rate': summary.error_rate or 0.0,
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [
{
'id': alert.id,
'type': alert.alert_type,
'title': alert.title,
'message': alert.message,
'severity': alert.severity,
'created_at': alert.created_at.isoformat()
}
for alert in alerts
],
'usage_percentages': usage_percentages,
'last_updated': summary.updated_at.isoformat()
}
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time."""
# Calculate billing periods
end_date = datetime.now()
periods = []
for i in range(months):
period_date = end_date - timedelta(days=30 * i)
periods.append(period_date.strftime("%Y-%m"))
periods.reverse() # Oldest first
# Get usage summaries for these periods
summaries = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(periods)
).order_by(UsageSummary.billing_period).all()
# Create trends data
trends = {
'periods': periods,
'total_calls': [],
'total_cost': [],
'total_tokens': [],
'provider_trends': {}
}
summary_dict = {s.billing_period: s for s in summaries}
for period in periods:
summary = summary_dict.get(period)
if summary:
trends['total_calls'].append(summary.total_calls or 0)
trends['total_cost'].append(summary.total_cost or 0.0)
trends['total_tokens'].append(summary.total_tokens or 0)
# Provider-specific trends
for provider in APIProvider:
provider_name = provider.value
if provider_name not in trends['provider_trends']:
trends['provider_trends'][provider_name] = {
'calls': [],
'cost': [],
'tokens': []
}
trends['provider_trends'][provider_name]['calls'].append(
getattr(summary, f"{provider_name}_calls", 0) or 0
)
trends['provider_trends'][provider_name]['cost'].append(
getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
)
trends['provider_trends'][provider_name]['tokens'].append(
getattr(summary, f"{provider_name}_tokens", 0) or 0
)
else:
# No data for this period
trends['total_calls'].append(0)
trends['total_cost'].append(0.0)
trends['total_tokens'].append(0)
for provider in APIProvider:
provider_name = provider.value
if provider_name not in trends['provider_trends']:
trends['provider_trends'][provider_name] = {
'calls': [],
'cost': [],
'tokens': []
}
trends['provider_trends'][provider_name]['calls'].append(0)
trends['provider_trends'][provider_name]['cost'].append(0.0)
trends['provider_trends'][provider_name]['tokens'].append(0)
return trends
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
"""Enforce usage limits before making an API call."""
# Check short-lived cache first (30s)
cache_key = f"{user_id}:{provider.value}"
now = datetime.utcnow()
cached = self._enforce_cache.get(cache_key)
if cached and cached.get('expires_at') and cached['expires_at'] > now:
return tuple(cached['result']) # type: ignore
result = self.pricing_service.check_usage_limits(
user_id=user_id,
provider=provider,
tokens_requested=tokens_requested
)
self._enforce_cache[cache_key] = {
'result': result,
'expires_at': now + timedelta(seconds=30)
}
return result
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
"""Reset usage status for the current billing period (after plan change)."""
try:
billing_period = datetime.now().strftime("%Y-%m")
summary = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == billing_period
).first()
if not summary:
# Nothing to reset
return {"reset": False, "reason": "no_summary"}
# Clear LIMIT_REACHED so the user can resume; keep counters intact
summary.usage_status = UsageStatus.ACTIVE
summary.updated_at = datetime.utcnow()
self.db.commit()
return {"reset": True}
except Exception as e:
self.db.rollback()
logger.error(f"Error resetting usage status: {e}")
return {"reset": False, "error": str(e)}