Add tenant-aware provider config resolver across LLM facades

This commit is contained in:
ي
2026-03-12 15:04:42 +05:30
parent b410ece4ca
commit feacbc6d59
5 changed files with 241 additions and 84 deletions

View File

@@ -13,12 +13,16 @@ from loguru import logger
from fastapi import HTTPException from fastapi import HTTPException
from services.wavespeed.client import WaveSpeedClient from services.wavespeed.client import WaveSpeedClient
from services.onboarding.api_key_manager import APIKeyManager
from utils.logger_utils import get_service_logger from utils.logger_utils import get_service_logger
from .tenant_provider_config import tenant_provider_config_resolver
logger = get_service_logger("audio_generation") logger = get_service_logger("audio_generation")
def _get_wavespeed_client(user_id: Optional[str]) -> WaveSpeedClient:
key, _source = tenant_provider_config_resolver.resolve_provider_key("wavespeed", user_id=user_id)
return WaveSpeedClient(api_key=key)
class AudioGenerationResult: class AudioGenerationResult:
"""Result of audio generation.""" """Result of audio generation."""
@@ -165,7 +169,7 @@ def generate_audio(
# Track response time # Track response time
import time import time
start_time = time.time() start_time = time.time()
client = WaveSpeedClient() client = _get_wavespeed_client(user_id)
audio_bytes = client.generate_speech( audio_bytes = client.generate_speech(
text=text, text=text,
voice_id=voice_id, voice_id=voice_id,
@@ -424,7 +428,7 @@ def clone_voice(
import time import time
start_time = time.time() start_time = time.time()
client = WaveSpeedClient() client = _get_wavespeed_client(user_id)
preview_audio_bytes = client.voice_clone( preview_audio_bytes = client.voice_clone(
audio_bytes=bytes(audio_bytes), audio_bytes=bytes(audio_bytes),
custom_voice_id=custom_voice_id, custom_voice_id=custom_voice_id,
@@ -617,7 +621,7 @@ def qwen3_voice_clone(
import time import time
start_time = time.time() start_time = time.time()
client = WaveSpeedClient() client = _get_wavespeed_client(user_id)
preview_audio_bytes = client.qwen3_voice_clone( preview_audio_bytes = client.qwen3_voice_clone(
audio_bytes=bytes(audio_bytes), audio_bytes=bytes(audio_bytes),
text=text, text=text,
@@ -802,7 +806,7 @@ def qwen3_voice_design(
import time import time
start_time = time.time() start_time = time.time()
client = WaveSpeedClient() client = _get_wavespeed_client(user_id)
preview_audio_bytes = client.voice_design( preview_audio_bytes = client.voice_design(
text=text, text=text,
voice_description=voice_description, voice_description=voice_description,
@@ -989,7 +993,7 @@ def cosyvoice_voice_clone(
import time import time
start_time = time.time() start_time = time.time()
client = WaveSpeedClient() client = _get_wavespeed_client(user_id)
preview_audio_bytes = client.cosyvoice_voice_clone( preview_audio_bytes = client.cosyvoice_voice_clone(
audio_bytes=bytes(audio_bytes), audio_bytes=bytes(audio_bytes),
text=text, text=text,

View File

@@ -22,40 +22,34 @@ from .image_generation.base import FaceSwapOptions, FaceSwapProvider
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
from utils.logger_utils import get_service_logger from utils.logger_utils import get_service_logger
from .tenant_provider_config import tenant_provider_config_resolver
logger = get_service_logger("image_generation.facade") logger = get_service_logger("image_generation.facade")
def _select_provider(explicit: Optional[str]) -> str: def _select_provider(explicit: Optional[str], user_id: Optional[str] = None) -> str:
if explicit: cfg = tenant_provider_config_resolver.resolve(
return explicit modality="image",
user_id=user_id,
# User requested WaveSpeed as default provider explicit_provider=explicit,
if os.getenv("WAVESPEED_API_KEY"): )
return "wavespeed" return (cfg.selected_providers or [explicit or "huggingface"])[0]
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
if gpt_provider.startswith("gemini"):
return "gemini"
if gpt_provider.startswith("hf"):
return "huggingface"
if os.getenv("STABILITY_API_KEY"):
return "stability"
# Fallback to huggingface to enable a path if configured
return "huggingface"
def _get_provider(provider_name: str): def _get_provider(provider_name: str, user_id: Optional[str] = None):
key, _source = tenant_provider_config_resolver.resolve_provider_key(provider_name, user_id=user_id)
if provider_name == "huggingface": if provider_name == "huggingface":
return HuggingFaceImageProvider() return HuggingFaceImageProvider(api_key=key)
if provider_name == "gemini": if provider_name == "gemini":
if key:
os.environ["GEMINI_API_KEY"] = key
os.environ.setdefault("GOOGLE_API_KEY", key)
return GeminiImageProvider() return GeminiImageProvider()
if provider_name == "stability": if provider_name == "stability":
return StabilityImageProvider() return StabilityImageProvider(api_key=key)
if provider_name == "wavespeed": if provider_name == "wavespeed":
return WaveSpeedImageProvider() return WaveSpeedImageProvider(api_key=key)
raise ValueError(f"Unknown image provider: {provider_name}") raise ValueError(f"Unknown image provider: {provider_name}")
@@ -328,7 +322,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
log_prefix="[Image Generation]" log_prefix="[Image Generation]"
) )
opts = options or {} opts = options or {}
provider_name = _select_provider(opts.get("provider")) provider_name = _select_provider(opts.get("provider"), user_id=user_id)
image_options = ImageGenerationOptions( image_options = ImageGenerationOptions(
prompt=prompt, prompt=prompt,
@@ -370,7 +364,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
image_options.model = "qwen-image" image_options.model = "qwen-image"
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model) logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
provider = _get_provider(provider_name) provider = _get_provider(provider_name, user_id=user_id)
# Track response time # Track response time
import time import time

View File

@@ -10,10 +10,10 @@ from typing import Optional, Dict, Any, List
from datetime import datetime from datetime import datetime
from loguru import logger from loguru import logger
from fastapi import HTTPException from fastapi import HTTPException
from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
from .tenant_provider_config import tenant_provider_config_resolver
def llm_text_gen( def llm_text_gen(
@@ -53,14 +53,17 @@ def llm_text_gen(
frequency_penalty = 0.0 frequency_penalty = 0.0
presence_penalty = 0.0 presence_penalty = 0.0
# Check for GPT_PROVIDER environment variable provider_cfg = tenant_provider_config_resolver.resolve(
env_provider = os.getenv('GPT_PROVIDER', '').lower() modality="text",
if env_provider in ['gemini', 'google']: user_id=user_id,
)
selected_provider = (provider_cfg.selected_providers or [None])[0]
if selected_provider in ["gemini", "google"]:
gpt_provider = "google" gpt_provider = "google"
model = "gemini-2.0-flash-001" model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001"
elif env_provider in ['hf_response_api', 'huggingface', 'hf']: elif selected_provider == "huggingface":
gpt_provider = "huggingface" gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq" model = provider_cfg.model_policy.get("default_model") or "mistralai/Mistral-7B-Instruct-v0.3:groq"
# Default blog characteristics # Default blog characteristics
blog_tone = "Professional" blog_tone = "Professional"
@@ -70,38 +73,26 @@ def llm_text_gen(
blog_output_format = "markdown" blog_output_format = "markdown"
blog_length = 2000 blog_length = 2000
# Check which providers have API keys available using APIKeyManager
api_key_manager = APIKeyManager()
available_providers = [] available_providers = []
if api_key_manager.get_api_key("gemini"): for provider in ("google", "huggingface"):
available_providers.append("google") if get_api_key(provider, user_id=user_id):
if api_key_manager.get_api_key("hf_token"): available_providers.append(provider)
available_providers.append("huggingface")
if gpt_provider not in available_providers:
# If no environment variable set, auto-detect based on available keys logger.warning(f"[llm_text_gen] Provider {gpt_provider} unavailable for user {user_id}, falling back.")
if not env_provider: if available_providers:
# Prefer Google Gemini if available, otherwise use Hugging Face gpt_provider = available_providers[0]
if "google" in available_providers:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
else: else:
logger.error("[llm_text_gen] No API keys found for supported providers.") logger.error("[llm_text_gen] No API keys found for supported providers.")
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.") raise RuntimeError("No LLM API keys configured for tenant or environment defaults.")
else:
# Environment variable was set, validate it's supported # Ensure downstream provider clients (currently env-based) receive resolved key
if gpt_provider not in available_providers: resolved_key = get_api_key(gpt_provider, user_id=user_id)
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers") if gpt_provider == "google" and resolved_key:
if "google" in available_providers: os.environ["GEMINI_API_KEY"] = resolved_key
gpt_provider = "google" os.environ.setdefault("GOOGLE_API_KEY", resolved_key)
model = "gemini-2.0-flash-001" elif gpt_provider == "huggingface" and resolved_key:
elif "huggingface" in available_providers: os.environ["HF_TOKEN"] = resolved_key
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
else:
raise RuntimeError("No supported providers available.")
if gpt_provider == "huggingface" and preferred_hf_models: if gpt_provider == "huggingface" and preferred_hf_models:
model = preferred_hf_models[0] model = preferred_hf_models[0]
@@ -391,17 +382,16 @@ def check_gpt_provider(gpt_provider: str) -> bool:
supported_providers = ["google", "huggingface"] supported_providers = ["google", "huggingface"]
return gpt_provider in supported_providers return gpt_provider in supported_providers
def get_api_key(gpt_provider: str) -> Optional[str]: def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
"""Get API key for the specified provider.""" """Get API key for the specified provider."""
try: try:
api_key_manager = APIKeyManager()
provider_mapping = { provider_mapping = {
"google": "gemini", "google": "gemini",
"huggingface": "hf_token" "huggingface": "huggingface"
} }
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider) mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
return api_key_manager.get_api_key(mapped_provider) key, _source = tenant_provider_config_resolver.resolve_provider_key(mapped_provider, user_id=user_id)
return key
except Exception as e: except Exception as e:
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}") logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
return None return None

View File

@@ -25,10 +25,10 @@ except ImportError:
HF_HUB_AVAILABLE = False HF_HUB_AVAILABLE = False
InferenceClient = None InferenceClient = None
from ..onboarding.api_key_manager import APIKeyManager
from services.subscription import PricingService from services.subscription import PricingService
from services.subscription.provider_detection import detect_actual_provider from services.subscription.provider_detection import detect_actual_provider
from utils.logger_utils import get_service_logger from utils.logger_utils import get_service_logger
from .tenant_provider_config import tenant_provider_config_resolver
logger = get_service_logger("video_generation_service") logger = get_service_logger("video_generation_service")
@@ -202,16 +202,10 @@ def _track_video_operation_usage(
return {} return {}
def _get_api_key(provider: str) -> Optional[str]: def _get_api_key(provider: str, user_id: Optional[str] = None) -> Optional[str]:
try: try:
manager = APIKeyManager() key, _source = tenant_provider_config_resolver.resolve_provider_key(provider, user_id=user_id)
mapping = { return key
"huggingface": "hf_token",
"wavespeed": "wavespeed", # WaveSpeed API key
"gemini": "gemini", # placeholder for Veo 3
"openai": "openai_api_key", # placeholder for Sora
}
return manager.get_api_key(mapping.get(provider, provider))
except Exception as e: except Exception as e:
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}") logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
return None return None
@@ -297,6 +291,7 @@ def _coerce_video_bytes(output: Any) -> bytes:
def _generate_with_huggingface( def _generate_with_huggingface(
user_id: Optional[str],
prompt: str, prompt: str,
num_frames: int = 24 * 4, num_frames: int = 24 * 4,
guidance_scale: float = 7.5, guidance_scale: float = 7.5,
@@ -311,7 +306,7 @@ def _generate_with_huggingface(
if not HF_HUB_AVAILABLE: if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub") raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
token = _get_api_key("huggingface") token = _get_api_key("huggingface", user_id=user_id)
if not token: if not token:
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.") raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
@@ -618,7 +613,13 @@ async def ai_video_generate(
- height: Video height in pixels (for image-to-video) - height: Video height in pixels (for image-to-video)
- metadata: Additional metadata dict - metadata: Additional metadata dict
""" """
logger.info(f"[video_gen] operation={operation_type}, provider={provider}") cfg = tenant_provider_config_resolver.resolve(
modality="video",
user_id=user_id,
explicit_provider=provider,
)
provider = (cfg.selected_providers or [provider])[0]
logger.info(f"[video_gen] operation={operation_type}, provider={provider}, credential_source={cfg.credential_source.get(provider)}")
# Enforce authentication usage like text gen does # Enforce authentication usage like text gen does
if not user_id: if not user_id:
@@ -679,7 +680,7 @@ async def ai_video_generate(
try: try:
if operation_type == "text-to-video": if operation_type == "text-to-video":
if provider == "huggingface": if provider == "huggingface":
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs) video_bytes = _generate_with_huggingface(user_id=user_id, prompt=prompt, **kwargs)
result = { result = {
"video_bytes": video_bytes, "video_bytes": video_bytes,
"model_name": kwargs.get("model", "tencent/HunyuanVideo"), "model_name": kwargs.get("model", "tencent/HunyuanVideo"),

View File

@@ -0,0 +1,168 @@
from __future__ import annotations
import os
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from utils.logger_utils import get_service_logger
logger = get_service_logger("tenant_provider_config")
@dataclass
class TenantProviderConfig:
selected_providers: List[str]
model_policy: Dict[str, Optional[str]]
credential_source: Dict[str, str]
provider_keys: Dict[str, str] = field(default_factory=dict)
class TenantProviderConfigResolver:
"""Resolves per-request provider, model policy, and credential source.
Priority: tenant-scoped DB key (future vault hook) -> environment defaults.
"""
_PROVIDER_ALIASES: Dict[str, Tuple[str, ...]] = {
"gemini": ("gemini", "google", "google_api_key", "gemini_api_key"),
"google": ("gemini", "google", "google_api_key", "gemini_api_key"),
"huggingface": ("huggingface", "hf", "hf_token"),
"hf": ("huggingface", "hf", "hf_token"),
"stability": ("stability", "stability_api_key"),
"wavespeed": ("wavespeed", "wavespeed_api_key"),
"openai": ("openai", "openai_api_key"),
}
_ENV_VARS: Dict[str, Tuple[str, ...]] = {
"gemini": ("GEMINI_API_KEY", "GOOGLE_API_KEY"),
"huggingface": ("HF_TOKEN",),
"stability": ("STABILITY_API_KEY",),
"wavespeed": ("WAVESPEED_API_KEY",),
"openai": ("OPENAI_API_KEY",),
}
_ENV_PROVIDER_DEFAULTS: Dict[str, str] = {
"text": "GPT_PROVIDER",
"image": "GPT_PROVIDER",
"video": "VIDEO_PROVIDER",
"audio": "AUDIO_PROVIDER",
}
_DEFAULT_MODELS: Dict[Tuple[str, str], str] = {
("text", "google"): "gemini-2.0-flash-001",
("text", "huggingface"): "mistralai/Mistral-7B-Instruct-v0.3:groq",
("image", "wavespeed"): "qwen-image",
("image", "huggingface"): "black-forest-labs/FLUX.1-Krea-dev",
("video", "huggingface"): "tencent/HunyuanVideo",
("video", "wavespeed"): "hunyuan-video-1.5",
("audio", "wavespeed"): "minimax-speech-02-hd",
}
def resolve(self, modality: str, user_id: Optional[str], explicit_provider: Optional[str] = None) -> TenantProviderConfig:
provider_candidates = self._resolve_providers(modality=modality, explicit_provider=explicit_provider)
provider_keys: Dict[str, str] = {}
credential_source: Dict[str, str] = {}
for provider in provider_candidates:
key, source = self.resolve_provider_key(provider=provider, user_id=user_id)
if key:
provider_keys[provider] = key
credential_source[provider] = source
selected_providers = [p for p in provider_candidates if p in provider_keys]
if not selected_providers and provider_candidates:
selected_providers = [provider_candidates[0]]
model_policy = {
"modality": modality,
"default_model": self._DEFAULT_MODELS.get((modality, selected_providers[0]), None) if selected_providers else None,
"allow_fallback": True,
}
return TenantProviderConfig(
selected_providers=selected_providers,
model_policy=model_policy,
credential_source=credential_source,
provider_keys=provider_keys,
)
def resolve_provider_key(self, provider: str, user_id: Optional[str]) -> Tuple[Optional[str], str]:
normalized = self._normalize_provider(provider)
tenant_key = self._get_tenant_key_from_db(user_id=user_id, provider=normalized)
if tenant_key:
return tenant_key, "tenant_db"
env_key = self._get_key_from_env(normalized)
if env_key:
return env_key, "env_default"
return None, "missing"
def _resolve_providers(self, modality: str, explicit_provider: Optional[str]) -> List[str]:
if explicit_provider:
return [self._normalize_provider(explicit_provider)]
env_provider = os.getenv(self._ENV_PROVIDER_DEFAULTS.get(modality, ""), "").strip().lower()
if env_provider:
normalized = self._normalize_provider(env_provider)
return [normalized]
defaults = {
"text": ["google", "huggingface"],
"image": ["wavespeed", "gemini", "huggingface", "stability"],
"video": ["huggingface", "wavespeed"],
"audio": ["wavespeed"],
}
return defaults.get(modality, ["google"])
def _normalize_provider(self, provider: str) -> str:
provider_l = (provider or "").strip().lower()
if provider_l in ("gemini", "google"):
return "gemini"
if provider_l in ("hf", "huggingface", "hf_response_api"):
return "huggingface"
return provider_l
def _get_tenant_key_from_db(self, user_id: Optional[str], provider: str) -> Optional[str]:
if not user_id:
return None
try:
from services.database import get_session_for_user
from models.onboarding import OnboardingSession, APIKey
db = get_session_for_user(user_id)
if not db:
return None
try:
session = (
db.query(OnboardingSession)
.filter(OnboardingSession.user_id == user_id)
.order_by(OnboardingSession.updated_at.desc())
.first()
)
if not session:
return None
aliases = self._PROVIDER_ALIASES.get(provider, (provider,))
rec = (
db.query(APIKey)
.filter(APIKey.session_id == session.id, APIKey.provider.in_(aliases))
.order_by(APIKey.updated_at.desc())
.first()
)
return rec.key if rec and rec.key else None
finally:
db.close()
except Exception as exc:
logger.debug("Tenant DB key lookup failed for provider=%s user_id=%s: %s", provider, user_id, exc)
return None
def _get_key_from_env(self, provider: str) -> Optional[str]:
for env_var in self._ENV_VARS.get(provider, ()): # pragma: no branch
value = os.getenv(env_var)
if value:
return value
return None
tenant_provider_config_resolver = TenantProviderConfigResolver()