Add comprehensive usage-based subscription system with API tracking

Co-authored-by: ajay.calsoft <ajay.calsoft@gmail.com>
This commit is contained in:
Cursor Agent
2025-09-04 17:18:27 +00:00
parent d57f7feb4a
commit e0a6150ed1
13 changed files with 3619 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
"""
Enhanced FastAPI Monitoring Middleware
Database-backed monitoring for API calls, errors, and performance metrics.
Database-backed monitoring for API calls, errors, performance metrics, and usage tracking.
Includes comprehensive subscription-based usage monitoring and cost tracking.
"""
from fastapi import Request, Response
@@ -14,12 +15,16 @@ import asyncio
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
import re
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
from models.subscription_models import APIProvider
from services.database import get_db
from services.usage_tracking_service import UsageTrackingService
from services.pricing_service import PricingService
class DatabaseAPIMonitor:
"""Database-backed API monitoring."""
"""Database-backed API monitoring with usage tracking and subscription management."""
def __init__(self):
self.cache_stats = {
@@ -27,12 +32,109 @@ class DatabaseAPIMonitor:
'misses': 0,
'hit_rate': 0.0
}
# API provider detection patterns
self.provider_patterns = {
APIProvider.GEMINI: [r'/gemini', r'gemini', r'google.*ai'],
APIProvider.OPENAI: [r'/openai', r'openai', r'gpt'],
APIProvider.ANTHROPIC: [r'/anthropic', r'claude', r'anthropic'],
APIProvider.MISTRAL: [r'/mistral', r'mistral'],
APIProvider.TAVILY: [r'/tavily', r'tavily'],
APIProvider.SERPER: [r'/serper', r'serper', r'google.*search'],
APIProvider.METAPHOR: [r'/metaphor', r'/exa', r'metaphor', r'exa'],
APIProvider.FIRECRAWL: [r'/firecrawl', r'firecrawl'],
APIProvider.STABILITY: [r'/stability', r'stable.*diffusion', r'stability']
}
def detect_api_provider(self, path: str, user_agent: str = None) -> Optional[APIProvider]:
"""Detect which API provider is being used based on request details."""
path_lower = path.lower()
user_agent_lower = (user_agent or '').lower()
for provider, patterns in self.provider_patterns.items():
for pattern in patterns:
if re.search(pattern, path_lower) or re.search(pattern, user_agent_lower):
return provider
return None
def extract_usage_metrics(self, request_body: str = None, response_body: str = None) -> Dict[str, Any]:
"""Extract usage metrics from request/response bodies."""
metrics = {
'tokens_input': 0,
'tokens_output': 0,
'model_used': None,
'search_count': 0,
'image_count': 0,
'page_count': 0
}
try:
# Try to parse request body for input tokens/content
if request_body:
request_data = json.loads(request_body) if isinstance(request_body, str) else request_body
# Extract model information
if 'model' in request_data:
metrics['model_used'] = request_data['model']
# Estimate input tokens from prompt/content
if 'prompt' in request_data:
metrics['tokens_input'] = self._estimate_tokens(request_data['prompt'])
elif 'messages' in request_data:
total_content = ' '.join([msg.get('content', '') for msg in request_data['messages']])
metrics['tokens_input'] = self._estimate_tokens(total_content)
elif 'input' in request_data:
metrics['tokens_input'] = self._estimate_tokens(str(request_data['input']))
# Count specific request types
if 'query' in request_data or 'search' in request_data:
metrics['search_count'] = 1
if 'image' in request_data or 'generate_image' in request_data:
metrics['image_count'] = 1
if 'url' in request_data or 'crawl' in request_data:
metrics['page_count'] = 1
# Try to parse response body for output tokens
if response_body:
response_data = json.loads(response_body) if isinstance(response_body, str) else response_body
# Extract output content and estimate tokens
if 'text' in response_data:
metrics['tokens_output'] = self._estimate_tokens(response_data['text'])
elif 'content' in response_data:
metrics['tokens_output'] = self._estimate_tokens(str(response_data['content']))
elif 'choices' in response_data and response_data['choices']:
choice = response_data['choices'][0]
if 'message' in choice and 'content' in choice['message']:
metrics['tokens_output'] = self._estimate_tokens(choice['message']['content'])
# Extract actual token usage if provided by API
if 'usage' in response_data:
usage = response_data['usage']
if 'prompt_tokens' in usage:
metrics['tokens_input'] = usage['prompt_tokens']
if 'completion_tokens' in usage:
metrics['tokens_output'] = usage['completion_tokens']
except (json.JSONDecodeError, KeyError, TypeError) as e:
logger.debug(f"Could not extract usage metrics: {e}")
return metrics
def _estimate_tokens(self, text: str) -> int:
"""Estimate token count for text (rough approximation)."""
if not text:
return 0
# Rough estimation: 1.3 tokens per word on average
word_count = len(str(text).split())
return int(word_count * 1.3)
async def add_request(self, db: Session, path: str, method: str, status_code: int,
duration: float, user_id: str = None, cache_hit: bool = None,
request_size: int = None, response_size: int = None,
user_agent: str = None, ip_address: str = None):
"""Add a request to database monitoring."""
user_agent: str = None, ip_address: str = None,
request_body: str = None, response_body: str = None):
"""Add a request to database monitoring with usage tracking."""
try:
# Store individual request
api_request = APIRequest(
@@ -49,6 +151,38 @@ class DatabaseAPIMonitor:
)
db.add(api_request)
# Track API usage if this is an API call to external providers
api_provider = self.detect_api_provider(path, user_agent)
if api_provider and user_id:
try:
# Extract usage metrics
usage_metrics = self.extract_usage_metrics(request_body, response_body)
# Track usage with the usage tracking service
usage_service = UsageTrackingService(db)
await usage_service.track_api_usage(
user_id=user_id,
provider=api_provider,
endpoint=path,
method=method,
model_used=usage_metrics.get('model_used'),
tokens_input=usage_metrics.get('tokens_input', 0),
tokens_output=usage_metrics.get('tokens_output', 0),
response_time=duration,
status_code=status_code,
request_size=request_size,
response_size=response_size,
user_agent=user_agent,
ip_address=ip_address,
search_count=usage_metrics.get('search_count', 0),
image_count=usage_metrics.get('image_count', 0),
page_count=usage_metrics.get('page_count', 0)
)
logger.info(f"Tracked usage for {user_id}: {api_provider.value} - {usage_metrics.get('tokens_input', 0)}+{usage_metrics.get('tokens_output', 0)} tokens")
except Exception as usage_error:
logger.error(f"Error tracking API usage: {usage_error}")
# Don't fail the main request if usage tracking fails
# Update endpoint stats
endpoint_key = f"{method} {path}"
endpoint_stats = db.query(APIEndpointStats).filter(
@@ -249,8 +383,73 @@ def should_monitor_endpoint(path: str) -> bool:
"""Check if an endpoint should be monitored."""
return not any(path.endswith(excluded) for excluded in EXCLUDED_ENDPOINTS)
async def check_usage_limits_middleware(request: Request, user_id: str) -> Optional[JSONResponse]:
"""Check usage limits before processing request."""
if not user_id:
return None
try:
db = next(get_db())
api_monitor = DatabaseAPIMonitor()
# Detect if this is an API call that should be rate limited
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
if not api_provider:
return None
# Get request body to estimate tokens
request_body = None
try:
if hasattr(request, '_body'):
request_body = request._body
else:
# Try to read body (this might not work in all cases)
body = await request.body()
request_body = body.decode('utf-8') if body else None
except:
pass
# Estimate tokens needed
tokens_requested = 0
if request_body:
usage_metrics = api_monitor.extract_usage_metrics(request_body)
tokens_requested = usage_metrics.get('tokens_input', 0)
# Check limits
usage_service = UsageTrackingService(db)
can_proceed, message, usage_info = await usage_service.enforce_usage_limits(
user_id=user_id,
provider=api_provider,
tokens_requested=tokens_requested
)
if not can_proceed:
logger.warning(f"Usage limit exceeded for {user_id}: {message}")
return JSONResponse(
status_code=429,
content={
"error": "Usage limit exceeded",
"message": message,
"usage_info": usage_info,
"provider": api_provider.value
}
)
# Warn if approaching limits
if usage_info.get('call_usage_percentage', 0) >= 80 or usage_info.get('cost_usage_percentage', 0) >= 80:
logger.warning(f"User {user_id} approaching usage limits: {usage_info}")
return None
except Exception as e:
logger.error(f"Error checking usage limits: {e}")
# Don't block requests if usage checking fails
return None
finally:
db.close()
async def monitoring_middleware(request: Request, call_next):
"""Enhanced FastAPI middleware for monitoring API calls."""
"""Enhanced FastAPI middleware for monitoring API calls with usage tracking."""
start_time = time.time()
# Skip monitoring for excluded endpoints
@@ -265,6 +464,29 @@ async def monitoring_middleware(request: Request, call_next):
user_id = request.query_params['user_id']
elif hasattr(request, 'path_params') and 'user_id' in request.path_params:
user_id = request.path_params['user_id']
# Also check headers for user identification
elif 'x-user-id' in request.headers:
user_id = request.headers['x-user-id']
# Check for authorization header with user info
elif 'authorization' in request.headers:
# This would need to be implemented based on your auth system
pass
except:
pass
# Check usage limits before processing
limit_response = await check_usage_limits_middleware(request, user_id)
if limit_response:
return limit_response
# Capture request body for usage tracking
request_body = None
try:
if hasattr(request, '_body'):
request_body = request._body.decode('utf-8') if request._body else None
else:
body = await request.body()
request_body = body.decode('utf-8') if body else None
except:
pass
@@ -276,6 +498,16 @@ async def monitoring_middleware(request: Request, call_next):
status_code = response.status_code
duration = time.time() - start_time
# Capture response body for usage tracking
response_body = None
try:
if hasattr(response, 'body'):
response_body = response.body.decode('utf-8') if response.body else None
elif hasattr(response, '_content'):
response_body = response._content.decode('utf-8') if response._content else None
except:
pass
# Check for cache-related headers
cache_hit = None
if hasattr(response, 'headers'):
@@ -283,7 +515,7 @@ async def monitoring_middleware(request: Request, call_next):
if cache_header:
cache_hit = cache_header.lower() == 'hit'
# Store in database
# Store in database with enhanced tracking
await api_monitor.add_request(
db=db,
path=request.url.path,
@@ -292,8 +524,12 @@ async def monitoring_middleware(request: Request, call_next):
duration=duration,
user_id=user_id,
cache_hit=cache_hit,
request_size=len(request_body) if request_body else None,
response_size=len(response_body) if response_body else None,
user_agent=request.headers.get('user-agent'),
ip_address=request.client.host if request.client else None
ip_address=request.client.host if request.client else None,
request_body=request_body,
response_body=response_body
)
# Add monitoring headers
@@ -306,7 +542,7 @@ async def monitoring_middleware(request: Request, call_next):
duration = time.time() - start_time
status_code = 500
# Store error in database
# Store error in database with enhanced tracking
await api_monitor.add_request(
db=db,
path=request.url.path,
@@ -315,8 +551,12 @@ async def monitoring_middleware(request: Request, call_next):
duration=duration,
user_id=user_id,
cache_hit=False,
request_size=len(request_body) if request_body else None,
response_size=None,
user_agent=request.headers.get('user-agent'),
ip_address=request.client.host if request.client else None
ip_address=request.client.host if request.client else None,
request_body=request_body,
response_body=None
)
logger.error(f"❌ API Error: {request.method} {request.url.path} - {str(e)}")