Files

78 lines
2.6 KiB
Python

"""
Model Utilities Module
This module contains model-related utility functions extracted from main_text_generation.py
to resolve merge conflicts and improve maintainability.
"""
import os
from typing import Optional, List
from ..routing_policy import PREMIUM_DEFAULT_MODEL
def _map_logical_model_to_provider_model(provider: str, model_name: str) -> str:
"""Map logical model aliases/full names to provider-specific model IDs."""
raw = (model_name or "").strip()
if not raw:
return raw
# Full provider path supplied explicitly; use as-is.
if "/" in raw:
return raw
key = raw.lower()
hf_map = {
"gpt-oss": "openai/gpt-oss-120b:cerebras",
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"llama": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:groq",
}
wavespeed_map = {
"gpt-oss": "openai/gpt-oss-120b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"gpt-oss-20b": "openai/gpt-oss-20b",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
"llama": "meta-llama/Llama-3.1-8B-Instruct",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct",
}
if provider in {"huggingface", "hf", "hf_response_api"}:
return hf_map.get(key, raw)
if provider == "wavespeed":
return wavespeed_map.get(key, raw)
return raw
def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[str]] = None) -> List[str]:
"""Resolve model sequence for a given provider."""
models_env = _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", ""))
if provider == "google":
return ["gemini-2.0-flash-001"]
if preferred_hf_models:
return [_map_logical_model_to_provider_model(provider, m) for m in preferred_hf_models if m]
if not models_env:
return [PREMIUM_DEFAULT_MODEL]
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
return resolved or [PREMIUM_DEFAULT_MODEL]
def _parse_csv_env(value: Optional[str]) -> List[str]:
"""Parse CSV environment variable into list of values."""
if not value:
return []
return [v.strip() for v in str(value).split(",") if v.strip()]