diff --git a/backend/api/onboarding_utils/endpoints_core.py b/backend/api/onboarding_utils/endpoints_core.py index 1cd1dafd..de3afa3e 100644 --- a/backend/api/onboarding_utils/endpoints_core.py +++ b/backend/api/onboarding_utils/endpoints_core.py @@ -45,16 +45,25 @@ async def initialize_onboarding(current_user: Dict[str, Any] = Depends(get_curre next_step = progress.get_next_incomplete_step() - # Derive a resilient current_step from DB if progress looks unset (production refresh) + # Derive a resilient current_step and is_completed from DB if file-based progress is absent/outdated derived_current_step = progress.current_step + derived_is_completed = progress.is_completed try: # Only derive if we're at the initial state - if not progress.is_completed and (progress.current_step in (1, 0)): + if (progress.current_step in (1, 0)) or not progress.is_completed: from services.onboarding_database_service import OnboardingDatabaseService from services.database import SessionLocal db = SessionLocal() try: db_service = OnboardingDatabaseService() + # If a DB session exists, prefer that state for completion + session_row = db_service.get_session_by_user(user_id, db) + if session_row: + # Trust explicit completion state from DB if available + if (getattr(session_row, 'current_step', 0) or 0) >= 6 or (getattr(session_row, 'progress', 0.0) or 0.0) >= 100.0: + derived_current_step = max(derived_current_step, 6) + derived_is_completed = True + # If website analysis exists -> at least step 2 completed website = db_service.get_website_analysis(user_id, db) if website and (website.get('website_url') or website.get('writing_style') or website.get('status') == 'completed'): @@ -63,10 +72,12 @@ async def initialize_onboarding(current_user: Dict[str, Any] = Depends(get_curre prefs = db_service.get_research_preferences(user_id, db) if prefs and (prefs.get('research_depth') or prefs.get('content_types')): derived_current_step = max(derived_current_step, 3) - # If persona data exists, bump to step 4 + # If persona data exists, bump to step 5 (personalization done) persona = db_service.get_persona_data(user_id, db) if persona and (persona.get('corePersona') or persona.get('platformPersonas')): - derived_current_step = max(derived_current_step, 4) + derived_current_step = max(derived_current_step, 5) + # If DB session did not explicitly mark completion but all major data exists, + # do not auto-complete; leave final step to the user. finally: db.close() except Exception: @@ -82,7 +93,7 @@ async def initialize_onboarding(current_user: Dict[str, Any] = Depends(get_curre "clerk_user_id": user_id, }, "onboarding": { - "is_completed": progress.is_completed, + "is_completed": derived_is_completed, "current_step": derived_current_step, "completion_percentage": progress.get_completion_percentage(), "next_step": next_step, diff --git a/backend/api/onboarding_utils/onboarding_completion_service.py b/backend/api/onboarding_utils/onboarding_completion_service.py index a2f01b7c..52c577e9 100644 --- a/backend/api/onboarding_utils/onboarding_completion_service.py +++ b/backend/api/onboarding_utils/onboarding_completion_service.py @@ -34,8 +34,8 @@ class OnboardingCompletionService: detail=f"Cannot complete onboarding. The following steps must be completed first: {missing_steps_str}" ) - # Validate API keys are configured - self._validate_api_keys() + # Validate API keys are configured (DB-aware) + self._validate_api_keys(user_id) # Generate writing persona from onboarding data only if not already present persona_generated = await self._generate_persona_from_onboarding(user_id) @@ -81,6 +81,15 @@ class OnboardingCompletionService: # DB-aware fallbacks for migration period try: if db_service: + if step_num == 1: + # Treat as completed if user has any API key in DB + keys = db_service.get_api_keys(user_id, db) + if keys and any(v for v in keys.values()): + try: + progress.mark_step_completed(1, {'source': 'db-fallback'}) + except Exception: + pass + continue if step_num == 2: # Treat as completed if website analysis exists in DB website = db_service.get_website_analysis(user_id, db) @@ -129,14 +138,52 @@ class OnboardingCompletionService: return missing_steps - def _validate_api_keys(self): - """Validate that API keys are configured.""" - api_manager = get_api_key_manager() - api_keys = api_manager.get_all_keys() - if not api_keys: + def _validate_api_keys(self, user_id: str): + """Validate that API keys are configured for the current user. + + Priority: + 1) Check database for per-user keys (production, user isolation) + 2) Fallback to in-memory/env keys via APIKeyManager (development/local) + """ + try: + # Prefer per-user DB keys in production + db = None + try: + db = next(get_db()) + db_service = OnboardingDatabaseService(db) + user_keys = db_service.get_api_keys(user_id, db) + if user_keys and any(v for v in user_keys.values()): + return + except Exception: + # DB lookup failed - continue to env fallback + pass + finally: + try: + if db and hasattr(db, 'close'): + db.close() + except Exception: + pass + + # Fallback to env/in-memory + api_manager = get_api_key_manager() + # Ensure latest env is loaded (middleware may have injected per-request keys) + try: + api_manager.load_api_keys() + except Exception: + pass + api_keys = api_manager.get_all_keys() + if not api_keys: + raise HTTPException( + status_code=400, + detail="Cannot complete onboarding. At least one AI provider API key must be configured." + ) + except HTTPException: + raise + except Exception: + # On unexpected error, fail closed with clear message raise HTTPException( status_code=400, - detail="Cannot complete onboarding. At least one AI provider API key must be configured." + detail="Cannot complete onboarding. API key validation failed." ) async def _generate_persona_from_onboarding(self, user_id: str) -> bool: