ALwrity onboarding final step

This commit is contained in:
ajaysi
2025-10-10 23:19:28 +05:30
parent e3daebec16
commit b1ebe1034e
38 changed files with 4867 additions and 770 deletions

View File

@@ -165,10 +165,10 @@ class OnboardingManager:
raise HTTPException(status_code=500, detail=str(e))
@self.app.post("/api/onboarding/api-keys")
async def api_key_save(request: APIKeyRequest):
async def api_key_save(request: APIKeyRequest, current_user: dict = Depends(get_current_user)):
"""Save an API key for a provider."""
try:
return await save_api_key(request)
return await save_api_key(request, current_user)
except Exception as e:
logger.error(f"Error in api_key_save: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -95,6 +95,10 @@ class RouterManager:
from routers.error_logging import router as error_logging_router
self.include_router_safely(error_logging_router, "error_logging")
# Frontend environment manager router
from routers.frontend_env_manager import router as frontend_env_router
self.include_router_safely(frontend_env_router, "frontend_env_manager")
logger.info("✅ Core routers included successfully")
return True

View File

@@ -15,7 +15,20 @@ class APIKeyManagementService:
"""Service for handling API key management operations."""
def __init__(self):
# Initialize APIKeyManager with database support
self.api_key_manager = APIKeyManager()
# Ensure database service is available
if not hasattr(self.api_key_manager, 'use_database'):
self.api_key_manager.use_database = True
try:
from services.onboarding_database_service import OnboardingDatabaseService
self.api_key_manager.db_service = OnboardingDatabaseService()
logger.info("Database service initialized for APIKeyManager")
except Exception as e:
logger.warning(f"Database service not available: {e}")
self.api_key_manager.use_database = False
self.api_key_manager.db_service = None
# Simple cache for API keys
self._api_keys_cache = None
self._cache_timestamp = 0
@@ -75,9 +88,16 @@ class APIKeyManagementService:
logger.error(f"Error getting API keys for onboarding: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
async def save_api_key(self, provider: str, api_key: str, description: str = None) -> Dict[str, Any]:
async def save_api_key(self, provider: str, api_key: str, description: str = None, current_user: dict = None) -> Dict[str, Any]:
"""Save an API key for a provider."""
try:
logger.info(f"📝 save_api_key called for provider: {provider}")
# Set user_id on the API key manager if available
if current_user and current_user.get('id'):
self.api_key_manager.user_id = current_user['id']
logger.info(f"Set user_id on APIKeyManager: {current_user['id']}")
success = self.api_key_manager.save_api_key(provider, api_key)
if success:

View File

@@ -35,11 +35,11 @@ async def get_api_keys_for_onboarding():
raise HTTPException(status_code=500, detail="Internal server error")
async def save_api_key(request: APIKeyRequest):
async def save_api_key(request: APIKeyRequest, current_user: dict = None):
try:
from api.onboarding_utils.api_key_management_service import APIKeyManagementService
api_service = APIKeyManagementService()
return await api_service.save_api_key(request.provider, request.api_key, request.description)
return await api_service.save_api_key(request.provider, request.api_key, request.description, current_user)
except Exception as e:
logger.error(f"Error saving API key: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")

View File

@@ -110,6 +110,16 @@ async def rate_limit_middleware(request: Request, call_next):
"""Rate limiting middleware using modular utilities."""
return await rate_limiter.rate_limit_middleware(request, call_next)
# API key injection middleware for production (user-specific keys)
@app.middleware("http")
async def inject_user_api_keys(request: Request, call_next):
"""
Inject user-specific API keys into environment for the request duration.
This allows existing code using os.getenv() to work in production.
"""
from middleware.api_key_injection_middleware import api_key_injection_middleware
return await api_key_injection_middleware(request, call_next)
# Health check endpoints using modular utilities
@app.get("/health")
async def health():

View File

@@ -0,0 +1,26 @@
-- Migration: Add persona_data table for onboarding step 4
-- Created: 2025-10-10
-- Description: Adds table to store persona generation data from onboarding step 4
CREATE TABLE IF NOT EXISTS persona_data (
id SERIAL PRIMARY KEY,
session_id INTEGER NOT NULL,
core_persona JSONB,
platform_personas JSONB,
quality_metrics JSONB,
selected_platforms JSONB,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (session_id) REFERENCES onboarding_sessions(id) ON DELETE CASCADE
);
-- Add index for better query performance
CREATE INDEX IF NOT EXISTS idx_persona_data_session_id ON persona_data(session_id);
CREATE INDEX IF NOT EXISTS idx_persona_data_created_at ON persona_data(created_at);
-- Add comment to table
COMMENT ON TABLE persona_data IS 'Stores persona generation data from onboarding step 4';
COMMENT ON COLUMN persona_data.core_persona IS 'Core persona data (demographics, psychographics, etc.)';
COMMENT ON COLUMN persona_data.platform_personas IS 'Platform-specific personas (LinkedIn, Twitter, etc.)';
COMMENT ON COLUMN persona_data.quality_metrics IS 'Quality assessment metrics';
COMMENT ON COLUMN persona_data.selected_platforms IS 'Array of selected platforms';

View File

@@ -22,3 +22,6 @@ WORDPRESS_REDIRECT_URI=
# Development Settings
DISABLE_AUTH=false
# local development
DEPLOY_ENV=local

View File

@@ -0,0 +1,114 @@
"""
API Key Injection Middleware
Temporarily injects user-specific API keys into os.environ for the duration of the request.
This allows existing code that uses os.getenv('GEMINI_API_KEY') to work without modification.
IMPORTANT: This is a compatibility layer. For new code, use UserAPIKeyContext directly.
"""
import os
from fastapi import Request
from loguru import logger
from typing import Callable
from services.user_api_key_context import user_api_keys
class APIKeyInjectionMiddleware:
"""
Middleware that injects user-specific API keys into environment variables
for the duration of each request.
"""
def __init__(self):
self.original_keys = {}
async def __call__(self, request: Request, call_next: Callable):
"""
Inject user-specific API keys before processing request,
restore original values after request completes.
"""
# Try to extract user_id from Authorization header
user_id = None
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Bearer '):
try:
from middleware.auth_middleware import clerk_auth
token = auth_header.replace('Bearer ', '')
user = await clerk_auth.verify_token(token)
if user:
# Try different possible keys for user_id
user_id = user.get('user_id') or user.get('clerk_user_id') or user.get('id')
logger.debug(f"[API Key Injection] Extracted user_id: {user_id}")
except Exception as e:
logger.debug(f"[API Key Injection] Could not extract user from token: {e}")
if not user_id:
# No authenticated user, proceed without injection
return await call_next(request)
# Check if we're in production mode
is_production = os.getenv('DEPLOY_ENV', 'local') == 'production'
if not is_production:
# Local mode - keys already in .env, no injection needed
return await call_next(request)
# Get user-specific API keys from database
with user_api_keys(user_id) as user_keys:
if not user_keys:
logger.warning(f"No API keys found for user {user_id}")
return await call_next(request)
# Save original environment values
original_keys = {}
keys_to_inject = {
'gemini': 'GEMINI_API_KEY',
'exa': 'EXA_API_KEY',
'copilotkit': 'COPILOTKIT_API_KEY',
'openai': 'OPENAI_API_KEY',
'anthropic': 'ANTHROPIC_API_KEY',
'tavily': 'TAVILY_API_KEY',
'serper': 'SERPER_API_KEY',
'firecrawl': 'FIRECRAWL_API_KEY',
}
# Inject user-specific keys into environment
for provider, env_var in keys_to_inject.items():
if provider in user_keys and user_keys[provider]:
# Save original value (if any)
original_keys[env_var] = os.environ.get(env_var)
# Inject user-specific key
os.environ[env_var] = user_keys[provider]
logger.debug(f"[PRODUCTION] Injected {env_var} for user {user_id}")
try:
# Process request with user-specific keys in environment
response = await call_next(request)
return response
finally:
# CRITICAL: Restore original environment values
for env_var, original_value in original_keys.items():
if original_value is None:
# Key didn't exist before, remove it
os.environ.pop(env_var, None)
else:
# Restore original value
os.environ[env_var] = original_value
logger.debug(f"[PRODUCTION] Cleaned up environment for user {user_id}")
async def api_key_injection_middleware(request: Request, call_next: Callable):
"""
Middleware function that injects user-specific API keys into environment.
Usage in app.py:
app.middleware("http")(api_key_injection_middleware)
"""
middleware = APIKeyInjectionMiddleware()
return await middleware(request, call_next)

View File

@@ -16,6 +16,7 @@ class OnboardingSession(Base):
api_keys = relationship('APIKey', back_populates='session', cascade="all, delete-orphan")
website_analyses = relationship('WebsiteAnalysis', back_populates='session', cascade="all, delete-orphan")
research_preferences = relationship('ResearchPreferences', back_populates='session', cascade="all, delete-orphan", uselist=False)
persona_data = relationship('PersonaData', back_populates='session', cascade="all, delete-orphan", uselist=False)
def __repr__(self):
return f"<OnboardingSession(id={self.id}, user_id={self.user_id}, step={self.current_step}, progress={self.progress})>"
@@ -143,4 +144,40 @@ class ResearchPreferences(Base):
'recommended_settings': self.recommended_settings,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
}
}
class PersonaData(Base):
"""Stores persona generation data from onboarding step 4."""
__tablename__ = 'persona_data'
id = Column(Integer, primary_key=True, autoincrement=True)
session_id = Column(Integer, ForeignKey('onboarding_sessions.id', ondelete='CASCADE'), nullable=False)
# Persona generation results
core_persona = Column(JSON) # Core persona data (demographics, psychographics, etc.)
platform_personas = Column(JSON) # Platform-specific personas (LinkedIn, Twitter, etc.)
quality_metrics = Column(JSON) # Quality assessment metrics
selected_platforms = Column(JSON) # Array of selected platforms
# Metadata
created_at = Column(DateTime, default=func.now())
updated_at = Column(DateTime, default=func.now(), onupdate=func.now())
# Relationships
session = relationship('OnboardingSession', back_populates='persona_data')
def __repr__(self):
return f"<PersonaData(id={self.id}, session_id={self.session_id})>"
def to_dict(self):
"""Convert to dictionary for API responses."""
return {
'id': self.id,
'session_id': self.session_id,
'core_persona': self.core_persona,
'platform_personas': self.platform_personas,
'quality_metrics': self.quality_metrics,
'selected_platforms': self.selected_platforms,
'created_at': self.created_at.isoformat() if self.created_at else None,
'updated_at': self.updated_at.isoformat() if self.updated_at else None
}

View File

@@ -0,0 +1,110 @@
"""
Frontend Environment Manager
Handles updating frontend environment variables (for development purposes).
"""
from fastapi import APIRouter, HTTPException, Depends
from pydantic import BaseModel
from typing import Dict, Any, Optional
from loguru import logger
import os
from pathlib import Path
router = APIRouter(
prefix="/api/frontend-env",
tags=["Frontend Environment"],
)
class FrontendEnvUpdateRequest(BaseModel):
key: str
value: str
description: Optional[str] = None
@router.post("/update")
async def update_frontend_env(request: FrontendEnvUpdateRequest):
"""
Update frontend environment variable (for development purposes).
This writes to the frontend/.env file.
"""
try:
# Get the frontend directory path
backend_dir = Path(__file__).parent.parent
frontend_dir = backend_dir.parent / "frontend"
env_path = frontend_dir / ".env"
# Ensure the frontend directory exists
if not frontend_dir.exists():
raise HTTPException(status_code=404, detail="Frontend directory not found")
# Read existing .env file
env_lines = []
if env_path.exists():
with open(env_path, 'r') as f:
env_lines = f.readlines()
# Update or add the environment variable
key_found = False
updated_lines = []
for line in env_lines:
if line.startswith(f"{request.key}="):
updated_lines.append(f"{request.key}={request.value}\n")
key_found = True
else:
updated_lines.append(line)
if not key_found:
# Add comment if description provided
if request.description:
updated_lines.append(f"# {request.description}\n")
updated_lines.append(f"{request.key}={request.value}\n")
# Write back to .env file
with open(env_path, 'w') as f:
f.writelines(updated_lines)
logger.info(f"Updated frontend environment variable: {request.key}")
return {
"success": True,
"message": f"Environment variable {request.key} updated successfully",
"key": request.key,
"value": request.value
}
except Exception as e:
logger.error(f"Error updating frontend environment: {e}")
raise HTTPException(status_code=500, detail=f"Failed to update environment variable: {str(e)}")
@router.get("/status")
async def get_frontend_env_status():
"""
Get status of frontend environment file.
"""
try:
# Get the frontend directory path
backend_dir = Path(__file__).parent.parent
frontend_dir = backend_dir.parent / "frontend"
env_path = frontend_dir / ".env"
if not env_path.exists():
return {
"exists": False,
"path": str(env_path),
"message": "Frontend .env file does not exist"
}
# Read and return basic info about the .env file
with open(env_path, 'r') as f:
content = f.read()
return {
"exists": True,
"path": str(env_path),
"size": len(content),
"lines": len(content.splitlines()),
"message": "Frontend .env file exists"
}
except Exception as e:
logger.error(f"Error checking frontend environment status: {e}")
raise HTTPException(status_code=500, detail=f"Failed to check environment status: {str(e)}")

