Base code

This commit is contained in:
Kunthawat Greethong
2026-01-08 22:39:53 +07:00
parent 697115c61a
commit c35fa52117
2169 changed files with 626670 additions and 0 deletions

View File

@@ -0,0 +1,204 @@
# Onboarding Services Package
This package contains all onboarding-related services and utilities for ALwrity. All onboarding data is stored in the database with proper user isolation, replacing the previous file-based JSON storage system.
## Architecture
### Database-First Design
- **Primary Storage**: PostgreSQL database with proper foreign keys and relationships
- **User Isolation**: Each user's onboarding data is completely separate
- **No File Storage**: Removed all JSON file operations for production scalability
- **Local Development**: API keys still written to `.env` for developer convenience
### Service Structure
```
backend/services/onboarding/
├── __init__.py # Package exports
├── database_service.py # Core database operations
├── progress_service.py # Progress tracking and step management
├── data_service.py # Data validation and processing
├── api_key_manager.py # API key management + progress tracking
└── README.md # This documentation
```
## Services
### 1. OnboardingDatabaseService (`database_service.py`)
**Purpose**: Core database operations for onboarding data with user isolation.
**Key Features**:
- User-specific session management
- API key storage and retrieval
- Website analysis persistence
- Research preferences management
- Persona data storage
- Brand analysis support (feature-flagged)
**Main Methods**:
- `get_or_create_session(user_id)` - Get or create user session
- `save_api_key(user_id, provider, key)` - Store API keys
- `save_website_analysis(user_id, data)` - Store website analysis
- `save_research_preferences(user_id, prefs)` - Store research settings
- `save_persona_data(user_id, data)` - Store persona information
### 2. OnboardingProgressService (`progress_service.py`)
**Purpose**: High-level progress tracking and step management.
**Key Features**:
- Database-only progress tracking
- Step completion validation
- Progress percentage calculation
- Onboarding completion management
**Main Methods**:
- `get_onboarding_status(user_id)` - Get current status
- `update_step(user_id, step_number)` - Update current step
- `update_progress(user_id, percentage)` - Update progress
- `complete_onboarding(user_id)` - Mark as complete
### 3. OnboardingDataService (`data_service.py`)
**Purpose**: Extract and use onboarding data for AI personalization.
**Key Features**:
- Personalized AI input generation
- Website analysis data extraction
- Research preferences integration
- Default fallback data
**Main Methods**:
- `get_personalized_ai_inputs(user_id)` - Generate personalized inputs
- `get_user_website_analysis(user_id)` - Get website data
- `get_user_research_preferences(user_id)` - Get research settings
### 4. OnboardingProgress + APIKeyManager (`api_key_manager.py`)
**Purpose**: Combined API key management and progress tracking with database persistence.
**Key Features**:
- Database-only progress persistence (no JSON files)
- API key management with environment integration
- Step-by-step progress tracking
- User-specific progress instances
**Main Classes**:
- `OnboardingProgress` - Progress tracking with database persistence
- `APIKeyManager` - API key management
- `StepData` - Individual step data structure
- `StepStatus` - Step status enumeration
## Database Schema
### Core Tables
- `onboarding_sessions` - User session tracking
- `api_keys` - User-specific API key storage
- `website_analyses` - Website analysis data
- `research_preferences` - User research settings
- `persona_data` - Generated persona information
### Relationships
- All data tables reference `onboarding_sessions.id`
- User isolation via `user_id` foreign key
- Proper cascade deletion and updates
## Usage Examples
### Basic Progress Tracking
```python
from services.onboarding import OnboardingProgress
# Get user-specific progress
progress = OnboardingProgress(user_id="user123")
# Mark step as completed
progress.mark_step_completed(1, {"api_keys": {"openai": "sk-..."}})
# Get progress summary
summary = progress.get_progress_summary()
```
### Database Operations
```python
from services.onboarding import OnboardingDatabaseService
from services.database import SessionLocal
db = SessionLocal()
service = OnboardingDatabaseService(db)
# Save API key
service.save_api_key("user123", "openai", "sk-...")
# Get website analysis
analysis = service.get_website_analysis("user123", db)
```
### Progress Service
```python
from services.onboarding import OnboardingProgressService
service = OnboardingProgressService()
# Get status
status = service.get_onboarding_status("user123")
# Update progress
service.update_step("user123", 2)
service.update_progress("user123", 50.0)
```
## Migration from File-Based Storage
### What Was Removed
- JSON file operations (`.onboarding_progress*.json`)
- File-based progress persistence
- Dual persistence system (file + database)
### What Was Kept
- Database persistence (enhanced)
- Local development `.env` API key writing
- All existing functionality and APIs
### Benefits
- **Production Ready**: No ephemeral file storage
- **Scalable**: Database-backed with proper indexing
- **User Isolated**: Complete data separation
- **Maintainable**: Single source of truth
## Environment Variables
### Required
- Database connection (via `services.database`)
- User authentication system
### Optional
- `ENABLE_WEBSITE_BRAND_COLUMNS=true` - Enable brand analysis features
- `DEPLOY_ENV=local` - Enable local `.env` API key writing
## Error Handling
All services include comprehensive error handling:
- Database connection failures
- User not found scenarios
- Invalid data validation
- Graceful fallbacks to defaults
## Performance Considerations
- Database queries are optimized with proper indexing
- User-specific caching where appropriate
- Minimal database calls through efficient service design
- Connection pooling via SQLAlchemy
## Testing
Each service can be tested independently:
- Unit tests for individual methods
- Integration tests with database
- Mock database sessions for isolated testing
## Future Enhancements
- Real-time progress updates via WebSocket
- Progress analytics and reporting
- Bulk user operations
- Advanced validation rules
- Progress recovery mechanisms

View File

@@ -0,0 +1,35 @@
"""
Onboarding Services Package
This package contains all onboarding-related services and utilities.
All onboarding data is stored in the database with proper user isolation.
Services:
- OnboardingDatabaseService: Core database operations for onboarding data
- OnboardingProgressService: Progress tracking and step management
- OnboardingDataService: Data validation and processing
- OnboardingProgress: Progress tracking with database persistence (from api_key_manager)
Architecture:
- Database-first: All data stored in PostgreSQL with proper foreign keys
- User isolation: Each user's data is completely separate
- No file storage: Removed all JSON file operations for production scalability
- Local development: API keys still written to .env for convenience
"""
# Import all public classes for easy access
from .database_service import OnboardingDatabaseService
from .progress_service import OnboardingProgressService
from .data_service import OnboardingDataService
from .api_key_manager import OnboardingProgress, APIKeyManager, get_onboarding_progress, get_user_onboarding_progress, get_onboarding_progress_for_user
__all__ = [
'OnboardingDatabaseService',
'OnboardingProgressService',
'OnboardingDataService',
'OnboardingProgress',
'APIKeyManager',
'get_onboarding_progress',
'get_user_onboarding_progress',
'get_onboarding_progress_for_user'
]

