Add brand analysis columns to onboarding database and migration scripts
This commit is contained in:
@@ -170,8 +170,36 @@ class OnboardingProgress:
|
||||
required_steps = [1, 2, 3, 6] # Steps 1, 2, 3, and 6 are required
|
||||
for step_num in required_steps:
|
||||
step = self.get_step_data(step_num)
|
||||
if step and step.status not in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
|
||||
return False
|
||||
if step and step.status in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
|
||||
continue
|
||||
|
||||
# DB-aware fallback for steps 2 and 3
|
||||
try:
|
||||
from services.onboarding_database_service import OnboardingDatabaseService
|
||||
from services.database import get_db
|
||||
db = next(get_db())
|
||||
db_service = OnboardingDatabaseService(db)
|
||||
if step_num == 2:
|
||||
w = db_service.get_website_analysis(self.user_id, db)
|
||||
if w and (w.get('website_url') or w.get('writing_style')):
|
||||
# Mark as completed to normalize state
|
||||
try:
|
||||
self.mark_step_completed(2, {'source': 'db-fallback'})
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
if step_num == 3:
|
||||
p = db_service.get_research_preferences(self.user_id, db)
|
||||
if p and p.get('research_depth'):
|
||||
try:
|
||||
self.mark_step_completed(3, {'source': 'db-fallback'})
|
||||
except Exception:
|
||||
pass
|
||||
continue
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_completion_percentage(self) -> float:
|
||||
|
||||
@@ -5,10 +5,13 @@ This replaces the JSON file-based storage with proper database persistence.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy import text
|
||||
|
||||
from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData
|
||||
from services.database import get_db
|
||||
@@ -20,6 +23,85 @@ class OnboardingDatabaseService:
|
||||
def __init__(self, db: Session = None):
|
||||
"""Initialize with optional database session."""
|
||||
self.db = db
|
||||
# Cache for schema feature detection
|
||||
self._brand_cols_checked: bool = False
|
||||
self._brand_cols_available: bool = False
|
||||
|
||||
# --- Feature flags and schema detection helpers ---
|
||||
def _brand_feature_enabled(self) -> bool:
|
||||
"""Check if writing brand-related columns is enabled via env flag."""
|
||||
return os.getenv('ENABLE_WEBSITE_BRAND_COLUMNS', 'true').lower() in {'1', 'true', 'yes', 'on'}
|
||||
|
||||
def _ensure_brand_column_detection(self, session_db: Session) -> None:
|
||||
"""Detect at runtime whether brand columns exist and cache the result."""
|
||||
if self._brand_cols_checked:
|
||||
return
|
||||
try:
|
||||
# This works across SQLite/Postgres; LIMIT 0 avoids scanning
|
||||
session_db.execute(text('SELECT brand_analysis, content_strategy_insights FROM website_analyses LIMIT 0'))
|
||||
self._brand_cols_available = True
|
||||
except Exception:
|
||||
self._brand_cols_available = False
|
||||
finally:
|
||||
self._brand_cols_checked = True
|
||||
|
||||
def _maybe_update_brand_columns(self, session_db: Session, session_id: int, brand_analysis: Any, content_strategy_insights: Any) -> None:
|
||||
"""Safely update brand columns using raw SQL if feature enabled and columns exist."""
|
||||
if not self._brand_feature_enabled():
|
||||
return
|
||||
self._ensure_brand_column_detection(session_db)
|
||||
if not self._brand_cols_available:
|
||||
return
|
||||
try:
|
||||
session_db.execute(
|
||||
text('''
|
||||
UPDATE website_analyses
|
||||
SET brand_analysis = :brand_analysis,
|
||||
content_strategy_insights = :content_strategy_insights
|
||||
WHERE session_id = :session_id
|
||||
'''),
|
||||
{
|
||||
'brand_analysis': json.dumps(brand_analysis) if brand_analysis is not None else None,
|
||||
'content_strategy_insights': json.dumps(content_strategy_insights) if content_strategy_insights is not None else None,
|
||||
'session_id': session_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipped updating brand columns (not critical): {e}")
|
||||
|
||||
def _maybe_attach_brand_columns(self, session_db: Session, session_id: int, result: Dict[str, Any]) -> None:
|
||||
"""Optionally read brand columns and attach to result if available."""
|
||||
if not self._brand_feature_enabled():
|
||||
return
|
||||
self._ensure_brand_column_detection(session_db)
|
||||
if not self._brand_cols_available:
|
||||
return
|
||||
try:
|
||||
row = session_db.execute(
|
||||
text('''
|
||||
SELECT brand_analysis, content_strategy_insights
|
||||
FROM website_analyses WHERE session_id = :session_id LIMIT 1
|
||||
'''),
|
||||
{'session_id': session_id}
|
||||
).mappings().first()
|
||||
if row:
|
||||
brand = row.get('brand_analysis')
|
||||
insights = row.get('content_strategy_insights')
|
||||
# If stored as TEXT in SQLite, try to parse JSON
|
||||
if isinstance(brand, str):
|
||||
try:
|
||||
brand = json.loads(brand)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(insights, str):
|
||||
try:
|
||||
insights = json.loads(insights)
|
||||
except Exception:
|
||||
pass
|
||||
result['brand_analysis'] = brand
|
||||
result['content_strategy_insights'] = insights
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipped reading brand columns (not critical): {e}")
|
||||
|
||||
def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession:
|
||||
"""Get existing onboarding session or create new one for user."""
|
||||
@@ -178,6 +260,24 @@ class OnboardingDatabaseService:
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
# Normalize payload. Step 2 sometimes sends { website, analysis: {...} }
|
||||
# while DB expects flattened fields. Support both shapes.
|
||||
incoming = analysis_data or {}
|
||||
nested = incoming.get('analysis') if isinstance(incoming.get('analysis'), dict) else None
|
||||
normalized = {
|
||||
'website_url': incoming.get('website') or incoming.get('website_url') or '',
|
||||
'writing_style': (nested or incoming).get('writing_style'),
|
||||
'content_characteristics': (nested or incoming).get('content_characteristics'),
|
||||
'target_audience': (nested or incoming).get('target_audience'),
|
||||
'content_type': (nested or incoming).get('content_type'),
|
||||
'recommended_settings': (nested or incoming).get('recommended_settings'),
|
||||
'brand_analysis': (nested or incoming).get('brand_analysis'),
|
||||
'content_strategy_insights': (nested or incoming).get('content_strategy_insights'),
|
||||
'crawl_result': (nested or incoming).get('crawl_result'),
|
||||
'style_patterns': (nested or incoming).get('style_patterns'),
|
||||
'style_guidelines': (nested or incoming).get('style_guidelines'),
|
||||
'status': (nested or incoming).get('status', incoming.get('status', 'completed')),
|
||||
}
|
||||
|
||||
# Check if analysis already exists
|
||||
existing = session_db.query(WebsiteAnalysis).filter(
|
||||
@@ -186,37 +286,46 @@ class OnboardingDatabaseService:
|
||||
|
||||
if existing:
|
||||
# Update existing
|
||||
existing.website_url = analysis_data.get('website_url', existing.website_url)
|
||||
existing.writing_style = analysis_data.get('writing_style')
|
||||
existing.content_characteristics = analysis_data.get('content_characteristics')
|
||||
existing.target_audience = analysis_data.get('target_audience')
|
||||
existing.content_type = analysis_data.get('content_type')
|
||||
existing.recommended_settings = analysis_data.get('recommended_settings')
|
||||
existing.crawl_result = analysis_data.get('crawl_result')
|
||||
existing.style_patterns = analysis_data.get('style_patterns')
|
||||
existing.style_guidelines = analysis_data.get('style_guidelines')
|
||||
existing.status = analysis_data.get('status', 'completed')
|
||||
existing.website_url = normalized.get('website_url', existing.website_url)
|
||||
existing.writing_style = normalized.get('writing_style')
|
||||
existing.content_characteristics = normalized.get('content_characteristics')
|
||||
existing.target_audience = normalized.get('target_audience')
|
||||
existing.content_type = normalized.get('content_type')
|
||||
existing.recommended_settings = normalized.get('recommended_settings')
|
||||
existing.crawl_result = normalized.get('crawl_result')
|
||||
existing.style_patterns = normalized.get('style_patterns')
|
||||
existing.style_guidelines = normalized.get('style_guidelines')
|
||||
existing.status = normalized.get('status', 'completed')
|
||||
existing.updated_at = datetime.now()
|
||||
logger.info(f"Updated website analysis for user {user_id}")
|
||||
else:
|
||||
# Create new
|
||||
analysis = WebsiteAnalysis(
|
||||
session_id=session.id,
|
||||
website_url=analysis_data.get('website_url', ''),
|
||||
writing_style=analysis_data.get('writing_style'),
|
||||
content_characteristics=analysis_data.get('content_characteristics'),
|
||||
target_audience=analysis_data.get('target_audience'),
|
||||
content_type=analysis_data.get('content_type'),
|
||||
recommended_settings=analysis_data.get('recommended_settings'),
|
||||
crawl_result=analysis_data.get('crawl_result'),
|
||||
style_patterns=analysis_data.get('style_patterns'),
|
||||
style_guidelines=analysis_data.get('style_guidelines'),
|
||||
status=analysis_data.get('status', 'completed')
|
||||
website_url=normalized.get('website_url', ''),
|
||||
writing_style=normalized.get('writing_style'),
|
||||
content_characteristics=normalized.get('content_characteristics'),
|
||||
target_audience=normalized.get('target_audience'),
|
||||
content_type=normalized.get('content_type'),
|
||||
recommended_settings=normalized.get('recommended_settings'),
|
||||
crawl_result=normalized.get('crawl_result'),
|
||||
style_patterns=normalized.get('style_patterns'),
|
||||
style_guidelines=normalized.get('style_guidelines'),
|
||||
status=normalized.get('status', 'completed')
|
||||
)
|
||||
session_db.add(analysis)
|
||||
logger.info(f"Created website analysis for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
|
||||
# Optional brand column update via raw SQL (feature-flagged)
|
||||
self._maybe_update_brand_columns(
|
||||
session_db=session_db,
|
||||
session_id=session.id,
|
||||
brand_analysis=normalized.get('brand_analysis'),
|
||||
content_strategy_insights=normalized.get('content_strategy_insights')
|
||||
)
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
@@ -239,7 +348,11 @@ class OnboardingDatabaseService:
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
return analysis.to_dict() if analysis else None
|
||||
result = analysis.to_dict() if analysis else None
|
||||
if result:
|
||||
# Optionally include brand fields without touching ORM mapping
|
||||
self._maybe_attach_brand_columns(session_db, session.id, result)
|
||||
return result
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting website analysis: {e}")
|
||||
@@ -358,6 +471,36 @@ class OnboardingDatabaseService:
|
||||
logger.error(f"Error getting research preferences: {e}")
|
||||
return None
|
||||
|
||||
def get_persona_data(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get persona data for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
persona = session_db.query(PersonaData).filter(
|
||||
PersonaData.session_id == session.id
|
||||
).first()
|
||||
|
||||
if not persona:
|
||||
return None
|
||||
|
||||
# Return persona data in the expected format
|
||||
return {
|
||||
'corePersona': persona.core_persona,
|
||||
'platformPersonas': persona.platform_personas,
|
||||
'qualityMetrics': persona.quality_metrics,
|
||||
'selectedPlatforms': persona.selected_platforms
|
||||
}
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting persona data: {e}")
|
||||
return None
|
||||
|
||||
def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool:
|
||||
"""Mark onboarding as complete for user."""
|
||||
session_db = db or self.db
|
||||
|
||||
Reference in New Issue
Block a user