135 lines
5.2 KiB
Python
135 lines
5.2 KiB
Python
"""
|
|
Rate Limiting Module
|
|
Handles rate limiting middleware and request tracking.
|
|
"""
|
|
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Dict, List, Optional
|
|
from fastapi import Request, Response
|
|
from fastapi.responses import JSONResponse
|
|
from loguru import logger
|
|
|
|
|
|
class RateLimiter:
|
|
"""Manages rate limiting for ALwrity backend."""
|
|
|
|
def __init__(self, window_seconds: int = 60, max_requests: int = 1000): # Increased for development
|
|
self.window_seconds = window_seconds
|
|
self.max_requests = max_requests
|
|
self.request_counts: Dict[str, List[float]] = defaultdict(list)
|
|
|
|
# Endpoints exempt from rate limiting
|
|
self.exempt_paths = [
|
|
"/stream/strategies",
|
|
"/stream/strategic-intelligence",
|
|
"/stream/keyword-research",
|
|
"/latest-strategy",
|
|
"/ai-analytics",
|
|
"/gap-analysis",
|
|
"/calendar-events",
|
|
# Research endpoints - exempt from rate limiting
|
|
"/api/research",
|
|
"/api/blog-writer",
|
|
"/api/blog-writer/research",
|
|
"/api/blog-writer/research/",
|
|
"/api/blog/research/status",
|
|
"/calendar-generation/progress",
|
|
"/health",
|
|
"/health/database",
|
|
]
|
|
# Prefixes to exempt entire route families (keep empty; rely on specific exemptions only)
|
|
self.exempt_prefixes = []
|
|
|
|
def is_exempt_path(self, path: str) -> bool:
|
|
"""Check if a path is exempt from rate limiting."""
|
|
return any(exempt_path == path or exempt_path in path for exempt_path in self.exempt_paths) or any(
|
|
path.startswith(prefix) for prefix in self.exempt_prefixes
|
|
)
|
|
|
|
def clean_old_requests(self, client_ip: str, current_time: float) -> None:
|
|
"""Clean old requests from the tracking dictionary."""
|
|
self.request_counts[client_ip] = [
|
|
req_time for req_time in self.request_counts[client_ip]
|
|
if current_time - req_time < self.window_seconds
|
|
]
|
|
|
|
def is_rate_limited(self, client_ip: str, current_time: float) -> bool:
|
|
"""Check if a client has exceeded the rate limit."""
|
|
self.clean_old_requests(client_ip, current_time)
|
|
return len(self.request_counts[client_ip]) >= self.max_requests
|
|
|
|
def add_request(self, client_ip: str, current_time: float) -> None:
|
|
"""Add a request to the tracking dictionary."""
|
|
self.request_counts[client_ip].append(current_time)
|
|
|
|
def get_rate_limit_response(self) -> JSONResponse:
|
|
"""Get a rate limit exceeded response."""
|
|
return JSONResponse(
|
|
status_code=429,
|
|
content={
|
|
"detail": "Too many requests",
|
|
"retry_after": self.window_seconds
|
|
},
|
|
headers={
|
|
"Access-Control-Allow-Origin": "*",
|
|
"Access-Control-Allow-Methods": "*",
|
|
"Access-Control-Allow-Headers": "*"
|
|
}
|
|
)
|
|
|
|
async def rate_limit_middleware(self, request: Request, call_next) -> Response:
|
|
"""Rate limiting middleware with exemptions for streaming endpoints."""
|
|
try:
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
current_time = time.time()
|
|
path = request.url.path
|
|
|
|
# Check if path is exempt from rate limiting
|
|
if self.is_exempt_path(path):
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
# Check rate limit
|
|
if self.is_rate_limited(client_ip, current_time):
|
|
logger.warning(f"Rate limit exceeded for {client_ip}")
|
|
return self.get_rate_limit_response()
|
|
|
|
# Add current request
|
|
self.add_request(client_ip, current_time)
|
|
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error in rate limiting middleware: {e}")
|
|
# Continue without rate limiting if there's an error
|
|
response = await call_next(request)
|
|
return response
|
|
|
|
def get_rate_limit_status(self, client_ip: str) -> Dict[str, any]:
|
|
"""Get current rate limit status for a client."""
|
|
current_time = time.time()
|
|
self.clean_old_requests(client_ip, current_time)
|
|
|
|
request_count = len(self.request_counts[client_ip])
|
|
remaining_requests = max(0, self.max_requests - request_count)
|
|
|
|
return {
|
|
"client_ip": client_ip,
|
|
"requests_in_window": request_count,
|
|
"max_requests": self.max_requests,
|
|
"remaining_requests": remaining_requests,
|
|
"window_seconds": self.window_seconds,
|
|
"is_limited": request_count >= self.max_requests
|
|
}
|
|
|
|
def reset_rate_limit(self, client_ip: Optional[str] = None) -> Dict[str, any]:
|
|
"""Reset rate limit for a specific client or all clients."""
|
|
if client_ip:
|
|
self.request_counts[client_ip] = []
|
|
return {"message": f"Rate limit reset for {client_ip}"}
|
|
else:
|
|
self.request_counts.clear()
|
|
return {"message": "Rate limit reset for all clients"}
|