Extract useful LLM provider improvements from PRs #423-#429

huggingface_provider.py:
- Add retry logic with _should_retry_hf_error and _is_non_retryable_hf_error
- Update default models from :groq to :cerebras (HF_FALLBACK_MODELS)
- Add fallback_models parameter to huggingface_text_response
- Add get_available_models with updated model list

main_text_generation.py:
- Add GPT_PROVIDER and TEXTGEN_AI_MODELS env var support
- Add preferred_provider and flow_type parameters to llm_text_gen
- Add HF_MODEL_MAPPING for short model name resolution
- Add flow_type logging tag for better observability

sif_agents.py:
- Add LOW_COST_SHARED_REMOTE_MODELS for SIF agents
- Update SharedLLMWrapper to use preferred_hf_models and flow_type

These changes preserve the modular textgen_utils structure while incorporating
the useful routing and retry logic improvements from the pending PRs.
This commit is contained in:
ajaysi
2026-03-22 11:16:48 +05:30
parent 16be2b21f4
commit a26fa84263
3 changed files with 134 additions and 45 deletions

View File

@@ -76,6 +76,7 @@ logger = get_service_logger("huggingface_provider")
from tenacity import (
retry,
retry_if_exception,
stop_after_attempt,
wait_random_exponential,
)
@@ -90,10 +91,10 @@ except ImportError:
logger.warn("OpenAI library not available. Install with: pip install openai")
HF_FALLBACK_MODELS = [
"openai/gpt-oss-120b:groq",
"moonshotai/Kimi-K2-Instruct-0905:groq",
"meta-llama/Llama-3.1-8B-Instruct:groq",
"mistralai/Mistral-7B-Instruct-v0.3:groq",
"openai/gpt-oss-120b:cerebras",
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
]
@@ -102,18 +103,19 @@ def _candidate_model_variants(model: str):
if not model:
return
# Try configured model first (supports provider suffixes like ":groq")
yield model
# Fallback to base repo id when provider suffix is not recognized by the router
if ":" in model:
base_model = model.split(":", 1)[0]
if base_model:
yield base_model
def _fallback_model_sequence(model: str):
sequence = [model] + HF_FALLBACK_MODELS
def _fallback_model_sequence(model: str, fallback_models: list = None):
if fallback_models:
sequence = [model] + fallback_models
else:
sequence = [model]
seen = set()
for preferred_model in sequence:
for candidate in _candidate_model_variants(preferred_model):
@@ -121,6 +123,27 @@ def _fallback_model_sequence(model: str):
seen.add(candidate)
yield candidate
def _is_non_retryable_hf_error(exc: Exception) -> bool:
"""Skip retries for deterministic HF failures (e.g., unknown model ids, billing)."""
msg = str(exc).lower()
status = getattr(exc, "status_code", None)
if isinstance(exc, NotFoundError) or "not found" in msg or "404" in msg:
return True
if status == 402 or "402" in msg or "depleted" in msg or "credits" in msg:
return True
if status == 401 or "unauthorized" in msg or "401" in msg:
return True
if status == 403 or "forbidden" in msg or "403" in msg:
return True
return False
def _should_retry_hf_error(exc: Exception) -> bool:
return not _is_non_retryable_hf_error(exc)
def get_huggingface_api_key() -> str:
"""Get Hugging Face API key with proper error handling."""
api_key = os.getenv('HF_TOKEN')
@@ -137,10 +160,15 @@ def get_huggingface_api_key() -> str:
return api_key
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
@retry(
retry=retry_if_exception(_should_retry_hf_error),
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
)
def huggingface_text_response(
prompt: str,
model: str = "openai/gpt-oss-120b:groq",
model: str = "openai/gpt-oss-120b:cerebras",
fallback_models: list = None,
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
@@ -154,7 +182,8 @@ def huggingface_text_response(
Args:
prompt (str): The input prompt for the AI model
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:groq")
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:cerebras")
fallback_models (list, optional): Explicit fallback models to try
temperature (float): Controls randomness (0.0-1.0)
max_tokens (int): Maximum tokens in response
top_p (float): Nucleus sampling parameter (0.0-1.0)
@@ -166,16 +195,10 @@ def huggingface_text_response(
Raises:
Exception: If API key is missing or API call fails
Best Practices:
- Use appropriate temperature for your use case (0.7 for creative, 0.1-0.3 for factual)
- Set max_tokens based on expected response length
- Use system_prompt to guide model behavior
- Handle errors gracefully in calling functions
Example:
result = huggingface_text_response(
prompt="Write a blog post about AI",
model="openai/gpt-oss-120b:groq",
model="openai/gpt-oss-120b:cerebras",
temperature=0.7,
max_tokens=2048,
system_prompt="You are a professional content writer."
@@ -439,12 +462,11 @@ def huggingface_structured_json_response(
except Exception as e:
logger.error(f"❌ Hugging Face API call failed: {e}")
# If 422 Unprocessable Entity (often due to response_format not supported), retry without it
if "422" in str(e) or "not supported" in str(e).lower() or isinstance(e, NotFoundError):
logger.info("Retrying without response_format...")
response = None
last_error = None
for candidate_model in _fallback_model_sequence(model):
for candidate_model in _fallback_model_sequence(model, fallback_models):
try:
response = client.chat.completions.create(
model=candidate_model,
@@ -463,14 +485,12 @@ def huggingface_structured_json_response(
if response is None:
raise last_error or e
response_text = response.choices[0].message.content
# ... (same parsing logic would apply, simplified here for brevity)
try:
return json.loads(response_text)
except:
# Regex fallback
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
return json.loads(json_match.group())
return {"error": "Failed to parse JSON response", "raw_response": response_text}
raise e
@@ -491,12 +511,12 @@ def get_available_models() -> list:
list: List of available model identifiers
"""
return [
"openai/gpt-oss-120b:groq",
"moonshotai/Kimi-K2-Instruct-0905:groq",
"openai/gpt-oss-120b:cerebras",
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
"Qwen/Qwen2.5-VL-7B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct:groq",
"microsoft/Phi-3-medium-4k-instruct:groq",
"mistralai/Mistral-7B-Instruct-v0.3:groq"
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
"microsoft/Phi-3-medium-4k-instruct:cerebras",
"mistralai/Mistral-7B-Instruct-v0.3:cerebras"
]
def validate_model(model: str) -> bool: