From 3a92c4af1a74c1d6c914e4b1f7942b8565d275c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D9=8A?= Date: Mon, 30 Mar 2026 08:09:28 +0530 Subject: [PATCH] Use tenant sessions for API key context and add startup key readiness check --- backend/services/startup_health.py | 59 ++++++++++++++++++++++++ backend/services/user_api_key_context.py | 8 ++-- 2 files changed, 64 insertions(+), 3 deletions(-) diff --git a/backend/services/startup_health.py b/backend/services/startup_health.py index 7ae07dcb..7efd93f6 100644 --- a/backend/services/startup_health.py +++ b/backend/services/startup_health.py @@ -15,6 +15,7 @@ from services.database import ( init_database, default_engine, ) +from services.user_api_key_context import get_user_api_keys _REQUIRED_SCHEMA: Dict[str, List[str]] = { "onboarding_sessions": ["id", "user_id", "updated_at"], @@ -144,6 +145,62 @@ def _check_db_access(checks: List[Dict[str, Any]], errors: List[str], warnings: return candidate_user +def _check_production_api_key_loading( + checks: List[Dict[str, Any]], + errors: List[str], + warnings: List[str], +) -> None: + deploy_env = os.getenv("DEPLOY_ENV", "local").strip().lower() + if deploy_env == "local": + _record_check(checks, "production_api_key_loading", True, "skipped in local deploy mode") + return + + test_tenant_id = os.getenv("ALWRITY_STARTUP_TEST_TENANT_ID", "").strip() + if not test_tenant_id: + message = ( + "Missing ALWRITY_STARTUP_TEST_TENANT_ID for production API key startup check." + ) + errors.append(message) + _record_check(checks, "production_api_key_loading", False, message) + return + + try: + keys = get_user_api_keys(test_tenant_id) + except Exception as exc: + errors.append( + f"Failed to load API keys for startup test tenant '{test_tenant_id}': {exc}" + ) + _record_check(checks, "production_api_key_loading", False, str(exc)) + return + + if not isinstance(keys, dict): + errors.append( + f"API key loader returned invalid payload type for startup test tenant '{test_tenant_id}'." + ) + _record_check(checks, "production_api_key_loading", False, "invalid payload type") + return + + non_empty_keys = [provider for provider, value in keys.items() if value] + if not non_empty_keys: + errors.append( + f"No API keys could be loaded for startup test tenant '{test_tenant_id}'." + ) + _record_check(checks, "production_api_key_loading", False, "no non-empty keys loaded") + return + + warning = None + if len(non_empty_keys) < len(keys): + warning = ( + f"Startup test tenant '{test_tenant_id}' has {len(non_empty_keys)}/{len(keys)} non-empty API keys." + ) + warnings.append(warning) + + detail = f"loaded {len(non_empty_keys)} non-empty keys for tenant {test_tenant_id}" + if warning: + detail = f"{detail}; {warning}" + _record_check(checks, "production_api_key_loading", True, detail) + + def run_startup_health_routine() -> Dict[str, Any]: checks: List[Dict[str, Any]] = [] errors: List[str] = [] @@ -152,6 +209,8 @@ def run_startup_health_routine() -> Dict[str, Any]: _check_workspace_root(checks, errors) if not errors: _check_db_access(checks, errors, warnings) + if not errors: + _check_production_api_key_loading(checks, errors, warnings) status = "healthy" if not errors else "failed" report = { diff --git a/backend/services/user_api_key_context.py b/backend/services/user_api_key_context.py index 7da02923..4b35feaf 100644 --- a/backend/services/user_api_key_context.py +++ b/backend/services/user_api_key_context.py @@ -71,10 +71,13 @@ class UserAPIKeyContext: """Load API keys from database for specific user.""" try: from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService - from services.database import SessionLocal + from services.database import get_session_for_user integration_service = OnboardingDataIntegrationService() - db = SessionLocal() + db = get_session_for_user(user_id) + if not db: + logger.error(f"Failed to create DB session for user {user_id}") + return {} try: integrated_data = integration_service.get_integrated_data_sync(user_id, db) keys = integrated_data.get('api_keys_data', {}) @@ -153,4 +156,3 @@ def get_tavily_key(user_id: Optional[str] = None) -> Optional[str]: def get_copilotkit_key(user_id: Optional[str] = None) -> Optional[str]: """Get CopilotKit API key for user.""" return UserAPIKeyContext.get_user_key(user_id, 'copilotkit') -