Add centralized text routing policy and premium HF defaults
This commit is contained in:
@@ -70,6 +70,7 @@ else:
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .routing_policy import PREMIUM_DEFAULT_MODEL, SIF_LOW_COST_MODEL_DEFAULTS
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("huggingface_provider")
|
||||
@@ -90,10 +91,10 @@ except ImportError:
|
||||
logger.warn("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
HF_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
PREMIUM_DEFAULT_MODEL,
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
SIF_LOW_COST_MODEL_DEFAULTS[0],
|
||||
]
|
||||
|
||||
|
||||
@@ -140,7 +141,7 @@ def get_huggingface_api_key() -> str:
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def huggingface_text_response(
|
||||
prompt: str,
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
model: str = PREMIUM_DEFAULT_MODEL,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
top_p: float = 0.9,
|
||||
@@ -175,7 +176,7 @@ def huggingface_text_response(
|
||||
Example:
|
||||
result = huggingface_text_response(
|
||||
prompt="Write a blog post about AI",
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
model=PREMIUM_DEFAULT_MODEL,
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
system_prompt="You are a professional content writer."
|
||||
@@ -274,7 +275,7 @@ def huggingface_text_response(
|
||||
def huggingface_structured_json_response(
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
model: str = PREMIUM_DEFAULT_MODEL,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 8192,
|
||||
system_prompt: Optional[str] = None
|
||||
@@ -491,12 +492,12 @@ def get_available_models() -> list:
|
||||
list: List of available model identifiers
|
||||
"""
|
||||
return [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
PREMIUM_DEFAULT_MODEL,
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"microsoft/Phi-3-medium-4k-instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
SIF_LOW_COST_MODEL_DEFAULTS[0]
|
||||
]
|
||||
|
||||
def validate_model(model: str) -> bool:
|
||||
|
||||
@@ -14,6 +14,11 @@ from ..onboarding.api_key_manager import APIKeyManager
|
||||
|
||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
from .routing_policy import (
|
||||
PREMIUM_DEFAULT_MODEL,
|
||||
SIF_LOW_COST_MODEL_DEFAULTS,
|
||||
resolve_text_provider_alias,
|
||||
)
|
||||
|
||||
|
||||
def llm_text_gen(
|
||||
@@ -45,6 +50,7 @@ def llm_text_gen(
|
||||
# Set default values for LLM parameters
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
hf_low_cost_default_model = SIF_LOW_COST_MODEL_DEFAULTS[0]
|
||||
temperature = 0.7
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
@@ -55,12 +61,13 @@ def llm_text_gen(
|
||||
|
||||
# Check for GPT_PROVIDER environment variable
|
||||
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
||||
if env_provider in ['gemini', 'google']:
|
||||
resolved_env_provider = resolve_text_provider_alias(env_provider)
|
||||
if resolved_env_provider == "google":
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
elif resolved_env_provider == "huggingface":
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
model = PREMIUM_DEFAULT_MODEL
|
||||
|
||||
# Default blog characteristics
|
||||
blog_tone = "Professional"
|
||||
@@ -86,7 +93,7 @@ def llm_text_gen(
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
model = PREMIUM_DEFAULT_MODEL
|
||||
else:
|
||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
||||
@@ -99,13 +106,13 @@ def llm_text_gen(
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
model = PREMIUM_DEFAULT_MODEL
|
||||
else:
|
||||
raise RuntimeError("No supported providers available.")
|
||||
|
||||
if gpt_provider == "huggingface" and preferred_hf_models:
|
||||
model = preferred_hf_models[0]
|
||||
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
|
||||
if gpt_provider == "huggingface" and preferred_hf_models is not None:
|
||||
model = preferred_hf_models[0] if preferred_hf_models else hf_low_cost_default_model
|
||||
logger.info(f"[llm_text_gen] Using SIF low-cost HF model: {model}")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
|
||||
@@ -313,7 +320,7 @@ def llm_text_gen(
|
||||
elif fallback_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
fallback_model = preferred_hf_models[0] if preferred_hf_models else PREMIUM_DEFAULT_MODEL
|
||||
|
||||
if fallback_provider == "google":
|
||||
if json_struct:
|
||||
@@ -340,7 +347,7 @@ def llm_text_gen(
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
model=fallback_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
@@ -348,7 +355,7 @@ def llm_text_gen(
|
||||
else:
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
model=fallback_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
@@ -389,7 +396,8 @@ def llm_text_gen(
|
||||
def check_gpt_provider(gpt_provider: str) -> bool:
|
||||
"""Check if the specified GPT provider is supported."""
|
||||
supported_providers = ["google", "huggingface"]
|
||||
return gpt_provider in supported_providers
|
||||
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
|
||||
return resolved_provider in supported_providers
|
||||
|
||||
def get_api_key(gpt_provider: str) -> Optional[str]:
|
||||
"""Get API key for the specified provider."""
|
||||
@@ -397,10 +405,12 @@ def get_api_key(gpt_provider: str) -> Optional[str]:
|
||||
api_key_manager = APIKeyManager()
|
||||
provider_mapping = {
|
||||
"google": "gemini",
|
||||
"huggingface": "hf_token"
|
||||
"huggingface": "hf_token",
|
||||
"wavespeed": "hf_token"
|
||||
}
|
||||
|
||||
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
||||
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
|
||||
mapped_provider = provider_mapping.get(resolved_provider, resolved_provider)
|
||||
return api_key_manager.get_api_key(mapped_provider)
|
||||
except Exception as e:
|
||||
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
|
||||
|
||||
30
backend/services/llm_providers/routing_policy.py
Normal file
30
backend/services/llm_providers/routing_policy.py
Normal file
@@ -0,0 +1,30 @@
|
||||
"""Routing policy for LLM provider aliases and model defaults."""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
|
||||
# Premium text generation defaults
|
||||
PREMIUM_DEFAULT_PROVIDER = "huggingface"
|
||||
PREMIUM_DEFAULT_MODEL = "openai/gpt-oss-120b:groq"
|
||||
|
||||
# SIF low-cost defaults for text generation
|
||||
SIF_LOW_COST_MODEL_DEFAULTS: List[str] = [
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
]
|
||||
|
||||
# Canonical provider aliases for text routing
|
||||
PROVIDER_ALIAS_MAPPING: Dict[str, str] = {
|
||||
"gemini": "google",
|
||||
"google": "google",
|
||||
"hf_response_api": "huggingface",
|
||||
"huggingface": "huggingface",
|
||||
"hf": "huggingface",
|
||||
# Text-only alias: route wavespeed GPT_PROVIDER to premium HF text route.
|
||||
"wavespeed": PREMIUM_DEFAULT_PROVIDER,
|
||||
}
|
||||
|
||||
|
||||
def resolve_text_provider_alias(provider: str) -> str:
|
||||
"""Resolve a GPT provider alias into a canonical text provider."""
|
||||
return PROVIDER_ALIAS_MAPPING.get((provider or "").lower(), "")
|
||||
|
||||
Reference in New Issue
Block a user