Files
moreminimore-marketing/backend/services/onboarding/database_service.py
Kunthawat Greethong c35fa52117 Base code
2026-01-08 22:39:53 +07:00

667 lines
29 KiB
Python

"""
Onboarding Database Service
Provides database-backed storage for onboarding progress with user isolation.
This replaces the JSON file-based storage with proper database persistence.
"""
from typing import Dict, Any, Optional, List
import os
import json
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import text
from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData
from services.database import get_db
class OnboardingDatabaseService:
"""Database service for onboarding with user isolation."""
def __init__(self, db: Session = None):
"""Initialize with optional database session."""
self.db = db
# Cache for schema feature detection
self._brand_cols_checked: bool = False
self._brand_cols_available: bool = False
self._research_persona_cols_checked: bool = False
self._research_persona_cols_available: bool = False
# --- Feature flags and schema detection helpers ---
def _brand_feature_enabled(self) -> bool:
"""Check if writing brand-related columns is enabled via env flag."""
return os.getenv('ENABLE_WEBSITE_BRAND_COLUMNS', 'true').lower() in {'1', 'true', 'yes', 'on'}
def _ensure_research_persona_columns(self, session_db: Session) -> None:
"""Ensure research_persona columns exist in persona_data table (runtime migration)."""
if self._research_persona_cols_checked:
return
try:
# Check if columns exist using PRAGMA (SQLite) or information_schema (PostgreSQL)
db_url = str(session_db.bind.url) if session_db.bind else ""
if 'sqlite' in db_url.lower():
# SQLite: Use PRAGMA to check columns
result = session_db.execute(text("PRAGMA table_info(persona_data)"))
cols = {row[1] for row in result} # Column name is at index 1
if 'research_persona' not in cols:
logger.info("Adding missing column research_persona to persona_data table")
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSON"))
session_db.commit()
if 'research_persona_generated_at' not in cols:
logger.info("Adding missing column research_persona_generated_at to persona_data table")
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP"))
session_db.commit()
self._research_persona_cols_available = True
else:
# PostgreSQL: Try to query the columns (will fail if they don't exist)
try:
session_db.execute(text("SELECT research_persona, research_persona_generated_at FROM persona_data LIMIT 0"))
self._research_persona_cols_available = True
except Exception:
# Columns don't exist, add them
logger.info("Adding missing columns research_persona and research_persona_generated_at to persona_data table")
try:
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSONB"))
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP"))
session_db.commit()
self._research_persona_cols_available = True
except Exception as alter_err:
logger.error(f"Failed to add research_persona columns: {alter_err}")
session_db.rollback()
raise
except Exception as e:
logger.error(f"Error ensuring research_persona columns: {e}")
session_db.rollback()
raise
finally:
self._research_persona_cols_checked = True
def _ensure_brand_column_detection(self, session_db: Session) -> None:
"""Detect at runtime whether brand columns exist and cache the result."""
if self._brand_cols_checked:
return
try:
# This works across SQLite/Postgres; LIMIT 0 avoids scanning
session_db.execute(text('SELECT brand_analysis, content_strategy_insights FROM website_analyses LIMIT 0'))
self._brand_cols_available = True
except Exception:
self._brand_cols_available = False
finally:
self._brand_cols_checked = True
def _maybe_update_brand_columns(self, session_db: Session, session_id: int, brand_analysis: Any, content_strategy_insights: Any) -> None:
"""Safely update brand columns using raw SQL if feature enabled and columns exist."""
if not self._brand_feature_enabled():
return
self._ensure_brand_column_detection(session_db)
if not self._brand_cols_available:
return
try:
session_db.execute(
text('''
UPDATE website_analyses
SET brand_analysis = :brand_analysis,
content_strategy_insights = :content_strategy_insights
WHERE session_id = :session_id
'''),
{
'brand_analysis': json.dumps(brand_analysis) if brand_analysis is not None else None,
'content_strategy_insights': json.dumps(content_strategy_insights) if content_strategy_insights is not None else None,
'session_id': session_id,
}
)
except Exception as e:
logger.warning(f"Skipped updating brand columns (not critical): {e}")
def _maybe_attach_brand_columns(self, session_db: Session, session_id: int, result: Dict[str, Any]) -> None:
"""Optionally read brand columns and attach to result if available."""
if not self._brand_feature_enabled():
return
self._ensure_brand_column_detection(session_db)
if not self._brand_cols_available:
return
try:
row = session_db.execute(
text('''
SELECT brand_analysis, content_strategy_insights
FROM website_analyses WHERE session_id = :session_id LIMIT 1
'''),
{'session_id': session_id}
).mappings().first()
if row:
brand = row.get('brand_analysis')
insights = row.get('content_strategy_insights')
# If stored as TEXT in SQLite, try to parse JSON
if isinstance(brand, str):
try:
brand = json.loads(brand)
except Exception:
pass
if isinstance(insights, str):
try:
insights = json.loads(insights)
except Exception:
pass
result['brand_analysis'] = brand
result['content_strategy_insights'] = insights
except Exception as e:
logger.warning(f"Skipped reading brand columns (not critical): {e}")
def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession:
"""Get existing onboarding session or create new one for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
# Try to get existing session for this user
session = session_db.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if session:
logger.info(f"Found existing onboarding session for user {user_id}")
return session
# Create new session
session = OnboardingSession(
user_id=user_id,
current_step=1,
progress=0.0,
started_at=datetime.now()
)
session_db.add(session)
session_db.commit()
session_db.refresh(session)
logger.info(f"Created new onboarding session for user {user_id}")
return session
except SQLAlchemyError as e:
logger.error(f"Database error in get_or_create_session: {e}")
session_db.rollback()
raise
def get_session_by_user(self, user_id: str, db: Session = None) -> Optional[OnboardingSession]:
"""Get onboarding session for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
return session_db.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
except SQLAlchemyError as e:
logger.error(f"Error getting session: {e}")
return None
def update_step(self, user_id: str, step_number: int, db: Session = None) -> bool:
"""Update current step for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.current_step = step_number
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Updated user {user_id} to step {step_number}")
return True
except SQLAlchemyError as e:
logger.error(f"Error updating step: {e}")
session_db.rollback()
return False
def update_progress(self, user_id: str, progress: float, db: Session = None) -> bool:
"""Update progress percentage for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.progress = progress
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Updated user {user_id} progress to {progress}%")
return True
except SQLAlchemyError as e:
logger.error(f"Error updating progress: {e}")
session_db.rollback()
return False
def save_api_key(self, user_id: str, provider: str, api_key: str, db: Session = None) -> bool:
"""Save API key for user with isolation."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
# Get user's onboarding session
session = self.get_or_create_session(user_id, session_db)
# Check if key already exists for this provider and session
existing_key = session_db.query(APIKey).filter(
APIKey.session_id == session.id,
APIKey.provider == provider
).first()
if existing_key:
# Update existing key
existing_key.key = api_key
existing_key.updated_at = datetime.now()
logger.info(f"Updated {provider} API key for user {user_id}")
else:
# Create new key
new_key = APIKey(
session_id=session.id,
provider=provider,
key=api_key
)
session_db.add(new_key)
logger.info(f"Created new {provider} API key for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving API key: {e}")
session_db.rollback()
return False
def get_api_keys(self, user_id: str, db: Session = None) -> Dict[str, str]:
"""Get all API keys for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return {}
keys = session_db.query(APIKey).filter(
APIKey.session_id == session.id
).all()
return {key.provider: key.key for key in keys}
except SQLAlchemyError as e:
logger.error(f"Error getting API keys: {e}")
return {}
def save_website_analysis(self, user_id: str, analysis_data: Dict[str, Any], db: Session = None) -> bool:
"""Save website analysis for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Normalize payload. Step 2 sometimes sends { website, analysis: {...} }
# while DB expects flattened fields. Support both shapes.
incoming = analysis_data or {}
nested = incoming.get('analysis') if isinstance(incoming.get('analysis'), dict) else None
normalized = {
'website_url': incoming.get('website') or incoming.get('website_url') or '',
'writing_style': (nested or incoming).get('writing_style'),
'content_characteristics': (nested or incoming).get('content_characteristics'),
'target_audience': (nested or incoming).get('target_audience'),
'content_type': (nested or incoming).get('content_type'),
'recommended_settings': (nested or incoming).get('recommended_settings'),
'brand_analysis': (nested or incoming).get('brand_analysis'),
'content_strategy_insights': (nested or incoming).get('content_strategy_insights'),
'crawl_result': (nested or incoming).get('crawl_result'),
'style_patterns': (nested or incoming).get('style_patterns'),
'style_guidelines': (nested or incoming).get('style_guidelines'),
'status': (nested or incoming).get('status', incoming.get('status', 'completed')),
}
# Check if analysis already exists
existing = session_db.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
if existing:
# Update existing - only update website_url if normalized value is not empty
# This prevents overwriting a valid URL with an empty string when step.data
# doesn't include the website field
normalized_url = normalized.get('website_url', '').strip() if normalized.get('website_url') else ''
if normalized_url:
existing.website_url = normalized_url
# If normalized_url is empty, keep existing.website_url unchanged
existing.writing_style = normalized.get('writing_style')
existing.content_characteristics = normalized.get('content_characteristics')
existing.target_audience = normalized.get('target_audience')
existing.content_type = normalized.get('content_type')
existing.recommended_settings = normalized.get('recommended_settings')
existing.crawl_result = normalized.get('crawl_result')
existing.style_patterns = normalized.get('style_patterns')
existing.style_guidelines = normalized.get('style_guidelines')
existing.status = normalized.get('status', 'completed')
existing.updated_at = datetime.now()
logger.info(f"Updated website analysis for user {user_id}")
else:
# Create new
analysis = WebsiteAnalysis(
session_id=session.id,
website_url=normalized.get('website_url', ''),
writing_style=normalized.get('writing_style'),
content_characteristics=normalized.get('content_characteristics'),
target_audience=normalized.get('target_audience'),
content_type=normalized.get('content_type'),
recommended_settings=normalized.get('recommended_settings'),
crawl_result=normalized.get('crawl_result'),
style_patterns=normalized.get('style_patterns'),
style_guidelines=normalized.get('style_guidelines'),
status=normalized.get('status', 'completed')
)
session_db.add(analysis)
logger.info(f"Created website analysis for user {user_id}")
session_db.commit()
# Optional brand column update via raw SQL (feature-flagged)
self._maybe_update_brand_columns(
session_db=session_db,
session_id=session.id,
brand_analysis=normalized.get('brand_analysis'),
content_strategy_insights=normalized.get('content_strategy_insights')
)
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving website analysis: {e}")
session_db.rollback()
return False
def get_website_analysis(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get website analysis for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
analysis = session_db.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
result = analysis.to_dict() if analysis else None
if result:
# Optionally include brand fields without touching ORM mapping
self._maybe_attach_brand_columns(session_db, session.id, result)
return result
except SQLAlchemyError as e:
logger.error(f"Error getting website analysis: {e}")
return None
def save_research_preferences(self, user_id: str, preferences: Dict[str, Any], db: Session = None) -> bool:
"""Save research preferences for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Check if preferences already exist
existing = session_db.query(ResearchPreferences).filter(
ResearchPreferences.session_id == session.id
).first()
if existing:
# Update existing
existing.research_depth = preferences.get('research_depth', existing.research_depth)
existing.content_types = preferences.get('content_types', existing.content_types)
existing.auto_research = preferences.get('auto_research', existing.auto_research)
existing.factual_content = preferences.get('factual_content', existing.factual_content)
existing.writing_style = preferences.get('writing_style')
existing.content_characteristics = preferences.get('content_characteristics')
existing.target_audience = preferences.get('target_audience')
existing.recommended_settings = preferences.get('recommended_settings')
existing.updated_at = datetime.now()
logger.info(f"Updated research preferences for user {user_id}")
else:
# Create new
prefs = ResearchPreferences(
session_id=session.id,
research_depth=preferences.get('research_depth', 'standard'),
content_types=preferences.get('content_types', []),
auto_research=preferences.get('auto_research', True),
factual_content=preferences.get('factual_content', True),
writing_style=preferences.get('writing_style'),
content_characteristics=preferences.get('content_characteristics'),
target_audience=preferences.get('target_audience'),
recommended_settings=preferences.get('recommended_settings')
)
session_db.add(prefs)
logger.info(f"Created research preferences for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving research preferences: {e}")
session_db.rollback()
return False
def save_persona_data(self, user_id: str, persona_data: Dict[str, Any], db: Session = None) -> bool:
"""Save persona data for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Check if persona data already exists for this user
existing = session_db.query(PersonaData).filter(
PersonaData.session_id == session.id
).first()
if existing:
# Update existing persona data
existing.core_persona = persona_data.get('corePersona')
existing.platform_personas = persona_data.get('platformPersonas')
existing.quality_metrics = persona_data.get('qualityMetrics')
existing.selected_platforms = persona_data.get('selectedPlatforms', [])
existing.updated_at = datetime.utcnow()
logger.info(f"Updated persona data for user {user_id}")
else:
# Create new persona data record
persona = PersonaData(
session_id=session.id,
core_persona=persona_data.get('corePersona'),
platform_personas=persona_data.get('platformPersonas'),
quality_metrics=persona_data.get('qualityMetrics'),
selected_platforms=persona_data.get('selectedPlatforms', [])
)
session_db.add(persona)
logger.info(f"Created persona data for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving persona data: {e}")
session_db.rollback()
return False
def get_research_preferences(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get research preferences for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
prefs = session_db.query(ResearchPreferences).filter(
ResearchPreferences.session_id == session.id
).first()
return prefs.to_dict() if prefs else None
except SQLAlchemyError as e:
logger.error(f"Error getting research preferences: {e}")
return None
def get_competitor_analysis(self, user_id: str, db: Session = None) -> Optional[List[Dict[str, Any]]]:
"""Get competitor analysis data for user from onboarding."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
from models.onboarding import CompetitorAnalysis
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
# Query CompetitorAnalysis table
competitor_records = session_db.query(CompetitorAnalysis).filter(
CompetitorAnalysis.session_id == session.id
).all()
if not competitor_records:
return None
# Convert to list of dicts
competitors = []
for record in competitor_records:
analysis_data = record.analysis_data or {}
competitors.append({
"url": record.competitor_url,
"domain": record.competitor_domain or record.competitor_url,
"title": analysis_data.get("title", record.competitor_domain or ""),
"summary": analysis_data.get("summary", ""),
"relevance_score": analysis_data.get("relevance_score", 0.5),
"highlights": analysis_data.get("highlights", []),
"favicon": analysis_data.get("favicon"),
"image": analysis_data.get("image"),
"published_date": analysis_data.get("published_date"),
"author": analysis_data.get("author"),
"competitive_insights": analysis_data.get("competitive_analysis", {}),
"content_insights": analysis_data.get("content_insights", {})
})
return competitors
except SQLAlchemyError as e:
logger.error(f"Error getting competitor analysis: {e}")
return None
def get_persona_data(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get persona data for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
# Ensure research_persona columns exist before querying
self._ensure_research_persona_columns(session_db)
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
persona = session_db.query(PersonaData).filter(
PersonaData.session_id == session.id
).first()
if not persona:
return None
# Return persona data in the expected format
return {
'corePersona': persona.core_persona,
'platformPersonas': persona.platform_personas,
'qualityMetrics': persona.quality_metrics,
'selectedPlatforms': persona.selected_platforms
}
except SQLAlchemyError as e:
logger.error(f"Error getting persona data: {e}")
return None
def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool:
"""Mark onboarding as complete for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.current_step = 6 # Final step
session.progress = 100.0
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Marked onboarding complete for user {user_id}")
return True
except SQLAlchemyError as e:
logger.error(f"Error marking onboarding complete: {e}")
session_db.rollback()
return False
def get_onboarding_status(self, user_id: str, db: Session = None) -> Dict[str, Any]:
"""Get comprehensive onboarding status for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
# User hasn't started onboarding yet
return {
"is_completed": False,
"current_step": 1,
"progress": 0.0,
"started_at": None,
"updated_at": None
}
return {
"is_completed": session.current_step >= 6 and session.progress >= 100.0,
"current_step": session.current_step,
"progress": session.progress,
"started_at": session.started_at.isoformat() if session.started_at else None,
"updated_at": session.updated_at.isoformat() if session.updated_at else None
}
except SQLAlchemyError as e:
logger.error(f"Error getting onboarding status: {e}")
return {
"is_completed": False,
"current_step": 1,
"progress": 0.0,
"started_at": None,
"updated_at": None
}