refactor: Combine API key checks into a single function

This commit is contained in:
ajaysi (aider)
2024-10-06 14:36:02 +05:30
parent 80e777d568
commit 7b219c8cea
2 changed files with 23 additions and 30 deletions

View File

@@ -216,8 +216,7 @@ def main():
setup_ui()
setup_environment_paths()
sidebar_configuration()
check_api_keys()
check_llm_environs()
check_all_api_keys()
setup_tabs()
modify_prompts_sidebar()

View File

@@ -4,10 +4,11 @@ from dotenv import load_dotenv
@st.cache_data
def check_api_keys():
def check_all_api_keys():
"""
Checks if the required API keys are present in the environment variables.
Checks if all required API keys are present in the environment variables.
Prompts the user to enter missing keys and saves them in the .env file.
This includes general API keys and the LLM provider key.
"""
api_keys = {
"METAPHOR_API_KEY": "https://dashboard.exa.ai/login",
@@ -16,30 +17,11 @@ def check_api_keys():
"STABILITY_API_KEY": "https://platform.stability.ai/",
"FIRECRAWL_API_KEY": "https://www.firecrawl.dev/account"
}
missing_keys = {
key: url for key, url in api_keys.items() if os.getenv(key) is None
}
if missing_keys:
st.warning(f"API keys not found: {', '.join(missing_keys)}. Please provide them below. Restart the app after saving the keys.")
with st.form(key='api_keys_form'):
for key, url in missing_keys.items():
st.text_input(f"{key}: 👉[Get it here]({url})👈", type="password", key=key)
if st.form_submit_button("Save Keys"):
with open(".env", "a") as env_file:
for key in missing_keys:
key_value = st.session_state[key]
env_file.write(f"{key}={key_value}\n")
st.success("API keys saved successfully! Please restart the application.")
st.stop()
return False
return True
@st.cache_data
def check_llm_environs():
"""
Ensures that the LLM provider and corresponding API key are set.
Prompts the user to select a provider and enter the API key if missing.
"""
gpt_provider = os.getenv("GPT_PROVIDER")
supported_providers = {
'google': "GEMINI_API_KEY",
@@ -57,13 +39,25 @@ def check_llm_environs():
except IOError as e:
st.error(f"Failed to write GPT_PROVIDER to .env file: {e}")
st.success(f"GPT Provider set to {gpt_provider}")
api_key_var = supported_providers[gpt_provider.lower()]
if not os.getenv(api_key_var):
api_key = st.text_input(f"Enter {api_key_var}:")
if api_key:
os.environ[api_key_var] = api_key
with open(".env", "a") as env_file:
env_file.write(f"{api_key_var}={api_key}\n")
st.success(f"{api_key_var} added successfully!")
missing_keys[api_key_var] = ''
if missing_keys:
st.warning(f"API keys not found: {', '.join(missing_keys)}. Please provide them below. Restart the app after saving the keys.")
with st.form(key='api_keys_form'):
for key, url in missing_keys.items():
if url:
st.text_input(f"{key}: 👉[Get it here]({url})👈", type="password", key=key)
else:
st.text_input(f"{key}:", type="password", key=key)
if st.form_submit_button("Save Keys"):
with open(".env", "a") as env_file:
for key in missing_keys:
key_value = st.session_state[key]
env_file.write(f"{key}={key_value}\n")
st.success("API keys saved successfully! Please restart the application.")
st.stop()
return False
return True