Merge_PR_417_centralized_text_routing_policy
This commit is contained in:
@@ -54,6 +54,7 @@ from typing import Optional, Dict, Any
|
|||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from utils.logger_utils import get_service_logger
|
from utils.logger_utils import get_service_logger
|
||||||
|
from .routing_policy import PREMIUM_DEFAULT_MODEL, SIF_LOW_COST_MODEL_DEFAULTS
|
||||||
|
|
||||||
# Use service-specific logger to avoid conflicts
|
# Use service-specific logger to avoid conflicts
|
||||||
logger = get_service_logger("huggingface_provider")
|
logger = get_service_logger("huggingface_provider")
|
||||||
@@ -77,10 +78,17 @@ except ImportError:
|
|||||||
|
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
HF_FALLBACK_MODELS = [
|
HF_FALLBACK_MODELS = [
|
||||||
|
<<<<<<< HEAD
|
||||||
"openai/gpt-oss-120b:cerebras",
|
"openai/gpt-oss-120b:cerebras",
|
||||||
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
|
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
|
||||||
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||||
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||||
|
=======
|
||||||
|
PREMIUM_DEFAULT_MODEL,
|
||||||
|
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||||
|
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||||
|
SIF_LOW_COST_MODEL_DEFAULTS[0],
|
||||||
|
>>>>>>> pr-417
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -219,8 +227,12 @@ def _get_hf_client(api_key: str):
|
|||||||
>>>>>>> pr-416
|
>>>>>>> pr-416
|
||||||
def huggingface_text_response(
|
def huggingface_text_response(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
<<<<<<< HEAD
|
||||||
model: str = "openai/gpt-oss-120b:cerebras",
|
model: str = "openai/gpt-oss-120b:cerebras",
|
||||||
fallback_models: Optional[List[str]] = None,
|
fallback_models: Optional[List[str]] = None,
|
||||||
|
=======
|
||||||
|
model: str = PREMIUM_DEFAULT_MODEL,
|
||||||
|
>>>>>>> pr-417
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 2048,
|
max_tokens: int = 2048,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
@@ -256,7 +268,11 @@ def huggingface_text_response(
|
|||||||
Example:
|
Example:
|
||||||
result = huggingface_text_response(
|
result = huggingface_text_response(
|
||||||
prompt="Write a blog post about AI",
|
prompt="Write a blog post about AI",
|
||||||
|
<<<<<<< HEAD
|
||||||
model="openai/gpt-oss-120b:cerebras",
|
model="openai/gpt-oss-120b:cerebras",
|
||||||
|
=======
|
||||||
|
model=PREMIUM_DEFAULT_MODEL,
|
||||||
|
>>>>>>> pr-417
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
max_tokens=2048,
|
max_tokens=2048,
|
||||||
system_prompt="You are a professional content writer."
|
system_prompt="You are a professional content writer."
|
||||||
@@ -369,8 +385,12 @@ def huggingface_text_response(
|
|||||||
def huggingface_structured_json_response(
|
def huggingface_structured_json_response(
|
||||||
prompt: str,
|
prompt: str,
|
||||||
schema: Dict[str, Any],
|
schema: Dict[str, Any],
|
||||||
|
<<<<<<< HEAD
|
||||||
model: str = "openai/gpt-oss-120b:cerebras",
|
model: str = "openai/gpt-oss-120b:cerebras",
|
||||||
fallback_models: Optional[List[str]] = None,
|
fallback_models: Optional[List[str]] = None,
|
||||||
|
=======
|
||||||
|
model: str = PREMIUM_DEFAULT_MODEL,
|
||||||
|
>>>>>>> pr-417
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.7,
|
||||||
max_tokens: int = 8192,
|
max_tokens: int = 8192,
|
||||||
system_prompt: Optional[str] = None,
|
system_prompt: Optional[str] = None,
|
||||||
@@ -616,12 +636,12 @@ def get_available_models() -> list:
|
|||||||
list: List of available model identifiers
|
list: List of available model identifiers
|
||||||
"""
|
"""
|
||||||
return [
|
return [
|
||||||
"openai/gpt-oss-120b:groq",
|
PREMIUM_DEFAULT_MODEL,
|
||||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||||
"Qwen/Qwen2.5-VL-7B-Instruct",
|
"Qwen/Qwen2.5-VL-7B-Instruct",
|
||||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||||
"microsoft/Phi-3-medium-4k-instruct:groq",
|
"microsoft/Phi-3-medium-4k-instruct:groq",
|
||||||
"mistralai/Mistral-7B-Instruct-v0.3:groq"
|
SIF_LOW_COST_MODEL_DEFAULTS[0]
|
||||||
]
|
]
|
||||||
|
|
||||||
def validate_model(model: str) -> bool:
|
def validate_model(model: str) -> bool:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from loguru import logger
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||||
|
<<<<<<< HEAD
|
||||||
from .tenant_provider_config import get_available_text_providers, get_tenant_api_key
|
from .tenant_provider_config import get_available_text_providers, get_tenant_api_key
|
||||||
from .routing_observability import emit_routing_event
|
from .routing_observability import emit_routing_event
|
||||||
|
|
||||||
@@ -128,6 +129,13 @@ def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[st
|
|||||||
|
|
||||||
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
|
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
|
||||||
return resolved or ["openai/gpt-oss-120b:groq"]
|
return resolved or ["openai/gpt-oss-120b:groq"]
|
||||||
|
=======
|
||||||
|
from .routing_policy import (
|
||||||
|
PREMIUM_DEFAULT_MODEL,
|
||||||
|
SIF_LOW_COST_MODEL_DEFAULTS,
|
||||||
|
resolve_text_provider_alias,
|
||||||
|
)
|
||||||
|
>>>>>>> pr-417
|
||||||
|
|
||||||
|
|
||||||
def llm_text_gen(
|
def llm_text_gen(
|
||||||
@@ -173,7 +181,11 @@ def llm_text_gen(
|
|||||||
=======
|
=======
|
||||||
gpt_provider = "google"
|
gpt_provider = "google"
|
||||||
model = "gemini-2.0-flash-001"
|
model = "gemini-2.0-flash-001"
|
||||||
|
<<<<<<< HEAD
|
||||||
>>>>>>> pr-416
|
>>>>>>> pr-416
|
||||||
|
=======
|
||||||
|
hf_low_cost_default_model = SIF_LOW_COST_MODEL_DEFAULTS[0]
|
||||||
|
>>>>>>> pr-417
|
||||||
temperature = 0.7
|
temperature = 0.7
|
||||||
max_tokens = 4000
|
max_tokens = 4000
|
||||||
top_p = 0.9
|
top_p = 0.9
|
||||||
@@ -185,7 +197,17 @@ def llm_text_gen(
|
|||||||
|
|
||||||
# Check for GPT_PROVIDER environment variable
|
# Check for GPT_PROVIDER environment variable
|
||||||
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
||||||
|
<<<<<<< HEAD
|
||||||
provider_list = [p.strip() for p in env_provider.split(',') if p.strip()]
|
provider_list = [p.strip() for p in env_provider.split(',') if p.strip()]
|
||||||
|
=======
|
||||||
|
resolved_env_provider = resolve_text_provider_alias(env_provider)
|
||||||
|
if resolved_env_provider == "google":
|
||||||
|
gpt_provider = "google"
|
||||||
|
model = "gemini-2.0-flash-001"
|
||||||
|
elif resolved_env_provider == "huggingface":
|
||||||
|
gpt_provider = "huggingface"
|
||||||
|
model = PREMIUM_DEFAULT_MODEL
|
||||||
|
>>>>>>> pr-417
|
||||||
|
|
||||||
# Determine if we're in strict mode (single provider) or fallback mode (multiple providers)
|
# Determine if we're in strict mode (single provider) or fallback mode (multiple providers)
|
||||||
strict_provider_mode = len(provider_list) == 1
|
strict_provider_mode = len(provider_list) == 1
|
||||||
@@ -335,19 +357,33 @@ def llm_text_gen(
|
|||||||
model = "gemini-2.0-flash-001"
|
model = "gemini-2.0-flash-001"
|
||||||
elif "huggingface" in available_providers:
|
elif "huggingface" in available_providers:
|
||||||
gpt_provider = "huggingface"
|
gpt_provider = "huggingface"
|
||||||
|
<<<<<<< HEAD
|
||||||
model = "openai/gpt-oss-120b:cerebras"
|
model = "openai/gpt-oss-120b:cerebras"
|
||||||
elif "wavespeed" in available_providers:
|
elif "wavespeed" in available_providers:
|
||||||
gpt_provider = "wavespeed"
|
gpt_provider = "wavespeed"
|
||||||
model = "openai/gpt-oss-120b"
|
model = "openai/gpt-oss-120b"
|
||||||
|
=======
|
||||||
|
model = PREMIUM_DEFAULT_MODEL
|
||||||
|
>>>>>>> pr-417
|
||||||
else:
|
else:
|
||||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||||
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
||||||
else:
|
else:
|
||||||
# Environment variable was set, validate it's supported
|
# Environment variable was set, validate it's supported
|
||||||
if gpt_provider not in available_providers:
|
if gpt_provider not in available_providers:
|
||||||
|
<<<<<<< HEAD
|
||||||
if strict_provider_mode:
|
if strict_provider_mode:
|
||||||
# Strict mode: fail if specified provider not available
|
# Strict mode: fail if specified provider not available
|
||||||
raise RuntimeError(f"Provider {gpt_provider} not available. Available: {available_providers}")
|
raise RuntimeError(f"Provider {gpt_provider} not available. Available: {available_providers}")
|
||||||
|
=======
|
||||||
|
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
|
||||||
|
if "google" in available_providers:
|
||||||
|
gpt_provider = "google"
|
||||||
|
model = "gemini-2.0-flash-001"
|
||||||
|
elif "huggingface" in available_providers:
|
||||||
|
gpt_provider = "huggingface"
|
||||||
|
model = PREMIUM_DEFAULT_MODEL
|
||||||
|
>>>>>>> pr-417
|
||||||
else:
|
else:
|
||||||
# Fallback mode: try other providers
|
# Fallback mode: try other providers
|
||||||
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
|
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
|
||||||
@@ -367,6 +403,7 @@ def llm_text_gen(
|
|||||||
provider_sequence = _resolve_provider_sequence(preferred_provider, env_provider_raw, available_providers)
|
provider_sequence = _resolve_provider_sequence(preferred_provider, env_provider_raw, available_providers)
|
||||||
>>>>>>> pr-416
|
>>>>>>> pr-416
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
if not provider_sequence:
|
if not provider_sequence:
|
||||||
logger.error("[llm_text_gen] No configured providers available for tenant.")
|
logger.error("[llm_text_gen] No configured providers available for tenant.")
|
||||||
raise RuntimeError("No LLM providers available for tenant.")
|
raise RuntimeError("No LLM providers available for tenant.")
|
||||||
@@ -387,6 +424,11 @@ def llm_text_gen(
|
|||||||
pinned_provider,
|
pinned_provider,
|
||||||
len(model_sequence) == 1,
|
len(model_sequence) == 1,
|
||||||
)
|
)
|
||||||
|
=======
|
||||||
|
if gpt_provider == "huggingface" and preferred_hf_models is not None:
|
||||||
|
model = preferred_hf_models[0] if preferred_hf_models else hf_low_cost_default_model
|
||||||
|
logger.info(f"[llm_text_gen] Using SIF low-cost HF model: {model}")
|
||||||
|
>>>>>>> pr-417
|
||||||
|
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
logger.info(f"[llm_text_gen][{flow_tag}] Using provider={gpt_provider}, model={model}")
|
logger.info(f"[llm_text_gen][{flow_tag}] Using provider={gpt_provider}, model={model}")
|
||||||
@@ -714,11 +756,15 @@ def llm_text_gen(
|
|||||||
elif fallback_provider == "huggingface":
|
elif fallback_provider == "huggingface":
|
||||||
provider_enum = APIProvider.MISTRAL
|
provider_enum = APIProvider.MISTRAL
|
||||||
actual_provider_name = "huggingface"
|
actual_provider_name = "huggingface"
|
||||||
|
<<<<<<< HEAD
|
||||||
fallback_model = preferred_hf_models[0] if preferred_hf_models else "openai/gpt-oss-120b:cerebras"
|
fallback_model = preferred_hf_models[0] if preferred_hf_models else "openai/gpt-oss-120b:cerebras"
|
||||||
elif fallback_provider == "wavespeed":
|
elif fallback_provider == "wavespeed":
|
||||||
provider_enum = APIProvider.WAVESPEED
|
provider_enum = APIProvider.WAVESPEED
|
||||||
actual_provider_name = "wavespeed"
|
actual_provider_name = "wavespeed"
|
||||||
fallback_model = "openai/gpt-oss-120b"
|
fallback_model = "openai/gpt-oss-120b"
|
||||||
|
=======
|
||||||
|
fallback_model = preferred_hf_models[0] if preferred_hf_models else PREMIUM_DEFAULT_MODEL
|
||||||
|
>>>>>>> pr-417
|
||||||
|
|
||||||
if fallback_provider == "google":
|
if fallback_provider == "google":
|
||||||
=======
|
=======
|
||||||
@@ -750,12 +796,16 @@ def llm_text_gen(
|
|||||||
response_text = huggingface_structured_json_response(
|
response_text = huggingface_structured_json_response(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
schema=json_struct,
|
schema=json_struct,
|
||||||
|
<<<<<<< HEAD
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
model=fallback_model,
|
model=fallback_model,
|
||||||
fallback_models=hf_fallback_models,
|
fallback_models=hf_fallback_models,
|
||||||
=======
|
=======
|
||||||
model=candidate_model,
|
model=candidate_model,
|
||||||
>>>>>>> pr-416
|
>>>>>>> pr-416
|
||||||
|
=======
|
||||||
|
model=fallback_model,
|
||||||
|
>>>>>>> pr-417
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
system_prompt=system_instructions,
|
system_prompt=system_instructions,
|
||||||
@@ -764,9 +814,13 @@ def llm_text_gen(
|
|||||||
else:
|
else:
|
||||||
response_text = huggingface_text_response(
|
response_text = huggingface_text_response(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
|
<<<<<<< HEAD
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
model=fallback_model,
|
model=fallback_model,
|
||||||
fallback_models=hf_fallback_models,
|
fallback_models=hf_fallback_models,
|
||||||
|
=======
|
||||||
|
model=fallback_model,
|
||||||
|
>>>>>>> pr-417
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
@@ -852,6 +906,7 @@ def llm_text_gen(
|
|||||||
|
|
||||||
def check_gpt_provider(gpt_provider: str) -> bool:
|
def check_gpt_provider(gpt_provider: str) -> bool:
|
||||||
"""Check if the specified GPT provider is supported."""
|
"""Check if the specified GPT provider is supported."""
|
||||||
|
<<<<<<< HEAD
|
||||||
<<<<<<< HEAD
|
<<<<<<< HEAD
|
||||||
supported_providers = ["google", "huggingface", "wavespeed"]
|
supported_providers = ["google", "huggingface", "wavespeed"]
|
||||||
return gpt_provider in supported_providers
|
return gpt_provider in supported_providers
|
||||||
@@ -862,6 +917,11 @@ def check_gpt_provider(gpt_provider: str) -> bool:
|
|||||||
supported_providers = {"google", "huggingface"}
|
supported_providers = {"google", "huggingface"}
|
||||||
return all(p in supported_providers for p in providers if p)
|
return all(p in supported_providers for p in providers if p)
|
||||||
>>>>>>> pr-416
|
>>>>>>> pr-416
|
||||||
|
=======
|
||||||
|
supported_providers = ["google", "huggingface"]
|
||||||
|
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
|
||||||
|
return resolved_provider in supported_providers
|
||||||
|
>>>>>>> pr-417
|
||||||
|
|
||||||
def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
|
def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
|
||||||
"""Get API key for the specified provider, preferring tenant-scoped keys."""
|
"""Get API key for the specified provider, preferring tenant-scoped keys."""
|
||||||
@@ -871,10 +931,15 @@ def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[st
|
|||||||
provider_mapping = {
|
provider_mapping = {
|
||||||
"google": "gemini",
|
"google": "gemini",
|
||||||
"huggingface": "hf_token",
|
"huggingface": "hf_token",
|
||||||
|
<<<<<<< HEAD
|
||||||
"wavespeed": "wavespeed"
|
"wavespeed": "wavespeed"
|
||||||
|
=======
|
||||||
|
"wavespeed": "hf_token"
|
||||||
|
>>>>>>> pr-417
|
||||||
}
|
}
|
||||||
|
|
||||||
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
|
||||||
|
mapped_provider = provider_mapping.get(resolved_provider, resolved_provider)
|
||||||
return api_key_manager.get_api_key(mapped_provider)
|
return api_key_manager.get_api_key(mapped_provider)
|
||||||
=======
|
=======
|
||||||
return get_tenant_api_key(user_id, gpt_provider)
|
return get_tenant_api_key(user_id, gpt_provider)
|
||||||
|
|||||||
30
backend/services/llm_providers/routing_policy.py
Normal file
30
backend/services/llm_providers/routing_policy.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Routing policy for LLM provider aliases and model defaults."""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
# Premium text generation defaults
|
||||||
|
PREMIUM_DEFAULT_PROVIDER = "huggingface"
|
||||||
|
PREMIUM_DEFAULT_MODEL = "openai/gpt-oss-120b:groq"
|
||||||
|
|
||||||
|
# SIF low-cost defaults for text generation
|
||||||
|
SIF_LOW_COST_MODEL_DEFAULTS: List[str] = [
|
||||||
|
"mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Canonical provider aliases for text routing
|
||||||
|
PROVIDER_ALIAS_MAPPING: Dict[str, str] = {
|
||||||
|
"gemini": "google",
|
||||||
|
"google": "google",
|
||||||
|
"hf_response_api": "huggingface",
|
||||||
|
"huggingface": "huggingface",
|
||||||
|
"hf": "huggingface",
|
||||||
|
# Text-only alias: route wavespeed GPT_PROVIDER to premium HF text route.
|
||||||
|
"wavespeed": PREMIUM_DEFAULT_PROVIDER,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_text_provider_alias(provider: str) -> str:
|
||||||
|
"""Resolve a GPT provider alias into a canonical text provider."""
|
||||||
|
return PROVIDER_ALIAS_MAPPING.get((provider or "").lower(), "")
|
||||||
|
|
||||||
Reference in New Issue
Block a user