View File

@@ -0,0 +1,124 @@
"""
Script to create the persona_data table for onboarding step 4.
This migration adds support for storing persona generation data.
Usage:
python backend/scripts/create_persona_data_table.py
"""
import sys
import os
from pathlib import Path
# Add backend directory to path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from loguru import logger
from sqlalchemy import inspect
def create_persona_data_table():
"""Create the persona_data table."""
try:
# Import after path is set
from services.database import engine
from models.onboarding import Base as OnboardingBase, PersonaData
logger.info("🔍 Checking if persona_data table exists...")
# Check if table already exists
inspector = inspect(engine)
existing_tables = inspector.get_table_names()
if 'persona_data' in existing_tables:
logger.info("✅ persona_data table already exists")
return True
logger.info("📊 Creating persona_data table...")
# Create only the persona_data table
PersonaData.__table__.create(bind=engine, checkfirst=True)
logger.info("✅ persona_data table created successfully")
# Verify creation
inspector = inspect(engine)
existing_tables = inspector.get_table_names()
if 'persona_data' in existing_tables:
logger.info("✅ Verification successful - persona_data table exists")
# Show table structure
columns = inspector.get_columns('persona_data')
logger.info(f"📋 Table structure ({len(columns)} columns):")
for col in columns:
logger.info(f" - {col['name']}: {col['type']}")
return True
else:
logger.error("❌ Table creation verification failed")
return False
except Exception as e:
logger.error(f"❌ Error creating persona_data table: {e}")
import traceback
logger.error(traceback.format_exc())
return False
def check_onboarding_tables():
"""Check all onboarding-related tables."""
try:
from services.database import engine
from sqlalchemy import inspect
inspector = inspect(engine)
existing_tables = inspector.get_table_names()
onboarding_tables = [
'onboarding_sessions',
'api_keys',
'website_analyses',
'research_preferences',
'persona_data'
]
logger.info("📋 Onboarding Tables Status:")
for table in onboarding_tables:
status = "" if table in existing_tables else ""
logger.info(f" {status} {table}")
return True
except Exception as e:
logger.error(f"Error checking tables: {e}")
return False
if __name__ == "__main__":
logger.info("=" * 60)
logger.info("Persona Data Table Migration")
logger.info("=" * 60)
# Check existing tables
check_onboarding_tables()
logger.info("")
# Create persona_data table
if create_persona_data_table():
logger.info("")
logger.info("=" * 60)
logger.info("✅ Migration completed successfully!")
logger.info("=" * 60)
# Check again to confirm
logger.info("")
check_onboarding_tables()
sys.exit(0)
else:
logger.error("")
logger.error("=" * 60)
logger.error("❌ Migration failed!")
logger.error("=" * 60)
sys.exit(1)

View File

