ALwrity onboarding final step
This commit is contained in:
@@ -165,10 +165,10 @@ class OnboardingManager:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@self.app.post("/api/onboarding/api-keys")
|
||||
async def api_key_save(request: APIKeyRequest):
|
||||
async def api_key_save(request: APIKeyRequest, current_user: dict = Depends(get_current_user)):
|
||||
"""Save an API key for a provider."""
|
||||
try:
|
||||
return await save_api_key(request)
|
||||
return await save_api_key(request, current_user)
|
||||
except Exception as e:
|
||||
logger.error(f"Error in api_key_save: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -95,6 +95,10 @@ class RouterManager:
|
||||
from routers.error_logging import router as error_logging_router
|
||||
self.include_router_safely(error_logging_router, "error_logging")
|
||||
|
||||
# Frontend environment manager router
|
||||
from routers.frontend_env_manager import router as frontend_env_router
|
||||
self.include_router_safely(frontend_env_router, "frontend_env_manager")
|
||||
|
||||
logger.info("✅ Core routers included successfully")
|
||||
return True
|
||||
|
||||
|
||||
@@ -15,7 +15,20 @@ class APIKeyManagementService:
|
||||
"""Service for handling API key management operations."""
|
||||
|
||||
def __init__(self):
|
||||
# Initialize APIKeyManager with database support
|
||||
self.api_key_manager = APIKeyManager()
|
||||
# Ensure database service is available
|
||||
if not hasattr(self.api_key_manager, 'use_database'):
|
||||
self.api_key_manager.use_database = True
|
||||
try:
|
||||
from services.onboarding_database_service import OnboardingDatabaseService
|
||||
self.api_key_manager.db_service = OnboardingDatabaseService()
|
||||
logger.info("Database service initialized for APIKeyManager")
|
||||
except Exception as e:
|
||||
logger.warning(f"Database service not available: {e}")
|
||||
self.api_key_manager.use_database = False
|
||||
self.api_key_manager.db_service = None
|
||||
|
||||
# Simple cache for API keys
|
||||
self._api_keys_cache = None
|
||||
self._cache_timestamp = 0
|
||||
@@ -75,9 +88,16 @@ class APIKeyManagementService:
|
||||
logger.error(f"Error getting API keys for onboarding: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
async def save_api_key(self, provider: str, api_key: str, description: str = None) -> Dict[str, Any]:
|
||||
async def save_api_key(self, provider: str, api_key: str, description: str = None, current_user: dict = None) -> Dict[str, Any]:
|
||||
"""Save an API key for a provider."""
|
||||
try:
|
||||
logger.info(f"📝 save_api_key called for provider: {provider}")
|
||||
|
||||
# Set user_id on the API key manager if available
|
||||
if current_user and current_user.get('id'):
|
||||
self.api_key_manager.user_id = current_user['id']
|
||||
logger.info(f"Set user_id on APIKeyManager: {current_user['id']}")
|
||||
|
||||
success = self.api_key_manager.save_api_key(provider, api_key)
|
||||
|
||||
if success:
|
||||
|
||||
@@ -35,11 +35,11 @@ async def get_api_keys_for_onboarding():
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
|
||||
async def save_api_key(request: APIKeyRequest):
|
||||
async def save_api_key(request: APIKeyRequest, current_user: dict = None):
|
||||
try:
|
||||
from api.onboarding_utils.api_key_management_service import APIKeyManagementService
|
||||
api_service = APIKeyManagementService()
|
||||
return await api_service.save_api_key(request.provider, request.api_key, request.description)
|
||||
return await api_service.save_api_key(request.provider, request.api_key, request.description, current_user)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving API key: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@@ -110,6 +110,16 @@ async def rate_limit_middleware(request: Request, call_next):
|
||||
"""Rate limiting middleware using modular utilities."""
|
||||
return await rate_limiter.rate_limit_middleware(request, call_next)
|
||||
|
||||
# API key injection middleware for production (user-specific keys)
|
||||
@app.middleware("http")
|
||||
async def inject_user_api_keys(request: Request, call_next):
|
||||
"""
|
||||
Inject user-specific API keys into environment for the request duration.
|
||||
This allows existing code using os.getenv() to work in production.
|
||||
"""
|
||||
from middleware.api_key_injection_middleware import api_key_injection_middleware
|
||||
return await api_key_injection_middleware(request, call_next)
|
||||
|
||||
# Health check endpoints using modular utilities
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
|
||||
26
backend/database/migrations/add_persona_data_table.sql
Normal file
26
backend/database/migrations/add_persona_data_table.sql
Normal file
@@ -0,0 +1,26 @@
|
||||
-- Migration: Add persona_data table for onboarding step 4
|
||||
-- Created: 2025-10-10
|
||||
-- Description: Adds table to store persona generation data from onboarding step 4
|
||||
|
||||
CREATE TABLE IF NOT EXISTS persona_data (
|
||||
id SERIAL PRIMARY KEY,
|
||||
session_id INTEGER NOT NULL,
|
||||
core_persona JSONB,
|
||||
platform_personas JSONB,
|
||||
quality_metrics JSONB,
|
||||
selected_platforms JSONB,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY (session_id) REFERENCES onboarding_sessions(id) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
-- Add index for better query performance
|
||||
CREATE INDEX IF NOT EXISTS idx_persona_data_session_id ON persona_data(session_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_persona_data_created_at ON persona_data(created_at);
|
||||
|
||||
-- Add comment to table
|
||||
COMMENT ON TABLE persona_data IS 'Stores persona generation data from onboarding step 4';
|
||||
COMMENT ON COLUMN persona_data.core_persona IS 'Core persona data (demographics, psychographics, etc.)';
|
||||
COMMENT ON COLUMN persona_data.platform_personas IS 'Platform-specific personas (LinkedIn, Twitter, etc.)';
|
||||
COMMENT ON COLUMN persona_data.quality_metrics IS 'Quality assessment metrics';
|
||||
COMMENT ON COLUMN persona_data.selected_platforms IS 'Array of selected platforms';
|
||||
@@ -22,3 +22,6 @@ WORDPRESS_REDIRECT_URI=
|
||||
|
||||
# Development Settings
|
||||
DISABLE_AUTH=false
|
||||
|
||||
# local development
|
||||
DEPLOY_ENV=local
|
||||
|
||||
114
backend/middleware/api_key_injection_middleware.py
Normal file
114
backend/middleware/api_key_injection_middleware.py
Normal file
@@ -0,0 +1,114 @@
|
||||
"""
|
||||
API Key Injection Middleware
|
||||
|
||||
Temporarily injects user-specific API keys into os.environ for the duration of the request.
|
||||
This allows existing code that uses os.getenv('GEMINI_API_KEY') to work without modification.
|
||||
|
||||
IMPORTANT: This is a compatibility layer. For new code, use UserAPIKeyContext directly.
|
||||
"""
|
||||
|
||||
import os
|
||||
from fastapi import Request
|
||||
from loguru import logger
|
||||
from typing import Callable
|
||||
from services.user_api_key_context import user_api_keys
|
||||
|
||||
|
||||
class APIKeyInjectionMiddleware:
|
||||
"""
|
||||
Middleware that injects user-specific API keys into environment variables
|
||||
for the duration of each request.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.original_keys = {}
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable):
|
||||
"""
|
||||
Inject user-specific API keys before processing request,
|
||||
restore original values after request completes.
|
||||
"""
|
||||
|
||||
# Try to extract user_id from Authorization header
|
||||
user_id = None
|
||||
auth_header = request.headers.get('Authorization')
|
||||
|
||||
if auth_header and auth_header.startswith('Bearer '):
|
||||
try:
|
||||
from middleware.auth_middleware import clerk_auth
|
||||
token = auth_header.replace('Bearer ', '')
|
||||
user = await clerk_auth.verify_token(token)
|
||||
if user:
|
||||
# Try different possible keys for user_id
|
||||
user_id = user.get('user_id') or user.get('clerk_user_id') or user.get('id')
|
||||
logger.debug(f"[API Key Injection] Extracted user_id: {user_id}")
|
||||
except Exception as e:
|
||||
logger.debug(f"[API Key Injection] Could not extract user from token: {e}")
|
||||
|
||||
if not user_id:
|
||||
# No authenticated user, proceed without injection
|
||||
return await call_next(request)
|
||||
|
||||
# Check if we're in production mode
|
||||
is_production = os.getenv('DEPLOY_ENV', 'local') == 'production'
|
||||
|
||||
if not is_production:
|
||||
# Local mode - keys already in .env, no injection needed
|
||||
return await call_next(request)
|
||||
|
||||
# Get user-specific API keys from database
|
||||
with user_api_keys(user_id) as user_keys:
|
||||
if not user_keys:
|
||||
logger.warning(f"No API keys found for user {user_id}")
|
||||
return await call_next(request)
|
||||
|
||||
# Save original environment values
|
||||
original_keys = {}
|
||||
keys_to_inject = {
|
||||
'gemini': 'GEMINI_API_KEY',
|
||||
'exa': 'EXA_API_KEY',
|
||||
'copilotkit': 'COPILOTKIT_API_KEY',
|
||||
'openai': 'OPENAI_API_KEY',
|
||||
'anthropic': 'ANTHROPIC_API_KEY',
|
||||
'tavily': 'TAVILY_API_KEY',
|
||||
'serper': 'SERPER_API_KEY',
|
||||
'firecrawl': 'FIRECRAWL_API_KEY',
|
||||
}
|
||||
|
||||
# Inject user-specific keys into environment
|
||||
for provider, env_var in keys_to_inject.items():
|
||||
if provider in user_keys and user_keys[provider]:
|
||||
# Save original value (if any)
|
||||
original_keys[env_var] = os.environ.get(env_var)
|
||||
# Inject user-specific key
|
||||
os.environ[env_var] = user_keys[provider]
|
||||
logger.debug(f"[PRODUCTION] Injected {env_var} for user {user_id}")
|
||||
|
||||
try:
|
||||
# Process request with user-specific keys in environment
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
finally:
|
||||
# CRITICAL: Restore original environment values
|
||||
for env_var, original_value in original_keys.items():
|
||||
if original_value is None:
|
||||
# Key didn't exist before, remove it
|
||||
os.environ.pop(env_var, None)
|
||||
else:
|
||||
# Restore original value
|
||||
os.environ[env_var] = original_value
|
||||
|
||||
logger.debug(f"[PRODUCTION] Cleaned up environment for user {user_id}")
|
||||
|
||||
|
||||
async def api_key_injection_middleware(request: Request, call_next: Callable):
|
||||
"""
|
||||
Middleware function that injects user-specific API keys into environment.
|
||||
|
||||
Usage in app.py:
|
||||
app.middleware("http")(api_key_injection_middleware)
|
||||
"""
|
||||
middleware = APIKeyInjectionMiddleware()
|
||||
return await middleware(request, call_next)
|
||||
|
||||
@@ -16,6 +16,7 @@ class OnboardingSession(Base):
|
||||
api_keys = relationship('APIKey', back_populates='session', cascade="all, delete-orphan")
|
||||
website_analyses = relationship('WebsiteAnalysis', back_populates='session', cascade="all, delete-orphan")
|
||||
research_preferences = relationship('ResearchPreferences', back_populates='session', cascade="all, delete-orphan", uselist=False)
|
||||
persona_data = relationship('PersonaData', back_populates='session', cascade="all, delete-orphan", uselist=False)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OnboardingSession(id={self.id}, user_id={self.user_id}, step={self.current_step}, progress={self.progress})>"
|
||||
@@ -143,4 +144,40 @@ class ResearchPreferences(Base):
|
||||
'recommended_settings': self.recommended_settings,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
}
|
||||
|
||||
class PersonaData(Base):
|
||||
"""Stores persona generation data from onboarding step 4."""
|
||||
__tablename__ = 'persona_data'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
session_id = Column(Integer, ForeignKey('onboarding_sessions.id', ondelete='CASCADE'), nullable=False)
|
||||
|
||||
# Persona generation results
|
||||
core_persona = Column(JSON) # Core persona data (demographics, psychographics, etc.)
|
||||
platform_personas = Column(JSON) # Platform-specific personas (LinkedIn, Twitter, etc.)
|
||||
quality_metrics = Column(JSON) # Quality assessment metrics
|
||||
selected_platforms = Column(JSON) # Array of selected platforms
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=func.now())
|
||||
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
|
||||
|
||||
# Relationships
|
||||
session = relationship('OnboardingSession', back_populates='persona_data')
|
||||
|
||||
def __repr__(self):
|
||||
return f"<PersonaData(id={self.id}, session_id={self.session_id})>"
|
||||
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary for API responses."""
|
||||
return {
|
||||
'id': self.id,
|
||||
'session_id': self.session_id,
|
||||
'core_persona': self.core_persona,
|
||||
'platform_personas': self.platform_personas,
|
||||
'quality_metrics': self.quality_metrics,
|
||||
'selected_platforms': self.selected_platforms,
|
||||
'created_at': self.created_at.isoformat() if self.created_at else None,
|
||||
'updated_at': self.updated_at.isoformat() if self.updated_at else None
|
||||
}
|
||||
110
backend/routers/frontend_env_manager.py
Normal file
110
backend/routers/frontend_env_manager.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Frontend Environment Manager
|
||||
Handles updating frontend environment variables (for development purposes).
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/api/frontend-env",
|
||||
tags=["Frontend Environment"],
|
||||
)
|
||||
|
||||
class FrontendEnvUpdateRequest(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
description: Optional[str] = None
|
||||
|
||||
@router.post("/update")
|
||||
async def update_frontend_env(request: FrontendEnvUpdateRequest):
|
||||
"""
|
||||
Update frontend environment variable (for development purposes).
|
||||
This writes to the frontend/.env file.
|
||||
"""
|
||||
try:
|
||||
# Get the frontend directory path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
frontend_dir = backend_dir.parent / "frontend"
|
||||
env_path = frontend_dir / ".env"
|
||||
|
||||
# Ensure the frontend directory exists
|
||||
if not frontend_dir.exists():
|
||||
raise HTTPException(status_code=404, detail="Frontend directory not found")
|
||||
|
||||
# Read existing .env file
|
||||
env_lines = []
|
||||
if env_path.exists():
|
||||
with open(env_path, 'r') as f:
|
||||
env_lines = f.readlines()
|
||||
|
||||
# Update or add the environment variable
|
||||
key_found = False
|
||||
updated_lines = []
|
||||
for line in env_lines:
|
||||
if line.startswith(f"{request.key}="):
|
||||
updated_lines.append(f"{request.key}={request.value}\n")
|
||||
key_found = True
|
||||
else:
|
||||
updated_lines.append(line)
|
||||
|
||||
if not key_found:
|
||||
# Add comment if description provided
|
||||
if request.description:
|
||||
updated_lines.append(f"# {request.description}\n")
|
||||
updated_lines.append(f"{request.key}={request.value}\n")
|
||||
|
||||
# Write back to .env file
|
||||
with open(env_path, 'w') as f:
|
||||
f.writelines(updated_lines)
|
||||
|
||||
logger.info(f"Updated frontend environment variable: {request.key}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Environment variable {request.key} updated successfully",
|
||||
"key": request.key,
|
||||
"value": request.value
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating frontend environment: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to update environment variable: {str(e)}")
|
||||
|
||||
@router.get("/status")
|
||||
async def get_frontend_env_status():
|
||||
"""
|
||||
Get status of frontend environment file.
|
||||
"""
|
||||
try:
|
||||
# Get the frontend directory path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
frontend_dir = backend_dir.parent / "frontend"
|
||||
env_path = frontend_dir / ".env"
|
||||
|
||||
if not env_path.exists():
|
||||
return {
|
||||
"exists": False,
|
||||
"path": str(env_path),
|
||||
"message": "Frontend .env file does not exist"
|
||||
}
|
||||
|
||||
# Read and return basic info about the .env file
|
||||
with open(env_path, 'r') as f:
|
||||
content = f.read()
|
||||
|
||||
return {
|
||||
"exists": True,
|
||||
"path": str(env_path),
|
||||
"size": len(content),
|
||||
"lines": len(content.splitlines()),
|
||||
"message": "Frontend .env file exists"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking frontend environment status: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to check environment status: {str(e)}")
|
||||
124
backend/scripts/create_persona_data_table.py
Normal file
124
backend/scripts/create_persona_data_table.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""
|
||||
Script to create the persona_data table for onboarding step 4.
|
||||
This migration adds support for storing persona generation data.
|
||||
|
||||
Usage:
|
||||
python backend/scripts/create_persona_data_table.py
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend directory to path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import inspect
|
||||
|
||||
def create_persona_data_table():
|
||||
"""Create the persona_data table."""
|
||||
try:
|
||||
# Import after path is set
|
||||
from services.database import engine
|
||||
from models.onboarding import Base as OnboardingBase, PersonaData
|
||||
|
||||
logger.info("🔍 Checking if persona_data table exists...")
|
||||
|
||||
# Check if table already exists
|
||||
inspector = inspect(engine)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
if 'persona_data' in existing_tables:
|
||||
logger.info("✅ persona_data table already exists")
|
||||
return True
|
||||
|
||||
logger.info("📊 Creating persona_data table...")
|
||||
|
||||
# Create only the persona_data table
|
||||
PersonaData.__table__.create(bind=engine, checkfirst=True)
|
||||
|
||||
logger.info("✅ persona_data table created successfully")
|
||||
|
||||
# Verify creation
|
||||
inspector = inspect(engine)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
if 'persona_data' in existing_tables:
|
||||
logger.info("✅ Verification successful - persona_data table exists")
|
||||
|
||||
# Show table structure
|
||||
columns = inspector.get_columns('persona_data')
|
||||
logger.info(f"📋 Table structure ({len(columns)} columns):")
|
||||
for col in columns:
|
||||
logger.info(f" - {col['name']}: {col['type']}")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error("❌ Table creation verification failed")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating persona_data table: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
def check_onboarding_tables():
|
||||
"""Check all onboarding-related tables."""
|
||||
try:
|
||||
from services.database import engine
|
||||
from sqlalchemy import inspect
|
||||
|
||||
inspector = inspect(engine)
|
||||
existing_tables = inspector.get_table_names()
|
||||
|
||||
onboarding_tables = [
|
||||
'onboarding_sessions',
|
||||
'api_keys',
|
||||
'website_analyses',
|
||||
'research_preferences',
|
||||
'persona_data'
|
||||
]
|
||||
|
||||
logger.info("📋 Onboarding Tables Status:")
|
||||
for table in onboarding_tables:
|
||||
status = "✅" if table in existing_tables else "❌"
|
||||
logger.info(f" {status} {table}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking tables: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("=" * 60)
|
||||
logger.info("Persona Data Table Migration")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Check existing tables
|
||||
check_onboarding_tables()
|
||||
|
||||
logger.info("")
|
||||
|
||||
# Create persona_data table
|
||||
if create_persona_data_table():
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info("✅ Migration completed successfully!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Check again to confirm
|
||||
logger.info("")
|
||||
check_onboarding_tables()
|
||||
|
||||
sys.exit(0)
|
||||
else:
|
||||
logger.error("")
|
||||
logger.error("=" * 60)
|
||||
logger.error("❌ Migration failed!")
|
||||
logger.error("=" * 60)
|
||||
sys.exit(1)
|
||||
|
||||
338
backend/scripts/verify_onboarding_data.py
Normal file
338
backend/scripts/verify_onboarding_data.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Database Verification Script for Onboarding Data
|
||||
Verifies that all onboarding steps data is properly saved to the database.
|
||||
|
||||
Usage:
|
||||
python backend/scripts/verify_onboarding_data.py [user_id]
|
||||
|
||||
Example:
|
||||
python backend/scripts/verify_onboarding_data.py user_33Gz1FPI86VDXhRY8QN4ragRFGN
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# Add backend directory to path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import inspect, text
|
||||
from typing import Optional
|
||||
import json
|
||||
|
||||
def get_user_id_from_args() -> Optional[str]:
|
||||
"""Get user_id from command line arguments."""
|
||||
if len(sys.argv) > 1:
|
||||
return sys.argv[1]
|
||||
return None
|
||||
|
||||
def verify_table_exists(table_name: str, inspector) -> bool:
|
||||
"""Check if a table exists in the database."""
|
||||
tables = inspector.get_table_names()
|
||||
exists = table_name in tables
|
||||
|
||||
if exists:
|
||||
logger.info(f"✅ Table '{table_name}' exists")
|
||||
# Show column count
|
||||
columns = inspector.get_columns(table_name)
|
||||
logger.info(f" Columns: {len(columns)}")
|
||||
else:
|
||||
logger.error(f"❌ Table '{table_name}' does NOT exist")
|
||||
|
||||
return exists
|
||||
|
||||
def verify_onboarding_session(user_id: str, db):
|
||||
"""Verify onboarding session data."""
|
||||
try:
|
||||
from models.onboarding import OnboardingSession
|
||||
|
||||
session = db.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
logger.info(f"✅ Onboarding Session found for user: {user_id}")
|
||||
logger.info(f" Session ID: {session.id}")
|
||||
logger.info(f" Current Step: {session.current_step}")
|
||||
logger.info(f" Progress: {session.progress}%")
|
||||
logger.info(f" Started At: {session.started_at}")
|
||||
logger.info(f" Updated At: {session.updated_at}")
|
||||
return session.id
|
||||
else:
|
||||
logger.error(f"❌ No onboarding session found for user: {user_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying onboarding session: {e}")
|
||||
return None
|
||||
|
||||
def verify_api_keys(session_id: int, user_id: str, db):
|
||||
"""Verify API keys data (Step 1)."""
|
||||
try:
|
||||
from models.onboarding import APIKey
|
||||
|
||||
api_keys = db.query(APIKey).filter(
|
||||
APIKey.session_id == session_id
|
||||
).all()
|
||||
|
||||
if api_keys:
|
||||
logger.info(f"✅ Step 1 (API Keys): Found {len(api_keys)} API key(s)")
|
||||
for key in api_keys:
|
||||
# Mask the key for security
|
||||
masked_key = f"{key.key[:8]}...{key.key[-4:]}" if len(key.key) > 12 else "***"
|
||||
logger.info(f" - Provider: {key.provider}")
|
||||
logger.info(f" Key: {masked_key}")
|
||||
logger.info(f" Created: {key.created_at}")
|
||||
else:
|
||||
logger.warning(f"⚠️ Step 1 (API Keys): No API keys found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying API keys: {e}")
|
||||
|
||||
def verify_website_analysis(session_id: int, user_id: str, db):
|
||||
"""Verify website analysis data (Step 2)."""
|
||||
try:
|
||||
from models.onboarding import WebsiteAnalysis
|
||||
|
||||
analysis = db.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session_id
|
||||
).first()
|
||||
|
||||
if analysis:
|
||||
logger.info(f"✅ Step 2 (Website Analysis): Data found")
|
||||
logger.info(f" Website URL: {analysis.website_url}")
|
||||
logger.info(f" Analysis Date: {analysis.analysis_date}")
|
||||
logger.info(f" Status: {analysis.status}")
|
||||
|
||||
if analysis.writing_style:
|
||||
logger.info(f" Writing Style: {len(analysis.writing_style)} attributes")
|
||||
if analysis.content_characteristics:
|
||||
logger.info(f" Content Characteristics: {len(analysis.content_characteristics)} attributes")
|
||||
if analysis.target_audience:
|
||||
logger.info(f" Target Audience: {len(analysis.target_audience)} attributes")
|
||||
else:
|
||||
logger.warning(f"⚠️ Step 2 (Website Analysis): No data found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying website analysis: {e}")
|
||||
|
||||
def verify_research_preferences(session_id: int, user_id: str, db):
|
||||
"""Verify research preferences data (Step 3)."""
|
||||
try:
|
||||
from models.onboarding import ResearchPreferences
|
||||
|
||||
prefs = db.query(ResearchPreferences).filter(
|
||||
ResearchPreferences.session_id == session_id
|
||||
).first()
|
||||
|
||||
if prefs:
|
||||
logger.info(f"✅ Step 3 (Research Preferences): Data found")
|
||||
logger.info(f" Research Depth: {prefs.research_depth}")
|
||||
logger.info(f" Content Types: {prefs.content_types}")
|
||||
logger.info(f" Auto Research: {prefs.auto_research}")
|
||||
logger.info(f" Factual Content: {prefs.factual_content}")
|
||||
else:
|
||||
logger.warning(f"⚠️ Step 3 (Research Preferences): No data found")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying research preferences: {e}")
|
||||
|
||||
def verify_persona_data(session_id: int, user_id: str, db):
|
||||
"""Verify persona data (Step 4) - THE NEW FIX!"""
|
||||
try:
|
||||
from models.onboarding import PersonaData
|
||||
|
||||
persona = db.query(PersonaData).filter(
|
||||
PersonaData.session_id == session_id
|
||||
).first()
|
||||
|
||||
if persona:
|
||||
logger.info(f"✅ Step 4 (Persona Generation): Data found ⭐")
|
||||
|
||||
if persona.core_persona:
|
||||
logger.info(f" Core Persona: Present")
|
||||
if isinstance(persona.core_persona, dict):
|
||||
logger.info(f" Attributes: {len(persona.core_persona)} fields")
|
||||
|
||||
if persona.platform_personas:
|
||||
logger.info(f" Platform Personas: Present")
|
||||
if isinstance(persona.platform_personas, dict):
|
||||
platforms = list(persona.platform_personas.keys())
|
||||
logger.info(f" Platforms: {', '.join(platforms)}")
|
||||
|
||||
if persona.quality_metrics:
|
||||
logger.info(f" Quality Metrics: Present")
|
||||
if isinstance(persona.quality_metrics, dict):
|
||||
logger.info(f" Metrics: {len(persona.quality_metrics)} fields")
|
||||
|
||||
if persona.selected_platforms:
|
||||
logger.info(f" Selected Platforms: {persona.selected_platforms}")
|
||||
|
||||
logger.info(f" Created At: {persona.created_at}")
|
||||
logger.info(f" Updated At: {persona.updated_at}")
|
||||
else:
|
||||
logger.error(f"❌ Step 4 (Persona Generation): No data found - THIS IS THE BUG WE FIXED!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying persona data: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
def show_raw_sql_query_example(user_id: str):
|
||||
"""Show example SQL queries for manual verification."""
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info("📋 Raw SQL Queries for Manual Verification:")
|
||||
logger.info("=" * 60)
|
||||
|
||||
queries = [
|
||||
("Onboarding Session",
|
||||
f"SELECT * FROM onboarding_sessions WHERE user_id = '{user_id}';"),
|
||||
|
||||
("API Keys",
|
||||
f"""SELECT ak.* FROM api_keys ak
|
||||
JOIN onboarding_sessions os ON ak.session_id = os.id
|
||||
WHERE os.user_id = '{user_id}';"""),
|
||||
|
||||
("Website Analysis",
|
||||
f"""SELECT wa.website_url, wa.analysis_date, wa.status
|
||||
FROM website_analyses wa
|
||||
JOIN onboarding_sessions os ON wa.session_id = os.id
|
||||
WHERE os.user_id = '{user_id}';"""),
|
||||
|
||||
("Research Preferences",
|
||||
f"""SELECT rp.research_depth, rp.content_types, rp.auto_research
|
||||
FROM research_preferences rp
|
||||
JOIN onboarding_sessions os ON rp.session_id = os.id
|
||||
WHERE os.user_id = '{user_id}';"""),
|
||||
|
||||
("Persona Data (NEW!)",
|
||||
f"""SELECT pd.* FROM persona_data pd
|
||||
JOIN onboarding_sessions os ON pd.session_id = os.id
|
||||
WHERE os.user_id = '{user_id}';"""),
|
||||
]
|
||||
|
||||
for title, query in queries:
|
||||
logger.info(f"\n{title}:")
|
||||
logger.info(f" {query}")
|
||||
|
||||
def count_all_records(db):
|
||||
"""Count records in all onboarding tables."""
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info("📊 Overall Database Statistics:")
|
||||
logger.info("=" * 60)
|
||||
|
||||
try:
|
||||
from models.onboarding import (
|
||||
OnboardingSession, APIKey, WebsiteAnalysis,
|
||||
ResearchPreferences, PersonaData
|
||||
)
|
||||
|
||||
counts = {
|
||||
"Onboarding Sessions": db.query(OnboardingSession).count(),
|
||||
"API Keys": db.query(APIKey).count(),
|
||||
"Website Analyses": db.query(WebsiteAnalysis).count(),
|
||||
"Research Preferences": db.query(ResearchPreferences).count(),
|
||||
"Persona Data": db.query(PersonaData).count(),
|
||||
}
|
||||
|
||||
for table, count in counts.items():
|
||||
logger.info(f" {table}: {count} record(s)")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error counting records: {e}")
|
||||
|
||||
def main():
|
||||
"""Main verification function."""
|
||||
logger.info("=" * 60)
|
||||
logger.info("🔍 Onboarding Database Verification")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Get user_id
|
||||
user_id = get_user_id_from_args()
|
||||
|
||||
if not user_id:
|
||||
logger.warning("⚠️ No user_id provided. Will show overall statistics only.")
|
||||
logger.info("Usage: python backend/scripts/verify_onboarding_data.py <user_id>")
|
||||
|
||||
try:
|
||||
from services.database import SessionLocal, engine
|
||||
from sqlalchemy import inspect
|
||||
|
||||
# Check tables exist
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info("1️⃣ Verifying Database Tables:")
|
||||
logger.info("=" * 60)
|
||||
|
||||
inspector = inspect(engine)
|
||||
tables = [
|
||||
'onboarding_sessions',
|
||||
'api_keys',
|
||||
'website_analyses',
|
||||
'research_preferences',
|
||||
'persona_data'
|
||||
]
|
||||
|
||||
all_exist = True
|
||||
for table in tables:
|
||||
if not verify_table_exists(table, inspector):
|
||||
all_exist = False
|
||||
|
||||
if not all_exist:
|
||||
logger.error("")
|
||||
logger.error("❌ Some tables are missing! Run migrations first.")
|
||||
return False
|
||||
|
||||
# Count all records
|
||||
db = SessionLocal()
|
||||
try:
|
||||
count_all_records(db)
|
||||
|
||||
# If user_id provided, show detailed data
|
||||
if user_id:
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f"2️⃣ Verifying Data for User: {user_id}")
|
||||
logger.info("=" * 60)
|
||||
|
||||
# Verify session
|
||||
session_id = verify_onboarding_session(user_id, db)
|
||||
|
||||
if session_id:
|
||||
logger.info("")
|
||||
# Verify each step's data
|
||||
verify_api_keys(session_id, user_id, db)
|
||||
logger.info("")
|
||||
verify_website_analysis(session_id, user_id, db)
|
||||
logger.info("")
|
||||
verify_research_preferences(session_id, user_id, db)
|
||||
logger.info("")
|
||||
verify_persona_data(session_id, user_id, db)
|
||||
|
||||
# Show SQL examples
|
||||
show_raw_sql_query_example(user_id)
|
||||
|
||||
logger.info("")
|
||||
logger.info("=" * 60)
|
||||
logger.info("✅ Verification Complete!")
|
||||
logger.info("=" * 60)
|
||||
|
||||
return True
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Verification failed: {e}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = main()
|
||||
sys.exit(0 if success else 1)
|
||||
|
||||
@@ -35,14 +35,31 @@ class StepData:
|
||||
class OnboardingProgress:
|
||||
"""Manages onboarding progress with persistence and validation."""
|
||||
|
||||
def __init__(self, progress_file: Optional[str] = None):
|
||||
def __init__(self, progress_file: Optional[str] = None, user_id: Optional[str] = None):
|
||||
self.steps = self._initialize_steps()
|
||||
self.current_step = 1
|
||||
self.started_at = datetime.now().isoformat()
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.is_completed = False
|
||||
self.completed_at = None
|
||||
self.progress_file = progress_file or ".onboarding_progress.json"
|
||||
self.user_id = user_id # Add user_id for database isolation
|
||||
|
||||
# Use user-specific file for backward compatibility
|
||||
if user_id:
|
||||
self.progress_file = progress_file or f".onboarding_progress_{user_id}.json"
|
||||
else:
|
||||
self.progress_file = progress_file or ".onboarding_progress.json"
|
||||
|
||||
# Initialize database service for dual persistence
|
||||
try:
|
||||
from services.onboarding_database_service import OnboardingDatabaseService
|
||||
self.db_service = OnboardingDatabaseService()
|
||||
self.use_database = True
|
||||
logger.info(f"Database service initialized for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Database service not available, using file only: {e}")
|
||||
self.db_service = None
|
||||
self.use_database = False
|
||||
|
||||
# Load existing progress if available
|
||||
self.load_progress()
|
||||
@@ -192,8 +209,9 @@ class OnboardingProgress:
|
||||
logger.info("Onboarding completed successfully")
|
||||
|
||||
def save_progress(self):
|
||||
"""Save progress to file."""
|
||||
"""Save progress to both file and database (dual persistence)."""
|
||||
try:
|
||||
# Save to JSON file (backward compatibility)
|
||||
progress_data = {
|
||||
"steps": [{
|
||||
"step_number": step.step_number,
|
||||
@@ -215,6 +233,65 @@ class OnboardingProgress:
|
||||
json.dump(progress_data, f, indent=2)
|
||||
|
||||
logger.debug(f"Progress saved to {self.progress_file}")
|
||||
|
||||
# Also save to database if available and user_id is set
|
||||
if self.use_database and self.db_service and self.user_id:
|
||||
try:
|
||||
from services.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Update session progress
|
||||
self.db_service.update_step(self.user_id, self.current_step, db)
|
||||
|
||||
# Calculate progress percentage
|
||||
completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED)
|
||||
progress_pct = (completed_count / len(self.steps)) * 100
|
||||
self.db_service.update_progress(self.user_id, progress_pct, db)
|
||||
|
||||
# Save step-specific data to appropriate tables
|
||||
for step in self.steps:
|
||||
if step.status == StepStatus.COMPLETED and step.data:
|
||||
if step.step_number == 1: # API Keys
|
||||
api_keys = step.data.get('api_keys', {})
|
||||
for provider, key in api_keys.items():
|
||||
if key:
|
||||
# Save to database (for user isolation in production)
|
||||
self.db_service.save_api_key(self.user_id, provider, key, db)
|
||||
|
||||
# Also save to .env file ONLY in local development
|
||||
# This allows local developers to have keys in .env for convenience
|
||||
# In production, keys are fetched from database per user
|
||||
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
|
||||
if is_local:
|
||||
try:
|
||||
from services.api_key_manager import APIKeyManager
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key_manager.save_api_key(provider, key)
|
||||
logger.info(f"[LOCAL] API key for {provider} saved to .env file")
|
||||
except Exception as env_error:
|
||||
logger.warning(f"[LOCAL] Failed to save {provider} API key to .env file: {env_error}")
|
||||
else:
|
||||
logger.info(f"[PRODUCTION] API key for {provider} saved to database only (user: {self.user_id})")
|
||||
|
||||
# Log database save confirmation
|
||||
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
|
||||
elif step.step_number == 2: # Website Analysis
|
||||
self.db_service.save_website_analysis(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Website analysis saved to database for user {self.user_id}")
|
||||
elif step.step_number == 3: # Research Preferences
|
||||
self.db_service.save_research_preferences(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Research preferences saved to database for user {self.user_id}")
|
||||
elif step.step_number == 4: # Persona Generation
|
||||
self.db_service.save_persona_data(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Persona data saved to database for user {self.user_id}")
|
||||
|
||||
logger.info(f"Progress also saved to database for user {self.user_id}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as db_error:
|
||||
logger.warning(f"Failed to save to database, JSON file still saved: {db_error}")
|
||||
# Don't fail if database save fails - JSON is still working
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving progress: {str(e)}")
|
||||
|
||||
@@ -423,8 +500,34 @@ class APIKeyManager:
|
||||
try:
|
||||
if provider in self.api_keys:
|
||||
self.api_keys[provider] = api_key
|
||||
self._save_to_env_file(provider, api_key)
|
||||
logger.info(f"API key saved for {provider}")
|
||||
|
||||
# Save to database if available and user_id is set
|
||||
if hasattr(self, 'use_database') and self.use_database and hasattr(self, 'db_service') and self.db_service and hasattr(self, 'user_id') and self.user_id:
|
||||
try:
|
||||
from services.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
self.db_service.save_api_key(self.user_id, provider, api_key, db)
|
||||
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as db_error:
|
||||
logger.warning(f"Failed to save {provider} API key to database: {db_error}")
|
||||
|
||||
# Also save to .env file in local mode
|
||||
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
|
||||
if is_local:
|
||||
# Special handling for CopilotKit - save to frontend/.env
|
||||
if provider == 'copilotkit':
|
||||
self._save_to_frontend_env(api_key)
|
||||
logger.info(f"[LOCAL] CopilotKit API key saved to frontend/.env file")
|
||||
else:
|
||||
# Save other keys to backend/.env
|
||||
self._save_to_env_file(provider, api_key)
|
||||
logger.info(f"[LOCAL] API key for {provider} saved to backend/.env file")
|
||||
else:
|
||||
logger.info(f"[PRODUCTION] API key for {provider} saved to memory only (database handles persistence)")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Unknown provider: {provider}")
|
||||
@@ -490,8 +593,50 @@ class APIKeyManager:
|
||||
"total_providers": len(self.api_keys)
|
||||
}
|
||||
|
||||
def _save_to_frontend_env(self, api_key: str):
|
||||
"""Save CopilotKit API key to frontend/.env file."""
|
||||
try:
|
||||
# Get the frontend directory path
|
||||
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
frontend_dir = os.path.join(os.path.dirname(backend_dir), "frontend")
|
||||
env_path = os.path.join(frontend_dir, ".env")
|
||||
|
||||
# Read existing .env file
|
||||
if os.path.exists(env_path):
|
||||
with open(env_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
else:
|
||||
lines = []
|
||||
|
||||
# Update or add REACT_APP_COPILOTKIT_API_KEY
|
||||
key_found = False
|
||||
updated_lines = []
|
||||
env_var = "REACT_APP_COPILOTKIT_API_KEY"
|
||||
|
||||
for line in lines:
|
||||
if line.startswith(f"{env_var}="):
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
key_found = True
|
||||
else:
|
||||
updated_lines.append(line)
|
||||
|
||||
if not key_found:
|
||||
# Ensure the file ends with a newline before adding new key
|
||||
if updated_lines and not updated_lines[-1].endswith('\n'):
|
||||
updated_lines[-1] += '\n'
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
|
||||
# Write back to frontend .env file
|
||||
with open(env_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines(updated_lines)
|
||||
|
||||
logger.debug(f"CopilotKit API key saved to frontend .env file")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to frontend .env file: {str(e)}")
|
||||
|
||||
def _save_to_env_file(self, provider: str, api_key: str):
|
||||
"""Save API key to .env file."""
|
||||
"""Save API key to backend .env file."""
|
||||
try:
|
||||
env_mapping = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
@@ -513,11 +658,10 @@ class APIKeyManager:
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
# Update .env file - use backend directory path
|
||||
import os
|
||||
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
env_path = os.path.join(backend_dir, ".env")
|
||||
if os.path.exists(env_path):
|
||||
with open(env_path, 'r') as f:
|
||||
with open(env_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
else:
|
||||
lines = []
|
||||
@@ -532,13 +676,23 @@ class APIKeyManager:
|
||||
updated_lines.append(line)
|
||||
|
||||
if not key_found:
|
||||
# Ensure the file ends with a newline before adding new key
|
||||
if updated_lines and not updated_lines[-1].endswith('\n'):
|
||||
updated_lines[-1] += '\n'
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
|
||||
with open(env_path, 'w') as f:
|
||||
with open(env_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines(updated_lines)
|
||||
|
||||
# Reload environment variables
|
||||
load_dotenv(override=True)
|
||||
# Reload environment variables into current process
|
||||
load_dotenv(env_path, override=True)
|
||||
|
||||
# Verify the key is now in environment
|
||||
loaded_key = os.environ.get(env_var)
|
||||
if loaded_key == api_key:
|
||||
logger.info(f"✅ {env_var} loaded into environment (available for immediate use)")
|
||||
else:
|
||||
logger.warning(f"⚠️ {env_var} written to .env but not in environment yet")
|
||||
|
||||
logger.debug(f"API key saved to .env file for {provider}")
|
||||
except Exception as e:
|
||||
@@ -555,13 +709,17 @@ def get_onboarding_progress() -> OnboardingProgress:
|
||||
return get_onboarding_progress._instance
|
||||
|
||||
def get_onboarding_progress_for_user(user_id: str) -> OnboardingProgress:
|
||||
"""Get or create a per-user onboarding progress instance persisted to a user-specific file."""
|
||||
"""Get or create a per-user onboarding progress instance with database persistence."""
|
||||
global _user_onboarding_progress_cache
|
||||
safe_user_id = ''.join([c if c.isalnum() or c in ('-', '_') else '_' for c in str(user_id)])
|
||||
if safe_user_id in _user_onboarding_progress_cache:
|
||||
return _user_onboarding_progress_cache[safe_user_id]
|
||||
|
||||
# Create user-specific progress file for backward compatibility
|
||||
progress_file = f".onboarding_progress_{safe_user_id}.json"
|
||||
instance = OnboardingProgress(progress_file=progress_file)
|
||||
|
||||
# Pass user_id to enable database persistence
|
||||
instance = OnboardingProgress(progress_file=progress_file, user_id=user_id)
|
||||
_user_onboarding_progress_cache[safe_user_id] = instance
|
||||
return instance
|
||||
|
||||
|
||||
418
backend/services/onboarding_database_service.py
Normal file
418
backend/services/onboarding_database_service.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""
|
||||
Onboarding Database Service
|
||||
Provides database-backed storage for onboarding progress with user isolation.
|
||||
This replaces the JSON file-based storage with proper database persistence.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData
|
||||
from services.database import get_db
|
||||
|
||||
|
||||
class OnboardingDatabaseService:
|
||||
"""Database service for onboarding with user isolation."""
|
||||
|
||||
def __init__(self, db: Session = None):
|
||||
"""Initialize with optional database session."""
|
||||
self.db = db
|
||||
|
||||
def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession:
|
||||
"""Get existing onboarding session or create new one for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
# Try to get existing session for this user
|
||||
session = session_db.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
logger.info(f"Found existing onboarding session for user {user_id}")
|
||||
return session
|
||||
|
||||
# Create new session
|
||||
session = OnboardingSession(
|
||||
user_id=user_id,
|
||||
current_step=1,
|
||||
progress=0.0,
|
||||
started_at=datetime.now()
|
||||
)
|
||||
session_db.add(session)
|
||||
session_db.commit()
|
||||
session_db.refresh(session)
|
||||
|
||||
logger.info(f"Created new onboarding session for user {user_id}")
|
||||
return session
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Database error in get_or_create_session: {e}")
|
||||
session_db.rollback()
|
||||
raise
|
||||
|
||||
def get_session_by_user(self, user_id: str, db: Session = None) -> Optional[OnboardingSession]:
|
||||
"""Get onboarding session for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
return session_db.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting session: {e}")
|
||||
return None
|
||||
|
||||
def update_step(self, user_id: str, step_number: int, db: Session = None) -> bool:
|
||||
"""Update current step for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
session.current_step = step_number
|
||||
session.updated_at = datetime.now()
|
||||
session_db.commit()
|
||||
|
||||
logger.info(f"Updated user {user_id} to step {step_number}")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error updating step: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def update_progress(self, user_id: str, progress: float, db: Session = None) -> bool:
|
||||
"""Update progress percentage for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
session.progress = progress
|
||||
session.updated_at = datetime.now()
|
||||
session_db.commit()
|
||||
|
||||
logger.info(f"Updated user {user_id} progress to {progress}%")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error updating progress: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def save_api_key(self, user_id: str, provider: str, api_key: str, db: Session = None) -> bool:
|
||||
"""Save API key for user with isolation."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
# Get user's onboarding session
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
|
||||
# Check if key already exists for this provider and session
|
||||
existing_key = session_db.query(APIKey).filter(
|
||||
APIKey.session_id == session.id,
|
||||
APIKey.provider == provider
|
||||
).first()
|
||||
|
||||
if existing_key:
|
||||
# Update existing key
|
||||
existing_key.key = api_key
|
||||
existing_key.updated_at = datetime.now()
|
||||
logger.info(f"Updated {provider} API key for user {user_id}")
|
||||
else:
|
||||
# Create new key
|
||||
new_key = APIKey(
|
||||
session_id=session.id,
|
||||
provider=provider,
|
||||
key=api_key
|
||||
)
|
||||
session_db.add(new_key)
|
||||
logger.info(f"Created new {provider} API key for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving API key: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_api_keys(self, user_id: str, db: Session = None) -> Dict[str, str]:
|
||||
"""Get all API keys for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return {}
|
||||
|
||||
keys = session_db.query(APIKey).filter(
|
||||
APIKey.session_id == session.id
|
||||
).all()
|
||||
|
||||
return {key.provider: key.key for key in keys}
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting API keys: {e}")
|
||||
return {}
|
||||
|
||||
def save_website_analysis(self, user_id: str, analysis_data: Dict[str, Any], db: Session = None) -> bool:
|
||||
"""Save website analysis for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
|
||||
# Check if analysis already exists
|
||||
existing = session_db.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing
|
||||
existing.website_url = analysis_data.get('website_url', existing.website_url)
|
||||
existing.writing_style = analysis_data.get('writing_style')
|
||||
existing.content_characteristics = analysis_data.get('content_characteristics')
|
||||
existing.target_audience = analysis_data.get('target_audience')
|
||||
existing.content_type = analysis_data.get('content_type')
|
||||
existing.recommended_settings = analysis_data.get('recommended_settings')
|
||||
existing.crawl_result = analysis_data.get('crawl_result')
|
||||
existing.style_patterns = analysis_data.get('style_patterns')
|
||||
existing.style_guidelines = analysis_data.get('style_guidelines')
|
||||
existing.status = analysis_data.get('status', 'completed')
|
||||
existing.updated_at = datetime.now()
|
||||
logger.info(f"Updated website analysis for user {user_id}")
|
||||
else:
|
||||
# Create new
|
||||
analysis = WebsiteAnalysis(
|
||||
session_id=session.id,
|
||||
website_url=analysis_data.get('website_url', ''),
|
||||
writing_style=analysis_data.get('writing_style'),
|
||||
content_characteristics=analysis_data.get('content_characteristics'),
|
||||
target_audience=analysis_data.get('target_audience'),
|
||||
content_type=analysis_data.get('content_type'),
|
||||
recommended_settings=analysis_data.get('recommended_settings'),
|
||||
crawl_result=analysis_data.get('crawl_result'),
|
||||
style_patterns=analysis_data.get('style_patterns'),
|
||||
style_guidelines=analysis_data.get('style_guidelines'),
|
||||
status=analysis_data.get('status', 'completed')
|
||||
)
|
||||
session_db.add(analysis)
|
||||
logger.info(f"Created website analysis for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving website analysis: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_website_analysis(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get website analysis for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
analysis = session_db.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
return analysis.to_dict() if analysis else None
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting website analysis: {e}")
|
||||
return None
|
||||
|
||||
def save_research_preferences(self, user_id: str, preferences: Dict[str, Any], db: Session = None) -> bool:
|
||||
"""Save research preferences for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
|
||||
# Check if preferences already exist
|
||||
existing = session_db.query(ResearchPreferences).filter(
|
||||
ResearchPreferences.session_id == session.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing
|
||||
existing.research_depth = preferences.get('research_depth', existing.research_depth)
|
||||
existing.content_types = preferences.get('content_types', existing.content_types)
|
||||
existing.auto_research = preferences.get('auto_research', existing.auto_research)
|
||||
existing.factual_content = preferences.get('factual_content', existing.factual_content)
|
||||
existing.writing_style = preferences.get('writing_style')
|
||||
existing.content_characteristics = preferences.get('content_characteristics')
|
||||
existing.target_audience = preferences.get('target_audience')
|
||||
existing.recommended_settings = preferences.get('recommended_settings')
|
||||
existing.updated_at = datetime.now()
|
||||
logger.info(f"Updated research preferences for user {user_id}")
|
||||
else:
|
||||
# Create new
|
||||
prefs = ResearchPreferences(
|
||||
session_id=session.id,
|
||||
research_depth=preferences.get('research_depth', 'standard'),
|
||||
content_types=preferences.get('content_types', []),
|
||||
auto_research=preferences.get('auto_research', True),
|
||||
factual_content=preferences.get('factual_content', True),
|
||||
writing_style=preferences.get('writing_style'),
|
||||
content_characteristics=preferences.get('content_characteristics'),
|
||||
target_audience=preferences.get('target_audience'),
|
||||
recommended_settings=preferences.get('recommended_settings')
|
||||
)
|
||||
session_db.add(prefs)
|
||||
logger.info(f"Created research preferences for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving research preferences: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def save_persona_data(self, user_id: str, persona_data: Dict[str, Any], db: Session = None) -> bool:
|
||||
"""Save persona data for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
|
||||
# Check if persona data already exists for this user
|
||||
existing = session_db.query(PersonaData).filter(
|
||||
PersonaData.session_id == session.id
|
||||
).first()
|
||||
|
||||
if existing:
|
||||
# Update existing persona data
|
||||
existing.core_persona = persona_data.get('corePersona')
|
||||
existing.platform_personas = persona_data.get('platformPersonas')
|
||||
existing.quality_metrics = persona_data.get('qualityMetrics')
|
||||
existing.selected_platforms = persona_data.get('selectedPlatforms', [])
|
||||
existing.updated_at = datetime.utcnow()
|
||||
logger.info(f"Updated persona data for user {user_id}")
|
||||
else:
|
||||
# Create new persona data record
|
||||
persona = PersonaData(
|
||||
session_id=session.id,
|
||||
core_persona=persona_data.get('corePersona'),
|
||||
platform_personas=persona_data.get('platformPersonas'),
|
||||
quality_metrics=persona_data.get('qualityMetrics'),
|
||||
selected_platforms=persona_data.get('selectedPlatforms', [])
|
||||
)
|
||||
session_db.add(persona)
|
||||
logger.info(f"Created persona data for user {user_id}")
|
||||
|
||||
session_db.commit()
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error saving persona data: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_research_preferences(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
|
||||
"""Get research preferences for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
prefs = session_db.query(ResearchPreferences).filter(
|
||||
ResearchPreferences.session_id == session.id
|
||||
).first()
|
||||
|
||||
return prefs.to_dict() if prefs else None
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting research preferences: {e}")
|
||||
return None
|
||||
|
||||
def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool:
|
||||
"""Mark onboarding as complete for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_or_create_session(user_id, session_db)
|
||||
session.current_step = 6 # Final step
|
||||
session.progress = 100.0
|
||||
session.updated_at = datetime.now()
|
||||
session_db.commit()
|
||||
|
||||
logger.info(f"Marked onboarding complete for user {user_id}")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error marking onboarding complete: {e}")
|
||||
session_db.rollback()
|
||||
return False
|
||||
|
||||
def get_onboarding_status(self, user_id: str, db: Session = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive onboarding status for user."""
|
||||
session_db = db or self.db
|
||||
if not session_db:
|
||||
raise ValueError("Database session required")
|
||||
|
||||
try:
|
||||
session = self.get_session_by_user(user_id, session_db)
|
||||
|
||||
if not session:
|
||||
# User hasn't started onboarding yet
|
||||
return {
|
||||
"is_completed": False,
|
||||
"current_step": 1,
|
||||
"progress": 0.0,
|
||||
"started_at": None,
|
||||
"updated_at": None
|
||||
}
|
||||
|
||||
return {
|
||||
"is_completed": session.current_step >= 6 and session.progress >= 100.0,
|
||||
"current_step": session.current_step,
|
||||
"progress": session.progress,
|
||||
"started_at": session.started_at.isoformat() if session.started_at else None,
|
||||
"updated_at": session.updated_at.isoformat() if session.updated_at else None
|
||||
}
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error getting onboarding status: {e}")
|
||||
return {
|
||||
"is_completed": False,
|
||||
"current_step": 1,
|
||||
"progress": 0.0,
|
||||
"started_at": None,
|
||||
"updated_at": None
|
||||
}
|
||||
|
||||
150
backend/services/user_api_key_context.py
Normal file
150
backend/services/user_api_key_context.py
Normal file
@@ -0,0 +1,150 @@
|
||||
"""
|
||||
User API Key Context Manager
|
||||
Provides user-specific API keys to backend services.
|
||||
|
||||
In development: Uses .env file
|
||||
In production: Fetches from database per user
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict
|
||||
from loguru import logger
|
||||
from contextlib import contextmanager
|
||||
|
||||
class UserAPIKeyContext:
|
||||
"""
|
||||
Context manager for user-specific API keys.
|
||||
|
||||
Usage:
|
||||
with UserAPIKeyContext(user_id) as api_keys:
|
||||
gemini_key = api_keys.get('gemini')
|
||||
exa_key = api_keys.get('exa')
|
||||
# Use keys for this specific user
|
||||
"""
|
||||
|
||||
def __init__(self, user_id: Optional[str] = None):
|
||||
"""
|
||||
Initialize with optional user_id.
|
||||
|
||||
Args:
|
||||
user_id: User ID to fetch keys for. If None, uses .env keys (local mode)
|
||||
"""
|
||||
self.user_id = user_id
|
||||
self.keys: Dict[str, str] = {}
|
||||
self._is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
|
||||
|
||||
def __enter__(self):
|
||||
"""Load API keys when entering context."""
|
||||
if self._is_local:
|
||||
# Local mode: Use .env file
|
||||
self.keys = self._load_from_env()
|
||||
logger.debug(f"[LOCAL] Loaded API keys from .env file")
|
||||
elif self.user_id:
|
||||
# Production mode: Fetch from database
|
||||
self.keys = self._load_from_database(self.user_id)
|
||||
logger.debug(f"[PRODUCTION] Loaded API keys from database for user {self.user_id}")
|
||||
else:
|
||||
logger.warning("No user_id provided in production mode - using empty keys")
|
||||
self.keys = {}
|
||||
|
||||
return self.keys
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Clean up when exiting context."""
|
||||
self.keys.clear()
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
def _load_from_env(self) -> Dict[str, str]:
|
||||
"""Load API keys from environment variables (.env file)."""
|
||||
return {
|
||||
'gemini': os.getenv('GEMINI_API_KEY', ''),
|
||||
'exa': os.getenv('EXA_API_KEY', ''),
|
||||
'copilotkit': os.getenv('COPILOTKIT_API_KEY', ''),
|
||||
'openai': os.getenv('OPENAI_API_KEY', ''),
|
||||
'anthropic': os.getenv('ANTHROPIC_API_KEY', ''),
|
||||
'tavily': os.getenv('TAVILY_API_KEY', ''),
|
||||
'serper': os.getenv('SERPER_API_KEY', ''),
|
||||
'firecrawl': os.getenv('FIRECRAWL_API_KEY', ''),
|
||||
}
|
||||
|
||||
def _load_from_database(self, user_id: str) -> Dict[str, str]:
|
||||
"""Load API keys from database for specific user."""
|
||||
try:
|
||||
from services.onboarding_database_service import OnboardingDatabaseService
|
||||
from services.database import SessionLocal
|
||||
|
||||
db_service = OnboardingDatabaseService()
|
||||
db = SessionLocal()
|
||||
try:
|
||||
keys = db_service.get_api_keys(user_id, db)
|
||||
logger.info(f"Loaded {len(keys)} API keys from database for user {user_id}")
|
||||
return keys
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load API keys from database for user {user_id}: {e}")
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def get_user_key(user_id: Optional[str], provider: str) -> Optional[str]:
|
||||
"""
|
||||
Convenience method to get a single API key for a user.
|
||||
|
||||
Args:
|
||||
user_id: User ID (None for development mode)
|
||||
provider: Provider name (e.g., 'gemini', 'exa')
|
||||
|
||||
Returns:
|
||||
API key string or None
|
||||
"""
|
||||
with UserAPIKeyContext(user_id) as keys:
|
||||
return keys.get(provider)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def user_api_keys(user_id: Optional[str] = None):
|
||||
"""
|
||||
Context manager function for easier usage.
|
||||
|
||||
Usage:
|
||||
from services.user_api_key_context import user_api_keys
|
||||
|
||||
with user_api_keys(user_id) as keys:
|
||||
gemini_key = keys.get('gemini')
|
||||
"""
|
||||
context = UserAPIKeyContext(user_id)
|
||||
try:
|
||||
yield context.__enter__()
|
||||
finally:
|
||||
context.__exit__(None, None, None)
|
||||
|
||||
|
||||
# Convenience function for FastAPI dependency injection
|
||||
def get_user_api_keys(user_id: str) -> Dict[str, str]:
|
||||
"""
|
||||
Get user-specific API keys for use in FastAPI endpoints.
|
||||
|
||||
Args:
|
||||
user_id: User ID from current_user
|
||||
|
||||
Returns:
|
||||
Dictionary of API keys for this user
|
||||
"""
|
||||
with UserAPIKeyContext(user_id) as keys:
|
||||
return keys
|
||||
|
||||
|
||||
def get_gemini_key(user_id: Optional[str] = None) -> Optional[str]:
|
||||
"""Get Gemini API key for user."""
|
||||
return UserAPIKeyContext.get_user_key(user_id, 'gemini')
|
||||
|
||||
|
||||
def get_exa_key(user_id: Optional[str] = None) -> Optional[str]:
|
||||
"""Get Exa API key for user."""
|
||||
return UserAPIKeyContext.get_user_key(user_id, 'exa')
|
||||
|
||||
|
||||
def get_copilotkit_key(user_id: Optional[str] = None) -> Optional[str]:
|
||||
"""Get CopilotKit API key for user."""
|
||||
return UserAPIKeyContext.get_user_key(user_id, 'copilotkit')
|
||||
|
||||
Reference in New Issue
Block a user