View File

@@ -0,0 +1,495 @@
"""
API Key Manager with Database-Only Onboarding Progress
Manages API keys and onboarding progress with database persistence only.
Removed all file-based JSON storage for production scalability.
"""
import os
import json
from typing import Dict, Any, Optional, List
from datetime import datetime
from loguru import logger
from enum import Enum
from services.database import get_db_session
class StepStatus(Enum):
"""Onboarding step status."""
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
SKIPPED = "skipped"
FAILED = "failed"
class StepData:
"""Data structure for onboarding step."""
def __init__(self, step_number: int, title: str, description: str, status: StepStatus = StepStatus.PENDING):
self.step_number = step_number
self.title = title
self.description = description
self.status = status
self.completed_at = None
self.data = None
self.validation_errors = []
class OnboardingProgress:
"""Manages onboarding progress with database persistence only."""
def __init__(self, user_id: Optional[str] = None):
self.steps = self._initialize_steps()
self.current_step = 1
self.started_at = datetime.now().isoformat()
self.last_updated = datetime.now().isoformat()
self.is_completed = False
self.completed_at = None
self.user_id = user_id # Add user_id for database isolation
# Initialize database service for persistence
try:
from .database_service import OnboardingDatabaseService
self.db_service = OnboardingDatabaseService()
self.use_database = True
logger.info(f"Database service initialized for user {user_id}")
except Exception as e:
logger.error(f"Database service not available: {e}")
self.db_service = None
self.use_database = False
raise Exception(f"Database service required but not available: {e}")
# Load existing progress from database if available
if self.use_database and self.user_id:
self.load_progress_from_db()
def _initialize_steps(self) -> List[StepData]:
"""Initialize the 6-step onboarding process."""
return [
StepData(1, "AI LLM Providers", "Configure AI language model providers", StepStatus.PENDING),
StepData(2, "Website Analysis", "Set up website analysis and crawling", StepStatus.PENDING),
StepData(3, "AI Research", "Configure AI research capabilities", StepStatus.PENDING),
StepData(4, "Personalization", "Set up personalization features", StepStatus.PENDING),
StepData(5, "Integrations", "Configure ALwrity integrations", StepStatus.PENDING),
StepData(6, "Complete Setup", "Finalize and complete onboarding", StepStatus.PENDING)
]
def get_step_data(self, step_number: int) -> Optional[StepData]:
"""Get data for a specific step."""
for step in self.steps:
if step.step_number == step_number:
return step
return None
def mark_step_completed(self, step_number: int, data: Optional[Dict[str, Any]] = None):
"""Mark a step as completed."""
logger.info(f"[mark_step_completed] Marking step {step_number} as completed")
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.COMPLETED
step.completed_at = datetime.now().isoformat()
step.data = data
self.last_updated = datetime.now().isoformat()
# Check if all steps are now completed
all_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps)
if all_completed:
# If all steps are completed, mark onboarding as complete
self.is_completed = True
self.completed_at = datetime.now().isoformat()
self.current_step = len(self.steps) # Set to last step number
logger.info(f"[mark_step_completed] All steps completed, marking onboarding as complete")
else:
# Only increment current_step if there are more steps to go
self.current_step = step_number + 1
# Ensure current_step doesn't exceed total steps
if self.current_step > len(self.steps):
self.current_step = len(self.steps)
logger.info(f"[mark_step_completed] Step {step_number} completed, new current_step: {self.current_step}, is_completed: {self.is_completed}")
self.save_progress()
logger.info(f"Step {step_number} marked as completed")
else:
logger.error(f"[mark_step_completed] Step {step_number} not found")
def mark_step_in_progress(self, step_number: int):
"""Mark a step as in progress."""
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.IN_PROGRESS
self.current_step = step_number
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.info(f"Step {step_number} marked as in progress")
else:
logger.error(f"Step {step_number} not found")
def mark_step_skipped(self, step_number: int):
"""Mark a step as skipped."""
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.SKIPPED
step.completed_at = datetime.now().isoformat()
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.info(f"Step {step_number} marked as skipped")
else:
logger.error(f"Step {step_number} not found")
def mark_step_failed(self, step_number: int, error_message: str):
"""Mark a step as failed with error message."""
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.FAILED
step.validation_errors.append(error_message)
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.error(f"Step {step_number} marked as failed: {error_message}")
else:
logger.error(f"Step {step_number} not found")
def get_progress_summary(self) -> Dict[str, Any]:
"""Get current progress summary."""
completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED)
skipped_count = sum(1 for s in self.steps if s.status == StepStatus.SKIPPED)
failed_count = sum(1 for s in self.steps if s.status == StepStatus.FAILED)
return {
"total_steps": len(self.steps),
"completed_steps": completed_count,
"skipped_steps": skipped_count,
"failed_steps": failed_count,
"current_step": self.current_step,
"is_completed": self.is_completed,
"progress_percentage": (completed_count + skipped_count) / len(self.steps) * 100
}
def get_next_step(self) -> Optional[StepData]:
"""Get the next step to work on."""
for step in self.steps:
if step.status == StepStatus.PENDING:
return step
return None
def get_completed_steps(self) -> List[StepData]:
"""Get all completed steps."""
return [step for step in self.steps if step.status == StepStatus.COMPLETED]
def get_failed_steps(self) -> List[StepData]:
"""Get all failed steps."""
return [step for step in self.steps if step.status == StepStatus.FAILED]
def reset_step(self, step_number: int):
"""Reset a step to pending status."""
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.PENDING
step.completed_at = None
step.data = None
step.validation_errors = []
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.info(f"Step {step_number} reset to pending")
else:
logger.error(f"Step {step_number} not found")
def reset_all_steps(self):
"""Reset all steps to pending status."""
for step in self.steps:
step.status = StepStatus.PENDING
step.completed_at = None
step.data = None
step.validation_errors = []
self.current_step = 1
self.is_completed = False
self.completed_at = None
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.info("All steps reset to pending")
def complete_onboarding(self):
"""Mark onboarding as complete."""
self.is_completed = True
self.completed_at = datetime.now().isoformat()
self.current_step = len(self.steps)
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.info("Onboarding completed successfully")
def save_progress(self):
"""Save progress to database."""
if not self.use_database or not self.db_service or not self.user_id:
logger.error("Cannot save progress: database service not available or user_id not set")
return
try:
from services.database import SessionLocal
db = SessionLocal()
try:
# Update session progress
self.db_service.update_step(self.user_id, self.current_step, db)
# Calculate progress percentage
completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED)
progress_pct = (completed_count / len(self.steps)) * 100
self.db_service.update_progress(self.user_id, progress_pct, db)
# Save step-specific data to appropriate tables
for step in self.steps:
if step.status == StepStatus.COMPLETED and step.data:
if step.step_number == 1: # API Keys
api_keys = step.data.get('api_keys', {})
for provider, key in api_keys.items():
if key:
# Save to database (for user isolation in production)
self.db_service.save_api_key(self.user_id, provider, key, db)
# Also save to .env file ONLY in local development
# This allows local developers to have keys in .env for convenience
# In production, keys are fetched from database per user
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
if is_local:
try:
from services.api_key_manager import APIKeyManager
api_key_manager = APIKeyManager()
api_key_manager.save_api_key(provider, key)
logger.info(f"[LOCAL] API key for {provider} saved to .env file")
except Exception as env_error:
logger.warning(f"[LOCAL] Failed to save {provider} API key to .env file: {env_error}")
else:
logger.info(f"[PRODUCTION] API key for {provider} saved to database only (user: {self.user_id})")
# Log database save confirmation
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
elif step.step_number == 2: # Website Analysis
# Transform frontend data structure to match database schema
# Frontend sends: { website: "url", analysis: {...} }
# Database expects: { website_url: "url", ...analysis (flattened) }
analysis_for_db = {}
if step.data:
# Extract website_url from 'website' or 'website_url' field
website_url = step.data.get('website') or step.data.get('website_url')
if website_url:
analysis_for_db['website_url'] = website_url
# Flatten nested 'analysis' object if it exists
if 'analysis' in step.data and isinstance(step.data['analysis'], dict):
analysis_for_db.update(step.data['analysis'])
# Also include any other top-level fields (except 'website' and 'analysis')
for key, value in step.data.items():
if key not in ['website', 'website_url', 'analysis']:
analysis_for_db[key] = value
# Ensure status is set
if 'status' not in analysis_for_db:
analysis_for_db['status'] = 'completed'
self.db_service.save_website_analysis(self.user_id, analysis_for_db, db)
logger.info(f"✅ DATABASE: Website analysis saved to database for user {self.user_id}")
elif step.step_number == 3: # Research Preferences
self.db_service.save_research_preferences(self.user_id, step.data, db)
logger.info(f"✅ DATABASE: Research preferences saved to database for user {self.user_id}")
elif step.step_number == 4: # Persona Generation
self.db_service.save_persona_data(self.user_id, step.data, db)
logger.info(f"✅ DATABASE: Persona data saved to database for user {self.user_id}")
logger.info(f"Progress saved to database for user {self.user_id}")
finally:
db.close()
except Exception as e:
logger.error(f"Error saving progress to database: {str(e)}")
raise
def load_progress_from_db(self):
"""Load progress from database."""
if not self.use_database or not self.db_service or not self.user_id:
logger.warning("Cannot load progress: database service not available or user_id not set")
return
try:
from services.database import SessionLocal
db = SessionLocal()
try:
# Get session data
session = self.db_service.get_session_by_user(self.user_id, db)
if not session:
logger.info(f"No existing onboarding session found for user {self.user_id}, starting fresh")
return
# Restore session data
self.current_step = session.current_step or 1
self.started_at = session.started_at.isoformat() if session.started_at else self.started_at
self.last_updated = session.last_updated.isoformat() if session.last_updated else self.last_updated
self.is_completed = session.is_completed or False
self.completed_at = session.completed_at.isoformat() if session.completed_at else None
# Load step-specific data from database
self._load_step_data_from_db(db)
# Fix any corrupted state
self._fix_corrupted_state()
logger.info(f"Progress loaded from database for user {self.user_id}")
finally:
db.close()
except Exception as e:
logger.error(f"Error loading progress from database: {str(e)}")
# Don't fail if database loading fails - start fresh
def _load_step_data_from_db(self, db):
"""Load step-specific data from database tables."""
try:
# Load API keys (step 1)
api_keys = self.db_service.get_api_keys(self.user_id, db)
if api_keys:
step1 = self.get_step_data(1)
if step1:
step1.status = StepStatus.COMPLETED
step1.data = {'api_keys': api_keys}
step1.completed_at = datetime.now().isoformat()
# Load website analysis (step 2)
website_analysis = self.db_service.get_website_analysis(self.user_id, db)
if website_analysis:
step2 = self.get_step_data(2)
if step2:
step2.status = StepStatus.COMPLETED
step2.data = website_analysis
step2.completed_at = datetime.now().isoformat()
# Load research preferences (step 3)
research_prefs = self.db_service.get_research_preferences(self.user_id, db)
if research_prefs:
step3 = self.get_step_data(3)
if step3:
step3.status = StepStatus.COMPLETED
step3.data = research_prefs
step3.completed_at = datetime.now().isoformat()
# Load persona data (step 4)
persona_data = self.db_service.get_persona_data(self.user_id, db)
if persona_data:
step4 = self.get_step_data(4)
if step4:
step4.status = StepStatus.COMPLETED
step4.data = persona_data
step4.completed_at = datetime.now().isoformat()
logger.info("Step data loaded from database")
except Exception as e:
logger.error(f"Error loading step data from database: {str(e)}")
def _fix_corrupted_state(self):
"""Fix any corrupted progress state."""
# Check if all steps are completed
all_steps_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps)
if all_steps_completed:
self.is_completed = True
self.completed_at = self.completed_at or datetime.now().isoformat()
self.current_step = len(self.steps)
else:
# Find the first incomplete step
for i, step in enumerate(self.steps):
if step.status == StepStatus.PENDING:
self.current_step = step.step_number
break
class APIKeyManager:
"""Manages API keys for different providers."""
def __init__(self):
self.api_keys = {}
self._load_from_env()
def _load_from_env(self):
"""Load API keys from environment variables."""
providers = [
'GEMINI_API_KEY',
'HF_TOKEN',
'TAVILY_API_KEY',
'SERPER_API_KEY',
'METAPHOR_API_KEY',
'FIRECRAWL_API_KEY',
'STABILITY_API_KEY',
'WAVESPEED_API_KEY',
]
for provider in providers:
key = os.getenv(provider)
if key:
# Convert provider name to lowercase for consistency
provider_name = provider.replace('_API_KEY', '').lower()
self.api_keys[provider_name] = key
logger.info(f"Loaded {provider_name} API key from environment")
def get_api_key(self, provider: str) -> Optional[str]:
"""Get API key for a provider."""
return self.api_keys.get(provider.lower())
def save_api_key(self, provider: str, api_key: str):
"""Save API key to environment and memory."""
provider_lower = provider.lower()
self.api_keys[provider_lower] = api_key
# Update environment variable
env_var = f"{provider.upper()}_API_KEY"
os.environ[env_var] = api_key
logger.info(f"Saved {provider} API key")
def has_api_key(self, provider: str) -> bool:
"""Check if API key exists for provider."""
return provider.lower() in self.api_keys and bool(self.api_keys[provider.lower()])
def get_all_keys(self) -> Dict[str, str]:
"""Get all API keys."""
return self.api_keys.copy()
def remove_api_key(self, provider: str):
"""Remove API key for provider."""
provider_lower = provider.lower()
if provider_lower in self.api_keys:
del self.api_keys[provider_lower]
# Remove from environment
env_var = f"{provider.upper()}_API_KEY"
if env_var in os.environ:
del os.environ[env_var]
logger.info(f"Removed {provider} API key")
# Global instances
_user_onboarding_progress_cache = {}
def get_user_onboarding_progress(user_id: str) -> OnboardingProgress:
"""Get user-specific onboarding progress instance."""
global _user_onboarding_progress_cache
safe_user_id = ''.join([c if c.isalnum() or c in ('-', '_') else '_' for c in str(user_id)])
if safe_user_id in _user_onboarding_progress_cache:
return _user_onboarding_progress_cache[safe_user_id]
# Pass user_id to enable database persistence
instance = OnboardingProgress(user_id=user_id)
_user_onboarding_progress_cache[safe_user_id] = instance
return instance
def get_onboarding_progress_for_user(user_id: str) -> OnboardingProgress:
"""Get user-specific onboarding progress instance (alias for compatibility)."""
return get_user_onboarding_progress(user_id)
def get_onboarding_progress():
"""Get the global onboarding progress instance."""
if not hasattr(get_onboarding_progress, '_instance'):
get_onboarding_progress._instance = OnboardingProgress()
return get_onboarding_progress._instance
def get_api_key_manager() -> APIKeyManager:
"""Get the global API key manager instance."""
if not hasattr(get_api_key_manager, '_instance'):
get_api_key_manager._instance = APIKeyManager()
return get_api_key_manager._instance

