Add tenant-aware provider config resolver across LLM facades

This commit is contained in:
ي
2026-03-12 15:04:42 +05:30
parent b410ece4ca
commit feacbc6d59
5 changed files with 241 additions and 84 deletions

View File

@@ -22,40 +22,34 @@ from .image_generation.base import FaceSwapOptions, FaceSwapProvider
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
from utils.logger_utils import get_service_logger
from .tenant_provider_config import tenant_provider_config_resolver
logger = get_service_logger("image_generation.facade")
def _select_provider(explicit: Optional[str]) -> str:
if explicit:
return explicit
# User requested WaveSpeed as default provider
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
if gpt_provider.startswith("gemini"):
return "gemini"
if gpt_provider.startswith("hf"):
return "huggingface"
if os.getenv("STABILITY_API_KEY"):
return "stability"
# Fallback to huggingface to enable a path if configured
return "huggingface"
def _select_provider(explicit: Optional[str], user_id: Optional[str] = None) -> str:
cfg = tenant_provider_config_resolver.resolve(
modality="image",
user_id=user_id,
explicit_provider=explicit,
)
return (cfg.selected_providers or [explicit or "huggingface"])[0]
def _get_provider(provider_name: str):
def _get_provider(provider_name: str, user_id: Optional[str] = None):
key, _source = tenant_provider_config_resolver.resolve_provider_key(provider_name, user_id=user_id)
if provider_name == "huggingface":
return HuggingFaceImageProvider()
return HuggingFaceImageProvider(api_key=key)
if provider_name == "gemini":
if key:
os.environ["GEMINI_API_KEY"] = key
os.environ.setdefault("GOOGLE_API_KEY", key)
return GeminiImageProvider()
if provider_name == "stability":
return StabilityImageProvider()
return StabilityImageProvider(api_key=key)
if provider_name == "wavespeed":
return WaveSpeedImageProvider()
return WaveSpeedImageProvider(api_key=key)
raise ValueError(f"Unknown image provider: {provider_name}")
@@ -328,7 +322,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
log_prefix="[Image Generation]"
)
opts = options or {}
provider_name = _select_provider(opts.get("provider"))
provider_name = _select_provider(opts.get("provider"), user_id=user_id)
image_options = ImageGenerationOptions(
prompt=prompt,
@@ -370,7 +364,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
image_options.model = "qwen-image"
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
provider = _get_provider(provider_name)
provider = _get_provider(provider_name, user_id=user_id)
# Track response time
import time