Extract useful LLM provider improvements from PRs #423-#429

huggingface_provider.py:
- Add retry logic with _should_retry_hf_error and _is_non_retryable_hf_error
- Update default models from :groq to :cerebras (HF_FALLBACK_MODELS)
- Add fallback_models parameter to huggingface_text_response
- Add get_available_models with updated model list

main_text_generation.py:
- Add GPT_PROVIDER and TEXTGEN_AI_MODELS env var support
- Add preferred_provider and flow_type parameters to llm_text_gen
- Add HF_MODEL_MAPPING for short model name resolution
- Add flow_type logging tag for better observability

sif_agents.py:
- Add LOW_COST_SHARED_REMOTE_MODELS for SIF agents
- Update SharedLLMWrapper to use preferred_hf_models and flow_type

These changes preserve the modular textgen_utils structure while incorporating
the useful routing and retry logic improvements from the pending PRs.
This commit is contained in:
ajaysi
2026-03-22 11:16:48 +05:30
parent 16be2b21f4
commit a26fa84263
3 changed files with 134 additions and 45 deletions

View File

@@ -16,12 +16,33 @@ from .huggingface_provider import huggingface_text_response, huggingface_structu
from .tenant_provider_config import tenant_provider_config_resolver
HF_MODEL_MAPPING = {
"gpt-oss": "openai/gpt-oss-120b:cerebras",
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"llama": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:cerebras",
}
HF_FALLBACK_MODELS = [
"openai/gpt-oss-120b:cerebras",
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
]
def llm_text_gen(
prompt: str,
system_prompt: Optional[str] = None,
json_struct: Optional[Dict[str, Any]] = None,
user_id: str = None,
preferred_hf_models: Optional[List[str]] = None,
preferred_provider: Optional[str] = None,
flow_type: Optional[str] = None,
) -> str:
"""
Generate text using Language Model (LLM) based on the provided prompt.
@@ -31,6 +52,9 @@ def llm_text_gen(
system_prompt (str, optional): Custom system prompt to use instead of the default one.
json_struct (dict, optional): JSON schema structure for structured responses.
user_id (str): Clerk user ID for subscription checking (required).
preferred_hf_models (list, optional): Preferred HuggingFace models.
preferred_provider (str, optional): Preferred provider (google, huggingface).
flow_type (str, optional): Flow type for logging (e.g., 'sif_agent', 'premium_tool').
Returns:
str: Generated text based on the prompt.
@@ -39,7 +63,10 @@ def llm_text_gen(
RuntimeError: If subscription limits are exceeded or user_id is missing.
"""
try:
logger.info("[llm_text_gen] Starting text generation")
resolved_flow_type = flow_type or ("sif_agent" if preferred_hf_models else "premium_tool")
flow_tag = f"flow_type={resolved_flow_type}"
logger.info(f"[llm_text_gen][{flow_tag}] Starting text generation")
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
# Set default values for LLM parameters
@@ -53,17 +80,50 @@ def llm_text_gen(
frequency_penalty = 0.0
presence_penalty = 0.0
provider_cfg = tenant_provider_config_resolver.resolve(
modality="text",
user_id=user_id,
)
selected_provider = (provider_cfg.selected_providers or [None])[0]
if selected_provider in ["gemini", "google"]:
gpt_provider = "google"
model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001"
elif selected_provider == "huggingface":
gpt_provider = "huggingface"
model = provider_cfg.model_policy.get("default_model") or "mistralai/Mistral-7B-Instruct-v0.3:groq"
# Check for GPT_PROVIDER environment variable
env_provider = os.getenv('GPT_PROVIDER', '').lower()
provider_list = [p.strip() for p in env_provider.split(',') if p.strip()]
# Check for TEXTGEN_AI_MODELS environment variable
textgen_models_env = os.getenv('TEXTGEN_AI_MODELS', '').strip()
model_list = [m.strip() for m in textgen_models_env.split(',') if m.strip()] if textgen_models_env else []
# Determine provider based on env vars or tenant config
if provider_list:
primary_provider = provider_list[0]
if primary_provider in ['gemini', 'google']:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif primary_provider in ['hf_response_api', 'huggingface', 'hf']:
gpt_provider = "huggingface"
model = "openai/gpt-oss-120b:cerebras"
elif preferred_provider:
if preferred_provider in ['gemini', 'google']:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif preferred_provider in ['hf_response_api', 'huggingface', 'hf']:
gpt_provider = "huggingface"
model = "openai/gpt-oss-120b:cerebras"
else:
# Fall back to tenant config
provider_cfg = tenant_provider_config_resolver.resolve(
modality="text",
user_id=user_id,
)
selected_provider = (provider_cfg.selected_providers or [None])[0]
if selected_provider in ["gemini", "google"]:
gpt_provider = "google"
model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001"
elif selected_provider == "huggingface":
gpt_provider = "huggingface"
model = provider_cfg.model_policy.get("default_model") or "openai/gpt-oss-120b:cerebras"
# Map short model names to full paths for HF
if model_list and gpt_provider == "huggingface":
if "/" in model_list[0]:
model = model_list[0]
else:
model = HF_MODEL_MAPPING.get(model_list[0], model_list[0])
# Default blog characteristics
blog_tone = "Professional"
@@ -96,7 +156,7 @@ def llm_text_gen(
if gpt_provider == "huggingface" and preferred_hf_models:
model = preferred_hf_models[0]
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
logger.info(f"[llm_text_gen][{flow_tag}] Using preferred HF model: {model}")
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
@@ -304,7 +364,7 @@ def llm_text_gen(
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
fallback_model = HF_FALLBACK_MODELS[0]
if fallback_provider == "google":
if json_struct: