Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

View File

@@ -0,0 +1 @@
# Makes the middleware directory a Python package

View File

@@ -0,0 +1,123 @@
"""
API Key Injection Middleware
Temporarily injects user-specific API keys into os.environ for the duration of the request.
This allows existing code that uses os.getenv('GEMINI_API_KEY') to work without modification.
IMPORTANT: This is a compatibility layer. For new code, use UserAPIKeyContext directly.
"""
import os
from fastapi import Request
from loguru import logger
from typing import Callable
from services.user_api_key_context import user_api_keys
class APIKeyInjectionMiddleware:
"""
Middleware that injects user-specific API keys into environment variables
for the duration of each request.
"""
def __init__(self):
self.original_keys = {}
async def __call__(self, request: Request, call_next: Callable):
"""
Inject user-specific API keys before processing request,
restore original values after request completes.
"""
# Try to extract user_id from Authorization header
user_id = None
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Bearer '):
try:
from middleware.auth_middleware import clerk_auth
token = auth_header.replace('Bearer ', '')
user = await clerk_auth.verify_token(token)
if user:
# Try different possible keys for user_id
user_id = user.get('user_id') or user.get('clerk_user_id') or user.get('id')
if user_id:
logger.info(f"[API Key Injection] Extracted user_id: {user_id}")
# Store user_id in request.state for monitoring middleware
request.state.user_id = user_id
else:
logger.warning(f"[API Key Injection] User object missing ID: {user}")
else:
# Token verification failed (likely expired) - log at debug level to reduce noise
logger.debug("[API Key Injection] Token verification failed (likely expired token)")
except Exception as e:
logger.error(f"[API Key Injection] Could not extract user from token: {e}")
if not user_id:
# No authenticated user, proceed without injection
return await call_next(request)
# Check if we're in production mode
is_production = os.getenv('DEPLOY_ENV', 'local') == 'production'
if not is_production:
# Local mode - keys already in .env, no injection needed
return await call_next(request)
# Get user-specific API keys from database
with user_api_keys(user_id) as user_keys:
if not user_keys:
logger.warning(f"No API keys found for user {user_id}")
return await call_next(request)
# Save original environment values
original_keys = {}
keys_to_inject = {
'gemini': 'GEMINI_API_KEY',
'exa': 'EXA_API_KEY',
'copilotkit': 'COPILOTKIT_API_KEY',
'openai': 'OPENAI_API_KEY',
'anthropic': 'ANTHROPIC_API_KEY',
'tavily': 'TAVILY_API_KEY',
'serper': 'SERPER_API_KEY',
'firecrawl': 'FIRECRAWL_API_KEY',
}
# Inject user-specific keys into environment
for provider, env_var in keys_to_inject.items():
if provider in user_keys and user_keys[provider]:
# Save original value (if any)
original_keys[env_var] = os.environ.get(env_var)
# Inject user-specific key
os.environ[env_var] = user_keys[provider]
logger.debug(f"[PRODUCTION] Injected {env_var} for user {user_id}")
try:
# Process request with user-specific keys in environment
response = await call_next(request)
return response
finally:
# CRITICAL: Restore original environment values
for env_var, original_value in original_keys.items():
if original_value is None:
# Key didn't exist before, remove it
os.environ.pop(env_var, None)
else:
# Restore original value
os.environ[env_var] = original_value
logger.debug(f"[PRODUCTION] Cleaned up environment for user {user_id}")
async def api_key_injection_middleware(request: Request, call_next: Callable):
"""
Middleware function that injects user-specific API keys into environment.
Usage in app.py:
app.middleware("http")(api_key_injection_middleware)
"""
middleware = APIKeyInjectionMiddleware()
return await middleware(request, call_next)

View File

