Fix TEXTGEN_AI_MODELS full-name mapping and unify model resolution

This commit is contained in:
ي
2026-03-12 15:02:47 +05:30
parent b410ece4ca
commit 4b7f443509
5 changed files with 436 additions and 376 deletions

View File

@@ -47,83 +47,49 @@ Last Updated: January 2025
""" """
import os import os
import sys
from pathlib import Path
import json import json
import re import re
from functools import lru_cache
from typing import Optional, Dict, Any from typing import Optional, Dict, Any
from dotenv import load_dotenv
# Fix the environment loading path - load from backend directory
current_dir = Path(__file__).parent.parent # services directory
backend_dir = current_dir.parent # backend directory
env_path = backend_dir / '.env'
if env_path.exists():
load_dotenv(env_path)
print(f"Loaded .env from: {env_path}")
else:
# Fallback to current directory
load_dotenv()
print(f"No .env found at {env_path}, using current directory")
from loguru import logger from loguru import logger
from utils.logger_utils import get_service_logger from utils.logger_utils import get_service_logger
# Use service-specific logger to avoid conflicts # Use service-specific logger to avoid conflicts
logger = get_service_logger("huggingface_provider") logger = get_service_logger("huggingface_provider")
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)
try: try:
from openai import OpenAI from openai import OpenAI
from openai import NotFoundError
OPENAI_AVAILABLE = True OPENAI_AVAILABLE = True
except ImportError: except ImportError:
OPENAI_AVAILABLE = False OPENAI_AVAILABLE = False
NotFoundError = Exception
logger.warn("OpenAI library not available. Install with: pip install openai") logger.warn("OpenAI library not available. Install with: pip install openai")
HF_FALLBACK_MODELS = [
"openai/gpt-oss-120b:groq",
"moonshotai/Kimi-K2-Instruct-0905:groq",
"meta-llama/Llama-3.1-8B-Instruct:groq",
"mistralai/Mistral-7B-Instruct-v0.3:groq",
]
def _candidate_model_variants(model: str): def _classify_hf_error(error: Exception) -> str:
"""Yield model ids to try for a single logical model preference.""" message = str(error or "").lower()
if not model: if any(x in message for x in ["insufficient", "quota", "billing", "payment", "credits", "balance"]):
return return "billing_or_quota"
if any(x in message for x in ["unauthorized", "forbidden", "permission", "invalid api key", "authentication"]):
# Try configured model first (supports provider suffixes like ":groq") return "auth_or_permission"
yield model if ("not found" in message) or ("404" in message):
return "model_not_found"
# Fallback to base repo id when provider suffix is not recognized by the router return "other"
if ":" in model:
base_model = model.split(":", 1)[0]
if base_model:
yield base_model
def _fallback_model_sequence(model: str): def _error_details(error: Exception) -> Dict[str, str]:
sequence = [model] + HF_FALLBACK_MODELS return {
seen = set() "type": type(error).__name__,
for preferred_model in sequence: "message": str(error),
for candidate in _candidate_model_variants(preferred_model): "repr": repr(error),
if candidate and candidate not in seen: }
seen.add(candidate)
yield candidate
def get_huggingface_api_key() -> str:
def get_huggingface_api_key(explicit_api_key: Optional[str] = None) -> str:
"""Get Hugging Face API key with proper error handling.""" """Get Hugging Face API key with proper error handling."""
api_key = os.getenv('HF_TOKEN') api_key = explicit_api_key or os.getenv('HF_TOKEN')
if not api_key: if not api_key:
error_msg = "HF_TOKEN environment variable is not set. Please set it in your .env file." error_msg = "HF_TOKEN environment variable is not set. Please set it in your .env file."
logger.error(error_msg) logger.error(error_msg)
@@ -137,14 +103,19 @@ def get_huggingface_api_key() -> str:
return api_key return api_key
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) @lru_cache(maxsize=16)
def _get_hf_client(api_key: str):
return OpenAI(base_url="https://router.huggingface.co/v1", api_key=api_key)
def huggingface_text_response( def huggingface_text_response(
prompt: str, prompt: str,
model: str = "openai/gpt-oss-120b:groq", model: str = "openai/gpt-oss-120b:groq",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 2048, max_tokens: int = 2048,
top_p: float = 0.9, top_p: float = 0.9,
system_prompt: Optional[str] = None system_prompt: Optional[str] = None,
api_key: Optional[str] = None,
) -> str: ) -> str:
""" """
Generate text response using Hugging Face Inference Providers API. Generate text response using Hugging Face Inference Providers API.
@@ -186,17 +157,14 @@ def huggingface_text_response(
raise ImportError("OpenAI library not available. Install with: pip install openai") raise ImportError("OpenAI library not available. Install with: pip install openai")
# Get API key with proper error handling # Get API key with proper error handling
api_key = get_huggingface_api_key() api_key = get_huggingface_api_key(api_key)
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})") logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
if not api_key: if not api_key:
raise Exception("HF_TOKEN not found in environment variables") raise Exception("HF_TOKEN not found in environment variables")
# Initialize Hugging Face client # Initialize Hugging Face client
client = OpenAI( client = _get_hf_client(api_key)
base_url=f"https://router.huggingface.co/hf/v1",
api_key=api_key,
)
logger.info("✅ Hugging Face client initialized for text response") logger.info("✅ Hugging Face client initialized for text response")
# Prepare input for the API # Prepare input for the API
@@ -227,31 +195,13 @@ def huggingface_text_response(
logger.info("🚀 Making Hugging Face API call (chat completion)...") logger.info("🚀 Making Hugging Face API call (chat completion)...")
# Add rate limiting to prevent expensive API calls response = client.chat.completions.create(
import time model=model,
time.sleep(1) # 1 second delay between API calls messages=messages,
temperature=temperature,
response = None top_p=top_p,
last_error = None max_tokens=max_tokens
for candidate_model in _fallback_model_sequence(model): )
try:
response = client.chat.completions.create(
model=candidate_model,
messages=messages,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens
)
if candidate_model != model:
logger.warning("HF text generation switched to fallback model: {}", candidate_model)
break
except NotFoundError as nf_err:
last_error = nf_err
logger.warning("HF model not found: {}. Trying fallback model.", candidate_model)
continue
if response is None:
raise last_error or Exception("Hugging Face text generation failed: all fallback models failed")
# Extract text from response # Extract text from response
generated_text = response.choices[0].message.content generated_text = response.choices[0].message.content
@@ -263,21 +213,23 @@ def huggingface_text_response(
generated_text = re.sub(r'```\n?', '', generated_text) generated_text = re.sub(r'```\n?', '', generated_text)
generated_text = generated_text.strip() generated_text = generated_text.strip()
logger.info(f"✅ Hugging Face text response generated successfully (length: {len(generated_text)})") logger.info("✅ Hugging Face text response generated successfully (length: {})", len(generated_text))
return generated_text return generated_text
except Exception as e: except Exception as e:
logger.error(f"❌ Hugging Face text generation failed: {str(e)}") error_class = _classify_hf_error(e)
details = _error_details(e)
logger.error("❌ Hugging Face text generation failed | error_class={} | type={} | message={} | repr={}", error_class, details["type"], details["message"], details["repr"])
raise Exception(f"Hugging Face text generation failed: {str(e)}") raise Exception(f"Hugging Face text generation failed: {str(e)}")
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def huggingface_structured_json_response( def huggingface_structured_json_response(
prompt: str, prompt: str,
schema: Dict[str, Any], schema: Dict[str, Any],
model: str = "openai/gpt-oss-120b:groq", model: str = "openai/gpt-oss-120b:groq",
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 8192, max_tokens: int = 8192,
system_prompt: Optional[str] = None system_prompt: Optional[str] = None,
api_key: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Generate structured JSON response using Hugging Face Inference Providers API. Generate structured JSON response using Hugging Face Inference Providers API.
@@ -329,7 +281,7 @@ def huggingface_structured_json_response(
raise ImportError("OpenAI library not available. Install with: pip install openai") raise ImportError("OpenAI library not available. Install with: pip install openai")
# Get API key with proper error handling # Get API key with proper error handling
api_key = get_huggingface_api_key() api_key = get_huggingface_api_key(api_key)
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})") logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
if not api_key: if not api_key:
@@ -337,10 +289,7 @@ def huggingface_structured_json_response(
# Initialize OpenAI client with Hugging Face base URL # Initialize OpenAI client with Hugging Face base URL
# Use standard Inference API endpoint # Use standard Inference API endpoint
client = OpenAI( client = _get_hf_client(api_key)
base_url=f"https://router.huggingface.co/hf/v1",
api_key=api_key,
)
logger.info("✅ Hugging Face client initialized for structured JSON response") logger.info("✅ Hugging Face client initialized for structured JSON response")
# Prepare input for the API # Prepare input for the API
@@ -380,104 +329,51 @@ def huggingface_structured_json_response(
json_schema_str = json.dumps(schema, indent=2) json_schema_str = json.dumps(schema, indent=2)
messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}" messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}"
# Add rate limiting to prevent expensive API calls try:
import time response = client.chat.completions.create(
time.sleep(1) # 1 second delay between API calls model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={"type": "json_object"}
)
except Exception as e:
details = _error_details(e)
logger.error("❌ Hugging Face API call failed | error_class={} | type={} | message={} | repr={}", _classify_hf_error(e), details["type"], details["message"], details["repr"])
raise
response_text = response.choices[0].message.content
# Clean up response text if needed
response_text = response_text.strip()
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.endswith("```"):
response_text = response_text[:-3]
response_text = response_text.strip()
try: try:
response = None parsed_json = json.loads(response_text)
last_error = None logger.info("✅ Hugging Face structured JSON response parsed successfully")
for candidate_model in _fallback_model_sequence(model): return parsed_json
except json.JSONDecodeError as json_err:
logger.error(f"❌ JSON parsing failed: {json_err}")
logger.error(f"Raw response: {response_text}")
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
try: try:
response = client.chat.completions.create( extracted_json = json.loads(json_match.group())
model=candidate_model, logger.info("✅ JSON extracted using regex fallback")
messages=messages, return extracted_json
temperature=temperature, except json.JSONDecodeError:
max_tokens=max_tokens, pass
response_format={"type": "json_object"} # Try to enforce JSON mode if supported return {"error": "Failed to parse JSON response", "raw_response": response_text}
)
if candidate_model != model:
logger.warning("HF structured generation switched to fallback model: {}", candidate_model)
break
except NotFoundError as nf_err:
last_error = nf_err
logger.warning("HF structured model not found: {}. Trying fallback model.", candidate_model)
continue
if response is None:
raise last_error or Exception("Hugging Face structured generation failed: all fallback models failed")
response_text = response.choices[0].message.content
# Clean up response text if needed
response_text = response_text.strip()
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.endswith("```"):
response_text = response_text[:-3]
response_text = response_text.strip()
try:
parsed_json = json.loads(response_text)
logger.info("✅ Hugging Face structured JSON response parsed successfully")
return parsed_json
except json.JSONDecodeError as json_err:
logger.error(f"❌ JSON parsing failed: {json_err}")
logger.error(f"Raw response: {response_text}")
# Try to extract JSON from the response using regex
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
try:
extracted_json = json.loads(json_match.group())
logger.info("✅ JSON extracted using regex fallback")
return extracted_json
except json.JSONDecodeError:
pass
return {"error": "Failed to parse JSON response", "raw_response": response_text}
except Exception as e:
logger.error(f"❌ Hugging Face API call failed: {e}")
# If 422 Unprocessable Entity (often due to response_format not supported), retry without it
if "422" in str(e) or "not supported" in str(e).lower() or isinstance(e, NotFoundError):
logger.info("Retrying without response_format...")
response = None
last_error = None
for candidate_model in _fallback_model_sequence(model):
try:
response = client.chat.completions.create(
model=candidate_model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens
)
if candidate_model != model:
logger.warning("HF structured no-response_format fallback model: {}", candidate_model)
break
except NotFoundError as nf_err:
last_error = nf_err
logger.warning("HF structured model not found (no response_format path): {}", candidate_model)
continue
if response is None:
raise last_error or e
response_text = response.choices[0].message.content
# ... (same parsing logic would apply, simplified here for brevity)
try:
return json.loads(response_text)
except:
# Regex fallback
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
return json.loads(json_match.group())
return {"error": "Failed to parse JSON response", "raw_response": response_text}
raise e
except Exception as e: except Exception as e:
error_msg = str(e) if str(e) else repr(e) error_msg = str(e) if str(e) else repr(e)
error_type = type(e).__name__ error_type = type(e).__name__
logger.error(f"❌ Hugging Face structured JSON generation failed: {error_type}: {error_msg}") details = _error_details(e)
logger.error("❌ Hugging Face structured JSON generation failed | error_class={} | type={} | message={} | repr={}", _classify_hf_error(e), error_type, details["message"], details["repr"])
logger.error(f"❌ Full exception details: {repr(e)}") logger.error(f"❌ Full exception details: {repr(e)}")
import traceback import traceback
logger.error(f"❌ Traceback: {traceback.format_exc()}") logger.error(f"❌ Traceback: {traceback.format_exc()}")

View File

@@ -10,10 +10,124 @@ from typing import Optional, Dict, Any, List
from datetime import datetime from datetime import datetime
from loguru import logger from loguru import logger
from fastapi import HTTPException from fastapi import HTTPException
from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
from .tenant_provider_config import get_available_text_providers, get_tenant_api_key
from .routing_observability import emit_routing_event
def _normalize_provider(provider: Optional[str]) -> Optional[str]:
if not provider:
return None
provider_aliases = {
"gemini": "google",
"google": "google",
"hf": "huggingface",
"hf_response_api": "huggingface",
"huggingface": "huggingface",
"wavespeed": "huggingface",
}
value = str(provider).strip().lower()
return provider_aliases.get(value, value)
def _parse_csv_env(value: Optional[str]) -> List[str]:
if not value:
return []
return [v.strip() for v in str(value).split(",") if v.strip()]
def _resolve_provider_sequence(
preferred_provider: Optional[str],
env_provider_raw: str,
available_providers: List[str],
) -> List[str]:
configured = _parse_csv_env(preferred_provider) if preferred_provider else _parse_csv_env(env_provider_raw)
normalized = [_normalize_provider(p) for p in configured if _normalize_provider(p)]
if not normalized:
if "google" in available_providers:
return ["google"]
if "huggingface" in available_providers:
return ["huggingface"]
return []
# preserve order and keep only available providers
sequence = []
for provider in normalized:
if provider in available_providers:
sequence.append(provider)
# strict mode for single configured provider: no silent remap
if len(normalized) == 1:
return sequence
# multi-provider mode: append any other available providers as tail only if none configured are available
if not sequence:
return [p for p in ["huggingface", "google"] if p in available_providers]
return sequence
def _map_logical_model_to_provider_model(provider: str, model_name: str) -> str:
"""Map logical model aliases/full names to provider-specific model IDs."""
raw = (model_name or "").strip()
if not raw:
return raw
# Full provider path supplied explicitly; use as-is.
if "/" in raw:
return raw
key = raw.lower()
hf_map = {
"gpt-oss": "openai/gpt-oss-120b:cerebras",
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"llama": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:groq",
}
wavespeed_map = {
"gpt-oss": "openai/gpt-oss-120b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"gpt-oss-20b": "openai/gpt-oss-20b",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
"llama": "meta-llama/Llama-3.1-8B-Instruct",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct",
}
if provider in {"huggingface", "hf", "hf_response_api"}:
return hf_map.get(key, raw)
if provider == "wavespeed":
return wavespeed_map.get(key, raw)
return raw
def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[str]] = None) -> List[str]:
models_env = _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", ""))
if provider == "google":
return ["gemini-2.0-flash-001"]
if preferred_hf_models:
return [_map_logical_model_to_provider_model(provider, m) for m in preferred_hf_models if m]
if not models_env:
return ["openai/gpt-oss-120b:groq"]
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
return resolved or ["openai/gpt-oss-120b:groq"]
def llm_text_gen( def llm_text_gen(
@@ -22,6 +136,8 @@ def llm_text_gen(
json_struct: Optional[Dict[str, Any]] = None, json_struct: Optional[Dict[str, Any]] = None,
user_id: str = None, user_id: str = None,
preferred_hf_models: Optional[List[str]] = None, preferred_hf_models: Optional[List[str]] = None,
preferred_provider: Optional[str] = None,
flow_type: str = "default",
) -> str: ) -> str:
""" """
Generate text using Language Model (LLM) based on the provided prompt. Generate text using Language Model (LLM) based on the provided prompt.
@@ -43,24 +159,16 @@ def llm_text_gen(
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters") logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
# Set default values for LLM parameters # Set default values for LLM parameters
gpt_provider = "google" # Default to Google Gemini gpt_provider = "google"
model = "gemini-2.0-flash-001" model = "gemini-2.0-flash-001"
temperature = 0.7 temperature = 0.7
max_tokens = 4000 max_tokens = 4000
top_p = 0.9 top_p = 0.9
n = 1 n = 1
fp = 16
frequency_penalty = 0.0
presence_penalty = 0.0
# Check for GPT_PROVIDER environment variable env_provider_raw = os.getenv('GPT_PROVIDER', '').lower()
env_provider = os.getenv('GPT_PROVIDER', '').lower() env_provider = _normalize_provider(env_provider_raw)
if env_provider in ['gemini', 'google']: preferred_provider_normalized = _normalize_provider(preferred_provider)
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
# Default blog characteristics # Default blog characteristics
blog_tone = "Professional" blog_tone = "Professional"
@@ -70,44 +178,41 @@ def llm_text_gen(
blog_output_format = "markdown" blog_output_format = "markdown"
blog_length = 2000 blog_length = 2000
# Check which providers have API keys available using APIKeyManager available_providers = get_available_text_providers(user_id)
api_key_manager = APIKeyManager() provider_sequence = _resolve_provider_sequence(preferred_provider, env_provider_raw, available_providers)
available_providers = []
if api_key_manager.get_api_key("gemini"):
available_providers.append("google")
if api_key_manager.get_api_key("hf_token"):
available_providers.append("huggingface")
# If no environment variable set, auto-detect based on available keys if not provider_sequence:
if not env_provider: logger.error("[llm_text_gen] No configured providers available for tenant.")
# Prefer Google Gemini if available, otherwise use Hugging Face raise RuntimeError("No LLM providers available for tenant.")
if "google" in available_providers:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
else:
logger.error("[llm_text_gen] No API keys found for supported providers.")
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
else:
# Environment variable was set, validate it's supported
if gpt_provider not in available_providers:
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
if "google" in available_providers:
gpt_provider = "google"
model = "gemini-2.0-flash-001"
elif "huggingface" in available_providers:
gpt_provider = "huggingface"
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
else:
raise RuntimeError("No supported providers available.")
if gpt_provider == "huggingface" and preferred_hf_models: # strict mode if single configured provider; multi-provider fallback if comma-separated providers
model = preferred_hf_models[0] pinned_provider = len(_parse_csv_env(preferred_provider or env_provider_raw)) == 1 and bool(preferred_provider or env_provider_raw)
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}") gpt_provider = provider_sequence[0]
model_sequence = _resolve_model_sequence(gpt_provider, preferred_hf_models)
model = model_sequence[0]
hf_api_key = get_tenant_api_key(user_id, "huggingface") if gpt_provider == "huggingface" else None
logger.info(
"[llm_text_gen] Mode | providers={} | models={} | env_models={} | strict_provider={} | strict_model={}",
provider_sequence,
model_sequence,
_parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", "")),
pinned_provider,
len(model_sequence) == 1,
)
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}") logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
emit_routing_event(
logger,
"text_route_selected",
user_id=user_id,
flow_type=flow_type,
provider_selected=gpt_provider,
model_selected=model,
env_provider=env_provider_raw or "auto",
fallback_count=0,
)
# Map provider name to APIProvider enum (define at function scope for usage tracking) # Map provider name to APIProvider enum (define at function scope for usage tracking)
from models.subscription_models import APIProvider from models.subscription_models import APIProvider
@@ -155,6 +260,13 @@ def llm_text_gen(
estimated_output_tokens = int(input_tokens * 1.5) estimated_output_tokens = int(input_tokens * 1.5)
estimated_total_tokens = input_tokens + estimated_output_tokens estimated_total_tokens = input_tokens + estimated_output_tokens
logger.info(
"[llm_text_gen][subscription_preflight] start | user_id={} | provider={} | tokens_requested={}",
user_id,
actual_provider_name or provider_enum.value,
estimated_total_tokens,
)
# Check limits using sync method from pricing service (strict enforcement) # Check limits using sync method from pricing service (strict enforcement)
can_proceed, message, usage_info = pricing_service.check_usage_limits( can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id, user_id=user_id,
@@ -174,6 +286,13 @@ def llm_text_gen(
} }
raise HTTPException(status_code=429, detail=error_detail) raise HTTPException(status_code=429, detail=error_detail)
logger.info(
"[llm_text_gen][subscription_preflight] pass | user_id={} | provider={} | tokens_requested={}",
user_id,
actual_provider_name or provider_enum.value,
estimated_total_tokens,
)
# Get current usage for limit checking only # Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter( usage = db.query(UsageSummary).filter(
@@ -219,103 +338,26 @@ def llm_text_gen(
else: else:
system_instructions = system_prompt system_instructions = system_prompt
# Generate response based on provider # Generate response based on provider/model sequence
response_text = None response_text = None
actual_provider_used = gpt_provider errors: List[str] = []
try:
if gpt_provider == "google":
if json_struct:
response_text = gemini_structured_json_response(
prompt=prompt,
schema=json_struct,
temperature=temperature,
top_p=top_p,
top_k=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = gemini_text_response(
prompt=prompt,
temperature=temperature,
top_p=top_p,
n=n,
max_tokens=max_tokens,
system_prompt=system_instructions
)
elif gpt_provider == "huggingface":
if json_struct:
response_text = huggingface_structured_json_response(
prompt=prompt,
schema=json_struct,
model=model,
temperature=temperature,
max_tokens=max_tokens,
system_prompt=system_instructions
)
else:
response_text = huggingface_text_response(
prompt=prompt,
model=model,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
system_prompt=system_instructions
)
else:
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
# TRACK USAGE after successful API call for provider_idx, provider_name in enumerate(provider_sequence):
if response_text: candidate_models = _resolve_model_sequence(provider_name, preferred_hf_models)
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}") for model_idx, candidate_model in enumerate(candidate_models):
try: try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync emit_routing_event(
logger,
# Estimate tokens "text_route_attempt",
tokens_input = int(len(prompt.split()) * 1.3)
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
# Ideally we should track start_time at beginning of function
duration = 0.5
track_agent_usage_sync(
user_id=user_id, user_id=user_id,
model_name=model, flow_type=flow_type,
prompt=prompt, provider_selected=provider_name,
response_text=response_text, model_selected=candidate_model,
duration=duration provider_attempt=provider_idx + 1,
model_attempt=model_idx + 1,
) )
except Exception as usage_error: if provider_name == "google":
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return response_text
except Exception as provider_error:
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
# CIRCUIT BREAKER: Only try ONE fallback to prevent expensive API calls
fallback_providers = ["google", "huggingface"]
fallback_providers = [p for p in fallback_providers if p in available_providers and p != gpt_provider]
if fallback_providers:
fallback_provider = fallback_providers[0] # Only try the first available
try:
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
actual_provider_used = fallback_provider
# Update provider enum for fallback
if fallback_provider == "google":
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
fallback_model = "gemini-2.0-flash-lite"
elif fallback_provider == "huggingface":
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
if fallback_provider == "google":
if json_struct: if json_struct:
response_text = gemini_structured_json_response( response_text = gemini_structured_json_response(
prompt=prompt, prompt=prompt,
@@ -324,7 +366,7 @@ def llm_text_gen(
top_p=top_p, top_p=top_p,
top_k=n, top_k=n,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
) )
else: else:
response_text = gemini_text_response( response_text = gemini_text_response(
@@ -333,54 +375,59 @@ def llm_text_gen(
top_p=top_p, top_p=top_p,
n=n, n=n,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
) )
elif fallback_provider == "huggingface": elif provider_name == "huggingface":
hf_api_key_current = get_tenant_api_key(user_id, "huggingface")
if json_struct: if json_struct:
response_text = huggingface_structured_json_response( response_text = huggingface_structured_json_response(
prompt=prompt, prompt=prompt,
schema=json_struct, schema=json_struct,
model="mistralai/Mistral-7B-Instruct-v0.3:groq", model=candidate_model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
api_key=hf_api_key_current,
) )
else: else:
response_text = huggingface_text_response( response_text = huggingface_text_response(
prompt=prompt, prompt=prompt,
model="mistralai/Mistral-7B-Instruct-v0.3:groq", model=candidate_model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
system_prompt=system_instructions system_prompt=system_instructions,
api_key=hf_api_key_current,
) )
else:
raise RuntimeError(f"Unknown provider {provider_name}")
# TRACK USAGE after successful fallback call
if response_text: if response_text:
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}") logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try: try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
track_agent_usage_sync( track_agent_usage_sync(
user_id=user_id, user_id=user_id,
model_name=fallback_model, model_name=candidate_model,
prompt=prompt, prompt=prompt,
response_text=response_text, response_text=response_text,
duration=0.5 # Approximate duration duration=0.5,
) )
except Exception as usage_error: except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True) logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return response_text
except Exception as provider_error:
err = f"provider={provider_name},model={candidate_model},error={provider_error}"
errors.append(err)
logger.error("[llm_text_gen] Attempt failed: {}", err)
continue
return response_text # strict provider mode: single configured provider should not switch
except Exception as fallback_error: if pinned_provider and len(provider_sequence) == 1:
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}") break
# CIRCUIT BREAKER: Stop immediately to prevent expensive API calls logger.error("[llm_text_gen] CIRCUIT BREAKER: All configured provider/model attempts failed. {}", errors)
logger.error("[llm_text_gen] CIRCUIT BREAKER: Stopping to prevent expensive API calls.") raise RuntimeError("All configured LLM provider/model attempts failed.")
raise RuntimeError("All LLM providers failed to generate a response.")
except Exception as e: except Exception as e:
logger.error(f"[llm_text_gen] Error during text generation: {str(e)}") logger.error(f"[llm_text_gen] Error during text generation: {str(e)}")
@@ -388,20 +435,17 @@ def llm_text_gen(
def check_gpt_provider(gpt_provider: str) -> bool: def check_gpt_provider(gpt_provider: str) -> bool:
"""Check if the specified GPT provider is supported.""" """Check if the specified GPT provider is supported."""
supported_providers = ["google", "huggingface"] providers = [_normalize_provider(p) for p in _parse_csv_env(gpt_provider)]
return gpt_provider in supported_providers if not providers:
providers = [_normalize_provider(gpt_provider)]
supported_providers = {"google", "huggingface"}
return all(p in supported_providers for p in providers if p)
def get_api_key(gpt_provider: str) -> Optional[str]: def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
"""Get API key for the specified provider.""" """Get API key for the specified provider, preferring tenant-scoped keys."""
try: try:
api_key_manager = APIKeyManager() return get_tenant_api_key(user_id, gpt_provider)
provider_mapping = {
"google": "gemini",
"huggingface": "hf_token"
}
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
return api_key_manager.get_api_key(mapped_provider)
except Exception as e: except Exception as e:
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}") logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
return None return None

View File

@@ -0,0 +1,22 @@
"""Structured observability helpers for LLM routing decisions."""
from __future__ import annotations
import hashlib
import json
from typing import Any, Dict, Optional
def _mask_user_id(user_id: Optional[str]) -> str:
if not user_id:
return "anonymous"
return hashlib.sha256(str(user_id).encode("utf-8")).hexdigest()[:12]
def emit_routing_event(logger, event: str, *, user_id: Optional[str] = None, **fields: Any) -> None:
payload: Dict[str, Any] = {
"event": event,
"tenant": _mask_user_id(user_id),
**fields,
}
logger.info("[llm_routing] {}", json.dumps(payload, sort_keys=True, default=str))

View File

@@ -0,0 +1,83 @@
"""Tenant-aware provider configuration and API key resolution for LLM providers."""
from __future__ import annotations
import os
import time
from typing import Dict, Optional
from loguru import logger
from services.database import get_session_for_user
from models.onboarding import APIKey, OnboardingSession
_PROVIDER_KEY_MAP = {
"google": "gemini",
"gemini": "gemini",
"huggingface": "hf_token",
"hf": "hf_token",
"hf_response_api": "hf_token",
}
_PROVIDER_ENV_MAP = {
"gemini": "GEMINI_API_KEY",
"hf_token": "HF_TOKEN",
}
_CACHE_TTL_SECONDS = int(os.getenv("TENANT_PROVIDER_CACHE_TTL", "60"))
_cache: Dict[str, tuple[float, Optional[str]]] = {}
def _cache_key(user_id: Optional[str], provider_key: str) -> str:
return f"{user_id or 'global'}::{provider_key}"
def _normalize_provider(provider: str) -> str:
return _PROVIDER_KEY_MAP.get((provider or "").lower(), (provider or "").lower())
def get_tenant_api_key(user_id: Optional[str], provider: str) -> Optional[str]:
provider_key = _normalize_provider(provider)
ck = _cache_key(user_id, provider_key)
cached = _cache.get(ck)
now = time.time()
if cached and (now - cached[0]) < _CACHE_TTL_SECONDS:
return cached[1]
key: Optional[str] = None
if user_id:
db = None
try:
db = get_session_for_user(user_id)
if db:
record = (
db.query(APIKey.key)
.join(OnboardingSession, APIKey.session_id == OnboardingSession.id)
.filter(OnboardingSession.user_id == user_id, APIKey.provider == provider_key)
.order_by(APIKey.updated_at.desc())
.first()
)
if record and record[0]:
key = record[0]
except Exception as exc:
logger.debug("tenant api-key lookup failed for user={}, provider={}: {}", user_id, provider_key, exc)
finally:
if db:
db.close()
if not key:
env_var = _PROVIDER_ENV_MAP.get(provider_key)
if env_var:
key = os.getenv(env_var)
_cache[ck] = (now, key)
return key
def get_available_text_providers(user_id: Optional[str]) -> list[str]:
providers = []
if get_tenant_api_key(user_id, "gemini"):
providers.append("google")
if get_tenant_api_key(user_id, "huggingface"):
providers.append("huggingface")
return providers

View File

@@ -10,6 +10,20 @@ from services.database import SessionLocal
from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService
def _ensure_dict(value: Any) -> Dict[str, Any]:
"""Safely coerce arbitrary payload shape into a dictionary."""
return value if isinstance(value, dict) else {}
def _ensure_list(value: Any) -> List[Any]:
"""Safely coerce arbitrary payload shape into a list."""
if isinstance(value, list):
return value
if value is None:
return []
return [value]
class PersonalizationService: class PersonalizationService:
""" """
Service for extracting user preferences from onboarding data Service for extracting user preferences from onboarding data
@@ -39,17 +53,18 @@ class PersonalizationService:
db = SessionLocal() db = SessionLocal()
try: try:
integration_service = OnboardingDataIntegrationService() integration_service = OnboardingDataIntegrationService()
integrated_data = integration_service.get_integrated_data_sync(user_id, db) integrated_data_raw = integration_service.get_integrated_data_sync(user_id, db)
canonical_profile = integrated_data.get('canonical_profile', {}) integrated_data = _ensure_dict(integrated_data_raw)
canonical_profile = _ensure_dict(integrated_data.get('canonical_profile'))
# Map strictly from Canonical Profile # Map strictly from Canonical Profile
preferences = { preferences = {
"industry": canonical_profile.get("industry"), "industry": canonical_profile.get("industry"),
"target_audience": canonical_profile.get("target_audience", {}), "target_audience": _ensure_dict(canonical_profile.get("target_audience")),
"platform_preferences": canonical_profile.get("platform_preferences", []), "platform_preferences": _ensure_list(canonical_profile.get("platform_preferences")),
"content_preferences": canonical_profile.get("content_types", []), "content_preferences": _ensure_list(canonical_profile.get("content_types")),
"style_preferences": canonical_profile.get("visual_style", {}), "style_preferences": _ensure_dict(canonical_profile.get("visual_style")),
"brand_colors": canonical_profile.get("brand_colors", []), "brand_colors": _ensure_list(canonical_profile.get("brand_colors")),
"recommended_templates": [], "recommended_templates": [],
"recommended_channels": [], "recommended_channels": [],
"writing_style": { "writing_style": {
@@ -58,7 +73,7 @@ class PersonalizationService:
"complexity": canonical_profile.get("writing_complexity", "intermediate"), "complexity": canonical_profile.get("writing_complexity", "intermediate"),
"engagement_level": canonical_profile.get("writing_engagement", "moderate"), "engagement_level": canonical_profile.get("writing_engagement", "moderate"),
}, },
"brand_values": canonical_profile.get("brand_values", []), "brand_values": _ensure_list(canonical_profile.get("brand_values")),
} }
# Ensure target_audience structure # Ensure target_audience structure
@@ -94,7 +109,7 @@ class PersonalizationService:
if not preferences["recommended_channels"]: if not preferences["recommended_channels"]:
preferences["recommended_channels"] = self._get_recommended_channels( preferences["recommended_channels"] = self._get_recommended_channels(
preferences.get("industry"), preferences.get("industry"),
preferences.get("target_audience", {}).get("demographics", []) _ensure_list(_ensure_dict(preferences.get("target_audience")).get("demographics"))
) )
logger.info(f"[Personalization] Extracted preferences for user {user_id}: industry={preferences.get('industry')}") logger.info(f"[Personalization] Extracted preferences for user {user_id}: industry={preferences.get('industry')}")