""" Enhanced FastAPI Monitoring Middleware Database-backed monitoring for API calls, errors, performance metrics, and usage tracking. Includes comprehensive subscription-based usage monitoring and cost tracking. """ from fastapi import Request, Response from fastapi.responses import JSONResponse import time import json from datetime import datetime, timedelta from typing import Dict, List, Any, Optional from collections import defaultdict, deque import asyncio from loguru import logger from sqlalchemy.orm import Session from sqlalchemy import and_, func import re from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance from models.subscription_models import APIProvider from services.database import get_db from services.usage_tracking_service import UsageTrackingService from services.pricing_service import PricingService class DatabaseAPIMonitor: """Database-backed API monitoring with usage tracking and subscription management.""" def __init__(self): self.cache_stats = { 'hits': 0, 'misses': 0, 'hit_rate': 0.0 } # API provider detection patterns self.provider_patterns = { APIProvider.GEMINI: [r'/gemini', r'gemini', r'google.*ai'], APIProvider.OPENAI: [r'/openai', r'openai', r'gpt'], APIProvider.ANTHROPIC: [r'/anthropic', r'claude', r'anthropic'], APIProvider.MISTRAL: [r'/mistral', r'mistral'], APIProvider.TAVILY: [r'/tavily', r'tavily'], APIProvider.SERPER: [r'/serper', r'serper', r'google.*search'], APIProvider.METAPHOR: [r'/metaphor', r'/exa', r'metaphor', r'exa'], APIProvider.FIRECRAWL: [r'/firecrawl', r'firecrawl'], APIProvider.STABILITY: [r'/stability', r'stable.*diffusion', r'stability'] } def detect_api_provider(self, path: str, user_agent: str = None) -> Optional[APIProvider]: """Detect which API provider is being used based on request details.""" path_lower = path.lower() user_agent_lower = (user_agent or '').lower() for provider, patterns in self.provider_patterns.items(): for pattern in patterns: if re.search(pattern, path_lower) or re.search(pattern, user_agent_lower): return provider return None def extract_usage_metrics(self, request_body: str = None, response_body: str = None) -> Dict[str, Any]: """Extract usage metrics from request/response bodies.""" metrics = { 'tokens_input': 0, 'tokens_output': 0, 'model_used': None, 'search_count': 0, 'image_count': 0, 'page_count': 0 } try: # Try to parse request body for input tokens/content if request_body: request_data = json.loads(request_body) if isinstance(request_body, str) else request_body # Extract model information if 'model' in request_data: metrics['model_used'] = request_data['model'] # Estimate input tokens from prompt/content if 'prompt' in request_data: metrics['tokens_input'] = self._estimate_tokens(request_data['prompt']) elif 'messages' in request_data: total_content = ' '.join([msg.get('content', '') for msg in request_data['messages']]) metrics['tokens_input'] = self._estimate_tokens(total_content) elif 'input' in request_data: metrics['tokens_input'] = self._estimate_tokens(str(request_data['input'])) # Count specific request types if 'query' in request_data or 'search' in request_data: metrics['search_count'] = 1 if 'image' in request_data or 'generate_image' in request_data: metrics['image_count'] = 1 if 'url' in request_data or 'crawl' in request_data: metrics['page_count'] = 1 # Try to parse response body for output tokens if response_body: response_data = json.loads(response_body) if isinstance(response_body, str) else response_body # Extract output content and estimate tokens if 'text' in response_data: metrics['tokens_output'] = self._estimate_tokens(response_data['text']) elif 'content' in response_data: metrics['tokens_output'] = self._estimate_tokens(str(response_data['content'])) elif 'choices' in response_data and response_data['choices']: choice = response_data['choices'][0] if 'message' in choice and 'content' in choice['message']: metrics['tokens_output'] = self._estimate_tokens(choice['message']['content']) # Extract actual token usage if provided by API if 'usage' in response_data: usage = response_data['usage'] if 'prompt_tokens' in usage: metrics['tokens_input'] = usage['prompt_tokens'] if 'completion_tokens' in usage: metrics['tokens_output'] = usage['completion_tokens'] except (json.JSONDecodeError, KeyError, TypeError) as e: logger.debug(f"Could not extract usage metrics: {e}") return metrics def _estimate_tokens(self, text: str) -> int: """Estimate token count for text (rough approximation).""" if not text: return 0 # Rough estimation: 1.3 tokens per word on average word_count = len(str(text).split()) return int(word_count * 1.3) async def add_request(self, db: Session, path: str, method: str, status_code: int, duration: float, user_id: str = None, cache_hit: bool = None, request_size: int = None, response_size: int = None, user_agent: str = None, ip_address: str = None, request_body: str = None, response_body: str = None): """Add a request to database monitoring with usage tracking.""" try: # Store individual request api_request = APIRequest( path=path, method=method, status_code=status_code, duration=duration, user_id=user_id, cache_hit=cache_hit, request_size=request_size, response_size=response_size, user_agent=user_agent, ip_address=ip_address ) db.add(api_request) # Track API usage if this is an API call to external providers api_provider = self.detect_api_provider(path, user_agent) if api_provider and user_id: try: # Extract usage metrics usage_metrics = self.extract_usage_metrics(request_body, response_body) # Track usage with the usage tracking service usage_service = UsageTrackingService(db) await usage_service.track_api_usage( user_id=user_id, provider=api_provider, endpoint=path, method=method, model_used=usage_metrics.get('model_used'), tokens_input=usage_metrics.get('tokens_input', 0), tokens_output=usage_metrics.get('tokens_output', 0), response_time=duration, status_code=status_code, request_size=request_size, response_size=response_size, user_agent=user_agent, ip_address=ip_address, search_count=usage_metrics.get('search_count', 0), image_count=usage_metrics.get('image_count', 0), page_count=usage_metrics.get('page_count', 0) ) logger.info(f"Tracked usage for {user_id}: {api_provider.value} - {usage_metrics.get('tokens_input', 0)}+{usage_metrics.get('tokens_output', 0)} tokens") except Exception as usage_error: logger.error(f"Error tracking API usage: {usage_error}") # Don't fail the main request if usage tracking fails # Update endpoint stats endpoint_key = f"{method} {path}" endpoint_stats = db.query(APIEndpointStats).filter( APIEndpointStats.endpoint == endpoint_key ).first() if not endpoint_stats: endpoint_stats = APIEndpointStats(endpoint=endpoint_key) db.add(endpoint_stats) # Update statistics - handle None values endpoint_stats.total_requests = (endpoint_stats.total_requests or 0) + 1 endpoint_stats.total_duration = (endpoint_stats.total_duration or 0.0) + duration endpoint_stats.avg_duration = endpoint_stats.total_duration / endpoint_stats.total_requests endpoint_stats.last_called = datetime.utcnow() if status_code >= 400: endpoint_stats.total_errors = (endpoint_stats.total_errors or 0) + 1 if cache_hit is not None: if cache_hit: endpoint_stats.cache_hits = (endpoint_stats.cache_hits or 0) + 1 else: endpoint_stats.cache_misses = (endpoint_stats.cache_misses or 0) + 1 total_cache_requests = endpoint_stats.cache_hits + endpoint_stats.cache_misses if total_cache_requests > 0: endpoint_stats.cache_hit_rate = (endpoint_stats.cache_hits / total_cache_requests) * 100 # Update min/max duration if endpoint_stats.min_duration is None or duration < endpoint_stats.min_duration: endpoint_stats.min_duration = duration if endpoint_stats.max_duration is None or duration > endpoint_stats.max_duration: endpoint_stats.max_duration = duration db.commit() # Update cache stats if cache_hit is not None: if cache_hit: self.cache_stats['hits'] += 1 else: self.cache_stats['misses'] += 1 total_cache_requests = self.cache_stats['hits'] + self.cache_stats['misses'] if total_cache_requests > 0: self.cache_stats['hit_rate'] = (self.cache_stats['hits'] / total_cache_requests) * 100 except Exception as e: logger.error(f"❌ Error storing API request: {str(e)}") db.rollback() async def get_stats(self, db: Session, minutes: int = 5) -> Dict[str, Any]: """Get current monitoring statistics from database.""" try: now = datetime.utcnow() since = now - timedelta(minutes=minutes) # Recent requests recent_requests = db.query(APIRequest).filter( APIRequest.timestamp >= since ).count() # Recent errors recent_errors = db.query(APIRequest).filter( and_( APIRequest.timestamp >= since, APIRequest.status_code >= 400 ) ).count() # Top endpoints top_endpoints = db.query(APIEndpointStats).order_by( APIEndpointStats.total_requests.desc() ).limit(10).all() # Recent errors details recent_error_details = db.query(APIRequest).filter( and_( APIRequest.timestamp >= since, APIRequest.status_code >= 400 ) ).order_by(APIRequest.timestamp.desc()).limit(10).all() # Overall stats total_requests = db.query(APIRequest).count() total_errors = db.query(APIRequest).filter(APIRequest.status_code >= 400).count() # Calculate error rate error_rate = (recent_errors / max(recent_requests, 1)) * 100 return { 'timestamp': now.isoformat(), 'overview': { 'total_requests': total_requests, 'total_errors': total_errors, 'recent_requests': recent_requests, 'recent_errors': recent_errors }, 'cache_performance': self.cache_stats, 'top_endpoints': [ { 'endpoint': endpoint.endpoint, 'count': endpoint.total_requests or 0, 'avg_time': round(endpoint.avg_duration or 0.0, 3), 'errors': endpoint.total_errors or 0, 'last_called': endpoint.last_called.isoformat() if endpoint.last_called else None, 'cache_hit_rate': round(endpoint.cache_hit_rate or 0.0, 2) } for endpoint in top_endpoints ], 'recent_errors': [ { 'timestamp': error.timestamp.isoformat(), 'path': error.path, 'method': error.method, 'status_code': error.status_code, 'duration': error.duration } for error in recent_error_details ], 'system_health': { 'status': 'healthy' if recent_errors < 5 else 'warning', 'error_rate': round(error_rate, 2) } } except Exception as e: logger.error(f"❌ Error getting monitoring stats: {str(e)}") return { 'timestamp': datetime.utcnow().isoformat(), 'error': str(e), 'overview': {'total_requests': 0, 'total_errors': 0, 'recent_requests': 0, 'recent_errors': 0}, 'system_health': {'status': 'unknown', 'error_rate': 0.0} } async def get_lightweight_stats(self, db: Session) -> Dict[str, Any]: """Get lightweight stats for dashboard header.""" try: now = datetime.utcnow() since = now - timedelta(minutes=5) # Quick stats for dashboard recent_requests = db.query(APIRequest).filter( APIRequest.timestamp >= since ).count() recent_errors = db.query(APIRequest).filter( and_( APIRequest.timestamp >= since, APIRequest.status_code >= 400 ) ).count() # Determine status if recent_errors == 0: status = "healthy" icon = "🟢" elif recent_errors < 3: status = "warning" icon = "🟡" else: status = "critical" icon = "🔴" return { 'status': status, 'icon': icon, 'recent_requests': recent_requests, 'recent_errors': recent_errors, 'error_rate': round((recent_errors / max(recent_requests, 1)) * 100, 1), 'timestamp': now.isoformat() } except Exception as e: logger.error(f"❌ Error getting lightweight stats: {str(e)}") return { 'status': 'unknown', 'icon': '⚪', 'recent_requests': 0, 'recent_errors': 0, 'error_rate': 0.0, 'timestamp': datetime.utcnow().isoformat() } # Global monitor instance api_monitor = DatabaseAPIMonitor() # List of endpoints to exclude from monitoring EXCLUDED_ENDPOINTS = [ "/api/content-planning/monitoring/lightweight-stats", "/api/content-planning/monitoring/api-stats", "/api/content-planning/monitoring/cache-stats", "/api/content-planning/monitoring/health" ] def should_monitor_endpoint(path: str) -> bool: """Check if an endpoint should be monitored.""" return not any(path.endswith(excluded) for excluded in EXCLUDED_ENDPOINTS) async def check_usage_limits_middleware(request: Request, user_id: str, request_body: str = None) -> Optional[JSONResponse]: """Check usage limits before processing request.""" if not user_id: return None try: db = next(get_db()) api_monitor = DatabaseAPIMonitor() # Detect if this is an API call that should be rate limited api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent')) if not api_provider: return None # Use provided request body or read it if not provided if request_body is None: try: if hasattr(request, '_body'): request_body = request._body else: # Try to read body (this might not work in all cases) body = await request.body() request_body = body.decode('utf-8') if body else None except: pass # Estimate tokens needed tokens_requested = 0 if request_body: usage_metrics = api_monitor.extract_usage_metrics(request_body) tokens_requested = usage_metrics.get('tokens_input', 0) # Check limits usage_service = UsageTrackingService(db) can_proceed, message, usage_info = await usage_service.enforce_usage_limits( user_id=user_id, provider=api_provider, tokens_requested=tokens_requested ) if not can_proceed: logger.warning(f"Usage limit exceeded for {user_id}: {message}") return JSONResponse( status_code=429, content={ "error": "Usage limit exceeded", "message": message, "usage_info": usage_info, "provider": api_provider.value } ) # Warn if approaching limits if usage_info.get('call_usage_percentage', 0) >= 80 or usage_info.get('cost_usage_percentage', 0) >= 80: logger.warning(f"User {user_id} approaching usage limits: {usage_info}") return None except Exception as e: logger.error(f"Error checking usage limits: {e}") # Don't block requests if usage checking fails return None finally: db.close() async def monitoring_middleware(request: Request, call_next): """Enhanced FastAPI middleware for monitoring API calls with usage tracking.""" start_time = time.time() # Skip monitoring for excluded endpoints if not should_monitor_endpoint(request.url.path): response = await call_next(request) return response # Extract request details user_id = None try: if hasattr(request, 'query_params') and 'user_id' in request.query_params: user_id = request.query_params['user_id'] elif hasattr(request, 'path_params') and 'user_id' in request.path_params: user_id = request.path_params['user_id'] # Also check headers for user identification elif 'x-user-id' in request.headers: user_id = request.headers['x-user-id'] # Check for authorization header with user info elif 'authorization' in request.headers: # This would need to be implemented based on your auth system pass except: pass # Capture request body for usage tracking (read once) request_body = None try: if hasattr(request, '_body'): request_body = request._body.decode('utf-8') if request._body else None else: body = await request.body() request_body = body.decode('utf-8') if body else None except: pass # Check usage limits before processing limit_response = await check_usage_limits_middleware(request, user_id, request_body) if limit_response: return limit_response # Get database session db = next(get_db()) try: response = await call_next(request) status_code = response.status_code duration = time.time() - start_time # Capture response body for usage tracking response_body = None try: if hasattr(response, 'body'): response_body = response.body.decode('utf-8') if response.body else None elif hasattr(response, '_content'): response_body = response._content.decode('utf-8') if response._content else None except: pass # Check for cache-related headers cache_hit = None if hasattr(response, 'headers'): cache_header = response.headers.get('x-cache-status') if cache_header: cache_hit = cache_header.lower() == 'hit' # Store in database with enhanced tracking await api_monitor.add_request( db=db, path=request.url.path, method=request.method, status_code=status_code, duration=duration, user_id=user_id, cache_hit=cache_hit, request_size=len(request_body) if request_body else None, response_size=len(response_body) if response_body else None, user_agent=request.headers.get('user-agent'), ip_address=request.client.host if request.client else None, request_body=request_body, response_body=response_body ) # Add monitoring headers response.headers['x-response-time'] = f"{duration:.3f}s" response.headers['x-monitor-id'] = f"{int(time.time())}" return response except Exception as e: duration = time.time() - start_time status_code = 500 # Store error in database with enhanced tracking await api_monitor.add_request( db=db, path=request.url.path, method=request.method, status_code=status_code, duration=duration, user_id=user_id, cache_hit=False, request_size=len(request_body) if request_body else None, response_size=None, user_agent=request.headers.get('user-agent'), ip_address=request.client.host if request.client else None, request_body=request_body, response_body=None ) logger.error(f"❌ API Error: {request.method} {request.url.path} - {str(e)}") return JSONResponse( status_code=500, content={"error": "Internal server error", "monitor_id": int(time.time())} ) finally: db.close() async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]: """Get current monitoring statistics.""" db = next(get_db()) try: return await api_monitor.get_stats(db, minutes) finally: db.close() async def get_lightweight_stats() -> Dict[str, Any]: """Get lightweight stats for dashboard header.""" db = next(get_db()) try: return await api_monitor.get_lightweight_stats(db) finally: db.close()