Add brand analysis columns to onboarding database and migration scripts

This commit is contained in:
ajaysi
2025-10-11 17:05:42 +05:30
parent b1ebe1034e
commit 1df12a64a2
25 changed files with 2415 additions and 90 deletions

View File

@@ -170,8 +170,36 @@ class OnboardingProgress:
required_steps = [1, 2, 3, 6] # Steps 1, 2, 3, and 6 are required
for step_num in required_steps:
step = self.get_step_data(step_num)
if step and step.status not in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
return False
if step and step.status in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
continue
# DB-aware fallback for steps 2 and 3
try:
from services.onboarding_database_service import OnboardingDatabaseService
from services.database import get_db
db = next(get_db())
db_service = OnboardingDatabaseService(db)
if step_num == 2:
w = db_service.get_website_analysis(self.user_id, db)
if w and (w.get('website_url') or w.get('writing_style')):
# Mark as completed to normalize state
try:
self.mark_step_completed(2, {'source': 'db-fallback'})
except Exception:
pass
continue
if step_num == 3:
p = db_service.get_research_preferences(self.user_id, db)
if p and p.get('research_depth'):
try:
self.mark_step_completed(3, {'source': 'db-fallback'})
except Exception:
pass
continue
except Exception:
pass
return False
return True
def get_completion_percentage(self) -> float:

View File

@@ -5,10 +5,13 @@ This replaces the JSON file-based storage with proper database persistence.
"""
from typing import Dict, Any, Optional, List
import os
import json
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import text
from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData
from services.database import get_db
@@ -20,6 +23,85 @@ class OnboardingDatabaseService:
def __init__(self, db: Session = None):
"""Initialize with optional database session."""
self.db = db
# Cache for schema feature detection
self._brand_cols_checked: bool = False
self._brand_cols_available: bool = False
# --- Feature flags and schema detection helpers ---
def _brand_feature_enabled(self) -> bool:
"""Check if writing brand-related columns is enabled via env flag."""
return os.getenv('ENABLE_WEBSITE_BRAND_COLUMNS', 'true').lower() in {'1', 'true', 'yes', 'on'}
def _ensure_brand_column_detection(self, session_db: Session) -> None:
"""Detect at runtime whether brand columns exist and cache the result."""
if self._brand_cols_checked:
return
try:
# This works across SQLite/Postgres; LIMIT 0 avoids scanning
session_db.execute(text('SELECT brand_analysis, content_strategy_insights FROM website_analyses LIMIT 0'))
self._brand_cols_available = True
except Exception:
self._brand_cols_available = False
finally:
self._brand_cols_checked = True
def _maybe_update_brand_columns(self, session_db: Session, session_id: int, brand_analysis: Any, content_strategy_insights: Any) -> None:
"""Safely update brand columns using raw SQL if feature enabled and columns exist."""
if not self._brand_feature_enabled():
return
self._ensure_brand_column_detection(session_db)
if not self._brand_cols_available:
return
try:
session_db.execute(
text('''
UPDATE website_analyses
SET brand_analysis = :brand_analysis,
content_strategy_insights = :content_strategy_insights
WHERE session_id = :session_id
'''),
{
'brand_analysis': json.dumps(brand_analysis) if brand_analysis is not None else None,
'content_strategy_insights': json.dumps(content_strategy_insights) if content_strategy_insights is not None else None,
'session_id': session_id,
}
)
except Exception as e:
logger.warning(f"Skipped updating brand columns (not critical): {e}")
def _maybe_attach_brand_columns(self, session_db: Session, session_id: int, result: Dict[str, Any]) -> None:
"""Optionally read brand columns and attach to result if available."""
if not self._brand_feature_enabled():
return
self._ensure_brand_column_detection(session_db)
if not self._brand_cols_available:
return
try:
row = session_db.execute(
text('''
SELECT brand_analysis, content_strategy_insights
FROM website_analyses WHERE session_id = :session_id LIMIT 1
'''),
{'session_id': session_id}
).mappings().first()
if row:
brand = row.get('brand_analysis')
insights = row.get('content_strategy_insights')
# If stored as TEXT in SQLite, try to parse JSON
if isinstance(brand, str):
try:
brand = json.loads(brand)
except Exception:
pass
if isinstance(insights, str):
try:
insights = json.loads(insights)
except Exception:
pass
result['brand_analysis'] = brand
result['content_strategy_insights'] = insights
except Exception as e:
logger.warning(f"Skipped reading brand columns (not critical): {e}")
def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession:
"""Get existing onboarding session or create new one for user."""
@@ -178,6 +260,24 @@ class OnboardingDatabaseService:
try:
session = self.get_or_create_session(user_id, session_db)
# Normalize payload. Step 2 sometimes sends { website, analysis: {...} }
# while DB expects flattened fields. Support both shapes.
incoming = analysis_data or {}
nested = incoming.get('analysis') if isinstance(incoming.get('analysis'), dict) else None
normalized = {
'website_url': incoming.get('website') or incoming.get('website_url') or '',
'writing_style': (nested or incoming).get('writing_style'),
'content_characteristics': (nested or incoming).get('content_characteristics'),
'target_audience': (nested or incoming).get('target_audience'),
'content_type': (nested or incoming).get('content_type'),
'recommended_settings': (nested or incoming).get('recommended_settings'),
'brand_analysis': (nested or incoming).get('brand_analysis'),
'content_strategy_insights': (nested or incoming).get('content_strategy_insights'),
'crawl_result': (nested or incoming).get('crawl_result'),
'style_patterns': (nested or incoming).get('style_patterns'),
'style_guidelines': (nested or incoming).get('style_guidelines'),
'status': (nested or incoming).get('status', incoming.get('status', 'completed')),
}
# Check if analysis already exists
existing = session_db.query(WebsiteAnalysis).filter(
@@ -186,37 +286,46 @@ class OnboardingDatabaseService:
if existing:
# Update existing
existing.website_url = analysis_data.get('website_url', existing.website_url)
existing.writing_style = analysis_data.get('writing_style')
existing.content_characteristics = analysis_data.get('content_characteristics')
existing.target_audience = analysis_data.get('target_audience')
existing.content_type = analysis_data.get('content_type')
existing.recommended_settings = analysis_data.get('recommended_settings')
existing.crawl_result = analysis_data.get('crawl_result')
existing.style_patterns = analysis_data.get('style_patterns')
existing.style_guidelines = analysis_data.get('style_guidelines')
existing.status = analysis_data.get('status', 'completed')
existing.website_url = normalized.get('website_url', existing.website_url)
existing.writing_style = normalized.get('writing_style')
existing.content_characteristics = normalized.get('content_characteristics')
existing.target_audience = normalized.get('target_audience')
existing.content_type = normalized.get('content_type')
existing.recommended_settings = normalized.get('recommended_settings')
existing.crawl_result = normalized.get('crawl_result')
existing.style_patterns = normalized.get('style_patterns')
existing.style_guidelines = normalized.get('style_guidelines')
existing.status = normalized.get('status', 'completed')
existing.updated_at = datetime.now()
logger.info(f"Updated website analysis for user {user_id}")
else:
# Create new
analysis = WebsiteAnalysis(
session_id=session.id,
website_url=analysis_data.get('website_url', ''),
writing_style=analysis_data.get('writing_style'),
content_characteristics=analysis_data.get('content_characteristics'),
target_audience=analysis_data.get('target_audience'),
content_type=analysis_data.get('content_type'),
recommended_settings=analysis_data.get('recommended_settings'),
crawl_result=analysis_data.get('crawl_result'),
style_patterns=analysis_data.get('style_patterns'),
style_guidelines=analysis_data.get('style_guidelines'),
status=analysis_data.get('status', 'completed')
website_url=normalized.get('website_url', ''),
writing_style=normalized.get('writing_style'),
content_characteristics=normalized.get('content_characteristics'),
target_audience=normalized.get('target_audience'),
content_type=normalized.get('content_type'),
recommended_settings=normalized.get('recommended_settings'),
crawl_result=normalized.get('crawl_result'),
style_patterns=normalized.get('style_patterns'),
style_guidelines=normalized.get('style_guidelines'),
status=normalized.get('status', 'completed')
)
session_db.add(analysis)
logger.info(f"Created website analysis for user {user_id}")
session_db.commit()
# Optional brand column update via raw SQL (feature-flagged)
self._maybe_update_brand_columns(
session_db=session_db,
session_id=session.id,
brand_analysis=normalized.get('brand_analysis'),
content_strategy_insights=normalized.get('content_strategy_insights')
)
session_db.commit()
return True
except SQLAlchemyError as e:
@@ -239,7 +348,11 @@ class OnboardingDatabaseService:
WebsiteAnalysis.session_id == session.id
).first()
return analysis.to_dict() if analysis else None
result = analysis.to_dict() if analysis else None
if result:
# Optionally include brand fields without touching ORM mapping
self._maybe_attach_brand_columns(session_db, session.id, result)
return result
except SQLAlchemyError as e:
logger.error(f"Error getting website analysis: {e}")
@@ -358,6 +471,36 @@ class OnboardingDatabaseService:
logger.error(f"Error getting research preferences: {e}")
return None
def get_persona_data(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get persona data for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
persona = session_db.query(PersonaData).filter(
PersonaData.session_id == session.id
).first()
if not persona:
return None
# Return persona data in the expected format
return {
'corePersona': persona.core_persona,
'platformPersonas': persona.platform_personas,
'qualityMetrics': persona.quality_metrics,
'selectedPlatforms': persona.selected_platforms
}
except SQLAlchemyError as e:
logger.error(f"Error getting persona data: {e}")
return None
def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool:
"""Mark onboarding as complete for user."""
session_db = db or self.db