""" Onboarding Database Service Provides database-backed storage for onboarding progress with user isolation. This replaces the JSON file-based storage with proper database persistence. """ from typing import Dict, Any, Optional, List import os import json from datetime import datetime from loguru import logger from sqlalchemy.orm import Session from sqlalchemy.exc import SQLAlchemyError from sqlalchemy import text from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData from services.database import get_db class OnboardingDatabaseService: """Database service for onboarding with user isolation.""" def __init__(self, db: Session = None): """Initialize with optional database session.""" self.db = db # Cache for schema feature detection self._brand_cols_checked: bool = False self._brand_cols_available: bool = False self._research_persona_cols_checked: bool = False self._research_persona_cols_available: bool = False # --- Feature flags and schema detection helpers --- def _brand_feature_enabled(self) -> bool: """Check if writing brand-related columns is enabled via env flag.""" return os.getenv('ENABLE_WEBSITE_BRAND_COLUMNS', 'true').lower() in {'1', 'true', 'yes', 'on'} def _ensure_research_persona_columns(self, session_db: Session) -> None: """Ensure research_persona columns exist in persona_data table (runtime migration).""" if self._research_persona_cols_checked: return try: # Check if columns exist using PRAGMA (SQLite) or information_schema (PostgreSQL) db_url = str(session_db.bind.url) if session_db.bind else "" if 'sqlite' in db_url.lower(): # SQLite: Use PRAGMA to check columns result = session_db.execute(text("PRAGMA table_info(persona_data)")) cols = {row[1] for row in result} # Column name is at index 1 if 'research_persona' not in cols: logger.info("Adding missing column research_persona to persona_data table") session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSON")) session_db.commit() if 'research_persona_generated_at' not in cols: logger.info("Adding missing column research_persona_generated_at to persona_data table") session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP")) session_db.commit() self._research_persona_cols_available = True else: # PostgreSQL: Try to query the columns (will fail if they don't exist) try: session_db.execute(text("SELECT research_persona, research_persona_generated_at FROM persona_data LIMIT 0")) self._research_persona_cols_available = True except Exception: # Columns don't exist, add them logger.info("Adding missing columns research_persona and research_persona_generated_at to persona_data table") try: session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSONB")) session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP")) session_db.commit() self._research_persona_cols_available = True except Exception as alter_err: logger.error(f"Failed to add research_persona columns: {alter_err}") session_db.rollback() raise except Exception as e: logger.error(f"Error ensuring research_persona columns: {e}") session_db.rollback() raise finally: self._research_persona_cols_checked = True def _ensure_brand_column_detection(self, session_db: Session) -> None: """Detect at runtime whether brand columns exist and cache the result.""" if self._brand_cols_checked: return try: # This works across SQLite/Postgres; LIMIT 0 avoids scanning session_db.execute(text('SELECT brand_analysis, content_strategy_insights FROM website_analyses LIMIT 0')) self._brand_cols_available = True except Exception: self._brand_cols_available = False finally: self._brand_cols_checked = True def _maybe_update_brand_columns(self, session_db: Session, session_id: int, brand_analysis: Any, content_strategy_insights: Any) -> None: """Safely update brand columns using raw SQL if feature enabled and columns exist.""" if not self._brand_feature_enabled(): return self._ensure_brand_column_detection(session_db) if not self._brand_cols_available: return try: session_db.execute( text(''' UPDATE website_analyses SET brand_analysis = :brand_analysis, content_strategy_insights = :content_strategy_insights WHERE session_id = :session_id '''), { 'brand_analysis': json.dumps(brand_analysis) if brand_analysis is not None else None, 'content_strategy_insights': json.dumps(content_strategy_insights) if content_strategy_insights is not None else None, 'session_id': session_id, } ) except Exception as e: logger.warning(f"Skipped updating brand columns (not critical): {e}") def _maybe_attach_brand_columns(self, session_db: Session, session_id: int, result: Dict[str, Any]) -> None: """Optionally read brand columns and attach to result if available.""" if not self._brand_feature_enabled(): return self._ensure_brand_column_detection(session_db) if not self._brand_cols_available: return try: row = session_db.execute( text(''' SELECT brand_analysis, content_strategy_insights FROM website_analyses WHERE session_id = :session_id LIMIT 1 '''), {'session_id': session_id} ).mappings().first() if row: brand = row.get('brand_analysis') insights = row.get('content_strategy_insights') # If stored as TEXT in SQLite, try to parse JSON if isinstance(brand, str): try: brand = json.loads(brand) except Exception: pass if isinstance(insights, str): try: insights = json.loads(insights) except Exception: pass result['brand_analysis'] = brand result['content_strategy_insights'] = insights except Exception as e: logger.warning(f"Skipped reading brand columns (not critical): {e}") def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession: """Get existing onboarding session or create new one for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: # Try to get existing session for this user session = session_db.query(OnboardingSession).filter( OnboardingSession.user_id == user_id ).first() if session: logger.info(f"Found existing onboarding session for user {user_id}") return session # Create new session session = OnboardingSession( user_id=user_id, current_step=1, progress=0.0, started_at=datetime.now() ) session_db.add(session) session_db.commit() session_db.refresh(session) logger.info(f"Created new onboarding session for user {user_id}") return session except SQLAlchemyError as e: logger.error(f"Database error in get_or_create_session: {e}") session_db.rollback() raise def get_session_by_user(self, user_id: str, db: Session = None) -> Optional[OnboardingSession]: """Get onboarding session for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: return session_db.query(OnboardingSession).filter( OnboardingSession.user_id == user_id ).first() except SQLAlchemyError as e: logger.error(f"Error getting session: {e}") return None def update_step(self, user_id: str, step_number: int, db: Session = None) -> bool: """Update current step for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_or_create_session(user_id, session_db) session.current_step = step_number session.updated_at = datetime.now() session_db.commit() logger.info(f"Updated user {user_id} to step {step_number}") return True except SQLAlchemyError as e: logger.error(f"Error updating step: {e}") session_db.rollback() return False def update_progress(self, user_id: str, progress: float, db: Session = None) -> bool: """Update progress percentage for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_or_create_session(user_id, session_db) session.progress = progress session.updated_at = datetime.now() session_db.commit() logger.info(f"Updated user {user_id} progress to {progress}%") return True except SQLAlchemyError as e: logger.error(f"Error updating progress: {e}") session_db.rollback() return False def save_api_key(self, user_id: str, provider: str, api_key: str, db: Session = None) -> bool: """Save API key for user with isolation.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: # Get user's onboarding session session = self.get_or_create_session(user_id, session_db) # Check if key already exists for this provider and session existing_key = session_db.query(APIKey).filter( APIKey.session_id == session.id, APIKey.provider == provider ).first() if existing_key: # Update existing key existing_key.key = api_key existing_key.updated_at = datetime.now() logger.info(f"Updated {provider} API key for user {user_id}") else: # Create new key new_key = APIKey( session_id=session.id, provider=provider, key=api_key ) session_db.add(new_key) logger.info(f"Created new {provider} API key for user {user_id}") session_db.commit() return True except SQLAlchemyError as e: logger.error(f"Error saving API key: {e}") session_db.rollback() return False def get_api_keys(self, user_id: str, db: Session = None) -> Dict[str, str]: """Get all API keys for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_session_by_user(user_id, session_db) if not session: return {} keys = session_db.query(APIKey).filter( APIKey.session_id == session.id ).all() return {key.provider: key.key for key in keys} except SQLAlchemyError as e: logger.error(f"Error getting API keys: {e}") return {} def save_website_analysis(self, user_id: str, analysis_data: Dict[str, Any], db: Session = None) -> bool: """Save website analysis for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_or_create_session(user_id, session_db) # Normalize payload. Step 2 sometimes sends { website, analysis: {...} } # while DB expects flattened fields. Support both shapes. incoming = analysis_data or {} nested = incoming.get('analysis') if isinstance(incoming.get('analysis'), dict) else None normalized = { 'website_url': incoming.get('website') or incoming.get('website_url') or '', 'writing_style': (nested or incoming).get('writing_style'), 'content_characteristics': (nested or incoming).get('content_characteristics'), 'target_audience': (nested or incoming).get('target_audience'), 'content_type': (nested or incoming).get('content_type'), 'recommended_settings': (nested or incoming).get('recommended_settings'), 'brand_analysis': (nested or incoming).get('brand_analysis'), 'content_strategy_insights': (nested or incoming).get('content_strategy_insights'), 'crawl_result': (nested or incoming).get('crawl_result'), 'style_patterns': (nested or incoming).get('style_patterns'), 'style_guidelines': (nested or incoming).get('style_guidelines'), 'status': (nested or incoming).get('status', incoming.get('status', 'completed')), } # Check if analysis already exists existing = session_db.query(WebsiteAnalysis).filter( WebsiteAnalysis.session_id == session.id ).first() if existing: # Update existing - only update website_url if normalized value is not empty # This prevents overwriting a valid URL with an empty string when step.data # doesn't include the website field normalized_url = normalized.get('website_url', '').strip() if normalized.get('website_url') else '' if normalized_url: existing.website_url = normalized_url # If normalized_url is empty, keep existing.website_url unchanged existing.writing_style = normalized.get('writing_style') existing.content_characteristics = normalized.get('content_characteristics') existing.target_audience = normalized.get('target_audience') existing.content_type = normalized.get('content_type') existing.recommended_settings = normalized.get('recommended_settings') existing.crawl_result = normalized.get('crawl_result') existing.style_patterns = normalized.get('style_patterns') existing.style_guidelines = normalized.get('style_guidelines') existing.status = normalized.get('status', 'completed') existing.updated_at = datetime.now() logger.info(f"Updated website analysis for user {user_id}") else: # Create new analysis = WebsiteAnalysis( session_id=session.id, website_url=normalized.get('website_url', ''), writing_style=normalized.get('writing_style'), content_characteristics=normalized.get('content_characteristics'), target_audience=normalized.get('target_audience'), content_type=normalized.get('content_type'), recommended_settings=normalized.get('recommended_settings'), crawl_result=normalized.get('crawl_result'), style_patterns=normalized.get('style_patterns'), style_guidelines=normalized.get('style_guidelines'), status=normalized.get('status', 'completed') ) session_db.add(analysis) logger.info(f"Created website analysis for user {user_id}") session_db.commit() # Optional brand column update via raw SQL (feature-flagged) self._maybe_update_brand_columns( session_db=session_db, session_id=session.id, brand_analysis=normalized.get('brand_analysis'), content_strategy_insights=normalized.get('content_strategy_insights') ) session_db.commit() return True except SQLAlchemyError as e: logger.error(f"Error saving website analysis: {e}") session_db.rollback() return False def get_website_analysis(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]: """Get website analysis for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_session_by_user(user_id, session_db) if not session: return None analysis = session_db.query(WebsiteAnalysis).filter( WebsiteAnalysis.session_id == session.id ).first() result = analysis.to_dict() if analysis else None if result: # Optionally include brand fields without touching ORM mapping self._maybe_attach_brand_columns(session_db, session.id, result) return result except SQLAlchemyError as e: logger.error(f"Error getting website analysis: {e}") return None def save_research_preferences(self, user_id: str, preferences: Dict[str, Any], db: Session = None) -> bool: """Save research preferences for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_or_create_session(user_id, session_db) # Check if preferences already exist existing = session_db.query(ResearchPreferences).filter( ResearchPreferences.session_id == session.id ).first() if existing: # Update existing existing.research_depth = preferences.get('research_depth', existing.research_depth) existing.content_types = preferences.get('content_types', existing.content_types) existing.auto_research = preferences.get('auto_research', existing.auto_research) existing.factual_content = preferences.get('factual_content', existing.factual_content) existing.writing_style = preferences.get('writing_style') existing.content_characteristics = preferences.get('content_characteristics') existing.target_audience = preferences.get('target_audience') existing.recommended_settings = preferences.get('recommended_settings') existing.updated_at = datetime.now() logger.info(f"Updated research preferences for user {user_id}") else: # Create new prefs = ResearchPreferences( session_id=session.id, research_depth=preferences.get('research_depth', 'standard'), content_types=preferences.get('content_types', []), auto_research=preferences.get('auto_research', True), factual_content=preferences.get('factual_content', True), writing_style=preferences.get('writing_style'), content_characteristics=preferences.get('content_characteristics'), target_audience=preferences.get('target_audience'), recommended_settings=preferences.get('recommended_settings') ) session_db.add(prefs) logger.info(f"Created research preferences for user {user_id}") session_db.commit() return True except SQLAlchemyError as e: logger.error(f"Error saving research preferences: {e}") session_db.rollback() return False def save_persona_data(self, user_id: str, persona_data: Dict[str, Any], db: Session = None) -> bool: """Save persona data for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_or_create_session(user_id, session_db) # Check if persona data already exists for this user existing = session_db.query(PersonaData).filter( PersonaData.session_id == session.id ).first() if existing: # Update existing persona data existing.core_persona = persona_data.get('corePersona') existing.platform_personas = persona_data.get('platformPersonas') existing.quality_metrics = persona_data.get('qualityMetrics') existing.selected_platforms = persona_data.get('selectedPlatforms', []) existing.updated_at = datetime.utcnow() logger.info(f"Updated persona data for user {user_id}") else: # Create new persona data record persona = PersonaData( session_id=session.id, core_persona=persona_data.get('corePersona'), platform_personas=persona_data.get('platformPersonas'), quality_metrics=persona_data.get('qualityMetrics'), selected_platforms=persona_data.get('selectedPlatforms', []) ) session_db.add(persona) logger.info(f"Created persona data for user {user_id}") session_db.commit() return True except SQLAlchemyError as e: logger.error(f"Error saving persona data: {e}") session_db.rollback() return False def get_research_preferences(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]: """Get research preferences for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_session_by_user(user_id, session_db) if not session: return None prefs = session_db.query(ResearchPreferences).filter( ResearchPreferences.session_id == session.id ).first() return prefs.to_dict() if prefs else None except SQLAlchemyError as e: logger.error(f"Error getting research preferences: {e}") return None def get_competitor_analysis(self, user_id: str, db: Session = None) -> Optional[List[Dict[str, Any]]]: """Get competitor analysis data for user from onboarding.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: from models.onboarding import CompetitorAnalysis session = self.get_session_by_user(user_id, session_db) if not session: return None # Query CompetitorAnalysis table competitor_records = session_db.query(CompetitorAnalysis).filter( CompetitorAnalysis.session_id == session.id ).all() if not competitor_records: return None # Convert to list of dicts competitors = [] for record in competitor_records: analysis_data = record.analysis_data or {} competitors.append({ "url": record.competitor_url, "domain": record.competitor_domain or record.competitor_url, "title": analysis_data.get("title", record.competitor_domain or ""), "summary": analysis_data.get("summary", ""), "relevance_score": analysis_data.get("relevance_score", 0.5), "highlights": analysis_data.get("highlights", []), "favicon": analysis_data.get("favicon"), "image": analysis_data.get("image"), "published_date": analysis_data.get("published_date"), "author": analysis_data.get("author"), "competitive_insights": analysis_data.get("competitive_analysis", {}), "content_insights": analysis_data.get("content_insights", {}) }) return competitors except SQLAlchemyError as e: logger.error(f"Error getting competitor analysis: {e}") return None def get_persona_data(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]: """Get persona data for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") # Ensure research_persona columns exist before querying self._ensure_research_persona_columns(session_db) try: session = self.get_session_by_user(user_id, session_db) if not session: return None persona = session_db.query(PersonaData).filter( PersonaData.session_id == session.id ).first() if not persona: return None # Return persona data in the expected format return { 'corePersona': persona.core_persona, 'platformPersonas': persona.platform_personas, 'qualityMetrics': persona.quality_metrics, 'selectedPlatforms': persona.selected_platforms } except SQLAlchemyError as e: logger.error(f"Error getting persona data: {e}") return None def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool: """Mark onboarding as complete for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_or_create_session(user_id, session_db) session.current_step = 6 # Final step session.progress = 100.0 session.updated_at = datetime.now() session_db.commit() logger.info(f"Marked onboarding complete for user {user_id}") return True except SQLAlchemyError as e: logger.error(f"Error marking onboarding complete: {e}") session_db.rollback() return False def get_onboarding_status(self, user_id: str, db: Session = None) -> Dict[str, Any]: """Get comprehensive onboarding status for user.""" session_db = db or self.db if not session_db: raise ValueError("Database session required") try: session = self.get_session_by_user(user_id, session_db) if not session: # User hasn't started onboarding yet return { "is_completed": False, "current_step": 1, "progress": 0.0, "started_at": None, "updated_at": None } return { "is_completed": session.current_step >= 6 and session.progress >= 100.0, "current_step": session.current_step, "progress": session.progress, "started_at": session.started_at.isoformat() if session.started_at else None, "updated_at": session.updated_at.isoformat() if session.updated_at else None } except SQLAlchemyError as e: logger.error(f"Error getting onboarding status: {e}") return { "is_completed": False, "current_step": 1, "progress": 0.0, "started_at": None, "updated_at": None }