diff --git a/backend/services/llm_providers/main_audio_generation.py b/backend/services/llm_providers/main_audio_generation.py index 0621c958..a54454aa 100644 --- a/backend/services/llm_providers/main_audio_generation.py +++ b/backend/services/llm_providers/main_audio_generation.py @@ -13,12 +13,16 @@ from loguru import logger from fastapi import HTTPException from services.wavespeed.client import WaveSpeedClient -from services.onboarding.api_key_manager import APIKeyManager from utils.logger_utils import get_service_logger +from .tenant_provider_config import tenant_provider_config_resolver logger = get_service_logger("audio_generation") +def _get_wavespeed_client(user_id: Optional[str]) -> WaveSpeedClient: + key, _source = tenant_provider_config_resolver.resolve_provider_key("wavespeed", user_id=user_id) + return WaveSpeedClient(api_key=key) + class AudioGenerationResult: """Result of audio generation.""" @@ -165,7 +169,7 @@ def generate_audio( # Track response time import time start_time = time.time() - client = WaveSpeedClient() + client = _get_wavespeed_client(user_id) audio_bytes = client.generate_speech( text=text, voice_id=voice_id, @@ -424,7 +428,7 @@ def clone_voice( import time start_time = time.time() - client = WaveSpeedClient() + client = _get_wavespeed_client(user_id) preview_audio_bytes = client.voice_clone( audio_bytes=bytes(audio_bytes), custom_voice_id=custom_voice_id, @@ -617,7 +621,7 @@ def qwen3_voice_clone( import time start_time = time.time() - client = WaveSpeedClient() + client = _get_wavespeed_client(user_id) preview_audio_bytes = client.qwen3_voice_clone( audio_bytes=bytes(audio_bytes), text=text, @@ -802,7 +806,7 @@ def qwen3_voice_design( import time start_time = time.time() - client = WaveSpeedClient() + client = _get_wavespeed_client(user_id) preview_audio_bytes = client.voice_design( text=text, voice_description=voice_description, @@ -989,7 +993,7 @@ def cosyvoice_voice_clone( import time start_time = time.time() - client = WaveSpeedClient() + client = _get_wavespeed_client(user_id) preview_audio_bytes = client.cosyvoice_voice_clone( audio_bytes=bytes(audio_bytes), text=text, diff --git a/backend/services/llm_providers/main_image_generation.py b/backend/services/llm_providers/main_image_generation.py index f9140b8a..ad8042de 100644 --- a/backend/services/llm_providers/main_image_generation.py +++ b/backend/services/llm_providers/main_image_generation.py @@ -22,40 +22,34 @@ from .image_generation.base import FaceSwapOptions, FaceSwapProvider from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider from utils.logger_utils import get_service_logger +from .tenant_provider_config import tenant_provider_config_resolver logger = get_service_logger("image_generation.facade") -def _select_provider(explicit: Optional[str]) -> str: - if explicit: - return explicit - - # User requested WaveSpeed as default provider - if os.getenv("WAVESPEED_API_KEY"): - return "wavespeed" - - gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower() - if gpt_provider.startswith("gemini"): - return "gemini" - if gpt_provider.startswith("hf"): - return "huggingface" - if os.getenv("STABILITY_API_KEY"): - return "stability" - - # Fallback to huggingface to enable a path if configured - return "huggingface" +def _select_provider(explicit: Optional[str], user_id: Optional[str] = None) -> str: + cfg = tenant_provider_config_resolver.resolve( + modality="image", + user_id=user_id, + explicit_provider=explicit, + ) + return (cfg.selected_providers or [explicit or "huggingface"])[0] -def _get_provider(provider_name: str): +def _get_provider(provider_name: str, user_id: Optional[str] = None): + key, _source = tenant_provider_config_resolver.resolve_provider_key(provider_name, user_id=user_id) if provider_name == "huggingface": - return HuggingFaceImageProvider() + return HuggingFaceImageProvider(api_key=key) if provider_name == "gemini": + if key: + os.environ["GEMINI_API_KEY"] = key + os.environ.setdefault("GOOGLE_API_KEY", key) return GeminiImageProvider() if provider_name == "stability": - return StabilityImageProvider() + return StabilityImageProvider(api_key=key) if provider_name == "wavespeed": - return WaveSpeedImageProvider() + return WaveSpeedImageProvider(api_key=key) raise ValueError(f"Unknown image provider: {provider_name}") @@ -328,7 +322,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i log_prefix="[Image Generation]" ) opts = options or {} - provider_name = _select_provider(opts.get("provider")) + provider_name = _select_provider(opts.get("provider"), user_id=user_id) image_options = ImageGenerationOptions( prompt=prompt, @@ -370,7 +364,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i image_options.model = "qwen-image" logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model) - provider = _get_provider(provider_name) + provider = _get_provider(provider_name, user_id=user_id) # Track response time import time diff --git a/backend/services/llm_providers/main_text_generation.py b/backend/services/llm_providers/main_text_generation.py index dd4ec672..2e766d29 100644 --- a/backend/services/llm_providers/main_text_generation.py +++ b/backend/services/llm_providers/main_text_generation.py @@ -10,10 +10,10 @@ from typing import Optional, Dict, Any, List from datetime import datetime from loguru import logger from fastapi import HTTPException -from ..onboarding.api_key_manager import APIKeyManager from .gemini_provider import gemini_text_response, gemini_structured_json_response from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response +from .tenant_provider_config import tenant_provider_config_resolver def llm_text_gen( @@ -53,14 +53,17 @@ def llm_text_gen( frequency_penalty = 0.0 presence_penalty = 0.0 - # Check for GPT_PROVIDER environment variable - env_provider = os.getenv('GPT_PROVIDER', '').lower() - if env_provider in ['gemini', 'google']: + provider_cfg = tenant_provider_config_resolver.resolve( + modality="text", + user_id=user_id, + ) + selected_provider = (provider_cfg.selected_providers or [None])[0] + if selected_provider in ["gemini", "google"]: gpt_provider = "google" - model = "gemini-2.0-flash-001" - elif env_provider in ['hf_response_api', 'huggingface', 'hf']: + model = provider_cfg.model_policy.get("default_model") or "gemini-2.0-flash-001" + elif selected_provider == "huggingface": gpt_provider = "huggingface" - model = "mistralai/Mistral-7B-Instruct-v0.3:groq" + model = provider_cfg.model_policy.get("default_model") or "mistralai/Mistral-7B-Instruct-v0.3:groq" # Default blog characteristics blog_tone = "Professional" @@ -70,38 +73,26 @@ def llm_text_gen( blog_output_format = "markdown" blog_length = 2000 - # Check which providers have API keys available using APIKeyManager - api_key_manager = APIKeyManager() available_providers = [] - if api_key_manager.get_api_key("gemini"): - available_providers.append("google") - if api_key_manager.get_api_key("hf_token"): - available_providers.append("huggingface") - - # If no environment variable set, auto-detect based on available keys - if not env_provider: - # Prefer Google Gemini if available, otherwise use Hugging Face - if "google" in available_providers: - gpt_provider = "google" - model = "gemini-2.0-flash-001" - elif "huggingface" in available_providers: - gpt_provider = "huggingface" - model = "mistralai/Mistral-7B-Instruct-v0.3:groq" + for provider in ("google", "huggingface"): + if get_api_key(provider, user_id=user_id): + available_providers.append(provider) + + if gpt_provider not in available_providers: + logger.warning(f"[llm_text_gen] Provider {gpt_provider} unavailable for user {user_id}, falling back.") + if available_providers: + gpt_provider = available_providers[0] else: logger.error("[llm_text_gen] No API keys found for supported providers.") - raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.") - else: - # Environment variable was set, validate it's supported - if gpt_provider not in available_providers: - logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers") - if "google" in available_providers: - gpt_provider = "google" - model = "gemini-2.0-flash-001" - elif "huggingface" in available_providers: - gpt_provider = "huggingface" - model = "mistralai/Mistral-7B-Instruct-v0.3:groq" - else: - raise RuntimeError("No supported providers available.") + raise RuntimeError("No LLM API keys configured for tenant or environment defaults.") + + # Ensure downstream provider clients (currently env-based) receive resolved key + resolved_key = get_api_key(gpt_provider, user_id=user_id) + if gpt_provider == "google" and resolved_key: + os.environ["GEMINI_API_KEY"] = resolved_key + os.environ.setdefault("GOOGLE_API_KEY", resolved_key) + elif gpt_provider == "huggingface" and resolved_key: + os.environ["HF_TOKEN"] = resolved_key if gpt_provider == "huggingface" and preferred_hf_models: model = preferred_hf_models[0] @@ -391,17 +382,16 @@ def check_gpt_provider(gpt_provider: str) -> bool: supported_providers = ["google", "huggingface"] return gpt_provider in supported_providers -def get_api_key(gpt_provider: str) -> Optional[str]: +def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]: """Get API key for the specified provider.""" try: - api_key_manager = APIKeyManager() provider_mapping = { "google": "gemini", - "huggingface": "hf_token" + "huggingface": "huggingface" } - mapped_provider = provider_mapping.get(gpt_provider, gpt_provider) - return api_key_manager.get_api_key(mapped_provider) + key, _source = tenant_provider_config_resolver.resolve_provider_key(mapped_provider, user_id=user_id) + return key except Exception as e: logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}") return None diff --git a/backend/services/llm_providers/main_video_generation.py b/backend/services/llm_providers/main_video_generation.py index 70f617f2..e1fa1ef6 100644 --- a/backend/services/llm_providers/main_video_generation.py +++ b/backend/services/llm_providers/main_video_generation.py @@ -25,10 +25,10 @@ except ImportError: HF_HUB_AVAILABLE = False InferenceClient = None -from ..onboarding.api_key_manager import APIKeyManager from services.subscription import PricingService from services.subscription.provider_detection import detect_actual_provider from utils.logger_utils import get_service_logger +from .tenant_provider_config import tenant_provider_config_resolver logger = get_service_logger("video_generation_service") @@ -202,16 +202,10 @@ def _track_video_operation_usage( return {} -def _get_api_key(provider: str) -> Optional[str]: +def _get_api_key(provider: str, user_id: Optional[str] = None) -> Optional[str]: try: - manager = APIKeyManager() - mapping = { - "huggingface": "hf_token", - "wavespeed": "wavespeed", # WaveSpeed API key - "gemini": "gemini", # placeholder for Veo 3 - "openai": "openai_api_key", # placeholder for Sora - } - return manager.get_api_key(mapping.get(provider, provider)) + key, _source = tenant_provider_config_resolver.resolve_provider_key(provider, user_id=user_id) + return key except Exception as e: logger.error(f"[video_gen] Failed to read API key for {provider}: {e}") return None @@ -297,6 +291,7 @@ def _coerce_video_bytes(output: Any) -> bytes: def _generate_with_huggingface( + user_id: Optional[str], prompt: str, num_frames: int = 24 * 4, guidance_scale: float = 7.5, @@ -311,7 +306,7 @@ def _generate_with_huggingface( if not HF_HUB_AVAILABLE: raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub") - token = _get_api_key("huggingface") + token = _get_api_key("huggingface", user_id=user_id) if not token: raise RuntimeError("HF token not configured. Set an hf_token in APIKeyManager.") @@ -618,7 +613,13 @@ async def ai_video_generate( - height: Video height in pixels (for image-to-video) - metadata: Additional metadata dict """ - logger.info(f"[video_gen] operation={operation_type}, provider={provider}") + cfg = tenant_provider_config_resolver.resolve( + modality="video", + user_id=user_id, + explicit_provider=provider, + ) + provider = (cfg.selected_providers or [provider])[0] + logger.info(f"[video_gen] operation={operation_type}, provider={provider}, credential_source={cfg.credential_source.get(provider)}") # Enforce authentication usage like text gen does if not user_id: @@ -679,7 +680,7 @@ async def ai_video_generate( try: if operation_type == "text-to-video": if provider == "huggingface": - video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs) + video_bytes = _generate_with_huggingface(user_id=user_id, prompt=prompt, **kwargs) result = { "video_bytes": video_bytes, "model_name": kwargs.get("model", "tencent/HunyuanVideo"), diff --git a/backend/services/llm_providers/tenant_provider_config.py b/backend/services/llm_providers/tenant_provider_config.py new file mode 100644 index 00000000..c9cf5f3a --- /dev/null +++ b/backend/services/llm_providers/tenant_provider_config.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass, field +from typing import Dict, List, Optional, Tuple + +from utils.logger_utils import get_service_logger + +logger = get_service_logger("tenant_provider_config") + + +@dataclass +class TenantProviderConfig: + selected_providers: List[str] + model_policy: Dict[str, Optional[str]] + credential_source: Dict[str, str] + provider_keys: Dict[str, str] = field(default_factory=dict) + + +class TenantProviderConfigResolver: + """Resolves per-request provider, model policy, and credential source. + + Priority: tenant-scoped DB key (future vault hook) -> environment defaults. + """ + + _PROVIDER_ALIASES: Dict[str, Tuple[str, ...]] = { + "gemini": ("gemini", "google", "google_api_key", "gemini_api_key"), + "google": ("gemini", "google", "google_api_key", "gemini_api_key"), + "huggingface": ("huggingface", "hf", "hf_token"), + "hf": ("huggingface", "hf", "hf_token"), + "stability": ("stability", "stability_api_key"), + "wavespeed": ("wavespeed", "wavespeed_api_key"), + "openai": ("openai", "openai_api_key"), + } + + _ENV_VARS: Dict[str, Tuple[str, ...]] = { + "gemini": ("GEMINI_API_KEY", "GOOGLE_API_KEY"), + "huggingface": ("HF_TOKEN",), + "stability": ("STABILITY_API_KEY",), + "wavespeed": ("WAVESPEED_API_KEY",), + "openai": ("OPENAI_API_KEY",), + } + + _ENV_PROVIDER_DEFAULTS: Dict[str, str] = { + "text": "GPT_PROVIDER", + "image": "GPT_PROVIDER", + "video": "VIDEO_PROVIDER", + "audio": "AUDIO_PROVIDER", + } + + _DEFAULT_MODELS: Dict[Tuple[str, str], str] = { + ("text", "google"): "gemini-2.0-flash-001", + ("text", "huggingface"): "mistralai/Mistral-7B-Instruct-v0.3:groq", + ("image", "wavespeed"): "qwen-image", + ("image", "huggingface"): "black-forest-labs/FLUX.1-Krea-dev", + ("video", "huggingface"): "tencent/HunyuanVideo", + ("video", "wavespeed"): "hunyuan-video-1.5", + ("audio", "wavespeed"): "minimax-speech-02-hd", + } + + def resolve(self, modality: str, user_id: Optional[str], explicit_provider: Optional[str] = None) -> TenantProviderConfig: + provider_candidates = self._resolve_providers(modality=modality, explicit_provider=explicit_provider) + provider_keys: Dict[str, str] = {} + credential_source: Dict[str, str] = {} + + for provider in provider_candidates: + key, source = self.resolve_provider_key(provider=provider, user_id=user_id) + if key: + provider_keys[provider] = key + credential_source[provider] = source + + selected_providers = [p for p in provider_candidates if p in provider_keys] + if not selected_providers and provider_candidates: + selected_providers = [provider_candidates[0]] + + model_policy = { + "modality": modality, + "default_model": self._DEFAULT_MODELS.get((modality, selected_providers[0]), None) if selected_providers else None, + "allow_fallback": True, + } + return TenantProviderConfig( + selected_providers=selected_providers, + model_policy=model_policy, + credential_source=credential_source, + provider_keys=provider_keys, + ) + + def resolve_provider_key(self, provider: str, user_id: Optional[str]) -> Tuple[Optional[str], str]: + normalized = self._normalize_provider(provider) + + tenant_key = self._get_tenant_key_from_db(user_id=user_id, provider=normalized) + if tenant_key: + return tenant_key, "tenant_db" + + env_key = self._get_key_from_env(normalized) + if env_key: + return env_key, "env_default" + + return None, "missing" + + def _resolve_providers(self, modality: str, explicit_provider: Optional[str]) -> List[str]: + if explicit_provider: + return [self._normalize_provider(explicit_provider)] + + env_provider = os.getenv(self._ENV_PROVIDER_DEFAULTS.get(modality, ""), "").strip().lower() + if env_provider: + normalized = self._normalize_provider(env_provider) + return [normalized] + + defaults = { + "text": ["google", "huggingface"], + "image": ["wavespeed", "gemini", "huggingface", "stability"], + "video": ["huggingface", "wavespeed"], + "audio": ["wavespeed"], + } + return defaults.get(modality, ["google"]) + + def _normalize_provider(self, provider: str) -> str: + provider_l = (provider or "").strip().lower() + if provider_l in ("gemini", "google"): + return "gemini" + if provider_l in ("hf", "huggingface", "hf_response_api"): + return "huggingface" + return provider_l + + def _get_tenant_key_from_db(self, user_id: Optional[str], provider: str) -> Optional[str]: + if not user_id: + return None + try: + from services.database import get_session_for_user + from models.onboarding import OnboardingSession, APIKey + + db = get_session_for_user(user_id) + if not db: + return None + try: + session = ( + db.query(OnboardingSession) + .filter(OnboardingSession.user_id == user_id) + .order_by(OnboardingSession.updated_at.desc()) + .first() + ) + if not session: + return None + + aliases = self._PROVIDER_ALIASES.get(provider, (provider,)) + rec = ( + db.query(APIKey) + .filter(APIKey.session_id == session.id, APIKey.provider.in_(aliases)) + .order_by(APIKey.updated_at.desc()) + .first() + ) + return rec.key if rec and rec.key else None + finally: + db.close() + except Exception as exc: + logger.debug("Tenant DB key lookup failed for provider=%s user_id=%s: %s", provider, user_id, exc) + return None + + def _get_key_from_env(self, provider: str) -> Optional[str]: + for env_var in self._ENV_VARS.get(provider, ()): # pragma: no branch + value = os.getenv(env_var) + if value: + return value + return None + + +tenant_provider_config_resolver = TenantProviderConfigResolver()