Files
ALwrity/backend/services/api_key_manager.py

572 lines
24 KiB
Python

"""Enhanced API Key Manager service for ALwrity backend."""
# This file contains the core business logic moved from lib/utils/api_key_manager/
# It includes the OnboardingProgress class and related functionality
import os
import json
from datetime import datetime
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, asdict
from enum import Enum
from loguru import logger
from dotenv import load_dotenv
class StepStatus(Enum):
PENDING = "pending"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
SKIPPED = "skipped"
@dataclass
class StepData:
step_number: int
title: str
description: str
status: StepStatus
completed_at: Optional[str] = None
data: Optional[Dict[str, Any]] = None
validation_errors: List[str] = None
def __post_init__(self):
if self.validation_errors is None:
self.validation_errors = []
class OnboardingProgress:
"""Manages onboarding progress with persistence and validation."""
def __init__(self, progress_file: Optional[str] = None):
self.steps = self._initialize_steps()
self.current_step = 1
self.started_at = datetime.now().isoformat()
self.last_updated = datetime.now().isoformat()
self.is_completed = False
self.completed_at = None
self.progress_file = progress_file or ".onboarding_progress.json"
# Load existing progress if available
self.load_progress()
def _initialize_steps(self) -> List[StepData]:
"""Initialize the 6-step onboarding process."""
return [
StepData(1, "AI LLM Providers", "Configure AI language model providers", StepStatus.PENDING),
StepData(2, "Website Analysis", "Set up website analysis and crawling", StepStatus.PENDING),
StepData(3, "AI Research", "Configure AI research capabilities", StepStatus.PENDING),
StepData(4, "Personalization", "Set up personalization features", StepStatus.PENDING),
StepData(5, "Integrations", "Configure ALwrity integrations", StepStatus.PENDING),
StepData(6, "Complete Setup", "Finalize and complete onboarding", StepStatus.PENDING)
]
def get_step_data(self, step_number: int) -> Optional[StepData]:
"""Get data for a specific step."""
for step in self.steps:
if step.step_number == step_number:
return step
return None
def mark_step_completed(self, step_number: int, data: Optional[Dict[str, Any]] = None):
"""Mark a step as completed."""
logger.info(f"[mark_step_completed] Marking step {step_number} as completed")
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.COMPLETED
step.completed_at = datetime.now().isoformat()
step.data = data
self.last_updated = datetime.now().isoformat()
# Check if all steps are now completed
all_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps)
if all_completed:
# If all steps are completed, mark onboarding as complete
self.is_completed = True
self.completed_at = datetime.now().isoformat()
self.current_step = len(self.steps) # Set to last step number
logger.info(f"[mark_step_completed] All steps completed, marking onboarding as complete")
else:
# Only increment current_step if there are more steps to go
self.current_step = step_number + 1
# Ensure current_step doesn't exceed total steps
if self.current_step > len(self.steps):
self.current_step = len(self.steps)
logger.info(f"[mark_step_completed] Step {step_number} completed, new current_step: {self.current_step}, is_completed: {self.is_completed}")
self.save_progress()
logger.info(f"Step {step_number} marked as completed")
else:
logger.error(f"[mark_step_completed] Step {step_number} not found")
def mark_step_in_progress(self, step_number: int):
"""Mark a step as in progress."""
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.IN_PROGRESS
self.current_step = step_number
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.info(f"Step {step_number} marked as in progress")
def mark_step_skipped(self, step_number: int):
"""Mark a step as skipped."""
step = self.get_step_data(step_number)
if step:
step.status = StepStatus.SKIPPED
step.completed_at = datetime.now().isoformat()
self.last_updated = datetime.now().isoformat()
# Check if all steps are now completed
all_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps)
if all_completed:
# If all steps are completed, mark onboarding as complete
self.is_completed = True
self.completed_at = datetime.now().isoformat()
self.current_step = len(self.steps) # Set to last step number
logger.info(f"[mark_step_skipped] All steps completed, marking onboarding as complete")
else:
# Only increment current_step if there are more steps to go
self.current_step = step_number + 1
# Ensure current_step doesn't exceed total steps
if self.current_step > len(self.steps):
self.current_step = len(self.steps)
logger.info(f"[mark_step_skipped] Step {step_number} skipped, new current_step: {self.current_step}, is_completed: {self.is_completed}")
self.save_progress()
logger.info(f"Step {step_number} marked as skipped")
def can_proceed_to_step(self, step_number: int) -> bool:
"""Check if user can proceed to a specific step."""
if step_number == 1:
return True # First step is always accessible
# Check if all previous steps are completed
for step in self.steps:
if step.step_number < step_number:
if step.status not in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
return False
return True
def can_complete_onboarding(self) -> bool:
"""Check if onboarding can be completed."""
required_steps = [1, 2, 3, 6] # Steps 1, 2, 3, and 6 are required
for step_num in required_steps:
step = self.get_step_data(step_num)
if step and step.status not in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
return False
return True
def get_completion_percentage(self) -> float:
"""Get the completion percentage."""
completed_steps = sum(1 for step in self.steps if step.status in [StepStatus.COMPLETED, StepStatus.SKIPPED])
return (completed_steps / len(self.steps)) * 100
def get_next_incomplete_step(self) -> Optional[int]:
"""Get the next incomplete step number."""
for step in self.steps:
if step.status not in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
return step.step_number
return None
def get_resume_step(self) -> int:
"""Get the step to resume from."""
logger.info(f"[get_resume_step] Checking resume step...")
logger.info(f"[get_resume_step] Current step: {self.current_step}")
logger.info(f"[get_resume_step] Steps status: {[f'{s.step_number}:{s.status.value}' for s in self.steps]}")
for step in self.steps:
if step.status not in [StepStatus.COMPLETED, StepStatus.SKIPPED]:
logger.info(f"[get_resume_step] Found incomplete step: {step.step_number}")
return step.step_number
logger.warning(f"[get_resume_step] No incomplete steps found, defaulting to step 1")
return 1 # Default to first step
def complete_onboarding(self):
"""Complete the onboarding process."""
self.is_completed = True
self.completed_at = datetime.now().isoformat()
self.last_updated = datetime.now().isoformat()
self.save_progress()
logger.info("Onboarding completed successfully")
def save_progress(self):
"""Save progress to file."""
try:
progress_data = {
"steps": [{
"step_number": step.step_number,
"title": step.title,
"description": step.description,
"status": step.status.value, # Convert enum to string
"completed_at": step.completed_at,
"data": step.data,
"validation_errors": step.validation_errors
} for step in self.steps],
"current_step": self.current_step,
"started_at": self.started_at,
"last_updated": self.last_updated,
"is_completed": self.is_completed,
"completed_at": self.completed_at
}
with open(self.progress_file, 'w') as f:
json.dump(progress_data, f, indent=2)
logger.debug(f"Progress saved to {self.progress_file}")
except Exception as e:
logger.error(f"Error saving progress: {str(e)}")
def load_progress(self):
"""Load progress from file."""
try:
if os.path.exists(self.progress_file):
with open(self.progress_file, 'r') as f:
progress_data = json.load(f)
# Restore step data
for step_data in progress_data.get("steps", []):
step_num = step_data.get("step_number")
if step_num:
step = self.get_step_data(step_num)
if step:
step.status = StepStatus(step_data.get("status", "pending"))
step.completed_at = step_data.get("completed_at")
step.data = step_data.get("data")
step.validation_errors = step_data.get("validation_errors", [])
# Restore other data
self.current_step = progress_data.get("current_step", 1)
self.started_at = progress_data.get("started_at", self.started_at)
self.last_updated = progress_data.get("last_updated", self.last_updated)
self.is_completed = progress_data.get("is_completed", False)
self.completed_at = progress_data.get("completed_at")
# Fix any corrupted state
self._fix_corrupted_state()
logger.info("Progress loaded from file")
except Exception as e:
logger.error(f"Error loading progress: {str(e)}")
def _fix_corrupted_state(self):
"""Fix any corrupted progress state."""
# Check if all steps are completed
all_steps_completed = all(s.status in [StepStatus.COMPLETED, StepStatus.SKIPPED] for s in self.steps)
if all_steps_completed:
# If all steps are completed, ensure is_completed is True and current_step is valid
if not self.is_completed:
logger.info(f"[_fix_corrupted_state] All steps completed but is_completed was False, fixing...")
self.is_completed = True
self.completed_at = datetime.now().isoformat()
# Ensure current_step doesn't exceed total steps
if self.current_step > len(self.steps):
logger.info(f"[_fix_corrupted_state] Current step {self.current_step} exceeds total steps {len(self.steps)}, fixing...")
self.current_step = len(self.steps)
self.save_progress()
else:
# If not all steps are completed, ensure is_completed is False
if self.is_completed:
logger.info(f"[_fix_corrupted_state] Not all steps completed but is_completed was True, fixing...")
self.is_completed = False
self.completed_at = None
self.save_progress()
def reset_progress(self):
"""Reset all progress."""
self.steps = self._initialize_steps()
self.current_step = 1
self.started_at = datetime.now().isoformat()
self.last_updated = datetime.now().isoformat()
self.is_completed = False
self.completed_at = None
self.save_progress()
logger.info("Progress reset successfully")
class APIKeyManager:
"""Enhanced manager for handling API keys with setup instructions."""
def __init__(self):
self.api_keys = {
"openai": None,
"gemini": None,
"anthropic": None,
"mistral": None,
"tavily": None,
"serper": None,
"metaphor": None, # legacy mapping for Exa, kept for backward compatibility
"exa": None,
"firecrawl": None,
"stability": None,
"copilotkit": None,
}
self.load_api_keys()
# Enhanced provider setup instructions
self.api_key_groups = {
"Create": {
"GEMINI_API_KEY": {
"url": "https://makersuite.google.com/app/apikey",
"description": "Google's Gemini AI for content generation",
"setup_steps": [
"Visit Google AI Studio",
"Create a Google Cloud account",
"Enable Gemini API",
"Generate API key"
]
},
"OPENAI_API_KEY": {
"url": "https://platform.openai.com/api-keys",
"description": "OpenAI's GPT models for content creation",
"setup_steps": [
"Go to OpenAI platform",
"Create an account",
"Navigate to API keys",
"Create new API key"
]
},
"MISTRAL_API_KEY": {
"url": "https://console.mistral.ai/api-keys/",
"description": "Mistral AI for efficient content generation",
"setup_steps": [
"Visit Mistral AI website",
"Sign up for an account",
"Access API section",
"Generate API key"
]
},
"ANTHROPIC_API_KEY": {
"url": "https://console.anthropic.com/",
"description": "Anthropic's Claude models for content creation",
"setup_steps": [
"Visit Anthropic console",
"Create an account",
"Navigate to API keys",
"Generate API key"
]
}
},
"Research": {
"TAVILY_API_KEY": {
"url": "https://tavily.com/#api",
"description": "Powers intelligent web research features",
"setup_steps": [
"Go to Tavily's website",
"Create an account",
"Access your API dashboard",
"Generate a new API key"
]
},
"SERPER_API_KEY": {
"url": "https://serper.dev/signup",
"description": "Enables Google search functionality",
"setup_steps": [
"Visit Serper.dev",
"Sign up for an account",
"Go to API section",
"Create your API key"
]
}
},
"Deep Search": {
"EXA_API_KEY": {
"url": "https://dashboard.exa.ai/login",
"description": "Exa (formerly Metaphor) for advanced web search",
"setup_steps": [
"Visit the Exa AI dashboard",
"Sign up for a free account",
"Navigate to API Keys section",
"Create a new API key"
]
},
"FIRECRAWL_API_KEY": {
"url": "https://www.firecrawl.dev/account",
"description": "Enables web content extraction",
"setup_steps": [
"Visit Firecrawl website",
"Sign up for an account",
"Access API dashboard",
"Create your API key"
]
}
},
"Integrations": {
"STABILITY_API_KEY": {
"url": "https://platform.stability.ai/",
"description": "Enables AI image generation",
"setup_steps": [
"Access Stability AI platform",
"Create an account",
"Navigate to API settings",
"Generate your API key"
]
}
},
"UI": {
"COPILOTKIT_API_KEY": {
"url": "https://copilotkit.ai",
"description": "CopilotKit public API key for in-app assistant",
"setup_steps": [
"Sign up or log in to CopilotKit",
"Navigate to API Keys",
"Generate a public API key (ck_pub_...)"
]
}
}
}
def save_api_key(self, provider: str, api_key: str) -> bool:
"""Save an API key for a provider."""
try:
if provider in self.api_keys:
self.api_keys[provider] = api_key
self._save_to_env_file(provider, api_key)
logger.info(f"API key saved for {provider}")
return True
else:
logger.error(f"Unknown provider: {provider}")
return False
except Exception as e:
logger.error(f"Error saving API key: {str(e)}")
return False
def get_api_key(self, provider: str) -> Optional[str]:
"""Get API key for a provider."""
return self.api_keys.get(provider)
def get_all_keys(self) -> Dict[str, str]:
"""Get all configured API keys."""
return {k: v for k, v in self.api_keys.items() if v is not None}
def load_api_keys(self):
"""Load API keys from environment variables."""
# Reload environment variables first - use backend directory path
import os
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
env_path = os.path.join(backend_dir, ".env")
load_dotenv(env_path, override=True)
env_mapping = {
"OPENAI_API_KEY": "openai",
"GEMINI_API_KEY": "gemini",
"ANTHROPIC_API_KEY": "anthropic",
"MISTRAL_API_KEY": "mistral",
"TAVILY_API_KEY": "tavily",
"SERPER_API_KEY": "serper",
"METAPHOR_API_KEY": "metaphor", # legacy
"EXA_API_KEY": "exa",
"FIRECRAWL_API_KEY": "firecrawl",
"STABILITY_API_KEY": "stability",
"COPILOTKIT_API_KEY": "copilotkit",
}
for env_var, provider in env_mapping.items():
api_key = os.getenv(env_var)
if api_key:
self.api_keys[provider] = api_key
def get_provider_setup_info(self, provider: str) -> Optional[Dict[str, Any]]:
"""Get setup information for a specific provider."""
for group_name, providers in self.api_key_groups.items():
for env_var, info in providers.items():
if env_var.lower().replace('_api_key', '').replace('_key', '') == provider:
return {
"provider": provider,
"group": group_name,
"url": info["url"],
"description": info["description"],
"setup_steps": info["setup_steps"]
}
return None
def get_all_providers_info(self) -> Dict[str, Any]:
"""Get information for all providers."""
return {
"groups": self.api_key_groups,
"configured_providers": [k for k, v in self.api_keys.items() if v],
"total_providers": len(self.api_keys)
}
def _save_to_env_file(self, provider: str, api_key: str):
"""Save API key to .env file."""
try:
env_mapping = {
"openai": "OPENAI_API_KEY",
"gemini": "GEMINI_API_KEY",
"anthropic": "ANTHROPIC_API_KEY",
"mistral": "MISTRAL_API_KEY",
"tavily": "TAVILY_API_KEY",
"serper": "SERPER_API_KEY",
"metaphor": "METAPHOR_API_KEY", # legacy
"exa": "EXA_API_KEY",
"firecrawl": "FIRECRAWL_API_KEY",
"stability": "STABILITY_API_KEY",
"copilotkit": "COPILOTKIT_API_KEY",
}
env_var = env_mapping.get(provider)
if env_var:
# Update environment variable
os.environ[env_var] = api_key
# Update .env file - use backend directory path
import os
backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
env_path = os.path.join(backend_dir, ".env")
if os.path.exists(env_path):
with open(env_path, 'r') as f:
lines = f.readlines()
else:
lines = []
key_found = False
updated_lines = []
for line in lines:
if line.startswith(f"{env_var}="):
updated_lines.append(f"{env_var}={api_key}\n")
key_found = True
else:
updated_lines.append(line)
if not key_found:
updated_lines.append(f"{env_var}={api_key}\n")
with open(env_path, 'w') as f:
f.writelines(updated_lines)
# Reload environment variables
load_dotenv(override=True)
logger.debug(f"API key saved to .env file for {provider}")
except Exception as e:
logger.error(f"Error saving to .env file: {str(e)}")
# Global instance for the application
_onboarding_progress = None
_user_onboarding_progress_cache: Dict[str, OnboardingProgress] = {}
def get_onboarding_progress() -> OnboardingProgress:
"""Get the global onboarding progress instance."""
if not hasattr(get_onboarding_progress, '_instance'):
get_onboarding_progress._instance = OnboardingProgress()
return get_onboarding_progress._instance
def get_onboarding_progress_for_user(user_id: str) -> OnboardingProgress:
"""Get or create a per-user onboarding progress instance persisted to a user-specific file."""
global _user_onboarding_progress_cache
safe_user_id = ''.join([c if c.isalnum() or c in ('-', '_') else '_' for c in str(user_id)])
if safe_user_id in _user_onboarding_progress_cache:
return _user_onboarding_progress_cache[safe_user_id]
progress_file = f".onboarding_progress_{safe_user_id}.json"
instance = OnboardingProgress(progress_file=progress_file)
_user_onboarding_progress_cache[safe_user_id] = instance
return instance
def get_api_key_manager() -> APIKeyManager:
"""Get the global API key manager instance."""
if not hasattr(get_api_key_manager, '_instance'):
get_api_key_manager._instance = APIKeyManager()
return get_api_key_manager._instance