Release Candidate: Production Release with Multi-Tenant & Onboarding Enhancements
This commit is contained in:
@@ -149,7 +149,7 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
try:
|
||||
path = request.url.path
|
||||
except Exception:
|
||||
pass
|
||||
path = ""
|
||||
|
||||
db = None
|
||||
try:
|
||||
@@ -159,8 +159,16 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
|
||||
# Safe User-Agent access
|
||||
user_agent = None
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
user_agent = request.headers.get('user-agent')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Detect if this is an API call that should be rate limited
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
api_provider = api_monitor.detect_api_provider(path, user_agent)
|
||||
if not api_provider:
|
||||
return None
|
||||
|
||||
@@ -236,9 +244,28 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
user_id = None
|
||||
try:
|
||||
# PRIORITY 1: Check request.state.user_id (set by API key injection middleware)
|
||||
if hasattr(request.state, 'user_id') and request.state.user_id:
|
||||
user_id = request.state.user_id
|
||||
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
|
||||
if hasattr(request.state, 'user_id'):
|
||||
# Directly check and convert without accessing attribute if None
|
||||
raw_user_id = request.state.user_id
|
||||
|
||||
# Defensive check for Depends object or other complex types
|
||||
if raw_user_id is not None:
|
||||
# If it's a string, use it
|
||||
if isinstance(raw_user_id, str):
|
||||
user_id = raw_user_id
|
||||
# If it has a dependency attribute (likely a Depends object), ignore it
|
||||
elif hasattr(raw_user_id, 'dependency'):
|
||||
logger.warning(f"Monitoring: request.state.user_id is a Depends object, ignoring.")
|
||||
user_id = None
|
||||
# Try to convert to string if it's a simple type
|
||||
else:
|
||||
try:
|
||||
user_id = str(raw_user_id)
|
||||
except:
|
||||
user_id = None
|
||||
|
||||
if user_id:
|
||||
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
|
||||
|
||||
# PRIORITY 2: Check query parameters
|
||||
elif hasattr(request, 'query_params') and 'user_id' in request.query_params:
|
||||
@@ -247,20 +274,23 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
user_id = request.path_params['user_id']
|
||||
|
||||
# PRIORITY 3: Check headers for user identification
|
||||
elif 'x-user-id' in request.headers:
|
||||
user_id = request.headers['x-user-id']
|
||||
elif 'x-user-email' in request.headers:
|
||||
user_id = request.headers['x-user-email'] # Use email as user identifier
|
||||
elif 'x-session-id' in request.headers:
|
||||
user_id = request.headers['x-session-id'] # Use session as fallback
|
||||
|
||||
# Check for authorization header with user info
|
||||
elif 'authorization' in request.headers:
|
||||
# Auth middleware should have set request.state.user_id
|
||||
# If not, this indicates an authentication failure (likely expired token)
|
||||
# Log at debug level to reduce noise - expired tokens are expected
|
||||
# But we can try to decode token if we really needed to, but let's rely on auth middleware
|
||||
pass
|
||||
elif hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
try:
|
||||
if request.headers.get('x-user-id'):
|
||||
user_id = request.headers.get('x-user-id')
|
||||
elif request.headers.get('x-user-email'):
|
||||
user_id = request.headers.get('x-user-email')
|
||||
elif request.headers.get('x-session-id'):
|
||||
user_id = request.headers.get('x-session-id')
|
||||
|
||||
# Check for authorization header with user info
|
||||
elif request.headers.get('authorization'):
|
||||
# Auth middleware should have set request.state.user_id
|
||||
# If not, this indicates an authentication failure (likely expired token)
|
||||
# Log at debug level to reduce noise - expired tokens are expected
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Error accessing request headers: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting user ID: {e}")
|
||||
@@ -269,7 +299,11 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
# Get database session if user identified
|
||||
db = None
|
||||
if user_id:
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database session for user {user_id}: {e}")
|
||||
db = None
|
||||
|
||||
# Capture request body for usage tracking (read once, safely)
|
||||
request_body = None
|
||||
@@ -291,29 +325,52 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
request_body = None
|
||||
|
||||
# Check usage limits before processing
|
||||
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
|
||||
if limit_response:
|
||||
if db: db.close()
|
||||
return limit_response
|
||||
# Skip for OPTIONS requests
|
||||
try:
|
||||
if request.method != "OPTIONS":
|
||||
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
|
||||
if limit_response:
|
||||
if db: db.close()
|
||||
return limit_response
|
||||
except Exception as e:
|
||||
logger.error(f"Error in usage limits middleware: {e}")
|
||||
# Continue processing if usage check fails (fail open)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Capture response body for usage tracking
|
||||
# Extract response body safely for usage tracking
|
||||
response_body = None
|
||||
try:
|
||||
if hasattr(response, 'body'):
|
||||
response_body = response.body.decode('utf-8') if response.body else None
|
||||
elif hasattr(response, '_content'):
|
||||
response_body = response._content.decode('utf-8') if response._content else None
|
||||
except:
|
||||
pass
|
||||
|
||||
if hasattr(response, 'body'):
|
||||
response_body = response.body.decode('utf-8') if response.body else None
|
||||
elif hasattr(response, '_content'):
|
||||
response_body = response._content.decode('utf-8') if response._content else None
|
||||
|
||||
# Track API usage if this is an API call to external providers
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
|
||||
# Safe URL path access
|
||||
try:
|
||||
path = request.url.path
|
||||
except:
|
||||
path = ""
|
||||
|
||||
# Safe User-Agent access - handle case where headers might be a Depends object
|
||||
user_agent = None
|
||||
try:
|
||||
# Defensive check: ensure request.headers is a valid headers object
|
||||
# Some dependency injection failures replace request attributes with Depends objects
|
||||
if hasattr(request, 'headers'):
|
||||
headers_obj = request.headers
|
||||
# Check if it has a 'get' method (like a dict or Headers object)
|
||||
if hasattr(headers_obj, 'get') and callable(headers_obj.get):
|
||||
user_agent = headers_obj.get('user-agent')
|
||||
except:
|
||||
pass
|
||||
|
||||
api_provider = api_monitor.detect_api_provider(path, user_agent)
|
||||
if api_provider and user_id:
|
||||
logger.info(f"Detected API call: {request.url.path} -> {api_provider.value} for user: {user_id}")
|
||||
try:
|
||||
@@ -326,7 +383,7 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
await usage_service.track_api_usage(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=request.url.path,
|
||||
endpoint=path,
|
||||
method=request.method,
|
||||
model_used=usage_metrics.get('model_used'),
|
||||
tokens_input=usage_metrics.get('tokens_input', 0),
|
||||
@@ -335,7 +392,7 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
status_code=status_code,
|
||||
request_size=len(request_body) if request_body else None,
|
||||
response_size=len(response_body) if response_body else None,
|
||||
user_agent=request.headers.get('user-agent'),
|
||||
user_agent=user_agent,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
search_count=usage_metrics.get('search_count', 0),
|
||||
image_count=usage_metrics.get('image_count', 0),
|
||||
|
||||
Reference in New Issue
Block a user