ALwrity onboarding final step
This commit is contained in:
@@ -35,14 +35,31 @@ class StepData:
|
||||
class OnboardingProgress:
|
||||
"""Manages onboarding progress with persistence and validation."""
|
||||
|
||||
def __init__(self, progress_file: Optional[str] = None):
|
||||
def __init__(self, progress_file: Optional[str] = None, user_id: Optional[str] = None):
|
||||
self.steps = self._initialize_steps()
|
||||
self.current_step = 1
|
||||
self.started_at = datetime.now().isoformat()
|
||||
self.last_updated = datetime.now().isoformat()
|
||||
self.is_completed = False
|
||||
self.completed_at = None
|
||||
self.progress_file = progress_file or ".onboarding_progress.json"
|
||||
self.user_id = user_id # Add user_id for database isolation
|
||||
|
||||
# Use user-specific file for backward compatibility
|
||||
if user_id:
|
||||
self.progress_file = progress_file or f".onboarding_progress_{user_id}.json"
|
||||
else:
|
||||
self.progress_file = progress_file or ".onboarding_progress.json"
|
||||
|
||||
# Initialize database service for dual persistence
|
||||
try:
|
||||
from services.onboarding_database_service import OnboardingDatabaseService
|
||||
self.db_service = OnboardingDatabaseService()
|
||||
self.use_database = True
|
||||
logger.info(f"Database service initialized for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Database service not available, using file only: {e}")
|
||||
self.db_service = None
|
||||
self.use_database = False
|
||||
|
||||
# Load existing progress if available
|
||||
self.load_progress()
|
||||
@@ -192,8 +209,9 @@ class OnboardingProgress:
|
||||
logger.info("Onboarding completed successfully")
|
||||
|
||||
def save_progress(self):
|
||||
"""Save progress to file."""
|
||||
"""Save progress to both file and database (dual persistence)."""
|
||||
try:
|
||||
# Save to JSON file (backward compatibility)
|
||||
progress_data = {
|
||||
"steps": [{
|
||||
"step_number": step.step_number,
|
||||
@@ -215,6 +233,65 @@ class OnboardingProgress:
|
||||
json.dump(progress_data, f, indent=2)
|
||||
|
||||
logger.debug(f"Progress saved to {self.progress_file}")
|
||||
|
||||
# Also save to database if available and user_id is set
|
||||
if self.use_database and self.db_service and self.user_id:
|
||||
try:
|
||||
from services.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
# Update session progress
|
||||
self.db_service.update_step(self.user_id, self.current_step, db)
|
||||
|
||||
# Calculate progress percentage
|
||||
completed_count = sum(1 for s in self.steps if s.status == StepStatus.COMPLETED)
|
||||
progress_pct = (completed_count / len(self.steps)) * 100
|
||||
self.db_service.update_progress(self.user_id, progress_pct, db)
|
||||
|
||||
# Save step-specific data to appropriate tables
|
||||
for step in self.steps:
|
||||
if step.status == StepStatus.COMPLETED and step.data:
|
||||
if step.step_number == 1: # API Keys
|
||||
api_keys = step.data.get('api_keys', {})
|
||||
for provider, key in api_keys.items():
|
||||
if key:
|
||||
# Save to database (for user isolation in production)
|
||||
self.db_service.save_api_key(self.user_id, provider, key, db)
|
||||
|
||||
# Also save to .env file ONLY in local development
|
||||
# This allows local developers to have keys in .env for convenience
|
||||
# In production, keys are fetched from database per user
|
||||
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
|
||||
if is_local:
|
||||
try:
|
||||
from services.api_key_manager import APIKeyManager
|
||||
api_key_manager = APIKeyManager()
|
||||
api_key_manager.save_api_key(provider, key)
|
||||
logger.info(f"[LOCAL] API key for {provider} saved to .env file")
|
||||
except Exception as env_error:
|
||||
logger.warning(f"[LOCAL] Failed to save {provider} API key to .env file: {env_error}")
|
||||
else:
|
||||
logger.info(f"[PRODUCTION] API key for {provider} saved to database only (user: {self.user_id})")
|
||||
|
||||
# Log database save confirmation
|
||||
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
|
||||
elif step.step_number == 2: # Website Analysis
|
||||
self.db_service.save_website_analysis(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Website analysis saved to database for user {self.user_id}")
|
||||
elif step.step_number == 3: # Research Preferences
|
||||
self.db_service.save_research_preferences(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Research preferences saved to database for user {self.user_id}")
|
||||
elif step.step_number == 4: # Persona Generation
|
||||
self.db_service.save_persona_data(self.user_id, step.data, db)
|
||||
logger.info(f"✅ DATABASE: Persona data saved to database for user {self.user_id}")
|
||||
|
||||
logger.info(f"Progress also saved to database for user {self.user_id}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as db_error:
|
||||
logger.warning(f"Failed to save to database, JSON file still saved: {db_error}")
|
||||
# Don't fail if database save fails - JSON is still working
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving progress: {str(e)}")
|
||||
|
||||
@@ -423,8 +500,34 @@ class APIKeyManager:
|
||||
try:
|
||||
if provider in self.api_keys:
|
||||
self.api_keys[provider] = api_key
|
||||
self._save_to_env_file(provider, api_key)
|
||||
logger.info(f"API key saved for {provider}")
|
||||
|
||||
# Save to database if available and user_id is set
|
||||
if hasattr(self, 'use_database') and self.use_database and hasattr(self, 'db_service') and self.db_service and hasattr(self, 'user_id') and self.user_id:
|
||||
try:
|
||||
from services.database import SessionLocal
|
||||
db = SessionLocal()
|
||||
try:
|
||||
self.db_service.save_api_key(self.user_id, provider, api_key, db)
|
||||
logger.info(f"✅ DATABASE: API key for {provider} saved to database for user {self.user_id}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as db_error:
|
||||
logger.warning(f"Failed to save {provider} API key to database: {db_error}")
|
||||
|
||||
# Also save to .env file in local mode
|
||||
is_local = os.getenv('DEPLOY_ENV', 'local') == 'local'
|
||||
if is_local:
|
||||
# Special handling for CopilotKit - save to frontend/.env
|
||||
if provider == 'copilotkit':
|
||||
self._save_to_frontend_env(api_key)
|
||||
logger.info(f"[LOCAL] CopilotKit API key saved to frontend/.env file")
|
||||
else:
|
||||
# Save other keys to backend/.env
|
||||
self._save_to_env_file(provider, api_key)
|
||||
logger.info(f"[LOCAL] API key for {provider} saved to backend/.env file")
|
||||
else:
|
||||
logger.info(f"[PRODUCTION] API key for {provider} saved to memory only (database handles persistence)")
|
||||
|
||||
return True
|
||||
else:
|
||||
logger.error(f"Unknown provider: {provider}")
|
||||
@@ -490,8 +593,50 @@ class APIKeyManager:
|
||||
"total_providers": len(self.api_keys)
|
||||
}
|
||||
|
||||
def _save_to_frontend_env(self, api_key: str):
|
||||
"""Save CopilotKit API key to frontend/.env file."""
|
||||
try:
|
||||
# Get the frontend directory path
|
||||
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
frontend_dir = os.path.join(os.path.dirname(backend_dir), "frontend")
|
||||
env_path = os.path.join(frontend_dir, ".env")
|
||||
|
||||
# Read existing .env file
|
||||
if os.path.exists(env_path):
|
||||
with open(env_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
else:
|
||||
lines = []
|
||||
|
||||
# Update or add REACT_APP_COPILOTKIT_API_KEY
|
||||
key_found = False
|
||||
updated_lines = []
|
||||
env_var = "REACT_APP_COPILOTKIT_API_KEY"
|
||||
|
||||
for line in lines:
|
||||
if line.startswith(f"{env_var}="):
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
key_found = True
|
||||
else:
|
||||
updated_lines.append(line)
|
||||
|
||||
if not key_found:
|
||||
# Ensure the file ends with a newline before adding new key
|
||||
if updated_lines and not updated_lines[-1].endswith('\n'):
|
||||
updated_lines[-1] += '\n'
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
|
||||
# Write back to frontend .env file
|
||||
with open(env_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines(updated_lines)
|
||||
|
||||
logger.debug(f"CopilotKit API key saved to frontend .env file")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving to frontend .env file: {str(e)}")
|
||||
|
||||
def _save_to_env_file(self, provider: str, api_key: str):
|
||||
"""Save API key to .env file."""
|
||||
"""Save API key to backend .env file."""
|
||||
try:
|
||||
env_mapping = {
|
||||
"openai": "OPENAI_API_KEY",
|
||||
@@ -513,11 +658,10 @@ class APIKeyManager:
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
# Update .env file - use backend directory path
|
||||
import os
|
||||
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
env_path = os.path.join(backend_dir, ".env")
|
||||
if os.path.exists(env_path):
|
||||
with open(env_path, 'r') as f:
|
||||
with open(env_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
lines = f.readlines()
|
||||
else:
|
||||
lines = []
|
||||
@@ -532,13 +676,23 @@ class APIKeyManager:
|
||||
updated_lines.append(line)
|
||||
|
||||
if not key_found:
|
||||
# Ensure the file ends with a newline before adding new key
|
||||
if updated_lines and not updated_lines[-1].endswith('\n'):
|
||||
updated_lines[-1] += '\n'
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
|
||||
with open(env_path, 'w') as f:
|
||||
with open(env_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines(updated_lines)
|
||||
|
||||
# Reload environment variables
|
||||
load_dotenv(override=True)
|
||||
# Reload environment variables into current process
|
||||
load_dotenv(env_path, override=True)
|
||||
|
||||
# Verify the key is now in environment
|
||||
loaded_key = os.environ.get(env_var)
|
||||
if loaded_key == api_key:
|
||||
logger.info(f"✅ {env_var} loaded into environment (available for immediate use)")
|
||||
else:
|
||||
logger.warning(f"⚠️ {env_var} written to .env but not in environment yet")
|
||||
|
||||
logger.debug(f"API key saved to .env file for {provider}")
|
||||
except Exception as e:
|
||||
@@ -555,13 +709,17 @@ def get_onboarding_progress() -> OnboardingProgress:
|
||||
return get_onboarding_progress._instance
|
||||
|
||||
def get_onboarding_progress_for_user(user_id: str) -> OnboardingProgress:
|
||||
"""Get or create a per-user onboarding progress instance persisted to a user-specific file."""
|
||||
"""Get or create a per-user onboarding progress instance with database persistence."""
|
||||
global _user_onboarding_progress_cache
|
||||
safe_user_id = ''.join([c if c.isalnum() or c in ('-', '_') else '_' for c in str(user_id)])
|
||||
if safe_user_id in _user_onboarding_progress_cache:
|
||||
return _user_onboarding_progress_cache[safe_user_id]
|
||||
|
||||
# Create user-specific progress file for backward compatibility
|
||||
progress_file = f".onboarding_progress_{safe_user_id}.json"
|
||||
instance = OnboardingProgress(progress_file=progress_file)
|
||||
|
||||
# Pass user_id to enable database persistence
|
||||
instance = OnboardingProgress(progress_file=progress_file, user_id=user_id)
|
||||
_user_onboarding_progress_cache[safe_user_id] = instance
|
||||
return instance
|
||||
|
||||
|
||||
Reference in New Issue
Block a user