Merge_PR_416_fix_textgen_ai_models_mapping

This commit is contained in:
ajaysi
2026-03-12 16:05:47 +05:30
5 changed files with 509 additions and 81 deletions

View File

@@ -47,26 +47,10 @@ Last Updated: January 2025
""" """
import os import os
import sys
from pathlib import Path
import json import json
import re import re
from typing import Optional, Dict, Any, List from functools import lru_cache
from typing import Optional, Dict, Any
from dotenv import load_dotenv
# Fix the environment loading path - load from backend directory
current_dir = Path(__file__).parent.parent # services directory
backend_dir = current_dir.parent # backend directory
env_path = backend_dir / '.env'
if env_path.exists():
load_dotenv(env_path)
print(f"Loaded .env from: {env_path}")
else:
# Fallback to current directory
load_dotenv()
print(f"No .env found at {env_path}, using current directory")
from loguru import logger from loguru import logger
from utils.logger_utils import get_service_logger from utils.logger_utils import get_service_logger
@@ -74,22 +58,24 @@ from utils.logger_utils import get_service_logger
# Use service-specific logger to avoid conflicts # Use service-specific logger to avoid conflicts
logger = get_service_logger("huggingface_provider") logger = get_service_logger("huggingface_provider")
<<<<<<< HEAD
from tenacity import ( from tenacity import (
retry, retry,
retry_if_exception, retry_if_exception,
stop_after_attempt, stop_after_attempt,
wait_random_exponential, wait_random_exponential,
) )
=======
>>>>>>> pr-416
try: try:
from openai import OpenAI from openai import OpenAI
from openai import NotFoundError
OPENAI_AVAILABLE = True OPENAI_AVAILABLE = True
except ImportError: except ImportError:
OPENAI_AVAILABLE = False OPENAI_AVAILABLE = False
NotFoundError = Exception
logger.warn("OpenAI library not available. Install with: pip install openai") logger.warn("OpenAI library not available. Install with: pip install openai")
<<<<<<< HEAD
HF_FALLBACK_MODELS = [ HF_FALLBACK_MODELS = [
"openai/gpt-oss-120b:cerebras", "openai/gpt-oss-120b:cerebras",
"moonshotai/Kimi-K2-Instruct-0905:cerebras", "moonshotai/Kimi-K2-Instruct-0905:cerebras",
@@ -179,8 +165,32 @@ def _hf_error_details(exc: Exception) -> str:
return details return details
def get_huggingface_api_key() -> str: def get_huggingface_api_key() -> str:
=======
def _classify_hf_error(error: Exception) -> str:
message = str(error or "").lower()
if any(x in message for x in ["insufficient", "quota", "billing", "payment", "credits", "balance"]):
return "billing_or_quota"
if any(x in message for x in ["unauthorized", "forbidden", "permission", "invalid api key", "authentication"]):
return "auth_or_permission"
if ("not found" in message) or ("404" in message):
return "model_not_found"
return "other"
def _error_details(error: Exception) -> Dict[str, str]:
return {
"type": type(error).__name__,
"message": str(error),
"repr": repr(error),
}
def get_huggingface_api_key(explicit_api_key: Optional[str] = None) -> str:
>>>>>>> pr-416
"""Get Hugging Face API key with proper error handling.""" """Get Hugging Face API key with proper error handling."""
api_key = os.getenv('HF_TOKEN') api_key = explicit_api_key or os.getenv('HF_TOKEN')
if not api_key: if not api_key:
error_msg = "HF_TOKEN environment variable is not set. Please set it in your .env file." error_msg = "HF_TOKEN environment variable is not set. Please set it in your .env file."
logger.error(error_msg) logger.error(error_msg)
@@ -194,11 +204,19 @@ def get_huggingface_api_key() -> str:
return api_key return api_key
<<<<<<< HEAD
@retry( @retry(
retry=retry_if_exception(_should_retry_hf_error), retry=retry_if_exception(_should_retry_hf_error),
wait=wait_random_exponential(min=1, max=60), wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6), stop=stop_after_attempt(6),
) )
=======
@lru_cache(maxsize=16)
def _get_hf_client(api_key: str):
return OpenAI(base_url="https://router.huggingface.co/v1", api_key=api_key)
>>>>>>> pr-416
def huggingface_text_response( def huggingface_text_response(
prompt: str, prompt: str,
model: str = "openai/gpt-oss-120b:cerebras", model: str = "openai/gpt-oss-120b:cerebras",
@@ -206,7 +224,8 @@ def huggingface_text_response(
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 2048, max_tokens: int = 2048,
top_p: float = 0.9, top_p: float = 0.9,
system_prompt: Optional[str] = None system_prompt: Optional[str] = None,
api_key: Optional[str] = None,
) -> str: ) -> str:
""" """
Generate text response using Hugging Face Inference Providers API. Generate text response using Hugging Face Inference Providers API.
@@ -248,17 +267,21 @@ def huggingface_text_response(
raise ImportError("OpenAI library not available. Install with: pip install openai") raise ImportError("OpenAI library not available. Install with: pip install openai")
# Get API key with proper error handling # Get API key with proper error handling
api_key = get_huggingface_api_key() api_key = get_huggingface_api_key(api_key)
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})") logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
if not api_key: if not api_key:
raise Exception("HF_TOKEN not found in environment variables") raise Exception("HF_TOKEN not found in environment variables")
# Initialize Hugging Face client # Initialize Hugging Face client
<<<<<<< HEAD
client = OpenAI( client = OpenAI(
base_url="https://router.huggingface.co/v1", base_url="https://router.huggingface.co/v1",
api_key=api_key, api_key=api_key,
) )
=======
client = _get_hf_client(api_key)
>>>>>>> pr-416
logger.info("✅ Hugging Face client initialized for text response") logger.info("✅ Hugging Face client initialized for text response")
# Prepare input for the API # Prepare input for the API
@@ -289,11 +312,14 @@ def huggingface_text_response(
logger.info("🚀 Making Hugging Face API call (chat completion)...") logger.info("🚀 Making Hugging Face API call (chat completion)...")
<<<<<<< HEAD
# Add rate limiting to prevent expensive API calls # Add rate limiting to prevent expensive API calls
import time import time
time.sleep(1) # 1 second delay between API calls time.sleep(1) # 1 second delay between API calls
# Call exactly the requested model; no retries, no fallbacks, no variants # Call exactly the requested model; no retries, no fallbacks, no variants
=======
>>>>>>> pr-416
response = client.chat.completions.create( response = client.chat.completions.create(
model=model, model=model,
messages=messages, messages=messages,
@@ -312,11 +338,12 @@ def huggingface_text_response(
generated_text = re.sub(r'```\n?', '', generated_text) generated_text = re.sub(r'```\n?', '', generated_text)
generated_text = generated_text.strip() generated_text = generated_text.strip()
logger.info(f"✅ Hugging Face text response generated successfully (length: {len(generated_text)})") logger.info("✅ Hugging Face text response generated successfully (length: {})", len(generated_text))
return generated_text return generated_text
except Exception as e: except Exception as e:
error_class = _classify_hf_error(e) error_class = _classify_hf_error(e)
<<<<<<< HEAD
error_details = _hf_error_details(e) error_details = _hf_error_details(e)
logger.error(f"❌ Hugging Face text generation failed: {error_details}") logger.error(f"❌ Hugging Face text generation failed: {error_details}")
@@ -333,9 +360,12 @@ def huggingface_text_response(
else: else:
logger.error(f"🔍 No HTTP response attached to exception object.") logger.error(f"🔍 No HTTP response attached to exception object.")
=======
details = _error_details(e)
logger.error("❌ Hugging Face text generation failed | error_class={} | type={} | message={} | repr={}", error_class, details["type"], details["message"], details["repr"])
>>>>>>> pr-416
raise Exception(f"Hugging Face text generation failed: {str(e)}") raise Exception(f"Hugging Face text generation failed: {str(e)}")
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def huggingface_structured_json_response( def huggingface_structured_json_response(
prompt: str, prompt: str,
schema: Dict[str, Any], schema: Dict[str, Any],
@@ -343,7 +373,8 @@ def huggingface_structured_json_response(
fallback_models: Optional[List[str]] = None, fallback_models: Optional[List[str]] = None,
temperature: float = 0.7, temperature: float = 0.7,
max_tokens: int = 8192, max_tokens: int = 8192,
system_prompt: Optional[str] = None system_prompt: Optional[str] = None,
api_key: Optional[str] = None,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
Generate structured JSON response using Hugging Face Inference Providers API. Generate structured JSON response using Hugging Face Inference Providers API.
@@ -395,7 +426,7 @@ def huggingface_structured_json_response(
raise ImportError("OpenAI library not available. Install with: pip install openai") raise ImportError("OpenAI library not available. Install with: pip install openai")
# Get API key with proper error handling # Get API key with proper error handling
api_key = get_huggingface_api_key() api_key = get_huggingface_api_key(api_key)
logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})") logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
if not api_key: if not api_key:
@@ -403,10 +434,14 @@ def huggingface_structured_json_response(
# Initialize OpenAI client with Hugging Face base URL # Initialize OpenAI client with Hugging Face base URL
# Use standard Inference API endpoint # Use standard Inference API endpoint
<<<<<<< HEAD
client = OpenAI( client = OpenAI(
base_url="https://router.huggingface.co/v1", base_url="https://router.huggingface.co/v1",
api_key=api_key, api_key=api_key,
) )
=======
client = _get_hf_client(api_key)
>>>>>>> pr-416
logger.info("✅ Hugging Face client initialized for structured JSON response") logger.info("✅ Hugging Face client initialized for structured JSON response")
# Prepare input for the API # Prepare input for the API
@@ -446,11 +481,8 @@ def huggingface_structured_json_response(
json_schema_str = json.dumps(schema, indent=2) json_schema_str = json.dumps(schema, indent=2)
messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}" messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}"
# Add rate limiting to prevent expensive API calls
import time
time.sleep(1) # 1 second delay between API calls
try: try:
<<<<<<< HEAD
response = None response = None
last_error = None last_error = None
for candidate_model in _fallback_model_sequence(model, fallback_models): for candidate_model in _fallback_model_sequence(model, fallback_models):
@@ -525,25 +557,52 @@ def huggingface_structured_json_response(
last_error = nf_err last_error = nf_err
logger.warning("HF structured model not found (no response_format path): {}", candidate_model) logger.warning("HF structured model not found (no response_format path): {}", candidate_model)
continue continue
=======
response = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format={"type": "json_object"}
)
except Exception as e:
details = _error_details(e)
logger.error("❌ Hugging Face API call failed | error_class={} | type={} | message={} | repr={}", _classify_hf_error(e), details["type"], details["message"], details["repr"])
raise
>>>>>>> pr-416
if response is None: response_text = response.choices[0].message.content
raise last_error or e
response_text = response.choices[0].message.content # Clean up response text if needed
# ... (same parsing logic would apply, simplified here for brevity) response_text = response_text.strip()
if response_text.startswith("```json"):
response_text = response_text[7:]
if response_text.endswith("```"):
response_text = response_text[:-3]
response_text = response_text.strip()
try:
parsed_json = json.loads(response_text)
logger.info("✅ Hugging Face structured JSON response parsed successfully")
return parsed_json
except json.JSONDecodeError as json_err:
logger.error(f"❌ JSON parsing failed: {json_err}")
logger.error(f"Raw response: {response_text}")
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
if json_match:
try: try:
return json.loads(response_text) extracted_json = json.loads(json_match.group())
except: logger.info("✅ JSON extracted using regex fallback")
# Regex fallback return extracted_json
json_match = re.search(r'\{.*\}', response_text, re.DOTALL) except json.JSONDecodeError:
if json_match: pass
return json.loads(json_match.group()) return {"error": "Failed to parse JSON response", "raw_response": response_text}
return {"error": "Failed to parse JSON response", "raw_response": response_text}
raise e
except Exception as e: except Exception as e:
error_msg = str(e) if str(e) else repr(e) error_msg = str(e) if str(e) else repr(e)
error_type = type(e).__name__ error_type = type(e).__name__
logger.error(f"❌ Hugging Face structured JSON generation failed: {error_type}: {error_msg}") details = _error_details(e)
logger.error("❌ Hugging Face structured JSON generation failed | error_class={} | type={} | message={} | repr={}", _classify_hf_error(e), error_type, details["message"], details["repr"])
logger.error(f"❌ Full exception details: {repr(e)}") logger.error(f"❌ Full exception details: {repr(e)}")
import traceback import traceback
logger.error(f"❌ Traceback: {traceback.format_exc()}") logger.error(f"❌ Traceback: {traceback.format_exc()}")

View File

@@ -10,10 +10,124 @@ from typing import Optional, Dict, Any, List
from datetime import datetime from datetime import datetime
from loguru import logger from loguru import logger
from fastapi import HTTPException from fastapi import HTTPException
from ..onboarding.api_key_manager import APIKeyManager
from .gemini_provider import gemini_text_response, gemini_structured_json_response from .gemini_provider import gemini_text_response, gemini_structured_json_response
from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response
from .tenant_provider_config import get_available_text_providers, get_tenant_api_key
from .routing_observability import emit_routing_event
def _normalize_provider(provider: Optional[str]) -> Optional[str]:
if not provider:
return None
provider_aliases = {
"gemini": "google",
"google": "google",
"hf": "huggingface",
"hf_response_api": "huggingface",
"huggingface": "huggingface",
"wavespeed": "huggingface",
}
value = str(provider).strip().lower()
return provider_aliases.get(value, value)
def _parse_csv_env(value: Optional[str]) -> List[str]:
if not value:
return []
return [v.strip() for v in str(value).split(",") if v.strip()]
def _resolve_provider_sequence(
preferred_provider: Optional[str],
env_provider_raw: str,
available_providers: List[str],
) -> List[str]:
configured = _parse_csv_env(preferred_provider) if preferred_provider else _parse_csv_env(env_provider_raw)
normalized = [_normalize_provider(p) for p in configured if _normalize_provider(p)]
if not normalized:
if "google" in available_providers:
return ["google"]
if "huggingface" in available_providers:
return ["huggingface"]
return []
# preserve order and keep only available providers
sequence = []
for provider in normalized:
if provider in available_providers:
sequence.append(provider)
# strict mode for single configured provider: no silent remap
if len(normalized) == 1:
return sequence
# multi-provider mode: append any other available providers as tail only if none configured are available
if not sequence:
return [p for p in ["huggingface", "google"] if p in available_providers]
return sequence
def _map_logical_model_to_provider_model(provider: str, model_name: str) -> str:
"""Map logical model aliases/full names to provider-specific model IDs."""
raw = (model_name or "").strip()
if not raw:
return raw
# Full provider path supplied explicitly; use as-is.
if "/" in raw:
return raw
key = raw.lower()
hf_map = {
"gpt-oss": "openai/gpt-oss-120b:cerebras",
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
"gpt-oss-20b": "openai/gpt-oss-20b:cerebras",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
"llama": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:groq",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:groq",
}
wavespeed_map = {
"gpt-oss": "openai/gpt-oss-120b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"gpt-oss-20b": "openai/gpt-oss-20b",
"mistral": "mistralai/Mistral-7B-Instruct-v0.3",
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
"llama": "meta-llama/Llama-3.1-8B-Instruct",
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct",
}
if provider in {"huggingface", "hf", "hf_response_api"}:
return hf_map.get(key, raw)
if provider == "wavespeed":
return wavespeed_map.get(key, raw)
return raw
def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[str]] = None) -> List[str]:
models_env = _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", ""))
if provider == "google":
return ["gemini-2.0-flash-001"]
if preferred_hf_models:
return [_map_logical_model_to_provider_model(provider, m) for m in preferred_hf_models if m]
if not models_env:
return ["openai/gpt-oss-120b:groq"]
resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()]
return resolved or ["openai/gpt-oss-120b:groq"]
def llm_text_gen( def llm_text_gen(
@@ -23,7 +137,11 @@ def llm_text_gen(
user_id: str = None, user_id: str = None,
preferred_hf_models: Optional[List[str]] = None, preferred_hf_models: Optional[List[str]] = None,
preferred_provider: Optional[str] = None, preferred_provider: Optional[str] = None,
<<<<<<< HEAD
flow_type: Optional[str] = None, flow_type: Optional[str] = None,
=======
flow_type: str = "default",
>>>>>>> pr-416
) -> str: ) -> str:
""" """
Generate text using Language Model (LLM) based on the provided prompt. Generate text using Language Model (LLM) based on the provided prompt.
@@ -49,12 +167,18 @@ def llm_text_gen(
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters") logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
# Set default values for LLM parameters # Set default values for LLM parameters
<<<<<<< HEAD
gpt_provider = "huggingface" # Default to premium HF route for ALwrity AI tools gpt_provider = "huggingface" # Default to premium HF route for ALwrity AI tools
model = "openai/gpt-oss-120b:cerebras" model = "openai/gpt-oss-120b:cerebras"
=======
gpt_provider = "google"
model = "gemini-2.0-flash-001"
>>>>>>> pr-416
temperature = 0.7 temperature = 0.7
max_tokens = 4000 max_tokens = 4000
top_p = 0.9 top_p = 0.9
n = 1 n = 1
<<<<<<< HEAD
fp = 16 fp = 16
frequency_penalty = 0.0 frequency_penalty = 0.0
presence_penalty = 0.0 presence_penalty = 0.0
@@ -143,6 +267,13 @@ def llm_text_gen(
elif gpt_provider == "google": elif gpt_provider == "google":
model = "gemini-2.0-flash-001" # Google has fewer options model = "gemini-2.0-flash-001" # Google has fewer options
=======
env_provider_raw = os.getenv('GPT_PROVIDER', '').lower()
env_provider = _normalize_provider(env_provider_raw)
preferred_provider_normalized = _normalize_provider(preferred_provider)
>>>>>>> pr-416
# Default blog characteristics # Default blog characteristics
blog_tone = "Professional" blog_tone = "Professional"
blog_demographic = "Professional" blog_demographic = "Professional"
@@ -151,6 +282,7 @@ def llm_text_gen(
blog_output_format = "markdown" blog_output_format = "markdown"
blog_length = 2000 blog_length = 2000
<<<<<<< HEAD
# Check which providers have API keys available using APIKeyManager # Check which providers have API keys available using APIKeyManager
api_key_manager = APIKeyManager() api_key_manager = APIKeyManager()
available_providers = [] available_providers = []
@@ -230,12 +362,47 @@ def llm_text_gen(
model = "openai/gpt-oss-120b" model = "openai/gpt-oss-120b"
else: else:
raise RuntimeError("No supported providers available.") raise RuntimeError("No supported providers available.")
=======
available_providers = get_available_text_providers(user_id)
provider_sequence = _resolve_provider_sequence(preferred_provider, env_provider_raw, available_providers)
>>>>>>> pr-416
if gpt_provider == "huggingface" and preferred_hf_models: if not provider_sequence:
model = preferred_hf_models[0] logger.error("[llm_text_gen] No configured providers available for tenant.")
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}") raise RuntimeError("No LLM providers available for tenant.")
# strict mode if single configured provider; multi-provider fallback if comma-separated providers
pinned_provider = len(_parse_csv_env(preferred_provider or env_provider_raw)) == 1 and bool(preferred_provider or env_provider_raw)
gpt_provider = provider_sequence[0]
model_sequence = _resolve_model_sequence(gpt_provider, preferred_hf_models)
model = model_sequence[0]
hf_api_key = get_tenant_api_key(user_id, "huggingface") if gpt_provider == "huggingface" else None
logger.info(
"[llm_text_gen] Mode | providers={} | models={} | env_models={} | strict_provider={} | strict_model={}",
provider_sequence,
model_sequence,
_parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", "")),
pinned_provider,
len(model_sequence) == 1,
)
<<<<<<< HEAD
logger.info(f"[llm_text_gen][{flow_tag}] Using provider={gpt_provider}, model={model}") logger.info(f"[llm_text_gen][{flow_tag}] Using provider={gpt_provider}, model={model}")
=======
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
emit_routing_event(
logger,
"text_route_selected",
user_id=user_id,
flow_type=flow_type,
provider_selected=gpt_provider,
model_selected=model,
env_provider=env_provider_raw or "auto",
fallback_count=0,
)
>>>>>>> pr-416
# Map provider name to APIProvider enum (define at function scope for usage tracking) # Map provider name to APIProvider enum (define at function scope for usage tracking)
from models.subscription_models import APIProvider from models.subscription_models import APIProvider
@@ -291,6 +458,13 @@ def llm_text_gen(
estimated_output_tokens = int(input_tokens * 1.5) estimated_output_tokens = int(input_tokens * 1.5)
estimated_total_tokens = input_tokens + estimated_output_tokens estimated_total_tokens = input_tokens + estimated_output_tokens
logger.info(
"[llm_text_gen][subscription_preflight] start | user_id={} | provider={} | tokens_requested={}",
user_id,
actual_provider_name or provider_enum.value,
estimated_total_tokens,
)
# Check limits using sync method from pricing service (strict enforcement) # Check limits using sync method from pricing service (strict enforcement)
can_proceed, message, usage_info = pricing_service.check_usage_limits( can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id, user_id=user_id,
@@ -315,7 +489,14 @@ def llm_text_gen(
'usage_info': usage_info if usage_info else {} 'usage_info': usage_info if usage_info else {}
} }
raise HTTPException(status_code=429, detail=error_detail) raise HTTPException(status_code=429, detail=error_detail)
logger.info(
"[llm_text_gen][subscription_preflight] pass | user_id={} | provider={} | tokens_requested={}",
user_id,
actual_provider_name or provider_enum.value,
estimated_total_tokens,
)
# Get current usage for limit checking only # Get current usage for limit checking only
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
usage = db.query(UsageSummary).filter( usage = db.query(UsageSummary).filter(
@@ -361,6 +542,7 @@ def llm_text_gen(
else: else:
system_instructions = system_prompt system_instructions = system_prompt
<<<<<<< HEAD
# HF behavior: fail fast on selected model; no intra-provider model fallback chain. # HF behavior: fail fast on selected model; no intra-provider model fallback chain.
hf_fallback_models: List[str] = [] hf_fallback_models: List[str] = []
@@ -463,23 +645,27 @@ def llm_text_gen(
logger.info( logger.info(
f"[llm_text_gen][{flow_tag}] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}" f"[llm_text_gen][{flow_tag}] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}"
) )
=======
# Generate response based on provider/model sequence
response_text = None
errors: List[str] = []
for provider_idx, provider_name in enumerate(provider_sequence):
candidate_models = _resolve_model_sequence(provider_name, preferred_hf_models)
for model_idx, candidate_model in enumerate(candidate_models):
>>>>>>> pr-416
try: try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync emit_routing_event(
logger,
# Estimate tokens "text_route_attempt",
tokens_input = int(len(prompt.split()) * 1.3)
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
# Ideally we should track start_time at beginning of function
duration = 0.5
track_agent_usage_sync(
user_id=user_id, user_id=user_id,
model_name=model, flow_type=flow_type,
prompt=prompt, provider_selected=provider_name,
response_text=response_text, model_selected=candidate_model,
duration=duration provider_attempt=provider_idx + 1,
model_attempt=model_idx + 1,
) )
<<<<<<< HEAD
except Exception as usage_error: except Exception as usage_error:
# Non-blocking: log error but don't fail the request # Non-blocking: log error but don't fail the request
@@ -535,6 +721,10 @@ def llm_text_gen(
fallback_model = "openai/gpt-oss-120b" fallback_model = "openai/gpt-oss-120b"
if fallback_provider == "google": if fallback_provider == "google":
=======
if provider_name == "google":
>>>>>>> pr-416
if json_struct: if json_struct:
response_text = gemini_structured_json_response( response_text = gemini_structured_json_response(
prompt=prompt, prompt=prompt,
@@ -543,7 +733,7 @@ def llm_text_gen(
top_p=top_p, top_p=top_p,
top_k=n, top_k=n,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
) )
else: else:
response_text = gemini_text_response( response_text = gemini_text_response(
@@ -552,22 +742,29 @@ def llm_text_gen(
top_p=top_p, top_p=top_p,
n=n, n=n,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
) )
elif fallback_provider == "huggingface": elif provider_name == "huggingface":
hf_api_key_current = get_tenant_api_key(user_id, "huggingface")
if json_struct: if json_struct:
response_text = huggingface_structured_json_response( response_text = huggingface_structured_json_response(
prompt=prompt, prompt=prompt,
schema=json_struct, schema=json_struct,
<<<<<<< HEAD
model=fallback_model, model=fallback_model,
fallback_models=hf_fallback_models, fallback_models=hf_fallback_models,
=======
model=candidate_model,
>>>>>>> pr-416
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
system_prompt=system_instructions system_prompt=system_instructions,
api_key=hf_api_key_current,
) )
else: else:
response_text = huggingface_text_response( response_text = huggingface_text_response(
prompt=prompt, prompt=prompt,
<<<<<<< HEAD
model=fallback_model, model=fallback_model,
fallback_models=hf_fallback_models, fallback_models=hf_fallback_models,
temperature=temperature, temperature=temperature,
@@ -592,31 +789,37 @@ def llm_text_gen(
prompt=prompt, prompt=prompt,
model=fallback_model, model=fallback_model,
fallback_models=None, fallback_models=None,
=======
model=candidate_model,
>>>>>>> pr-416
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
top_p=top_p, top_p=top_p,
system_prompt=system_instructions system_prompt=system_instructions,
api_key=hf_api_key_current,
) )
else:
# TRACK USAGE after successful fallback call raise RuntimeError(f"Unknown provider {provider_name}")
if response_text: if response_text:
<<<<<<< HEAD
logger.info( logger.info(
f"[llm_text_gen][{flow_tag}] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}" f"[llm_text_gen][{flow_tag}] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}"
) )
=======
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
>>>>>>> pr-416
try: try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
track_agent_usage_sync( track_agent_usage_sync(
user_id=user_id, user_id=user_id,
model_name=fallback_model, model_name=candidate_model,
prompt=prompt, prompt=prompt,
response_text=response_text, response_text=response_text,
duration=0.5 # Approximate duration duration=0.5,
) )
except Exception as usage_error: except Exception as usage_error:
<<<<<<< HEAD
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True) logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
return response_text return response_text
@@ -626,6 +829,22 @@ def llm_text_gen(
# CIRCUIT BREAKER: Stop immediately to prevent expensive API calls # CIRCUIT BREAKER: Stop immediately to prevent expensive API calls
logger.error(f"[llm_text_gen][{flow_tag}] CIRCUIT BREAKER: Stopping to prevent expensive API calls.") logger.error(f"[llm_text_gen][{flow_tag}] CIRCUIT BREAKER: Stopping to prevent expensive API calls.")
raise RuntimeError("All LLM providers failed to generate a response.") raise RuntimeError("All LLM providers failed to generate a response.")
=======
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
return response_text
except Exception as provider_error:
err = f"provider={provider_name},model={candidate_model},error={provider_error}"
errors.append(err)
logger.error("[llm_text_gen] Attempt failed: {}", err)
continue
# strict provider mode: single configured provider should not switch
if pinned_provider and len(provider_sequence) == 1:
break
logger.error("[llm_text_gen] CIRCUIT BREAKER: All configured provider/model attempts failed. {}", errors)
raise RuntimeError("All configured LLM provider/model attempts failed.")
>>>>>>> pr-416
except Exception as e: except Exception as e:
logger.error(f"[llm_text_gen][{flow_tag}] Error during text generation: {str(e)}") logger.error(f"[llm_text_gen][{flow_tag}] Error during text generation: {str(e)}")
@@ -633,12 +852,21 @@ def llm_text_gen(
def check_gpt_provider(gpt_provider: str) -> bool: def check_gpt_provider(gpt_provider: str) -> bool:
"""Check if the specified GPT provider is supported.""" """Check if the specified GPT provider is supported."""
<<<<<<< HEAD
supported_providers = ["google", "huggingface", "wavespeed"] supported_providers = ["google", "huggingface", "wavespeed"]
return gpt_provider in supported_providers return gpt_provider in supported_providers
=======
providers = [_normalize_provider(p) for p in _parse_csv_env(gpt_provider)]
if not providers:
providers = [_normalize_provider(gpt_provider)]
supported_providers = {"google", "huggingface"}
return all(p in supported_providers for p in providers if p)
>>>>>>> pr-416
def get_api_key(gpt_provider: str) -> Optional[str]: def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]:
"""Get API key for the specified provider.""" """Get API key for the specified provider, preferring tenant-scoped keys."""
try: try:
<<<<<<< HEAD
api_key_manager = APIKeyManager() api_key_manager = APIKeyManager()
provider_mapping = { provider_mapping = {
"google": "gemini", "google": "gemini",
@@ -648,6 +876,10 @@ def get_api_key(gpt_provider: str) -> Optional[str]:
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider) mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
return api_key_manager.get_api_key(mapped_provider) return api_key_manager.get_api_key(mapped_provider)
=======
return get_tenant_api_key(user_id, gpt_provider)
>>>>>>> pr-416
except Exception as e: except Exception as e:
logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}") logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}")
return None return None

View File

@@ -0,0 +1,22 @@
"""Structured observability helpers for LLM routing decisions."""
from __future__ import annotations
import hashlib
import json
from typing import Any, Dict, Optional
def _mask_user_id(user_id: Optional[str]) -> str:
if not user_id:
return "anonymous"
return hashlib.sha256(str(user_id).encode("utf-8")).hexdigest()[:12]
def emit_routing_event(logger, event: str, *, user_id: Optional[str] = None, **fields: Any) -> None:
payload: Dict[str, Any] = {
"event": event,
"tenant": _mask_user_id(user_id),
**fields,
}
logger.info("[llm_routing] {}", json.dumps(payload, sort_keys=True, default=str))

View File

@@ -0,0 +1,83 @@
"""Tenant-aware provider configuration and API key resolution for LLM providers."""
from __future__ import annotations
import os
import time
from typing import Dict, Optional
from loguru import logger
from services.database import get_session_for_user
from models.onboarding import APIKey, OnboardingSession
_PROVIDER_KEY_MAP = {
"google": "gemini",
"gemini": "gemini",
"huggingface": "hf_token",
"hf": "hf_token",
"hf_response_api": "hf_token",
}
_PROVIDER_ENV_MAP = {
"gemini": "GEMINI_API_KEY",
"hf_token": "HF_TOKEN",
}
_CACHE_TTL_SECONDS = int(os.getenv("TENANT_PROVIDER_CACHE_TTL", "60"))
_cache: Dict[str, tuple[float, Optional[str]]] = {}
def _cache_key(user_id: Optional[str], provider_key: str) -> str:
return f"{user_id or 'global'}::{provider_key}"
def _normalize_provider(provider: str) -> str:
return _PROVIDER_KEY_MAP.get((provider or "").lower(), (provider or "").lower())
def get_tenant_api_key(user_id: Optional[str], provider: str) -> Optional[str]:
provider_key = _normalize_provider(provider)
ck = _cache_key(user_id, provider_key)
cached = _cache.get(ck)
now = time.time()
if cached and (now - cached[0]) < _CACHE_TTL_SECONDS:
return cached[1]
key: Optional[str] = None
if user_id:
db = None
try:
db = get_session_for_user(user_id)
if db:
record = (
db.query(APIKey.key)
.join(OnboardingSession, APIKey.session_id == OnboardingSession.id)
.filter(OnboardingSession.user_id == user_id, APIKey.provider == provider_key)
.order_by(APIKey.updated_at.desc())
.first()
)
if record and record[0]:
key = record[0]
except Exception as exc:
logger.debug("tenant api-key lookup failed for user={}, provider={}: {}", user_id, provider_key, exc)
finally:
if db:
db.close()
if not key:
env_var = _PROVIDER_ENV_MAP.get(provider_key)
if env_var:
key = os.getenv(env_var)
_cache[ck] = (now, key)
return key
def get_available_text_providers(user_id: Optional[str]) -> list[str]:
providers = []
if get_tenant_api_key(user_id, "gemini"):
providers.append("google")
if get_tenant_api_key(user_id, "huggingface"):
providers.append("huggingface")
return providers

View File

@@ -10,6 +10,20 @@ from services.database import get_session_for_user
from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService
def _ensure_dict(value: Any) -> Dict[str, Any]:
"""Safely coerce arbitrary payload shape into a dictionary."""
return value if isinstance(value, dict) else {}
def _ensure_list(value: Any) -> List[Any]:
"""Safely coerce arbitrary payload shape into a list."""
if isinstance(value, list):
return value
if value is None:
return []
return [value]
class PersonalizationService: class PersonalizationService:
""" """
Service for extracting user preferences from onboarding data Service for extracting user preferences from onboarding data
@@ -52,6 +66,7 @@ class PersonalizationService:
return self._get_default_preferences() return self._get_default_preferences()
integration_service = OnboardingDataIntegrationService() integration_service = OnboardingDataIntegrationService()
<<<<<<< HEAD
integrated_data = integration_service.get_integrated_data_sync(user_id, db) integrated_data = integration_service.get_integrated_data_sync(user_id, db)
if not isinstance(integrated_data, dict): if not isinstance(integrated_data, dict):
logger.warning( logger.warning(
@@ -65,15 +80,28 @@ class PersonalizationService:
f"[Personalization] Canonical profile is non-dict for user {user_id}; using defaults" f"[Personalization] Canonical profile is non-dict for user {user_id}; using defaults"
) )
canonical_profile = {} canonical_profile = {}
=======
integrated_data_raw = integration_service.get_integrated_data_sync(user_id, db)
integrated_data = _ensure_dict(integrated_data_raw)
canonical_profile = _ensure_dict(integrated_data.get('canonical_profile'))
>>>>>>> pr-416
# Map strictly from Canonical Profile # Map strictly from Canonical Profile
preferences = { preferences = {
"industry": canonical_profile.get("industry"), "industry": canonical_profile.get("industry"),
<<<<<<< HEAD
"target_audience": self._as_dict(canonical_profile.get("target_audience", {})), "target_audience": self._as_dict(canonical_profile.get("target_audience", {})),
"platform_preferences": self._as_list(canonical_profile.get("platform_preferences", [])), "platform_preferences": self._as_list(canonical_profile.get("platform_preferences", [])),
"content_preferences": self._as_list(canonical_profile.get("content_types", [])), "content_preferences": self._as_list(canonical_profile.get("content_types", [])),
"style_preferences": self._as_dict(canonical_profile.get("visual_style", {})), "style_preferences": self._as_dict(canonical_profile.get("visual_style", {})),
"brand_colors": self._as_list(canonical_profile.get("brand_colors", [])), "brand_colors": self._as_list(canonical_profile.get("brand_colors", [])),
=======
"target_audience": _ensure_dict(canonical_profile.get("target_audience")),
"platform_preferences": _ensure_list(canonical_profile.get("platform_preferences")),
"content_preferences": _ensure_list(canonical_profile.get("content_types")),
"style_preferences": _ensure_dict(canonical_profile.get("visual_style")),
"brand_colors": _ensure_list(canonical_profile.get("brand_colors")),
>>>>>>> pr-416
"recommended_templates": [], "recommended_templates": [],
"recommended_channels": [], "recommended_channels": [],
"writing_style": { "writing_style": {
@@ -82,7 +110,11 @@ class PersonalizationService:
"complexity": canonical_profile.get("writing_complexity", "intermediate"), "complexity": canonical_profile.get("writing_complexity", "intermediate"),
"engagement_level": canonical_profile.get("writing_engagement", "moderate"), "engagement_level": canonical_profile.get("writing_engagement", "moderate"),
}, },
<<<<<<< HEAD
"brand_values": self._as_list(canonical_profile.get("brand_values", [])), "brand_values": self._as_list(canonical_profile.get("brand_values", [])),
=======
"brand_values": _ensure_list(canonical_profile.get("brand_values")),
>>>>>>> pr-416
} }
# Ensure target_audience structure # Ensure target_audience structure
@@ -118,7 +150,7 @@ class PersonalizationService:
if not preferences["recommended_channels"]: if not preferences["recommended_channels"]:
preferences["recommended_channels"] = self._get_recommended_channels( preferences["recommended_channels"] = self._get_recommended_channels(
preferences.get("industry"), preferences.get("industry"),
preferences.get("target_audience", {}).get("demographics", []) _ensure_list(_ensure_dict(preferences.get("target_audience")).get("demographics"))
) )
logger.info(f"[Personalization] Extracted preferences for user {user_id}: industry={preferences.get('industry')}") logger.info(f"[Personalization] Extracted preferences for user {user_id}: industry={preferences.get('industry')}")