""" API Key Manager with Database-Only Onboarding Progress Manages API keys and onboarding progress with database persistence only. Removed all file-based JSON storage for production scalability. """ import os import json from typing import Dict, Any, Optional, List from datetime import datetime from loguru import logger from enum import Enum from services.database import get_db_session class StepStatus(Enum): """Onboarding step status.""" PENDING = "pending" IN_PROGRESS = "in_progress" COMPLETED = "completed" SKIPPED = "skipped" FAILED = "failed" class StepData: """Data structure for onboarding step.""" def __init__(self, step_number: int, title: str, description: str, status: StepStatus = StepStatus.PENDING): self.step_number = step_number self.title = title self.description = description self.status = status self.completed_at = None self.data = None self.validation_errors = [] class OnboardingProgress: """Manages onboarding progress with database persistence only.""" def __init__(self, user_id: Optional[str] = None): self.steps = self._initialize_steps() self.current_step = 1 self.started_at = datetime.now().isoformat() self.last_updated = datetime.now().isoformat() self.is_completed = False self.completed_at = None self.user_id = user_id # Add user_id for database isolation # Initialize database service for persistence try: from .database_service import OnboardingDatabaseService self.db_service = OnboardingDatabaseService() self.use_database = True logger.info(f"Database service initialized for user {user_id}") except Exception as e: logger.error(f"Database service not available: {e}") self.db_service = None self.use_database = False raise Exception(f"Database service required but not available: {e}") # Load existing progress from database if available if self.use_database and self.user_id: self.load_progress_from_db() def _initialize_steps(self) -> List[StepData]: """Initialize the 6-step onboarding process.""" return [ StepData(1, "AI LLM Providers", "Configure AI language model providers", StepStatus.PENDING), StepData(2, "Website Analysis", "Set up website analysis and crawling", StepStatus.PENDING), StepData(3, "AI Research", "Configure AI research capabilities", StepStatus.PENDING), StepData(4, "Personalization", "Set up personalization features", StepStatus.PENDING), StepData(5, "Integrations", "Configure ALwrity integrations", StepStatus.PENDING), StepData(6, "Complete Setup", "Finalize and complete onboarding", StepStatus.PENDING) ] def get_step_data(self, step_number: int) -> Optional[StepData]: """Get data for a specific step.""" for step in self.steps: if step.step_number == step_number: return step return None def mark_step_completed(self, step_number: int, data: Optional[Dict[str, Any]] = None): """Mark a step as completed.""" logger.info(f"[mark_step_completed] Marking step {step_number} as completed") step = self.get_step_data(step_number) if step: step.status = StepStatus.COMPLETED step.completed_at = datetime.now().isoformat() step.data = data self.last_updated = datetime.now().isoformat() # Check if all steps are now completed all_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps) if all_completed: # If all steps are completed, mark onboarding as complete self.is_completed = True self.completed_at = datetime.now().isoformat() self.current_step = len(self.steps) # Set to last step number logger.info(f"[mark_step_completed] All steps completed, marking onboarding as complete") else: # Only increment current_step if there are more steps to go self.current_step = step_number + 1 # Ensure current_step doesn't exceed total steps if self.current_step > len(self.steps): self.current_step = len(self.steps) logger.info(f"[mark_step_completed] Step {step_number} completed, new current_step: {self.current_step}, is_completed: {self.is_completed}") self.save_progress() logger.info(f"Step {step_number} marked as completed") else: logger.error(f"[mark_step_completed] Step {step_number} not found") def mark_step_in_progress(self, step_number: int): """Mark a step as in progress.""" step = self.get_step_data(step_number) if step: step.status = StepStatus.IN_PROGRESS self.current_step = step_number self.last_updated = datetime.now().isoformat() self.save_progress() logger.info(f"Step {step_number} marked as in progress") else: logger.error(f"Step {step_number} not found") def mark_step_skipped(self, step_number: int): """Mark a step as skipped.""" step = self.get_step_data(step_number) if step: step.status = StepStatus.SKIPPED step.completed_at = datetime.now().isoformat() self.last_updated = datetime.now().isoformat() self.save_progress() logger.info(f"Step {step_number} marked as skipped") else: logger.error(f"Step {step_number} not found") def mark_step_failed(self, step_number: int, error_message: str): """Mark a step as failed with error message.""" step = self.get_step_data(step_number) if step: step.status = StepStatus.FAILED step.validation_errors.append(error_message) self.last_updated = datetime.now().isoformat() self.save_progress() logger.error(f"Step {step_number} marked as failed: {error_message}") else: logger.error(f"Step {step_number} not found") def get_progress_summary(self) -> Dict[str, Any]: """Get current progress summary.""" completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED) skipped_count = sum(1 for s in self.steps if s.status == StepStatus.SKIPPED) failed_count = sum(1 for s in self.steps if s.status == StepStatus.FAILED) return { "total_steps": len(self.steps), "completed_steps": completed_count, "skipped_steps": skipped_count, "failed_steps": failed_count, "current_step": self.current_step, "is_completed": self.is_completed, "progress_percentage": (completed_count + skipped_count) / len(self.steps) * 100 } def get_next_step(self) -> Optional[StepData]: """Get the next step to work on.""" for step in self.steps: if step.status == StepStatus.PENDING: return step return None def get_completed_steps(self) -> List[StepData]: """Get all completed steps.""" return [step for step in self.steps if step.status == StepStatus.COMPLETED] def get_failed_steps(self) -> List[StepData]: """Get all failed steps.""" return [step for step in self.steps if step.status == StepStatus.FAILED] def reset_step(self, step_number: int): """Reset a step to pending status.""" step = self.get_step_data(step_number) if step: step.status = StepStatus.PENDING step.completed_at = None step.data = None step.validation_errors = [] self.last_updated = datetime.now().isoformat() self.save_progress() logger.info(f"Step {step_number} reset to pending") else: logger.error(f"Step {step_number} not found") def reset_all_steps(self): """Reset all steps to pending status.""" for step in self.steps: step.status = StepStatus.PENDING step.completed_at = None step.data = None step.validation_errors = [] self.current_step = 1 self.is_completed = False self.completed_at = None self.last_updated = datetime.now().isoformat() self.save_progress() logger.info("All steps reset to pending") def complete_onboarding(self): """Mark onboarding as complete.""" self.is_completed = True self.completed_at = datetime.now().isoformat() self.current_step = len(self.steps) self.last_updated = datetime.now().isoformat() self.save_progress() logger.info("Onboarding completed successfully") def save_progress(self): """Save progress to database.""" if not self.use_database or not self.db_service or not self.user_id: logger.error("Cannot save progress: database service not available or user_id not set") return try: from services.database import SessionLocal db = SessionLocal() try: # Update session progress self.db_service.update_step(self.user_id, self.current_step, db) # Calculate progress percentage completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED) progress_pct = (completed_count / len(self.steps)) * 100 self.db_service.update_progress(self.user_id, progress_pct, db) # Save step-specific data to appropriate tables for step in self.steps: if step.status == StepStatus.COMPLETED and step.data: if step.step_number == 1: # API Keys api_keys = step.data.get('api_keys', {}) for provider, key in api_keys.items(): if key: # Save to database (for user isolation in production) self.db_service.save_api_key(self.user_id, provider, key, db) # Also save to .env file ONLY in local development # This allows local developers to have keys in .env for convenience # In production, keys are fetched from database per user is_local = os.getenv('DEPLOY_ENV', 'local') == 'local' if is_local: try: from services.api_key_manager import APIKeyManager api_key_manager = APIKeyManager() api_key_manager.save_api_key(provider, key) logger.info(f"[LOCAL] API key for {provider} saved to .env file") except Exception as env_error: logger.warning(f"[LOCAL] Failed to save {provider} API key to .env file: {env_error}") else: logger.info(f"[PRODUCTION] API key for {provider} saved to database only (user: {self.user_id})") # Log database save confirmation logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}") elif step.step_number == 2: # Website Analysis # Transform frontend data structure to match database schema # Frontend sends: { website: "url", analysis: {...} } # Database expects: { website_url: "url", ...analysis (flattened) } analysis_for_db = {} if step.data: # Extract website_url from 'website' or 'website_url' field website_url = step.data.get('website') or step.data.get('website_url') if website_url: analysis_for_db['website_url'] = website_url # Flatten nested 'analysis' object if it exists if 'analysis' in step.data and isinstance(step.data['analysis'], dict): analysis_for_db.update(step.data['analysis']) # Also include any other top-level fields (except 'website' and 'analysis') for key, value in step.data.items(): if key not in ['website', 'website_url', 'analysis']: analysis_for_db[key] = value # Ensure status is set if 'status' not in analysis_for_db: analysis_for_db['status'] = 'completed' self.db_service.save_website_analysis(self.user_id, analysis_for_db, db) logger.info(f"✅ DATABASE: Website analysis saved to database for user {self.user_id}") elif step.step_number == 3: # Research Preferences self.db_service.save_research_preferences(self.user_id, step.data, db) logger.info(f"✅ DATABASE: Research preferences saved to database for user {self.user_id}") elif step.step_number == 4: # Persona Generation self.db_service.save_persona_data(self.user_id, step.data, db) logger.info(f"✅ DATABASE: Persona data saved to database for user {self.user_id}") logger.info(f"Progress saved to database for user {self.user_id}") finally: db.close() except Exception as e: logger.error(f"Error saving progress to database: {str(e)}") raise def load_progress_from_db(self): """Load progress from database.""" if not self.use_database or not self.db_service or not self.user_id: logger.warning("Cannot load progress: database service not available or user_id not set") return try: from services.database import SessionLocal db = SessionLocal() try: # Get session data session = self.db_service.get_session_by_user(self.user_id, db) if not session: logger.info(f"No existing onboarding session found for user {self.user_id}, starting fresh") return # Restore session data self.current_step = session.current_step or 1 self.started_at = session.started_at.isoformat() if session.started_at else self.started_at self.last_updated = session.last_updated.isoformat() if session.last_updated else self.last_updated self.is_completed = session.is_completed or False self.completed_at = session.completed_at.isoformat() if session.completed_at else None # Load step-specific data from database self._load_step_data_from_db(db) # Fix any corrupted state self._fix_corrupted_state() logger.info(f"Progress loaded from database for user {self.user_id}") finally: db.close() except Exception as e: logger.error(f"Error loading progress from database: {str(e)}") # Don't fail if database loading fails - start fresh def _load_step_data_from_db(self, db): """Load step-specific data from database tables.""" try: # Load API keys (step 1) api_keys = self.db_service.get_api_keys(self.user_id, db) if api_keys: step1 = self.get_step_data(1) if step1: step1.status = StepStatus.COMPLETED step1.data = {'api_keys': api_keys} step1.completed_at = datetime.now().isoformat() # Load website analysis (step 2) website_analysis = self.db_service.get_website_analysis(self.user_id, db) if website_analysis: step2 = self.get_step_data(2) if step2: step2.status = StepStatus.COMPLETED step2.data = website_analysis step2.completed_at = datetime.now().isoformat() # Load research preferences (step 3) research_prefs = self.db_service.get_research_preferences(self.user_id, db) if research_prefs: step3 = self.get_step_data(3) if step3: step3.status = StepStatus.COMPLETED step3.data = research_prefs step3.completed_at = datetime.now().isoformat() # Load persona data (step 4) persona_data = self.db_service.get_persona_data(self.user_id, db) if persona_data: step4 = self.get_step_data(4) if step4: step4.status = StepStatus.COMPLETED step4.data = persona_data step4.completed_at = datetime.now().isoformat() logger.info("Step data loaded from database") except Exception as e: logger.error(f"Error loading step data from database: {str(e)}") def _fix_corrupted_state(self): """Fix any corrupted progress state.""" # Check if all steps are completed all_steps_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps) if all_steps_completed: self.is_completed = True self.completed_at = self.completed_at or datetime.now().isoformat() self.current_step = len(self.steps) else: # Find the first incomplete step for i, step in enumerate(self.steps): if step.status == StepStatus.PENDING: self.current_step = step.step_number break class APIKeyManager: """Manages API keys for different providers.""" def __init__(self): self.api_keys = {} self._load_from_env() def _load_from_env(self): """Load API keys from environment variables.""" providers = [ 'GEMINI_API_KEY', 'HF_TOKEN', 'TAVILY_API_KEY', 'SERPER_API_KEY', 'METAPHOR_API_KEY', 'FIRECRAWL_API_KEY', 'STABILITY_API_KEY', 'WAVESPEED_API_KEY', ] for provider in providers: key = os.getenv(provider) if key: # Convert provider name to lowercase for consistency provider_name = provider.replace('_API_KEY', '').lower() self.api_keys[provider_name] = key logger.info(f"Loaded {provider_name} API key from environment") def get_api_key(self, provider: str) -> Optional[str]: """Get API key for a provider.""" return self.api_keys.get(provider.lower()) def save_api_key(self, provider: str, api_key: str): """Save API key to environment and memory.""" provider_lower = provider.lower() self.api_keys[provider_lower] = api_key # Update environment variable env_var = f"{provider.upper()}_API_KEY" os.environ[env_var] = api_key logger.info(f"Saved {provider} API key") def has_api_key(self, provider: str) -> bool: """Check if API key exists for provider.""" return provider.lower() in self.api_keys and bool(self.api_keys[provider.lower()]) def get_all_keys(self) -> Dict[str, str]: """Get all API keys.""" return self.api_keys.copy() def remove_api_key(self, provider: str): """Remove API key for provider.""" provider_lower = provider.lower() if provider_lower in self.api_keys: del self.api_keys[provider_lower] # Remove from environment env_var = f"{provider.upper()}_API_KEY" if env_var in os.environ: del os.environ[env_var] logger.info(f"Removed {provider} API key") # Global instances _user_onboarding_progress_cache = {} def get_user_onboarding_progress(user_id: str) -> OnboardingProgress: """Get user-specific onboarding progress instance.""" global _user_onboarding_progress_cache safe_user_id = ''.join([c if c.isalnum() or c in ('-', '_') else '_' for c in str(user_id)]) if safe_user_id in _user_onboarding_progress_cache: return _user_onboarding_progress_cache[safe_user_id] # Pass user_id to enable database persistence instance = OnboardingProgress(user_id=user_id) _user_onboarding_progress_cache[safe_user_id] = instance return instance def get_onboarding_progress_for_user(user_id: str) -> OnboardingProgress: """Get user-specific onboarding progress instance (alias for compatibility).""" return get_user_onboarding_progress(user_id) def get_onboarding_progress(): """Get the global onboarding progress instance.""" if not hasattr(get_onboarding_progress, '_instance'): get_onboarding_progress._instance = OnboardingProgress() return get_onboarding_progress._instance def get_api_key_manager() -> APIKeyManager: """Get the global API key manager instance.""" if not hasattr(get_api_key_manager, '_instance'): get_api_key_manager._instance = APIKeyManager() return get_api_key_manager._instance