Compare commits

..

1 Commits

Author SHA1 Message Date
ي
11966cf341 Adjust missing API-key logging in injection middleware 2026-03-31 07:33:42 +05:30
3 changed files with 58 additions and 66 deletions

View File

@@ -8,6 +8,7 @@ IMPORTANT: This is a compatibility layer. For new code, use UserAPIKeyContext di
""" """
import os import os
import time
from fastapi import Request from fastapi import Request
from loguru import logger from loguru import logger
from typing import Callable from typing import Callable
@@ -20,8 +21,61 @@ class APIKeyInjectionMiddleware:
for the duration of each request. for the duration of each request.
""" """
# Shared across middleware instances (module currently instantiates per request)
_missing_keys_log_timestamps = {}
def __init__(self): def __init__(self):
self.original_keys = {} self.original_keys = {}
@staticmethod
def _should_skip_missing_key_warning(request: Request) -> bool:
"""
Optionally suppress missing-key warnings for non-AI/internal routes.
Controlled by API_KEY_INJECTION_SKIP_NON_AI_WARNINGS (default: true).
"""
skip_non_ai_warnings = os.getenv('API_KEY_INJECTION_SKIP_NON_AI_WARNINGS', 'true').lower() in ('1', 'true', 'yes')
if not skip_non_ai_warnings:
return False
path_lower = (request.url.path or '').lower()
return (
path_lower.startswith('/api/subscription/')
or path_lower.startswith('/api/onboarding/')
or path_lower.endswith('/status')
or path_lower.endswith('/health')
or path_lower == '/health'
or path_lower == '/status'
)
def _log_missing_keys_non_blocking(self, request: Request, user_id: str) -> None:
"""
Log missing API keys without interrupting request flow.
- Defaults to debug-level logging.
- Optional warn once-per-user-per-interval via env:
API_KEY_INJECTION_MISSING_KEYS_LOG_MODE=warn_once
API_KEY_INJECTION_MISSING_KEYS_LOG_INTERVAL_SECONDS=900
"""
try:
if self._should_skip_missing_key_warning(request):
logger.debug(f"[API Key Injection] Missing keys for user {user_id} on non-AI route; skipping warning")
return
log_mode = os.getenv('API_KEY_INJECTION_MISSING_KEYS_LOG_MODE', 'debug').lower()
if log_mode != 'warn_once':
logger.debug(f"No API keys found for user {user_id}")
return
interval_seconds = int(os.getenv('API_KEY_INJECTION_MISSING_KEYS_LOG_INTERVAL_SECONDS', '900'))
now = time.time()
last_logged_at = self._missing_keys_log_timestamps.get(user_id, 0)
if (now - last_logged_at) >= max(interval_seconds, 1):
logger.warning(f"No API keys found for user {user_id}")
self._missing_keys_log_timestamps[user_id] = now
else:
logger.debug(f"No API keys found for user {user_id} (warning suppressed by interval)")
except Exception as log_error:
# Logging should never block request processing
logger.debug(f"[API Key Injection] Failed to log missing keys state for user {user_id}: {log_error}")
async def __call__(self, request: Request, call_next: Callable): async def __call__(self, request: Request, call_next: Callable):
""" """
@@ -68,7 +122,7 @@ class APIKeyInjectionMiddleware:
# Get user-specific API keys from database # Get user-specific API keys from database
with user_api_keys(user_id) as user_keys: with user_api_keys(user_id) as user_keys:
if not user_keys: if not user_keys:
logger.warning(f"No API keys found for user {user_id}") self._log_missing_keys_non_blocking(request, user_id)
return await call_next(request) return await call_next(request)
# Save original environment values # Save original environment values
@@ -120,4 +174,3 @@ async def api_key_injection_middleware(request: Request, call_next: Callable):
""" """
middleware = APIKeyInjectionMiddleware() middleware = APIKeyInjectionMiddleware()
return await middleware(request, call_next) return await middleware(request, call_next)

View File

@@ -15,7 +15,6 @@ from services.database import (
init_database, init_database,
default_engine, default_engine,
) )
from services.user_api_key_context import get_user_api_keys
_REQUIRED_SCHEMA: Dict[str, List[str]] = { _REQUIRED_SCHEMA: Dict[str, List[str]] = {
"onboarding_sessions": ["id", "user_id", "updated_at"], "onboarding_sessions": ["id", "user_id", "updated_at"],
@@ -145,62 +144,6 @@ def _check_db_access(checks: List[Dict[str, Any]], errors: List[str], warnings:
return candidate_user return candidate_user
def _check_production_api_key_loading(
checks: List[Dict[str, Any]],
errors: List[str],
warnings: List[str],
) -> None:
deploy_env = os.getenv("DEPLOY_ENV", "local").strip().lower()
if deploy_env == "local":
_record_check(checks, "production_api_key_loading", True, "skipped in local deploy mode")
return
test_tenant_id = os.getenv("ALWRITY_STARTUP_TEST_TENANT_ID", "").strip()
if not test_tenant_id:
message = (
"Missing ALWRITY_STARTUP_TEST_TENANT_ID for production API key startup check."
)
errors.append(message)
_record_check(checks, "production_api_key_loading", False, message)
return
try:
keys = get_user_api_keys(test_tenant_id)
except Exception as exc:
errors.append(
f"Failed to load API keys for startup test tenant '{test_tenant_id}': {exc}"
)
_record_check(checks, "production_api_key_loading", False, str(exc))
return
if not isinstance(keys, dict):
errors.append(
f"API key loader returned invalid payload type for startup test tenant '{test_tenant_id}'."
)
_record_check(checks, "production_api_key_loading", False, "invalid payload type")
return
non_empty_keys = [provider for provider, value in keys.items() if value]
if not non_empty_keys:
errors.append(
f"No API keys could be loaded for startup test tenant '{test_tenant_id}'."
)
_record_check(checks, "production_api_key_loading", False, "no non-empty keys loaded")
return
warning = None
if len(non_empty_keys) < len(keys):
warning = (
f"Startup test tenant '{test_tenant_id}' has {len(non_empty_keys)}/{len(keys)} non-empty API keys."
)
warnings.append(warning)
detail = f"loaded {len(non_empty_keys)} non-empty keys for tenant {test_tenant_id}"
if warning:
detail = f"{detail}; {warning}"
_record_check(checks, "production_api_key_loading", True, detail)
def run_startup_health_routine() -> Dict[str, Any]: def run_startup_health_routine() -> Dict[str, Any]:
checks: List[Dict[str, Any]] = [] checks: List[Dict[str, Any]] = []
errors: List[str] = [] errors: List[str] = []
@@ -209,8 +152,6 @@ def run_startup_health_routine() -> Dict[str, Any]:
_check_workspace_root(checks, errors) _check_workspace_root(checks, errors)
if not errors: if not errors:
_check_db_access(checks, errors, warnings) _check_db_access(checks, errors, warnings)
if not errors:
_check_production_api_key_loading(checks, errors, warnings)
status = "healthy" if not errors else "failed" status = "healthy" if not errors else "failed"
report = { report = {

View File

@@ -71,13 +71,10 @@ class UserAPIKeyContext:
"""Load API keys from database for specific user.""" """Load API keys from database for specific user."""
try: try:
from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService
from services.database import get_session_for_user from services.database import SessionLocal
integration_service = OnboardingDataIntegrationService() integration_service = OnboardingDataIntegrationService()
db = get_session_for_user(user_id) db = SessionLocal()
if not db:
logger.error(f"Failed to create DB session for user {user_id}")
return {}
try: try:
integrated_data = integration_service.get_integrated_data_sync(user_id, db) integrated_data = integration_service.get_integrated_data_sync(user_id, db)
keys = integrated_data.get('api_keys_data', {}) keys = integrated_data.get('api_keys_data', {})
@@ -156,3 +153,4 @@ def get_tavily_key(user_id: Optional[str] = None) -> Optional[str]:
def get_copilotkit_key(user_id: Optional[str] = None) -> Optional[str]: def get_copilotkit_key(user_id: Optional[str] = None) -> Optional[str]:
"""Get CopilotKit API key for user.""" """Get CopilotKit API key for user."""
return UserAPIKeyContext.get_user_key(user_id, 'copilotkit') return UserAPIKeyContext.get_user_key(user_id, 'copilotkit')