Refine HF fallback policy controls and SIF low-cost routing

This commit is contained in:
ي
2026-03-12 15:03:47 +05:30
parent b410ece4ca
commit bf191374a5
3 changed files with 85 additions and 20 deletions

View File

@@ -34,7 +34,11 @@ class SharedLLMWrapper:
try: try:
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults, # We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
# but we could map them if needed. # but we could map them if needed.
return llm_text_gen(prompt, user_id=self.user_id) return llm_text_gen(
prompt,
user_id=self.user_id,
preferred_hf_models=REMOTE_LOW_COST_HF_MODELS,
)
except Exception as e: except Exception as e:
logger.error(f"SharedLLMWrapper failed to generate text: {e}") logger.error(f"SharedLLMWrapper failed to generate text: {e}")
return f"[ERROR: Shared LLM generation failed for user {self.user_id}]" return f"[ERROR: Shared LLM generation failed for user {self.user_id}]"
@@ -44,6 +48,13 @@ class SharedLLMWrapper:
_local_llm_cache = {} _local_llm_cache = {}
REMOTE_LOW_COST_HF_MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-0.5B-Instruct",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
]
LOCAL_LLM_FALLBACKS = [ LOCAL_LLM_FALLBACKS = [
"Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-1.5B-Instruct",
"Qwen/Qwen2.5-0.5B-Instruct", "Qwen/Qwen2.5-0.5B-Instruct",

View File

@@ -51,7 +51,7 @@ import sys
from pathlib import Path from pathlib import Path
import json import json
import re import re
from typing import Optional, Dict, Any from typing import Optional, Dict, Any, List, Iterable
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -97,7 +97,7 @@ HF_FALLBACK_MODELS = [
] ]
def _candidate_model_variants(model: str): def _candidate_model_variants(model: str, allow_model_variant_fallback: bool = True):
"""Yield model ids to try for a single logical model preference.""" """Yield model ids to try for a single logical model preference."""
if not model: if not model:
return return
@@ -106,17 +106,31 @@ def _candidate_model_variants(model: str):
yield model yield model
# Fallback to base repo id when provider suffix is not recognized by the router # Fallback to base repo id when provider suffix is not recognized by the router
if ":" in model: if allow_model_variant_fallback and ":" in model:
base_model = model.split(":", 1)[0] base_model = model.split(":", 1)[0]
if base_model: if base_model:
yield base_model yield base_model
def _fallback_model_sequence(model: str): def _fallback_model_sequence(
sequence = [model] + HF_FALLBACK_MODELS model: str,
fallback_models: Optional[List[str]] = None,
allow_model_variant_fallback: bool = True,
):
sequence: Iterable[str]
if fallback_models is None:
# Safe default only when caller doesn't provide explicit policy.
sequence = [model] + HF_FALLBACK_MODELS
else:
# Caller owns fallback policy fully. Empty list means only requested model.
sequence = [model] + list(fallback_models)
seen = set() seen = set()
for preferred_model in sequence: for preferred_model in sequence:
for candidate in _candidate_model_variants(preferred_model): for candidate in _candidate_model_variants(
preferred_model,
allow_model_variant_fallback=allow_model_variant_fallback,
):
if candidate and candidate not in seen: if candidate and candidate not in seen:
seen.add(candidate) seen.add(candidate)
yield candidate yield candidate
@@ -144,7 +158,9 @@ def huggingface_text_response(
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 2048, max_tokens: int = 2048,
top_p: float = 0.9, top_p: float = 0.9,
system_prompt: Optional[str] = None system_prompt: Optional[str] = None,
fallback_models: Optional[List[str]] = None,
allow_model_variant_fallback: bool = True,
) -> str: ) -> str:
""" """
Generate text response using Hugging Face Inference Providers API. Generate text response using Hugging Face Inference Providers API.
@@ -233,7 +249,11 @@ def huggingface_text_response(
response = None response = None
last_error = None last_error = None
for candidate_model in _fallback_model_sequence(model): for candidate_model in _fallback_model_sequence(
model=model,
fallback_models=fallback_models,
allow_model_variant_fallback=allow_model_variant_fallback,
):
try: try:
response = client.chat.completions.create( response = client.chat.completions.create(
model=candidate_model, model=candidate_model,
@@ -277,7 +297,9 @@ def huggingface_structured_json_response(
model: str = "openai/gpt-oss-120b:groq", model: str = "openai/gpt-oss-120b:groq",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 8192, max_tokens: int = 8192,
system_prompt: Optional[str] = None system_prompt: Optional[str] = None,
fallback_models: Optional[List[str]] = None,
allow_model_variant_fallback: bool = True,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Generate structured JSON response using Hugging Face Inference Providers API. Generate structured JSON response using Hugging Face Inference Providers API.
@@ -387,7 +409,11 @@ def huggingface_structured_json_response(
try: try:
response = None response = None
last_error = None last_error = None
for candidate_model in _fallback_model_sequence(model): for candidate_model in _fallback_model_sequence(
model=model,
fallback_models=fallback_models,
allow_model_variant_fallback=allow_model_variant_fallback,
):
try: try:
response = client.chat.completions.create( response = client.chat.completions.create(
model=candidate_model, model=candidate_model,
@@ -444,7 +470,11 @@ def huggingface_structured_json_response(
logger.info("Retrying without response_format...") logger.info("Retrying without response_format...")
response = None response = None
last_error = None last_error = None
for candidate_model in _fallback_model_sequence(model): for candidate_model in _fallback_model_sequence(
model=model,
fallback_models=fallback_models,
allow_model_variant_fallback=allow_model_variant_fallback,
):
try: try:
response = client.chat.completions.create( response = client.chat.completions.create(
model=candidate_model, model=candidate_model,

View File

@@ -15,6 +15,10 @@ from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
PREMIUM_HF_MINIMAL_FALLBACK_MODELS = [
"openai/gpt-oss-120b:groq",
]
def llm_text_gen( def llm_text_gen(
prompt: str, prompt: str,
@@ -103,9 +107,21 @@ def llm_text_gen(
else: else:
raise RuntimeError("No supported providers available.") raise RuntimeError("No supported providers available.")
if gpt_provider == "huggingface" and preferred_hf_models: hf_fallback_models: Optional[List[str]] = None
model = preferred_hf_models[0] hf_allow_model_variant_fallback = True
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}") if gpt_provider == "huggingface":
if preferred_hf_models is not None:
if preferred_hf_models:
model = preferred_hf_models[0]
hf_fallback_models = preferred_hf_models[1:]
logger.info(f"[llm_text_gen] Using caller-provided HF policy starting model: {model}")
else:
# Explicit empty policy: only requested model (plus optional variant handling).
hf_fallback_models = []
logger.info("[llm_text_gen] Using caller-provided HF policy with no fallback models")
else:
# Premium/default path: minimal safe fallback chain to avoid excessive model hopping.
hf_fallback_models = PREMIUM_HF_MINIMAL_FALLBACK_MODELS
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}") logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
@@ -251,7 +267,9 @@ def llm_text_gen(
model=model, model=model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
fallback_models=hf_fallback_models,
allow_model_variant_fallback=hf_allow_model_variant_fallback,
) )
else: else:
response_text = huggingface_text_response( response_text = huggingface_text_response(
@@ -260,7 +278,9 @@ def llm_text_gen(
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
system_prompt=system_instructions system_prompt=system_instructions,
fallback_models=hf_fallback_models,
allow_model_variant_fallback=hf_allow_model_variant_fallback,
) )
else: else:
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}") logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
@@ -343,7 +363,9 @@ def llm_text_gen(
model="mistralai/Mistral-7B-Instruct-v0.3:groq", model="mistralai/Mistral-7B-Instruct-v0.3:groq",
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
fallback_models=PREMIUM_HF_MINIMAL_FALLBACK_MODELS,
allow_model_variant_fallback=True,
) )
else: else:
response_text = huggingface_text_response( response_text = huggingface_text_response(
@@ -352,7 +374,9 @@ def llm_text_gen(
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
system_prompt=system_instructions system_prompt=system_instructions,
fallback_models=PREMIUM_HF_MINIMAL_FALLBACK_MODELS,
allow_model_variant_fallback=True,
) )
# TRACK USAGE after successful fallback call # TRACK USAGE after successful fallback call