AI Analysis and Content Strategy fixes. Enhanced Strategy Routes refactoring.
This commit is contained in:
156
backend/services/subscription/provider_detection.py
Normal file
156
backend/services/subscription/provider_detection.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Provider Detection Utility
|
||||
Detects the actual provider (WaveSpeed, Google, HuggingFace, etc.) from model names and endpoints.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from models.subscription_models import APIProvider
|
||||
from loguru import logger
|
||||
|
||||
def detect_actual_provider(provider_enum: APIProvider, model_name: Optional[str] = None, endpoint: Optional[str] = None) -> str:
|
||||
"""
|
||||
Detect the actual provider name from provider enum, model name, and endpoint.
|
||||
|
||||
Args:
|
||||
provider_enum: The APIProvider enum value (e.g., APIProvider.VIDEO, APIProvider.GEMINI)
|
||||
model_name: The model name (e.g., "alibaba/wan-2.5/text-to-video", "gemini-2.5-flash")
|
||||
endpoint: The API endpoint (e.g., "/video-generation/wavespeed", "/image-generation/stability")
|
||||
|
||||
Returns:
|
||||
Actual provider name: "wavespeed", "google", "huggingface", "stability", "openai", "anthropic", etc.
|
||||
"""
|
||||
|
||||
# For LLM providers, use the enum value directly
|
||||
if provider_enum in [APIProvider.GEMINI]:
|
||||
return "google"
|
||||
elif provider_enum == APIProvider.OPENAI:
|
||||
return "openai"
|
||||
elif provider_enum == APIProvider.ANTHROPIC:
|
||||
return "anthropic"
|
||||
elif provider_enum == APIProvider.MISTRAL:
|
||||
# MISTRAL enum is used for HuggingFace models
|
||||
return "huggingface"
|
||||
|
||||
# For search APIs, use the enum value
|
||||
elif provider_enum in [APIProvider.TAVILY, APIProvider.SERPER, APIProvider.METAPHOR, APIProvider.FIRECRAWL, APIProvider.EXA]:
|
||||
return provider_enum.value
|
||||
|
||||
# For media generation, detect from model name or endpoint
|
||||
elif provider_enum == APIProvider.VIDEO:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed models
|
||||
if any(x in model_lower for x in ["wan-2.5", "seedance", "infinitetalk", "wavespeed", "alibaba"]):
|
||||
return "wavespeed"
|
||||
# HuggingFace models
|
||||
elif any(x in model_lower for x in ["huggingface", "hf", "tencent", "hunyuan"]):
|
||||
return "huggingface"
|
||||
# Google models (future)
|
||||
elif any(x in model_lower for x in ["veo", "gemini"]):
|
||||
return "google"
|
||||
# OpenAI models (future)
|
||||
elif any(x in model_lower for x in ["sora", "openai"]):
|
||||
return "openai"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
||||
return "huggingface"
|
||||
elif "google" in endpoint_lower or "gemini" in endpoint_lower:
|
||||
return "google"
|
||||
elif "openai" in endpoint_lower:
|
||||
return "openai"
|
||||
|
||||
# Default for video: WaveSpeed (most common)
|
||||
return "wavespeed"
|
||||
|
||||
elif provider_enum == APIProvider.AUDIO:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed models
|
||||
if any(x in model_lower for x in ["minimax", "speech-02", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Google models
|
||||
elif any(x in model_lower for x in ["google", "gemini", "tts"]):
|
||||
return "google"
|
||||
# OpenAI models
|
||||
elif any(x in model_lower for x in ["openai", "tts-1"]):
|
||||
return "openai"
|
||||
# ElevenLabs (future)
|
||||
elif "elevenlabs" in model_lower:
|
||||
return "elevenlabs"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "google" in endpoint_lower:
|
||||
return "google"
|
||||
elif "openai" in endpoint_lower:
|
||||
return "openai"
|
||||
|
||||
# Default for audio: WaveSpeed (most common)
|
||||
return "wavespeed"
|
||||
|
||||
elif provider_enum == APIProvider.STABILITY:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed OSS models
|
||||
if any(x in model_lower for x in ["qwen", "ideogram", "flux", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Stability AI models
|
||||
elif any(x in model_lower for x in ["stability", "stable-diffusion", "sd-"]):
|
||||
return "stability"
|
||||
# HuggingFace models
|
||||
elif any(x in model_lower for x in ["huggingface", "hf", "runway"]):
|
||||
return "huggingface"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "stability" in endpoint_lower:
|
||||
return "stability"
|
||||
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
||||
return "huggingface"
|
||||
|
||||
# Default: check if it's actually WaveSpeed based on common OSS models
|
||||
if model_name and any(x in model_name.lower() for x in ["qwen", "ideogram", "flux"]):
|
||||
return "wavespeed"
|
||||
|
||||
# Default for image generation: Stability (legacy)
|
||||
return "stability"
|
||||
|
||||
elif provider_enum == APIProvider.IMAGE_EDIT:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed OSS models
|
||||
if any(x in model_lower for x in ["qwen", "flux", "kontext", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Stability AI models
|
||||
elif any(x in model_lower for x in ["stability", "stable-diffusion"]):
|
||||
return "stability"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "stability" in endpoint_lower:
|
||||
return "stability"
|
||||
|
||||
# Default for image editing: WaveSpeed (OSS-first strategy)
|
||||
return "wavespeed"
|
||||
|
||||
# Fallback: use enum value
|
||||
logger.warning(f"Could not detect actual provider for {provider_enum.value}, using enum value")
|
||||
return provider_enum.value
|
||||
Reference in New Issue
Block a user