Extract useful LLM provider improvements from PRs #423-#429
huggingface_provider.py: - Add retry logic with _should_retry_hf_error and _is_non_retryable_hf_error - Update default models from :groq to :cerebras (HF_FALLBACK_MODELS) - Add fallback_models parameter to huggingface_text_response - Add get_available_models with updated model list main_text_generation.py: - Add GPT_PROVIDER and TEXTGEN_AI_MODELS env var support - Add preferred_provider and flow_type parameters to llm_text_gen - Add HF_MODEL_MAPPING for short model name resolution - Add flow_type logging tag for better observability sif_agents.py: - Add LOW_COST_SHARED_REMOTE_MODELS for SIF agents - Update SharedLLMWrapper to use preferred_hf_models and flow_type These changes preserve the modular textgen_utils structure while incorporating the useful routing and retry logic improvements from the pending PRs.
This commit is contained in:
@@ -32,9 +32,12 @@ class SharedLLMWrapper:
|
||||
def generate(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the shared LLM provider."""
|
||||
try:
|
||||
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
|
||||
# but we could map them if needed.
|
||||
return llm_text_gen(prompt, user_id=self.user_id)
|
||||
return llm_text_gen(
|
||||
prompt,
|
||||
user_id=self.user_id,
|
||||
preferred_hf_models=LOW_COST_SHARED_REMOTE_MODELS,
|
||||
flow_type="sif_agent",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SharedLLMWrapper failed to generate text: {e}")
|
||||
return f"[ERROR: Shared LLM generation failed for user {self.user_id}]"
|
||||
@@ -44,6 +47,12 @@ class SharedLLMWrapper:
|
||||
|
||||
_local_llm_cache = {}
|
||||
|
||||
LOW_COST_SHARED_REMOTE_MODELS = [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
]
|
||||
|
||||
LOCAL_LLM_FALLBACKS = [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
|
||||
@@ -76,6 +76,7 @@ logger = get_service_logger("huggingface_provider")
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
@@ -90,10 +91,10 @@ except ImportError:
|
||||
logger.warn("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
HF_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
]
|
||||
|
||||
|
||||
@@ -102,18 +103,19 @@ def _candidate_model_variants(model: str):
|
||||
if not model:
|
||||
return
|
||||
|
||||
# Try configured model first (supports provider suffixes like ":groq")
|
||||
yield model
|
||||
|
||||
# Fallback to base repo id when provider suffix is not recognized by the router
|
||||
if ":" in model:
|
||||
base_model = model.split(":", 1)[0]
|
||||
if base_model:
|
||||
yield base_model
|
||||
|
||||
|
||||
def _fallback_model_sequence(model: str):
|
||||
sequence = [model] + HF_FALLBACK_MODELS
|
||||
def _fallback_model_sequence(model: str, fallback_models: list = None):
|
||||
if fallback_models:
|
||||
sequence = [model] + fallback_models
|
||||
else:
|
||||
sequence = [model]
|
||||
seen = set()
|
||||
for preferred_model in sequence:
|
||||
for candidate in _candidate_model_variants(preferred_model):
|
||||
@@ -121,6 +123,27 @@ def _fallback_model_sequence(model: str):
|
||||
seen.add(candidate)
|
||||
yield candidate
|
||||
|
||||
|
||||
def _is_non_retryable_hf_error(exc: Exception) -> bool:
|
||||
"""Skip retries for deterministic HF failures (e.g., unknown model ids, billing)."""
|
||||
msg = str(exc).lower()
|
||||
status = getattr(exc, "status_code", None)
|
||||
|
||||
if isinstance(exc, NotFoundError) or "not found" in msg or "404" in msg:
|
||||
return True
|
||||
if status == 402 or "402" in msg or "depleted" in msg or "credits" in msg:
|
||||
return True
|
||||
if status == 401 or "unauthorized" in msg or "401" in msg:
|
||||
return True
|
||||
if status == 403 or "forbidden" in msg or "403" in msg:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _should_retry_hf_error(exc: Exception) -> bool:
|
||||
return not _is_non_retryable_hf_error(exc)
|
||||
|
||||
def get_huggingface_api_key() -> str:
|
||||
"""Get Hugging Face API key with proper error handling."""
|
||||
api_key = os.getenv('HF_TOKEN')
|
||||
@@ -137,10 +160,15 @@ def get_huggingface_api_key() -> str:
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_hf_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def huggingface_text_response(
|
||||
prompt: str,
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
model: str = "openai/gpt-oss-120b:cerebras",
|
||||
fallback_models: list = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
top_p: float = 0.9,
|
||||
@@ -154,7 +182,8 @@ def huggingface_text_response(
|
||||
|
||||
Args:
|
||||
prompt (str): The input prompt for the AI model
|
||||
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:groq")
|
||||
model (str): Hugging Face model identifier (default: "openai/gpt-oss-120b:cerebras")
|
||||
fallback_models (list, optional): Explicit fallback models to try
|
||||
temperature (float): Controls randomness (0.0-1.0)
|
||||
max_tokens (int): Maximum tokens in response
|
||||
top_p (float): Nucleus sampling parameter (0.0-1.0)
|
||||
@@ -166,16 +195,10 @@ def huggingface_text_response(
|
||||
Raises:
|
||||
Exception: If API key is missing or API call fails
|
||||
|
||||
Best Practices:
|
||||
- Use appropriate temperature for your use case (0.7 for creative, 0.1-0.3 for factual)
|
||||
- Set max_tokens based on expected response length
|
||||
- Use system_prompt to guide model behavior
|
||||
- Handle errors gracefully in calling functions
|
||||
|
||||
Example:
|
||||
result = huggingface_text_response(
|
||||
prompt="Write a blog post about AI",
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
model="openai/gpt-oss-120b:cerebras",
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
system_prompt="You are a professional content writer."
|
||||
@@ -439,12 +462,11 @@ def huggingface_structured_json_response(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Hugging Face API call failed: {e}")
|
||||
# If 422 Unprocessable Entity (often due to response_format not supported), retry without it
|
||||
if "422" in str(e) or "not supported" in str(e).lower() or isinstance(e, NotFoundError):
|
||||
logger.info("Retrying without response_format...")
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
@@ -463,14 +485,12 @@ def huggingface_structured_json_response(
|
||||
if response is None:
|
||||
raise last_error or e
|
||||
response_text = response.choices[0].message.content
|
||||
# ... (same parsing logic would apply, simplified here for brevity)
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except:
|
||||
# Regex fallback
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
return json.loads(json_match.group())
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
raise e
|
||||
|
||||
@@ -491,12 +511,12 @@ def get_available_models() -> list:
|
||||
list: List of available model identifiers
|
||||
"""
|
||||
return [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"microsoft/Phi-3-medium-4k-instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"microsoft/Phi-3-medium-4k-instruct:cerebras",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:cerebras"
|
||||
]
|
||||
|
||||
def validate_model(model: str) -> bool:
|
||||
|
||||
@@ -16,12 +16,33 @@ from .huggingface_provider import huggingface_text_response, huggingface_structu
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
|
||||
HF_MODEL_MAPPING = {
|
||||
"gpt-oss": "openai/gpt-oss-120b:cerebras",
|
||||
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
|
||||
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
|
||||
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"llama": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:cerebras",
|
||||
}
|
||||
|
||||
HF_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
]
|
||||
|
||||
|
||||
def llm_text_gen(
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
json_struct: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None,
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
preferred_provider: Optional[str] = None,
|
||||
flow_type: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
@@ -31,6 +52,9 @@ def llm_text_gen(
|
||||
system_prompt (str, optional): Custom system prompt to use instead of the default one.
|
||||
json_struct (dict, optional): JSON schema structure for structured responses.
|
||||
user_id (str): Clerk user ID for subscription checking (required).
|
||||
preferred_hf_models (list, optional): Preferred HuggingFace models.
|
||||
preferred_provider (str, optional): Preferred provider (google, huggingface).
|
||||
flow_type (str, optional): Flow type for logging (e.g., 'sif_agent', 'premium_tool').
|
||||
|
||||
Returns:
|
||||
str: Generated text based on the prompt.
|
||||
@@ -39,7 +63,10 @@ def llm_text_gen(
|
||||
RuntimeError: If subscription limits are exceeded or user_id is missing.
|
||||
"""
|
||||
try:
|
||||
logger.info("[llm_text_gen] Starting text generation")
|
||||
resolved_flow_type = flow_type or ("sif_agent" if preferred_hf_models else "premium_tool")
|
||||
flow_tag = f"flow_type={resolved_flow_type}"
|
||||
|
||||
logger.info(f"[llm_text_gen][{flow_tag}] Starting text generation")
|
||||
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
|
||||
|
||||
# Set default values for LLM parameters
|
||||
@@ -53,17 +80,50 @@ def llm_text_gen(
|
||||
frequency_penalty = 0.0
|
||||
presence_penalty = 0.0
|
||||
|
||||
provider_cfg = tenant_provider_config_resolver.resolve(
|
||||
modality="text",
|
||||
user_id=user_id,
|
||||
)
|
||||
selected_provider = (provider_cfg.selected_providers or [None])[0]
|
||||
if selected_provider in ["gemini", "google"]:
|
||||
gpt_provider = "google"
|
||||
model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001"
|
||||
elif selected_provider == "huggingface":
|
||||
gpt_provider = "huggingface"
|
||||
model = provider_cfg.model_policy.get("default_model") or "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
# Check for GPT_PROVIDER environment variable
|
||||
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
||||
provider_list = [p.strip() for p in env_provider.split(',') if p.strip()]
|
||||
|
||||
# Check for TEXTGEN_AI_MODELS environment variable
|
||||
textgen_models_env = os.getenv('TEXTGEN_AI_MODELS', '').strip()
|
||||
model_list = [m.strip() for m in textgen_models_env.split(',') if m.strip()] if textgen_models_env else []
|
||||
|
||||
# Determine provider based on env vars or tenant config
|
||||
if provider_list:
|
||||
primary_provider = provider_list[0]
|
||||
if primary_provider in ['gemini', 'google']:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif primary_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
elif preferred_provider:
|
||||
if preferred_provider in ['gemini', 'google']:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif preferred_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
else:
|
||||
# Fall back to tenant config
|
||||
provider_cfg = tenant_provider_config_resolver.resolve(
|
||||
modality="text",
|
||||
user_id=user_id,
|
||||
)
|
||||
selected_provider = (provider_cfg.selected_providers or [None])[0]
|
||||
if selected_provider in ["gemini", "google"]:
|
||||
gpt_provider = "google"
|
||||
model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001"
|
||||
elif selected_provider == "huggingface":
|
||||
gpt_provider = "huggingface"
|
||||
model = provider_cfg.model_policy.get("default_model") or "openai/gpt-oss-120b:cerebras"
|
||||
|
||||
# Map short model names to full paths for HF
|
||||
if model_list and gpt_provider == "huggingface":
|
||||
if "/" in model_list[0]:
|
||||
model = model_list[0]
|
||||
else:
|
||||
model = HF_MODEL_MAPPING.get(model_list[0], model_list[0])
|
||||
|
||||
# Default blog characteristics
|
||||
blog_tone = "Professional"
|
||||
@@ -96,7 +156,7 @@ def llm_text_gen(
|
||||
|
||||
if gpt_provider == "huggingface" and preferred_hf_models:
|
||||
model = preferred_hf_models[0]
|
||||
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
|
||||
logger.info(f"[llm_text_gen][{flow_tag}] Using preferred HF model: {model}")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
|
||||
@@ -304,7 +364,7 @@ def llm_text_gen(
|
||||
elif fallback_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
fallback_model = HF_FALLBACK_MODELS[0]
|
||||
|
||||
if fallback_provider == "google":
|
||||
if json_struct:
|
||||
|
||||
Reference in New Issue
Block a user