@@ -0,0 +1,338 @@
"""
Database Verification Script for Onboarding Data
Verifies that all onboarding steps data is properly saved to the database.
Usage:
python backend/scripts/verify_onboarding_data.py [user_id]
Example:
python backend/scripts/verify_onboarding_data.py user_33Gz1FPI86VDXhRY8QN4ragRFGN
"""
import sys
import os
from pathlib import Path
# Add backend directory to path
backend_dir = Path(__file__).parent.parent
sys.path.insert(0, str(backend_dir))
from loguru import logger
from sqlalchemy import inspect, text
from typing import Optional
import json
def get_user_id_from_args() -> Optional[str]:
"""Get user_id from command line arguments."""
if len(sys.argv) > 1:
return sys.argv[1]
return None
def verify_table_exists(table_name: str, inspector) -> bool:
"""Check if a table exists in the database."""
tables = inspector.get_table_names()
exists = table_name in tables
if exists:
logger.info(f"✅ Table '{table_name}' exists")
# Show column count
columns = inspector.get_columns(table_name)
logger.info(f" Columns: {len(columns)}")
else:
logger.error(f"❌ Table '{table_name}' does NOT exist")
return exists
def verify_onboarding_session(user_id: str, db):
"""Verify onboarding session data."""
try:
from models.onboarding import OnboardingSession
session = db.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if session:
logger.info(f"✅ Onboarding Session found for user: {user_id}")
logger.info(f" Session ID: {session.id}")
logger.info(f" Current Step: {session.current_step}")
logger.info(f" Progress: {session.progress}%")
logger.info(f" Started At: {session.started_at}")
logger.info(f" Updated At: {session.updated_at}")
return session.id
else:
logger.error(f"❌ No onboarding session found for user: {user_id}")
return None
except Exception as e:
logger.error(f"Error verifying onboarding session: {e}")
return None
def verify_api_keys(session_id: int, user_id: str, db):
"""Verify API keys data (Step 1)."""
try:
from models.onboarding import APIKey
api_keys = db.query(APIKey).filter(
APIKey.session_id == session_id
).all()
if api_keys:
logger.info(f"✅ Step 1 (API Keys): Found {len(api_keys)} API key(s)")
for key in api_keys:
# Mask the key for security
masked_key = f"{key.key[:8]}...{key.key[-4:]}" if len(key.key) > 12 else "***"
logger.info(f" - Provider: {key.provider}")
logger.info(f" Key: {masked_key}")
logger.info(f" Created: {key.created_at}")
else:
logger.warning(f"⚠️ Step 1 (API Keys): No API keys found")
except Exception as e:
logger.error(f"Error verifying API keys: {e}")
def verify_website_analysis(session_id: int, user_id: str, db):
"""Verify website analysis data (Step 2)."""
try:
from models.onboarding import WebsiteAnalysis
analysis = db.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session_id
).first()
if analysis:
logger.info(f"✅ Step 2 (Website Analysis): Data found")
logger.info(f" Website URL: {analysis.website_url}")
logger.info(f" Analysis Date: {analysis.analysis_date}")
logger.info(f" Status: {analysis.status}")
if analysis.writing_style:
logger.info(f" Writing Style: {len(analysis.writing_style)} attributes")
if analysis.content_characteristics:
logger.info(f" Content Characteristics: {len(analysis.content_characteristics)} attributes")
if analysis.target_audience:
logger.info(f" Target Audience: {len(analysis.target_audience)} attributes")
else:
logger.warning(f"⚠️ Step 2 (Website Analysis): No data found")
except Exception as e:
logger.error(f"Error verifying website analysis: {e}")
def verify_research_preferences(session_id: int, user_id: str, db):
"""Verify research preferences data (Step 3)."""
try:
from models.onboarding import ResearchPreferences
prefs = db.query(ResearchPreferences).filter(
ResearchPreferences.session_id == session_id
).first()
if prefs:
logger.info(f"✅ Step 3 (Research Preferences): Data found")
logger.info(f" Research Depth: {prefs.research_depth}")
logger.info(f" Content Types: {prefs.content_types}")
logger.info(f" Auto Research: {prefs.auto_research}")
logger.info(f" Factual Content: {prefs.factual_content}")
else:
logger.warning(f"⚠️ Step 3 (Research Preferences): No data found")
except Exception as e:
logger.error(f"Error verifying research preferences: {e}")
def verify_persona_data(session_id: int, user_id: str, db):
"""Verify persona data (Step 4) - THE NEW FIX!"""
try:
from models.onboarding import PersonaData
persona = db.query(PersonaData).filter(
PersonaData.session_id == session_id
).first()
if persona:
logger.info(f"✅ Step 4 (Persona Generation): Data found ⭐")
if persona.core_persona:
logger.info(f" Core Persona: Present")
if isinstance(persona.core_persona, dict):
logger.info(f" Attributes: {len(persona.core_persona)} fields")
if persona.platform_personas:
logger.info(f" Platform Personas: Present")
if isinstance(persona.platform_personas, dict):
platforms = list(persona.platform_personas.keys())
logger.info(f" Platforms: {', '.join(platforms)}")
if persona.quality_metrics:
logger.info(f" Quality Metrics: Present")
if isinstance(persona.quality_metrics, dict):
logger.info(f" Metrics: {len(persona.quality_metrics)} fields")
if persona.selected_platforms:
logger.info(f" Selected Platforms: {persona.selected_platforms}")
logger.info(f" Created At: {persona.created_at}")
logger.info(f" Updated At: {persona.updated_at}")
else:
logger.error(f"❌ Step 4 (Persona Generation): No data found - THIS IS THE BUG WE FIXED!")
except Exception as e:
logger.error(f"Error verifying persona data: {e}")
import traceback
logger.error(traceback.format_exc())
def show_raw_sql_query_example(user_id: str):
"""Show example SQL queries for manual verification."""
logger.info("")
logger.info("=" * 60)
logger.info("📋 Raw SQL Queries for Manual Verification:")
logger.info("=" * 60)
queries = [
("Onboarding Session",
f"SELECT * FROM onboarding_sessions WHERE user_id = '{user_id}';"),
("API Keys",
f"""SELECT ak.* FROM api_keys ak
JOIN onboarding_sessions os ON ak.session_id = os.id
WHERE os.user_id = '{user_id}';"""),
("Website Analysis",
f"""SELECT wa.website_url, wa.analysis_date, wa.status
FROM website_analyses wa
JOIN onboarding_sessions os ON wa.session_id = os.id
WHERE os.user_id = '{user_id}';"""),
("Research Preferences",
f"""SELECT rp.research_depth, rp.content_types, rp.auto_research
FROM research_preferences rp
JOIN onboarding_sessions os ON rp.session_id = os.id
WHERE os.user_id = '{user_id}';"""),
("Persona Data (NEW!)",
f"""SELECT pd.* FROM persona_data pd
JOIN onboarding_sessions os ON pd.session_id = os.id
WHERE os.user_id = '{user_id}';"""),
]
for title, query in queries:
logger.info(f"\n{title}:")
logger.info(f" {query}")
def count_all_records(db):
"""Count records in all onboarding tables."""
logger.info("")
logger.info("=" * 60)
logger.info("📊 Overall Database Statistics:")
logger.info("=" * 60)
try:
from models.onboarding import (
OnboardingSession, APIKey, WebsiteAnalysis,
ResearchPreferences, PersonaData
)
counts = {
"Onboarding Sessions": db.query(OnboardingSession).count(),
"API Keys": db.query(APIKey).count(),
"Website Analyses": db.query(WebsiteAnalysis).count(),
"Research Preferences": db.query(ResearchPreferences).count(),
"Persona Data": db.query(PersonaData).count(),
}
for table, count in counts.items():
logger.info(f" {table}: {count} record(s)")
except Exception as e:
logger.error(f"Error counting records: {e}")
def main():
"""Main verification function."""
logger.info("=" * 60)
logger.info("🔍 Onboarding Database Verification")
logger.info("=" * 60)
# Get user_id
user_id = get_user_id_from_args()
if not user_id:
logger.warning("⚠️ No user_id provided. Will show overall statistics only.")
logger.info("Usage: python backend/scripts/verify_onboarding_data.py <user_id>")
try:
from services.database import SessionLocal, engine
from sqlalchemy import inspect
# Check tables exist
logger.info("")
logger.info("=" * 60)
logger.info("1⃣ Verifying Database Tables:")
logger.info("=" * 60)
inspector = inspect(engine)
tables = [
'onboarding_sessions',
'api_keys',
'website_analyses',
'research_preferences',
'persona_data'
]
all_exist = True
for table in tables:
if not verify_table_exists(table, inspector):
all_exist = False
if not all_exist:
logger.error("")
logger.error("❌ Some tables are missing! Run migrations first.")
return False
# Count all records
db = SessionLocal()
try:
count_all_records(db)
# If user_id provided, show detailed data
if user_id:
logger.info("")
logger.info("=" * 60)
logger.info(f"2⃣ Verifying Data for User: {user_id}")
logger.info("=" * 60)
# Verify session
session_id = verify_onboarding_session(user_id, db)
if session_id:
logger.info("")
# Verify each step's data
verify_api_keys(session_id, user_id, db)
logger.info("")
verify_website_analysis(session_id, user_id, db)
logger.info("")
verify_research_preferences(session_id, user_id, db)
logger.info("")
verify_persona_data(session_id, user_id, db)
# Show SQL examples
show_raw_sql_query_example(user_id)
logger.info("")
logger.info("=" * 60)
logger.info("✅ Verification Complete!")
logger.info("=" * 60)
return True
finally:
db.close()
except Exception as e:
logger.error(f"❌ Verification failed: {e}")
import traceback
logger.error(traceback.format_exc())
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)

View File

@@ -35,14 +35,31 @@ class StepData:
class OnboardingProgress:
"""Manages onboarding progress with persistence and validation."""
def __init__(self, progress_file: Optional[str] = None):
def __init__(self, progress_file: Optional[str] = None, user_id: Optional[str] = None):
self.steps = self._initialize_steps()
self.current_step = 1
self.started_at = datetime.now().isoformat()
self.last_updated = datetime.now().isoformat()
self.is_completed = False
self.completed_at = None
self.progress_file = progress_file or ".onboarding_progress.json"
self.user_id = user_id # Add user_id for database isolation
# Use user-specific file for backward compatibility
if user_id:
self.progress_file = progress_file or f".onboarding_progress_{user_id}.json"
else:
self.progress_file = progress_file or ".onboarding_progress.json"
# Initialize database service for dual persistence
try:
from services.onboarding_database_service import OnboardingDatabaseService
self.db_service = OnboardingDatabaseService()
self.use_database = True
logger.info(f"Database service initialized for user {user_id}")
except Exception as e:
logger.warning(f"Database service not available, using file only: {e}")
self.db_service = None
self.use_database = False
# Load existing progress if available
self.load_progress()
@@ -192,8 +209,9 @@ class OnboardingProgress:
logger.info("Onboarding completed successfully")
def save_progress(self):
"""Save progress to file."""
"""Save progress to both file and database (dual persistence)."""
try:
# Save to JSON file (backward compatibility)
progress_data = {
"steps": [{
"step_number": step.step_number,
@@ -215,6 +233,65 @@ class OnboardingProgress:
json.dump(progress_data, f, indent=2)
logger.debug(f"Progress saved to {self.progress_file}")
# Also save to database if available and user_id is set
if self.use_database and self.db_service and self.user_id:
try:
from services.database import SessionLocal
db = SessionLocal()
try:
# Update session progress
self.db_service.update_step(self.user_id, self.current_step, db)
# Calculate progress percentage
completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED)
progress_pct = (completed_count / len(self.steps)) * 100
self.db_service.update_progress(self.user_id, progress_pct, db)
# Save step-specific data to appropriate tables
for step in self.steps:
if step.status == StepStatus.COMPLETED and step.data:
if step.step_number == 1: # API Keys
api_keys = step.data.get('api_keys', {})
for provider, key in api_keys.items():
if key:
# Save to database (for user isolation in production)
self.db_service.save_api_key(self.user_id, provider, key, db)
# Also save to .env file ONLY in local development
# This allows local developers to have keys in .env for convenience
# In production, keys are fetched from database per user
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
if is_local:
try:
from services.api_key_manager import APIKeyManager
api_key_manager = APIKeyManager()
api_key_manager.save_api_key(provider, key)
logger.info(f"[LOCAL] API key for {provider} saved to .env file")
except Exception as env_error:
logger.warning(f"[LOCAL] Failed to save {provider} API key to .env file: {env_error}")
else:
logger.info(f"[PRODUCTION] API key for {provider} saved to database only (user: {self.user_id})")
# Log database save confirmation
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
elif step.step_number == 2: # Website Analysis
self.db_service.save_website_analysis(self.user_id, step.data, db)
logger.info(f"✅ DATABASE: Website analysis saved to database for user {self.user_id}")
elif step.step_number == 3: # Research Preferences
self.db_service.save_research_preferences(self.user_id, step.data, db)
logger.info(f"✅ DATABASE: Research preferences saved to database for user {self.user_id}")
elif step.step_number == 4: # Persona Generation
self.db_service.save_persona_data(self.user_id, step.data, db)
logger.info(f"✅ DATABASE: Persona data saved to database for user {self.user_id}")
logger.info(f"Progress also saved to database for user {self.user_id}")
finally:
db.close()
except Exception as db_error:
logger.warning(f"Failed to save to database, JSON file still saved: {db_error}")
# Don't fail if database save fails - JSON is still working
except Exception as e:
logger.error(f"Error saving progress: {str(e)}")
@@ -423,8 +500,34 @@ class APIKeyManager:
try:
if provider in self.api_keys:
self.api_keys[provider] = api_key
self._save_to_env_file(provider, api_key)
logger.info(f"API key saved for {provider}")
# Save to database if available and user_id is set
if hasattr(self, 'use_database') and self.use_database and hasattr(self, 'db_service') and self.db_service and hasattr(self, 'user_id') and self.user_id:
try:
from services.database import SessionLocal
db = SessionLocal()
try:
self.db_service.save_api_key(self.user_id, provider, api_key, db)
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
finally:
db.close()
except Exception as db_error:
logger.warning(f"Failed to save {provider} API key to database: {db_error}")
# Also save to .env file in local mode
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
if is_local:
# Special handling for CopilotKit - save to frontend/.env
if provider == 'copilotkit':
self._save_to_frontend_env(api_key)
logger.info(f"[LOCAL] CopilotKit API key saved to frontend/.env file")
else:
# Save other keys to backend/.env
self._save_to_env_file(provider, api_key)
logger.info(f"[LOCAL] API key for {provider} saved to backend/.env file")
else:
logger.info(f"[PRODUCTION] API key for {provider} saved to memory only (database handles persistence)")
return True
else:
logger.error(f"Unknown provider: {provider}")
@@ -490,8 +593,50 @@ class APIKeyManager:
"total_providers": len(self.api_keys)
}
def _save_to_frontend_env(self, api_key: str):
"""Save CopilotKit API key to frontend/.env file."""
try:
# Get the frontend directory path
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
frontend_dir = os.path.join(os.path.dirname(backend_dir), "frontend")
env_path = os.path.join(frontend_dir, ".env")
# Read existing .env file
if os.path.exists(env_path):
with open(env_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
else:
lines = []
# Update or add REACT_APP_COPILOTKIT_API_KEY
key_found = False
updated_lines = []
env_var = "REACT_APP_COPILOTKIT_API_KEY"
for line in lines:
if line.startswith(f"{env_var}="):
updated_lines.append(f"{env_var}={api_key}\n")
key_found = True
else:
updated_lines.append(line)
if not key_found:
# Ensure the file ends with a newline before adding new key
if updated_lines and not updated_lines[-1].endswith('\n'):
updated_lines[-1] += '\n'
updated_lines.append(f"{env_var}={api_key}\n")
# Write back to frontend .env file
with open(env_path, 'w', encoding='utf-8') as f:
f.writelines(updated_lines)
logger.debug(f"CopilotKit API key saved to frontend .env file")
except Exception as e:
logger.error(f"Error saving to frontend .env file: {str(e)}")
def _save_to_env_file(self, provider: str, api_key: str):
"""Save API key to .env file."""
"""Save API key to backend .env file."""
try:
env_mapping = {
"openai": "OPENAI_API_KEY",
@@ -513,11 +658,10 @@ class APIKeyManager:
os.environ[env_var] = api_key
# Update .env file - use backend directory path
import os
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
env_path = os.path.join(backend_dir, ".env")
if os.path.exists(env_path):
with open(env_path, 'r') as f:
with open(env_path, 'r', encoding='utf-8', errors='ignore') as f:
lines = f.readlines()
else:
lines = []
@@ -532,13 +676,23 @@ class APIKeyManager:
updated_lines.append(line)
if not key_found:
# Ensure the file ends with a newline before adding new key
if updated_lines and not updated_lines[-1].endswith('\n'):
updated_lines[-1] += '\n'
updated_lines.append(f"{env_var}={api_key}\n")
with open(env_path, 'w') as f:
with open(env_path, 'w', encoding='utf-8') as f:
f.writelines(updated_lines)
# Reload environment variables
load_dotenv(override=True)
# Reload environment variables into current process
load_dotenv(env_path, override=True)
# Verify the key is now in environment
loaded_key = os.environ.get(env_var)
if loaded_key == api_key:
logger.info(f"{env_var} loaded into environment (available for immediate use)")
else:
logger.warning(f"⚠️ {env_var} written to .env but not in environment yet")
logger.debug(f"API key saved to .env file for {provider}")
except Exception as e:
@@ -555,13 +709,17 @@ def get_onboarding_progress() -> OnboardingProgress:
return get_onboarding_progress._instance
def get_onboarding_progress_for_user(user_id: str) -> OnboardingProgress:
"""Get or create a per-user onboarding progress instance persisted to a user-specific file."""
"""Get or create a per-user onboarding progress instance with database persistence."""
global _user_onboarding_progress_cache
safe_user_id = ''.join([c if c.isalnum() or c in ('-', '_') else '_' for c in str(user_id)])
if safe_user_id in _user_onboarding_progress_cache:
return _user_onboarding_progress_cache[safe_user_id]
# Create user-specific progress file for backward compatibility
progress_file = f".onboarding_progress_{safe_user_id}.json"
instance = OnboardingProgress(progress_file=progress_file)
# Pass user_id to enable database persistence
instance = OnboardingProgress(progress_file=progress_file, user_id=user_id)
_user_onboarding_progress_cache[safe_user_id] = instance
return instance

View File

@@ -0,0 +1,418 @@
"""
Onboarding Database Service
Provides database-backed storage for onboarding progress with user isolation.
This replaces the JSON file-based storage with proper database persistence.
"""
from typing import Dict, Any, Optional, List
from datetime import datetime
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError
from models.onboarding import OnboardingSession, APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData
from services.database import get_db
class OnboardingDatabaseService:
"""Database service for onboarding with user isolation."""
def __init__(self, db: Session = None):
"""Initialize with optional database session."""
self.db = db
def get_or_create_session(self, user_id: str, db: Session = None) -> OnboardingSession:
"""Get existing onboarding session or create new one for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
# Try to get existing session for this user
session = session_db.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if session:
logger.info(f"Found existing onboarding session for user {user_id}")
return session
# Create new session
session = OnboardingSession(
user_id=user_id,
current_step=1,
progress=0.0,
started_at=datetime.now()
)
session_db.add(session)
session_db.commit()
session_db.refresh(session)
logger.info(f"Created new onboarding session for user {user_id}")
return session
except SQLAlchemyError as e:
logger.error(f"Database error in get_or_create_session: {e}")
session_db.rollback()
raise
def get_session_by_user(self, user_id: str, db: Session = None) -> Optional[OnboardingSession]:
"""Get onboarding session for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
return session_db.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
except SQLAlchemyError as e:
logger.error(f"Error getting session: {e}")
return None
def update_step(self, user_id: str, step_number: int, db: Session = None) -> bool:
"""Update current step for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.current_step = step_number
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Updated user {user_id} to step {step_number}")
return True
except SQLAlchemyError as e:
logger.error(f"Error updating step: {e}")
session_db.rollback()
return False
def update_progress(self, user_id: str, progress: float, db: Session = None) -> bool:
"""Update progress percentage for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.progress = progress
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Updated user {user_id} progress to {progress}%")
return True
except SQLAlchemyError as e:
logger.error(f"Error updating progress: {e}")
session_db.rollback()
return False
def save_api_key(self, user_id: str, provider: str, api_key: str, db: Session = None) -> bool:
"""Save API key for user with isolation."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
# Get user's onboarding session
session = self.get_or_create_session(user_id, session_db)
# Check if key already exists for this provider and session
existing_key = session_db.query(APIKey).filter(
APIKey.session_id == session.id,
APIKey.provider == provider
).first()
if existing_key:
# Update existing key
existing_key.key = api_key
existing_key.updated_at = datetime.now()
logger.info(f"Updated {provider} API key for user {user_id}")
else:
# Create new key
new_key = APIKey(
session_id=session.id,
provider=provider,
key=api_key
)
session_db.add(new_key)
logger.info(f"Created new {provider} API key for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving API key: {e}")
session_db.rollback()
return False
def get_api_keys(self, user_id: str, db: Session = None) -> Dict[str, str]:
"""Get all API keys for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return {}
keys = session_db.query(APIKey).filter(
APIKey.session_id == session.id
).all()
return {key.provider: key.key for key in keys}
except SQLAlchemyError as e:
logger.error(f"Error getting API keys: {e}")
return {}
def save_website_analysis(self, user_id: str, analysis_data: Dict[str, Any], db: Session = None) -> bool:
"""Save website analysis for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Check if analysis already exists
existing = session_db.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
if existing:
# Update existing
existing.website_url = analysis_data.get('website_url', existing.website_url)
existing.writing_style = analysis_data.get('writing_style')
existing.content_characteristics = analysis_data.get('content_characteristics')
existing.target_audience = analysis_data.get('target_audience')
existing.content_type = analysis_data.get('content_type')
existing.recommended_settings = analysis_data.get('recommended_settings')
existing.crawl_result = analysis_data.get('crawl_result')
existing.style_patterns = analysis_data.get('style_patterns')
existing.style_guidelines = analysis_data.get('style_guidelines')
existing.status = analysis_data.get('status', 'completed')
existing.updated_at = datetime.now()
logger.info(f"Updated website analysis for user {user_id}")
else:
# Create new
analysis = WebsiteAnalysis(
session_id=session.id,
website_url=analysis_data.get('website_url', ''),
writing_style=analysis_data.get('writing_style'),
content_characteristics=analysis_data.get('content_characteristics'),
target_audience=analysis_data.get('target_audience'),
content_type=analysis_data.get('content_type'),
recommended_settings=analysis_data.get('recommended_settings'),
crawl_result=analysis_data.get('crawl_result'),
style_patterns=analysis_data.get('style_patterns'),
style_guidelines=analysis_data.get('style_guidelines'),
status=analysis_data.get('status', 'completed')
)
session_db.add(analysis)
logger.info(f"Created website analysis for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving website analysis: {e}")
session_db.rollback()
return False
def get_website_analysis(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get website analysis for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
analysis = session_db.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
return analysis.to_dict() if analysis else None
except SQLAlchemyError as e:
logger.error(f"Error getting website analysis: {e}")
return None
def save_research_preferences(self, user_id: str, preferences: Dict[str, Any], db: Session = None) -> bool:
"""Save research preferences for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Check if preferences already exist
existing = session_db.query(ResearchPreferences).filter(
ResearchPreferences.session_id == session.id
).first()
if existing:
# Update existing
existing.research_depth = preferences.get('research_depth', existing.research_depth)
existing.content_types = preferences.get('content_types', existing.content_types)
existing.auto_research = preferences.get('auto_research', existing.auto_research)
existing.factual_content = preferences.get('factual_content', existing.factual_content)
existing.writing_style = preferences.get('writing_style')
existing.content_characteristics = preferences.get('content_characteristics')
existing.target_audience = preferences.get('target_audience')
existing.recommended_settings = preferences.get('recommended_settings')
existing.updated_at = datetime.now()
logger.info(f"Updated research preferences for user {user_id}")
else:
# Create new
prefs = ResearchPreferences(
session_id=session.id,
research_depth=preferences.get('research_depth', 'standard'),
content_types=preferences.get('content_types', []),
auto_research=preferences.get('auto_research', True),
factual_content=preferences.get('factual_content', True),
writing_style=preferences.get('writing_style'),
content_characteristics=preferences.get('content_characteristics'),
target_audience=preferences.get('target_audience'),
recommended_settings=preferences.get('recommended_settings')
)
session_db.add(prefs)
logger.info(f"Created research preferences for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving research preferences: {e}")
session_db.rollback()
return False
def save_persona_data(self, user_id: str, persona_data: Dict[str, Any], db: Session = None) -> bool:
"""Save persona data for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
# Check if persona data already exists for this user
existing = session_db.query(PersonaData).filter(
PersonaData.session_id == session.id
).first()
if existing:
# Update existing persona data
existing.core_persona = persona_data.get('corePersona')
existing.platform_personas = persona_data.get('platformPersonas')
existing.quality_metrics = persona_data.get('qualityMetrics')
existing.selected_platforms = persona_data.get('selectedPlatforms', [])
existing.updated_at = datetime.utcnow()
logger.info(f"Updated persona data for user {user_id}")
else:
# Create new persona data record
persona = PersonaData(
session_id=session.id,
core_persona=persona_data.get('corePersona'),
platform_personas=persona_data.get('platformPersonas'),
quality_metrics=persona_data.get('qualityMetrics'),
selected_platforms=persona_data.get('selectedPlatforms', [])
)
session_db.add(persona)
logger.info(f"Created persona data for user {user_id}")
session_db.commit()
return True
except SQLAlchemyError as e:
logger.error(f"Error saving persona data: {e}")
session_db.rollback()
return False
def get_research_preferences(self, user_id: str, db: Session = None) -> Optional[Dict[str, Any]]:
"""Get research preferences for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
return None
prefs = session_db.query(ResearchPreferences).filter(
ResearchPreferences.session_id == session.id
).first()
return prefs.to_dict() if prefs else None
except SQLAlchemyError as e:
logger.error(f"Error getting research preferences: {e}")
return None
def mark_onboarding_complete(self, user_id: str, db: Session = None) -> bool:
"""Mark onboarding as complete for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_or_create_session(user_id, session_db)
session.current_step = 6 # Final step
session.progress = 100.0
session.updated_at = datetime.now()
session_db.commit()
logger.info(f"Marked onboarding complete for user {user_id}")
return True
except SQLAlchemyError as e:
logger.error(f"Error marking onboarding complete: {e}")
session_db.rollback()
return False
def get_onboarding_status(self, user_id: str, db: Session = None) -> Dict[str, Any]:
"""Get comprehensive onboarding status for user."""
session_db = db or self.db
if not session_db:
raise ValueError("Database session required")
try:
session = self.get_session_by_user(user_id, session_db)
if not session:
# User hasn't started onboarding yet
return {
"is_completed": False,
"current_step": 1,
"progress": 0.0,
"started_at": None,
"updated_at": None
}
return {
"is_completed": session.current_step >= 6 and session.progress >= 100.0,
"current_step": session.current_step,
"progress": session.progress,
"started_at": session.started_at.isoformat() if session.started_at else None,
"updated_at": session.updated_at.isoformat() if session.updated_at else None
}
except SQLAlchemyError as e:
logger.error(f"Error getting onboarding status: {e}")
return {
"is_completed": False,
"current_step": 1,
"progress": 0.0,
"started_at": None,
"updated_at": None
}

View File

@@ -0,0 +1,150 @@
"""
User API Key Context Manager
Provides user-specific API keys to backend services.
In development: Uses .env file
In production: Fetches from database per user
"""
import os
from typing import Optional, Dict
from loguru import logger
from contextlib import contextmanager
class UserAPIKeyContext:
"""
Context manager for user-specific API keys.
Usage:
with UserAPIKeyContext(user_id) as api_keys:
gemini_key = api_keys.get('gemini')
exa_key = api_keys.get('exa')
# Use keys for this specific user
"""
def __init__(self, user_id: Optional[str] = None):
"""
Initialize with optional user_id.
Args:
user_id: User ID to fetch keys for. If None, uses .env keys (local mode)
"""
self.user_id = user_id
self.keys: Dict[str, str] = {}
self._is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
def __enter__(self):
"""Load API keys when entering context."""
if self._is_local:
# Local mode: Use .env file
self.keys = self._load_from_env()
logger.debug(f"[LOCAL] Loaded API keys from .env file")
elif self.user_id:
# Production mode: Fetch from database
self.keys = self._load_from_database(self.user_id)
logger.debug(f"[PRODUCTION] Loaded API keys from database for user {self.user_id}")
else:
logger.warning("No user_id provided in production mode - using empty keys")
self.keys = {}
return self.keys
def __exit__(self, exc_type, exc_val, exc_tb):
"""Clean up when exiting context."""
self.keys.clear()
return False # Don't suppress exceptions
def _load_from_env(self) -> Dict[str, str]:
"""Load API keys from environment variables (.env file)."""
return {
'gemini': os.getenv('GEMINI_API_KEY', ''),
'exa': os.getenv('EXA_API_KEY', ''),
'copilotkit': os.getenv('COPILOTKIT_API_KEY', ''),
'openai': os.getenv('OPENAI_API_KEY', ''),
'anthropic': os.getenv('ANTHROPIC_API_KEY', ''),
'tavily': os.getenv('TAVILY_API_KEY', ''),
'serper': os.getenv('SERPER_API_KEY', ''),
'firecrawl': os.getenv('FIRECRAWL_API_KEY', ''),
}
def _load_from_database(self, user_id: str) -> Dict[str, str]:
"""Load API keys from database for specific user."""
try:
from services.onboarding_database_service import OnboardingDatabaseService
from services.database import SessionLocal
db_service = OnboardingDatabaseService()
db = SessionLocal()
try:
keys = db_service.get_api_keys(user_id, db)
logger.info(f"Loaded {len(keys)} API keys from database for user {user_id}")
return keys
finally:
db.close()
except Exception as e:
logger.error(f"Failed to load API keys from database for user {user_id}: {e}")
return {}
@staticmethod
def get_user_key(user_id: Optional[str], provider: str) -> Optional[str]:
"""
Convenience method to get a single API key for a user.
Args:
user_id: User ID (None for development mode)
provider: Provider name (e.g., 'gemini', 'exa')
Returns:
API key string or None
"""
with UserAPIKeyContext(user_id) as keys:
return keys.get(provider)
@contextmanager
def user_api_keys(user_id: Optional[str] = None):
"""
Context manager function for easier usage.
Usage:
from services.user_api_key_context import user_api_keys
with user_api_keys(user_id) as keys:
gemini_key = keys.get('gemini')
"""
context = UserAPIKeyContext(user_id)
try:
yield context.__enter__()
finally:
context.__exit__(None, None, None)
# Convenience function for FastAPI dependency injection
def get_user_api_keys(user_id: str) -> Dict[str, str]:
"""
Get user-specific API keys for use in FastAPI endpoints.
Args:
user_id: User ID from current_user
Returns:
Dictionary of API keys for this user
"""
with UserAPIKeyContext(user_id) as keys:
return keys
def get_gemini_key(user_id: Optional[str] = None) -> Optional[str]:
"""Get Gemini API key for user."""
return UserAPIKeyContext.get_user_key(user_id, 'gemini')
def get_exa_key(user_id: Optional[str] = None) -> Optional[str]:
"""Get Exa API key for user."""
return UserAPIKeyContext.get_user_key(user_id, 'exa')
def get_copilotkit_key(user_id: Optional[str] = None) -> Optional[str]:
"""Get CopilotKit API key for user."""
return UserAPIKeyContext.get_user_key(user_id, 'copilotkit')