"Extract_text_generation_utilities_into_modular_structure"
This commit is contained in:
77
backend/services/llm_providers/textgen_utils/model_utils.py
Normal file
77
backend/services/llm_providers/textgen_utils/model_utils.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Model Utilities Module
|
||||
|
||||
This module contains model-related utility functions extracted from main_text_generation.py
|
||||
to resolve merge conflicts and improve maintainability.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, List
|
||||
|
||||
from ..routing_policy import PREMIUM_DEFAULT_MODEL
|
||||
|
||||
|
||||
def _map_logical_model_to_provider_model(provider: str, model_name: str) -> str:
|
||||
"""Map logical model aliases/full names to provider-specific model IDs."""
|
||||
raw = (model_name or "").strip()
|
||||
if not raw:
|
||||
return raw
|
||||
|
||||
# Full provider path supplied explicitly; use as-is.
|
||||
if "/" in raw:
|
||||
return raw
|
||||
|
||||
key = raw.lower()
|
||||
|
||||
hf_map = {
|
||||
"gpt-oss": "openai/gpt-oss-120b:cerebras",
|
||||
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
|
||||
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
|
||||
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"llama": "meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:groq",
|
||||
}
|
||||
|
||||
wavespeed_map = {
|
||||
"gpt-oss": "openai/gpt-oss-120b",
|
||||
"gpt-oss-120b": "openai/gpt-oss-120b",
|
||||
"gpt-oss-20b": "openai/gpt-oss-20b",
|
||||
"mistral": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"llama": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
}
|
||||
|
||||
if provider in {"huggingface", "hf", "hf_response_api"}:
|
||||
return hf_map.get(key, raw)
|
||||
if provider == "wavespeed":
|
||||
return wavespeed_map.get(key, raw)
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[str]] = None) -> List[str]:
|
||||
"""Resolve model sequence for a given provider."""
|
||||
models_env = _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", ""))
|
||||
|
||||
if provider == "google":
|
||||
return ["gemini-2.0-flash-001"]
|
||||
|
||||
if preferred_hf_models:
|
||||
return [_map_logical_model_to_provider_model(provider, m) for m in preferred_hf_models if m]
|
||||
|
||||
if not models_env:
|
||||
return [PREMIUM_DEFAULT_MODEL]
|
||||
|
||||
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
|
||||
return resolved or [PREMIUM_DEFAULT_MODEL]
|
||||
|
||||
|
||||
def _parse_csv_env(value: Optional[str]) -> List[str]:
|
||||
"""Parse CSV environment variable into list of values."""
|
||||
if not value:
|
||||
return []
|
||||
return [v.strip() for v in str(value).split(",") if v.strip()]
|
||||
Reference in New Issue
Block a user