@@ -0,0 +1,348 @@
"""Authentication middleware for ALwrity backend."""
import os
from typing import Optional, Dict, Any
from fastapi import HTTPException, Depends, status, Request, Query
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from loguru import logger
from dotenv import load_dotenv
# Try to import fastapi-clerk-auth, fallback to custom implementation
try:
from fastapi_clerk_auth import ClerkHTTPBearer, ClerkConfig
CLERK_AUTH_AVAILABLE = True
except ImportError:
CLERK_AUTH_AVAILABLE = False
logger.warning("fastapi-clerk-auth not available, using custom implementation")
# Load environment variables from the correct path
# Get the backend directory path (parent of middleware directory)
_backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
_env_path = os.path.join(_backend_dir, ".env")
load_dotenv(_env_path, override=False) # Don't override if already loaded
# Initialize security scheme
security = HTTPBearer(auto_error=False)
class ClerkAuthMiddleware:
"""Clerk authentication middleware using fastapi-clerk-auth or custom implementation."""
def __init__(self):
"""Initialize Clerk authentication middleware."""
self.clerk_secret_key = os.getenv('CLERK_SECRET_KEY', '').strip()
# Check for both backend and frontend naming conventions
publishable_key = (
os.getenv('CLERK_PUBLISHABLE_KEY') or
os.getenv('REACT_APP_CLERK_PUBLISHABLE_KEY', '')
)
self.clerk_publishable_key = publishable_key.strip() if publishable_key else None
self.disable_auth = os.getenv('DISABLE_AUTH', 'false').lower() == 'true'
# Cache for PyJWKClient to avoid repeated JWKS fetches
self._jwks_client_cache = {}
self._jwks_url_cache = None
if not self.clerk_secret_key and not self.disable_auth:
logger.warning("CLERK_SECRET_KEY not found, authentication may fail")
# Initialize fastapi-clerk-auth if available
if CLERK_AUTH_AVAILABLE and not self.disable_auth:
try:
if self.clerk_secret_key and self.clerk_publishable_key:
# Extract instance from publishable key for JWKS URL
# Format: pk_test_<instance>.<domain> or pk_live_<instance>.<domain>
parts = self.clerk_publishable_key.replace('pk_test_', '').replace('pk_live_', '').split('.')
if len(parts) >= 1:
# Extract the domain from publishable key or use default
# Clerk URLs are typically: https://<instance>.clerk.accounts.dev
instance = parts[0]
jwks_url = f"https://{instance}.clerk.accounts.dev/.well-known/jwks.json"
# Create Clerk configuration with JWKS URL
clerk_config = ClerkConfig(
secret_key=self.clerk_secret_key,
jwks_url=jwks_url
)
# Create ClerkHTTPBearer instance for dependency injection
self.clerk_bearer = ClerkHTTPBearer(clerk_config)
logger.info(f"fastapi-clerk-auth initialized successfully with JWKS URL: {jwks_url}")
else:
logger.warning("Could not extract instance from publishable key")
self.clerk_bearer = None
else:
logger.warning("CLERK_SECRET_KEY or CLERK_PUBLISHABLE_KEY not found")
self.clerk_bearer = None
except Exception as e:
logger.error(f"Failed to initialize fastapi-clerk-auth: {e}")
self.clerk_bearer = None
else:
self.clerk_bearer = None
logger.info(f"ClerkAuthMiddleware initialized - Auth disabled: {self.disable_auth}, fastapi-clerk-auth: {CLERK_AUTH_AVAILABLE}")
async def verify_token(self, token: str) -> Optional[Dict[str, Any]]:
"""Verify Clerk JWT using fastapi-clerk-auth or custom implementation."""
try:
if self.disable_auth:
logger.info("Authentication disabled, returning mock user")
return {
'id': 'mock_user_id',
'email': 'mock@example.com',
'first_name': 'Mock',
'last_name': 'User',
'clerk_user_id': 'mock_clerk_user_id'
}
if not self.clerk_secret_key:
logger.error("CLERK_SECRET_KEY not configured")
return None
# Use fastapi-clerk-auth if available
if self.clerk_bearer:
try:
# Decode and verify the JWT token
import jwt
from jwt import PyJWKClient
# Get the JWKS URL from the token header
unverified_header = jwt.get_unverified_header(token)
# Decode token to get issuer for JWKS URL
unverified_claims = jwt.decode(token, options={"verify_signature": False})
issuer = unverified_claims.get('iss', '')
# Construct JWKS URL from issuer
jwks_url = f"{issuer}/.well-known/jwks.json"
# Use cached PyJWKClient to avoid repeated JWKS fetches
if jwks_url not in self._jwks_client_cache:
logger.info(f"Creating new PyJWKClient for {jwks_url} with caching enabled")
# Create client with caching enabled (cache_keys=True keeps keys in memory)
self._jwks_client_cache[jwks_url] = PyJWKClient(
jwks_url,
cache_keys=True,
max_cached_keys=16
)
jwks_client = self._jwks_client_cache[jwks_url]
signing_key = jwks_client.get_signing_key_from_jwt(token)
# Verify and decode the token with clock skew tolerance
# Add 300 seconds (5 minutes) leeway to handle clock skew and token refresh delays
decoded_token = jwt.decode(
token,
signing_key.key,
algorithms=["RS256"],
options={"verify_signature": True, "verify_exp": True},
leeway=300 # Allow 5 minutes leeway for token refresh during navigation
)
# Extract user information
user_id = decoded_token.get('sub')
email = decoded_token.get('email')
first_name = decoded_token.get('first_name') or decoded_token.get('given_name')
last_name = decoded_token.get('last_name') or decoded_token.get('family_name')
if user_id:
logger.info(f"Token verified successfully using fastapi-clerk-auth for user: {email} (ID: {user_id})")
return {
'id': user_id,
'email': email,
'first_name': first_name,
'last_name': last_name,
'clerk_user_id': user_id
}
else:
logger.warning("No user ID found in verified token")
return None
except Exception as e:
# Expired tokens are expected - log at debug level to reduce noise
error_msg = str(e).lower()
if 'expired' in error_msg or 'signature has expired' in error_msg:
logger.debug(f"Token expired (expected): {e}")
else:
logger.warning(f"fastapi-clerk-auth verification error: {e}")
return None
else:
# Fallback to custom implementation (not secure for production)
logger.warning("Using fallback JWT decoding without signature verification")
try:
import jwt
# Decode the JWT without verification to get claims
# This is NOT secure for production - only for development
# Add leeway to handle clock skew
decoded_token = jwt.decode(
token,
options={"verify_signature": False},
leeway=300 # Allow 5 minutes leeway for token refresh
)
# Extract user information from the token
user_id = decoded_token.get('sub') or decoded_token.get('user_id')
email = decoded_token.get('email')
first_name = decoded_token.get('first_name')
last_name = decoded_token.get('last_name')
if not user_id:
logger.warning("No user ID found in token")
return None
logger.info(f"Token decoded successfully (fallback) for user: {email} (ID: {user_id})")
return {
'id': user_id,
'email': email,
'first_name': first_name,
'last_name': last_name,
'clerk_user_id': user_id
}
except Exception as e:
logger.warning(f"Fallback JWT decode error: {e}")
return None
except Exception as e:
logger.error(f"Token verification error: {e}")
return None
# Initialize middleware
clerk_auth = ClerkAuthMiddleware()
async def get_current_user(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Dict[str, Any]:
"""Get current authenticated user."""
try:
if not credentials:
# CRITICAL: Log as ERROR since this is a security issue - authenticated endpoint accessed without credentials
endpoint_path = f"{request.method} {request.url.path}"
logger.error(
f"🔒 AUTHENTICATION ERROR: No credentials provided for authenticated endpoint: {endpoint_path} "
f"(client_ip={request.client.host if request.client else 'unknown'})"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
token = credentials.credentials
user = await clerk_auth.verify_token(token)
if not user:
# Token verification failed - log with endpoint context for debugging
endpoint_path = f"{request.method} {request.url.path}"
logger.error(
f"🔒 AUTHENTICATION ERROR: Token verification failed for endpoint: {endpoint_path} "
f"(client_ip={request.client.host if request.client else 'unknown'})"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication failed",
headers={"WWW-Authenticate": "Bearer"},
)
return user
except HTTPException:
raise
except Exception as e:
endpoint_path = f"{request.method} {request.url.path}"
logger.error(
f"🔒 AUTHENTICATION ERROR: Unexpected error during authentication for endpoint: {endpoint_path}: {e}",
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication failed",
headers={"WWW-Authenticate": "Bearer"},
)
async def get_optional_user(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Optional[Dict[str, Any]]:
"""Get current user if authenticated, otherwise return None."""
try:
if not credentials:
return None
token = credentials.credentials
user = await clerk_auth.verify_token(token)
return user
except Exception as e:
logger.warning(f"Optional authentication failed: {e}")
return None
async def get_current_user_with_query_token(
request: Request,
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security)
) -> Dict[str, Any]:
"""Get current authenticated user from either Authorization header or query parameter.
This is useful for media endpoints (audio, video, images) that need to be accessed
by HTML elements like <audio> or <img> which cannot send custom headers.
Args:
request: FastAPI request object
credentials: HTTP authorization credentials from header
Returns:
User dictionary with authentication info
Raises:
HTTPException: If authentication fails
"""
try:
# Try to get token from Authorization header first
token_to_verify = None
if credentials:
token_to_verify = credentials.credentials
else:
# Fall back to query parameter if no header
query_token = request.query_params.get("token")
if query_token:
token_to_verify = query_token
if not token_to_verify:
# CRITICAL: Log as ERROR since this is a security issue
endpoint_path = f"{request.method} {request.url.path}"
logger.error(
f"🔒 AUTHENTICATION ERROR: No credentials provided (neither header nor query parameter) "
f"for authenticated endpoint: {endpoint_path} "
f"(client_ip={request.client.host if request.client else 'unknown'})"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)
user = await clerk_auth.verify_token(token_to_verify)
if not user:
# Token verification failed - log with endpoint context
endpoint_path = f"{request.method} {request.url.path}"
logger.error(
f"🔒 AUTHENTICATION ERROR: Token verification failed for endpoint: {endpoint_path} "
f"(client_ip={request.client.host if request.client else 'unknown'})"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication failed",
headers={"WWW-Authenticate": "Bearer"},
)
return user
except HTTPException:
raise
except Exception as e:
endpoint_path = f"{request.method} {request.url.path}"
logger.error(
f"🔒 AUTHENTICATION ERROR: Unexpected error during authentication for endpoint: {endpoint_path}: {e}",
exc_info=True
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication failed",
headers={"WWW-Authenticate": "Bearer"},
)

View File

@@ -0,0 +1,332 @@
"""
Intelligent Logging Middleware for AI SEO Tools
Provides structured logging, file saving, and monitoring capabilities
for all SEO tool operations with performance tracking.
"""
import json
import asyncio
import aiofiles
from datetime import datetime
from functools import wraps
from typing import Dict, Any, Callable
from pathlib import Path
from loguru import logger
import os
import time
# Logging configuration
LOG_BASE_DIR = "logs"
os.makedirs(LOG_BASE_DIR, exist_ok=True)
# Ensure subdirectories exist
for subdir in ["seo_tools", "api_calls", "errors", "performance"]:
os.makedirs(f"{LOG_BASE_DIR}/{subdir}", exist_ok=True)
class PerformanceLogger:
"""Performance monitoring and logging for SEO operations"""
def __init__(self):
self.performance_data = {}
async def log_performance(self, operation: str, duration: float, metadata: Dict[str, Any] = None):
"""Log performance metrics for operations"""
performance_log = {
"operation": operation,
"duration_seconds": duration,
"timestamp": datetime.utcnow().isoformat(),
"metadata": metadata or {}
}
await save_to_file(f"{LOG_BASE_DIR}/performance/metrics.jsonl", performance_log)
# Log performance warnings for slow operations
if duration > 30: # More than 30 seconds
logger.warning(f"Slow operation detected: {operation} took {duration:.2f} seconds")
elif duration > 10: # More than 10 seconds
logger.info(f"Operation {operation} took {duration:.2f} seconds")
performance_logger = PerformanceLogger()
async def save_to_file(filepath: str, data: Dict[str, Any]) -> None:
"""
Asynchronously save structured data to a JSONL file
Args:
filepath: Path to the log file
data: Dictionary data to save
"""
try:
# Ensure directory exists
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
# Convert data to JSON string
json_line = json.dumps(data, default=str) + "\n"
# Write asynchronously
async with aiofiles.open(filepath, "a", encoding="utf-8") as file:
await file.write(json_line)
except Exception as e:
logger.error(f"Failed to save log to {filepath}: {e}")
def log_api_call(func: Callable) -> Callable:
"""
Decorator for logging API calls with performance tracking
Automatically logs request/response data, timing, and errors
for SEO tool endpoints.
"""
@wraps(func)
async def wrapper(*args, **kwargs):
start_time = time.time()
operation_name = func.__name__
# Extract request data
request_data = {}
for arg in args:
if hasattr(arg, 'dict'): # Pydantic model
request_data.update(arg.dict())
# Log API call start
call_log = {
"operation": operation_name,
"timestamp": datetime.utcnow().isoformat(),
"request_data": request_data,
"status": "started"
}
logger.info(f"API Call Started: {operation_name}")
try:
# Execute the function
result = await func(*args, **kwargs)
execution_time = time.time() - start_time
# Log successful completion
call_log.update({
"status": "completed",
"execution_time": execution_time,
"success": getattr(result, 'success', True),
"completion_timestamp": datetime.utcnow().isoformat()
})
await save_to_file(f"{LOG_BASE_DIR}/api_calls/successful.jsonl", call_log)
await performance_logger.log_performance(operation_name, execution_time, request_data)
logger.info(f"API Call Completed: {operation_name} in {execution_time:.2f}s")
return result
except Exception as e:
execution_time = time.time() - start_time
# Log error
error_log = call_log.copy()
error_log.update({
"status": "failed",
"execution_time": execution_time,
"error_type": type(e).__name__,
"error_message": str(e),
"completion_timestamp": datetime.utcnow().isoformat()
})
await save_to_file(f"{LOG_BASE_DIR}/api_calls/failed.jsonl", error_log)
logger.error(f"API Call Failed: {operation_name} after {execution_time:.2f}s - {e}")
# Re-raise the exception
raise
return wrapper
class SEOToolsLogger:
"""Centralized logger for SEO tools with intelligent categorization"""
@staticmethod
async def log_tool_usage(tool_name: str, input_data: Dict[str, Any],
output_data: Dict[str, Any], success: bool = True):
"""Log SEO tool usage with input/output tracking"""
usage_log = {
"tool": tool_name,
"timestamp": datetime.utcnow().isoformat(),
"input_data": input_data,
"output_data": output_data,
"success": success,
"input_size": len(str(input_data)),
"output_size": len(str(output_data))
}
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/usage.jsonl", usage_log)
@staticmethod
async def log_ai_analysis(tool_name: str, prompt: str, response: str,
model_used: str, tokens_used: int = None):
"""Log AI analysis operations with token tracking"""
ai_log = {
"tool": tool_name,
"timestamp": datetime.utcnow().isoformat(),
"model": model_used,
"prompt_length": len(prompt),
"response_length": len(response),
"tokens_used": tokens_used,
"prompt_preview": prompt[:200] + "..." if len(prompt) > 200 else prompt,
"response_preview": response[:200] + "..." if len(response) > 200 else response
}
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/ai_analysis.jsonl", ai_log)
@staticmethod
async def log_external_api_call(api_name: str, endpoint: str, response_code: int,
response_time: float, request_data: Dict[str, Any] = None):
"""Log external API calls (PageSpeed, etc.)"""
api_log = {
"api": api_name,
"endpoint": endpoint,
"response_code": response_code,
"response_time": response_time,
"timestamp": datetime.utcnow().isoformat(),
"request_data": request_data or {},
"success": 200 <= response_code < 300
}
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/external_apis.jsonl", api_log)
@staticmethod
async def log_crawling_operation(url: str, pages_crawled: int, errors_found: int,
crawl_depth: int, duration: float):
"""Log web crawling operations"""
crawl_log = {
"url": url,
"pages_crawled": pages_crawled,
"errors_found": errors_found,
"crawl_depth": crawl_depth,
"duration": duration,
"timestamp": datetime.utcnow().isoformat(),
"pages_per_second": pages_crawled / duration if duration > 0 else 0
}
await save_to_file(f"{LOG_BASE_DIR}/seo_tools/crawling.jsonl", crawl_log)
class LogAnalyzer:
"""Analyze logs to provide insights and monitoring"""
@staticmethod
async def get_performance_summary(hours: int = 24) -> Dict[str, Any]:
"""Get performance summary for the last N hours"""
try:
performance_file = f"{LOG_BASE_DIR}/performance/metrics.jsonl"
if not os.path.exists(performance_file):
return {"error": "No performance data available"}
# Read recent performance data
cutoff_time = datetime.utcnow().timestamp() - (hours * 3600)
operations = []
async with aiofiles.open(performance_file, "r") as file:
async for line in file:
try:
data = json.loads(line.strip())
log_time = datetime.fromisoformat(data["timestamp"]).timestamp()
if log_time >= cutoff_time:
operations.append(data)
except (json.JSONDecodeError, KeyError):
continue
if not operations:
return {"message": f"No operations in the last {hours} hours"}
# Calculate statistics
durations = [op["duration_seconds"] for op in operations]
operation_counts = {}
for op in operations:
op_name = op["operation"]
operation_counts[op_name] = operation_counts.get(op_name, 0) + 1
return {
"total_operations": len(operations),
"average_duration": sum(durations) / len(durations),
"max_duration": max(durations),
"min_duration": min(durations),
"operations_by_type": operation_counts,
"time_period_hours": hours
}
except Exception as e:
logger.error(f"Error analyzing performance logs: {e}")
return {"error": str(e)}
@staticmethod
async def get_error_summary(hours: int = 24) -> Dict[str, Any]:
"""Get error summary for the last N hours"""
try:
error_file = f"{LOG_BASE_DIR}/seo_tools/errors.jsonl"
if not os.path.exists(error_file):
return {"message": "No errors recorded"}
cutoff_time = datetime.utcnow().timestamp() - (hours * 3600)
errors = []
async with aiofiles.open(error_file, "r") as file:
async for line in file:
try:
data = json.loads(line.strip())
log_time = datetime.fromisoformat(data["timestamp"]).timestamp()
if log_time >= cutoff_time:
errors.append(data)
except (json.JSONDecodeError, KeyError):
continue
if not errors:
return {"message": f"No errors in the last {hours} hours"}
# Analyze errors
error_types = {}
functions_with_errors = {}
for error in errors:
error_type = error.get("error_type", "Unknown")
function = error.get("function", "Unknown")
error_types[error_type] = error_types.get(error_type, 0) + 1
functions_with_errors[function] = functions_with_errors.get(function, 0) + 1
return {
"total_errors": len(errors),
"error_types": error_types,
"functions_with_errors": functions_with_errors,
"recent_errors": errors[-5:], # Last 5 errors
"time_period_hours": hours
}
except Exception as e:
logger.error(f"Error analyzing error logs: {e}")
return {"error": str(e)}
# Initialize global logger instance
seo_logger = SEOToolsLogger()
log_analyzer = LogAnalyzer()
# Configure loguru for structured logging
# Commented out to prevent conflicts with main logging configuration
# logger.add(
# f"{LOG_BASE_DIR}/application.log",
# rotation="1 day",
# retention="30 days",
# level="INFO",
# format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
# serialize=True
# )
# logger.add(
# f"{LOG_BASE_DIR}/errors.log",
# rotation="1 day",
# retention="30 days",
# level="ERROR",
# format="{time:YYYY-MM-DD HH:mm:ss} | {level} | {name}:{function}:{line} | {message}",
# serialize=True
# )
logger.info("Logging middleware initialized successfully")

View File

@@ -0,0 +1,702 @@
"""Middleware for Stability AI operations."""
import time
import asyncio
import os
from typing import Dict, Any, Optional, List
from collections import defaultdict, deque
from fastapi import Request, HTTPException
from fastapi.responses import JSONResponse
import json
from loguru import logger
from datetime import datetime, timedelta
class RateLimitMiddleware:
"""Rate limiting middleware for Stability AI API calls."""
def __init__(self, requests_per_window: int = 150, window_seconds: int = 10):
"""Initialize rate limiter.
Args:
requests_per_window: Maximum requests per time window
window_seconds: Time window in seconds
"""
self.requests_per_window = requests_per_window
self.window_seconds = window_seconds
self.request_times: Dict[str, deque] = defaultdict(lambda: deque())
self.blocked_until: Dict[str, float] = {}
async def __call__(self, request: Request, call_next):
"""Process request with rate limiting.
Args:
request: FastAPI request
call_next: Next middleware/endpoint
Returns:
Response
"""
# Skip rate limiting for non-Stability endpoints
if not request.url.path.startswith("/api/stability"):
return await call_next(request)
# Get client identifier (IP address or API key)
client_id = self._get_client_id(request)
current_time = time.time()
# Check if client is currently blocked
if client_id in self.blocked_until:
if current_time < self.blocked_until[client_id]:
remaining = int(self.blocked_until[client_id] - current_time)
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"retry_after": remaining,
"message": f"You have been timed out for {remaining} seconds"
}
)
else:
# Timeout expired, remove block
del self.blocked_until[client_id]
# Clean old requests outside the window
request_times = self.request_times[client_id]
while request_times and request_times[0] < current_time - self.window_seconds:
request_times.popleft()
# Check rate limit
if len(request_times) >= self.requests_per_window:
# Rate limit exceeded, block for 60 seconds
self.blocked_until[client_id] = current_time + 60
return JSONResponse(
status_code=429,
content={
"error": "Rate limit exceeded",
"retry_after": 60,
"message": "You have exceeded the rate limit of 150 requests within a 10 second period"
}
)
# Add current request time
request_times.append(current_time)
# Process request
response = await call_next(request)
# Add rate limit headers
response.headers["X-RateLimit-Limit"] = str(self.requests_per_window)
response.headers["X-RateLimit-Remaining"] = str(self.requests_per_window - len(request_times))
response.headers["X-RateLimit-Reset"] = str(int(current_time + self.window_seconds))
return response
def _get_client_id(self, request: Request) -> str:
"""Get client identifier for rate limiting.
Args:
request: FastAPI request
Returns:
Client identifier
"""
# Try to get API key from authorization header
auth_header = request.headers.get("authorization", "")
if auth_header.startswith("Bearer "):
return auth_header[7:15] # Use first 8 chars of API key
# Fall back to IP address
return request.client.host if request.client else "unknown"
class MonitoringMiddleware:
"""Monitoring middleware for Stability AI operations."""
def __init__(self):
"""Initialize monitoring middleware."""
self.request_stats = defaultdict(lambda: {
"count": 0,
"total_time": 0,
"errors": 0,
"last_request": None
})
self.active_requests = {}
async def __call__(self, request: Request, call_next):
"""Process request with monitoring.
Args:
request: FastAPI request
call_next: Next middleware/endpoint
Returns:
Response
"""
# Skip monitoring for non-Stability endpoints
if not request.url.path.startswith("/api/stability"):
return await call_next(request)
start_time = time.time()
request_id = f"{int(start_time * 1000)}_{id(request)}"
# Extract operation info
operation = self._extract_operation(request.url.path)
# Log request start
self.active_requests[request_id] = {
"operation": operation,
"start_time": start_time,
"path": request.url.path,
"method": request.method
}
try:
# Process request
response = await call_next(request)
# Calculate processing time
processing_time = time.time() - start_time
# Update stats
stats = self.request_stats[operation]
stats["count"] += 1
stats["total_time"] += processing_time
stats["last_request"] = datetime.utcnow().isoformat()
# Add monitoring headers
response.headers["X-Processing-Time"] = str(round(processing_time, 3))
response.headers["X-Operation"] = operation
response.headers["X-Request-ID"] = request_id
# Log successful request
logger.info(f"Stability AI request completed: {operation} in {processing_time:.3f}s")
return response
except Exception as e:
# Update error stats
self.request_stats[operation]["errors"] += 1
# Log error
logger.error(f"Stability AI request failed: {operation} - {str(e)}")
raise
finally:
# Clean up active request
self.active_requests.pop(request_id, None)
def _extract_operation(self, path: str) -> str:
"""Extract operation name from request path.
Args:
path: Request path
Returns:
Operation name
"""
path_parts = path.split("/")
if len(path_parts) >= 4:
if "generate" in path_parts:
return f"generate_{path_parts[-1]}"
elif "edit" in path_parts:
return f"edit_{path_parts[-1]}"
elif "upscale" in path_parts:
return f"upscale_{path_parts[-1]}"
elif "control" in path_parts:
return f"control_{path_parts[-1]}"
elif "3d" in path_parts:
return f"3d_{path_parts[-1]}"
elif "audio" in path_parts:
return f"audio_{path_parts[-1]}"
return "unknown"
def get_stats(self) -> Dict[str, Any]:
"""Get monitoring statistics.
Returns:
Monitoring statistics
"""
stats = {}
for operation, data in self.request_stats.items():
avg_time = data["total_time"] / data["count"] if data["count"] > 0 else 0
error_rate = (data["errors"] / data["count"]) * 100 if data["count"] > 0 else 0
stats[operation] = {
"total_requests": data["count"],
"total_errors": data["errors"],
"error_rate_percent": round(error_rate, 2),
"average_processing_time": round(avg_time, 3),
"last_request": data["last_request"]
}
stats["active_requests"] = len(self.active_requests)
stats["total_operations"] = len(self.request_stats)
return stats
class ContentModerationMiddleware:
"""Content moderation middleware for Stability AI requests."""
def __init__(self):
"""Initialize content moderation middleware."""
self.blocked_terms = self._load_blocked_terms()
self.warning_terms = self._load_warning_terms()
async def __call__(self, request: Request, call_next):
"""Process request with content moderation.
Args:
request: FastAPI request
call_next: Next middleware/endpoint
Returns:
Response
"""
# Skip moderation for non-generation endpoints
if not self._should_moderate(request.url.path):
return await call_next(request)
# Extract and check prompt content
prompt = await self._extract_prompt(request)
if prompt:
moderation_result = self._moderate_content(prompt)
if moderation_result["blocked"]:
return JSONResponse(
status_code=403,
content={
"error": "Content moderation",
"message": "Your request was flagged by our content moderation system",
"issues": moderation_result["issues"]
}
)
if moderation_result["warnings"]:
logger.warning(f"Content warnings for prompt: {moderation_result['warnings']}")
# Process request
response = await call_next(request)
# Add content moderation headers
if prompt:
response.headers["X-Content-Moderated"] = "true"
return response
def _should_moderate(self, path: str) -> bool:
"""Check if path should be moderated.
Args:
path: Request path
Returns:
True if should be moderated
"""
moderated_paths = ["/generate/", "/edit/", "/control/", "/audio/"]
return any(mod_path in path for mod_path in moderated_paths)
async def _extract_prompt(self, request: Request) -> Optional[str]:
"""Extract prompt from request.
Args:
request: FastAPI request
Returns:
Extracted prompt or None
"""
try:
if request.method == "POST":
# For form data, we'd need to parse the form
# This is a simplified version
body = await request.body()
if b"prompt=" in body:
# Extract prompt from form data (simplified)
body_str = body.decode('utf-8', errors='ignore')
if "prompt=" in body_str:
start = body_str.find("prompt=") + 7
end = body_str.find("&", start)
if end == -1:
end = len(body_str)
return body_str[start:end]
except:
pass
return None
def _moderate_content(self, prompt: str) -> Dict[str, Any]:
"""Moderate content for policy violations.
Args:
prompt: Text prompt to moderate
Returns:
Moderation result
"""
issues = []
warnings = []
prompt_lower = prompt.lower()
# Check for blocked terms
for term in self.blocked_terms:
if term in prompt_lower:
issues.append(f"Contains blocked term: {term}")
# Check for warning terms
for term in self.warning_terms:
if term in prompt_lower:
warnings.append(f"Contains flagged term: {term}")
return {
"blocked": len(issues) > 0,
"issues": issues,
"warnings": warnings
}
def _load_blocked_terms(self) -> List[str]:
"""Load blocked terms from configuration.
Returns:
List of blocked terms
"""
# In production, this would load from a configuration file or database
return [
# Add actual blocked terms here
]
def _load_warning_terms(self) -> List[str]:
"""Load warning terms from configuration.
Returns:
List of warning terms
"""
# In production, this would load from a configuration file or database
return [
# Add actual warning terms here
]
class CachingMiddleware:
"""Caching middleware for Stability AI responses."""
def __init__(self, cache_duration: int = 3600):
"""Initialize caching middleware.
Args:
cache_duration: Cache duration in seconds
"""
self.cache_duration = cache_duration
self.cache: Dict[str, Dict[str, Any]] = {}
self.cache_times: Dict[str, float] = {}
async def __call__(self, request: Request, call_next):
"""Process request with caching.
Args:
request: FastAPI request
call_next: Next middleware/endpoint
Returns:
Response (cached or fresh)
"""
# Skip caching for non-cacheable endpoints
if not self._should_cache(request):
return await call_next(request)
# Generate cache key
cache_key = await self._generate_cache_key(request)
# Check cache
if self._is_cached(cache_key):
logger.info(f"Returning cached result for {cache_key}")
cached_data = self.cache[cache_key]
return JSONResponse(
content=cached_data["content"],
headers={**cached_data["headers"], "X-Cache-Hit": "true"}
)
# Process request
response = await call_next(request)
# Cache successful responses
if response.status_code == 200 and self._should_cache_response(response):
await self._cache_response(cache_key, response)
return response
def _should_cache(self, request: Request) -> bool:
"""Check if request should be cached.
Args:
request: FastAPI request
Returns:
True if should be cached
"""
# Only cache GET requests and certain POST operations
if request.method == "GET":
return True
# Cache deterministic operations (those with seeds)
cacheable_paths = ["/models/info", "/supported-formats", "/health"]
return any(path in request.url.path for path in cacheable_paths)
def _should_cache_response(self, response) -> bool:
"""Check if response should be cached.
Args:
response: FastAPI response
Returns:
True if should be cached
"""
# Don't cache large binary responses
content_length = response.headers.get("content-length")
if content_length and int(content_length) > 1024 * 1024: # 1MB
return False
return True
async def _generate_cache_key(self, request: Request) -> str:
"""Generate cache key for request.
Args:
request: FastAPI request
Returns:
Cache key
"""
import hashlib
key_parts = [
request.method,
request.url.path,
str(sorted(request.query_params.items()))
]
# For POST requests, include body hash
if request.method == "POST":
body = await request.body()
if body:
key_parts.append(hashlib.md5(body).hexdigest())
key_string = "|".join(key_parts)
return hashlib.sha256(key_string.encode()).hexdigest()
def _is_cached(self, cache_key: str) -> bool:
"""Check if key is cached and not expired.
Args:
cache_key: Cache key
Returns:
True if cached and valid
"""
if cache_key not in self.cache:
return False
cache_time = self.cache_times.get(cache_key, 0)
return time.time() - cache_time < self.cache_duration
async def _cache_response(self, cache_key: str, response) -> None:
"""Cache response data.
Args:
cache_key: Cache key
response: Response to cache
"""
try:
# Only cache JSON responses for now
if response.headers.get("content-type", "").startswith("application/json"):
self.cache[cache_key] = {
"content": json.loads(response.body),
"headers": dict(response.headers)
}
self.cache_times[cache_key] = time.time()
except:
# Ignore cache errors
pass
def clear_cache(self) -> None:
"""Clear all cached data."""
self.cache.clear()
self.cache_times.clear()
def get_cache_stats(self) -> Dict[str, Any]:
"""Get cache statistics.
Returns:
Cache statistics
"""
current_time = time.time()
expired_keys = [
key for key, cache_time in self.cache_times.items()
if current_time - cache_time > self.cache_duration
]
return {
"total_entries": len(self.cache),
"expired_entries": len(expired_keys),
"cache_hit_rate": "N/A", # Would need request tracking
"memory_usage": sum(len(str(data)) for data in self.cache.values())
}
class RequestLoggingMiddleware:
"""Logging middleware for Stability AI requests."""
def __init__(self):
"""Initialize logging middleware."""
self.request_log = []
self.max_log_entries = 1000
async def __call__(self, request: Request, call_next):
"""Process request with logging.
Args:
request: FastAPI request
call_next: Next middleware/endpoint
Returns:
Response
"""
# Skip logging for non-Stability endpoints
if not request.url.path.startswith("/api/stability"):
return await call_next(request)
start_time = time.time()
request_id = f"{int(start_time * 1000)}_{id(request)}"
# Log request details
log_entry = {
"request_id": request_id,
"timestamp": datetime.utcnow().isoformat(),
"method": request.method,
"path": request.url.path,
"query_params": dict(request.query_params),
"client_ip": request.client.host if request.client else "unknown",
"user_agent": request.headers.get("user-agent", "unknown")
}
try:
# Process request
response = await call_next(request)
# Calculate processing time
processing_time = time.time() - start_time
# Update log entry
log_entry.update({
"status_code": response.status_code,
"processing_time": round(processing_time, 3),
"response_size": len(response.body) if hasattr(response, 'body') else 0,
"success": True
})
return response
except Exception as e:
# Log error
log_entry.update({
"error": str(e),
"success": False,
"processing_time": round(time.time() - start_time, 3)
})
raise
finally:
# Add to log
self._add_log_entry(log_entry)
def _add_log_entry(self, entry: Dict[str, Any]) -> None:
"""Add entry to request log.
Args:
entry: Log entry
"""
self.request_log.append(entry)
# Keep only recent entries
if len(self.request_log) > self.max_log_entries:
self.request_log = self.request_log[-self.max_log_entries:]
def get_recent_logs(self, limit: int = 100) -> List[Dict[str, Any]]:
"""Get recent log entries.
Args:
limit: Maximum number of entries to return
Returns:
Recent log entries
"""
return self.request_log[-limit:]
def get_log_summary(self) -> Dict[str, Any]:
"""Get summary of logged requests.
Returns:
Log summary statistics
"""
if not self.request_log:
return {"total_requests": 0}
total_requests = len(self.request_log)
successful_requests = sum(1 for entry in self.request_log if entry.get("success", False))
# Calculate average processing time
processing_times = [
entry["processing_time"] for entry in self.request_log
if "processing_time" in entry
]
avg_processing_time = sum(processing_times) / len(processing_times) if processing_times else 0
# Get operation breakdown
operations = defaultdict(int)
for entry in self.request_log:
operation = entry.get("path", "unknown").split("/")[-1]
operations[operation] += 1
return {
"total_requests": total_requests,
"successful_requests": successful_requests,
"error_rate_percent": round((1 - successful_requests / total_requests) * 100, 2),
"average_processing_time": round(avg_processing_time, 3),
"operations_breakdown": dict(operations),
"time_range": {
"start": self.request_log[0]["timestamp"],
"end": self.request_log[-1]["timestamp"]
}
}
# Global middleware instances
rate_limiter = RateLimitMiddleware()
monitoring = MonitoringMiddleware()
caching = CachingMiddleware()
request_logging = RequestLoggingMiddleware()
def get_middleware_stats() -> Dict[str, Any]:
"""Get statistics from all middleware components.
Returns:
Combined middleware statistics
"""
return {
"rate_limiting": {
"active_blocks": len(rate_limiter.blocked_until),
"requests_per_window": rate_limiter.requests_per_window,
"window_seconds": rate_limiter.window_seconds
},
"monitoring": monitoring.get_stats(),
"caching": caching.get_cache_stats(),
"logging": request_logging.get_log_summary()
}