"feat:enhance-podcast-topic-ai"
This commit is contained in:
@@ -51,7 +51,7 @@ import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -76,6 +76,7 @@ logger = get_service_logger("huggingface_provider")
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
@@ -90,10 +91,10 @@ except ImportError:
|
||||
logger.warn("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
HF_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
]
|
||||
|
||||
|
||||
@@ -102,7 +103,7 @@ def _candidate_model_variants(model: str):
|
||||
if not model:
|
||||
return
|
||||
|
||||
# Try configured model first (supports provider suffixes like ":groq")
|
||||
# Try configured model first (supports provider suffixes like ":cerebras")
|
||||
yield model
|
||||
|
||||
# Fallback to base repo id when provider suffix is not recognized by the router
|
||||
@@ -112,8 +113,13 @@ def _candidate_model_variants(model: str):
|
||||
yield base_model
|
||||
|
||||
|
||||
def _fallback_model_sequence(model: str):
|
||||
sequence = [model] + HF_FALLBACK_MODELS
|
||||
def _fallback_model_sequence(model: str, fallback_models: Optional[List[str]] = None):
|
||||
# IMPORTANT: Do not apply implicit global fallback chains.
|
||||
# Callers must explicitly provide fallback_models when they want multi-model retries.
|
||||
if fallback_models:
|
||||
sequence = [model] + fallback_models
|
||||
else:
|
||||
sequence = [model]
|
||||
seen = set()
|
||||
for preferred_model in sequence:
|
||||
for candidate in _candidate_model_variants(preferred_model):
|
||||
@@ -121,6 +127,57 @@ def _fallback_model_sequence(model: str):
|
||||
seen.add(candidate)
|
||||
yield candidate
|
||||
|
||||
|
||||
def _is_non_retryable_hf_error(exc: Exception) -> bool:
|
||||
"""Skip retries for deterministic HF failures (e.g., unknown model ids, billing)."""
|
||||
msg = str(exc).lower()
|
||||
status = getattr(exc, "status_code", None)
|
||||
|
||||
# Non-retryable errors
|
||||
if isinstance(exc, NotFoundError) or "not found" in msg or "404" in msg:
|
||||
return True
|
||||
if status == 402 or "402" in msg or "depleted" in msg or "credits" in msg:
|
||||
return True
|
||||
if status == 401 or "unauthorized" in msg or "401" in msg:
|
||||
return True
|
||||
if status == 403 or "forbidden" in msg or "403" in msg:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _should_retry_hf_error(exc: Exception) -> bool:
|
||||
return not _is_non_retryable_hf_error(exc)
|
||||
|
||||
|
||||
def _classify_hf_error(exc: Exception) -> str:
|
||||
"""Classify HF failures for actionable logs."""
|
||||
msg = str(exc).lower()
|
||||
if any(token in msg for token in ["insufficient", "balance", "quota", "billing", "payment", "402"]):
|
||||
return "billing_or_quota"
|
||||
if "unauthorized" in msg or "forbidden" in msg or "401" in msg or "403" in msg:
|
||||
return "auth_or_permission"
|
||||
if "not found" in msg or "404" in msg:
|
||||
return "model_not_found"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _hf_error_details(exc: Exception) -> str:
|
||||
"""Return compact, actionable exception details for logs."""
|
||||
status = getattr(exc, "status_code", None)
|
||||
err_type = type(exc).__name__
|
||||
message = str(exc)
|
||||
raw_body = getattr(exc, "body", None)
|
||||
details = f"type={err_type}"
|
||||
if status is not None:
|
||||
details += f", status={status}"
|
||||
if message:
|
||||
details += f", message={message}"
|
||||
if raw_body:
|
||||
details += f", body={raw_body}"
|
||||
details += f", repr={repr(exc)}"
|
||||
return details
|
||||
|
||||
def get_huggingface_api_key() -> str:
|
||||
"""Get Hugging Face API key with proper error handling."""
|
||||
api_key = os.getenv('HF_TOKEN')
|
||||
@@ -137,10 +194,15 @@ def get_huggingface_api_key() -> str:
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_hf_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def huggingface_text_response(
|
||||
prompt: str,
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
model: str = "openai/gpt-oss-120b:cerebras",
|
||||
fallback_models: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
top_p: float = 0.9,
|
||||
@@ -175,7 +237,7 @@ def huggingface_text_response(
|
||||
Example:
|
||||
result = huggingface_text_response(
|
||||
prompt="Write a blog post about AI",
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
model="openai/gpt-oss-120b:cerebras",
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
system_prompt="You are a professional content writer."
|
||||
@@ -194,7 +256,7 @@ def huggingface_text_response(
|
||||
|
||||
# Initialize Hugging Face client
|
||||
client = OpenAI(
|
||||
base_url=f"https://router.huggingface.co/hf/v1",
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ Hugging Face client initialized for text response")
|
||||
@@ -231,27 +293,14 @@ def huggingface_text_response(
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
if candidate_model != model:
|
||||
logger.warning("HF text generation switched to fallback model: {}", candidate_model)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
logger.warning("HF model not found: {}. Trying fallback model.", candidate_model)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
raise last_error or Exception("Hugging Face text generation failed: all fallback models failed")
|
||||
# Call exactly the requested model; no retries, no fallbacks, no variants
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# Extract text from response
|
||||
generated_text = response.choices[0].message.content
|
||||
@@ -267,14 +316,31 @@ def huggingface_text_response(
|
||||
return generated_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Hugging Face text generation failed: {str(e)}")
|
||||
error_class = _classify_hf_error(e)
|
||||
error_details = _hf_error_details(e)
|
||||
logger.error(f"❌ Hugging Face text generation failed: {error_details}")
|
||||
|
||||
# Extra diagnostics: try to capture raw response if available
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
logger.error(f"🔍 HF Error Diagnostics:")
|
||||
logger.error(f" - Status: {e.response.status_code}")
|
||||
logger.error(f" - Headers: {dict(e.response.headers)}")
|
||||
try:
|
||||
body_json = e.response.json()
|
||||
logger.error(f" - Body JSON: {json.dumps(body_json, indent=2)}")
|
||||
except Exception:
|
||||
logger.error(f" - Body Raw: {e.response.text[:1000]}")
|
||||
else:
|
||||
logger.error(f"🔍 No HTTP response attached to exception object.")
|
||||
|
||||
raise Exception(f"Hugging Face text generation failed: {str(e)}")
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def huggingface_structured_json_response(
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
model: str = "openai/gpt-oss-120b:cerebras",
|
||||
fallback_models: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 8192,
|
||||
system_prompt: Optional[str] = None
|
||||
@@ -338,7 +404,7 @@ def huggingface_structured_json_response(
|
||||
# Initialize OpenAI client with Hugging Face base URL
|
||||
# Use standard Inference API endpoint
|
||||
client = OpenAI(
|
||||
base_url=f"https://router.huggingface.co/hf/v1",
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ Hugging Face client initialized for structured JSON response")
|
||||
@@ -387,7 +453,7 @@ def huggingface_structured_json_response(
|
||||
try:
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
@@ -444,7 +510,7 @@ def huggingface_structured_json_response(
|
||||
logger.info("Retrying without response_format...")
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
|
||||
Reference in New Issue
Block a user