diff --git a/backend/services/llm_providers/huggingface_provider.py b/backend/services/llm_providers/huggingface_provider.py index d09ef0be..3dfa7426 100644 --- a/backend/services/llm_providers/huggingface_provider.py +++ b/backend/services/llm_providers/huggingface_provider.py @@ -47,26 +47,10 @@ Last Updated: January 2025 """ import os -import sys -from pathlib import Path import json import re -from typing import Optional, Dict, Any, List - -from dotenv import load_dotenv - -# Fix the environment loading path - load from backend directory -current_dir = Path(__file__).parent.parent # services directory -backend_dir = current_dir.parent # backend directory -env_path = backend_dir / '.env' - -if env_path.exists(): - load_dotenv(env_path) - print(f"Loaded .env from: {env_path}") -else: - # Fallback to current directory - load_dotenv() - print(f"No .env found at {env_path}, using current directory") +from functools import lru_cache +from typing import Optional, Dict, Any from loguru import logger from utils.logger_utils import get_service_logger @@ -74,22 +58,24 @@ from utils.logger_utils import get_service_logger # Use service-specific logger to avoid conflicts logger = get_service_logger("huggingface_provider") +<<<<<<< HEAD from tenacity import ( retry, retry_if_exception, stop_after_attempt, wait_random_exponential, ) +======= +>>>>>>> pr-416 try: from openai import OpenAI - from openai import NotFoundError OPENAI_AVAILABLE = True except ImportError: OPENAI_AVAILABLE = False - NotFoundError = Exception logger.warn("OpenAI library not available. Install with: pip install openai") +<<<<<<< HEAD HF_FALLBACK_MODELS = [ "openai/gpt-oss-120b:cerebras", "moonshotai/Kimi-K2-Instruct-0905:cerebras", @@ -179,8 +165,32 @@ def _hf_error_details(exc: Exception) -> str: return details def get_huggingface_api_key() -> str: +======= + + +def _classify_hf_error(error: Exception) -> str: + message = str(error or "").lower() + if any(x in message for x in ["insufficient", "quota", "billing", "payment", "credits", "balance"]): + return "billing_or_quota" + if any(x in message for x in ["unauthorized", "forbidden", "permission", "invalid api key", "authentication"]): + return "auth_or_permission" + if ("not found" in message) or ("404" in message): + return "model_not_found" + return "other" + + +def _error_details(error: Exception) -> Dict[str, str]: + return { + "type": type(error).__name__, + "message": str(error), + "repr": repr(error), + } + + +def get_huggingface_api_key(explicit_api_key: Optional[str] = None) -> str: +>>>>>>> pr-416 """Get Hugging Face API key with proper error handling.""" - api_key = os.getenv('HF_TOKEN') + api_key = explicit_api_key or os.getenv('HF_TOKEN') if not api_key: error_msg = "HF_TOKEN environment variable is not set. Please set it in your .env file." logger.error(error_msg) @@ -194,11 +204,19 @@ def get_huggingface_api_key() -> str: return api_key +<<<<<<< HEAD @retry( retry=retry_if_exception(_should_retry_hf_error), wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), ) +======= +@lru_cache(maxsize=16) +def _get_hf_client(api_key: str): + return OpenAI(base_url="https://router.huggingface.co/v1", api_key=api_key) + + +>>>>>>> pr-416 def huggingface_text_response( prompt: str, model: str = "openai/gpt-oss-120b:cerebras", @@ -206,7 +224,8 @@ def huggingface_text_response( temperature: float = 0.7, max_tokens: int = 2048, top_p: float = 0.9, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, + api_key: Optional[str] = None, ) -> str: """ Generate text response using Hugging Face Inference Providers API. @@ -248,17 +267,21 @@ def huggingface_text_response( raise ImportError("OpenAI library not available. Install with: pip install openai") # Get API key with proper error handling - api_key = get_huggingface_api_key() + api_key = get_huggingface_api_key(api_key) logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})") if not api_key: raise Exception("HF_TOKEN not found in environment variables") # Initialize Hugging Face client +<<<<<<< HEAD client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=api_key, ) +======= + client = _get_hf_client(api_key) +>>>>>>> pr-416 logger.info("✅ Hugging Face client initialized for text response") # Prepare input for the API @@ -289,11 +312,14 @@ def huggingface_text_response( logger.info("🚀 Making Hugging Face API call (chat completion)...") +<<<<<<< HEAD # Add rate limiting to prevent expensive API calls import time time.sleep(1) # 1 second delay between API calls # Call exactly the requested model; no retries, no fallbacks, no variants +======= +>>>>>>> pr-416 response = client.chat.completions.create( model=model, messages=messages, @@ -312,11 +338,12 @@ def huggingface_text_response( generated_text = re.sub(r'```\n?', '', generated_text) generated_text = generated_text.strip() - logger.info(f"✅ Hugging Face text response generated successfully (length: {len(generated_text)})") + logger.info("✅ Hugging Face text response generated successfully (length: {})", len(generated_text)) return generated_text except Exception as e: error_class = _classify_hf_error(e) +<<<<<<< HEAD error_details = _hf_error_details(e) logger.error(f"❌ Hugging Face text generation failed: {error_details}") @@ -333,9 +360,12 @@ def huggingface_text_response( else: logger.error(f"🔍 No HTTP response attached to exception object.") +======= + details = _error_details(e) + logger.error("❌ Hugging Face text generation failed | error_class={} | type={} | message={} | repr={}", error_class, details["type"], details["message"], details["repr"]) +>>>>>>> pr-416 raise Exception(f"Hugging Face text generation failed: {str(e)}") -@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) def huggingface_structured_json_response( prompt: str, schema: Dict[str, Any], @@ -343,7 +373,8 @@ def huggingface_structured_json_response( fallback_models: Optional[List[str]] = None, temperature: float = 0.7, max_tokens: int = 8192, - system_prompt: Optional[str] = None + system_prompt: Optional[str] = None, + api_key: Optional[str] = None, ) -> Dict[str, Any]: """ Generate structured JSON response using Hugging Face Inference Providers API. @@ -395,7 +426,7 @@ def huggingface_structured_json_response( raise ImportError("OpenAI library not available. Install with: pip install openai") # Get API key with proper error handling - api_key = get_huggingface_api_key() + api_key = get_huggingface_api_key(api_key) logger.info(f"🔑 Hugging Face API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})") if not api_key: @@ -403,10 +434,14 @@ def huggingface_structured_json_response( # Initialize OpenAI client with Hugging Face base URL # Use standard Inference API endpoint +<<<<<<< HEAD client = OpenAI( base_url="https://router.huggingface.co/v1", api_key=api_key, ) +======= + client = _get_hf_client(api_key) +>>>>>>> pr-416 logger.info("✅ Hugging Face client initialized for structured JSON response") # Prepare input for the API @@ -446,11 +481,8 @@ def huggingface_structured_json_response( json_schema_str = json.dumps(schema, indent=2) messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}" - # Add rate limiting to prevent expensive API calls - import time - time.sleep(1) # 1 second delay between API calls - try: +<<<<<<< HEAD response = None last_error = None for candidate_model in _fallback_model_sequence(model, fallback_models): @@ -525,25 +557,52 @@ def huggingface_structured_json_response( last_error = nf_err logger.warning("HF structured model not found (no response_format path): {}", candidate_model) continue +======= + response = client.chat.completions.create( + model=model, + messages=messages, + temperature=temperature, + max_tokens=max_tokens, + response_format={"type": "json_object"} + ) + except Exception as e: + details = _error_details(e) + logger.error("❌ Hugging Face API call failed | error_class={} | type={} | message={} | repr={}", _classify_hf_error(e), details["type"], details["message"], details["repr"]) + raise +>>>>>>> pr-416 - if response is None: - raise last_error or e - response_text = response.choices[0].message.content - # ... (same parsing logic would apply, simplified here for brevity) + response_text = response.choices[0].message.content + + # Clean up response text if needed + response_text = response_text.strip() + if response_text.startswith("```json"): + response_text = response_text[7:] + if response_text.endswith("```"): + response_text = response_text[:-3] + response_text = response_text.strip() + + try: + parsed_json = json.loads(response_text) + logger.info("✅ Hugging Face structured JSON response parsed successfully") + return parsed_json + except json.JSONDecodeError as json_err: + logger.error(f"❌ JSON parsing failed: {json_err}") + logger.error(f"Raw response: {response_text}") + json_match = re.search(r'\{.*\}', response_text, re.DOTALL) + if json_match: try: - return json.loads(response_text) - except: - # Regex fallback - json_match = re.search(r'\{.*\}', response_text, re.DOTALL) - if json_match: - return json.loads(json_match.group()) - return {"error": "Failed to parse JSON response", "raw_response": response_text} - raise e + extracted_json = json.loads(json_match.group()) + logger.info("✅ JSON extracted using regex fallback") + return extracted_json + except json.JSONDecodeError: + pass + return {"error": "Failed to parse JSON response", "raw_response": response_text} except Exception as e: error_msg = str(e) if str(e) else repr(e) error_type = type(e).__name__ - logger.error(f"❌ Hugging Face structured JSON generation failed: {error_type}: {error_msg}") + details = _error_details(e) + logger.error("❌ Hugging Face structured JSON generation failed | error_class={} | type={} | message={} | repr={}", _classify_hf_error(e), error_type, details["message"], details["repr"]) logger.error(f"❌ Full exception details: {repr(e)}") import traceback logger.error(f"❌ Traceback: {traceback.format_exc()}") diff --git a/backend/services/llm_providers/main_text_generation.py b/backend/services/llm_providers/main_text_generation.py index 0299bad4..bd013997 100644 --- a/backend/services/llm_providers/main_text_generation.py +++ b/backend/services/llm_providers/main_text_generation.py @@ -10,10 +10,124 @@ from typing import Optional, Dict, Any, List from datetime import datetime from loguru import logger from fastapi import HTTPException -from ..onboarding.api_key_manager import APIKeyManager - from .gemini_provider import gemini_text_response, gemini_structured_json_response from .huggingface_provider import huggingface_text_response, huggingface_structured_json_response +from .tenant_provider_config import get_available_text_providers, get_tenant_api_key +from .routing_observability import emit_routing_event + + +def _normalize_provider(provider: Optional[str]) -> Optional[str]: + if not provider: + return None + provider_aliases = { + "gemini": "google", + "google": "google", + "hf": "huggingface", + "hf_response_api": "huggingface", + "huggingface": "huggingface", + "wavespeed": "huggingface", + } + value = str(provider).strip().lower() + return provider_aliases.get(value, value) + + + + +def _parse_csv_env(value: Optional[str]) -> List[str]: + if not value: + return [] + return [v.strip() for v in str(value).split(",") if v.strip()] + + +def _resolve_provider_sequence( + preferred_provider: Optional[str], + env_provider_raw: str, + available_providers: List[str], +) -> List[str]: + configured = _parse_csv_env(preferred_provider) if preferred_provider else _parse_csv_env(env_provider_raw) + normalized = [_normalize_provider(p) for p in configured if _normalize_provider(p)] + + if not normalized: + if "google" in available_providers: + return ["google"] + if "huggingface" in available_providers: + return ["huggingface"] + return [] + + # preserve order and keep only available providers + sequence = [] + for provider in normalized: + if provider in available_providers: + sequence.append(provider) + + # strict mode for single configured provider: no silent remap + if len(normalized) == 1: + return sequence + + # multi-provider mode: append any other available providers as tail only if none configured are available + if not sequence: + return [p for p in ["huggingface", "google"] if p in available_providers] + + return sequence + + + + +def _map_logical_model_to_provider_model(provider: str, model_name: str) -> str: + """Map logical model aliases/full names to provider-specific model IDs.""" + raw = (model_name or "").strip() + if not raw: + return raw + + # Full provider path supplied explicitly; use as-is. + if "/" in raw: + return raw + + key = raw.lower() + + hf_map = { + "gpt-oss": "openai/gpt-oss-120b:cerebras", + "gpt-oss-120b": "openai/gpt-oss-120b:cerebras", + "gpt-oss-20b": "openai/gpt-oss-20b:cerebras", + "mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras", + "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras", + "llama": "meta-llama/Llama-3.1-8B-Instruct:groq", + "llama-8b": "meta-llama/Llama-3.1-8B-Instruct:groq", + "llama-70b": "meta-llama/Llama-3.1-70B-Instruct:groq", + } + + wavespeed_map = { + "gpt-oss": "openai/gpt-oss-120b", + "gpt-oss-120b": "openai/gpt-oss-120b", + "gpt-oss-20b": "openai/gpt-oss-20b", + "mistral": "mistralai/Mistral-7B-Instruct-v0.3", + "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3", + "llama": "meta-llama/Llama-3.1-8B-Instruct", + "llama-8b": "meta-llama/Llama-3.1-8B-Instruct", + "llama-70b": "meta-llama/Llama-3.1-70B-Instruct", + } + + if provider in {"huggingface", "hf", "hf_response_api"}: + return hf_map.get(key, raw) + if provider == "wavespeed": + return wavespeed_map.get(key, raw) + + return raw + +def _resolve_model_sequence(provider: str, preferred_hf_models: Optional[List[str]] = None) -> List[str]: + models_env = _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", "")) + + if provider == "google": + return ["gemini-2.0-flash-001"] + + if preferred_hf_models: + return [_map_logical_model_to_provider_model(provider, m) for m in preferred_hf_models if m] + + if not models_env: + return ["openai/gpt-oss-120b:groq"] + + resolved = [_map_logical_model_to_provider_model(provider, m) for m in models_env if m.strip()] + return resolved or ["openai/gpt-oss-120b:groq"] def llm_text_gen( @@ -23,7 +137,11 @@ def llm_text_gen( user_id: str = None, preferred_hf_models: Optional[List[str]] = None, preferred_provider: Optional[str] = None, +<<<<<<< HEAD flow_type: Optional[str] = None, +======= + flow_type: str = "default", +>>>>>>> pr-416 ) -> str: """ Generate text using Language Model (LLM) based on the provided prompt. @@ -49,12 +167,18 @@ def llm_text_gen( logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters") # Set default values for LLM parameters +<<<<<<< HEAD gpt_provider = "huggingface" # Default to premium HF route for ALwrity AI tools model = "openai/gpt-oss-120b:cerebras" +======= + gpt_provider = "google" + model = "gemini-2.0-flash-001" +>>>>>>> pr-416 temperature = 0.7 max_tokens = 4000 top_p = 0.9 n = 1 +<<<<<<< HEAD fp = 16 frequency_penalty = 0.0 presence_penalty = 0.0 @@ -143,6 +267,13 @@ def llm_text_gen( elif gpt_provider == "google": model = "gemini-2.0-flash-001" # Google has fewer options +======= + + env_provider_raw = os.getenv('GPT_PROVIDER', '').lower() + env_provider = _normalize_provider(env_provider_raw) + preferred_provider_normalized = _normalize_provider(preferred_provider) + +>>>>>>> pr-416 # Default blog characteristics blog_tone = "Professional" blog_demographic = "Professional" @@ -151,6 +282,7 @@ def llm_text_gen( blog_output_format = "markdown" blog_length = 2000 +<<<<<<< HEAD # Check which providers have API keys available using APIKeyManager api_key_manager = APIKeyManager() available_providers = [] @@ -230,12 +362,47 @@ def llm_text_gen( model = "openai/gpt-oss-120b" else: raise RuntimeError("No supported providers available.") +======= + available_providers = get_available_text_providers(user_id) + provider_sequence = _resolve_provider_sequence(preferred_provider, env_provider_raw, available_providers) +>>>>>>> pr-416 - if gpt_provider == "huggingface" and preferred_hf_models: - model = preferred_hf_models[0] - logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}") + if not provider_sequence: + logger.error("[llm_text_gen] No configured providers available for tenant.") + raise RuntimeError("No LLM providers available for tenant.") + + # strict mode if single configured provider; multi-provider fallback if comma-separated providers + pinned_provider = len(_parse_csv_env(preferred_provider or env_provider_raw)) == 1 and bool(preferred_provider or env_provider_raw) + gpt_provider = provider_sequence[0] + model_sequence = _resolve_model_sequence(gpt_provider, preferred_hf_models) + model = model_sequence[0] + + hf_api_key = get_tenant_api_key(user_id, "huggingface") if gpt_provider == "huggingface" else None + + logger.info( + "[llm_text_gen] Mode | providers={} | models={} | env_models={} | strict_provider={} | strict_model={}", + provider_sequence, + model_sequence, + _parse_csv_env(os.getenv("TEXTGEN_AI_MODELS", "")), + pinned_provider, + len(model_sequence) == 1, + ) +<<<<<<< HEAD logger.info(f"[llm_text_gen][{flow_tag}] Using provider={gpt_provider}, model={model}") +======= + logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}") + emit_routing_event( + logger, + "text_route_selected", + user_id=user_id, + flow_type=flow_type, + provider_selected=gpt_provider, + model_selected=model, + env_provider=env_provider_raw or "auto", + fallback_count=0, + ) +>>>>>>> pr-416 # Map provider name to APIProvider enum (define at function scope for usage tracking) from models.subscription_models import APIProvider @@ -291,6 +458,13 @@ def llm_text_gen( estimated_output_tokens = int(input_tokens * 1.5) estimated_total_tokens = input_tokens + estimated_output_tokens + logger.info( + "[llm_text_gen][subscription_preflight] start | user_id={} | provider={} | tokens_requested={}", + user_id, + actual_provider_name or provider_enum.value, + estimated_total_tokens, + ) + # Check limits using sync method from pricing service (strict enforcement) can_proceed, message, usage_info = pricing_service.check_usage_limits( user_id=user_id, @@ -315,7 +489,14 @@ def llm_text_gen( 'usage_info': usage_info if usage_info else {} } raise HTTPException(status_code=429, detail=error_detail) - + + logger.info( + "[llm_text_gen][subscription_preflight] pass | user_id={} | provider={} | tokens_requested={}", + user_id, + actual_provider_name or provider_enum.value, + estimated_total_tokens, + ) + # Get current usage for limit checking only current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") usage = db.query(UsageSummary).filter( @@ -361,6 +542,7 @@ def llm_text_gen( else: system_instructions = system_prompt +<<<<<<< HEAD # HF behavior: fail fast on selected model; no intra-provider model fallback chain. hf_fallback_models: List[str] = [] @@ -463,23 +645,27 @@ def llm_text_gen( logger.info( f"[llm_text_gen][{flow_tag}] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}" ) +======= + # Generate response based on provider/model sequence + response_text = None + errors: List[str] = [] + + for provider_idx, provider_name in enumerate(provider_sequence): + candidate_models = _resolve_model_sequence(provider_name, preferred_hf_models) + for model_idx, candidate_model in enumerate(candidate_models): +>>>>>>> pr-416 try: - from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync - - # Estimate tokens - tokens_input = int(len(prompt.split()) * 1.3) - - # Calculate duration (mocking it since we didn't track start time explicitly in this function) - # Ideally we should track start_time at beginning of function - duration = 0.5 - - track_agent_usage_sync( + emit_routing_event( + logger, + "text_route_attempt", user_id=user_id, - model_name=model, - prompt=prompt, - response_text=response_text, - duration=duration + flow_type=flow_type, + provider_selected=provider_name, + model_selected=candidate_model, + provider_attempt=provider_idx + 1, + model_attempt=model_idx + 1, ) +<<<<<<< HEAD except Exception as usage_error: # Non-blocking: log error but don't fail the request @@ -535,6 +721,10 @@ def llm_text_gen( fallback_model = "openai/gpt-oss-120b" if fallback_provider == "google": +======= + + if provider_name == "google": +>>>>>>> pr-416 if json_struct: response_text = gemini_structured_json_response( prompt=prompt, @@ -543,7 +733,7 @@ def llm_text_gen( top_p=top_p, top_k=n, max_tokens=max_tokens, - system_prompt=system_instructions + system_prompt=system_instructions, ) else: response_text = gemini_text_response( @@ -552,22 +742,29 @@ def llm_text_gen( top_p=top_p, n=n, max_tokens=max_tokens, - system_prompt=system_instructions + system_prompt=system_instructions, ) - elif fallback_provider == "huggingface": + elif provider_name == "huggingface": + hf_api_key_current = get_tenant_api_key(user_id, "huggingface") if json_struct: response_text = huggingface_structured_json_response( prompt=prompt, schema=json_struct, +<<<<<<< HEAD model=fallback_model, fallback_models=hf_fallback_models, +======= + model=candidate_model, +>>>>>>> pr-416 temperature=temperature, max_tokens=max_tokens, - system_prompt=system_instructions + system_prompt=system_instructions, + api_key=hf_api_key_current, ) else: response_text = huggingface_text_response( prompt=prompt, +<<<<<<< HEAD model=fallback_model, fallback_models=hf_fallback_models, temperature=temperature, @@ -592,31 +789,37 @@ def llm_text_gen( prompt=prompt, model=fallback_model, fallback_models=None, +======= + model=candidate_model, +>>>>>>> pr-416 temperature=temperature, max_tokens=max_tokens, top_p=top_p, - system_prompt=system_instructions + system_prompt=system_instructions, + api_key=hf_api_key_current, ) - - # TRACK USAGE after successful fallback call + else: + raise RuntimeError(f"Unknown provider {provider_name}") + if response_text: +<<<<<<< HEAD logger.info( f"[llm_text_gen][{flow_tag}] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}" ) +======= + logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}") +>>>>>>> pr-416 try: from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync - - # Estimate tokens - tokens_input = int(len(prompt.split()) * 1.3) - track_agent_usage_sync( user_id=user_id, - model_name=fallback_model, + model_name=candidate_model, prompt=prompt, response_text=response_text, - duration=0.5 # Approximate duration + duration=0.5, ) except Exception as usage_error: +<<<<<<< HEAD logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True) return response_text @@ -626,6 +829,22 @@ def llm_text_gen( # CIRCUIT BREAKER: Stop immediately to prevent expensive API calls logger.error(f"[llm_text_gen][{flow_tag}] CIRCUIT BREAKER: Stopping to prevent expensive API calls.") raise RuntimeError("All LLM providers failed to generate a response.") +======= + logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True) + return response_text + except Exception as provider_error: + err = f"provider={provider_name},model={candidate_model},error={provider_error}" + errors.append(err) + logger.error("[llm_text_gen] Attempt failed: {}", err) + continue + + # strict provider mode: single configured provider should not switch + if pinned_provider and len(provider_sequence) == 1: + break + + logger.error("[llm_text_gen] CIRCUIT BREAKER: All configured provider/model attempts failed. {}", errors) + raise RuntimeError("All configured LLM provider/model attempts failed.") +>>>>>>> pr-416 except Exception as e: logger.error(f"[llm_text_gen][{flow_tag}] Error during text generation: {str(e)}") @@ -633,12 +852,21 @@ def llm_text_gen( def check_gpt_provider(gpt_provider: str) -> bool: """Check if the specified GPT provider is supported.""" +<<<<<<< HEAD supported_providers = ["google", "huggingface", "wavespeed"] return gpt_provider in supported_providers +======= + providers = [_normalize_provider(p) for p in _parse_csv_env(gpt_provider)] + if not providers: + providers = [_normalize_provider(gpt_provider)] + supported_providers = {"google", "huggingface"} + return all(p in supported_providers for p in providers if p) +>>>>>>> pr-416 -def get_api_key(gpt_provider: str) -> Optional[str]: - """Get API key for the specified provider.""" +def get_api_key(gpt_provider: str, user_id: Optional[str] = None) -> Optional[str]: + """Get API key for the specified provider, preferring tenant-scoped keys.""" try: +<<<<<<< HEAD api_key_manager = APIKeyManager() provider_mapping = { "google": "gemini", @@ -648,6 +876,10 @@ def get_api_key(gpt_provider: str) -> Optional[str]: mapped_provider = provider_mapping.get(gpt_provider, gpt_provider) return api_key_manager.get_api_key(mapped_provider) +======= + return get_tenant_api_key(user_id, gpt_provider) +>>>>>>> pr-416 except Exception as e: logger.error(f"[get_api_key] Error getting API key for {gpt_provider}: {str(e)}") - return None + return None + diff --git a/backend/services/llm_providers/routing_observability.py b/backend/services/llm_providers/routing_observability.py new file mode 100644 index 00000000..768432ef --- /dev/null +++ b/backend/services/llm_providers/routing_observability.py @@ -0,0 +1,22 @@ +"""Structured observability helpers for LLM routing decisions.""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any, Dict, Optional + + +def _mask_user_id(user_id: Optional[str]) -> str: + if not user_id: + return "anonymous" + return hashlib.sha256(str(user_id).encode("utf-8")).hexdigest()[:12] + + +def emit_routing_event(logger, event: str, *, user_id: Optional[str] = None, **fields: Any) -> None: + payload: Dict[str, Any] = { + "event": event, + "tenant": _mask_user_id(user_id), + **fields, + } + logger.info("[llm_routing] {}", json.dumps(payload, sort_keys=True, default=str)) diff --git a/backend/services/llm_providers/tenant_provider_config.py b/backend/services/llm_providers/tenant_provider_config.py new file mode 100644 index 00000000..dc637e91 --- /dev/null +++ b/backend/services/llm_providers/tenant_provider_config.py @@ -0,0 +1,83 @@ +"""Tenant-aware provider configuration and API key resolution for LLM providers.""" + +from __future__ import annotations + +import os +import time +from typing import Dict, Optional + +from loguru import logger + +from services.database import get_session_for_user +from models.onboarding import APIKey, OnboardingSession + +_PROVIDER_KEY_MAP = { + "google": "gemini", + "gemini": "gemini", + "huggingface": "hf_token", + "hf": "hf_token", + "hf_response_api": "hf_token", +} + +_PROVIDER_ENV_MAP = { + "gemini": "GEMINI_API_KEY", + "hf_token": "HF_TOKEN", +} + +_CACHE_TTL_SECONDS = int(os.getenv("TENANT_PROVIDER_CACHE_TTL", "60")) +_cache: Dict[str, tuple[float, Optional[str]]] = {} + + +def _cache_key(user_id: Optional[str], provider_key: str) -> str: + return f"{user_id or 'global'}::{provider_key}" + + +def _normalize_provider(provider: str) -> str: + return _PROVIDER_KEY_MAP.get((provider or "").lower(), (provider or "").lower()) + + +def get_tenant_api_key(user_id: Optional[str], provider: str) -> Optional[str]: + provider_key = _normalize_provider(provider) + ck = _cache_key(user_id, provider_key) + cached = _cache.get(ck) + now = time.time() + if cached and (now - cached[0]) < _CACHE_TTL_SECONDS: + return cached[1] + + key: Optional[str] = None + if user_id: + db = None + try: + db = get_session_for_user(user_id) + if db: + record = ( + db.query(APIKey.key) + .join(OnboardingSession, APIKey.session_id == OnboardingSession.id) + .filter(OnboardingSession.user_id == user_id, APIKey.provider == provider_key) + .order_by(APIKey.updated_at.desc()) + .first() + ) + if record and record[0]: + key = record[0] + except Exception as exc: + logger.debug("tenant api-key lookup failed for user={}, provider={}: {}", user_id, provider_key, exc) + finally: + if db: + db.close() + + if not key: + env_var = _PROVIDER_ENV_MAP.get(provider_key) + if env_var: + key = os.getenv(env_var) + + _cache[ck] = (now, key) + return key + + +def get_available_text_providers(user_id: Optional[str]) -> list[str]: + providers = [] + if get_tenant_api_key(user_id, "gemini"): + providers.append("google") + if get_tenant_api_key(user_id, "huggingface"): + providers.append("huggingface") + return providers diff --git a/backend/services/product_marketing/personalization_service.py b/backend/services/product_marketing/personalization_service.py index 4099eaf3..c1107ea6 100644 --- a/backend/services/product_marketing/personalization_service.py +++ b/backend/services/product_marketing/personalization_service.py @@ -10,6 +10,20 @@ from services.database import get_session_for_user from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService +def _ensure_dict(value: Any) -> Dict[str, Any]: + """Safely coerce arbitrary payload shape into a dictionary.""" + return value if isinstance(value, dict) else {} + + +def _ensure_list(value: Any) -> List[Any]: + """Safely coerce arbitrary payload shape into a list.""" + if isinstance(value, list): + return value + if value is None: + return [] + return [value] + + class PersonalizationService: """ Service for extracting user preferences from onboarding data @@ -52,6 +66,7 @@ class PersonalizationService: return self._get_default_preferences() integration_service = OnboardingDataIntegrationService() +<<<<<<< HEAD integrated_data = integration_service.get_integrated_data_sync(user_id, db) if not isinstance(integrated_data, dict): logger.warning( @@ -65,15 +80,28 @@ class PersonalizationService: f"[Personalization] Canonical profile is non-dict for user {user_id}; using defaults" ) canonical_profile = {} +======= + integrated_data_raw = integration_service.get_integrated_data_sync(user_id, db) + integrated_data = _ensure_dict(integrated_data_raw) + canonical_profile = _ensure_dict(integrated_data.get('canonical_profile')) +>>>>>>> pr-416 # Map strictly from Canonical Profile preferences = { "industry": canonical_profile.get("industry"), +<<<<<<< HEAD "target_audience": self._as_dict(canonical_profile.get("target_audience", {})), "platform_preferences": self._as_list(canonical_profile.get("platform_preferences", [])), "content_preferences": self._as_list(canonical_profile.get("content_types", [])), "style_preferences": self._as_dict(canonical_profile.get("visual_style", {})), "brand_colors": self._as_list(canonical_profile.get("brand_colors", [])), +======= + "target_audience": _ensure_dict(canonical_profile.get("target_audience")), + "platform_preferences": _ensure_list(canonical_profile.get("platform_preferences")), + "content_preferences": _ensure_list(canonical_profile.get("content_types")), + "style_preferences": _ensure_dict(canonical_profile.get("visual_style")), + "brand_colors": _ensure_list(canonical_profile.get("brand_colors")), +>>>>>>> pr-416 "recommended_templates": [], "recommended_channels": [], "writing_style": { @@ -82,7 +110,11 @@ class PersonalizationService: "complexity": canonical_profile.get("writing_complexity", "intermediate"), "engagement_level": canonical_profile.get("writing_engagement", "moderate"), }, +<<<<<<< HEAD "brand_values": self._as_list(canonical_profile.get("brand_values", [])), +======= + "brand_values": _ensure_list(canonical_profile.get("brand_values")), +>>>>>>> pr-416 } # Ensure target_audience structure @@ -118,7 +150,7 @@ class PersonalizationService: if not preferences["recommended_channels"]: preferences["recommended_channels"] = self._get_recommended_channels( preferences.get("industry"), - preferences.get("target_audience", {}).get("demographics", []) + _ensure_list(_ensure_dict(preferences.get("target_audience")).get("demographics")) ) logger.info(f"[Personalization] Extracted preferences for user {user_id}: industry={preferences.get('industry')}")