Merge_PR_420_add_tenant_aware_provider_config_resolver_across_llm_facades
This commit is contained in:
@@ -13,12 +13,16 @@ from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from services.onboarding.api_key_manager import APIKeyManager
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
logger = get_service_logger("audio_generation")
|
||||
|
||||
|
||||
def _get_wavespeed_client(user_id: Optional[str]) -> WaveSpeedClient:
|
||||
key, _source = tenant_provider_config_resolver.resolve_provider_key("wavespeed", user_id=user_id)
|
||||
return WaveSpeedClient(api_key=key)
|
||||
|
||||
class AudioGenerationResult:
|
||||
"""Result of audio generation."""
|
||||
|
||||
@@ -165,7 +169,7 @@ def generate_audio(
|
||||
# Track response time
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
client = _get_wavespeed_client(user_id)
|
||||
audio_bytes = client.generate_speech(
|
||||
text=text,
|
||||
voice_id=voice_id,
|
||||
@@ -424,7 +428,7 @@ def clone_voice(
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
client = _get_wavespeed_client(user_id)
|
||||
preview_audio_bytes = client.voice_clone(
|
||||
audio_bytes=bytes(audio_bytes),
|
||||
custom_voice_id=custom_voice_id,
|
||||
@@ -617,7 +621,7 @@ def qwen3_voice_clone(
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
client = _get_wavespeed_client(user_id)
|
||||
preview_audio_bytes = client.qwen3_voice_clone(
|
||||
audio_bytes=bytes(audio_bytes),
|
||||
text=text,
|
||||
@@ -802,7 +806,7 @@ def qwen3_voice_design(
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
client = _get_wavespeed_client(user_id)
|
||||
preview_audio_bytes = client.voice_design(
|
||||
text=text,
|
||||
voice_description=voice_description,
|
||||
@@ -989,7 +993,7 @@ def cosyvoice_voice_clone(
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
client = _get_wavespeed_client(user_id)
|
||||
preview_audio_bytes = client.cosyvoice_voice_clone(
|
||||
audio_bytes=bytes(audio_bytes),
|
||||
text=text,
|
||||
|
||||
@@ -22,40 +22,34 @@ from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.facade")
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
# User requested WaveSpeed as default provider
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
if gpt_provider.startswith("hf"):
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
def _select_provider(explicit: Optional[str], user_id: Optional[str] = None) -> str:
|
||||
cfg = tenant_provider_config_resolver.resolve(
|
||||
modality="image",
|
||||
user_id=user_id,
|
||||
explicit_provider=explicit,
|
||||
)
|
||||
return (cfg.selected_providers or [explicit or "huggingface"])[0]
|
||||
|
||||
|
||||
def _get_provider(provider_name: str):
|
||||
def _get_provider(provider_name: str, user_id: Optional[str] = None):
|
||||
key, _source = tenant_provider_config_resolver.resolve_provider_key(provider_name, user_id=user_id)
|
||||
if provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider()
|
||||
return HuggingFaceImageProvider(api_key=key)
|
||||
if provider_name == "gemini":
|
||||
if key:
|
||||
os.environ["GEMINI_API_KEY"] = key
|
||||
os.environ.setdefault("GOOGLE_API_KEY", key)
|
||||
return GeminiImageProvider()
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider()
|
||||
return StabilityImageProvider(api_key=key)
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedImageProvider()
|
||||
return WaveSpeedImageProvider(api_key=key)
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
@@ -328,7 +322,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
log_prefix="[Image Generation]"
|
||||
)
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
provider_name = _select_provider(opts.get("provider"), user_id=user_id)
|
||||
|
||||
image_options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
@@ -370,7 +364,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
image_options.model = "qwen-image"
|
||||
|
||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||
provider = _get_provider(provider_name)
|
||||
provider = _get_provider(provider_name, user_id=user_id)
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
<<<<<<< HEAD
|
||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
<<<<<<< HEAD
|
||||
@@ -140,6 +141,12 @@ from .routing_policy import (
|
||||
PREMIUM_HF_MINIMAL_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
]
|
||||
=======
|
||||
|
||||
from .gemini_provider import gemini_text_response, gemini_structured_json_response
|
||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
>>>>>>> pr-420
|
||||
|
||||
|
||||
def llm_text_gen(
|
||||
@@ -199,6 +206,7 @@ def llm_text_gen(
|
||||
frequency_penalty = 0.0
|
||||
presence_penalty = 0.0
|
||||
|
||||
<<<<<<< HEAD
|
||||
# Check for GPT_PROVIDER environment variable
|
||||
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
||||
<<<<<<< HEAD
|
||||
@@ -212,6 +220,19 @@ def llm_text_gen(
|
||||
gpt_provider = "huggingface"
|
||||
model = PREMIUM_DEFAULT_MODEL
|
||||
>>>>>>> pr-417
|
||||
=======
|
||||
provider_cfg = tenant_provider_config_resolver.resolve(
|
||||
modality="text",
|
||||
user_id=user_id,
|
||||
)
|
||||
selected_provider = (provider_cfg.selected_providers or [None])[0]
|
||||
if selected_provider in ["gemini", "google"]:
|
||||
gpt_provider = "google"
|
||||
model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001"
|
||||
elif selected_provider == "huggingface":
|
||||
gpt_provider = "huggingface"
|
||||
model = provider_cfg.model_policy.get("default_model") or "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
>>>>>>> pr-420
|
||||
|
||||
# Determine if we're in strict mode (single provider) or fallback mode (multiple providers)
|
||||
strict_provider_mode = len(provider_list) == 1
|
||||
@@ -308,6 +329,7 @@ def llm_text_gen(
|
||||
blog_output_format = "markdown"
|
||||
blog_length = 2000
|
||||
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
# Check which providers have API keys available using APIKeyManager
|
||||
api_key_manager = APIKeyManager()
|
||||
@@ -406,6 +428,28 @@ def llm_text_gen(
|
||||
available_providers = get_available_text_providers(user_id)
|
||||
provider_sequence = _resolve_provider_sequence(preferred_provider, env_provider_raw, available_providers)
|
||||
>>>>>>> pr-416
|
||||
=======
|
||||
available_providers = []
|
||||
for provider in ("google", "huggingface"):
|
||||
if get_api_key(provider, user_id=user_id):
|
||||
available_providers.append(provider)
|
||||
|
||||
if gpt_provider not in available_providers:
|
||||
logger.warning(f"[llm_text_gen] Provider {gpt_provider} unavailable for user {user_id}, falling back.")
|
||||
if available_providers:
|
||||
gpt_provider = available_providers[0]
|
||||
else:
|
||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||
raise RuntimeError("No LLM API keys configured for tenant or environment defaults.")
|
||||
|
||||
# Ensure downstream provider clients (currently env-based) receive resolved key
|
||||
resolved_key = get_api_key(gpt_provider, user_id=user_id)
|
||||
if gpt_provider == "google" and resolved_key:
|
||||
os.environ["GEMINI_API_KEY"] = resolved_key
|
||||
os.environ.setdefault("GOOGLE_API_KEY", resolved_key)
|
||||
elif gpt_provider == "huggingface" and resolved_key:
|
||||
os.environ["HF_TOKEN"] = resolved_key
|
||||
>>>>>>> pr-420
|
||||
|
||||
<<<<<<< HEAD
|
||||
<<<<<<< HEAD
|
||||
@@ -958,6 +1002,7 @@ def check_gpt_provider(gpt_provider: str) -> bool:
|
||||
>>>>>>> pr-417
|
||||
|
||||
def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
|
||||
<<<<<<< HEAD
|
||||
"""Get API key for the specified provider, preferring tenant-scoped keys."""
|
||||
try:
|
||||
<<<<<<< HEAD
|
||||
@@ -978,6 +1023,17 @@ def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[st
|
||||
=======
|
||||
return get_tenant_api_key(user_id, gpt_provider)
|
||||
>>>>>>> pr-416
|
||||
=======
|
||||
"""Get API key for the specified provider."""
|
||||
try:
|
||||
provider_mapping = {
|
||||
"google": "gemini",
|
||||
"huggingface": "huggingface"
|
||||
}
|
||||
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
||||
key, _source = tenant_provider_config_resolver.resolve_provider_key(mapped_provider, user_id=user_id)
|
||||
return key
|
||||
>>>>>>> pr-420
|
||||
except Exception as e:
|
||||
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
|
||||
return None
|
||||
|
||||
@@ -25,10 +25,10 @@ except ImportError:
|
||||
HF_HUB_AVAILABLE = False
|
||||
InferenceClient = None
|
||||
|
||||
from ..onboarding.api_key_manager import APIKeyManager
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
logger = get_service_logger("video_generation_service")
|
||||
|
||||
@@ -202,16 +202,10 @@ def _track_video_operation_usage(
|
||||
return {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> Optional[str]:
|
||||
def _get_api_key(provider: str, user_id: Optional[str] = None) -> Optional[str]:
|
||||
try:
|
||||
manager = APIKeyManager()
|
||||
mapping = {
|
||||
"huggingface": "hf_token",
|
||||
"wavespeed": "wavespeed", # WaveSpeed API key
|
||||
"gemini": "gemini", # placeholder for Veo 3
|
||||
"openai": "openai_api_key", # placeholder for Sora
|
||||
}
|
||||
return manager.get_api_key(mapping.get(provider, provider))
|
||||
key, _source = tenant_provider_config_resolver.resolve_provider_key(provider, user_id=user_id)
|
||||
return key
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Failed to read API key for {provider}: {e}")
|
||||
return None
|
||||
@@ -297,6 +291,7 @@ def _coerce_video_bytes(output: Any) -> bytes:
|
||||
|
||||
|
||||
def _generate_with_huggingface(
|
||||
user_id: Optional[str],
|
||||
prompt: str,
|
||||
num_frames: int = 24 * 4,
|
||||
guidance_scale: float = 7.5,
|
||||
@@ -311,7 +306,7 @@ def _generate_with_huggingface(
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
token = _get_api_key("huggingface")
|
||||
token = _get_api_key("huggingface", user_id=user_id)
|
||||
if not token:
|
||||
raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.")
|
||||
|
||||
@@ -618,7 +613,13 @@ async def ai_video_generate(
|
||||
- height: Video height in pixels (for image-to-video)
|
||||
- metadata: Additional metadata dict
|
||||
"""
|
||||
logger.info(f"[video_gen] operation={operation_type}, provider={provider}")
|
||||
cfg = tenant_provider_config_resolver.resolve(
|
||||
modality="video",
|
||||
user_id=user_id,
|
||||
explicit_provider=provider,
|
||||
)
|
||||
provider = (cfg.selected_providers or [provider])[0]
|
||||
logger.info(f"[video_gen] operation={operation_type}, provider={provider}, credential_source={cfg.credential_source.get(provider)}")
|
||||
|
||||
# Enforce authentication usage like text gen does
|
||||
if not user_id:
|
||||
@@ -679,7 +680,7 @@ async def ai_video_generate(
|
||||
try:
|
||||
if operation_type == "text-to-video":
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
|
||||
video_bytes = _generate_with_huggingface(user_id=user_id, prompt=prompt, **kwargs)
|
||||
result = {
|
||||
"video_bytes": video_bytes,
|
||||
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
<<<<<<< HEAD
|
||||
"""Tenant-aware provider configuration and API key resolution for LLM providers."""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -81,3 +82,173 @@ def get_available_text_providers(user_id: Optional[str]) -> list[str]:
|
||||
if get_tenant_api_key(user_id, "huggingface"):
|
||||
providers.append("huggingface")
|
||||
return providers
|
||||
=======
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("tenant_provider_config")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TenantProviderConfig:
|
||||
selected_providers: List[str]
|
||||
model_policy: Dict[str, Optional[str]]
|
||||
credential_source: Dict[str, str]
|
||||
provider_keys: Dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
class TenantProviderConfigResolver:
|
||||
"""Resolves per-request provider, model policy, and credential source.
|
||||
|
||||
Priority: tenant-scoped DB key (future vault hook) -> environment defaults.
|
||||
"""
|
||||
|
||||
_PROVIDER_ALIASES: Dict[str, Tuple[str, ...]] = {
|
||||
"gemini": ("gemini", "google", "google_api_key", "gemini_api_key"),
|
||||
"google": ("gemini", "google", "google_api_key", "gemini_api_key"),
|
||||
"huggingface": ("huggingface", "hf", "hf_token"),
|
||||
"hf": ("huggingface", "hf", "hf_token"),
|
||||
"stability": ("stability", "stability_api_key"),
|
||||
"wavespeed": ("wavespeed", "wavespeed_api_key"),
|
||||
"openai": ("openai", "openai_api_key"),
|
||||
}
|
||||
|
||||
_ENV_VARS: Dict[str, Tuple[str, ...]] = {
|
||||
"gemini": ("GEMINI_API_KEY", "GOOGLE_API_KEY"),
|
||||
"huggingface": ("HF_TOKEN",),
|
||||
"stability": ("STABILITY_API_KEY",),
|
||||
"wavespeed": ("WAVESPEED_API_KEY",),
|
||||
"openai": ("OPENAI_API_KEY",),
|
||||
}
|
||||
|
||||
_ENV_PROVIDER_DEFAULTS: Dict[str, str] = {
|
||||
"text": "GPT_PROVIDER",
|
||||
"image": "GPT_PROVIDER",
|
||||
"video": "VIDEO_PROVIDER",
|
||||
"audio": "AUDIO_PROVIDER",
|
||||
}
|
||||
|
||||
_DEFAULT_MODELS: Dict[Tuple[str, str], str] = {
|
||||
("text", "google"): "gemini-2.0-flash-001",
|
||||
("text", "huggingface"): "mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
("image", "wavespeed"): "qwen-image",
|
||||
("image", "huggingface"): "black-forest-labs/FLUX.1-Krea-dev",
|
||||
("video", "huggingface"): "tencent/HunyuanVideo",
|
||||
("video", "wavespeed"): "hunyuan-video-1.5",
|
||||
("audio", "wavespeed"): "minimax-speech-02-hd",
|
||||
}
|
||||
|
||||
def resolve(self, modality: str, user_id: Optional[str], explicit_provider: Optional[str] = None) -> TenantProviderConfig:
|
||||
provider_candidates = self._resolve_providers(modality=modality, explicit_provider=explicit_provider)
|
||||
provider_keys: Dict[str, str] = {}
|
||||
credential_source: Dict[str, str] = {}
|
||||
|
||||
for provider in provider_candidates:
|
||||
key, source = self.resolve_provider_key(provider=provider, user_id=user_id)
|
||||
if key:
|
||||
provider_keys[provider] = key
|
||||
credential_source[provider] = source
|
||||
|
||||
selected_providers = [p for p in provider_candidates if p in provider_keys]
|
||||
if not selected_providers and provider_candidates:
|
||||
selected_providers = [provider_candidates[0]]
|
||||
|
||||
model_policy = {
|
||||
"modality": modality,
|
||||
"default_model": self._DEFAULT_MODELS.get((modality, selected_providers[0]), None) if selected_providers else None,
|
||||
"allow_fallback": True,
|
||||
}
|
||||
return TenantProviderConfig(
|
||||
selected_providers=selected_providers,
|
||||
model_policy=model_policy,
|
||||
credential_source=credential_source,
|
||||
provider_keys=provider_keys,
|
||||
)
|
||||
|
||||
def resolve_provider_key(self, provider: str, user_id: Optional[str]) -> Tuple[Optional[str], str]:
|
||||
normalized = self._normalize_provider(provider)
|
||||
|
||||
tenant_key = self._get_tenant_key_from_db(user_id=user_id, provider=normalized)
|
||||
if tenant_key:
|
||||
return tenant_key, "tenant_db"
|
||||
|
||||
env_key = self._get_key_from_env(normalized)
|
||||
if env_key:
|
||||
return env_key, "env_default"
|
||||
|
||||
return None, "missing"
|
||||
|
||||
def _resolve_providers(self, modality: str, explicit_provider: Optional[str]) -> List[str]:
|
||||
if explicit_provider:
|
||||
return [self._normalize_provider(explicit_provider)]
|
||||
|
||||
env_provider = os.getenv(self._ENV_PROVIDER_DEFAULTS.get(modality, ""), "").strip().lower()
|
||||
if env_provider:
|
||||
normalized = self._normalize_provider(env_provider)
|
||||
return [normalized]
|
||||
|
||||
defaults = {
|
||||
"text": ["google", "huggingface"],
|
||||
"image": ["wavespeed", "gemini", "huggingface", "stability"],
|
||||
"video": ["huggingface", "wavespeed"],
|
||||
"audio": ["wavespeed"],
|
||||
}
|
||||
return defaults.get(modality, ["google"])
|
||||
|
||||
def _normalize_provider(self, provider: str) -> str:
|
||||
provider_l = (provider or "").strip().lower()
|
||||
if provider_l in ("gemini", "google"):
|
||||
return "gemini"
|
||||
if provider_l in ("hf", "huggingface", "hf_response_api"):
|
||||
return "huggingface"
|
||||
return provider_l
|
||||
|
||||
def _get_tenant_key_from_db(self, user_id: Optional[str], provider: str) -> Optional[str]:
|
||||
if not user_id:
|
||||
return None
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from models.onboarding import OnboardingSession, APIKey
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
return None
|
||||
try:
|
||||
session = (
|
||||
db.query(OnboardingSession)
|
||||
.filter(OnboardingSession.user_id == user_id)
|
||||
.order_by(OnboardingSession.updated_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
aliases = self._PROVIDER_ALIASES.get(provider, (provider,))
|
||||
rec = (
|
||||
db.query(APIKey)
|
||||
.filter(APIKey.session_id == session.id, APIKey.provider.in_(aliases))
|
||||
.order_by(APIKey.updated_at.desc())
|
||||
.first()
|
||||
)
|
||||
return rec.key if rec and rec.key else None
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as exc:
|
||||
logger.debug("Tenant DB key lookup failed for provider=%s user_id=%s: %s", provider, user_id, exc)
|
||||
return None
|
||||
|
||||
def _get_key_from_env(self, provider: str) -> Optional[str]:
|
||||
for env_var in self._ENV_VARS.get(provider, ()): # pragma: no branch
|
||||
value = os.getenv(env_var)
|
||||
if value:
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
tenant_provider_config_resolver = TenantProviderConfigResolver()
|
||||
>>>>>>> pr-420
|
||||
|
||||
Reference in New Issue
Block a user