Base code
This commit is contained in:
1
backend/middleware/__init__.py
Normal file
1
backend/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Makes the middleware directory a Python package
|
||||
123
backend/middleware/api_key_injection_middleware.py
Normal file
123
backend/middleware/api_key_injection_middleware.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""
|
||||
API Key Injection Middleware
|
||||
|
||||
Temporarily injects user-specific API keys into os.environ for the duration of the request.
|
||||
This allows existing code that uses os.getenv('GEMINI_API_KEY') to work without modification.
|
||||
|
||||
IMPORTANT: This is a compatibility layer. For new code, use UserAPIKeyContext directly.
|
||||
"""
|
||||
|
||||
import os
|
||||
from fastapi import Request
|
||||
from loguru import logger
|
||||
from typing import Callable
|
||||
from services.user_api_key_context import user_api_keys
|
||||
|
||||
|
||||
class APIKeyInjectionMiddleware:
|
||||
"""
|
||||
Middleware that injects user-specific API keys into environment variables
|
||||
for the duration of each request.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.original_keys = {}
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
"""
|
||||
Inject user-specific API keys before processing request,
|
||||
restore original values after request completes.
|
||||
"""
|
||||
|
||||
# Try to extract user_id from Authorization header
|
||||
user_id = None
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if auth_header and auth_header.startswith('Bearer '):
|
||||
try:
|
||||
from middleware.auth_middleware import clerk_auth
|
||||
token = auth_header.replace('Bearer ', '')
|
||||
user = await clerk_auth.verify_token(token)
|
||||
if user:
|
||||
# Try different possible keys for user_id
|
||||
user_id = user.get('user_id') or user.get('clerk_user_id') or user.get('id')
|
||||
if user_id:
|
||||
logger.info(f"[API Key Injection] Extracted user_id: {user_id}")
|
||||
|
||||
# Store user_id in request.state for monitoring middleware
|
||||
request.state.user_id = user_id
|
||||
else:
|
||||
logger.warning(f"[API Key Injection] User object missing ID: {user}")
|
||||
else:
|
||||
# Token verification failed (likely expired) - log at debug level to reduce noise
|
||||
logger.debug("[API Key Injection] Token verification failed (likely expired token)")
|
||||
except Exception as e:
|
||||
logger.error(f"[API Key Injection] Could not extract user from token: {e}")
|
||||
|
||||
if not user_id:
|
||||
# No authenticated user, proceed without injection
|
||||
return await call_next(request)
|
||||
|
||||
# Check if we're in production mode
|
||||
is_production = os.getenv('DEPLOY_ENV', 'local') == 'production'
|
||||
|
||||
if not is_production:
|
||||
# Local mode - keys already in .env, no injection needed
|
||||
return await call_next(request)
|
||||
|
||||
# Get user-specific API keys from database
|
||||
with user_api_keys(user_id) as user_keys:
|
||||
if not user_keys:
|
||||
logger.warning(f"No API keys found for user {user_id}")
|
||||
return await call_next(request)
|
||||
|
||||
# Save original environment values
|
||||
original_keys = {}
|
||||
keys_to_inject = {
|
||||
'gemini': 'GEMINI_API_KEY',
|
||||
'exa': 'EXA_API_KEY',
|
||||
'copilotkit': 'COPILOTKIT_API_KEY',
|
||||
'openai': 'OPENAI_API_KEY',
|
||||
'anthropic': 'ANTHROPIC_API_KEY',
|
||||
'tavily': 'TAVILY_API_KEY',
|
||||
'serper': 'SERPER_API_KEY',
|
||||
'firecrawl': 'FIRECRAWL_API_KEY',
|
||||
}
|
||||
|
||||
# Inject user-specific keys into environment
|
||||
for provider, env_var in keys_to_inject.items():
|
||||
if provider in user_keys and user_keys[provider]:
|
||||
# Save original value (if any)
|
||||
original_keys[env_var] = os.environ.get(env_var)
|
||||
# Inject user-specific key
|
||||
os.environ[env_var] = user_keys[provider]
|
||||
logger.debug(f"[PRODUCTION] Injected {env_var} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Process request with user-specific keys in environment
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
finally:
|
||||
# CRITICAL: Restore original environment values
|
||||
for env_var, original_value in original_keys.items():
|
||||
if original_value is None:
|
||||
# Key didn't exist before, remove it
|
||||
os.environ.pop(env_var, None)
|
||||
else:
|
||||
# Restore original value
|
||||
os.environ[env_var] = original_value
|
||||
|
||||
logger.debug(f"[PRODUCTION] Cleaned up environment for user {user_id}")
|
||||
|
||||
|
||||
async def api_key_injection_middleware(request: Request, call_next: Callable):
|
||||
"""
|
||||
Middleware function that injects user-specific API keys into environment.
|
||||
|
||||
Usage in app.py:
|
||||
app.middleware("http")(api_key_injection_middleware)
|
||||
"""
|
||||
middleware = APIKeyInjectionMiddleware()
|
||||
return await middleware(request, call_next)
|
||||
|
||||
348
backend/middleware/auth_middleware.py
Normal file
348
backend/middleware/auth_middleware.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Authentication middleware for ALwrity backend."""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException, Depends, status, Request, Query
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from loguru import logger
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Try to import fastapi-clerk-auth, fallback to custom implementation
|
||||
try:
|
||||
from fastapi_clerk_auth import ClerkHTTPBearer, ClerkConfig
|
||||
CLERK_AUTH_AVAILABLE = True
|
||||
except ImportError:
|
||||
CLERK_AUTH_AVAILABLE = False
|
||||
logger.warning("fastapi-clerk-auth not available, using custom implementation")
|
||||
|
||||
# Load environment variables from the correct path
|
||||
# Get the backend directory path (parent of middleware directory)
|
||||
_backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
_env_path = os.path.join(_backend_dir, ".env")
|
||||
load_dotenv(_env_path, override=False) # Don't override if already loaded
|
||||
|
||||
# Initialize security scheme
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
class ClerkAuthMiddleware:
|
||||
"""Clerk authentication middleware using fastapi-clerk-auth or custom implementation."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Clerk authentication middleware."""
|
||||
self.clerk_secret_key = os.getenv('CLERK_SECRET_KEY', '').strip()
|
||||
# Check for both backend and frontend naming conventions
|
||||
publishable_key = (
|
||||
os.getenv('CLERK_PUBLISHABLE_KEY') or
|
||||
os.getenv('REACT_APP_CLERK_PUBLISHABLE_KEY', '')
|
||||
)
|
||||
self.clerk_publishable_key = publishable_key.strip() if publishable_key else None
|
||||
self.disable_auth = os.getenv('DISABLE_AUTH', 'false').lower() == 'true'
|
||||
|
||||
# Cache for PyJWKClient to avoid repeated JWKS fetches
|
||||
self._jwks_client_cache = {}
|
||||
self._jwks_url_cache = None
|
||||
|
||||
if not self.clerk_secret_key and not self.disable_auth:
|
||||
logger.warning("CLERK_SECRET_KEY not found, authentication may fail")
|
||||
|
||||
# Initialize fastapi-clerk-auth if available
|
||||
if CLERK_AUTH_AVAILABLE and not self.disable_auth:
|
||||
try:
|
||||
if self.clerk_secret_key and self.clerk_publishable_key:
|
||||
# Extract instance from publishable key for JWKS URL
|
||||
# Format: pk_test_<instance>.<domain> or pk_live_<instance>.<domain>
|
||||
parts = self.clerk_publishable_key.replace('pk_test_', '').replace('pk_live_', '').split('.')
|
||||
if len(parts) >= 1:
|
||||
# Extract the domain from publishable key or use default
|
||||
# Clerk URLs are typically: https://<instance>.clerk.accounts.dev
|
||||
instance = parts[0]
|
||||
jwks_url = f"https://{instance}.clerk.accounts.dev/.well-known/jwks.json"
|
||||
|
||||
# Create Clerk configuration with JWKS URL
|
||||
clerk_config = ClerkConfig(
|
||||
secret_key=self.clerk_secret_key,
|
||||
jwks_url=jwks_url
|
||||
)
|
||||
# Create ClerkHTTPBearer instance for dependency injection
|
||||
self.clerk_bearer = ClerkHTTPBearer(clerk_config)
|
||||
logger.info(f"fastapi-clerk-auth initialized successfully with JWKS URL: {jwks_url}")
|
||||
else:
|
||||
logger.warning("Could not extract instance from publishable key")
|
||||
self.clerk_bearer = None
|
||||
else:
|
||||
logger.warning("CLERK_SECRET_KEY or CLERK_PUBLISHABLE_KEY not found")
|
||||
self.clerk_bearer = None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize fastapi-clerk-auth: {e}")
|
||||
self.clerk_bearer = None
|
||||
else:
|
||||
self.clerk_bearer = None
|
||||
|
||||
logger.info(f"ClerkAuthMiddleware initialized - Auth disabled: {self.disable_auth}, fastapi-clerk-auth: {CLERK_AUTH_AVAILABLE}")
|
||||
|
||||
async def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""Verify Clerk JWT using fastapi-clerk-auth or custom implementation."""
|
||||
try:
|
||||
if self.disable_auth:
|
||||
logger.info("Authentication disabled, returning mock user")
|
||||
return {
|
||||
'id': 'mock_user_id',
|
||||
'email': 'mock@example.com',
|
||||
'first_name': 'Mock',
|
||||
'last_name': 'User',
|
||||
'clerk_user_id': 'mock_clerk_user_id'
|
||||
}
|
||||
|
||||
if not self.clerk_secret_key:
|
||||
logger.error("CLERK_SECRET_KEY not configured")
|
||||
return None
|
||||
|
||||
# Use fastapi-clerk-auth if available
|
||||
if self.clerk_bearer:
|
||||
try:
|
||||
# Decode and verify the JWT token
|
||||
import jwt
|
||||
from jwt import PyJWKClient
|
||||
|
||||
# Get the JWKS URL from the token header
|
||||
unverified_header = jwt.get_unverified_header(token)
|
||||
|
||||
# Decode token to get issuer for JWKS URL
|
||||
unverified_claims = jwt.decode(token, options={"verify_signature": False})
|
||||
issuer = unverified_claims.get('iss', '')
|
||||
|
||||
# Construct JWKS URL from issuer
|
||||
jwks_url = f"{issuer}/.well-known/jwks.json"
|
||||
|
||||
# Use cached PyJWKClient to avoid repeated JWKS fetches
|
||||
if jwks_url not in self._jwks_client_cache:
|
||||
logger.info(f"Creating new PyJWKClient for {jwks_url} with caching enabled")
|
||||
# Create client with caching enabled (cache_keys=True keeps keys in memory)
|
||||
self._jwks_client_cache[jwks_url] = PyJWKClient(
|
||||
jwks_url,
|
||||
cache_keys=True,
|
||||
max_cached_keys=16
|
||||
)
|
||||
|
||||
jwks_client = self._jwks_client_cache[jwks_url]
|
||||
signing_key = jwks_client.get_signing_key_from_jwt(token)
|
||||
|
||||
# Verify and decode the token with clock skew tolerance
|
||||
# Add 300 seconds (5 minutes) leeway to handle clock skew and token refresh delays
|
||||
decoded_token = jwt.decode(
|
||||
token,
|
||||
signing_key.key,
|
||||
algorithms=["RS256"],
|
||||
options={"verify_signature": True, "verify_exp": True},
|
||||
leeway=300 # Allow 5 minutes leeway for token refresh during navigation
|
||||
)
|
||||
|
||||
# Extract user information
|
||||
user_id = decoded_token.get('sub')
|
||||
email = decoded_token.get('email')
|
||||
first_name = decoded_token.get('first_name') or decoded_token.get('given_name')
|
||||
last_name = decoded_token.get('last_name') or decoded_token.get('family_name')
|
||||
|
||||
if user_id:
|
||||
logger.info(f"Token verified successfully using fastapi-clerk-auth for user: {email} (ID: {user_id})")
|
||||
return {
|
||||
'id': user_id,
|
||||
'email': email,
|
||||
'first_name': first_name,
|
||||
'last_name': last_name,
|
||||
'clerk_user_id': user_id
|
||||
}
|
||||
else:
|
||||
logger.warning("No user ID found in verified token")
|
||||
return None
|
||||
except Exception as e:
|
||||
# Expired tokens are expected - log at debug level to reduce noise
|
||||
error_msg = str(e).lower()
|
||||
if 'expired' in error_msg or 'signature has expired' in error_msg:
|
||||
logger.debug(f"Token expired (expected): {e}")
|
||||
else:
|
||||
logger.warning(f"fastapi-clerk-auth verification error: {e}")
|
||||
return None
|
||||
else:
|
||||
# Fallback to custom implementation (not secure for production)
|
||||
logger.warning("Using fallback JWT decoding without signature verification")
|
||||
try:
|
||||
import jwt
|
||||
# Decode the JWT without verification to get claims
|
||||
# This is NOT secure for production - only for development
|
||||
# Add leeway to handle clock skew
|
||||
decoded_token = jwt.decode(
|
||||
token,
|
||||
options={"verify_signature": False},
|
||||
leeway=300 # Allow 5 minutes leeway for token refresh
|
||||
)
|
||||
|
||||
# Extract user information from the token
|
||||
user_id = decoded_token.get('sub') or decoded_token.get('user_id')
|
||||
email = decoded_token.get('email')
|
||||
first_name = decoded_token.get('first_name')
|
||||
last_name = decoded_token.get('last_name')
|
||||
|
||||
if not user_id:
|
||||
logger.warning("No user ID found in token")
|
||||
return None
|
||||
|
||||
logger.info(f"Token decoded successfully (fallback) for user: {email} (ID: {user_id})")
|
||||
return {
|
||||
'id': user_id,
|
||||
'email': email,
|
||||
'first_name': first_name,
|
||||
'last_name': last_name,
|
||||
'clerk_user_id': user_id
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Fallback JWT decode error: {e}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification error: {e}")
|
||||
return None
|
||||
|
||||
# Initialize middleware
|
||||
clerk_auth = ClerkAuthMiddleware()
|
||||
|
||||
async def get_current_user(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current authenticated user."""
|
||||
try:
|
||||
if not credentials:
|
||||
# CRITICAL: Log as ERROR since this is a security issue - authenticated endpoint accessed without credentials
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: No credentials provided for authenticated endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
user = await clerk_auth.verify_token(token)
|
||||
if not user:
|
||||
# Token verification failed - log with endpoint context for debugging
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: Token verification failed for endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication failed",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: Unexpected error during authentication for endpoint: {endpoint_path}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication failed",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
async def get_optional_user(
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Get current user if authenticated, otherwise return None."""
|
||||
try:
|
||||
if not credentials:
|
||||
return None
|
||||
|
||||
token = credentials.credentials
|
||||
user = await clerk_auth.verify_token(token)
|
||||
return user
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Optional authentication failed: {e}")
|
||||
return None
|
||||
|
||||
async def get_current_user_with_query_token(
|
||||
request: Request,
|
||||
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current authenticated user from either Authorization header or query parameter.
|
||||
|
||||
This is useful for media endpoints (audio, video, images) that need to be accessed
|
||||
by HTML elements like <audio> or <img> which cannot send custom headers.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
credentials: HTTP authorization credentials from header
|
||||
|
||||
Returns:
|
||||
User dictionary with authentication info
|
||||
|
||||
Raises:
|
||||
HTTPException: If authentication fails
|
||||
"""
|
||||
try:
|
||||
# Try to get token from Authorization header first
|
||||
token_to_verify = None
|
||||
if credentials:
|
||||
token_to_verify = credentials.credentials
|
||||
else:
|
||||
# Fall back to query parameter if no header
|
||||
query_token = request.query_params.get("token")
|
||||
if query_token:
|
||||
token_to_verify = query_token
|
||||
|
||||
if not token_to_verify:
|
||||
# CRITICAL: Log as ERROR since this is a security issue
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: No credentials provided (neither header nor query parameter) "
|
||||
f"for authenticated endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Not authenticated",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
user = await clerk_auth.verify_token(token_to_verify)
|
||||
if not user:
|
||||
# Token verification failed - log with endpoint context
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: Token verification failed for endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication failed",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: Unexpected error during authentication for endpoint: {endpoint_path}: {e}",
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication failed",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
332
backend/middleware/logging_middleware.py
Normal file
332
backend/middleware/logging_middleware.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""
|
||||
Intelligent Logging Middleware for AI SEO Tools
|
||||
|
||||
Provides structured logging, file saving, and monitoring capabilities
|
||||
for all SEO tool operations with performance tracking.
|
||||
"""
|
||||
|
||||
import json
|
||||
import asyncio
|
||||
import aiofiles
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from typing import Dict, Any, Callable
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import os
|
||||
import time
|
||||
|
||||
# Logging configuration
|
||||
LOG_BASE_DIR = "logs"
|
||||
os.makedirs(LOG_BASE_DIR, exist_ok=True)
|
||||
|
||||
# Ensure subdirectories exist
|
||||
for subdir in ["seo_tools", "api_calls", "errors", "performance"]:
|
||||
os.makedirs(f"{LOG_BASE_DIR}/{subdir}", exist_ok=True)
|
||||
|
||||
class PerformanceLogger:
|
||||
"""Performance monitoring and logging for SEO operations"""
|
||||
|
||||
def __init__(self):
|
||||
self.performance_data = {}
|
||||
|
||||
async def log_performance(self, operation: str, duration: float, metadata: Dict[str, Any] = None):
|
||||
"""Log performance metrics for operations"""
|
||||
performance_log = {
|
||||
"operation": operation,
|
||||
"duration_seconds": duration,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"metadata": metadata or {}
|
||||
}
|
||||
|
||||
await save_to_file(f"{LOG_BASE_DIR}/performance/metrics.jsonl", performance_log)
|
||||
|
||||
# Log performance warnings for slow operations
|
||||
if duration > 30: # More than 30 seconds
|
||||
logger.warning(f"Slow operation detected: {operation} took {duration:.2f} seconds")
|
||||
elif duration > 10: # More than 10 seconds
|
||||
logger.info(f"Operation {operation} took {duration:.2f} seconds")
|
||||
|
||||
performance_logger = PerformanceLogger()
|
||||
|
||||
async def save_to_file(filepath: str, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Asynchronously save structured data to a JSONL file
|
||||
|
||||
Args:
|
||||
filepath: Path to the log file
|
||||
data: Dictionary data to save
|
||||
"""
|
||||
try:
|
||||
# Ensure directory exists
|
||||
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Convert data to JSON string
|
||||
json_line = json.dumps(data, default=str) + "\n"
|
||||
|
||||
# Write asynchronously
|
||||
async with aiofiles.open(filepath, "a", encoding="utf-8") as file:
|
||||
await file.write(json_line)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save log to {filepath}: {e}")
|
||||
|
||||
def log_api_call(func: Callable) -> Callable:
|
||||
"""
|
||||
Decorator for logging API calls with performance tracking
|
||||
|
||||
Automatically logs request/response data, timing, and errors
|
||||
for SEO tool endpoints.
|
||||
"""
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
operation_name = func.__name__
|
||||
|
||||
# Extract request data
|
||||
request_data = {}
|
||||
for arg in args:
|
||||
if hasattr(arg, 'dict'): # Pydantic model
|
||||
request_data.update(arg.dict())
|
||||
|
||||
# Log API call start
|
||||
call_log = {
|
||||
"operation": operation_name,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_data": request_data,
|
||||
"status": "started"
|
||||
}
|
||||
|
||||
logger.info(f"API Call Started: {operation_name}")
|
||||
|
||||
try:
|
||||
# Execute the function
|
||||
result = await func(*args, **kwargs)
|
||||
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Log successful completion
|
||||
call_log.update({
|
||||
"status": "completed",
|
||||
"execution_time": execution_time,
|
||||
"success": getattr(result, 'success', True),
|
||||
"completion_timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
await save_to_file(f"{LOG_BASE_DIR}/api_calls/successful.jsonl", call_log)
|
||||
await performance_logger.log_performance(operation_name, execution_time, request_data)
|
||||
|
||||
logger.info(f"API Call Completed: {operation_name} in {execution_time:.2f}s")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = time.time() - start_time
|
||||
|
||||
# Log error
|
||||
error_log = call_log.copy()
|
||||
error_log.update({
|
||||
"status": "failed",
|
||||
"execution_time": execution_time,
|
||||
"error_type": type(e).__name__,
|
||||
"error_message": str(e),
|
||||
"completion_timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
await save_to_file(f"{LOG_BASE_DIR}/api_calls/failed.jsonl", error_log)
|
||||
|
||||
logger.error(f"API Call Failed: {operation_name} after {execution_time:.2f}s - {e}")
|
||||
|
||||
# Re-raise the exception
|
||||
raise
|
||||
|
||||
return wrapper
|
||||
|
||||
class SEOToolsLogger:
|
||||
"""Centralized logger for SEO tools with intelligent categorization"""
|
||||
|
||||
@staticmethod
|
||||
async def log_tool_usage(tool_name: str, input_data: Dict[str, Any],
|
||||
output_data: Dict[str, Any], success: bool = True):
|
||||
"""Log SEO tool usage with input/output tracking"""
|
||||
usage_log = {
|
||||
"tool": tool_name,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"input_data": input_data,
|
||||
"output_data": output_data,
|
||||
"success": success,
|
||||
"input_size": len(str(input_data)),
|
||||
"output_size": len(str(output_data))
|
||||
}
|
||||
|
||||
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/usage.jsonl", usage_log)
|
||||
|
||||
@staticmethod
|
||||
async def log_ai_analysis(tool_name: str, prompt: str, response: str,
|
||||
model_used: str, tokens_used: int = None):
|
||||
"""Log AI analysis operations with token tracking"""
|
||||
ai_log = {
|
||||
"tool": tool_name,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"model": model_used,
|
||||
"prompt_length": len(prompt),
|
||||
"response_length": len(response),
|
||||
"tokens_used": tokens_used,
|
||||
"prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt,
|
||||
"response_preview": response[:200] + "..." if len(response) > 200 else response
|
||||
}
|
||||
|
||||
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/ai_analysis.jsonl", ai_log)
|
||||
|
||||
@staticmethod
|
||||
async def log_external_api_call(api_name: str, endpoint: str, response_code: int,
|
||||
response_time: float, request_data: Dict[str, Any] = None):
|
||||
"""Log external API calls (PageSpeed, etc.)"""
|
||||
api_log = {
|
||||
"api": api_name,
|
||||
"endpoint": endpoint,
|
||||
"response_code": response_code,
|
||||
"response_time": response_time,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"request_data": request_data or {},
|
||||
"success": 200 <= response_code < 300
|
||||
}
|
||||
|
||||
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/external_apis.jsonl", api_log)
|
||||
|
||||
@staticmethod
|
||||
async def log_crawling_operation(url: str, pages_crawled: int, errors_found: int,
|
||||
crawl_depth: int, duration: float):
|
||||
"""Log web crawling operations"""
|
||||
crawl_log = {
|
||||
"url": url,
|
||||
"pages_crawled": pages_crawled,
|
||||
"errors_found": errors_found,
|
||||
"crawl_depth": crawl_depth,
|
||||
"duration": duration,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"pages_per_second": pages_crawled / duration if duration > 0 else 0
|
||||
}
|
||||
|
||||
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/crawling.jsonl", crawl_log)
|
||||
|
||||
class LogAnalyzer:
|
||||
"""Analyze logs to provide insights and monitoring"""
|
||||
|
||||
@staticmethod
|
||||
async def get_performance_summary(hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get performance summary for the last N hours"""
|
||||
try:
|
||||
performance_file = f"{LOG_BASE_DIR}/performance/metrics.jsonl"
|
||||
if not os.path.exists(performance_file):
|
||||
return {"error": "No performance data available"}
|
||||
|
||||
# Read recent performance data
|
||||
cutoff_time = datetime.utcnow().timestamp() - (hours * 3600)
|
||||
operations = []
|
||||
|
||||
async with aiofiles.open(performance_file, "r") as file:
|
||||
async for line in file:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
log_time = datetime.fromisoformat(data["timestamp"]).timestamp()
|
||||
if log_time >= cutoff_time:
|
||||
operations.append(data)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
if not operations:
|
||||
return {"message": f"No operations in the last {hours} hours"}
|
||||
|
||||
# Calculate statistics
|
||||
durations = [op["duration_seconds"] for op in operations]
|
||||
operation_counts = {}
|
||||
for op in operations:
|
||||
op_name = op["operation"]
|
||||
operation_counts[op_name] = operation_counts.get(op_name, 0) + 1
|
||||
|
||||
return {
|
||||
"total_operations": len(operations),
|
||||
"average_duration": sum(durations) / len(durations),
|
||||
"max_duration": max(durations),
|
||||
"min_duration": min(durations),
|
||||
"operations_by_type": operation_counts,
|
||||
"time_period_hours": hours
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing performance logs: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
@staticmethod
|
||||
async def get_error_summary(hours: int = 24) -> Dict[str, Any]:
|
||||
"""Get error summary for the last N hours"""
|
||||
try:
|
||||
error_file = f"{LOG_BASE_DIR}/seo_tools/errors.jsonl"
|
||||
if not os.path.exists(error_file):
|
||||
return {"message": "No errors recorded"}
|
||||
|
||||
cutoff_time = datetime.utcnow().timestamp() - (hours * 3600)
|
||||
errors = []
|
||||
|
||||
async with aiofiles.open(error_file, "r") as file:
|
||||
async for line in file:
|
||||
try:
|
||||
data = json.loads(line.strip())
|
||||
log_time = datetime.fromisoformat(data["timestamp"]).timestamp()
|
||||
if log_time >= cutoff_time:
|
||||
errors.append(data)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
||||
if not errors:
|
||||
return {"message": f"No errors in the last {hours} hours"}
|
||||
|
||||
# Analyze errors
|
||||
error_types = {}
|
||||
functions_with_errors = {}
|
||||
|
||||
for error in errors:
|
||||
error_type = error.get("error_type", "Unknown")
|
||||
function = error.get("function", "Unknown")
|
||||
|
||||
error_types[error_type] = error_types.get(error_type, 0) + 1
|
||||
functions_with_errors[function] = functions_with_errors.get(function, 0) + 1
|
||||
|
||||
return {
|
||||
"total_errors": len(errors),
|
||||
"error_types": error_types,
|
||||
"functions_with_errors": functions_with_errors,
|
||||
"recent_errors": errors[-5:], # Last 5 errors
|
||||
"time_period_hours": hours
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error analyzing error logs: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
# Initialize global logger instance
|
||||
seo_logger = SEOToolsLogger()
|
||||
log_analyzer = LogAnalyzer()
|
||||
|
||||
# Configure loguru for structured logging
|
||||
# Commented out to prevent conflicts with main logging configuration
|
||||
# logger.add(
|
||||
# f"{LOG_BASE_DIR}/application.log",
|
||||
# rotation="1 day",
|
||||
# retention="30 days",
|
||||
# level="INFO",
|
||||
# format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
|
||||
# serialize=True
|
||||
# )
|
||||
|
||||
# logger.add(
|
||||
# f"{LOG_BASE_DIR}/errors.log",
|
||||
# rotation="1 day",
|
||||
# retention="30 days",
|
||||
# level="ERROR",
|
||||
# format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
|
||||
# serialize=True
|
||||
# )
|
||||
|
||||
logger.info("Logging middleware initialized successfully")
|
||||
702
backend/middleware/stability_middleware.py
Normal file
702
backend/middleware/stability_middleware.py
Normal file
@@ -0,0 +1,702 @@
|
||||
"""Middleware for Stability AI operations."""
|
||||
|
||||
import time
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List
|
||||
from collections import defaultdict, deque
|
||||
from fastapi import Request, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
import json
|
||||
from loguru import logger
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
class RateLimitMiddleware:
|
||||
"""Rate limiting middleware for Stability AI API calls."""
|
||||
|
||||
def __init__(self, requests_per_window: int = 150, window_seconds: int = 10):
|
||||
"""Initialize rate limiter.
|
||||
|
||||
Args:
|
||||
requests_per_window: Maximum requests per time window
|
||||
window_seconds: Time window in seconds
|
||||
"""
|
||||
self.requests_per_window = requests_per_window
|
||||
self.window_seconds = window_seconds
|
||||
self.request_times: Dict[str, deque] = defaultdict(lambda: deque())
|
||||
self.blocked_until: Dict[str, float] = {}
|
||||
|
||||
async def __call__(self, request: Request, call_next):
|
||||
"""Process request with rate limiting.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
call_next: Next middleware/endpoint
|
||||
|
||||
Returns:
|
||||
Response
|
||||
"""
|
||||
# Skip rate limiting for non-Stability endpoints
|
||||
if not request.url.path.startswith("/api/stability"):
|
||||
return await call_next(request)
|
||||
|
||||
# Get client identifier (IP address or API key)
|
||||
client_id = self._get_client_id(request)
|
||||
current_time = time.time()
|
||||
|
||||
# Check if client is currently blocked
|
||||
if client_id in self.blocked_until:
|
||||
if current_time < self.blocked_until[client_id]:
|
||||
remaining = int(self.blocked_until[client_id] - current_time)
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded",
|
||||
"retry_after": remaining,
|
||||
"message": f"You have been timed out for {remaining} seconds"
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Timeout expired, remove block
|
||||
del self.blocked_until[client_id]
|
||||
|
||||
# Clean old requests outside the window
|
||||
request_times = self.request_times[client_id]
|
||||
while request_times and request_times[0] < current_time - self.window_seconds:
|
||||
request_times.popleft()
|
||||
|
||||
# Check rate limit
|
||||
if len(request_times) >= self.requests_per_window:
|
||||
# Rate limit exceeded, block for 60 seconds
|
||||
self.blocked_until[client_id] = current_time + 60
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"error": "Rate limit exceeded",
|
||||
"retry_after": 60,
|
||||
"message": "You have exceeded the rate limit of 150 requests within a 10 second period"
|
||||
}
|
||||
)
|
||||
|
||||
# Add current request time
|
||||
request_times.append(current_time)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add rate limit headers
|
||||
response.headers["X-RateLimit-Limit"] = str(self.requests_per_window)
|
||||
response.headers["X-RateLimit-Remaining"] = str(self.requests_per_window - len(request_times))
|
||||
response.headers["X-RateLimit-Reset"] = str(int(current_time + self.window_seconds))
|
||||
|
||||
return response
|
||||
|
||||
def _get_client_id(self, request: Request) -> str:
|
||||
"""Get client identifier for rate limiting.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
Client identifier
|
||||
"""
|
||||
# Try to get API key from authorization header
|
||||
auth_header = request.headers.get("authorization", "")
|
||||
if auth_header.startswith("Bearer "):
|
||||
return auth_header[7:15] # Use first 8 chars of API key
|
||||
|
||||
# Fall back to IP address
|
||||
return request.client.host if request.client else "unknown"
|
||||
|
||||
|
||||
class MonitoringMiddleware:
|
||||
"""Monitoring middleware for Stability AI operations."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize monitoring middleware."""
|
||||
self.request_stats = defaultdict(lambda: {
|
||||
"count": 0,
|
||||
"total_time": 0,
|
||||
"errors": 0,
|
||||
"last_request": None
|
||||
})
|
||||
self.active_requests = {}
|
||||
|
||||
async def __call__(self, request: Request, call_next):
|
||||
"""Process request with monitoring.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
call_next: Next middleware/endpoint
|
||||
|
||||
Returns:
|
||||
Response
|
||||
"""
|
||||
# Skip monitoring for non-Stability endpoints
|
||||
if not request.url.path.startswith("/api/stability"):
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
request_id = f"{int(start_time * 1000)}_{id(request)}"
|
||||
|
||||
# Extract operation info
|
||||
operation = self._extract_operation(request.url.path)
|
||||
|
||||
# Log request start
|
||||
self.active_requests[request_id] = {
|
||||
"operation": operation,
|
||||
"start_time": start_time,
|
||||
"path": request.url.path,
|
||||
"method": request.method
|
||||
}
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Update stats
|
||||
stats = self.request_stats[operation]
|
||||
stats["count"] += 1
|
||||
stats["total_time"] += processing_time
|
||||
stats["last_request"] = datetime.utcnow().isoformat()
|
||||
|
||||
# Add monitoring headers
|
||||
response.headers["X-Processing-Time"] = str(round(processing_time, 3))
|
||||
response.headers["X-Operation"] = operation
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
# Log successful request
|
||||
logger.info(f"Stability AI request completed: {operation} in {processing_time:.3f}s")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Update error stats
|
||||
self.request_stats[operation]["errors"] += 1
|
||||
|
||||
# Log error
|
||||
logger.error(f"Stability AI request failed: {operation} - {str(e)}")
|
||||
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Clean up active request
|
||||
self.active_requests.pop(request_id, None)
|
||||
|
||||
def _extract_operation(self, path: str) -> str:
|
||||
"""Extract operation name from request path.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
Operation name
|
||||
"""
|
||||
path_parts = path.split("/")
|
||||
|
||||
if len(path_parts) >= 4:
|
||||
if "generate" in path_parts:
|
||||
return f"generate_{path_parts[-1]}"
|
||||
elif "edit" in path_parts:
|
||||
return f"edit_{path_parts[-1]}"
|
||||
elif "upscale" in path_parts:
|
||||
return f"upscale_{path_parts[-1]}"
|
||||
elif "control" in path_parts:
|
||||
return f"control_{path_parts[-1]}"
|
||||
elif "3d" in path_parts:
|
||||
return f"3d_{path_parts[-1]}"
|
||||
elif "audio" in path_parts:
|
||||
return f"audio_{path_parts[-1]}"
|
||||
|
||||
return "unknown"
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get monitoring statistics.
|
||||
|
||||
Returns:
|
||||
Monitoring statistics
|
||||
"""
|
||||
stats = {}
|
||||
|
||||
for operation, data in self.request_stats.items():
|
||||
avg_time = data["total_time"] / data["count"] if data["count"] > 0 else 0
|
||||
error_rate = (data["errors"] / data["count"]) * 100 if data["count"] > 0 else 0
|
||||
|
||||
stats[operation] = {
|
||||
"total_requests": data["count"],
|
||||
"total_errors": data["errors"],
|
||||
"error_rate_percent": round(error_rate, 2),
|
||||
"average_processing_time": round(avg_time, 3),
|
||||
"last_request": data["last_request"]
|
||||
}
|
||||
|
||||
stats["active_requests"] = len(self.active_requests)
|
||||
stats["total_operations"] = len(self.request_stats)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
class ContentModerationMiddleware:
|
||||
"""Content moderation middleware for Stability AI requests."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize content moderation middleware."""
|
||||
self.blocked_terms = self._load_blocked_terms()
|
||||
self.warning_terms = self._load_warning_terms()
|
||||
|
||||
async def __call__(self, request: Request, call_next):
|
||||
"""Process request with content moderation.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
call_next: Next middleware/endpoint
|
||||
|
||||
Returns:
|
||||
Response
|
||||
"""
|
||||
# Skip moderation for non-generation endpoints
|
||||
if not self._should_moderate(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract and check prompt content
|
||||
prompt = await self._extract_prompt(request)
|
||||
|
||||
if prompt:
|
||||
moderation_result = self._moderate_content(prompt)
|
||||
|
||||
if moderation_result["blocked"]:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "Content moderation",
|
||||
"message": "Your request was flagged by our content moderation system",
|
||||
"issues": moderation_result["issues"]
|
||||
}
|
||||
)
|
||||
|
||||
if moderation_result["warnings"]:
|
||||
logger.warning(f"Content warnings for prompt: {moderation_result['warnings']}")
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Add content moderation headers
|
||||
if prompt:
|
||||
response.headers["X-Content-Moderated"] = "true"
|
||||
|
||||
return response
|
||||
|
||||
def _should_moderate(self, path: str) -> bool:
|
||||
"""Check if path should be moderated.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if should be moderated
|
||||
"""
|
||||
moderated_paths = ["/generate/", "/edit/", "/control/", "/audio/"]
|
||||
return any(mod_path in path for mod_path in moderated_paths)
|
||||
|
||||
async def _extract_prompt(self, request: Request) -> Optional[str]:
|
||||
"""Extract prompt from request.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
Extracted prompt or None
|
||||
"""
|
||||
try:
|
||||
if request.method == "POST":
|
||||
# For form data, we'd need to parse the form
|
||||
# This is a simplified version
|
||||
body = await request.body()
|
||||
if b"prompt=" in body:
|
||||
# Extract prompt from form data (simplified)
|
||||
body_str = body.decode('utf-8', errors='ignore')
|
||||
if "prompt=" in body_str:
|
||||
start = body_str.find("prompt=") + 7
|
||||
end = body_str.find("&", start)
|
||||
if end == -1:
|
||||
end = len(body_str)
|
||||
return body_str[start:end]
|
||||
except:
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def _moderate_content(self, prompt: str) -> Dict[str, Any]:
|
||||
"""Moderate content for policy violations.
|
||||
|
||||
Args:
|
||||
prompt: Text prompt to moderate
|
||||
|
||||
Returns:
|
||||
Moderation result
|
||||
"""
|
||||
issues = []
|
||||
warnings = []
|
||||
|
||||
prompt_lower = prompt.lower()
|
||||
|
||||
# Check for blocked terms
|
||||
for term in self.blocked_terms:
|
||||
if term in prompt_lower:
|
||||
issues.append(f"Contains blocked term: {term}")
|
||||
|
||||
# Check for warning terms
|
||||
for term in self.warning_terms:
|
||||
if term in prompt_lower:
|
||||
warnings.append(f"Contains flagged term: {term}")
|
||||
|
||||
return {
|
||||
"blocked": len(issues) > 0,
|
||||
"issues": issues,
|
||||
"warnings": warnings
|
||||
}
|
||||
|
||||
def _load_blocked_terms(self) -> List[str]:
|
||||
"""Load blocked terms from configuration.
|
||||
|
||||
Returns:
|
||||
List of blocked terms
|
||||
"""
|
||||
# In production, this would load from a configuration file or database
|
||||
return [
|
||||
# Add actual blocked terms here
|
||||
]
|
||||
|
||||
def _load_warning_terms(self) -> List[str]:
|
||||
"""Load warning terms from configuration.
|
||||
|
||||
Returns:
|
||||
List of warning terms
|
||||
"""
|
||||
# In production, this would load from a configuration file or database
|
||||
return [
|
||||
# Add actual warning terms here
|
||||
]
|
||||
|
||||
|
||||
class CachingMiddleware:
|
||||
"""Caching middleware for Stability AI responses."""
|
||||
|
||||
def __init__(self, cache_duration: int = 3600):
|
||||
"""Initialize caching middleware.
|
||||
|
||||
Args:
|
||||
cache_duration: Cache duration in seconds
|
||||
"""
|
||||
self.cache_duration = cache_duration
|
||||
self.cache: Dict[str, Dict[str, Any]] = {}
|
||||
self.cache_times: Dict[str, float] = {}
|
||||
|
||||
async def __call__(self, request: Request, call_next):
|
||||
"""Process request with caching.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
call_next: Next middleware/endpoint
|
||||
|
||||
Returns:
|
||||
Response (cached or fresh)
|
||||
"""
|
||||
# Skip caching for non-cacheable endpoints
|
||||
if not self._should_cache(request):
|
||||
return await call_next(request)
|
||||
|
||||
# Generate cache key
|
||||
cache_key = await self._generate_cache_key(request)
|
||||
|
||||
# Check cache
|
||||
if self._is_cached(cache_key):
|
||||
logger.info(f"Returning cached result for {cache_key}")
|
||||
cached_data = self.cache[cache_key]
|
||||
|
||||
return JSONResponse(
|
||||
content=cached_data["content"],
|
||||
headers={**cached_data["headers"], "X-Cache-Hit": "true"}
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Cache successful responses
|
||||
if response.status_code == 200 and self._should_cache_response(response):
|
||||
await self._cache_response(cache_key, response)
|
||||
|
||||
return response
|
||||
|
||||
def _should_cache(self, request: Request) -> bool:
|
||||
"""Check if request should be cached.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
True if should be cached
|
||||
"""
|
||||
# Only cache GET requests and certain POST operations
|
||||
if request.method == "GET":
|
||||
return True
|
||||
|
||||
# Cache deterministic operations (those with seeds)
|
||||
cacheable_paths = ["/models/info", "/supported-formats", "/health"]
|
||||
return any(path in request.url.path for path in cacheable_paths)
|
||||
|
||||
def _should_cache_response(self, response) -> bool:
|
||||
"""Check if response should be cached.
|
||||
|
||||
Args:
|
||||
response: FastAPI response
|
||||
|
||||
Returns:
|
||||
True if should be cached
|
||||
"""
|
||||
# Don't cache large binary responses
|
||||
content_length = response.headers.get("content-length")
|
||||
if content_length and int(content_length) > 1024 * 1024: # 1MB
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
async def _generate_cache_key(self, request: Request) -> str:
|
||||
"""Generate cache key for request.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
|
||||
Returns:
|
||||
Cache key
|
||||
"""
|
||||
import hashlib
|
||||
|
||||
key_parts = [
|
||||
request.method,
|
||||
request.url.path,
|
||||
str(sorted(request.query_params.items()))
|
||||
]
|
||||
|
||||
# For POST requests, include body hash
|
||||
if request.method == "POST":
|
||||
body = await request.body()
|
||||
if body:
|
||||
key_parts.append(hashlib.md5(body).hexdigest())
|
||||
|
||||
key_string = "|".join(key_parts)
|
||||
return hashlib.sha256(key_string.encode()).hexdigest()
|
||||
|
||||
def _is_cached(self, cache_key: str) -> bool:
|
||||
"""Check if key is cached and not expired.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
|
||||
Returns:
|
||||
True if cached and valid
|
||||
"""
|
||||
if cache_key not in self.cache:
|
||||
return False
|
||||
|
||||
cache_time = self.cache_times.get(cache_key, 0)
|
||||
return time.time() - cache_time < self.cache_duration
|
||||
|
||||
async def _cache_response(self, cache_key: str, response) -> None:
|
||||
"""Cache response data.
|
||||
|
||||
Args:
|
||||
cache_key: Cache key
|
||||
response: Response to cache
|
||||
"""
|
||||
try:
|
||||
# Only cache JSON responses for now
|
||||
if response.headers.get("content-type", "").startswith("application/json"):
|
||||
self.cache[cache_key] = {
|
||||
"content": json.loads(response.body),
|
||||
"headers": dict(response.headers)
|
||||
}
|
||||
self.cache_times[cache_key] = time.time()
|
||||
except:
|
||||
# Ignore cache errors
|
||||
pass
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear all cached data."""
|
||||
self.cache.clear()
|
||||
self.cache_times.clear()
|
||||
|
||||
def get_cache_stats(self) -> Dict[str, Any]:
|
||||
"""Get cache statistics.
|
||||
|
||||
Returns:
|
||||
Cache statistics
|
||||
"""
|
||||
current_time = time.time()
|
||||
expired_keys = [
|
||||
key for key, cache_time in self.cache_times.items()
|
||||
if current_time - cache_time > self.cache_duration
|
||||
]
|
||||
|
||||
return {
|
||||
"total_entries": len(self.cache),
|
||||
"expired_entries": len(expired_keys),
|
||||
"cache_hit_rate": "N/A", # Would need request tracking
|
||||
"memory_usage": sum(len(str(data)) for data in self.cache.values())
|
||||
}
|
||||
|
||||
|
||||
class RequestLoggingMiddleware:
|
||||
"""Logging middleware for Stability AI requests."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize logging middleware."""
|
||||
self.request_log = []
|
||||
self.max_log_entries = 1000
|
||||
|
||||
async def __call__(self, request: Request, call_next):
|
||||
"""Process request with logging.
|
||||
|
||||
Args:
|
||||
request: FastAPI request
|
||||
call_next: Next middleware/endpoint
|
||||
|
||||
Returns:
|
||||
Response
|
||||
"""
|
||||
# Skip logging for non-Stability endpoints
|
||||
if not request.url.path.startswith("/api/stability"):
|
||||
return await call_next(request)
|
||||
|
||||
start_time = time.time()
|
||||
request_id = f"{int(start_time * 1000)}_{id(request)}"
|
||||
|
||||
# Log request details
|
||||
log_entry = {
|
||||
"request_id": request_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"query_params": dict(request.query_params),
|
||||
"client_ip": request.client.host if request.client else "unknown",
|
||||
"user_agent": request.headers.get("user-agent", "unknown")
|
||||
}
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate processing time
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# Update log entry
|
||||
log_entry.update({
|
||||
"status_code": response.status_code,
|
||||
"processing_time": round(processing_time, 3),
|
||||
"response_size": len(response.body) if hasattr(response, 'body') else 0,
|
||||
"success": True
|
||||
})
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# Log error
|
||||
log_entry.update({
|
||||
"error": str(e),
|
||||
"success": False,
|
||||
"processing_time": round(time.time() - start_time, 3)
|
||||
})
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Add to log
|
||||
self._add_log_entry(log_entry)
|
||||
|
||||
def _add_log_entry(self, entry: Dict[str, Any]) -> None:
|
||||
"""Add entry to request log.
|
||||
|
||||
Args:
|
||||
entry: Log entry
|
||||
"""
|
||||
self.request_log.append(entry)
|
||||
|
||||
# Keep only recent entries
|
||||
if len(self.request_log) > self.max_log_entries:
|
||||
self.request_log = self.request_log[-self.max_log_entries:]
|
||||
|
||||
def get_recent_logs(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""Get recent log entries.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of entries to return
|
||||
|
||||
Returns:
|
||||
Recent log entries
|
||||
"""
|
||||
return self.request_log[-limit:]
|
||||
|
||||
def get_log_summary(self) -> Dict[str, Any]:
|
||||
"""Get summary of logged requests.
|
||||
|
||||
Returns:
|
||||
Log summary statistics
|
||||
"""
|
||||
if not self.request_log:
|
||||
return {"total_requests": 0}
|
||||
|
||||
total_requests = len(self.request_log)
|
||||
successful_requests = sum(1 for entry in self.request_log if entry.get("success", False))
|
||||
|
||||
# Calculate average processing time
|
||||
processing_times = [
|
||||
entry["processing_time"] for entry in self.request_log
|
||||
if "processing_time" in entry
|
||||
]
|
||||
avg_processing_time = sum(processing_times) / len(processing_times) if processing_times else 0
|
||||
|
||||
# Get operation breakdown
|
||||
operations = defaultdict(int)
|
||||
for entry in self.request_log:
|
||||
operation = entry.get("path", "unknown").split("/")[-1]
|
||||
operations[operation] += 1
|
||||
|
||||
return {
|
||||
"total_requests": total_requests,
|
||||
"successful_requests": successful_requests,
|
||||
"error_rate_percent": round((1 - successful_requests / total_requests) * 100, 2),
|
||||
"average_processing_time": round(avg_processing_time, 3),
|
||||
"operations_breakdown": dict(operations),
|
||||
"time_range": {
|
||||
"start": self.request_log[0]["timestamp"],
|
||||
"end": self.request_log[-1]["timestamp"]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# Global middleware instances
|
||||
rate_limiter = RateLimitMiddleware()
|
||||
monitoring = MonitoringMiddleware()
|
||||
caching = CachingMiddleware()
|
||||
request_logging = RequestLoggingMiddleware()
|
||||
|
||||
|
||||
def get_middleware_stats() -> Dict[str, Any]:
|
||||
"""Get statistics from all middleware components.
|
||||
|
||||
Returns:
|
||||
Combined middleware statistics
|
||||
"""
|
||||
return {
|
||||
"rate_limiting": {
|
||||
"active_blocks": len(rate_limiter.blocked_until),
|
||||
"requests_per_window": rate_limiter.requests_per_window,
|
||||
"window_seconds": rate_limiter.window_seconds
|
||||
},
|
||||
"monitoring": monitoring.get_stats(),
|
||||
"caching": caching.get_cache_stats(),
|
||||
"logging": request_logging.get_log_summary()
|
||||
}
|
||||
Reference in New Issue
Block a user