View File

@@ -0,0 +1,291 @@
"""
Onboarding Data Service
Extracts real user data from onboarding to personalize AI inputs
"""
from typing import Dict, Any, List, Optional
from sqlalchemy.orm import Session
from loguru import logger
from datetime import datetime
import json
from services.database import get_db_session
from models.onboarding import OnboardingSession, WebsiteAnalysis, ResearchPreferences
class OnboardingDataService:
"""Service to extract and use real onboarding data for AI personalization."""
def __init__(self, db: Optional[Session] = None):
"""Initialize the onboarding data service."""
self.db = db
logger.info("OnboardingDataService initialized")
def get_user_website_analysis(self, user_id: int) -> Optional[Dict[str, Any]]:
"""
Get website analysis data for a specific user.
Args:
user_id: User ID to get data for
Returns:
Website analysis data or None if not found
"""
try:
session = self.db or get_db_session()
# Find onboarding session for user
onboarding_session = session.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if not onboarding_session:
logger.warning(f"No onboarding session found for user {user_id}")
return None
# Get website analysis for this session
website_analysis = session.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == onboarding_session.id
).first()
if not website_analysis:
logger.warning(f"No website analysis found for user {user_id}")
return None
return website_analysis.to_dict()
except Exception as e:
logger.error(f"Error getting website analysis for user {user_id}: {str(e)}")
return None
def get_user_research_preferences(self, user_id: int) -> Optional[Dict[str, Any]]:
"""
Get research preferences for a specific user.
Args:
user_id: User ID to get data for
Returns:
Research preferences data or None if not found
"""
try:
session = self.db or get_db_session()
# Find onboarding session for user
onboarding_session = session.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if not onboarding_session:
logger.warning(f"No onboarding session found for user {user_id}")
return None
# Get research preferences for this session
research_prefs = session.query(ResearchPreferences).filter(
ResearchPreferences.session_id == onboarding_session.id
).first()
if not research_prefs:
logger.warning(f"No research preferences found for user {user_id}")
return None
return research_prefs.to_dict()
except Exception as e:
logger.error(f"Error getting research preferences for user {user_id}: {str(e)}")
return None
def get_personalized_ai_inputs(self, user_id: int) -> Dict[str, Any]:
"""
Get personalized AI inputs based on user's onboarding data.
Args:
user_id: User ID to get personalized data for
Returns:
Personalized data for AI analysis
"""
try:
logger.info(f"Getting personalized AI inputs for user {user_id}")
# Get website analysis
website_analysis = self.get_user_website_analysis(user_id)
research_prefs = self.get_user_research_preferences(user_id)
if not website_analysis:
logger.warning(f"No onboarding data found for user {user_id}, using defaults")
return self._get_default_ai_inputs()
# Extract real data from website analysis
writing_style = website_analysis.get('writing_style', {})
target_audience = website_analysis.get('target_audience', {})
content_type = website_analysis.get('content_type', {})
recommended_settings = website_analysis.get('recommended_settings', {})
# Build personalized AI inputs
personalized_inputs = {
"website_analysis": {
"website_url": website_analysis.get('website_url', ''),
"content_types": self._extract_content_types(content_type),
"writing_style": writing_style.get('tone', 'professional'),
"target_audience": target_audience.get('demographics', ['professionals']),
"industry_focus": target_audience.get('industry_focus', 'general'),
"expertise_level": target_audience.get('expertise_level', 'intermediate')
},
"competitor_analysis": {
"top_performers": self._generate_competitor_suggestions(target_audience),
"industry": target_audience.get('industry_focus', 'general'),
"target_demographics": target_audience.get('demographics', [])
},
"gap_analysis": {
"content_gaps": self._identify_content_gaps(content_type, writing_style),
"target_keywords": self._generate_target_keywords(target_audience),
"content_opportunities": self._identify_opportunities(content_type)
},
"keyword_analysis": {
"high_value_keywords": self._generate_high_value_keywords(target_audience),
"content_topics": self._generate_content_topics(content_type),
"search_intent": self._analyze_search_intent(target_audience)
}
}
# Add research preferences if available
if research_prefs:
personalized_inputs["research_preferences"] = {
"research_depth": research_prefs.get('research_depth', 'Standard'),
"content_types": research_prefs.get('content_types', []),
"auto_research": research_prefs.get('auto_research', True),
"factual_content": research_prefs.get('factual_content', True)
}
logger.info(f"✅ Generated personalized AI inputs for user {user_id}")
return personalized_inputs
except Exception as e:
logger.error(f"Error generating personalized AI inputs for user {user_id}: {str(e)}")
return self._get_default_ai_inputs()
def _extract_content_types(self, content_type: Dict[str, Any]) -> List[str]:
"""Extract content types from content type analysis."""
types = []
if content_type.get('primary_type'):
types.append(content_type['primary_type'])
if content_type.get('secondary_types'):
types.extend(content_type['secondary_types'])
return types if types else ['blog', 'article']
def _generate_competitor_suggestions(self, target_audience: Dict[str, Any]) -> List[str]:
"""Generate competitor suggestions based on target audience."""
industry = target_audience.get('industry_focus', 'general')
demographics = target_audience.get('demographics', ['professionals'])
# Generate industry-specific competitors
if industry == 'technology':
return ['techcrunch.com', 'wired.com', 'theverge.com']
elif industry == 'marketing':
return ['hubspot.com', 'marketingland.com', 'moz.com']
else:
return ['competitor1.com', 'competitor2.com', 'competitor3.com']
def _identify_content_gaps(self, content_type: Dict[str, Any], writing_style: Dict[str, Any]) -> List[str]:
"""Identify content gaps based on current content type and style."""
gaps = []
primary_type = content_type.get('primary_type', 'blog')
if primary_type == 'blog':
gaps.extend(['Video tutorials', 'Case studies', 'Infographics'])
elif primary_type == 'video':
gaps.extend(['Blog posts', 'Whitepapers', 'Webinars'])
# Add style-based gaps
tone = writing_style.get('tone', 'professional')
if tone == 'professional':
gaps.append('Personal stories')
elif tone == 'casual':
gaps.append('Expert interviews')
return gaps
def _generate_target_keywords(self, target_audience: Dict[str, Any]) -> List[str]:
"""Generate target keywords based on audience analysis."""
industry = target_audience.get('industry_focus', 'general')
expertise = target_audience.get('expertise_level', 'intermediate')
if industry == 'technology':
return ['AI tools', 'Digital transformation', 'Tech trends']
elif industry == 'marketing':
return ['Content marketing', 'SEO strategies', 'Social media']
else:
return ['Industry insights', 'Best practices', 'Expert tips']
def _identify_opportunities(self, content_type: Dict[str, Any]) -> List[str]:
"""Identify content opportunities based on current content type."""
opportunities = []
purpose = content_type.get('purpose', 'informational')
if purpose == 'informational':
opportunities.extend(['How-to guides', 'Tutorials', 'Educational content'])
elif purpose == 'promotional':
opportunities.extend(['Case studies', 'Testimonials', 'Success stories'])
return opportunities
def _generate_high_value_keywords(self, target_audience: Dict[str, Any]) -> List[str]:
"""Generate high-value keywords based on audience analysis."""
industry = target_audience.get('industry_focus', 'general')
if industry == 'technology':
return ['AI marketing', 'Content automation', 'Digital strategy']
elif industry == 'marketing':
return ['Content marketing', 'SEO optimization', 'Social media strategy']
else:
return ['Industry trends', 'Best practices', 'Expert insights']
def _generate_content_topics(self, content_type: Dict[str, Any]) -> List[str]:
"""Generate content topics based on content type analysis."""
topics = []
primary_type = content_type.get('primary_type', 'blog')
if primary_type == 'blog':
topics.extend(['Industry trends', 'How-to guides', 'Expert insights'])
elif primary_type == 'video':
topics.extend(['Tutorials', 'Product demos', 'Expert interviews'])
return topics
def _analyze_search_intent(self, target_audience: Dict[str, Any]) -> Dict[str, Any]:
"""Analyze search intent based on target audience."""
expertise = target_audience.get('expertise_level', 'intermediate')
if expertise == 'beginner':
return {'intent': 'educational', 'focus': 'basic concepts'}
elif expertise == 'intermediate':
return {'intent': 'practical', 'focus': 'implementation'}
else:
return {'intent': 'advanced', 'focus': 'strategic insights'}
def _get_default_ai_inputs(self) -> Dict[str, Any]:
"""Get default AI inputs when no onboarding data is available."""
return {
"website_analysis": {
"content_types": ["blog", "video", "social"],
"writing_style": "professional",
"target_audience": ["professionals"],
"industry_focus": "general",
"expertise_level": "intermediate"
},
"competitor_analysis": {
"top_performers": ["competitor1.com", "competitor2.com"],
"industry": "general",
"target_demographics": ["professionals"]
},
"gap_analysis": {
"content_gaps": ["AI content", "Video tutorials", "Case studies"],
"target_keywords": ["Industry insights", "Best practices"],
"content_opportunities": ["How-to guides", "Tutorials"]
},
"keyword_analysis": {
"high_value_keywords": ["AI marketing", "Content automation", "Digital strategy"],
"content_topics": ["Industry trends", "Expert insights"],
"search_intent": {"intent": "practical", "focus": "implementation"}
}
}

