Refine HF fallback policy controls and SIF low-cost routing

This commit is contained in:
ي
2026-03-12 15:03:47 +05:30
parent b410ece4ca
commit bf191374a5
3 changed files with 85 additions and 20 deletions

View File

@@ -34,7 +34,11 @@ class SharedLLMWrapper:
try:
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
# but we could map them if needed.
return llm_text_gen(prompt, user_id=self.user_id)
return llm_text_gen(
prompt,
user_id=self.user_id,
preferred_hf_models=REMOTE_LOW_COST_HF_MODELS,
)
except Exception as e:
logger.error(f"SharedLLMWrapper failed to generate text: {e}")
return f"[ERROR: Shared LLM generation failed for user {self.user_id}]"
@@ -44,6 +48,13 @@ class SharedLLMWrapper:
_local_llm_cache = {}
REMOTE_LOW_COST_HF_MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-0.5B-Instruct",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]
LOCAL_LLM_FALLBACKS = [
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-0.5B-Instruct",

View File

@@ -51,7 +51,7 @@ import sys
from pathlib import Path
import json
import re
from typing import Optional, Dict, Any
from typing import Optional, Dict, Any, List, Iterable
from dotenv import load_dotenv
@@ -97,7 +97,7 @@ HF_FALLBACK_MODELS = [
]
def _candidate_model_variants(model: str):
def _candidate_model_variants(model: str, allow_model_variant_fallback: bool = True):
"""Yield model ids to try for a single logical model preference."""
if not model:
return
@@ -106,17 +106,31 @@ def _candidate_model_variants(model: str):
yield model
# Fallback to base repo id when provider suffix is not recognized by the router
if ":" in model:
if allow_model_variant_fallback and ":" in model:
base_model = model.split(":", 1)[0]
if base_model:
yield base_model
def _fallback_model_sequence(model: str):
sequence = [model] + HF_FALLBACK_MODELS
def _fallback_model_sequence(
model: str,
fallback_models: Optional[List[str]] = None,
allow_model_variant_fallback: bool = True,
):
sequence: Iterable[str]
if fallback_models is None:
# Safe default only when caller doesn't provide explicit policy.
sequence = [model] + HF_FALLBACK_MODELS
else:
# Caller owns fallback policy fully. Empty list means only requested model.
sequence = [model] + list(fallback_models)
seen = set()
for preferred_model in sequence:
for candidate in _candidate_model_variants(preferred_model):
for candidate in _candidate_model_variants(
preferred_model,
allow_model_variant_fallback=allow_model_variant_fallback,
):
if candidate and candidate not in seen:
seen.add(candidate)
yield candidate
@@ -144,7 +158,9 @@ def huggingface_text_response(
temperature: float = 0.7,
max_tokens: int = 2048,
top_p: float = 0.9,
system_prompt: Optional[str] = None
system_prompt: Optional[str] = None,
fallback_models: Optional[List[str]] = None,
allow_model_variant_fallback: bool = True,
) -> str:
"""
Generate text response using Hugging Face Inference Providers API.
@@ -233,7 +249,11 @@ def huggingface_text_response(
response = None
last_error = None
for candidate_model in _fallback_model_sequence(model):
for candidate_model in _fallback_model_sequence(
model=model,
fallback_models=fallback_models,
allow_model_variant_fallback=allow_model_variant_fallback,
):
try:
response = client.chat.completions.create(
model=candidate_model,
@@ -277,7 +297,9 @@ def huggingface_structured_json_response(
model: str = "openai/gpt-oss-120b:groq",
temperature: float = 0.7,
max_tokens: int = 8192,
system_prompt: Optional[str] = None
system_prompt: Optional[str] = None,
fallback_models: Optional[List[str]] = None,
allow_model_variant_fallback: bool = True,
) -> Dict[str, Any]:
"""
Generate structured JSON response using Hugging Face Inference Providers API.
@@ -387,7 +409,11 @@ def huggingface_structured_json_response(
try:
response = None
last_error = None
for candidate_model in _fallback_model_sequence(model):
for candidate_model in _fallback_model_sequence(
model=model,
fallback_models=fallback_models,
allow_model_variant_fallback=allow_model_variant_fallback,
):
try:
response = client.chat.completions.create(
model=candidate_model,
@@ -444,7 +470,11 @@ def huggingface_structured_json_response(
logger.info("Retrying without response_format...")
response = None
last_error = None
for candidate_model in _fallback_model_sequence(model):
for candidate_model in _fallback_model_sequence(
model=model,
fallback_models=fallback_models,
allow_model_variant_fallback=allow_model_variant_fallback,
):
try:
response = client.chat.completions.create(
model=candidate_model,

View File

@@ -15,6 +15,10 @@ from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
PREMIUM_HF_MINIMAL_FALLBACK_MODELS = [
"openai/gpt-oss-120b:groq",
]
def llm_text_gen(
prompt: str,
@@ -103,10 +107,22 @@ def llm_text_gen(
else:
raise RuntimeError("No supported providers available.")
if gpt_provider == "huggingface" and preferred_hf_models:
model = preferred_hf_models[0]
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
hf_fallback_models: Optional[List[str]] = None
hf_allow_model_variant_fallback = True
if gpt_provider == "huggingface":
if preferred_hf_models is not None:
if preferred_hf_models:
model = preferred_hf_models[0]
hf_fallback_models = preferred_hf_models[1:]
logger.info(f"[llm_text_gen] Using caller-provided HF policy starting model: {model}")
else:
# Explicit empty policy: only requested model (plus optional variant handling).
hf_fallback_models = []
logger.info("[llm_text_gen] Using caller-provided HF policy with no fallback models")
else:
# Premium/default path: minimal safe fallback chain to avoid excessive model hopping.
hf_fallback_models = PREMIUM_HF_MINIMAL_FALLBACK_MODELS
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
# Map provider name to APIProvider enum (define at function scope for usage tracking)
@@ -251,7 +267,9 @@ def llm_text_gen(
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions
system_prompt=system_instructions,
fallback_models=hf_fallback_models,
allow_model_variant_fallback=hf_allow_model_variant_fallback,
)
else:
response_text = huggingface_text_response(
@@ -260,7 +278,9 @@ def llm_text_gen(
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
system_prompt=system_instructions,
fallback_models=hf_fallback_models,
allow_model_variant_fallback=hf_allow_model_variant_fallback,
)
else:
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
@@ -343,7 +363,9 @@ def llm_text_gen(
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions
system_prompt=system_instructions,
fallback_models=PREMIUM_HF_MINIMAL_FALLBACK_MODELS,
allow_model_variant_fallback=True,
)
else:
response_text = huggingface_text_response(
@@ -352,7 +374,9 @@ def llm_text_gen(
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
system_prompt=system_instructions,
fallback_models=PREMIUM_HF_MINIMAL_FALLBACK_MODELS,
allow_model_variant_fallback=True,
)
# TRACK USAGE after successful fallback call