Merge_PR_417_centralized_text_routing_policy

This commit is contained in:
ajaysi
2026-03-12 16:08:40 +05:30
3 changed files with 118 additions and 3 deletions

View File

@@ -12,6 +12,7 @@ from loguru import logger
from fastapi import HTTPException
from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
<<<<<<< HEAD
from .tenant_provider_config import get_available_text_providers, get_tenant_api_key
from .routing_observability import emit_routing_event
@@ -128,6 +129,13 @@ def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[st
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
return resolved or ["openai/gpt-oss-120b:groq"]
=======
from .routing_policy import (
PREMIUM_DEFAULT_MODEL,
SIF_LOW_COST_MODEL_DEFAULTS,
resolve_text_provider_alias,
)
>>>>>>> pr-417
def llm_text_gen(
@@ -173,7 +181,11 @@ def llm_text_gen(
=======
gpt_provider = "google"
model = "gemini-2.0-flash-001"
<<<<<<< HEAD
>>>>>>> pr-416
=======
hf_low_cost_default_model = SIF_LOW_COST_MODEL_DEFAULTS[0]
>>>>>>> pr-417
temperature = 0.7
max_tokens = 4000
top_p = 0.9
@@ -185,7 +197,17 @@ def llm_text_gen(
# Check for GPT_PROVIDER environment variable
env_provider = os.getenv('GPT_PROVIDER', '').lower()
<<<<<<< HEAD
provider_list = [p.strip() for p in env_provider.split(',') if p.strip()]
=======
resolved_env_provider = resolve_text_provider_alias(env_provider)
if resolved_env_provider == "google":
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif resolved_env_provider == "huggingface":
gpt_provider = "huggingface"
model = PREMIUM_DEFAULT_MODEL
>>>>>>> pr-417
# Determine if we're in strict mode (single provider) or fallback mode (multiple providers)
strict_provider_mode = len(provider_list) == 1
@@ -335,19 +357,33 @@ def llm_text_gen(
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
<<<<<<< HEAD
model = "openai/gpt-oss-120b:cerebras"
elif "wavespeed" in available_providers:
gpt_provider = "wavespeed"
model = "openai/gpt-oss-120b"
=======
model = PREMIUM_DEFAULT_MODEL
>>>>>>> pr-417
else:
logger.error("[llm_text_gen] No API keys found for supported providers.")
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
else:
# Environment variable was set, validate it's supported
if gpt_provider not in available_providers:
<<<<<<< HEAD
if strict_provider_mode:
# Strict mode: fail if specified provider not available
raise RuntimeError(f"Provider {gpt_provider} not available. Available: {available_providers}")
=======
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
if "google" in available_providers:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = PREMIUM_DEFAULT_MODEL
>>>>>>> pr-417
else:
# Fallback mode: try other providers
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
@@ -367,6 +403,7 @@ def llm_text_gen(
provider_sequence = _resolve_provider_sequence(preferred_provider, env_provider_raw, available_providers)
>>>>>>> pr-416
<<<<<<< HEAD
if not provider_sequence:
logger.error("[llm_text_gen] No configured providers available for tenant.")
raise RuntimeError("No LLM providers available for tenant.")
@@ -387,6 +424,11 @@ def llm_text_gen(
pinned_provider,
len(model_sequence) == 1,
)
=======
if gpt_provider == "huggingface" and preferred_hf_models is not None:
model = preferred_hf_models[0] if preferred_hf_models else hf_low_cost_default_model
logger.info(f"[llm_text_gen] Using SIF low-cost HF model: {model}")
>>>>>>> pr-417
<<<<<<< HEAD
logger.info(f"[llm_text_gen][{flow_tag}] Using provider={gpt_provider}, model={model}")
@@ -714,11 +756,15 @@ def llm_text_gen(
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
<<<<<<< HEAD
fallback_model = preferred_hf_models[0] if preferred_hf_models else "openai/gpt-oss-120b:cerebras"
elif fallback_provider == "wavespeed":
provider_enum = APIProvider.WAVESPEED
actual_provider_name = "wavespeed"
fallback_model = "openai/gpt-oss-120b"
=======
fallback_model = preferred_hf_models[0] if preferred_hf_models else PREMIUM_DEFAULT_MODEL
>>>>>>> pr-417
if fallback_provider == "google":
=======
@@ -750,12 +796,16 @@ def llm_text_gen(
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
<<<<<<< HEAD
<<<<<<< HEAD
model=fallback_model,
fallback_models=hf_fallback_models,
=======
model=candidate_model,
>>>>>>> pr-416
=======
model=fallback_model,
>>>>>>> pr-417
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions,
@@ -764,9 +814,13 @@ def llm_text_gen(
else:
response_text = huggingface_text_response(
prompt=prompt,
<<<<<<< HEAD
<<<<<<< HEAD
model=fallback_model,
fallback_models=hf_fallback_models,
=======
model=fallback_model,
>>>>>>> pr-417
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
@@ -852,6 +906,7 @@ def llm_text_gen(
def check_gpt_provider(gpt_provider: str) -> bool:
"""Check if the specified GPT provider is supported."""
<<<<<<< HEAD
<<<<<<< HEAD
supported_providers = ["google", "huggingface", "wavespeed"]
return gpt_provider in supported_providers
@@ -862,6 +917,11 @@ def check_gpt_provider(gpt_provider: str) -> bool:
supported_providers = {"google", "huggingface"}
return all(p in supported_providers for p in providers if p)
>>>>>>> pr-416
=======
supported_providers = ["google", "huggingface"]
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
return resolved_provider in supported_providers
>>>>>>> pr-417
def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
"""Get API key for the specified provider, preferring tenant-scoped keys."""
@@ -871,10 +931,15 @@ def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[st
provider_mapping = {
"google": "gemini",
"huggingface": "hf_token",
<<<<<<< HEAD
"wavespeed": "wavespeed"
=======
"wavespeed": "hf_token"
>>>>>>> pr-417
}
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
resolved_provider = resolve_text_provider_alias(gpt_provider) or gpt_provider
mapped_provider = provider_mapping.get(resolved_provider, resolved_provider)
return api_key_manager.get_api_key(mapped_provider)
=======
return get_tenant_api_key(user_id, gpt_provider)