702 lines
23 KiB
Python
702 lines
23 KiB
Python
"""Middleware for Stability AI operations."""
|
|
|
|
import time
|
|
import asyncio
|
|
import os
|
|
from typing import Dict, Any, Optional, List
|
|
from collections import defaultdict, deque
|
|
from fastapi import Request, HTTPException
|
|
from fastapi.responses import JSONResponse
|
|
import json
|
|
from loguru import logger
|
|
from datetime import datetime, timedelta
|
|
|
|
|
|
class RateLimitMiddleware:
|
|
"""Rate limiting middleware for Stability AI API calls."""
|
|
|
|
def __init__(self, requests_per_window: int = 150, window_seconds: int = 10):
|
|
"""Initialize rate limiter.
|
|
|
|
Args:
|
|
requests_per_window: Maximum requests per time window
|
|
window_seconds: Time window in seconds
|
|
"""
|
|
self.requests_per_window = requests_per_window
|
|
self.window_seconds = window_seconds
|
|
self.request_times: Dict[str, deque] = defaultdict(lambda: deque())
|
|
self.blocked_until: Dict[str, float] = {}
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
"""Process request with rate limiting.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
call_next: Next middleware/endpoint
|
|
|
|
Returns:
|
|
Response
|
|
"""
|
|
# Skip rate limiting for non-Stability endpoints
|
|
if not request.url.path.startswith("/api/stability"):
|
|
return await call_next(request)
|
|
|
|
# Get client identifier (IP address or API key)
|
|
client_id = self._get_client_id(request)
|
|
current_time = time.time()
|
|
|
|
# Check if client is currently blocked
|
|
if client_id in self.blocked_until:
|
|
if current_time < self.blocked_until[client_id]:
|
|
remaining = int(self.blocked_until[client_id] - current_time)
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Rate limit exceeded",
|
|
"retry_after": remaining,
|
|
"message": f"You have been timed out for {remaining} seconds"
|
|
}
|
|
)
|
|
else:
|
|
# Timeout expired, remove block
|
|
del self.blocked_until[client_id]
|
|
|
|
# Clean old requests outside the window
|
|
request_times = self.request_times[client_id]
|
|
while request_times and request_times[0] < current_time - self.window_seconds:
|
|
request_times.popleft()
|
|
|
|
# Check rate limit
|
|
if len(request_times) >= self.requests_per_window:
|
|
# Rate limit exceeded, block for 60 seconds
|
|
self.blocked_until[client_id] = current_time + 60
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"error": "Rate limit exceeded",
|
|
"retry_after": 60,
|
|
"message": "You have exceeded the rate limit of 150 requests within a 10 second period"
|
|
}
|
|
)
|
|
|
|
# Add current request time
|
|
request_times.append(current_time)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Add rate limit headers
|
|
response.headers["X-RateLimit-Limit"] = str(self.requests_per_window)
|
|
response.headers["X-RateLimit-Remaining"] = str(self.requests_per_window - len(request_times))
|
|
response.headers["X-RateLimit-Reset"] = str(int(current_time + self.window_seconds))
|
|
|
|
return response
|
|
|
|
def _get_client_id(self, request: Request) -> str:
|
|
"""Get client identifier for rate limiting.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
|
|
Returns:
|
|
Client identifier
|
|
"""
|
|
# Try to get API key from authorization header
|
|
auth_header = request.headers.get("authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
return auth_header[7:15] # Use first 8 chars of API key
|
|
|
|
# Fall back to IP address
|
|
return request.client.host if request.client else "unknown"
|
|
|
|
|
|
class MonitoringMiddleware:
|
|
"""Monitoring middleware for Stability AI operations."""
|
|
|
|
def __init__(self):
|
|
"""Initialize monitoring middleware."""
|
|
self.request_stats = defaultdict(lambda: {
|
|
"count": 0,
|
|
"total_time": 0,
|
|
"errors": 0,
|
|
"last_request": None
|
|
})
|
|
self.active_requests = {}
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
"""Process request with monitoring.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
call_next: Next middleware/endpoint
|
|
|
|
Returns:
|
|
Response
|
|
"""
|
|
# Skip monitoring for non-Stability endpoints
|
|
if not request.url.path.startswith("/api/stability"):
|
|
return await call_next(request)
|
|
|
|
start_time = time.time()
|
|
request_id = f"{int(start_time * 1000)}_{id(request)}"
|
|
|
|
# Extract operation info
|
|
operation = self._extract_operation(request.url.path)
|
|
|
|
# Log request start
|
|
self.active_requests[request_id] = {
|
|
"operation": operation,
|
|
"start_time": start_time,
|
|
"path": request.url.path,
|
|
"method": request.method
|
|
}
|
|
|
|
try:
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Calculate processing time
|
|
processing_time = time.time() - start_time
|
|
|
|
# Update stats
|
|
stats = self.request_stats[operation]
|
|
stats["count"] += 1
|
|
stats["total_time"] += processing_time
|
|
stats["last_request"] = datetime.utcnow().isoformat()
|
|
|
|
# Add monitoring headers
|
|
response.headers["X-Processing-Time"] = str(round(processing_time, 3))
|
|
response.headers["X-Operation"] = operation
|
|
response.headers["X-Request-ID"] = request_id
|
|
|
|
# Log successful request
|
|
logger.info(f"Stability AI request completed: {operation} in {processing_time:.3f}s")
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
# Update error stats
|
|
self.request_stats[operation]["errors"] += 1
|
|
|
|
# Log error
|
|
logger.error(f"Stability AI request failed: {operation} - {str(e)}")
|
|
|
|
raise
|
|
|
|
finally:
|
|
# Clean up active request
|
|
self.active_requests.pop(request_id, None)
|
|
|
|
def _extract_operation(self, path: str) -> str:
|
|
"""Extract operation name from request path.
|
|
|
|
Args:
|
|
path: Request path
|
|
|
|
Returns:
|
|
Operation name
|
|
"""
|
|
path_parts = path.split("/")
|
|
|
|
if len(path_parts) >= 4:
|
|
if "generate" in path_parts:
|
|
return f"generate_{path_parts[-1]}"
|
|
elif "edit" in path_parts:
|
|
return f"edit_{path_parts[-1]}"
|
|
elif "upscale" in path_parts:
|
|
return f"upscale_{path_parts[-1]}"
|
|
elif "control" in path_parts:
|
|
return f"control_{path_parts[-1]}"
|
|
elif "3d" in path_parts:
|
|
return f"3d_{path_parts[-1]}"
|
|
elif "audio" in path_parts:
|
|
return f"audio_{path_parts[-1]}"
|
|
|
|
return "unknown"
|
|
|
|
def get_stats(self) -> Dict[str, Any]:
|
|
"""Get monitoring statistics.
|
|
|
|
Returns:
|
|
Monitoring statistics
|
|
"""
|
|
stats = {}
|
|
|
|
for operation, data in self.request_stats.items():
|
|
avg_time = data["total_time"] / data["count"] if data["count"] > 0 else 0
|
|
error_rate = (data["errors"] / data["count"]) * 100 if data["count"] > 0 else 0
|
|
|
|
stats[operation] = {
|
|
"total_requests": data["count"],
|
|
"total_errors": data["errors"],
|
|
"error_rate_percent": round(error_rate, 2),
|
|
"average_processing_time": round(avg_time, 3),
|
|
"last_request": data["last_request"]
|
|
}
|
|
|
|
stats["active_requests"] = len(self.active_requests)
|
|
stats["total_operations"] = len(self.request_stats)
|
|
|
|
return stats
|
|
|
|
|
|
class ContentModerationMiddleware:
|
|
"""Content moderation middleware for Stability AI requests."""
|
|
|
|
def __init__(self):
|
|
"""Initialize content moderation middleware."""
|
|
self.blocked_terms = self._load_blocked_terms()
|
|
self.warning_terms = self._load_warning_terms()
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
"""Process request with content moderation.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
call_next: Next middleware/endpoint
|
|
|
|
Returns:
|
|
Response
|
|
"""
|
|
# Skip moderation for non-generation endpoints
|
|
if not self._should_moderate(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# Extract and check prompt content
|
|
prompt = await self._extract_prompt(request)
|
|
|
|
if prompt:
|
|
moderation_result = self._moderate_content(prompt)
|
|
|
|
if moderation_result["blocked"]:
|
|
return JSONResponse(
|
|
status_code=403,
|
|
content={
|
|
"error": "Content moderation",
|
|
"message": "Your request was flagged by our content moderation system",
|
|
"issues": moderation_result["issues"]
|
|
}
|
|
)
|
|
|
|
if moderation_result["warnings"]:
|
|
logger.warning(f"Content warnings for prompt: {moderation_result['warnings']}")
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Add content moderation headers
|
|
if prompt:
|
|
response.headers["X-Content-Moderated"] = "true"
|
|
|
|
return response
|
|
|
|
def _should_moderate(self, path: str) -> bool:
|
|
"""Check if path should be moderated.
|
|
|
|
Args:
|
|
path: Request path
|
|
|
|
Returns:
|
|
True if should be moderated
|
|
"""
|
|
moderated_paths = ["/generate/", "/edit/", "/control/", "/audio/"]
|
|
return any(mod_path in path for mod_path in moderated_paths)
|
|
|
|
async def _extract_prompt(self, request: Request) -> Optional[str]:
|
|
"""Extract prompt from request.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
|
|
Returns:
|
|
Extracted prompt or None
|
|
"""
|
|
try:
|
|
if request.method == "POST":
|
|
# For form data, we'd need to parse the form
|
|
# This is a simplified version
|
|
body = await request.body()
|
|
if b"prompt=" in body:
|
|
# Extract prompt from form data (simplified)
|
|
body_str = body.decode('utf-8', errors='ignore')
|
|
if "prompt=" in body_str:
|
|
start = body_str.find("prompt=") + 7
|
|
end = body_str.find("&", start)
|
|
if end == -1:
|
|
end = len(body_str)
|
|
return body_str[start:end]
|
|
except:
|
|
pass
|
|
|
|
return None
|
|
|
|
def _moderate_content(self, prompt: str) -> Dict[str, Any]:
|
|
"""Moderate content for policy violations.
|
|
|
|
Args:
|
|
prompt: Text prompt to moderate
|
|
|
|
Returns:
|
|
Moderation result
|
|
"""
|
|
issues = []
|
|
warnings = []
|
|
|
|
prompt_lower = prompt.lower()
|
|
|
|
# Check for blocked terms
|
|
for term in self.blocked_terms:
|
|
if term in prompt_lower:
|
|
issues.append(f"Contains blocked term: {term}")
|
|
|
|
# Check for warning terms
|
|
for term in self.warning_terms:
|
|
if term in prompt_lower:
|
|
warnings.append(f"Contains flagged term: {term}")
|
|
|
|
return {
|
|
"blocked": len(issues) > 0,
|
|
"issues": issues,
|
|
"warnings": warnings
|
|
}
|
|
|
|
def _load_blocked_terms(self) -> List[str]:
|
|
"""Load blocked terms from configuration.
|
|
|
|
Returns:
|
|
List of blocked terms
|
|
"""
|
|
# In production, this would load from a configuration file or database
|
|
return [
|
|
# Add actual blocked terms here
|
|
]
|
|
|
|
def _load_warning_terms(self) -> List[str]:
|
|
"""Load warning terms from configuration.
|
|
|
|
Returns:
|
|
List of warning terms
|
|
"""
|
|
# In production, this would load from a configuration file or database
|
|
return [
|
|
# Add actual warning terms here
|
|
]
|
|
|
|
|
|
class CachingMiddleware:
|
|
"""Caching middleware for Stability AI responses."""
|
|
|
|
def __init__(self, cache_duration: int = 3600):
|
|
"""Initialize caching middleware.
|
|
|
|
Args:
|
|
cache_duration: Cache duration in seconds
|
|
"""
|
|
self.cache_duration = cache_duration
|
|
self.cache: Dict[str, Dict[str, Any]] = {}
|
|
self.cache_times: Dict[str, float] = {}
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
"""Process request with caching.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
call_next: Next middleware/endpoint
|
|
|
|
Returns:
|
|
Response (cached or fresh)
|
|
"""
|
|
# Skip caching for non-cacheable endpoints
|
|
if not self._should_cache(request):
|
|
return await call_next(request)
|
|
|
|
# Generate cache key
|
|
cache_key = await self._generate_cache_key(request)
|
|
|
|
# Check cache
|
|
if self._is_cached(cache_key):
|
|
logger.info(f"Returning cached result for {cache_key}")
|
|
cached_data = self.cache[cache_key]
|
|
|
|
return JSONResponse(
|
|
content=cached_data["content"],
|
|
headers={**cached_data["headers"], "X-Cache-Hit": "true"}
|
|
)
|
|
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Cache successful responses
|
|
if response.status_code == 200 and self._should_cache_response(response):
|
|
await self._cache_response(cache_key, response)
|
|
|
|
return response
|
|
|
|
def _should_cache(self, request: Request) -> bool:
|
|
"""Check if request should be cached.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
|
|
Returns:
|
|
True if should be cached
|
|
"""
|
|
# Only cache GET requests and certain POST operations
|
|
if request.method == "GET":
|
|
return True
|
|
|
|
# Cache deterministic operations (those with seeds)
|
|
cacheable_paths = ["/models/info", "/supported-formats", "/health"]
|
|
return any(path in request.url.path for path in cacheable_paths)
|
|
|
|
def _should_cache_response(self, response) -> bool:
|
|
"""Check if response should be cached.
|
|
|
|
Args:
|
|
response: FastAPI response
|
|
|
|
Returns:
|
|
True if should be cached
|
|
"""
|
|
# Don't cache large binary responses
|
|
content_length = response.headers.get("content-length")
|
|
if content_length and int(content_length) > 1024 * 1024: # 1MB
|
|
return False
|
|
|
|
return True
|
|
|
|
async def _generate_cache_key(self, request: Request) -> str:
|
|
"""Generate cache key for request.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
|
|
Returns:
|
|
Cache key
|
|
"""
|
|
import hashlib
|
|
|
|
key_parts = [
|
|
request.method,
|
|
request.url.path,
|
|
str(sorted(request.query_params.items()))
|
|
]
|
|
|
|
# For POST requests, include body hash
|
|
if request.method == "POST":
|
|
body = await request.body()
|
|
if body:
|
|
key_parts.append(hashlib.md5(body).hexdigest())
|
|
|
|
key_string = "|".join(key_parts)
|
|
return hashlib.sha256(key_string.encode()).hexdigest()
|
|
|
|
def _is_cached(self, cache_key: str) -> bool:
|
|
"""Check if key is cached and not expired.
|
|
|
|
Args:
|
|
cache_key: Cache key
|
|
|
|
Returns:
|
|
True if cached and valid
|
|
"""
|
|
if cache_key not in self.cache:
|
|
return False
|
|
|
|
cache_time = self.cache_times.get(cache_key, 0)
|
|
return time.time() - cache_time < self.cache_duration
|
|
|
|
async def _cache_response(self, cache_key: str, response) -> None:
|
|
"""Cache response data.
|
|
|
|
Args:
|
|
cache_key: Cache key
|
|
response: Response to cache
|
|
"""
|
|
try:
|
|
# Only cache JSON responses for now
|
|
if response.headers.get("content-type", "").startswith("application/json"):
|
|
self.cache[cache_key] = {
|
|
"content": json.loads(response.body),
|
|
"headers": dict(response.headers)
|
|
}
|
|
self.cache_times[cache_key] = time.time()
|
|
except:
|
|
# Ignore cache errors
|
|
pass
|
|
|
|
def clear_cache(self) -> None:
|
|
"""Clear all cached data."""
|
|
self.cache.clear()
|
|
self.cache_times.clear()
|
|
|
|
def get_cache_stats(self) -> Dict[str, Any]:
|
|
"""Get cache statistics.
|
|
|
|
Returns:
|
|
Cache statistics
|
|
"""
|
|
current_time = time.time()
|
|
expired_keys = [
|
|
key for key, cache_time in self.cache_times.items()
|
|
if current_time - cache_time > self.cache_duration
|
|
]
|
|
|
|
return {
|
|
"total_entries": len(self.cache),
|
|
"expired_entries": len(expired_keys),
|
|
"cache_hit_rate": "N/A", # Would need request tracking
|
|
"memory_usage": sum(len(str(data)) for data in self.cache.values())
|
|
}
|
|
|
|
|
|
class RequestLoggingMiddleware:
|
|
"""Logging middleware for Stability AI requests."""
|
|
|
|
def __init__(self):
|
|
"""Initialize logging middleware."""
|
|
self.request_log = []
|
|
self.max_log_entries = 1000
|
|
|
|
async def __call__(self, request: Request, call_next):
|
|
"""Process request with logging.
|
|
|
|
Args:
|
|
request: FastAPI request
|
|
call_next: Next middleware/endpoint
|
|
|
|
Returns:
|
|
Response
|
|
"""
|
|
# Skip logging for non-Stability endpoints
|
|
if not request.url.path.startswith("/api/stability"):
|
|
return await call_next(request)
|
|
|
|
start_time = time.time()
|
|
request_id = f"{int(start_time * 1000)}_{id(request)}"
|
|
|
|
# Log request details
|
|
log_entry = {
|
|
"request_id": request_id,
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"query_params": dict(request.query_params),
|
|
"client_ip": request.client.host if request.client else "unknown",
|
|
"user_agent": request.headers.get("user-agent", "unknown")
|
|
}
|
|
|
|
try:
|
|
# Process request
|
|
response = await call_next(request)
|
|
|
|
# Calculate processing time
|
|
processing_time = time.time() - start_time
|
|
|
|
# Update log entry
|
|
log_entry.update({
|
|
"status_code": response.status_code,
|
|
"processing_time": round(processing_time, 3),
|
|
"response_size": len(response.body) if hasattr(response, 'body') else 0,
|
|
"success": True
|
|
})
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
# Log error
|
|
log_entry.update({
|
|
"error": str(e),
|
|
"success": False,
|
|
"processing_time": round(time.time() - start_time, 3)
|
|
})
|
|
raise
|
|
|
|
finally:
|
|
# Add to log
|
|
self._add_log_entry(log_entry)
|
|
|
|
def _add_log_entry(self, entry: Dict[str, Any]) -> None:
|
|
"""Add entry to request log.
|
|
|
|
Args:
|
|
entry: Log entry
|
|
"""
|
|
self.request_log.append(entry)
|
|
|
|
# Keep only recent entries
|
|
if len(self.request_log) > self.max_log_entries:
|
|
self.request_log = self.request_log[-self.max_log_entries:]
|
|
|
|
def get_recent_logs(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""Get recent log entries.
|
|
|
|
Args:
|
|
limit: Maximum number of entries to return
|
|
|
|
Returns:
|
|
Recent log entries
|
|
"""
|
|
return self.request_log[-limit:]
|
|
|
|
def get_log_summary(self) -> Dict[str, Any]:
|
|
"""Get summary of logged requests.
|
|
|
|
Returns:
|
|
Log summary statistics
|
|
"""
|
|
if not self.request_log:
|
|
return {"total_requests": 0}
|
|
|
|
total_requests = len(self.request_log)
|
|
successful_requests = sum(1 for entry in self.request_log if entry.get("success", False))
|
|
|
|
# Calculate average processing time
|
|
processing_times = [
|
|
entry["processing_time"] for entry in self.request_log
|
|
if "processing_time" in entry
|
|
]
|
|
avg_processing_time = sum(processing_times) / len(processing_times) if processing_times else 0
|
|
|
|
# Get operation breakdown
|
|
operations = defaultdict(int)
|
|
for entry in self.request_log:
|
|
operation = entry.get("path", "unknown").split("/")[-1]
|
|
operations[operation] += 1
|
|
|
|
return {
|
|
"total_requests": total_requests,
|
|
"successful_requests": successful_requests,
|
|
"error_rate_percent": round((1 - successful_requests / total_requests) * 100, 2),
|
|
"average_processing_time": round(avg_processing_time, 3),
|
|
"operations_breakdown": dict(operations),
|
|
"time_range": {
|
|
"start": self.request_log[0]["timestamp"],
|
|
"end": self.request_log[-1]["timestamp"]
|
|
}
|
|
}
|
|
|
|
|
|
# Global middleware instances
|
|
rate_limiter = RateLimitMiddleware()
|
|
monitoring = MonitoringMiddleware()
|
|
caching = CachingMiddleware()
|
|
request_logging = RequestLoggingMiddleware()
|
|
|
|
|
|
def get_middleware_stats() -> Dict[str, Any]:
|
|
"""Get statistics from all middleware components.
|
|
|
|
Returns:
|
|
Combined middleware statistics
|
|
"""
|
|
return {
|
|
"rate_limiting": {
|
|
"active_blocks": len(rate_limiter.blocked_until),
|
|
"requests_per_window": rate_limiter.requests_per_window,
|
|
"window_seconds": rate_limiter.window_seconds
|
|
},
|
|
"monitoring": monitoring.get_stats(),
|
|
"caching": caching.get_cache_stats(),
|
|
"logging": request_logging.get_log_summary()
|
|
} |