Add centralized text routing policy and premium HF defaults

This commit is contained in:
ي
2026-03-12 15:03:22 +05:30
parent b410ece4ca
commit d4528fbc74
3 changed files with 62 additions and 21 deletions

View File

@@ -14,6 +14,11 @@ from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
from .routing_policy import (
PREMIUM_DEFAULT_MODEL,
SIF_LOW_COST_MODEL_DEFAULTS,
resolve_text_provider_alias,
)
def llm_text_gen(
@@ -45,6 +50,7 @@ def llm_text_gen(
# Set default values for LLM parameters
gpt_provider = "google" # Default to Google Gemini
model = "gemini-2.0-flash-001"
hf_low_cost_default_model = SIF_LOW_COST_MODEL_DEFAULTS[0]
temperature = 0.7
max_tokens = 4000
top_p = 0.9
@@ -55,12 +61,13 @@ def llm_text_gen(
# Check for GPT_PROVIDER environment variable
env_provider = os.getenv('GPT_PROVIDER', '').lower()
if env_provider in ['gemini', 'google']:
resolved_env_provider = resolve_text_provider_alias(env_provider)
if resolved_env_provider == "google":
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
elif resolved_env_provider == "huggingface":
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
model = PREMIUM_DEFAULT_MODEL
# Default blog characteristics
blog_tone = "Professional"
@@ -86,7 +93,7 @@ def llm_text_gen(
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
model = PREMIUM_DEFAULT_MODEL
else:
logger.error("[llm_text_gen] No API keys found for supported providers.")
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
@@ -99,13 +106,13 @@ def llm_text_gen(
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
model = PREMIUM_DEFAULT_MODEL
else:
raise RuntimeError("No supported providers available.")
if gpt_provider == "huggingface" and preferred_hf_models:
model = preferred_hf_models[0]
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
if gpt_provider == "huggingface" and preferred_hf_models is not None:
model = preferred_hf_models[0] if preferred_hf_models else hf_low_cost_default_model
logger.info(f"[llm_text_gen] Using SIF low-cost HF model: {model}")
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
@@ -313,7 +320,7 @@ def llm_text_gen(
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
fallback_model = preferred_hf_models[0] if preferred_hf_models else PREMIUM_DEFAULT_MODEL
if fallback_provider == "google":
if json_struct:
@@ -340,7 +347,7 @@ def llm_text_gen(
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
model=fallback_model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions
@@ -348,7 +355,7 @@ def llm_text_gen(
else:
response_text = huggingface_text_response(
prompt=prompt,
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
model=fallback_model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
@@ -389,7 +396,8 @@ def llm_text_gen(
def check_gpt_provider(gpt_provider: str) -> bool:
"""Check if the specified GPT provider is supported."""
supported_providers = ["google", "huggingface"]
return gpt_provider in supported_providers
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
return resolved_provider in supported_providers
def get_api_key(gpt_provider: str) -> Optional[str]:
"""Get API key for the specified provider."""
@@ -397,10 +405,12 @@ def get_api_key(gpt_provider: str) -> Optional[str]:
api_key_manager = APIKeyManager()
provider_mapping = {
"google": "gemini",
"huggingface": "hf_token"
"huggingface": "hf_token",
"wavespeed": "hf_token"
}
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
mapped_provider = provider_mapping.get(resolved_provider, resolved_provider)
return api_key_manager.get_api_key(mapped_provider)
except Exception as e:
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")