Add tenant-aware provider config resolver across LLM facades
This commit is contained in:
@@ -13,12 +13,16 @@ from loguru import logger
|
|||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from services.wavespeed.client import WaveSpeedClient
|
from services.wavespeed.client import WaveSpeedClient
|
||||||
from services.onboarding.api_key_manager import APIKeyManager
|
|
||||||
from utils.logger_utils import get_service_logger
|
from utils.logger_utils import get_service_logger
|
||||||
|
from .tenant_provider_config import tenant_provider_config_resolver
|
||||||
|
|
||||||
logger = get_service_logger("audio_generation")
|
logger = get_service_logger("audio_generation")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_wavespeed_client(user_id: Optional[str]) -> WaveSpeedClient:
|
||||||
|
key, _source = tenant_provider_config_resolver.resolve_provider_key("wavespeed", user_id=user_id)
|
||||||
|
return WaveSpeedClient(api_key=key)
|
||||||
|
|
||||||
class AudioGenerationResult:
|
class AudioGenerationResult:
|
||||||
"""Result of audio generation."""
|
"""Result of audio generation."""
|
||||||
|
|
||||||
@@ -165,7 +169,7 @@ def generate_audio(
|
|||||||
# Track response time
|
# Track response time
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
client = WaveSpeedClient()
|
client = _get_wavespeed_client(user_id)
|
||||||
audio_bytes = client.generate_speech(
|
audio_bytes = client.generate_speech(
|
||||||
text=text,
|
text=text,
|
||||||
voice_id=voice_id,
|
voice_id=voice_id,
|
||||||
@@ -424,7 +428,7 @@ def clone_voice(
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
client = WaveSpeedClient()
|
client = _get_wavespeed_client(user_id)
|
||||||
preview_audio_bytes = client.voice_clone(
|
preview_audio_bytes = client.voice_clone(
|
||||||
audio_bytes=bytes(audio_bytes),
|
audio_bytes=bytes(audio_bytes),
|
||||||
custom_voice_id=custom_voice_id,
|
custom_voice_id=custom_voice_id,
|
||||||
@@ -617,7 +621,7 @@ def qwen3_voice_clone(
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
client = WaveSpeedClient()
|
client = _get_wavespeed_client(user_id)
|
||||||
preview_audio_bytes = client.qwen3_voice_clone(
|
preview_audio_bytes = client.qwen3_voice_clone(
|
||||||
audio_bytes=bytes(audio_bytes),
|
audio_bytes=bytes(audio_bytes),
|
||||||
text=text,
|
text=text,
|
||||||
@@ -802,7 +806,7 @@ def qwen3_voice_design(
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
client = WaveSpeedClient()
|
client = _get_wavespeed_client(user_id)
|
||||||
preview_audio_bytes = client.voice_design(
|
preview_audio_bytes = client.voice_design(
|
||||||
text=text,
|
text=text,
|
||||||
voice_description=voice_description,
|
voice_description=voice_description,
|
||||||
@@ -989,7 +993,7 @@ def cosyvoice_voice_clone(
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
client = WaveSpeedClient()
|
client = _get_wavespeed_client(user_id)
|
||||||
preview_audio_bytes = client.cosyvoice_voice_clone(
|
preview_audio_bytes = client.cosyvoice_voice_clone(
|
||||||
audio_bytes=bytes(audio_bytes),
|
audio_bytes=bytes(audio_bytes),
|
||||||
text=text,
|
text=text,
|
||||||
|
|||||||
@@ -22,40 +22,34 @@ from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
|||||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||||
from utils.logger_utils import get_service_logger
|
from utils.logger_utils import get_service_logger
|
||||||
|
from .tenant_provider_config import tenant_provider_config_resolver
|
||||||
|
|
||||||
|
|
||||||
logger = get_service_logger("image_generation.facade")
|
logger = get_service_logger("image_generation.facade")
|
||||||
|
|
||||||
|
|
||||||
def _select_provider(explicit: Optional[str]) -> str:
|
def _select_provider(explicit: Optional[str], user_id: Optional[str] = None) -> str:
|
||||||
if explicit:
|
cfg = tenant_provider_config_resolver.resolve(
|
||||||
return explicit
|
modality="image",
|
||||||
|
user_id=user_id,
|
||||||
# User requested WaveSpeed as default provider
|
explicit_provider=explicit,
|
||||||
if os.getenv("WAVESPEED_API_KEY"):
|
)
|
||||||
return "wavespeed"
|
return (cfg.selected_providers or [explicit or "huggingface"])[0]
|
||||||
|
|
||||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
|
||||||
if gpt_provider.startswith("gemini"):
|
|
||||||
return "gemini"
|
|
||||||
if gpt_provider.startswith("hf"):
|
|
||||||
return "huggingface"
|
|
||||||
if os.getenv("STABILITY_API_KEY"):
|
|
||||||
return "stability"
|
|
||||||
|
|
||||||
# Fallback to huggingface to enable a path if configured
|
|
||||||
return "huggingface"
|
|
||||||
|
|
||||||
|
|
||||||
def _get_provider(provider_name: str):
|
def _get_provider(provider_name: str, user_id: Optional[str] = None):
|
||||||
|
key, _source = tenant_provider_config_resolver.resolve_provider_key(provider_name, user_id=user_id)
|
||||||
if provider_name == "huggingface":
|
if provider_name == "huggingface":
|
||||||
return HuggingFaceImageProvider()
|
return HuggingFaceImageProvider(api_key=key)
|
||||||
if provider_name == "gemini":
|
if provider_name == "gemini":
|
||||||
|
if key:
|
||||||
|
os.environ["GEMINI_API_KEY"] = key
|
||||||
|
os.environ.setdefault("GOOGLE_API_KEY", key)
|
||||||
return GeminiImageProvider()
|
return GeminiImageProvider()
|
||||||
if provider_name == "stability":
|
if provider_name == "stability":
|
||||||
return StabilityImageProvider()
|
return StabilityImageProvider(api_key=key)
|
||||||
if provider_name == "wavespeed":
|
if provider_name == "wavespeed":
|
||||||
return WaveSpeedImageProvider()
|
return WaveSpeedImageProvider(api_key=key)
|
||||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||||
|
|
||||||
|
|
||||||
@@ -328,7 +322,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
|||||||
log_prefix="[Image Generation]"
|
log_prefix="[Image Generation]"
|
||||||
)
|
)
|
||||||
opts = options or {}
|
opts = options or {}
|
||||||
provider_name = _select_provider(opts.get("provider"))
|
provider_name = _select_provider(opts.get("provider"), user_id=user_id)
|
||||||
|
|
||||||
image_options = ImageGenerationOptions(
|
image_options = ImageGenerationOptions(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
@@ -370,7 +364,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
|||||||
image_options.model = "qwen-image"
|
image_options.model = "qwen-image"
|
||||||
|
|
||||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||||
provider = _get_provider(provider_name)
|
provider = _get_provider(provider_name, user_id=user_id)
|
||||||
|
|
||||||
# Track response time
|
# Track response time
|
||||||
import time
|
import time
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ from typing import Optional, Dict, Any, List
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from ..onboarding.api_key_manager import APIKeyManager
|
|
||||||
|
|
||||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||||
|
from .tenant_provider_config import tenant_provider_config_resolver
|
||||||
|
|
||||||
|
|
||||||
def llm_text_gen(
|
def llm_text_gen(
|
||||||
@@ -53,14 +53,17 @@ def llm_text_gen(
|
|||||||
frequency_penalty = 0.0
|
frequency_penalty = 0.0
|
||||||
presence_penalty = 0.0
|
presence_penalty = 0.0
|
||||||
|
|
||||||
# Check for GPT_PROVIDER environment variable
|
provider_cfg = tenant_provider_config_resolver.resolve(
|
||||||
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
modality="text",
|
||||||
if env_provider in ['gemini', 'google']:
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
selected_provider = (provider_cfg.selected_providers or [None])[0]
|
||||||
|
if selected_provider in ["gemini", "google"]:
|
||||||
gpt_provider = "google"
|
gpt_provider = "google"
|
||||||
model = "gemini-2.0-flash-001"
|
model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001"
|
||||||
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
|
elif selected_provider == "huggingface":
|
||||||
gpt_provider = "huggingface"
|
gpt_provider = "huggingface"
|
||||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
model = provider_cfg.model_policy.get("default_model") or "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||||
|
|
||||||
# Default blog characteristics
|
# Default blog characteristics
|
||||||
blog_tone = "Professional"
|
blog_tone = "Professional"
|
||||||
@@ -70,38 +73,26 @@ def llm_text_gen(
|
|||||||
blog_output_format = "markdown"
|
blog_output_format = "markdown"
|
||||||
blog_length = 2000
|
blog_length = 2000
|
||||||
|
|
||||||
# Check which providers have API keys available using APIKeyManager
|
|
||||||
api_key_manager = APIKeyManager()
|
|
||||||
available_providers = []
|
available_providers = []
|
||||||
if api_key_manager.get_api_key("gemini"):
|
for provider in ("google", "huggingface"):
|
||||||
available_providers.append("google")
|
if get_api_key(provider, user_id=user_id):
|
||||||
if api_key_manager.get_api_key("hf_token"):
|
available_providers.append(provider)
|
||||||
available_providers.append("huggingface")
|
|
||||||
|
if gpt_provider not in available_providers:
|
||||||
# If no environment variable set, auto-detect based on available keys
|
logger.warning(f"[llm_text_gen] Provider {gpt_provider} unavailable for user {user_id}, falling back.")
|
||||||
if not env_provider:
|
if available_providers:
|
||||||
# Prefer Google Gemini if available, otherwise use Hugging Face
|
gpt_provider = available_providers[0]
|
||||||
if "google" in available_providers:
|
|
||||||
gpt_provider = "google"
|
|
||||||
model = "gemini-2.0-flash-001"
|
|
||||||
elif "huggingface" in available_providers:
|
|
||||||
gpt_provider = "huggingface"
|
|
||||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
|
||||||
else:
|
else:
|
||||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||||
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
raise RuntimeError("No LLM API keys configured for tenant or environment defaults.")
|
||||||
else:
|
|
||||||
# Environment variable was set, validate it's supported
|
# Ensure downstream provider clients (currently env-based) receive resolved key
|
||||||
if gpt_provider not in available_providers:
|
resolved_key = get_api_key(gpt_provider, user_id=user_id)
|
||||||
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
|
if gpt_provider == "google" and resolved_key:
|
||||||
if "google" in available_providers:
|
os.environ["GEMINI_API_KEY"] = resolved_key
|
||||||
gpt_provider = "google"
|
os.environ.setdefault("GOOGLE_API_KEY", resolved_key)
|
||||||
model = "gemini-2.0-flash-001"
|
elif gpt_provider == "huggingface" and resolved_key:
|
||||||
elif "huggingface" in available_providers:
|
os.environ["HF_TOKEN"] = resolved_key
|
||||||
gpt_provider = "huggingface"
|
|
||||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
|
||||||
else:
|
|
||||||
raise RuntimeError("No supported providers available.")
|
|
||||||
|
|
||||||
if gpt_provider == "huggingface" and preferred_hf_models:
|
if gpt_provider == "huggingface" and preferred_hf_models:
|
||||||
model = preferred_hf_models[0]
|
model = preferred_hf_models[0]
|
||||||
@@ -391,17 +382,16 @@ def check_gpt_provider(gpt_provider: str) -> bool:
|
|||||||
supported_providers = ["google", "huggingface"]
|
supported_providers = ["google", "huggingface"]
|
||||||
return gpt_provider in supported_providers
|
return gpt_provider in supported_providers
|
||||||
|
|
||||||
def get_api_key(gpt_provider: str) -> Optional[str]:
|
def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
|
||||||
"""Get API key for the specified provider."""
|
"""Get API key for the specified provider."""
|
||||||
try:
|
try:
|
||||||
api_key_manager = APIKeyManager()
|
|
||||||
provider_mapping = {
|
provider_mapping = {
|
||||||
"google": "gemini",
|
"google": "gemini",
|
||||||
"huggingface": "hf_token"
|
"huggingface": "huggingface"
|
||||||
}
|
}
|
||||||
|
|
||||||
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
||||||
return api_key_manager.get_api_key(mapped_provider)
|
key, _source = tenant_provider_config_resolver.resolve_provider_key(mapped_provider, user_id=user_id)
|
||||||
|
return key
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
|
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -25,10 +25,10 @@ except ImportError:
|
|||||||
HF_HUB_AVAILABLE = False
|
HF_HUB_AVAILABLE = False
|
||||||
InferenceClient = None
|
InferenceClient = None
|
||||||
|
|
||||||
from ..onboarding.api_key_manager import APIKeyManager
|
|
||||||
from services.subscription import PricingService
|
from services.subscription import PricingService
|
||||||
from services.subscription.provider_detection import detect_actual_provider
|
from services.subscription.provider_detection import detect_actual_provider
|
||||||
from utils.logger_utils import get_service_logger
|
from utils.logger_utils import get_service_logger
|
||||||
|
from .tenant_provider_config import tenant_provider_config_resolver
|
||||||
|
|
||||||
logger = get_service_logger("video_generation_service")
|
logger = get_service_logger("video_generation_service")
|
||||||
|
|
||||||
@@ -202,16 +202,10 @@ def _track_video_operation_usage(
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def _get_api_key(provider: str) -> Optional[str]:
|
def _get_api_key(provider: str, user_id: Optional[str] = None) -> Optional[str]:
|
||||||
try:
|
try:
|
||||||
manager = APIKeyManager()
|
key, _source = tenant_provider_config_resolver.resolve_provider_key(provider, user_id=user_id)
|
||||||
mapping = {
|
return key
|
||||||
"huggingface": "hf_token",
|
|
||||||
"wavespeed": "wavespeed", # WaveSpeed API key
|
|
||||||
"gemini": "gemini", # placeholder for Veo 3
|
|
||||||
"openai": "openai_api_key", # placeholder for Sora
|
|
||||||
}
|
|
||||||
return manager.get_api_key(mapping.get(provider, provider))
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
|
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -297,6 +291,7 @@ def _coerce_video_bytes(output: Any) -> bytes:
|
|||||||
|
|
||||||
|
|
||||||
def _generate_with_huggingface(
|
def _generate_with_huggingface(
|
||||||
|
user_id: Optional[str],
|
||||||
prompt: str,
|
prompt: str,
|
||||||
num_frames: int = 24 * 4,
|
num_frames: int = 24 * 4,
|
||||||
guidance_scale: float = 7.5,
|
guidance_scale: float = 7.5,
|
||||||
@@ -311,7 +306,7 @@ def _generate_with_huggingface(
|
|||||||
if not HF_HUB_AVAILABLE:
|
if not HF_HUB_AVAILABLE:
|
||||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||||
|
|
||||||
token = _get_api_key("huggingface")
|
token = _get_api_key("huggingface", user_id=user_id)
|
||||||
if not token:
|
if not token:
|
||||||
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
|
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
|
||||||
|
|
||||||
@@ -618,7 +613,13 @@ async def ai_video_generate(
|
|||||||
- height: Video height in pixels (for image-to-video)
|
- height: Video height in pixels (for image-to-video)
|
||||||
- metadata: Additional metadata dict
|
- metadata: Additional metadata dict
|
||||||
"""
|
"""
|
||||||
logger.info(f"[video_gen] operation={operation_type}, provider={provider}")
|
cfg = tenant_provider_config_resolver.resolve(
|
||||||
|
modality="video",
|
||||||
|
user_id=user_id,
|
||||||
|
explicit_provider=provider,
|
||||||
|
)
|
||||||
|
provider = (cfg.selected_providers or [provider])[0]
|
||||||
|
logger.info(f"[video_gen] operation={operation_type}, provider={provider}, credential_source={cfg.credential_source.get(provider)}")
|
||||||
|
|
||||||
# Enforce authentication usage like text gen does
|
# Enforce authentication usage like text gen does
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@@ -679,7 +680,7 @@ async def ai_video_generate(
|
|||||||
try:
|
try:
|
||||||
if operation_type == "text-to-video":
|
if operation_type == "text-to-video":
|
||||||
if provider == "huggingface":
|
if provider == "huggingface":
|
||||||
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
|
video_bytes = _generate_with_huggingface(user_id=user_id, prompt=prompt, **kwargs)
|
||||||
result = {
|
result = {
|
||||||
"video_bytes": video_bytes,
|
"video_bytes": video_bytes,
|
||||||
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
|
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
|
||||||
|
|||||||
168
backend/services/llm_providers/tenant_provider_config.py
Normal file
168
backend/services/llm_providers/tenant_provider_config.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
from utils.logger_utils import get_service_logger
|
||||||
|
|
||||||
|
logger = get_service_logger("tenant_provider_config")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TenantProviderConfig:
|
||||||
|
selected_providers: List[str]
|
||||||
|
model_policy: Dict[str, Optional[str]]
|
||||||
|
credential_source: Dict[str, str]
|
||||||
|
provider_keys: Dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TenantProviderConfigResolver:
|
||||||
|
"""Resolves per-request provider, model policy, and credential source.
|
||||||
|
|
||||||
|
Priority: tenant-scoped DB key (future vault hook) -> environment defaults.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_PROVIDER_ALIASES: Dict[str, Tuple[str, ...]] = {
|
||||||
|
"gemini": ("gemini", "google", "google_api_key", "gemini_api_key"),
|
||||||
|
"google": ("gemini", "google", "google_api_key", "gemini_api_key"),
|
||||||
|
"huggingface": ("huggingface", "hf", "hf_token"),
|
||||||
|
"hf": ("huggingface", "hf", "hf_token"),
|
||||||
|
"stability": ("stability", "stability_api_key"),
|
||||||
|
"wavespeed": ("wavespeed", "wavespeed_api_key"),
|
||||||
|
"openai": ("openai", "openai_api_key"),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ENV_VARS: Dict[str, Tuple[str, ...]] = {
|
||||||
|
"gemini": ("GEMINI_API_KEY", "GOOGLE_API_KEY"),
|
||||||
|
"huggingface": ("HF_TOKEN",),
|
||||||
|
"stability": ("STABILITY_API_KEY",),
|
||||||
|
"wavespeed": ("WAVESPEED_API_KEY",),
|
||||||
|
"openai": ("OPENAI_API_KEY",),
|
||||||
|
}
|
||||||
|
|
||||||
|
_ENV_PROVIDER_DEFAULTS: Dict[str, str] = {
|
||||||
|
"text": "GPT_PROVIDER",
|
||||||
|
"image": "GPT_PROVIDER",
|
||||||
|
"video": "VIDEO_PROVIDER",
|
||||||
|
"audio": "AUDIO_PROVIDER",
|
||||||
|
}
|
||||||
|
|
||||||
|
_DEFAULT_MODELS: Dict[Tuple[str, str], str] = {
|
||||||
|
("text", "google"): "gemini-2.0-flash-001",
|
||||||
|
("text", "huggingface"): "mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||||
|
("image", "wavespeed"): "qwen-image",
|
||||||
|
("image", "huggingface"): "black-forest-labs/FLUX.1-Krea-dev",
|
||||||
|
("video", "huggingface"): "tencent/HunyuanVideo",
|
||||||
|
("video", "wavespeed"): "hunyuan-video-1.5",
|
||||||
|
("audio", "wavespeed"): "minimax-speech-02-hd",
|
||||||
|
}
|
||||||
|
|
||||||
|
def resolve(self, modality: str, user_id: Optional[str], explicit_provider: Optional[str] = None) -> TenantProviderConfig:
|
||||||
|
provider_candidates = self._resolve_providers(modality=modality, explicit_provider=explicit_provider)
|
||||||
|
provider_keys: Dict[str, str] = {}
|
||||||
|
credential_source: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for provider in provider_candidates:
|
||||||
|
key, source = self.resolve_provider_key(provider=provider, user_id=user_id)
|
||||||
|
if key:
|
||||||
|
provider_keys[provider] = key
|
||||||
|
credential_source[provider] = source
|
||||||
|
|
||||||
|
selected_providers = [p for p in provider_candidates if p in provider_keys]
|
||||||
|
if not selected_providers and provider_candidates:
|
||||||
|
selected_providers = [provider_candidates[0]]
|
||||||
|
|
||||||
|
model_policy = {
|
||||||
|
"modality": modality,
|
||||||
|
"default_model": self._DEFAULT_MODELS.get((modality, selected_providers[0]), None) if selected_providers else None,
|
||||||
|
"allow_fallback": True,
|
||||||
|
}
|
||||||
|
return TenantProviderConfig(
|
||||||
|
selected_providers=selected_providers,
|
||||||
|
model_policy=model_policy,
|
||||||
|
credential_source=credential_source,
|
||||||
|
provider_keys=provider_keys,
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_provider_key(self, provider: str, user_id: Optional[str]) -> Tuple[Optional[str], str]:
|
||||||
|
normalized = self._normalize_provider(provider)
|
||||||
|
|
||||||
|
tenant_key = self._get_tenant_key_from_db(user_id=user_id, provider=normalized)
|
||||||
|
if tenant_key:
|
||||||
|
return tenant_key, "tenant_db"
|
||||||
|
|
||||||
|
env_key = self._get_key_from_env(normalized)
|
||||||
|
if env_key:
|
||||||
|
return env_key, "env_default"
|
||||||
|
|
||||||
|
return None, "missing"
|
||||||
|
|
||||||
|
def _resolve_providers(self, modality: str, explicit_provider: Optional[str]) -> List[str]:
|
||||||
|
if explicit_provider:
|
||||||
|
return [self._normalize_provider(explicit_provider)]
|
||||||
|
|
||||||
|
env_provider = os.getenv(self._ENV_PROVIDER_DEFAULTS.get(modality, ""), "").strip().lower()
|
||||||
|
if env_provider:
|
||||||
|
normalized = self._normalize_provider(env_provider)
|
||||||
|
return [normalized]
|
||||||
|
|
||||||
|
defaults = {
|
||||||
|
"text": ["google", "huggingface"],
|
||||||
|
"image": ["wavespeed", "gemini", "huggingface", "stability"],
|
||||||
|
"video": ["huggingface", "wavespeed"],
|
||||||
|
"audio": ["wavespeed"],
|
||||||
|
}
|
||||||
|
return defaults.get(modality, ["google"])
|
||||||
|
|
||||||
|
def _normalize_provider(self, provider: str) -> str:
|
||||||
|
provider_l = (provider or "").strip().lower()
|
||||||
|
if provider_l in ("gemini", "google"):
|
||||||
|
return "gemini"
|
||||||
|
if provider_l in ("hf", "huggingface", "hf_response_api"):
|
||||||
|
return "huggingface"
|
||||||
|
return provider_l
|
||||||
|
|
||||||
|
def _get_tenant_key_from_db(self, user_id: Optional[str], provider: str) -> Optional[str]:
|
||||||
|
if not user_id:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from services.database import get_session_for_user
|
||||||
|
from models.onboarding import OnboardingSession, APIKey
|
||||||
|
|
||||||
|
db = get_session_for_user(user_id)
|
||||||
|
if not db:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
session = (
|
||||||
|
db.query(OnboardingSession)
|
||||||
|
.filter(OnboardingSession.user_id == user_id)
|
||||||
|
.order_by(OnboardingSession.updated_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
if not session:
|
||||||
|
return None
|
||||||
|
|
||||||
|
aliases = self._PROVIDER_ALIASES.get(provider, (provider,))
|
||||||
|
rec = (
|
||||||
|
db.query(APIKey)
|
||||||
|
.filter(APIKey.session_id == session.id, APIKey.provider.in_(aliases))
|
||||||
|
.order_by(APIKey.updated_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
return rec.key if rec and rec.key else None
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug("Tenant DB key lookup failed for provider=%s user_id=%s: %s", provider, user_id, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _get_key_from_env(self, provider: str) -> Optional[str]:
|
||||||
|
for env_var in self._ENV_VARS.get(provider, ()): # pragma: no branch
|
||||||
|
value = os.getenv(env_var)
|
||||||
|
if value:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
tenant_provider_config_resolver = TenantProviderConfigResolver()
|
||||||
Reference in New Issue
Block a user