diff --git a/alwrity.py b/alwrity.py index a1194d5c..b5ca17d5 100644 --- a/alwrity.py +++ b/alwrity.py @@ -216,8 +216,7 @@ def main(): setup_ui() setup_environment_paths() sidebar_configuration() - check_api_keys() - check_llm_environs() + check_all_api_keys() setup_tabs() modify_prompts_sidebar() diff --git a/lib/utils/api_key_manager.py b/lib/utils/api_key_manager.py index d15ddbdf..5730cc2a 100644 --- a/lib/utils/api_key_manager.py +++ b/lib/utils/api_key_manager.py @@ -4,10 +4,11 @@ from dotenv import load_dotenv @st.cache_data -def check_api_keys(): +def check_all_api_keys(): """ - Checks if the required API keys are present in the environment variables. + Checks if all required API keys are present in the environment variables. Prompts the user to enter missing keys and saves them in the .env file. + This includes general API keys and the LLM provider key. """ api_keys = { "METAPHOR_API_KEY": "https://dashboard.exa.ai/login", @@ -16,30 +17,11 @@ def check_api_keys(): "STABILITY_API_KEY": "https://platform.stability.ai/", "FIRECRAWL_API_KEY": "https://www.firecrawl.dev/account" } + missing_keys = { key: url for key, url in api_keys.items() if os.getenv(key) is None } - if missing_keys: - st.warning(f"API keys not found: {', '.join(missing_keys)}. Please provide them below. Restart the app after saving the keys.") - with st.form(key='api_keys_form'): - for key, url in missing_keys.items(): - st.text_input(f"{key}: 👉[Get it here]({url})👈", type="password", key=key) - if st.form_submit_button("Save Keys"): - with open(".env", "a") as env_file: - for key in missing_keys: - key_value = st.session_state[key] - env_file.write(f"{key}={key_value}\n") - st.success("API keys saved successfully! Please restart the application.") - st.stop() - return False - return True -@st.cache_data -def check_llm_environs(): - """ - Ensures that the LLM provider and corresponding API key are set. - Prompts the user to select a provider and enter the API key if missing. - """ gpt_provider = os.getenv("GPT_PROVIDER") supported_providers = { 'google': "GEMINI_API_KEY", @@ -57,13 +39,25 @@ def check_llm_environs(): except IOError as e: st.error(f"Failed to write GPT_PROVIDER to .env file: {e}") st.success(f"GPT Provider set to {gpt_provider}") + api_key_var = supported_providers[gpt_provider.lower()] if not os.getenv(api_key_var): - api_key = st.text_input(f"Enter {api_key_var}:") - if api_key: - os.environ[api_key_var] = api_key - with open(".env", "a") as env_file: - env_file.write(f"{api_key_var}={api_key}\n") - st.success(f"{api_key_var} added successfully!") + missing_keys[api_key_var] = '' + + if missing_keys: + st.warning(f"API keys not found: {', '.join(missing_keys)}. Please provide them below. Restart the app after saving the keys.") + with st.form(key='api_keys_form'): + for key, url in missing_keys.items(): + if url: + st.text_input(f"{key}: 👉[Get it here]({url})👈", type="password", key=key) + else: + st.text_input(f"{key}:", type="password", key=key) + if st.form_submit_button("Save Keys"): + with open(".env", "a") as env_file: + for key in missing_keys: + key_value = st.session_state[key] + env_file.write(f"{key}={key_value}\n") + st.success("API keys saved successfully! Please restart the application.") + st.stop() return False return True