diff --git a/backend/services/llm_providers/main_text_generation.py b/backend/services/llm_providers/main_text_generation.py index 4aeff637..d0105689 100644 --- a/backend/services/llm_providers/main_text_generation.py +++ b/backend/services/llm_providers/main_text_generation.py @@ -92,19 +92,38 @@ def llm_text_gen( # Determine provider based on env vars or tenant config if provider_list: primary_provider = provider_list[0] - if primary_provider in ['gemini', 'google']: + if primary_provider in ['wavespeed', 'wave']: + gpt_provider = "wavespeed" + model = os.getenv('WAVESPEED_TEXT_MODEL', 'Qwen/Qwen2.5-7B-Instruct') + elif primary_provider in ['gemini', 'google']: gpt_provider = "google" model = "gemini-2.0-flash-001" elif primary_provider in ['hf_response_api', 'huggingface', 'hf']: gpt_provider = "huggingface" model = "openai/gpt-oss-120b:cerebras" + elif primary_provider in ['openai', 'gpt']: + gpt_provider = "openai" + model = os.getenv('OPENAI_MODEL', 'gpt-4o-mini') + else: + logger.warning(f"[llm_text_gen] Unknown GPT_PROVIDER: {primary_provider}, using auto-select") + gpt_provider = None + model = None elif preferred_provider: - if preferred_provider in ['gemini', 'google']: + if preferred_provider in ['wavespeed', 'wave']: + gpt_provider = "wavespeed" + model = os.getenv('WAVESPEED_TEXT_MODEL', 'Qwen/Qwen2.5-7B-Instruct') + elif preferred_provider in ['openai', 'gpt']: + gpt_provider = "openai" + model = os.getenv('OPENAI_MODEL', 'gpt-4o-mini') + elif preferred_provider in ['gemini', 'google']: gpt_provider = "google" model = "gemini-2.0-flash-001" elif preferred_provider in ['hf_response_api', 'huggingface', 'hf']: gpt_provider = "huggingface" model = "openai/gpt-oss-120b:cerebras" + else: + gpt_provider = None + model = None else: # Fall back to tenant config provider_cfg = tenant_provider_config_resolver.resolve( @@ -190,9 +209,16 @@ def llm_text_gen( elif gpt_provider == "huggingface": provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking actual_provider_name = "huggingface" # Keep actual provider name for logs + elif gpt_provider == "wavespeed": + provider_enum = APIProvider.OPENAI # Map to OpenAI for tracking purposes + actual_provider_name = "wavespeed" + elif gpt_provider == "openai": + provider_enum = APIProvider.OPENAI + actual_provider_name = "openai" if not provider_enum: - raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking") + # For unknown providers, try to proceed without subscription tracking + logger.warning(f"[llm_text_gen] Unknown provider {gpt_provider}, proceeding without subscription check") # SUBSCRIPTION CHECK - Required and strict enforcement if not user_id: @@ -332,9 +358,19 @@ def llm_text_gen( top_p=top_p, system_prompt=system_instructions ) + elif gpt_provider == "wavespeed": + from services.llm_providers.wavespeed_provider import wavespeed_text_response + response_text = wavespeed_text_response( + prompt=prompt, + model=model or "Qwen/Qwen2.5-7B-Instruct", + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + system_prompt=system_instructions + ) else: logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}") - raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface") + raise RuntimeError(f"Unknown LLM provider: {gpt_provider}. Supported providers: google, huggingface, wavespeed") # TRACK USAGE after successful API call if response_text: