Base code
This commit is contained in:
204
backend/services/onboarding/README.md
Normal file
204
backend/services/onboarding/README.md
Normal file
@@ -0,0 +1,204 @@
|
||||
# Onboarding Services Package
|
||||
|
||||
This package contains all onboarding-related services and utilities for ALwrity. All onboarding data is stored in the database with proper user isolation, replacing the previous file-based JSON storage system.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Database-First Design
|
||||
- **Primary Storage**: PostgreSQL database with proper foreign keys and relationships
|
||||
- **User Isolation**: Each user's onboarding data is completely separate
|
||||
- **No File Storage**: Removed all JSON file operations for production scalability
|
||||
- **Local Development**: API keys still written to `.env` for developer convenience
|
||||
|
||||
### Service Structure
|
||||
|
||||
```
|
||||
backend/services/onboarding/
|
||||
├── __init__.py # Package exports
|
||||
├── database_service.py # Core database operations
|
||||
├── progress_service.py # Progress tracking and step management
|
||||
├── data_service.py # Data validation and processing
|
||||
├── api_key_manager.py # API key management + progress tracking
|
||||
└── README.md # This documentation
|
||||
```
|
||||
|
||||
## Services
|
||||
|
||||
### 1. OnboardingDatabaseService (`database_service.py`)
|
||||
**Purpose**: Core database operations for onboarding data with user isolation.
|
||||
|
||||
**Key Features**:
|
||||
- User-specific session management
|
||||
- API key storage and retrieval
|
||||
- Website analysis persistence
|
||||
- Research preferences management
|
||||
- Persona data storage
|
||||
- Brand analysis support (feature-flagged)
|
||||
|
||||
**Main Methods**:
|
||||
- `get_or_create_session(user_id)` - Get or create user session
|
||||
- `save_api_key(user_id, provider, key)` - Store API keys
|
||||
- `save_website_analysis(user_id, data)` - Store website analysis
|
||||
- `save_research_preferences(user_id, prefs)` - Store research settings
|
||||
- `save_persona_data(user_id, data)` - Store persona information
|
||||
|
||||
### 2. OnboardingProgressService (`progress_service.py`)
|
||||
**Purpose**: High-level progress tracking and step management.
|
||||
|
||||
**Key Features**:
|
||||
- Database-only progress tracking
|
||||
- Step completion validation
|
||||
- Progress percentage calculation
|
||||
- Onboarding completion management
|
||||
|
||||
**Main Methods**:
|
||||
- `get_onboarding_status(user_id)` - Get current status
|
||||
- `update_step(user_id, step_number)` - Update current step
|
||||
- `update_progress(user_id, percentage)` - Update progress
|
||||
- `complete_onboarding(user_id)` - Mark as complete
|
||||
|
||||
### 3. OnboardingDataService (`data_service.py`)
|
||||
**Purpose**: Extract and use onboarding data for AI personalization.
|
||||
|
||||
**Key Features**:
|
||||
- Personalized AI input generation
|
||||
- Website analysis data extraction
|
||||
- Research preferences integration
|
||||
- Default fallback data
|
||||
|
||||
**Main Methods**:
|
||||
- `get_personalized_ai_inputs(user_id)` - Generate personalized inputs
|
||||
- `get_user_website_analysis(user_id)` - Get website data
|
||||
- `get_user_research_preferences(user_id)` - Get research settings
|
||||
|
||||
### 4. OnboardingProgress + APIKeyManager (`api_key_manager.py`)
|
||||
**Purpose**: Combined API key management and progress tracking with database persistence.
|
||||
|
||||
**Key Features**:
|
||||
- Database-only progress persistence (no JSON files)
|
||||
- API key management with environment integration
|
||||
- Step-by-step progress tracking
|
||||
- User-specific progress instances
|
||||
|
||||
**Main Classes**:
|
||||
- `OnboardingProgress` - Progress tracking with database persistence
|
||||
- `APIKeyManager` - API key management
|
||||
- `StepData` - Individual step data structure
|
||||
- `StepStatus` - Step status enumeration
|
||||
|
||||
## Database Schema
|
||||
|
||||
### Core Tables
|
||||
- `onboarding_sessions` - User session tracking
|
||||
- `api_keys` - User-specific API key storage
|
||||
- `website_analyses` - Website analysis data
|
||||
- `research_preferences` - User research settings
|
||||
- `persona_data` - Generated persona information
|
||||
|
||||
### Relationships
|
||||
- All data tables reference `onboarding_sessions.id`
|
||||
- User isolation via `user_id` foreign key
|
||||
- Proper cascade deletion and updates
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Basic Progress Tracking
|
||||
```python
|
||||
from services.onboarding import OnboardingProgress
|
||||
|
||||
# Get user-specific progress
|
||||
progress = OnboardingProgress(user_id="user123")
|
||||
|
||||
# Mark step as completed
|
||||
progress.mark_step_completed(1, {"api_keys": {"openai": "sk-..."}})
|
||||
|
||||
# Get progress summary
|
||||
summary = progress.get_progress_summary()
|
||||
```
|
||||
|
||||
### Database Operations
|
||||
```python
|
||||
from services.onboarding import OnboardingDatabaseService
|
||||
from services.database import SessionLocal
|
||||
|
||||
db = SessionLocal()
|
||||
service = OnboardingDatabaseService(db)
|
||||
|
||||
# Save API key
|
||||
service.save_api_key("user123", "openai", "sk-...")
|
||||
|
||||
# Get website analysis
|
||||
analysis = service.get_website_analysis("user123", db)
|
||||
```
|
||||
|
||||
### Progress Service
|
||||
```python
|
||||
from services.onboarding import OnboardingProgressService
|
||||
|
||||
service = OnboardingProgressService()
|
||||
|
||||
# Get status
|
||||
status = service.get_onboarding_status("user123")
|
||||
|
||||
# Update progress
|
||||
service.update_step("user123", 2)
|
||||
service.update_progress("user123", 50.0)
|
||||
```
|
||||
|
||||
## Migration from File-Based Storage
|
||||
|
||||
### What Was Removed
|
||||
- JSON file operations (`.onboarding_progress*.json`)
|
||||
- File-based progress persistence
|
||||
- Dual persistence system (file + database)
|
||||
|
||||
### What Was Kept
|
||||
- Database persistence (enhanced)
|
||||
- Local development `.env` API key writing
|
||||
- All existing functionality and APIs
|
||||
|
||||
### Benefits
|
||||
- **Production Ready**: No ephemeral file storage
|
||||
- **Scalable**: Database-backed with proper indexing
|
||||
- **User Isolated**: Complete data separation
|
||||
- **Maintainable**: Single source of truth
|
||||
|
||||
## Environment Variables
|
||||
|
||||
### Required
|
||||
- Database connection (via `services.database`)
|
||||
- User authentication system
|
||||
|
||||
### Optional
|
||||
- `ENABLE_WEBSITE_BRAND_COLUMNS=true` - Enable brand analysis features
|
||||
- `DEPLOY_ENV=local` - Enable local `.env` API key writing
|
||||
|
||||
## Error Handling
|
||||
|
||||
All services include comprehensive error handling:
|
||||
- Database connection failures
|
||||
- User not found scenarios
|
||||
- Invalid data validation
|
||||
- Graceful fallbacks to defaults
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- Database queries are optimized with proper indexing
|
||||
- User-specific caching where appropriate
|
||||
- Minimal database calls through efficient service design
|
||||
- Connection pooling via SQLAlchemy
|
||||
|
||||
## Testing
|
||||
|
||||
Each service can be tested independently:
|
||||
- Unit tests for individual methods
|
||||
- Integration tests with database
|
||||
- Mock database sessions for isolated testing
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- Real-time progress updates via WebSocket
|
||||
- Progress analytics and reporting
|
||||
- Bulk user operations
|
||||
- Advanced validation rules
|
||||
- Progress recovery mechanisms
|
||||
35
backend/services/onboarding/__init__.py
Normal file
35
backend/services/onboarding/__init__.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""
|
||||
Onboarding Services Package
|
||||
|
||||
This package contains all onboarding-related services and utilities.
|
||||
All onboarding data is stored in the database with proper user isolation.
|
||||
|
||||
Services:
|
||||
- OnboardingDatabaseService: Core database operations for onboarding data
|
||||
- OnboardingProgressService: Progress tracking and step management
|
||||
- OnboardingDataService: Data validation and processing
|
||||
- OnboardingProgress: Progress tracking with database persistence (from api_key_manager)
|
||||
|
||||
Architecture:
|
||||
- Database-first: All data stored in PostgreSQL with proper foreign keys
|
||||
- User isolation: Each user's data is completely separate
|
||||
- No file storage: Removed all JSON file operations for production scalability
|
||||
- Local development: API keys still written to .env for convenience
|
||||
"""
|
||||
|
||||
# Import all public classes for easy access
|
||||
from .database_service import OnboardingDatabaseService
|
||||
from .progress_service import OnboardingProgressService
|
||||
from .data_service import OnboardingDataService
|
||||
from .api_key_manager import OnboardingProgress, APIKeyManager, get_onboarding_progress, get_user_onboarding_progress, get_onboarding_progress_for_user
|
||||
|
||||
__all__ = [
|
||||
'OnboardingDatabaseService',
|
||||
'OnboardingProgressService',
|
||||
'OnboardingDataService',
|
||||
'OnboardingProgress',
|
||||
'APIKeyManager',
|
||||
'get_onboarding_progress',
|
||||
'get_user_onboarding_progress',
|
||||
'get_onboarding_progress_for_user'
|
||||
]
|
||||
495
backend/services/onboarding/api_key_manager.py
Normal file
495
backend/services/onboarding/api_key_manager.py
Normal file
@@ -0,0 +1,495 @@
|
||||
"""
|
||||
API Key Manager with Database-Only Onboarding Progress
|
||||
Manages API keys and onboarding progress with database persistence only.
|
||||
Removed all file-based JSON storage for production scalability.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from enum import Enum
|
||||
|
||||
from services.database import get_db_session
|
||||
|
||||
|
||||
class StepStatus(Enum):
|
||||
"""Onboarding step status."""
|
||||
PENDING = "pending"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
SKIPPED = "skipped"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class StepData:
|
||||
"""Data structure for onboarding step."""
|
||||
|
||||
def __init__(self, step_number: int, title: str, description: str, status: StepStatus = StepStatus.PENDING):
|
||||
self.step_number = step_number
|
||||
self.title = title
|
||||
self.description = description
|
||||
self.status = status
|
||||
self.completed_at = None
|
||||
self.data = None
|
||||
self.validation_errors = []
|
||||
|
||||
|
||||
class OnboardingProgress:
|
||||
"""Manages onboarding progress with database persistence only."""
|
||||
|
||||
def __init__(self, user_id: Optional[str] = None):
|
||||
self.steps = self._initialize_steps()
|
||||
self.current_step = 1
|
||||
self.started_at = datetime.now().isoformat()
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.is_completed = False
|
||||
self.completed_at = None
|
||||
self.user_id = user_id # Add user_id for database isolation
|
||||
|
||||
# Initialize database service for persistence
|
||||
try:
|
||||
from .database_service import OnboardingDatabaseService
|
||||
self.db_service = OnboardingDatabaseService()
|
||||
self.use_database = True
|
||||
logger.info(f"Database service initialized for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Database service not available: {e}")
|
||||
self.db_service = None
|
||||
self.use_database = False
|
||||
raise Exception(f"Database service required but not available: {e}")
|
||||
|
||||
# Load existing progress from database if available
|
||||
if self.use_database and self.user_id:
|
||||
self.load_progress_from_db()
|
||||
|
||||
def _initialize_steps(self) -> List[StepData]:
|
||||
"""Initialize the 6-step onboarding process."""
|
||||
return [
|
||||
StepData(1, "AI LLM Providers", "Configure AI language model providers", StepStatus.PENDING),
|
||||
StepData(2, "Website Analysis", "Set up website analysis and crawling", StepStatus.PENDING),
|
||||
StepData(3, "AI Research", "Configure AI research capabilities", StepStatus.PENDING),
|
||||
StepData(4, "Personalization", "Set up personalization features", StepStatus.PENDING),
|
||||
StepData(5, "Integrations", "Configure ALwrity integrations", StepStatus.PENDING),
|
||||
StepData(6, "Complete Setup", "Finalize and complete onboarding", StepStatus.PENDING)
|
||||
]
|
||||
|
||||
def get_step_data(self, step_number: int) -> Optional[StepData]:
|
||||
"""Get data for a specific step."""
|
||||
for step in self.steps:
|
||||
if step.step_number == step_number:
|
||||
return step
|
||||
return None
|
||||
|
||||
def mark_step_completed(self, step_number: int, data: Optional[Dict[str, Any]] = None):
|
||||
"""Mark a step as completed."""
|
||||
logger.info(f"[mark_step_completed] Marking step {step_number} as completed")
|
||||
step = self.get_step_data(step_number)
|
||||
if step:
|
||||
step.status = StepStatus.COMPLETED
|
||||
step.completed_at = datetime.now().isoformat()
|
||||
step.data = data
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
|
||||
# Check if all steps are now completed
|
||||
all_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps)
|
||||
|
||||
if all_completed:
|
||||
# If all steps are completed, mark onboarding as complete
|
||||
self.is_completed = True
|
||||
self.completed_at = datetime.now().isoformat()
|
||||
self.current_step = len(self.steps) # Set to last step number
|
||||
logger.info(f"[mark_step_completed] All steps completed, marking onboarding as complete")
|
||||
else:
|
||||
# Only increment current_step if there are more steps to go
|
||||
self.current_step = step_number + 1
|
||||
# Ensure current_step doesn't exceed total steps
|
||||
if self.current_step > len(self.steps):
|
||||
self.current_step = len(self.steps)
|
||||
|
||||
logger.info(f"[mark_step_completed] Step {step_number} completed, new current_step: {self.current_step}, is_completed: {self.is_completed}")
|
||||
self.save_progress()
|
||||
logger.info(f"Step {step_number} marked as completed")
|
||||
else:
|
||||
logger.error(f"[mark_step_completed] Step {step_number} not found")
|
||||
|
||||
def mark_step_in_progress(self, step_number: int):
|
||||
"""Mark a step as in progress."""
|
||||
step = self.get_step_data(step_number)
|
||||
if step:
|
||||
step.status = StepStatus.IN_PROGRESS
|
||||
self.current_step = step_number
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.save_progress()
|
||||
logger.info(f"Step {step_number} marked as in progress")
|
||||
else:
|
||||
logger.error(f"Step {step_number} not found")
|
||||
|
||||
def mark_step_skipped(self, step_number: int):
|
||||
"""Mark a step as skipped."""
|
||||
step = self.get_step_data(step_number)
|
||||
if step:
|
||||
step.status = StepStatus.SKIPPED
|
||||
step.completed_at = datetime.now().isoformat()
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.save_progress()
|
||||
logger.info(f"Step {step_number} marked as skipped")
|
||||
else:
|
||||
logger.error(f"Step {step_number} not found")
|
||||
|
||||
def mark_step_failed(self, step_number: int, error_message: str):
|
||||
"""Mark a step as failed with error message."""
|
||||
step = self.get_step_data(step_number)
|
||||
if step:
|
||||
step.status = StepStatus.FAILED
|
||||
step.validation_errors.append(error_message)
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.save_progress()
|
||||
logger.error(f"Step {step_number} marked as failed: {error_message}")
|
||||
else:
|
||||
logger.error(f"Step {step_number} not found")
|
||||
|
||||
def get_progress_summary(self) -> Dict[str, Any]:
|
||||
"""Get current progress summary."""
|
||||
completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED)
|
||||
skipped_count = sum(1 for s in self.steps if s.status == StepStatus.SKIPPED)
|
||||
failed_count = sum(1 for s in self.steps if s.status == StepStatus.FAILED)
|
||||
|
||||
return {
|
||||
"total_steps": len(self.steps),
|
||||
"completed_steps": completed_count,
|
||||
"skipped_steps": skipped_count,
|
||||
"failed_steps": failed_count,
|
||||
"current_step": self.current_step,
|
||||
"is_completed": self.is_completed,
|
||||
"progress_percentage": (completed_count + skipped_count) / len(self.steps) * 100
|
||||
}
|
||||
|
||||
def get_next_step(self) -> Optional[StepData]:
|
||||
"""Get the next step to work on."""
|
||||
for step in self.steps:
|
||||
if step.status == StepStatus.PENDING:
|
||||
return step
|
||||
return None
|
||||
|
||||
def get_completed_steps(self) -> List[StepData]:
|
||||
"""Get all completed steps."""
|
||||
return [step for step in self.steps if step.status == StepStatus.COMPLETED]
|
||||
|
||||
def get_failed_steps(self) -> List[StepData]:
|
||||
"""Get all failed steps."""
|
||||
return [step for step in self.steps if step.status == StepStatus.FAILED]
|
||||
|
||||
def reset_step(self, step_number: int):
|
||||
"""Reset a step to pending status."""
|
||||
step = self.get_step_data(step_number)
|
||||
if step:
|
||||
step.status = StepStatus.PENDING
|
||||
step.completed_at = None
|
||||
step.data = None
|
||||
step.validation_errors = []
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.save_progress()
|
||||
logger.info(f"Step {step_number} reset to pending")
|
||||
else:
|
||||
logger.error(f"Step {step_number} not found")
|
||||
|
||||
def reset_all_steps(self):
|
||||
"""Reset all steps to pending status."""
|
||||
for step in self.steps:
|
||||
step.status = StepStatus.PENDING
|
||||
step.completed_at = None
|
||||
step.data = None
|
||||
step.validation_errors = []
|
||||
|
||||
self.current_step = 1
|
||||
self.is_completed = False
|
||||
self.completed_at = None
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.save_progress()
|
||||
logger.info("All steps reset to pending")
|
||||
|
||||
def complete_onboarding(self):
|
||||
"""Mark onboarding as complete."""
|
||||
self.is_completed = True
|
||||
self.completed_at = datetime.now().isoformat()
|
||||
self.current_step = len(self.steps)
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.save_progress()
|
||||
logger.info("Onboarding completed successfully")
|
||||
|
||||
def save_progress(self):
|
||||
"""Save progress to database."""
|
||||
if not self.use_database or not self.db_service or not self.user_id:
|
||||
logger.error("Cannot save progress: database service not available or user_id not set")
|
||||
return
|
||||
|
||||
try:
|
||||
from services.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Update session progress
|
||||
self.db_service.update_step(self.user_id, self.current_step, db)
|
||||
|
||||
# Calculate progress percentage
|
||||
completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED)
|
||||
progress_pct = (completed_count / len(self.steps)) * 100
|
||||
self.db_service.update_progress(self.user_id, progress_pct, db)
|
||||
|
||||
# Save step-specific data to appropriate tables
|
||||
for step in self.steps:
|
||||
if step.status == StepStatus.COMPLETED and step.data:
|
||||
if step.step_number == 1: # API Keys
|
||||
api_keys = step.data.get('api_keys', {})
|
||||
for provider, key in api_keys.items():
|
||||
if key:
|
||||
# Save to database (for user isolation in production)
|
||||
self.db_service.save_api_key(self.user_id, provider, key, db)
|
||||
|
||||
# Also save to .env file ONLY in local development
|
||||
# This allows local developers to have keys in .env for convenience
|
||||
# In production, keys are fetched from database per user
|
||||
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
|
||||
if is_local:
|
||||
try:
|
||||
from services.api_key_manager import APIKeyManager
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key_manager.save_api_key(provider, key)
|
||||
logger.info(f"[LOCAL] API key for {provider} saved to .env file")
|
||||
except Exception as env_error:
|
||||
logger.warning(f"[LOCAL] Failed to save {provider} API key to .env file: {env_error}")
|
||||
else:
|
||||
logger.info(f"[PRODUCTION] API key for {provider} saved to database only (user: {self.user_id})")
|
||||
|
||||
# Log database save confirmation
|
||||
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
|
||||
elif step.step_number == 2: # Website Analysis
|
||||
# Transform frontend data structure to match database schema
|
||||
# Frontend sends: { website: "url", analysis: {...} }
|
||||
# Database expects: { website_url: "url", ...analysis (flattened) }
|
||||
analysis_for_db = {}
|
||||
if step.data:
|
||||
# Extract website_url from 'website' or 'website_url' field
|
||||
website_url = step.data.get('website') or step.data.get('website_url')
|
||||
if website_url:
|
||||
analysis_for_db['website_url'] = website_url
|
||||
# Flatten nested 'analysis' object if it exists
|
||||
if 'analysis' in step.data and isinstance(step.data['analysis'], dict):
|
||||
analysis_for_db.update(step.data['analysis'])
|
||||
# Also include any other top-level fields (except 'website' and 'analysis')
|
||||
for key, value in step.data.items():
|
||||
if key not in ['website', 'website_url', 'analysis']:
|
||||
analysis_for_db[key] = value
|
||||
# Ensure status is set
|
||||
if 'status' not in analysis_for_db:
|
||||
analysis_for_db['status'] = 'completed'
|
||||
|
||||
self.db_service.save_website_analysis(self.user_id, analysis_for_db, db)
|
||||
logger.info(f"✅ DATABASE: Website analysis saved to database for user {self.user_id}")
|
||||
elif step.step_number == 3: # Research Preferences
|
||||
self.db_service.save_research_preferences(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Research preferences saved to database for user {self.user_id}")
|
||||
elif step.step_number == 4: # Persona Generation
|
||||
self.db_service.save_persona_data(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Persona data saved to database for user {self.user_id}")
|
||||
|
||||
logger.info(f"Progress saved to database for user {self.user_id}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving progress to database: {str(e)}")
|
||||
raise
|
||||
|
||||
def load_progress_from_db(self):
|
||||
"""Load progress from database."""
|
||||
if not self.use_database or not self.db_service or not self.user_id:
|
||||
logger.warning("Cannot load progress: database service not available or user_id not set")
|
||||
return
|
||||
|
||||
try:
|
||||
from services.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get session data
|
||||
session = self.db_service.get_session_by_user(self.user_id, db)
|
||||
if not session:
|
||||
logger.info(f"No existing onboarding session found for user {self.user_id}, starting fresh")
|
||||
return
|
||||
|
||||
# Restore session data
|
||||
self.current_step = session.current_step or 1
|
||||
self.started_at = session.started_at.isoformat() if session.started_at else self.started_at
|
||||
self.last_updated = session.last_updated.isoformat() if session.last_updated else self.last_updated
|
||||
self.is_completed = session.is_completed or False
|
||||
self.completed_at = session.completed_at.isoformat() if session.completed_at else None
|
||||
|
||||
# Load step-specific data from database
|
||||
self._load_step_data_from_db(db)
|
||||
|
||||
# Fix any corrupted state
|
||||
self._fix_corrupted_state()
|
||||
|
||||
logger.info(f"Progress loaded from database for user {self.user_id}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading progress from database: {str(e)}")
|
||||
# Don't fail if database loading fails - start fresh
|
||||
|
||||
def _load_step_data_from_db(self, db):
|
||||
"""Load step-specific data from database tables."""
|
||||
try:
|
||||
# Load API keys (step 1)
|
||||
api_keys = self.db_service.get_api_keys(self.user_id, db)
|
||||
if api_keys:
|
||||
step1 = self.get_step_data(1)
|
||||
if step1:
|
||||
step1.status = StepStatus.COMPLETED
|
||||
step1.data = {'api_keys': api_keys}
|
||||
step1.completed_at = datetime.now().isoformat()
|
||||
|
||||
# Load website analysis (step 2)
|
||||
website_analysis = self.db_service.get_website_analysis(self.user_id, db)
|
||||
if website_analysis:
|
||||
step2 = self.get_step_data(2)
|
||||
if step2:
|
||||
step2.status = StepStatus.COMPLETED
|
||||
step2.data = website_analysis
|
||||
step2.completed_at = datetime.now().isoformat()
|
||||
|
||||
# Load research preferences (step 3)
|
||||
research_prefs = self.db_service.get_research_preferences(self.user_id, db)
|
||||
if research_prefs:
|
||||
step3 = self.get_step_data(3)
|
||||
if step3:
|
||||
step3.status = StepStatus.COMPLETED
|
||||
step3.data = research_prefs
|
||||
step3.completed_at = datetime.now().isoformat()
|
||||
|
||||
# Load persona data (step 4)
|
||||
persona_data = self.db_service.get_persona_data(self.user_id, db)
|
||||
if persona_data:
|
||||
step4 = self.get_step_data(4)
|
||||
if step4:
|
||||
step4.status = StepStatus.COMPLETED
|
||||
step4.data = persona_data
|
||||
step4.completed_at = datetime.now().isoformat()
|
||||
|
||||
logger.info("Step data loaded from database")
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading step data from database: {str(e)}")
|
||||
|
||||
def _fix_corrupted_state(self):
|
||||
"""Fix any corrupted progress state."""
|
||||
# Check if all steps are completed
|
||||
all_steps_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps)
|
||||
|
||||
if all_steps_completed:
|
||||
self.is_completed = True
|
||||
self.completed_at = self.completed_at or datetime.now().isoformat()
|
||||
self.current_step = len(self.steps)
|
||||
else:
|
||||
# Find the first incomplete step
|
||||
for i, step in enumerate(self.steps):
|
||||
if step.status == StepStatus.PENDING:
|
||||
self.current_step = step.step_number
|
||||
break
|
||||
|
||||
|
||||
class APIKeyManager:
|
||||
"""Manages API keys for different providers."""
|
||||
|
||||
def __init__(self):
|
||||
self.api_keys = {}
|
||||
self._load_from_env()
|
||||
|
||||
def _load_from_env(self):
|
||||
"""Load API keys from environment variables."""
|
||||
providers = [
|
||||
'GEMINI_API_KEY',
|
||||
'HF_TOKEN',
|
||||
'TAVILY_API_KEY',
|
||||
'SERPER_API_KEY',
|
||||
'METAPHOR_API_KEY',
|
||||
'FIRECRAWL_API_KEY',
|
||||
'STABILITY_API_KEY',
|
||||
'WAVESPEED_API_KEY',
|
||||
]
|
||||
|
||||
for provider in providers:
|
||||
key = os.getenv(provider)
|
||||
if key:
|
||||
# Convert provider name to lowercase for consistency
|
||||
provider_name = provider.replace('_API_KEY', '').lower()
|
||||
self.api_keys[provider_name] = key
|
||||
logger.info(f"Loaded {provider_name} API key from environment")
|
||||
|
||||
def get_api_key(self, provider: str) -> Optional[str]:
|
||||
"""Get API key for a provider."""
|
||||
return self.api_keys.get(provider.lower())
|
||||
|
||||
def save_api_key(self, provider: str, api_key: str):
|
||||
"""Save API key to environment and memory."""
|
||||
provider_lower = provider.lower()
|
||||
self.api_keys[provider_lower] = api_key
|
||||
|
||||
# Update environment variable
|
||||
env_var = f"{provider.upper()}_API_KEY"
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
logger.info(f"Saved {provider} API key")
|
||||
|
||||
def has_api_key(self, provider: str) -> bool:
|
||||
"""Check if API key exists for provider."""
|
||||
return provider.lower() in self.api_keys and bool(self.api_keys[provider.lower()])
|
||||
|
||||
def get_all_keys(self) -> Dict[str, str]:
|
||||
"""Get all API keys."""
|
||||
return self.api_keys.copy()
|
||||
|
||||
def remove_api_key(self, provider: str):
|
||||
"""Remove API key for provider."""
|
||||
provider_lower = provider.lower()
|
||||
if provider_lower in self.api_keys:
|
||||
del self.api_keys[provider_lower]
|
||||
|
||||
# Remove from environment
|
||||
env_var = f"{provider.upper()}_API_KEY"
|
||||
if env_var in os.environ:
|
||||
del os.environ[env_var]
|
||||
|
||||
logger.info(f"Removed {provider} API key")
|
||||
|
||||
|
||||
# Global instances
|
||||
_user_onboarding_progress_cache = {}
|
||||
|
||||
def get_user_onboarding_progress(user_id: str) -> OnboardingProgress:
|
||||
"""Get user-specific onboarding progress instance."""
|
||||
global _user_onboarding_progress_cache
|
||||
safe_user_id = ''.join([c if c.isalnum() or c in ('-', '_') else '_' for c in str(user_id)])
|
||||
if safe_user_id in _user_onboarding_progress_cache:
|
||||
return _user_onboarding_progress_cache[safe_user_id]
|
||||
|
||||
# Pass user_id to enable database persistence
|
||||
instance = OnboardingProgress(user_id=user_id)
|
||||
_user_onboarding_progress_cache[safe_user_id] = instance
|
||||
return instance
|
||||
|
||||
def get_onboarding_progress_for_user(user_id: str) -> OnboardingProgress:
|
||||
"""Get user-specific onboarding progress instance (alias for compatibility)."""
|
||||
return get_user_onboarding_progress(user_id)
|
||||
|
||||
def get_onboarding_progress():
|
||||
"""Get the global onboarding progress instance."""
|
||||
if not hasattr(get_onboarding_progress, '_instance'):
|
||||
get_onboarding_progress._instance = OnboardingProgress()
|
||||
return get_onboarding_progress._instance
|
||||
|
||||
def get_api_key_manager() -> APIKeyManager:
|
||||
"""Get the global API key manager instance."""
|
||||
if not hasattr(get_api_key_manager, '_instance'):
|
||||
get_api_key_manager._instance = APIKeyManager()
|
||||
return get_api_key_manager._instance
|
||||
291
backend/services/onboarding/data_service.py
Normal file
291
backend/services/onboarding/data_service.py
Normal file
@@ -0,0 +1,291 @@
|
||||
"""
|
||||
Onboarding Data Service
|
||||
Extracts real user data from onboarding to personalize AI inputs
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
import json
|
||||
|
||||
from services.database import get_db_session
|
||||
from models.onboarding import OnboardingSession, WebsiteAnalysis, ResearchPreferences
|
||||
|
||||
class OnboardingDataService:
|
||||
"""Service to extract and use real onboarding data for AI personalization."""
|
||||
|
||||
def __init__(self, db: Optional[Session] = None):
|
||||
"""Initialize the onboarding data service."""
|
||||
self.db = db
|
||||
logger.info("OnboardingDataService initialized")
|
||||
|
||||
def get_user_website_analysis(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get website analysis data for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: User ID to get data for
|
||||
|
||||
Returns:
|
||||
Website analysis data or None if not found
|
||||
"""
|
||||
try:
|
||||
session = self.db or get_db_session()
|
||||
|
||||
# Find onboarding session for user
|
||||
onboarding_session = session.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not onboarding_session:
|
||||
logger.warning(f"No onboarding session found for user {user_id}")
|
||||
return None
|
||||
|
||||
# Get website analysis for this session
|
||||
website_analysis = session.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == onboarding_session.id
|
||||
).first()
|
||||
|
||||
if not website_analysis:
|
||||
logger.warning(f"No website analysis found for user {user_id}")
|
||||
return None
|
||||
|
||||
return website_analysis.to_dict()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting website analysis for user {user_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_user_research_preferences(self, user_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get research preferences for a specific user.
|
||||
|
||||
Args:
|
||||
user_id: User ID to get data for
|
||||
|
||||
Returns:
|
||||
Research preferences data or None if not found
|
||||
"""
|
||||
try:
|
||||
session = self.db or get_db_session()
|
||||
|
||||
# Find onboarding session for user
|
||||
onboarding_session = session.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not onboarding_session:
|
||||
logger.warning(f"No onboarding session found for user {user_id}")
|
||||
return None
|
||||
|
||||
# Get research preferences for this session
|
||||
research_prefs = session.query(ResearchPreferences).filter(
|
||||
ResearchPreferences.session_id == onboarding_session.id
|
||||
).first()
|
||||
|
||||
if not research_prefs:
|
||||
logger.warning(f"No research preferences found for user {user_id}")
|
||||
return None
|
||||
|
||||
return research_prefs.to_dict()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting research preferences for user {user_id}: {str(e)}")
|
||||
return None
|
||||
|
||||
def get_personalized_ai_inputs(self, user_id: int) -> Dict[str, Any]:
|
||||
"""
|
||||
Get personalized AI inputs based on user's onboarding data.
|
||||
|
||||
Args:
|
||||
user_id: User ID to get personalized data for
|
||||
|
||||
Returns:
|
||||
Personalized data for AI analysis
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Getting personalized AI inputs for user {user_id}")
|
||||
|
||||
# Get website analysis
|
||||
website_analysis = self.get_user_website_analysis(user_id)
|
||||
research_prefs = self.get_user_research_preferences(user_id)
|
||||
|
||||
if not website_analysis:
|
||||
logger.warning(f"No onboarding data found for user {user_id}, using defaults")
|
||||
return self._get_default_ai_inputs()
|
||||
|
||||
# Extract real data from website analysis
|
||||
writing_style = website_analysis.get('writing_style', {})
|
||||
target_audience = website_analysis.get('target_audience', {})
|
||||
content_type = website_analysis.get('content_type', {})
|
||||
recommended_settings = website_analysis.get('recommended_settings', {})
|
||||
|
||||
# Build personalized AI inputs
|
||||
personalized_inputs = {
|
||||
"website_analysis": {
|
||||
"website_url": website_analysis.get('website_url', ''),
|
||||
"content_types": self._extract_content_types(content_type),
|
||||
"writing_style": writing_style.get('tone', 'professional'),
|
||||
"target_audience": target_audience.get('demographics', ['professionals']),
|
||||
"industry_focus": target_audience.get('industry_focus', 'general'),
|
||||
"expertise_level": target_audience.get('expertise_level', 'intermediate')
|
||||
},
|
||||
"competitor_analysis": {
|
||||
"top_performers": self._generate_competitor_suggestions(target_audience),
|
||||
"industry": target_audience.get('industry_focus', 'general'),
|
||||
"target_demographics": target_audience.get('demographics', [])
|
||||
},
|
||||
"gap_analysis": {
|
||||
"content_gaps": self._identify_content_gaps(content_type, writing_style),
|
||||
"target_keywords": self._generate_target_keywords(target_audience),
|
||||
"content_opportunities": self._identify_opportunities(content_type)
|
||||
},
|
||||
"keyword_analysis": {
|
||||
"high_value_keywords": self._generate_high_value_keywords(target_audience),
|
||||
"content_topics": self._generate_content_topics(content_type),
|
||||
"search_intent": self._analyze_search_intent(target_audience)
|
||||
}
|
||||
}
|
||||
|
||||
# Add research preferences if available
|
||||
if research_prefs:
|
||||
personalized_inputs["research_preferences"] = {
|
||||
"research_depth": research_prefs.get('research_depth', 'Standard'),
|
||||
"content_types": research_prefs.get('content_types', []),
|
||||
"auto_research": research_prefs.get('auto_research', True),
|
||||
"factual_content": research_prefs.get('factual_content', True)
|
||||
}
|
||||
|
||||
logger.info(f"✅ Generated personalized AI inputs for user {user_id}")
|
||||
return personalized_inputs
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating personalized AI inputs for user {user_id}: {str(e)}")
|
||||
return self._get_default_ai_inputs()
|
||||
|
||||
def _extract_content_types(self, content_type: Dict[str, Any]) -> List[str]:
|
||||
"""Extract content types from content type analysis."""
|
||||
types = []
|
||||
if content_type.get('primary_type'):
|
||||
types.append(content_type['primary_type'])
|
||||
if content_type.get('secondary_types'):
|
||||
types.extend(content_type['secondary_types'])
|
||||
return types if types else ['blog', 'article']
|
||||
|
||||
def _generate_competitor_suggestions(self, target_audience: Dict[str, Any]) -> List[str]:
|
||||
"""Generate competitor suggestions based on target audience."""
|
||||
industry = target_audience.get('industry_focus', 'general')
|
||||
demographics = target_audience.get('demographics', ['professionals'])
|
||||
|
||||
# Generate industry-specific competitors
|
||||
if industry == 'technology':
|
||||
return ['techcrunch.com', 'wired.com', 'theverge.com']
|
||||
elif industry == 'marketing':
|
||||
return ['hubspot.com', 'marketingland.com', 'moz.com']
|
||||
else:
|
||||
return ['competitor1.com', 'competitor2.com', 'competitor3.com']
|
||||
|
||||
def _identify_content_gaps(self, content_type: Dict[str, Any], writing_style: Dict[str, Any]) -> List[str]:
|
||||
"""Identify content gaps based on current content type and style."""
|
||||
gaps = []
|
||||
primary_type = content_type.get('primary_type', 'blog')
|
||||
|
||||
if primary_type == 'blog':
|
||||
gaps.extend(['Video tutorials', 'Case studies', 'Infographics'])
|
||||
elif primary_type == 'video':
|
||||
gaps.extend(['Blog posts', 'Whitepapers', 'Webinars'])
|
||||
|
||||
# Add style-based gaps
|
||||
tone = writing_style.get('tone', 'professional')
|
||||
if tone == 'professional':
|
||||
gaps.append('Personal stories')
|
||||
elif tone == 'casual':
|
||||
gaps.append('Expert interviews')
|
||||
|
||||
return gaps
|
||||
|
||||
def _generate_target_keywords(self, target_audience: Dict[str, Any]) -> List[str]:
|
||||
"""Generate target keywords based on audience analysis."""
|
||||
industry = target_audience.get('industry_focus', 'general')
|
||||
expertise = target_audience.get('expertise_level', 'intermediate')
|
||||
|
||||
if industry == 'technology':
|
||||
return ['AI tools', 'Digital transformation', 'Tech trends']
|
||||
elif industry == 'marketing':
|
||||
return ['Content marketing', 'SEO strategies', 'Social media']
|
||||
else:
|
||||
return ['Industry insights', 'Best practices', 'Expert tips']
|
||||
|
||||
def _identify_opportunities(self, content_type: Dict[str, Any]) -> List[str]:
|
||||
"""Identify content opportunities based on current content type."""
|
||||
opportunities = []
|
||||
purpose = content_type.get('purpose', 'informational')
|
||||
|
||||
if purpose == 'informational':
|
||||
opportunities.extend(['How-to guides', 'Tutorials', 'Educational content'])
|
||||
elif purpose == 'promotional':
|
||||
opportunities.extend(['Case studies', 'Testimonials', 'Success stories'])
|
||||
|
||||
return opportunities
|
||||
|
||||
def _generate_high_value_keywords(self, target_audience: Dict[str, Any]) -> List[str]:
|
||||
"""Generate high-value keywords based on audience analysis."""
|
||||
industry = target_audience.get('industry_focus', 'general')
|
||||
|
||||
if industry == 'technology':
|
||||
return ['AI marketing', 'Content automation', 'Digital strategy']
|
||||
elif industry == 'marketing':
|
||||
return ['Content marketing', 'SEO optimization', 'Social media strategy']
|
||||
else:
|
||||
return ['Industry trends', 'Best practices', 'Expert insights']
|
||||
|
||||
def _generate_content_topics(self, content_type: Dict[str, Any]) -> List[str]:
|
||||
"""Generate content topics based on content type analysis."""
|
||||
topics = []
|
||||
primary_type = content_type.get('primary_type', 'blog')
|
||||
|
||||
if primary_type == 'blog':
|
||||
topics.extend(['Industry trends', 'How-to guides', 'Expert insights'])
|
||||
elif primary_type == 'video':
|
||||
topics.extend(['Tutorials', 'Product demos', 'Expert interviews'])
|
||||
|
||||
return topics
|
||||
|
||||
def _analyze_search_intent(self, target_audience: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Analyze search intent based on target audience."""
|
||||
expertise = target_audience.get('expertise_level', 'intermediate')
|
||||
|
||||
if expertise == 'beginner':
|
||||
return {'intent': 'educational', 'focus': 'basic concepts'}
|
||||
elif expertise == 'intermediate':
|
||||
return {'intent': 'practical', 'focus': 'implementation'}
|
||||
else:
|
||||
return {'intent': 'advanced', 'focus': 'strategic insights'}
|
||||
|
||||
def _get_default_ai_inputs(self) -> Dict[str, Any]:
|
||||
"""Get default AI inputs when no onboarding data is available."""
|
||||
return {
|
||||
"website_analysis": {
|
||||
"content_types": ["blog", "video", "social"],
|
||||
"writing_style": "professional",
|
||||
"target_audience": ["professionals"],
|
||||
"industry_focus": "general",
|
||||
"expertise_level": "intermediate"
|
||||
},
|
||||
"competitor_analysis": {
|
||||
"top_performers": ["competitor1.com", "competitor2.com"],
|
||||
"industry": "general",
|
||||
"target_demographics": ["professionals"]
|
||||
},
|
||||
"gap_analysis": {
|
||||
"content_gaps": ["AI content", "Video tutorials", "Case studies"],
|
||||
"target_keywords": ["Industry insights", "Best practices"],
|
||||
"content_opportunities": ["How-to guides", "Tutorials"]
|
||||
},
|
||||
"keyword_analysis": {
|
||||
"high_value_keywords": ["AI marketing", "Content automation", "Digital strategy"],
|
||||
"content_topics": ["Industry trends", "Expert insights"],
|
||||
"search_intent": {"intent": "practical", "focus": "implementation"}
|
||||
}
|
||||
}
|
||||
666
backend/services/onboarding/database_service.py
Normal file
666
backend/services/onboarding/database_service.py
Normal file
@@ -0,0 +1,666 @@
|
||||
"""
|
||||
Onboarding Database Service
|
||||
Provides database-backed storage for onboarding progress with user isolation.
|
||||
This replaces the JSON file-based storage with proper database persistence.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy import text
|
||||
|
||||
from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData
|
||||
from services.database import get_db
|
||||
|
||||
|
||||
class OnboardingDatabaseService:
|
||||
"""Database service for onboarding with user isolation."""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
"""Initialize with optional database session."""
|
||||
self.db = db
|
||||
# Cache for schema feature detection
|
||||
self._brand_cols_checked: bool = False
|
||||
self._brand_cols_available: bool = False
|
||||
self._research_persona_cols_checked: bool = False
|
||||
self._research_persona_cols_available: bool = False
|
||||
|
||||
# --- Feature flags and schema detection helpers ---
|
||||
def _brand_feature_enabled(self) -> bool:
|
||||
"""Check if writing brand-related columns is enabled via env flag."""
|
||||
return os.getenv('ENABLE_WEBSITE_BRAND_COLUMNS', 'true').lower() in {'1', 'true', 'yes', 'on'}
|
||||
|
||||
def _ensure_research_persona_columns(self, session_db: Session) -> None:
|
||||
"""Ensure research_persona columns exist in persona_data table (runtime migration)."""
|
||||
if self._research_persona_cols_checked:
|
||||
return
|
||||
|
||||
try:
|
||||
# Check if columns exist using PRAGMA (SQLite) or information_schema (PostgreSQL)
|
||||
db_url = str(session_db.bind.url) if session_db.bind else ""
|
||||
|
||||
if 'sqlite' in db_url.lower():
|
||||
# SQLite: Use PRAGMA to check columns
|
||||
result = session_db.execute(text("PRAGMA table_info(persona_data)"))
|
||||
cols = {row[1] for row in result} # Column name is at index 1
|
||||
|
||||
if 'research_persona' not in cols:
|
||||
logger.info("Adding missing column research_persona to persona_data table")
|
||||
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSON"))
|
||||
session_db.commit()
|
||||
|
||||
if 'research_persona_generated_at' not in cols:
|
||||
logger.info("Adding missing column research_persona_generated_at to persona_data table")
|
||||
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP"))
|
||||
session_db.commit()
|
||||
|
||||
self._research_persona_cols_available = True
|
||||
else:
|
||||
# PostgreSQL: Try to query the columns (will fail if they don't exist)
|
||||
try:
|
||||
session_db.execute(text("SELECT research_persona, research_persona_generated_at FROM persona_data LIMIT 0"))
|
||||
self._research_persona_cols_available = True
|
||||
except Exception:
|
||||
# Columns don't exist, add them
|
||||
logger.info("Adding missing columns research_persona and research_persona_generated_at to persona_data table")
|
||||
try:
|
||||
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona JSONB"))
|
||||
session_db.execute(text("ALTER TABLE persona_data ADD COLUMN research_persona_generated_at TIMESTAMP"))
|
||||
session_db.commit()
|
||||
self._research_persona_cols_available = True
|
||||
except Exception as alter_err:
|
||||
logger.error(f"Failed to add research_persona columns: {alter_err}")
|
||||
session_db.rollback()
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring research_persona columns: {e}")
|
||||
session_db.rollback()
|
||||
raise
|
||||
finally:
|
||||
self._research_persona_cols_checked = True
|
||||
|
||||
def _ensure_brand_column_detection(self, session_db: Session) -> None:
|
||||
"""Detect at runtime whether brand columns exist and cache the result."""
|
||||
if self._brand_cols_checked:
|
||||
return
|
||||
try:
|
||||
# This works across SQLite/Postgres; LIMIT 0 avoids scanning
|
||||
session_db.execute(text('SELECT brand_analysis, content_strategy_insights FROM website_analyses LIMIT 0'))
|
||||
self._brand_cols_available = True
|
||||
except Exception:
|
||||
self._brand_cols_available = False
|
||||
finally:
|
||||
self._brand_cols_checked = True
|
||||
|
||||
def _maybe_update_brand_columns(self, session_db: Session, session_id: int, brand_analysis: Any, content_strategy_insights: Any) -> None:
|
||||
"""Safely update brand columns using raw SQL if feature enabled and columns exist."""
|
||||
if not self._brand_feature_enabled():
|
||||
return
|
||||
self._ensure_brand_column_detection(session_db)
|
||||
if not self._brand_cols_available:
|
||||
return
|
||||
try:
|
||||
session_db.execute(
|
||||
text('''
|
||||
UPDATE website_analyses
|
||||
SET brand_analysis = :brand_analysis,
|
||||
content_strategy_insights = :content_strategy_insights
|
||||
WHERE session_id = :session_id
|
||||
'''),
|
||||
{
|
||||
'brand_analysis': json.dumps(brand_analysis) if brand_analysis is not None else None,
|
||||
'content_strategy_insights': json.dumps(content_strategy_insights) if content_strategy_insights is not None else None,
|
||||
'session_id': session_id,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipped updating brand columns (not critical): {e}")
|
||||
|
||||
def _maybe_attach_brand_columns(self, session_db: Session, session_id: int, result: Dict[str, Any]) -> None:
|
||||
"""Optionally read brand columns and attach to result if available."""
|
||||
if not self._brand_feature_enabled():
|
||||
return
|
||||
self._ensure_brand_column_detection(session_db)
|
||||
if not self._brand_cols_available:
|
||||
return
|
||||
try:
|
||||
row = session_db.execute(
|
||||
text('''
|
||||
SELECT brand_analysis, content_strategy_insights
|
||||
FROM website_analyses WHERE session_id = :session_id LIMIT 1
|
||||
'''),
|
||||
{'session_id': session_id}
|
||||
).mappings().first()
|
||||
if row:
|
||||
brand = row.get('brand_analysis')
|
||||
insights = row.get('content_strategy_insights')
|
||||
# If stored as TEXT in SQLite, try to parse JSON
|
||||
if isinstance(brand, str):
|
||||
try:
|
||||
brand = json.loads(brand)
|
||||
except Exception:
|
||||
pass
|
||||
if isinstance(insights, str):
|
||||
try:
|
||||
insights = json.loads(insights)
|
||||
except Exception:
|
||||
pass
|
||||
result['brand_analysis'] = brand
|
||||
result['content_strategy_insights'] = insights
|
||||
except Exception as e:
|
||||
logger.warning(f"Skipped reading brand columns (not critical): {e}")
|
||||
|
||||
def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession:
|
||||
"""Get existing onboarding session or create new one for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
# Try to get existing session for this user
|
||||
session = session_db.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
logger.info(f"Found existing onboarding session for user {user_id}")
|
||||
return session
|
||||
|
||||
# Create new session
|
||||
session = OnboardingSession(
|
||||
user_id=user_id,
|
||||
current_step=1,
|
||||
progress=0.0,
|
||||
started_at=datetime.now()
|
||||
)
|
||||
session_db.add(session)
|
||||
session_db.commit()
|
||||
session_db.refresh(session)
|
||||
|
||||
logger.info(f"Created new onboarding session for user {user_id}")
|
||||
return session
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error in get_or_create_session: {e}")
|
||||
session_db.rollback()
|
||||
raise
|
||||
|
||||
def get_session_by_user(self, user_id: str, db: Session = None) -> Optional[OnboardingSession]:
|
||||
"""Get onboarding session for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
return session_db.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting session: {e}")
|
||||
return None
|
||||
|
||||
def update_step(self, user_id: str, step_number: int, db: Session = None) -> bool:
|
||||
"""Update current step for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
session.current_step = step_number
|
||||
session.updated_at = datetime.now()
|
||||
session_db.commit()
|
||||
|
||||
logger.info(f"Updated user {user_id} to step {step_number}")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error updating step: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def update_progress(self, user_id: str, progress: float, db: Session = None) -> bool:
|
||||
"""Update progress percentage for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
session.progress = progress
|
||||
session.updated_at = datetime.now()
|
||||
session_db.commit()
|
||||
|
||||
logger.info(f"Updated user {user_id} progress to {progress}%")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error updating progress: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def save_api_key(self, user_id: str, provider: str, api_key: str, db: Session = None) -> bool:
|
||||
"""Save API key for user with isolation."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
# Get user's onboarding session
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
|
||||
# Check if key already exists for this provider and session
|
||||
existing_key = session_db.query(APIKey).filter(
|
||||
APIKey.session_id == session.id,
|
||||
APIKey.provider == provider
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
# Update existing key
|
||||
existing_key.key = api_key
|
||||
existing_key.updated_at = datetime.now()
|
||||
logger.info(f"Updated {provider} API key for user {user_id}")
|
||||
else:
|
||||
# Create new key
|
||||
new_key = APIKey(
|
||||
session_id=session.id,
|
||||
provider=provider,
|
||||
key=api_key
|
||||
)
|
||||
session_db.add(new_key)
|
||||
logger.info(f"Created new {provider} API key for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving API key: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_api_keys(self, user_id: str, db: Session = None) -> Dict[str, str]:
|
||||
"""Get all API keys for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return {}
|
||||
|
||||
keys = session_db.query(APIKey).filter(
|
||||
APIKey.session_id == session.id
|
||||
).all()
|
||||
|
||||
return {key.provider: key.key for key in keys}
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting API keys: {e}")
|
||||
return {}
|
||||
|
||||
def save_website_analysis(self, user_id: str, analysis_data: Dict[str, Any], db: Session = None) -> bool:
|
||||
"""Save website analysis for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
# Normalize payload. Step 2 sometimes sends { website, analysis: {...} }
|
||||
# while DB expects flattened fields. Support both shapes.
|
||||
incoming = analysis_data or {}
|
||||
nested = incoming.get('analysis') if isinstance(incoming.get('analysis'), dict) else None
|
||||
normalized = {
|
||||
'website_url': incoming.get('website') or incoming.get('website_url') or '',
|
||||
'writing_style': (nested or incoming).get('writing_style'),
|
||||
'content_characteristics': (nested or incoming).get('content_characteristics'),
|
||||
'target_audience': (nested or incoming).get('target_audience'),
|
||||
'content_type': (nested or incoming).get('content_type'),
|
||||
'recommended_settings': (nested or incoming).get('recommended_settings'),
|
||||
'brand_analysis': (nested or incoming).get('brand_analysis'),
|
||||
'content_strategy_insights': (nested or incoming).get('content_strategy_insights'),
|
||||
'crawl_result': (nested or incoming).get('crawl_result'),
|
||||
'style_patterns': (nested or incoming).get('style_patterns'),
|
||||
'style_guidelines': (nested or incoming).get('style_guidelines'),
|
||||
'status': (nested or incoming).get('status', incoming.get('status', 'completed')),
|
||||
}
|
||||
|
||||
# Check if analysis already exists
|
||||
existing = session_db.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing - only update website_url if normalized value is not empty
|
||||
# This prevents overwriting a valid URL with an empty string when step.data
|
||||
# doesn't include the website field
|
||||
normalized_url = normalized.get('website_url', '').strip() if normalized.get('website_url') else ''
|
||||
if normalized_url:
|
||||
existing.website_url = normalized_url
|
||||
# If normalized_url is empty, keep existing.website_url unchanged
|
||||
existing.writing_style = normalized.get('writing_style')
|
||||
existing.content_characteristics = normalized.get('content_characteristics')
|
||||
existing.target_audience = normalized.get('target_audience')
|
||||
existing.content_type = normalized.get('content_type')
|
||||
existing.recommended_settings = normalized.get('recommended_settings')
|
||||
existing.crawl_result = normalized.get('crawl_result')
|
||||
existing.style_patterns = normalized.get('style_patterns')
|
||||
existing.style_guidelines = normalized.get('style_guidelines')
|
||||
existing.status = normalized.get('status', 'completed')
|
||||
existing.updated_at = datetime.now()
|
||||
logger.info(f"Updated website analysis for user {user_id}")
|
||||
else:
|
||||
# Create new
|
||||
analysis = WebsiteAnalysis(
|
||||
session_id=session.id,
|
||||
website_url=normalized.get('website_url', ''),
|
||||
writing_style=normalized.get('writing_style'),
|
||||
content_characteristics=normalized.get('content_characteristics'),
|
||||
target_audience=normalized.get('target_audience'),
|
||||
content_type=normalized.get('content_type'),
|
||||
recommended_settings=normalized.get('recommended_settings'),
|
||||
crawl_result=normalized.get('crawl_result'),
|
||||
style_patterns=normalized.get('style_patterns'),
|
||||
style_guidelines=normalized.get('style_guidelines'),
|
||||
status=normalized.get('status', 'completed')
|
||||
)
|
||||
session_db.add(analysis)
|
||||
logger.info(f"Created website analysis for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
|
||||
# Optional brand column update via raw SQL (feature-flagged)
|
||||
self._maybe_update_brand_columns(
|
||||
session_db=session_db,
|
||||
session_id=session.id,
|
||||
brand_analysis=normalized.get('brand_analysis'),
|
||||
content_strategy_insights=normalized.get('content_strategy_insights')
|
||||
)
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving website analysis: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_website_analysis(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get website analysis for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
analysis = session_db.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
result = analysis.to_dict() if analysis else None
|
||||
if result:
|
||||
# Optionally include brand fields without touching ORM mapping
|
||||
self._maybe_attach_brand_columns(session_db, session.id, result)
|
||||
return result
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting website analysis: {e}")
|
||||
return None
|
||||
|
||||
def save_research_preferences(self, user_id: str, preferences: Dict[str, Any], db: Session = None) -> bool:
|
||||
"""Save research preferences for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
|
||||
# Check if preferences already exist
|
||||
existing = session_db.query(ResearchPreferences).filter(
|
||||
ResearchPreferences.session_id == session.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing
|
||||
existing.research_depth = preferences.get('research_depth', existing.research_depth)
|
||||
existing.content_types = preferences.get('content_types', existing.content_types)
|
||||
existing.auto_research = preferences.get('auto_research', existing.auto_research)
|
||||
existing.factual_content = preferences.get('factual_content', existing.factual_content)
|
||||
existing.writing_style = preferences.get('writing_style')
|
||||
existing.content_characteristics = preferences.get('content_characteristics')
|
||||
existing.target_audience = preferences.get('target_audience')
|
||||
existing.recommended_settings = preferences.get('recommended_settings')
|
||||
existing.updated_at = datetime.now()
|
||||
logger.info(f"Updated research preferences for user {user_id}")
|
||||
else:
|
||||
# Create new
|
||||
prefs = ResearchPreferences(
|
||||
session_id=session.id,
|
||||
research_depth=preferences.get('research_depth', 'standard'),
|
||||
content_types=preferences.get('content_types', []),
|
||||
auto_research=preferences.get('auto_research', True),
|
||||
factual_content=preferences.get('factual_content', True),
|
||||
writing_style=preferences.get('writing_style'),
|
||||
content_characteristics=preferences.get('content_characteristics'),
|
||||
target_audience=preferences.get('target_audience'),
|
||||
recommended_settings=preferences.get('recommended_settings')
|
||||
)
|
||||
session_db.add(prefs)
|
||||
logger.info(f"Created research preferences for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving research preferences: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def save_persona_data(self, user_id: str, persona_data: Dict[str, Any], db: Session = None) -> bool:
|
||||
"""Save persona data for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
|
||||
# Check if persona data already exists for this user
|
||||
existing = session_db.query(PersonaData).filter(
|
||||
PersonaData.session_id == session.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing persona data
|
||||
existing.core_persona = persona_data.get('corePersona')
|
||||
existing.platform_personas = persona_data.get('platformPersonas')
|
||||
existing.quality_metrics = persona_data.get('qualityMetrics')
|
||||
existing.selected_platforms = persona_data.get('selectedPlatforms', [])
|
||||
existing.updated_at = datetime.utcnow()
|
||||
logger.info(f"Updated persona data for user {user_id}")
|
||||
else:
|
||||
# Create new persona data record
|
||||
persona = PersonaData(
|
||||
session_id=session.id,
|
||||
core_persona=persona_data.get('corePersona'),
|
||||
platform_personas=persona_data.get('platformPersonas'),
|
||||
quality_metrics=persona_data.get('qualityMetrics'),
|
||||
selected_platforms=persona_data.get('selectedPlatforms', [])
|
||||
)
|
||||
session_db.add(persona)
|
||||
logger.info(f"Created persona data for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving persona data: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_research_preferences(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get research preferences for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
prefs = session_db.query(ResearchPreferences).filter(
|
||||
ResearchPreferences.session_id == session.id
|
||||
).first()
|
||||
|
||||
return prefs.to_dict() if prefs else None
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting research preferences: {e}")
|
||||
return None
|
||||
|
||||
def get_competitor_analysis(self, user_id: str, db: Session = None) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Get competitor analysis data for user from onboarding."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
from models.onboarding import CompetitorAnalysis
|
||||
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
# Query CompetitorAnalysis table
|
||||
competitor_records = session_db.query(CompetitorAnalysis).filter(
|
||||
CompetitorAnalysis.session_id == session.id
|
||||
).all()
|
||||
|
||||
if not competitor_records:
|
||||
return None
|
||||
|
||||
# Convert to list of dicts
|
||||
competitors = []
|
||||
for record in competitor_records:
|
||||
analysis_data = record.analysis_data or {}
|
||||
competitors.append({
|
||||
"url": record.competitor_url,
|
||||
"domain": record.competitor_domain or record.competitor_url,
|
||||
"title": analysis_data.get("title", record.competitor_domain or ""),
|
||||
"summary": analysis_data.get("summary", ""),
|
||||
"relevance_score": analysis_data.get("relevance_score", 0.5),
|
||||
"highlights": analysis_data.get("highlights", []),
|
||||
"favicon": analysis_data.get("favicon"),
|
||||
"image": analysis_data.get("image"),
|
||||
"published_date": analysis_data.get("published_date"),
|
||||
"author": analysis_data.get("author"),
|
||||
"competitive_insights": analysis_data.get("competitive_analysis", {}),
|
||||
"content_insights": analysis_data.get("content_insights", {})
|
||||
})
|
||||
|
||||
return competitors
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting competitor analysis: {e}")
|
||||
return None
|
||||
|
||||
def get_persona_data(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get persona data for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
# Ensure research_persona columns exist before querying
|
||||
self._ensure_research_persona_columns(session_db)
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
persona = session_db.query(PersonaData).filter(
|
||||
PersonaData.session_id == session.id
|
||||
).first()
|
||||
|
||||
if not persona:
|
||||
return None
|
||||
|
||||
# Return persona data in the expected format
|
||||
return {
|
||||
'corePersona': persona.core_persona,
|
||||
'platformPersonas': persona.platform_personas,
|
||||
'qualityMetrics': persona.quality_metrics,
|
||||
'selectedPlatforms': persona.selected_platforms
|
||||
}
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting persona data: {e}")
|
||||
return None
|
||||
|
||||
def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool:
|
||||
"""Mark onboarding as complete for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
session.current_step = 6 # Final step
|
||||
session.progress = 100.0
|
||||
session.updated_at = datetime.now()
|
||||
session_db.commit()
|
||||
|
||||
logger.info(f"Marked onboarding complete for user {user_id}")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error marking onboarding complete: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_onboarding_status(self, user_id: str, db: Session = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive onboarding status for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
|
||||
if not session:
|
||||
# User hasn't started onboarding yet
|
||||
return {
|
||||
"is_completed": False,
|
||||
"current_step": 1,
|
||||
"progress": 0.0,
|
||||
"started_at": None,
|
||||
"updated_at": None
|
||||
}
|
||||
|
||||
return {
|
||||
"is_completed": session.current_step >= 6 and session.progress >= 100.0,
|
||||
"current_step": session.current_step,
|
||||
"progress": session.progress,
|
||||
"started_at": session.started_at.isoformat() if session.started_at else None,
|
||||
"updated_at": session.updated_at.isoformat() if session.updated_at else None
|
||||
}
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting onboarding status: {e}")
|
||||
return {
|
||||
"is_completed": False,
|
||||
"current_step": 1,
|
||||
"progress": 0.0,
|
||||
"started_at": None,
|
||||
"updated_at": None
|
||||
}
|
||||
|
||||
163
backend/services/onboarding/progress_service.py
Normal file
163
backend/services/onboarding/progress_service.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Database-only Onboarding Progress Service
|
||||
Replaces file-based progress tracking with database-only implementation.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from services.database import SessionLocal
|
||||
from .database_service import OnboardingDatabaseService
|
||||
|
||||
|
||||
class OnboardingProgressService:
|
||||
"""Database-only onboarding progress management."""
|
||||
|
||||
def __init__(self):
|
||||
self.db_service = OnboardingDatabaseService()
|
||||
|
||||
def get_onboarding_status(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get current onboarding status from database only."""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get session data
|
||||
session = self.db_service.get_session_by_user(user_id, db)
|
||||
if not session:
|
||||
return {
|
||||
"is_completed": False,
|
||||
"current_step": 1,
|
||||
"completion_percentage": 0.0,
|
||||
"started_at": None,
|
||||
"last_updated": None,
|
||||
"completed_at": None
|
||||
}
|
||||
|
||||
# Check if onboarding is complete
|
||||
# Consider complete if either the final step is reached OR progress hit 100%
|
||||
# This guards against partial writes where one field persisted but the other didn't.
|
||||
is_completed = (session.current_step >= 6) or (session.progress >= 100.0)
|
||||
|
||||
return {
|
||||
"is_completed": is_completed,
|
||||
"current_step": session.current_step,
|
||||
"completion_percentage": session.progress,
|
||||
"started_at": session.started_at.isoformat() if session.started_at else None,
|
||||
"last_updated": session.updated_at.isoformat() if session.updated_at else None,
|
||||
"completed_at": session.updated_at.isoformat() if is_completed else None
|
||||
}
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting onboarding status: {e}")
|
||||
return {
|
||||
"is_completed": False,
|
||||
"current_step": 1,
|
||||
"completion_percentage": 0.0,
|
||||
"started_at": None,
|
||||
"last_updated": None,
|
||||
"completed_at": None
|
||||
}
|
||||
|
||||
def update_step(self, user_id: str, step_number: int) -> bool:
|
||||
"""Update current step in database."""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
success = self.db_service.update_step(user_id, step_number, db)
|
||||
if success:
|
||||
logger.info(f"Updated user {user_id} to step {step_number}")
|
||||
return success
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating step: {e}")
|
||||
return False
|
||||
|
||||
def update_progress(self, user_id: str, progress_percentage: float) -> bool:
|
||||
"""Update progress percentage in database."""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
success = self.db_service.update_progress(user_id, progress_percentage, db)
|
||||
if success:
|
||||
logger.info(f"Updated user {user_id} progress to {progress_percentage}%")
|
||||
return success
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating progress: {e}")
|
||||
return False
|
||||
|
||||
def complete_onboarding(self, user_id: str) -> bool:
|
||||
"""Mark onboarding as complete in database."""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
success = self.db_service.mark_onboarding_complete(user_id, db)
|
||||
if success:
|
||||
logger.info(f"Marked onboarding complete for user {user_id}")
|
||||
return success
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error completing onboarding: {e}")
|
||||
return False
|
||||
|
||||
def reset_onboarding(self, user_id: str) -> bool:
|
||||
"""Reset onboarding progress in database."""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Reset to step 1, 0% progress
|
||||
success = self.db_service.update_step(user_id, 1, db)
|
||||
if success:
|
||||
self.db_service.update_progress(user_id, 0.0, db)
|
||||
logger.info(f"Reset onboarding for user {user_id}")
|
||||
return success
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error resetting onboarding: {e}")
|
||||
return False
|
||||
|
||||
def get_completion_data(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get completion data for validation."""
|
||||
try:
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Get all relevant data for completion validation
|
||||
session = self.db_service.get_session_by_user(user_id, db)
|
||||
api_keys = self.db_service.get_api_keys(user_id, db)
|
||||
website_analysis = self.db_service.get_website_analysis(user_id, db)
|
||||
research_preferences = self.db_service.get_research_preferences(user_id, db)
|
||||
persona_data = self.db_service.get_persona_data(user_id, db)
|
||||
|
||||
return {
|
||||
"session": session,
|
||||
"api_keys": api_keys,
|
||||
"website_analysis": website_analysis,
|
||||
"research_preferences": research_preferences,
|
||||
"persona_data": persona_data
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting completion data: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
# Global instance
|
||||
_onboarding_progress_service = None
|
||||
|
||||
def get_onboarding_progress_service() -> OnboardingProgressService:
|
||||
"""Get the global onboarding progress service instance."""
|
||||
global _onboarding_progress_service
|
||||
if _onboarding_progress_service is None:
|
||||
_onboarding_progress_service = OnboardingProgressService()
|
||||
return _onboarding_progress_service
|
||||
Reference in New Issue
Block a user