Make SIF fail fast and add low-cost remote LLM fallback
This commit is contained in:
@@ -82,11 +82,29 @@ from tenacity import (
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from openai import NotFoundError
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
NotFoundError = Exception
|
||||
logger.warn("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
HF_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
]
|
||||
|
||||
|
||||
def _fallback_model_sequence(model: str):
|
||||
sequence = [model] + HF_FALLBACK_MODELS
|
||||
seen = set()
|
||||
for candidate in sequence:
|
||||
if candidate and candidate not in seen:
|
||||
seen.add(candidate)
|
||||
yield candidate
|
||||
|
||||
def get_huggingface_api_key() -> str:
|
||||
"""Get Hugging Face API key with proper error handling."""
|
||||
api_key = os.getenv('HF_TOKEN')
|
||||
@@ -197,14 +215,27 @@ def huggingface_text_response(
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
# Make the API call using Chat Completions
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
if candidate_model != model:
|
||||
logger.warning("HF text generation switched to fallback model: %s", candidate_model)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
logger.warning("HF model not found: %s. Trying fallback model.", candidate_model)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
raise last_error or Exception("Hugging Face text generation failed: all fallback models failed")
|
||||
|
||||
# Extract text from response
|
||||
generated_text = response.choices[0].message.content
|
||||
@@ -338,13 +369,27 @@ def huggingface_structured_json_response(
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"} # Try to enforce JSON mode if supported
|
||||
)
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"} # Try to enforce JSON mode if supported
|
||||
)
|
||||
if candidate_model != model:
|
||||
logger.warning("HF structured generation switched to fallback model: %s", candidate_model)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
logger.warning("HF structured model not found: %s. Trying fallback model.", candidate_model)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
raise last_error or Exception("Hugging Face structured generation failed: all fallback models failed")
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
|
||||
@@ -379,14 +424,28 @@ def huggingface_structured_json_response(
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Hugging Face API call failed: {e}")
|
||||
# If 422 Unprocessable Entity (often due to response_format not supported), retry without it
|
||||
if "422" in str(e) or "not supported" in str(e).lower():
|
||||
if "422" in str(e) or "not supported" in str(e).lower() or isinstance(e, NotFoundError):
|
||||
logger.info("Retrying without response_format...")
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
if candidate_model != model:
|
||||
logger.warning("HF structured no-response_format fallback model: %s", candidate_model)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
logger.warning("HF structured model not found (no response_format path): %s", candidate_model)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
raise last_error or e
|
||||
response_text = response.choices[0].message.content
|
||||
# ... (same parsing logic would apply, simplified here for brevity)
|
||||
try:
|
||||
|
||||
@@ -6,7 +6,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
@@ -16,7 +16,13 @@ from .gemini_provider import gemini_text_response, gemini_structured_json_respon
|
||||
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
|
||||
|
||||
|
||||
def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct: Optional[Dict[str, Any]] = None, user_id: str = None) -> str:
|
||||
def llm_text_gen(
|
||||
prompt: str,
|
||||
system_prompt: Optional[str] = None,
|
||||
json_struct: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None,
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
|
||||
@@ -54,7 +60,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
|
||||
# Default blog characteristics
|
||||
blog_tone = "Professional"
|
||||
@@ -80,7 +86,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
else:
|
||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
||||
@@ -93,9 +99,13 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
else:
|
||||
raise RuntimeError("No supported providers available.")
|
||||
|
||||
if gpt_provider == "huggingface" and preferred_hf_models:
|
||||
model = preferred_hf_models[0]
|
||||
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
|
||||
@@ -303,7 +313,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
elif fallback_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3"
|
||||
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
|
||||
if fallback_provider == "google":
|
||||
if json_struct:
|
||||
@@ -330,7 +340,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3",
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
@@ -338,7 +348,7 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
else:
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3",
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
@@ -394,4 +404,4 @@ def get_api_key(gpt_provider: str) -> Optional[str]:
|
||||
return api_key_manager.get_api_key(mapped_provider)
|
||||
except Exception as e:
|
||||
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
|
||||
return None
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user