View File

@@ -0,0 +1,666 @@
"""
Onboarding Database Service
Provides database-backed storage for onboarding progress with user isolation.
This replaces the JSON file-based storage with proper database persistence.
"""
from typing import Dict, Any, Optional, List
import os
import json
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy import text
from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData
from services.database import get_db
class OnboardingDatabaseService:
"""Database service for onboarding with user isolation."""
def __init__(self, db: Session = None):
"""Initialize with optional database session."""
self.db = db
# Cache for schema feature detection
self._brand_cols_checked: bool = False
self._brand_cols_available: bool = False
self._research_persona_cols_checked: bool = False
self._research_persona_cols_available: bool = False
# --- Feature flags and schema detection helpers ---
def _brand_feature_enabled(self) -> bool:
"""Check if writing brand-related columns is enabled via env flag."""
return os.getenv('ENABLE_WEBSITE_BRAND_COLUMNS', 'true').lower() in {'1', 'true', 'yes', 'on'}
def _ensure_research_persona_columns(self, session_db: Session) -> None:
"""Ensure research_persona columns exist in persona_data table (runtime migration)."""
if self._research_persona_cols_checked:
return
try:
# Check if columns exist using PRAGMA (SQLite) or information_schema (PostgreSQL)
db_url = str(session_db.bind.url) if session_db.bind else ""
if 'sqlite' in db_url.lower():
# SQLite: Use PRAGMA to check columns
result = session_db.execute(text("PRAGMA table_info(persona_data)"))
cols = {row[1] for row in result} # Column name is at index 1
if 'research_persona' not in cols:
logger.info("Adding missing column research_persona to persona_data table")
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSON"))
session_db.commit()
if 'research_persona_generated_at' not in cols:
logger.info("Adding missing column research_persona_generated_at to persona_data table")
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP"))
session_db.commit()
self._research_persona_cols_available = True
else:
# PostgreSQL: Try to query the columns (will fail if they don't exist)
try:
session_db.execute(text("SELECT research_persona, research_persona_generated_at FROM persona_data LIMIT 0"))
self._research_persona_cols_available = True
except Exception:
# Columns don't exist, add them
logger.info("Adding missing columns research_persona and research_persona_generated_at to persona_data table")
try:
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSONB"))
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP"))
session_db.commit()
self._research_persona_cols_available = True
except Exception as alter_err:
logger.error(f"Failed to add research_persona columns: {alter_err}")
session_db.rollback()
raise
except Exception as e:
logger.error(f"Error ensuring research_persona columns: {e}")
session_db.rollback()
raise
finally:
self._research_persona_cols_checked = True
def _ensure_brand_column_detection(self, session_db: Session) -> None:
"""Detect at runtime whether brand columns exist and cache the result."""
if self._brand_cols_checked:
return
try:
# This works across SQLite/Postgres; LIMIT 0 avoids scanning
session_db.execute(text('SELECT brand_analysis, content_strategy_insights FROM website_analyses LIMIT 0'))
self._brand_cols_available = True
except Exception:
self._brand_cols_available = False
finally:
self._brand_cols_checked = True
def _maybe_update_brand_columns(self, session_db: Session, session_id: int, brand_analysis: Any, content_strategy_insights: Any) -> None:
"""Safely update brand columns using raw SQL if feature enabled and columns exist."""
if not self._brand_feature_enabled():
return
self._ensure_brand_column_detection(session_db)
if not self._brand_cols_available:
return
try:
session_db.execute(
text('''
UPDATE website_analyses
SET brand_analysis = :brand_analysis,
content_strategy_insights = :content_strategy_insights
WHERE session_id = :session_id
'''),
{
'brand_analysis': json.dumps(brand_analysis) if brand_analysis is not None else None,
'content_strategy_insights': json.dumps(content_strategy_insights) if content_strategy_insights is not None else None,
'session_id': session_id,
}
)
except Exception as e:
logger.warning(f"Skipped updating brand columns (not critical): {e}")
def _maybe_attach_brand_columns(self, session_db: Session, session_id: int, result: Dict[str, Any]) -> None:
"""Optionally read brand columns and attach to result if available."""
if not self._brand_feature_enabled():
return
self._ensure_brand_column_detection(session_db)
if not self._brand_cols_available:
return
try:
row = session_db.execute(
text('''
SELECT brand_analysis, content_strategy_insights
FROM website_analyses WHERE session_id = :session_id LIMIT 1
'''),
{'session_id': session_id}
).mappings().first()
if row:
brand = row.get('brand_analysis')
insights = row.get('content_strategy_insights')
# If stored as TEXT in SQLite, try to parse JSON
if isinstance(brand, str):
try:
brand = json.loads(brand)
except Exception:
pass
if isinstance(insights, str):
try:
insights = json.loads(insights)
except Exception:
pass
result['brand_analysis'] = brand
result['content_strategy_insights'] = insights
except Exception as e:
logger.warning(f"Skipped reading brand columns (not critical): {e}")
def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession:
"""Get existing onboarding session or create new one for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
# Try to get existing session for this user
session = session_db.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if session:
logger.info(f"Found existing onboarding session for user {user_id}")
return session
# Create new session
session = OnboardingSession(
user_id=user_id,
current_step=1,
progress=0.0,
started_at=datetime.now()
)
session_db.add(session)
session_db.commit()
session_db.refresh(session)
logger.info(f"Created new onboarding session for user {user_id}")
return session
except SQLAlchemyError as e:
logger.error(f"Database error in get_or_create_session: {e}")
session_db.rollback()
raise
def get_session_by_user(self, user_id: str, db: Session = None) -> Optional[OnboardingSession]:
"""Get onboarding session for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
return session_db.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
except SQLAlchemyError as e:
logger.error(f"Error getting session: {e}")
return None
def update_step(self, user_id: str, step_number: int, db: Session = None) -> bool:
"""Update current step for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.current_step = step_number
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Updated user {user_id} to step {step_number}")
return True
except SQLAlchemyError as e:
logger.error(f"Error updating step: {e}")
session_db.rollback()
return False
def update_progress(self, user_id: str, progress: float, db: Session = None) -> bool:
"""Update progress percentage for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.progress = progress
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Updated user {user_id} progress to {progress}%")
return True
except SQLAlchemyError as e:
logger.error(f"Error updating progress: {e}")
session_db.rollback()
return False
def save_api_key(self, user_id: str, provider: str, api_key: str, db: Session = None) -> bool:
"""Save API key for user with isolation."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
# Get user's onboarding session
session = self.get_or_create_session(user_id, session_db)
# Check if key already exists for this provider and session
existing_key = session_db.query(APIKey).filter(
APIKey.session_id == session.id,
APIKey.provider == provider
).first()
if existing_key:
# Update existing key
existing_key.key = api_key
existing_key.updated_at = datetime.now()
logger.info(f"Updated {provider} API key for user {user_id}")
else:
# Create new key
new_key = APIKey(
session_id=session.id,
provider=provider,
key=api_key
)
session_db.add(new_key)
logger.info(f"Created new {provider} API key for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving API key: {e}")
session_db.rollback()
return False
def get_api_keys(self, user_id: str, db: Session = None) -> Dict[str, str]:
"""Get all API keys for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return {}
keys = session_db.query(APIKey).filter(
APIKey.session_id == session.id
).all()
return {key.provider: key.key for key in keys}
except SQLAlchemyError as e:
logger.error(f"Error getting API keys: {e}")
return {}
def save_website_analysis(self, user_id: str, analysis_data: Dict[str, Any], db: Session = None) -> bool:
"""Save website analysis for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Normalize payload. Step 2 sometimes sends { website, analysis: {...} }
# while DB expects flattened fields. Support both shapes.
incoming = analysis_data or {}
nested = incoming.get('analysis') if isinstance(incoming.get('analysis'), dict) else None
normalized = {
'website_url': incoming.get('website') or incoming.get('website_url') or '',
'writing_style': (nested or incoming).get('writing_style'),
'content_characteristics': (nested or incoming).get('content_characteristics'),
'target_audience': (nested or incoming).get('target_audience'),
'content_type': (nested or incoming).get('content_type'),
'recommended_settings': (nested or incoming).get('recommended_settings'),
'brand_analysis': (nested or incoming).get('brand_analysis'),
'content_strategy_insights': (nested or incoming).get('content_strategy_insights'),
'crawl_result': (nested or incoming).get('crawl_result'),
'style_patterns': (nested or incoming).get('style_patterns'),
'style_guidelines': (nested or incoming).get('style_guidelines'),
'status': (nested or incoming).get('status', incoming.get('status', 'completed')),
}
# Check if analysis already exists
existing = session_db.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
if existing:
# Update existing - only update website_url if normalized value is not empty
# This prevents overwriting a valid URL with an empty string when step.data
# doesn't include the website field
normalized_url = normalized.get('website_url', '').strip() if normalized.get('website_url') else ''
if normalized_url:
existing.website_url = normalized_url
# If normalized_url is empty, keep existing.website_url unchanged
existing.writing_style = normalized.get('writing_style')
existing.content_characteristics = normalized.get('content_characteristics')
existing.target_audience = normalized.get('target_audience')
existing.content_type = normalized.get('content_type')
existing.recommended_settings = normalized.get('recommended_settings')
existing.crawl_result = normalized.get('crawl_result')
existing.style_patterns = normalized.get('style_patterns')
existing.style_guidelines = normalized.get('style_guidelines')
existing.status = normalized.get('status', 'completed')
existing.updated_at = datetime.now()
logger.info(f"Updated website analysis for user {user_id}")
else:
# Create new
analysis = WebsiteAnalysis(
session_id=session.id,
website_url=normalized.get('website_url', ''),
writing_style=normalized.get('writing_style'),
content_characteristics=normalized.get('content_characteristics'),
target_audience=normalized.get('target_audience'),
content_type=normalized.get('content_type'),
recommended_settings=normalized.get('recommended_settings'),
crawl_result=normalized.get('crawl_result'),
style_patterns=normalized.get('style_patterns'),
style_guidelines=normalized.get('style_guidelines'),
status=normalized.get('status', 'completed')
)
session_db.add(analysis)
logger.info(f"Created website analysis for user {user_id}")
session_db.commit()
# Optional brand column update via raw SQL (feature-flagged)
self._maybe_update_brand_columns(
session_db=session_db,
session_id=session.id,
brand_analysis=normalized.get('brand_analysis'),
content_strategy_insights=normalized.get('content_strategy_insights')
)
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving website analysis: {e}")
session_db.rollback()
return False
def get_website_analysis(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get website analysis for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
analysis = session_db.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
result = analysis.to_dict() if analysis else None
if result:
# Optionally include brand fields without touching ORM mapping
self._maybe_attach_brand_columns(session_db, session.id, result)
return result
except SQLAlchemyError as e:
logger.error(f"Error getting website analysis: {e}")
return None
def save_research_preferences(self, user_id: str, preferences: Dict[str, Any], db: Session = None) -> bool:
"""Save research preferences for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Check if preferences already exist
existing = session_db.query(ResearchPreferences).filter(
ResearchPreferences.session_id == session.id
).first()
if existing:
# Update existing
existing.research_depth = preferences.get('research_depth', existing.research_depth)
existing.content_types = preferences.get('content_types', existing.content_types)
existing.auto_research = preferences.get('auto_research', existing.auto_research)
existing.factual_content = preferences.get('factual_content', existing.factual_content)
existing.writing_style = preferences.get('writing_style')
existing.content_characteristics = preferences.get('content_characteristics')
existing.target_audience = preferences.get('target_audience')
existing.recommended_settings = preferences.get('recommended_settings')
existing.updated_at = datetime.now()
logger.info(f"Updated research preferences for user {user_id}")
else:
# Create new
prefs = ResearchPreferences(
session_id=session.id,
research_depth=preferences.get('research_depth', 'standard'),
content_types=preferences.get('content_types', []),
auto_research=preferences.get('auto_research', True),
factual_content=preferences.get('factual_content', True),
writing_style=preferences.get('writing_style'),
content_characteristics=preferences.get('content_characteristics'),
target_audience=preferences.get('target_audience'),
recommended_settings=preferences.get('recommended_settings')
)
session_db.add(prefs)
logger.info(f"Created research preferences for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving research preferences: {e}")
session_db.rollback()
return False
def save_persona_data(self, user_id: str, persona_data: Dict[str, Any], db: Session = None) -> bool:
"""Save persona data for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Check if persona data already exists for this user
existing = session_db.query(PersonaData).filter(
PersonaData.session_id == session.id
).first()
if existing:
# Update existing persona data
existing.core_persona = persona_data.get('corePersona')
existing.platform_personas = persona_data.get('platformPersonas')
existing.quality_metrics = persona_data.get('qualityMetrics')
existing.selected_platforms = persona_data.get('selectedPlatforms', [])
existing.updated_at = datetime.utcnow()
logger.info(f"Updated persona data for user {user_id}")
else:
# Create new persona data record
persona = PersonaData(
session_id=session.id,
core_persona=persona_data.get('corePersona'),
platform_personas=persona_data.get('platformPersonas'),
quality_metrics=persona_data.get('qualityMetrics'),
selected_platforms=persona_data.get('selectedPlatforms', [])
)
session_db.add(persona)
logger.info(f"Created persona data for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving persona data: {e}")
session_db.rollback()
return False
def get_research_preferences(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get research preferences for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
prefs = session_db.query(ResearchPreferences).filter(
ResearchPreferences.session_id == session.id
).first()
return prefs.to_dict() if prefs else None
except SQLAlchemyError as e:
logger.error(f"Error getting research preferences: {e}")
return None
def get_competitor_analysis(self, user_id: str, db: Session = None) -> Optional[List[Dict[str, Any]]]:
"""Get competitor analysis data for user from onboarding."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
from models.onboarding import CompetitorAnalysis
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
# Query CompetitorAnalysis table
competitor_records = session_db.query(CompetitorAnalysis).filter(
CompetitorAnalysis.session_id == session.id
).all()
if not competitor_records:
return None
# Convert to list of dicts
competitors = []
for record in competitor_records:
analysis_data = record.analysis_data or {}
competitors.append({
"url": record.competitor_url,
"domain": record.competitor_domain or record.competitor_url,
"title": analysis_data.get("title", record.competitor_domain or ""),
"summary": analysis_data.get("summary", ""),
"relevance_score": analysis_data.get("relevance_score", 0.5),
"highlights": analysis_data.get("highlights", []),
"favicon": analysis_data.get("favicon"),
"image": analysis_data.get("image"),
"published_date": analysis_data.get("published_date"),
"author": analysis_data.get("author"),
"competitive_insights": analysis_data.get("competitive_analysis", {}),
"content_insights": analysis_data.get("content_insights", {})
})
return competitors
except SQLAlchemyError as e:
logger.error(f"Error getting competitor analysis: {e}")
return None
def get_persona_data(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get persona data for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
# Ensure research_persona columns exist before querying
self._ensure_research_persona_columns(session_db)
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
persona = session_db.query(PersonaData).filter(
PersonaData.session_id == session.id
).first()
if not persona:
return None
# Return persona data in the expected format
return {
'corePersona': persona.core_persona,
'platformPersonas': persona.platform_personas,
'qualityMetrics': persona.quality_metrics,
'selectedPlatforms': persona.selected_platforms
}
except SQLAlchemyError as e:
logger.error(f"Error getting persona data: {e}")
return None
def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool:
"""Mark onboarding as complete for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.current_step = 6 # Final step
session.progress = 100.0
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Marked onboarding complete for user {user_id}")
return True
except SQLAlchemyError as e:
logger.error(f"Error marking onboarding complete: {e}")
session_db.rollback()
return False
def get_onboarding_status(self, user_id: str, db: Session = None) -> Dict[str, Any]:
"""Get comprehensive onboarding status for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
# User hasn't started onboarding yet
return {
"is_completed": False,
"current_step": 1,
"progress": 0.0,
"started_at": None,
"updated_at": None
}
return {
"is_completed": session.current_step >= 6 and session.progress >= 100.0,
"current_step": session.current_step,
"progress": session.progress,
"started_at": session.started_at.isoformat() if session.started_at else None,
"updated_at": session.updated_at.isoformat() if session.updated_at else None
}
except SQLAlchemyError as e:
logger.error(f"Error getting onboarding status: {e}")
return {
"is_completed": False,
"current_step": 1,
"progress": 0.0,
"started_at": None,
"updated_at": None
}

View File

@@ -0,0 +1,163 @@
"""
Database-only Onboarding Progress Service
Replaces file-based progress tracking with database-only implementation.
"""
from typing import Dict, Any, List, Optional
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from services.database import SessionLocal
from .database_service import OnboardingDatabaseService
class OnboardingProgressService:
"""Database-only onboarding progress management."""
def __init__(self):
self.db_service = OnboardingDatabaseService()
def get_onboarding_status(self, user_id: str) -> Dict[str, Any]:
"""Get current onboarding status from database only."""
try:
db = SessionLocal()
try:
# Get session data
session = self.db_service.get_session_by_user(user_id, db)
if not session:
return {
"is_completed": False,
"current_step": 1,
"completion_percentage": 0.0,
"started_at": None,
"last_updated": None,
"completed_at": None
}
# Check if onboarding is complete
# Consider complete if either the final step is reached OR progress hit 100%
# This guards against partial writes where one field persisted but the other didn't.
is_completed = (session.current_step >= 6) or (session.progress >= 100.0)
return {
"is_completed": is_completed,
"current_step": session.current_step,
"completion_percentage": session.progress,
"started_at": session.started_at.isoformat() if session.started_at else None,
"last_updated": session.updated_at.isoformat() if session.updated_at else None,
"completed_at": session.updated_at.isoformat() if is_completed else None
}
finally:
db.close()
except Exception as e:
logger.error(f"Error getting onboarding status: {e}")
return {
"is_completed": False,
"current_step": 1,
"completion_percentage": 0.0,
"started_at": None,
"last_updated": None,
"completed_at": None
}
def update_step(self, user_id: str, step_number: int) -> bool:
"""Update current step in database."""
try:
db = SessionLocal()
try:
success = self.db_service.update_step(user_id, step_number, db)
if success:
logger.info(f"Updated user {user_id} to step {step_number}")
return success
finally:
db.close()
except Exception as e:
logger.error(f"Error updating step: {e}")
return False
def update_progress(self, user_id: str, progress_percentage: float) -> bool:
"""Update progress percentage in database."""
try:
db = SessionLocal()
try:
success = self.db_service.update_progress(user_id, progress_percentage, db)
if success:
logger.info(f"Updated user {user_id} progress to {progress_percentage}%")
return success
finally:
db.close()
except Exception as e:
logger.error(f"Error updating progress: {e}")
return False
def complete_onboarding(self, user_id: str) -> bool:
"""Mark onboarding as complete in database."""
try:
db = SessionLocal()
try:
success = self.db_service.mark_onboarding_complete(user_id, db)
if success:
logger.info(f"Marked onboarding complete for user {user_id}")
return success
finally:
db.close()
except Exception as e:
logger.error(f"Error completing onboarding: {e}")
return False
def reset_onboarding(self, user_id: str) -> bool:
"""Reset onboarding progress in database."""
try:
db = SessionLocal()
try:
# Reset to step 1, 0% progress
success = self.db_service.update_step(user_id, 1, db)
if success:
self.db_service.update_progress(user_id, 0.0, db)
logger.info(f"Reset onboarding for user {user_id}")
return success
finally:
db.close()
except Exception as e:
logger.error(f"Error resetting onboarding: {e}")
return False
def get_completion_data(self, user_id: str) -> Dict[str, Any]:
"""Get completion data for validation."""
try:
db = SessionLocal()
try:
# Get all relevant data for completion validation
session = self.db_service.get_session_by_user(user_id, db)
api_keys = self.db_service.get_api_keys(user_id, db)
website_analysis = self.db_service.get_website_analysis(user_id, db)
research_preferences = self.db_service.get_research_preferences(user_id, db)
persona_data = self.db_service.get_persona_data(user_id, db)
return {
"session": session,
"api_keys": api_keys,
"website_analysis": website_analysis,
"research_preferences": research_preferences,
"persona_data": persona_data
}
finally:
db.close()
except Exception as e:
logger.error(f"Error getting completion data: {e}")
return {}
# Global instance
_onboarding_progress_service = None
def get_onboarding_progress_service() -> OnboardingProgressService:
"""Get the global onboarding progress service instance."""
global _onboarding_progress_service
if _onboarding_progress_service is None:
_onboarding_progress_service = OnboardingProgressService()
return _onboarding_progress_service