78 lines
2.6 KiB
Python
78 lines
2.6 KiB
Python
"""
|
|
Model Utilities Module
|
|
|
|
This module contains model-related utility functions extracted from main_text_generation.py
|
|
to resolve merge conflicts and improve maintainability.
|
|
"""
|
|
|
|
import os
|
|
from typing import Optional, List
|
|
|
|
from ..routing_policy import PREMIUM_DEFAULT_MODEL
|
|
|
|
|
|
def _map_logical_model_to_provider_model(provider: str, model_name: str) -> str:
|
|
"""Map logical model aliases/full names to provider-specific model IDs."""
|
|
raw = (model_name or "").strip()
|
|
if not raw:
|
|
return raw
|
|
|
|
# Full provider path supplied explicitly; use as-is.
|
|
if "/" in raw:
|
|
return raw
|
|
|
|
key = raw.lower()
|
|
|
|
hf_map = {
|
|
"gpt-oss": "openai/gpt-oss-120b:cerebras",
|
|
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
|
|
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
|
|
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
|
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
|
"llama": "meta-llama/Llama-3.1-8B-Instruct:groq",
|
|
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:groq",
|
|
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:groq",
|
|
}
|
|
|
|
wavespeed_map = {
|
|
"gpt-oss": "openai/gpt-oss-120b",
|
|
"gpt-oss-120b": "openai/gpt-oss-120b",
|
|
"gpt-oss-20b": "openai/gpt-oss-20b",
|
|
"mistral": "mistralai/Mistral-7B-Instruct-v0.3",
|
|
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
|
|
"llama": "meta-llama/Llama-3.1-8B-Instruct",
|
|
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
|
|
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct",
|
|
}
|
|
|
|
if provider in {"huggingface", "hf", "hf_response_api"}:
|
|
return hf_map.get(key, raw)
|
|
if provider == "wavespeed":
|
|
return wavespeed_map.get(key, raw)
|
|
|
|
return raw
|
|
|
|
|
|
def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[str]] = None) -> List[str]:
|
|
"""Resolve model sequence for a given provider."""
|
|
models_env = _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", ""))
|
|
|
|
if provider == "google":
|
|
return ["gemini-2.0-flash-001"]
|
|
|
|
if preferred_hf_models:
|
|
return [_map_logical_model_to_provider_model(provider, m) for m in preferred_hf_models if m]
|
|
|
|
if not models_env:
|
|
return [PREMIUM_DEFAULT_MODEL]
|
|
|
|
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
|
|
return resolved or [PREMIUM_DEFAULT_MODEL]
|
|
|
|
|
|
def _parse_csv_env(value: Optional[str]) -> List[str]:
|
|
"""Parse CSV environment variable into list of values."""
|
|
if not value:
|
|
return []
|
|
return [v.strip() for v in str(value).split(",") if v.strip()]
|