Add tenant-aware provider config resolver across LLM facades
This commit is contained in:
@@ -22,40 +22,34 @@ from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
|
||||
logger = get_service_logger("image_generation.facade")
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
# User requested WaveSpeed as default provider
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
if gpt_provider.startswith("hf"):
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
def _select_provider(explicit: Optional[str], user_id: Optional[str] = None) -> str:
|
||||
cfg = tenant_provider_config_resolver.resolve(
|
||||
modality="image",
|
||||
user_id=user_id,
|
||||
explicit_provider=explicit,
|
||||
)
|
||||
return (cfg.selected_providers or [explicit or "huggingface"])[0]
|
||||
|
||||
|
||||
def _get_provider(provider_name: str):
|
||||
def _get_provider(provider_name: str, user_id: Optional[str] = None):
|
||||
key, _source = tenant_provider_config_resolver.resolve_provider_key(provider_name, user_id=user_id)
|
||||
if provider_name == "huggingface":
|
||||
return HuggingFaceImageProvider()
|
||||
return HuggingFaceImageProvider(api_key=key)
|
||||
if provider_name == "gemini":
|
||||
if key:
|
||||
os.environ["GEMINI_API_KEY"] = key
|
||||
os.environ.setdefault("GOOGLE_API_KEY", key)
|
||||
return GeminiImageProvider()
|
||||
if provider_name == "stability":
|
||||
return StabilityImageProvider()
|
||||
return StabilityImageProvider(api_key=key)
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedImageProvider()
|
||||
return WaveSpeedImageProvider(api_key=key)
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
@@ -328,7 +322,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
log_prefix="[Image Generation]"
|
||||
)
|
||||
opts = options or {}
|
||||
provider_name = _select_provider(opts.get("provider"))
|
||||
provider_name = _select_provider(opts.get("provider"), user_id=user_id)
|
||||
|
||||
image_options = ImageGenerationOptions(
|
||||
prompt=prompt,
|
||||
@@ -370,7 +364,7 @@ def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_i
|
||||
image_options.model = "qwen-image"
|
||||
|
||||
logger.info("Generating image via provider=%s model=%s", provider_name, image_options.model)
|
||||
provider = _get_provider(provider_name)
|
||||
provider = _get_provider(provider_name, user_id=user_id)
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
|
||||
Reference in New Issue
Block a user