Gemini AI common code and utils
This commit is contained in:
@@ -19,6 +19,17 @@ def render_ai_providers(api_key_manager: APIKeyManager) -> Dict[str, Any]:
|
||||
"""Render the AI providers setup step."""
|
||||
logger.info("[render_ai_providers] Starting AI providers setup")
|
||||
try:
|
||||
# Initialize wizard state if not already initialized
|
||||
if 'wizard_state' not in st.session_state:
|
||||
st.session_state.wizard_state = {
|
||||
'current_step': 1,
|
||||
'total_steps': 6,
|
||||
'progress': 0,
|
||||
'completed_steps': set(),
|
||||
'last_updated': datetime.now()
|
||||
}
|
||||
logger.info("[render_ai_providers] Initialized wizard state")
|
||||
|
||||
# Store API key manager in session state for update_progress
|
||||
st.session_state['api_key_manager'] = api_key_manager
|
||||
|
||||
@@ -209,6 +220,15 @@ def render_ai_providers(api_key_manager: APIKeyManager) -> Dict[str, Any]:
|
||||
'google': google_key if validate_api_key(google_key) else None
|
||||
}
|
||||
|
||||
# Save API keys to .env file
|
||||
if validate_api_key(openai_key):
|
||||
api_key_manager.save_api_key("openai", openai_key)
|
||||
logger.info("[render_ai_providers] OpenAI API key saved to .env file")
|
||||
|
||||
if validate_api_key(google_key):
|
||||
api_key_manager.save_api_key("gemini", google_key)
|
||||
logger.info("[render_ai_providers] Google Gemini API key saved to .env file")
|
||||
|
||||
# Update progress and move to next step
|
||||
st.session_state['current_step'] = 2 # Set the next step explicitly
|
||||
update_progress()
|
||||
|
||||
@@ -91,18 +91,15 @@ def render_final_setup(api_key_manager: APIKeyManager) -> Dict[str, Any]:
|
||||
logger.info("[render_final_setup] User clicked complete setup")
|
||||
try:
|
||||
# Verify all required API keys are present and valid
|
||||
is_valid, missing_keys, impact_messages = check_all_api_keys(api_key_manager)
|
||||
is_valid = check_all_api_keys(api_key_manager)
|
||||
|
||||
if not is_valid:
|
||||
st.error("⚠️ Some required API keys are missing")
|
||||
st.markdown("### Missing API Keys and Impact")
|
||||
|
||||
# Display impact messages in a structured way
|
||||
for message in impact_messages:
|
||||
if message.startswith("⚠️"):
|
||||
st.error(message)
|
||||
else:
|
||||
st.warning(message)
|
||||
# Display impact messages
|
||||
st.warning("⚠️ Missing AI Provider: At least one AI provider (OpenAI, Google Gemini, Anthropic Claude, or Mistral) is required.")
|
||||
st.warning("⚠️ Missing Research Provider: At least one research provider (SerpAPI, Tavily, Metaphor, or Firecrawl) is required.")
|
||||
|
||||
st.markdown("""
|
||||
<div style='background-color: #fff3cd; color: #856404; padding: 1rem; border-radius: 0.25rem; margin-top: 1rem;'>
|
||||
|
||||
@@ -133,16 +133,74 @@ class APIKeyManager:
|
||||
except Exception as e:
|
||||
logger.error(f"[APIKeyManager.load_api_keys] Error loading API keys: {str(e)}")
|
||||
|
||||
def save_api_key(self, provider: str, key: str):
|
||||
"""Save an API key."""
|
||||
logger.info(f"[APIKeyManager.save_api_key] Saving API key for provider: {provider}")
|
||||
def save_api_key(self, provider: str, api_key: str) -> bool:
|
||||
"""
|
||||
Save an API key for a provider.
|
||||
|
||||
Args:
|
||||
provider: The provider name (e.g., 'openai', 'gemini')
|
||||
api_key: The API key value
|
||||
|
||||
Returns:
|
||||
bool: True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
self.api_keys[provider] = key
|
||||
# Save to environment variable
|
||||
os.environ[f"{provider.upper()}_API_KEY"] = key
|
||||
logger.info(f"[APIKeyManager.save_api_key] Successfully saved API key for {provider}")
|
||||
logger.info(f"[APIKeyManager] Saving API key for {provider}")
|
||||
|
||||
# Map provider to environment variable name
|
||||
env_var_map = {
|
||||
'openai': 'OPENAI_API_KEY',
|
||||
'gemini': 'GEMINI_API_KEY',
|
||||
'mistral': 'MISTRAL_API_KEY',
|
||||
'anthropic': 'ANTHROPIC_API_KEY',
|
||||
'serpapi': 'SERPAPI_API_KEY',
|
||||
'tavily': 'TAVILY_API_KEY',
|
||||
'metaphor': 'METAPHOR_API_KEY',
|
||||
'firecrawl': 'FIRECRAWL_API_KEY'
|
||||
}
|
||||
|
||||
env_var = env_var_map.get(provider)
|
||||
if not env_var:
|
||||
logger.error(f"[APIKeyManager] Unknown provider: {provider}")
|
||||
return False
|
||||
|
||||
# Update the in-memory dictionary
|
||||
self.api_keys[provider] = api_key
|
||||
|
||||
# Update environment variable
|
||||
os.environ[env_var] = api_key
|
||||
|
||||
# Read existing .env file content
|
||||
env_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))), '.env')
|
||||
try:
|
||||
with open(env_path, 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
except FileNotFoundError:
|
||||
lines = []
|
||||
|
||||
# Update or add the API key
|
||||
key_found = False
|
||||
updated_lines = []
|
||||
for line in lines:
|
||||
if line.startswith(f"{env_var}="):
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
key_found = True
|
||||
else:
|
||||
updated_lines.append(line)
|
||||
|
||||
if not key_found:
|
||||
updated_lines.append(f"{env_var}={api_key}\n")
|
||||
|
||||
# Write back to .env file
|
||||
with open(env_path, 'w', encoding='utf-8') as f:
|
||||
f.writelines(updated_lines)
|
||||
|
||||
logger.info(f"[APIKeyManager] Successfully saved API key for {provider}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[APIKeyManager.save_api_key] Error saving API key: {str(e)}")
|
||||
logger.error(f"[APIKeyManager] Error saving API key for {provider}: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_api_key(self, provider: str) -> Optional[str]:
|
||||
"""Get an API key."""
|
||||
|
||||
Reference in New Issue
Block a user