Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

View File

@@ -0,0 +1,134 @@
"""
Rate Limiting Module
Handles rate limiting middleware and request tracking.
"""
import time
from collections import defaultdict
from typing import Dict, List, Optional
from fastapi import Request, Response
from fastapi.responses import JSONResponse
from loguru import logger
class RateLimiter:
"""Manages rate limiting for ALwrity backend."""
def __init__(self, window_seconds: int = 60, max_requests: int = 1000): # Increased for development
self.window_seconds = window_seconds
self.max_requests = max_requests
self.request_counts: Dict[str, List[float]] = defaultdict(list)
# Endpoints exempt from rate limiting
self.exempt_paths = [
"/stream/strategies",
"/stream/strategic-intelligence",
"/stream/keyword-research",
"/latest-strategy",
"/ai-analytics",
"/gap-analysis",
"/calendar-events",
# Research endpoints - exempt from rate limiting
"/api/research",
"/api/blog-writer",
"/api/blog-writer/research",
"/api/blog-writer/research/",
"/api/blog/research/status",
"/calendar-generation/progress",
"/health",
"/health/database",
]
# Prefixes to exempt entire route families (keep empty; rely on specific exemptions only)
self.exempt_prefixes = []
def is_exempt_path(self, path: str) -> bool:
"""Check if a path is exempt from rate limiting."""
return any(exempt_path == path or exempt_path in path for exempt_path in self.exempt_paths) or any(
path.startswith(prefix) for prefix in self.exempt_prefixes
)
def clean_old_requests(self, client_ip: str, current_time: float) -> None:
"""Clean old requests from the tracking dictionary."""
self.request_counts[client_ip] = [
req_time for req_time in self.request_counts[client_ip]
if current_time - req_time < self.window_seconds
]
def is_rate_limited(self, client_ip: str, current_time: float) -> bool:
"""Check if a client has exceeded the rate limit."""
self.clean_old_requests(client_ip, current_time)
return len(self.request_counts[client_ip]) >= self.max_requests
def add_request(self, client_ip: str, current_time: float) -> None:
"""Add a request to the tracking dictionary."""
self.request_counts[client_ip].append(current_time)
def get_rate_limit_response(self) -> JSONResponse:
"""Get a rate limit exceeded response."""
return JSONResponse(
status_code=429,
content={
"detail": "Too many requests",
"retry_after": self.window_seconds
},
headers={
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": "*",
"Access-Control-Allow-Headers": "*"
}
)
async def rate_limit_middleware(self, request: Request, call_next) -> Response:
"""Rate limiting middleware with exemptions for streaming endpoints."""
try:
client_ip = request.client.host if request.client else "unknown"
current_time = time.time()
path = request.url.path
# Check if path is exempt from rate limiting
if self.is_exempt_path(path):
response = await call_next(request)
return response
# Check rate limit
if self.is_rate_limited(client_ip, current_time):
logger.warning(f"Rate limit exceeded for {client_ip}")
return self.get_rate_limit_response()
# Add current request
self.add_request(client_ip, current_time)
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error in rate limiting middleware: {e}")
# Continue without rate limiting if there's an error
response = await call_next(request)
return response
def get_rate_limit_status(self, client_ip: str) -> Dict[str, any]:
"""Get current rate limit status for a client."""
current_time = time.time()
self.clean_old_requests(client_ip, current_time)
request_count = len(self.request_counts[client_ip])
remaining_requests = max(0, self.max_requests - request_count)
return {
"client_ip": client_ip,
"requests_in_window": request_count,
"max_requests": self.max_requests,
"remaining_requests": remaining_requests,
"window_seconds": self.window_seconds,
"is_limited": request_count >= self.max_requests
}
def reset_rate_limit(self, client_ip: Optional[str] = None) -> Dict[str, any]:
"""Reset rate limit for a specific client or all clients."""
if client_ip:
self.request_counts[client_ip] = []
return {"message": f"Rate limit reset for {client_ip}"}
else:
self.request_counts.clear()
return {"message": "Rate limit reset for all clients"}