157 lines
6.2 KiB
Python
157 lines
6.2 KiB
Python
"""
|
|
Provider Detection Utility
|
|
Detects the actual provider (WaveSpeed, Google, HuggingFace, etc.) from model names and endpoints.
|
|
"""
|
|
|
|
from typing import Optional
|
|
from models.subscription_models import APIProvider
|
|
from loguru import logger
|
|
|
|
def detect_actual_provider(provider_enum: APIProvider, model_name: Optional[str] = None, endpoint: Optional[str] = None) -> str:
|
|
"""
|
|
Detect the actual provider name from provider enum, model name, and endpoint.
|
|
|
|
Args:
|
|
provider_enum: The APIProvider enum value (e.g., APIProvider.VIDEO, APIProvider.GEMINI)
|
|
model_name: The model name (e.g., "alibaba/wan-2.5/text-to-video", "gemini-2.5-flash")
|
|
endpoint: The API endpoint (e.g., "/video-generation/wavespeed", "/image-generation/stability")
|
|
|
|
Returns:
|
|
Actual provider name: "wavespeed", "google", "huggingface", "stability", "openai", "anthropic", etc.
|
|
"""
|
|
|
|
# For LLM providers, use the enum value directly
|
|
if provider_enum in [APIProvider.GEMINI]:
|
|
return "google"
|
|
elif provider_enum == APIProvider.OPENAI:
|
|
return "openai"
|
|
elif provider_enum == APIProvider.ANTHROPIC:
|
|
return "anthropic"
|
|
elif provider_enum == APIProvider.MISTRAL:
|
|
# MISTRAL enum is used for HuggingFace models
|
|
return "huggingface"
|
|
|
|
# For search APIs, use the enum value
|
|
elif provider_enum in [APIProvider.TAVILY, APIProvider.SERPER, APIProvider.METAPHOR, APIProvider.FIRECRAWL, APIProvider.EXA]:
|
|
return provider_enum.value
|
|
|
|
# For media generation, detect from model name or endpoint
|
|
elif provider_enum == APIProvider.VIDEO:
|
|
# Check model name first
|
|
if model_name:
|
|
model_lower = model_name.lower()
|
|
# WaveSpeed models
|
|
if any(x in model_lower for x in ["wan-2.5", "seedance", "infinitetalk", "wavespeed", "alibaba"]):
|
|
return "wavespeed"
|
|
# HuggingFace models
|
|
elif any(x in model_lower for x in ["huggingface", "hf", "tencent", "hunyuan"]):
|
|
return "huggingface"
|
|
# Google models (future)
|
|
elif any(x in model_lower for x in ["veo", "gemini"]):
|
|
return "google"
|
|
# OpenAI models (future)
|
|
elif any(x in model_lower for x in ["sora", "openai"]):
|
|
return "openai"
|
|
|
|
# Check endpoint
|
|
if endpoint:
|
|
endpoint_lower = endpoint.lower()
|
|
if "wavespeed" in endpoint_lower:
|
|
return "wavespeed"
|
|
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
|
return "huggingface"
|
|
elif "google" in endpoint_lower or "gemini" in endpoint_lower:
|
|
return "google"
|
|
elif "openai" in endpoint_lower:
|
|
return "openai"
|
|
|
|
# Default for video: WaveSpeed (most common)
|
|
return "wavespeed"
|
|
|
|
elif provider_enum == APIProvider.AUDIO:
|
|
# Check model name first
|
|
if model_name:
|
|
model_lower = model_name.lower()
|
|
# WaveSpeed models
|
|
if any(x in model_lower for x in ["minimax", "speech-02", "wavespeed"]):
|
|
return "wavespeed"
|
|
# Google models
|
|
elif any(x in model_lower for x in ["google", "gemini", "tts"]):
|
|
return "google"
|
|
# OpenAI models
|
|
elif any(x in model_lower for x in ["openai", "tts-1"]):
|
|
return "openai"
|
|
# ElevenLabs (future)
|
|
elif "elevenlabs" in model_lower:
|
|
return "elevenlabs"
|
|
|
|
# Check endpoint
|
|
if endpoint:
|
|
endpoint_lower = endpoint.lower()
|
|
if "wavespeed" in endpoint_lower:
|
|
return "wavespeed"
|
|
elif "google" in endpoint_lower:
|
|
return "google"
|
|
elif "openai" in endpoint_lower:
|
|
return "openai"
|
|
|
|
# Default for audio: WaveSpeed (most common)
|
|
return "wavespeed"
|
|
|
|
elif provider_enum == APIProvider.STABILITY:
|
|
# Check model name first
|
|
if model_name:
|
|
model_lower = model_name.lower()
|
|
# WaveSpeed OSS models
|
|
if any(x in model_lower for x in ["qwen", "ideogram", "flux", "wavespeed"]):
|
|
return "wavespeed"
|
|
# Stability AI models
|
|
elif any(x in model_lower for x in ["stability", "stable-diffusion", "sd-"]):
|
|
return "stability"
|
|
# HuggingFace models
|
|
elif any(x in model_lower for x in ["huggingface", "hf", "runway"]):
|
|
return "huggingface"
|
|
|
|
# Check endpoint
|
|
if endpoint:
|
|
endpoint_lower = endpoint.lower()
|
|
if "wavespeed" in endpoint_lower:
|
|
return "wavespeed"
|
|
elif "stability" in endpoint_lower:
|
|
return "stability"
|
|
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
|
return "huggingface"
|
|
|
|
# Default: check if it's actually WaveSpeed based on common OSS models
|
|
if model_name and any(x in model_name.lower() for x in ["qwen", "ideogram", "flux"]):
|
|
return "wavespeed"
|
|
|
|
# Default for image generation: Stability (legacy)
|
|
return "stability"
|
|
|
|
elif provider_enum == APIProvider.IMAGE_EDIT:
|
|
# Check model name first
|
|
if model_name:
|
|
model_lower = model_name.lower()
|
|
# WaveSpeed OSS models
|
|
if any(x in model_lower for x in ["qwen", "flux", "kontext", "wavespeed"]):
|
|
return "wavespeed"
|
|
# Stability AI models
|
|
elif any(x in model_lower for x in ["stability", "stable-diffusion"]):
|
|
return "stability"
|
|
|
|
# Check endpoint
|
|
if endpoint:
|
|
endpoint_lower = endpoint.lower()
|
|
if "wavespeed" in endpoint_lower:
|
|
return "wavespeed"
|
|
elif "stability" in endpoint_lower:
|
|
return "stability"
|
|
|
|
# Default for image editing: WaveSpeed (OSS-first strategy)
|
|
return "wavespeed"
|
|
|
|
# Fallback: use enum value
|
|
logger.warning(f"Could not detect actual provider for {provider_enum.value}, using enum value")
|
|
return provider_enum.value
|