Base code
This commit is contained in:
134
backend/alwrity_utils/rate_limiter.py
Normal file
134
backend/alwrity_utils/rate_limiter.py
Normal file
@@ -0,0 +1,134 @@
|
||||
"""
|
||||
Rate Limiting Module
|
||||
Handles rate limiting middleware and request tracking.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional
|
||||
from fastapi import Request, Response
|
||||
from fastapi.responses import JSONResponse
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Manages rate limiting for ALwrity backend."""
|
||||
|
||||
def __init__(self, window_seconds: int = 60, max_requests: int = 1000): # Increased for development
|
||||
self.window_seconds = window_seconds
|
||||
self.max_requests = max_requests
|
||||
self.request_counts: Dict[str, List[float]] = defaultdict(list)
|
||||
|
||||
# Endpoints exempt from rate limiting
|
||||
self.exempt_paths = [
|
||||
"/stream/strategies",
|
||||
"/stream/strategic-intelligence",
|
||||
"/stream/keyword-research",
|
||||
"/latest-strategy",
|
||||
"/ai-analytics",
|
||||
"/gap-analysis",
|
||||
"/calendar-events",
|
||||
# Research endpoints - exempt from rate limiting
|
||||
"/api/research",
|
||||
"/api/blog-writer",
|
||||
"/api/blog-writer/research",
|
||||
"/api/blog-writer/research/",
|
||||
"/api/blog/research/status",
|
||||
"/calendar-generation/progress",
|
||||
"/health",
|
||||
"/health/database",
|
||||
]
|
||||
# Prefixes to exempt entire route families (keep empty; rely on specific exemptions only)
|
||||
self.exempt_prefixes = []
|
||||
|
||||
def is_exempt_path(self, path: str) -> bool:
|
||||
"""Check if a path is exempt from rate limiting."""
|
||||
return any(exempt_path == path or exempt_path in path for exempt_path in self.exempt_paths) or any(
|
||||
path.startswith(prefix) for prefix in self.exempt_prefixes
|
||||
)
|
||||
|
||||
def clean_old_requests(self, client_ip: str, current_time: float) -> None:
|
||||
"""Clean old requests from the tracking dictionary."""
|
||||
self.request_counts[client_ip] = [
|
||||
req_time for req_time in self.request_counts[client_ip]
|
||||
if current_time - req_time < self.window_seconds
|
||||
]
|
||||
|
||||
def is_rate_limited(self, client_ip: str, current_time: float) -> bool:
|
||||
"""Check if a client has exceeded the rate limit."""
|
||||
self.clean_old_requests(client_ip, current_time)
|
||||
return len(self.request_counts[client_ip]) >= self.max_requests
|
||||
|
||||
def add_request(self, client_ip: str, current_time: float) -> None:
|
||||
"""Add a request to the tracking dictionary."""
|
||||
self.request_counts[client_ip].append(current_time)
|
||||
|
||||
def get_rate_limit_response(self) -> JSONResponse:
|
||||
"""Get a rate limit exceeded response."""
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"detail": "Too many requests",
|
||||
"retry_after": self.window_seconds
|
||||
},
|
||||
headers={
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": "*",
|
||||
"Access-Control-Allow-Headers": "*"
|
||||
}
|
||||
)
|
||||
|
||||
async def rate_limit_middleware(self, request: Request, call_next) -> Response:
|
||||
"""Rate limiting middleware with exemptions for streaming endpoints."""
|
||||
try:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
current_time = time.time()
|
||||
path = request.url.path
|
||||
|
||||
# Check if path is exempt from rate limiting
|
||||
if self.is_exempt_path(path):
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
# Check rate limit
|
||||
if self.is_rate_limited(client_ip, current_time):
|
||||
logger.warning(f"Rate limit exceeded for {client_ip}")
|
||||
return self.get_rate_limit_response()
|
||||
|
||||
# Add current request
|
||||
self.add_request(client_ip, current_time)
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in rate limiting middleware: {e}")
|
||||
# Continue without rate limiting if there's an error
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
def get_rate_limit_status(self, client_ip: str) -> Dict[str, any]:
|
||||
"""Get current rate limit status for a client."""
|
||||
current_time = time.time()
|
||||
self.clean_old_requests(client_ip, current_time)
|
||||
|
||||
request_count = len(self.request_counts[client_ip])
|
||||
remaining_requests = max(0, self.max_requests - request_count)
|
||||
|
||||
return {
|
||||
"client_ip": client_ip,
|
||||
"requests_in_window": request_count,
|
||||
"max_requests": self.max_requests,
|
||||
"remaining_requests": remaining_requests,
|
||||
"window_seconds": self.window_seconds,
|
||||
"is_limited": request_count >= self.max_requests
|
||||
}
|
||||
|
||||
def reset_rate_limit(self, client_ip: Optional[str] = None) -> Dict[str, any]:
|
||||
"""Reset rate limit for a specific client or all clients."""
|
||||
if client_ip:
|
||||
self.request_counts[client_ip] = []
|
||||
return {"message": f"Rate limit reset for {client_ip}"}
|
||||
else:
|
||||
self.request_counts.clear()
|
||||
return {"message": "Rate limit reset for all clients"}
|
||||
Reference in New Issue
Block a user