"feat:enhance-podcast-topic-ai"
This commit is contained in:
@@ -108,6 +108,46 @@ def get_user_db_path(user_id: str) -> str:
|
||||
# Default to specific for new databases
|
||||
return specific_db_path
|
||||
|
||||
|
||||
def has_onboarding_session(user_id: str, db: Optional[Session] = None) -> bool:
|
||||
"""Return True when at least one onboarding session exists for the given user."""
|
||||
if not user_id:
|
||||
return False
|
||||
|
||||
db_session = db
|
||||
close_db = False
|
||||
|
||||
try:
|
||||
if db_session is None:
|
||||
# Avoid opening/creating a DB for non-existent user workspace.
|
||||
db_path = get_user_db_path(user_id)
|
||||
if not os.path.exists(db_path):
|
||||
return False
|
||||
db_session = get_session_for_user(user_id)
|
||||
close_db = True
|
||||
|
||||
if not db_session:
|
||||
return False
|
||||
|
||||
from models.onboarding import OnboardingSession
|
||||
|
||||
onboarding_row = (
|
||||
db_session.query(OnboardingSession.id)
|
||||
.filter(OnboardingSession.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
return onboarding_row is not None
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed onboarding session existence check for user {user_id}: {e}")
|
||||
return False
|
||||
finally:
|
||||
if close_db and db_session:
|
||||
try:
|
||||
db_session.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_all_user_ids() -> List[str]:
|
||||
"""
|
||||
Discover all user IDs by scanning workspace directories.
|
||||
|
||||
@@ -23,9 +23,15 @@ def track_agent_usage_sync(user_id: str, model_name: str, prompt: str, response_
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini"
|
||||
elif "gpt" in model_lower or "openai" in model_lower or "mistral" in model_lower:
|
||||
# HuggingFace/Mistral often mapped to gpt-oss or mistral
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
# Check if it's WaveSpeed vs HuggingFace based on context or model naming
|
||||
# WaveSpeed models don't have :cerebras suffix, HF models do
|
||||
if ":cerebras" in model_name.lower() or "huggingface" in model_name.lower():
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
else:
|
||||
# Assume WaveSpeed for gpt models without provider suffix
|
||||
provider_enum = APIProvider.WAVESPEED
|
||||
actual_provider_name = "wavespeed"
|
||||
elif "claude" in model_lower or "anthropic" in model_lower:
|
||||
provider_enum = APIProvider.ANTHROPIC
|
||||
actual_provider_name = "anthropic"
|
||||
|
||||
@@ -340,6 +340,7 @@ class BaseALwrityAgent(ABC):
|
||||
prompt=prompt,
|
||||
user_id=self.user_id,
|
||||
preferred_hf_models=LOW_COST_REMOTE_MODELS,
|
||||
flow_type="sif_agent",
|
||||
),
|
||||
)
|
||||
logger.warning(
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from loguru import logger
|
||||
from .base import SIFBaseAgent, TXTAI_AVAILABLE, Agent
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, TaskProposal
|
||||
from services.database import has_onboarding_session
|
||||
|
||||
try:
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
@@ -22,11 +23,16 @@ class CompetitorResponseAgent(BaseALwrityAgent):
|
||||
super().__init__(user_id, "competitor_analyst", shared_llm_name, llm, **kwargs)
|
||||
|
||||
self.sif_service = None
|
||||
if SIF_AVAILABLE:
|
||||
if SIF_AVAILABLE and has_onboarding_session(user_id):
|
||||
try:
|
||||
self.sif_service = SIFIntegrationService(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SIF service for CompetitorResponseAgent: {e}")
|
||||
elif SIF_AVAILABLE:
|
||||
logger.debug(
|
||||
"Skipping SIF service initialization for CompetitorResponseAgent user {}: no onboarding session",
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""Create a specialized txtai Agent for competitor analysis."""
|
||||
|
||||
@@ -8,6 +8,7 @@ from .base import SIFBaseAgent, TXTAI_AVAILABLE, Agent
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, TaskProposal
|
||||
from services.seo_tools.content_strategy_service import ContentStrategyService
|
||||
from services.analytics import PlatformAnalyticsService
|
||||
from services.database import has_onboarding_session
|
||||
|
||||
try:
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
@@ -26,11 +27,16 @@ class ContentStrategyAgent(BaseALwrityAgent):
|
||||
|
||||
self.sif_service = None
|
||||
self.content_strategy_service = ContentStrategyService()
|
||||
if SIF_AVAILABLE:
|
||||
if SIF_AVAILABLE and has_onboarding_session(user_id):
|
||||
try:
|
||||
self.sif_service = SIFIntegrationService(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SIF service for ContentStrategyAgent: {e}")
|
||||
elif SIF_AVAILABLE:
|
||||
logger.debug(
|
||||
"Skipping SIF service initialization for ContentStrategyAgent user {}: no onboarding session",
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""Create a specialized txtai Agent for content strategy with tools."""
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from loguru import logger
|
||||
from .base import SIFBaseAgent, TXTAI_AVAILABLE, Agent
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, TaskProposal
|
||||
from services.database import has_onboarding_session
|
||||
|
||||
try:
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
@@ -22,11 +23,16 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
super().__init__(user_id, "seo_specialist", shared_llm_name, llm, **kwargs)
|
||||
|
||||
self.sif_service = None
|
||||
if SIF_AVAILABLE:
|
||||
if SIF_AVAILABLE and has_onboarding_session(user_id):
|
||||
try:
|
||||
self.sif_service = SIFIntegrationService(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SIF service for SEOOptimizationAgent: {e}")
|
||||
elif SIF_AVAILABLE:
|
||||
logger.debug(
|
||||
"Skipping SIF service initialization for SEOOptimizationAgent user {}: no onboarding session",
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""Create a specialized txtai Agent for SEO optimization."""
|
||||
|
||||
@@ -6,6 +6,7 @@ from datetime import datetime
|
||||
from loguru import logger
|
||||
from .base import SIFBaseAgent, TXTAI_AVAILABLE, Agent
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, TaskProposal
|
||||
from services.database import has_onboarding_session
|
||||
|
||||
try:
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
@@ -22,11 +23,16 @@ class SocialAmplificationAgent(BaseALwrityAgent):
|
||||
super().__init__(user_id, "social_media_manager", shared_llm_name, llm, **kwargs)
|
||||
|
||||
self.sif_service = None
|
||||
if SIF_AVAILABLE:
|
||||
if SIF_AVAILABLE and has_onboarding_session(user_id):
|
||||
try:
|
||||
self.sif_service = SIFIntegrationService(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to initialize SIF service for SocialAmplificationAgent: {e}")
|
||||
elif SIF_AVAILABLE:
|
||||
logger.debug(
|
||||
"Skipping SIF service initialization for SocialAmplificationAgent user {}: no onboarding session",
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""Create a specialized txtai Agent for social media."""
|
||||
|
||||
@@ -13,6 +13,7 @@ from datetime import datetime, timedelta
|
||||
from dataclasses import dataclass, asdict
|
||||
from loguru import logger
|
||||
|
||||
from services.database import has_onboarding_session
|
||||
from ..txtai_service import TxtaiIntelligenceService
|
||||
from ..semantic_cache import semantic_cache_manager
|
||||
from ..sif_integration import SIFIntegrationService
|
||||
@@ -74,9 +75,15 @@ class RealTimeSemanticMonitor:
|
||||
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.intelligence_service = TxtaiIntelligenceService(user_id)
|
||||
self.cache_manager = semantic_cache_manager
|
||||
self.sif_service = SIFIntegrationService(user_id)
|
||||
self.sif_enabled = has_onboarding_session(user_id)
|
||||
self.intelligence_service = TxtaiIntelligenceService(user_id) if self.sif_enabled else None
|
||||
self.sif_service = SIFIntegrationService(user_id) if self.sif_enabled else None
|
||||
if not self.sif_enabled:
|
||||
logger.info(
|
||||
"Skipping semantic monitor SIF initialization for user {}: no onboarding session found",
|
||||
user_id,
|
||||
)
|
||||
|
||||
# Initialize monitoring agents (lazy initialization to avoid circular imports)
|
||||
self.strategy_agent = None
|
||||
@@ -239,6 +246,9 @@ class RealTimeSemanticMonitor:
|
||||
async def _check_semantic_health(self) -> List[SemanticHealthMetric]:
|
||||
"""Check overall semantic health of user's content."""
|
||||
metrics = []
|
||||
|
||||
if not self.sif_enabled or not self.sif_service:
|
||||
return metrics
|
||||
|
||||
try:
|
||||
# Get current semantic insights
|
||||
@@ -301,6 +311,8 @@ class RealTimeSemanticMonitor:
|
||||
async def _monitor_competitors(self) -> List[CompetitorSemanticSnapshot]:
|
||||
"""Monitor competitor semantic positioning."""
|
||||
snapshots = []
|
||||
if not self.sif_enabled or not self.intelligence_service:
|
||||
return snapshots
|
||||
try:
|
||||
# 1. Get competitors from SIF integration
|
||||
# We assume SIFIntegrationService has methods to get competitor data or we query index
|
||||
@@ -370,6 +382,9 @@ class RealTimeSemanticMonitor:
|
||||
async def _analyze_content_performance(self) -> List[ContentSemanticInsight]:
|
||||
"""Analyze content performance and identify insights using SIF Agents."""
|
||||
insights = []
|
||||
|
||||
if not self.sif_enabled or not self.sif_service:
|
||||
return insights
|
||||
|
||||
try:
|
||||
current_time = datetime.now()
|
||||
|
||||
@@ -34,7 +34,12 @@ class SharedLLMWrapper:
|
||||
try:
|
||||
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
|
||||
# but we could map them if needed.
|
||||
return llm_text_gen(prompt, user_id=self.user_id)
|
||||
return llm_text_gen(
|
||||
prompt,
|
||||
user_id=self.user_id,
|
||||
preferred_hf_models=LOW_COST_SHARED_REMOTE_MODELS,
|
||||
flow_type="sif_agent",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"SharedLLMWrapper failed to generate text: {e}")
|
||||
return f"[ERROR: Shared LLM generation failed for user {self.user_id}]"
|
||||
@@ -44,6 +49,12 @@ class SharedLLMWrapper:
|
||||
|
||||
_local_llm_cache = {}
|
||||
|
||||
LOW_COST_SHARED_REMOTE_MODELS = [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
]
|
||||
|
||||
LOCAL_LLM_FALLBACKS = [
|
||||
"Qwen/Qwen2.5-1.5B-Instruct",
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
|
||||
@@ -12,7 +12,7 @@ from datetime import datetime
|
||||
from sqlalchemy import select, desc
|
||||
import json
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_session_for_user, has_onboarding_session
|
||||
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
|
||||
|
||||
# Import existing SIF components
|
||||
@@ -1070,8 +1070,14 @@ class SIFIntegrationAPI:
|
||||
def __init__(self):
|
||||
self.services: Dict[str, SIFIntegrationService] = {}
|
||||
|
||||
def get_service(self, user_id: str) -> SIFIntegrationService:
|
||||
def get_service(self, user_id: str) -> Optional[SIFIntegrationService]:
|
||||
"""Get or create SIF service for a user."""
|
||||
if not has_onboarding_session(user_id):
|
||||
logger.debug(
|
||||
"Skipping SIF service creation for user {} via SIFIntegrationAPI: no onboarding session",
|
||||
user_id,
|
||||
)
|
||||
return None
|
||||
if user_id not in self.services:
|
||||
self.services[user_id] = SIFIntegrationService(user_id)
|
||||
return self.services[user_id]
|
||||
@@ -1079,11 +1085,25 @@ class SIFIntegrationAPI:
|
||||
async def get_semantic_insights_with_cache(self, user_id: str, website_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get semantic insights with caching metadata."""
|
||||
service = self.get_service(user_id)
|
||||
if not service:
|
||||
return {
|
||||
"source": "skipped",
|
||||
"reason": "no_onboarding_session",
|
||||
"insights": {},
|
||||
}
|
||||
return await service.get_semantic_insights(website_data)
|
||||
|
||||
async def get_cache_performance(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get cache performance metrics for a user."""
|
||||
service = self.get_service(user_id)
|
||||
if not service:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"cache_enabled": False,
|
||||
"performance": {},
|
||||
"reason": "no_onboarding_session",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
stats = service.get_cache_performance_stats()
|
||||
|
||||
return {
|
||||
@@ -1096,6 +1116,13 @@ class SIFIntegrationAPI:
|
||||
async def invalidate_user_cache(self, user_id: str, reason: str = "api_request") -> Dict[str, Any]:
|
||||
"""Invalidate cache for a specific user."""
|
||||
service = self.get_service(user_id)
|
||||
if not service:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"success": False,
|
||||
"reason": "no_onboarding_session",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
success = await service.invalidate_user_cache(reason)
|
||||
|
||||
return {
|
||||
|
||||
@@ -83,6 +83,7 @@ from utils.logger_utils import get_service_logger
|
||||
logger = get_service_logger("gemini_provider")
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
@@ -114,7 +115,27 @@ def get_gemini_api_key() -> str:
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def _is_non_retryable_gemini_error(exc: Exception) -> bool:
|
||||
"""Skip retries for deterministic quota exhaustion and auth errors."""
|
||||
msg = str(exc).lower()
|
||||
return (
|
||||
"resource_exhausted" in msg
|
||||
or "quota exceeded" in msg
|
||||
or "free_tier" in msg
|
||||
or "requestsperday" in msg
|
||||
or "authentication" in msg
|
||||
or "permission denied" in msg
|
||||
or "invalid api key" in msg
|
||||
)
|
||||
|
||||
def _should_retry_gemini_error(exc: Exception) -> bool:
|
||||
return not _is_non_retryable_gemini_error(exc)
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_gemini_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_prompt):
|
||||
"""
|
||||
Generate text response using Google's Gemini Pro model.
|
||||
@@ -182,7 +203,7 @@ def gemini_text_response(prompt, temperature, top_p, n, max_tokens, system_promp
|
||||
#logger.info(f"Number of Token in Prompt Sent: {model.count_tokens(prompt)}")
|
||||
return response.text
|
||||
except Exception as err:
|
||||
logger.error(f"Failed to get response from Gemini: {err}. Retrying.")
|
||||
logger.error(f"Failed to get response from Gemini: {err}")
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Dict, Any
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -76,6 +76,7 @@ logger = get_service_logger("huggingface_provider")
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
@@ -90,10 +91,10 @@ except ImportError:
|
||||
logger.warn("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
HF_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b:groq",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:groq",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:groq",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
"openai/gpt-oss-120b:cerebras",
|
||||
"moonshotai/Kimi-K2-Instruct-0905:cerebras",
|
||||
"meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
]
|
||||
|
||||
|
||||
@@ -102,7 +103,7 @@ def _candidate_model_variants(model: str):
|
||||
if not model:
|
||||
return
|
||||
|
||||
# Try configured model first (supports provider suffixes like ":groq")
|
||||
# Try configured model first (supports provider suffixes like ":cerebras")
|
||||
yield model
|
||||
|
||||
# Fallback to base repo id when provider suffix is not recognized by the router
|
||||
@@ -112,8 +113,13 @@ def _candidate_model_variants(model: str):
|
||||
yield base_model
|
||||
|
||||
|
||||
def _fallback_model_sequence(model: str):
|
||||
sequence = [model] + HF_FALLBACK_MODELS
|
||||
def _fallback_model_sequence(model: str, fallback_models: Optional[List[str]] = None):
|
||||
# IMPORTANT: Do not apply implicit global fallback chains.
|
||||
# Callers must explicitly provide fallback_models when they want multi-model retries.
|
||||
if fallback_models:
|
||||
sequence = [model] + fallback_models
|
||||
else:
|
||||
sequence = [model]
|
||||
seen = set()
|
||||
for preferred_model in sequence:
|
||||
for candidate in _candidate_model_variants(preferred_model):
|
||||
@@ -121,6 +127,57 @@ def _fallback_model_sequence(model: str):
|
||||
seen.add(candidate)
|
||||
yield candidate
|
||||
|
||||
|
||||
def _is_non_retryable_hf_error(exc: Exception) -> bool:
|
||||
"""Skip retries for deterministic HF failures (e.g., unknown model ids, billing)."""
|
||||
msg = str(exc).lower()
|
||||
status = getattr(exc, "status_code", None)
|
||||
|
||||
# Non-retryable errors
|
||||
if isinstance(exc, NotFoundError) or "not found" in msg or "404" in msg:
|
||||
return True
|
||||
if status == 402 or "402" in msg or "depleted" in msg or "credits" in msg:
|
||||
return True
|
||||
if status == 401 or "unauthorized" in msg or "401" in msg:
|
||||
return True
|
||||
if status == 403 or "forbidden" in msg or "403" in msg:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _should_retry_hf_error(exc: Exception) -> bool:
|
||||
return not _is_non_retryable_hf_error(exc)
|
||||
|
||||
|
||||
def _classify_hf_error(exc: Exception) -> str:
|
||||
"""Classify HF failures for actionable logs."""
|
||||
msg = str(exc).lower()
|
||||
if any(token in msg for token in ["insufficient", "balance", "quota", "billing", "payment", "402"]):
|
||||
return "billing_or_quota"
|
||||
if "unauthorized" in msg or "forbidden" in msg or "401" in msg or "403" in msg:
|
||||
return "auth_or_permission"
|
||||
if "not found" in msg or "404" in msg:
|
||||
return "model_not_found"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _hf_error_details(exc: Exception) -> str:
|
||||
"""Return compact, actionable exception details for logs."""
|
||||
status = getattr(exc, "status_code", None)
|
||||
err_type = type(exc).__name__
|
||||
message = str(exc)
|
||||
raw_body = getattr(exc, "body", None)
|
||||
details = f"type={err_type}"
|
||||
if status is not None:
|
||||
details += f", status={status}"
|
||||
if message:
|
||||
details += f", message={message}"
|
||||
if raw_body:
|
||||
details += f", body={raw_body}"
|
||||
details += f", repr={repr(exc)}"
|
||||
return details
|
||||
|
||||
def get_huggingface_api_key() -> str:
|
||||
"""Get Hugging Face API key with proper error handling."""
|
||||
api_key = os.getenv('HF_TOKEN')
|
||||
@@ -137,10 +194,15 @@ def get_huggingface_api_key() -> str:
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_hf_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def huggingface_text_response(
|
||||
prompt: str,
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
model: str = "openai/gpt-oss-120b:cerebras",
|
||||
fallback_models: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
top_p: float = 0.9,
|
||||
@@ -175,7 +237,7 @@ def huggingface_text_response(
|
||||
Example:
|
||||
result = huggingface_text_response(
|
||||
prompt="Write a blog post about AI",
|
||||
model="openai/gpt-oss-120b:groq",
|
||||
model="openai/gpt-oss-120b:cerebras",
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
system_prompt="You are a professional content writer."
|
||||
@@ -194,7 +256,7 @@ def huggingface_text_response(
|
||||
|
||||
# Initialize Hugging Face client
|
||||
client = OpenAI(
|
||||
base_url=f"https://router.huggingface.co/hf/v1",
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ Hugging Face client initialized for text response")
|
||||
@@ -231,27 +293,14 @@ def huggingface_text_response(
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
if candidate_model != model:
|
||||
logger.warning("HF text generation switched to fallback model: {}", candidate_model)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
logger.warning("HF model not found: {}. Trying fallback model.", candidate_model)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
raise last_error or Exception("Hugging Face text generation failed: all fallback models failed")
|
||||
# Call exactly the requested model; no retries, no fallbacks, no variants
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# Extract text from response
|
||||
generated_text = response.choices[0].message.content
|
||||
@@ -267,14 +316,31 @@ def huggingface_text_response(
|
||||
return generated_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Hugging Face text generation failed: {str(e)}")
|
||||
error_class = _classify_hf_error(e)
|
||||
error_details = _hf_error_details(e)
|
||||
logger.error(f"❌ Hugging Face text generation failed: {error_details}")
|
||||
|
||||
# Extra diagnostics: try to capture raw response if available
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
logger.error(f"🔍 HF Error Diagnostics:")
|
||||
logger.error(f" - Status: {e.response.status_code}")
|
||||
logger.error(f" - Headers: {dict(e.response.headers)}")
|
||||
try:
|
||||
body_json = e.response.json()
|
||||
logger.error(f" - Body JSON: {json.dumps(body_json, indent=2)}")
|
||||
except Exception:
|
||||
logger.error(f" - Body Raw: {e.response.text[:1000]}")
|
||||
else:
|
||||
logger.error(f"🔍 No HTTP response attached to exception object.")
|
||||
|
||||
raise Exception(f"Hugging Face text generation failed: {str(e)}")
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def huggingface_structured_json_response(
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
model: str = "openai/gpt-oss-120b:groq",
|
||||
model: str = "openai/gpt-oss-120b:cerebras",
|
||||
fallback_models: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 8192,
|
||||
system_prompt: Optional[str] = None
|
||||
@@ -338,7 +404,7 @@ def huggingface_structured_json_response(
|
||||
# Initialize OpenAI client with Hugging Face base URL
|
||||
# Use standard Inference API endpoint
|
||||
client = OpenAI(
|
||||
base_url=f"https://router.huggingface.co/hf/v1",
|
||||
base_url="https://router.huggingface.co/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ Hugging Face client initialized for structured JSON response")
|
||||
@@ -387,7 +453,7 @@ def huggingface_structured_json_response(
|
||||
try:
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
@@ -444,7 +510,7 @@ def huggingface_structured_json_response(
|
||||
logger.info("Retrying without response_format...")
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
|
||||
@@ -22,6 +22,8 @@ def llm_text_gen(
|
||||
json_struct: Optional[Dict[str, Any]] = None,
|
||||
user_id: str = None,
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
preferred_provider: Optional[str] = None,
|
||||
flow_type: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
@@ -39,12 +41,16 @@ def llm_text_gen(
|
||||
RuntimeError: If subscription limits are exceeded or user_id is missing.
|
||||
"""
|
||||
try:
|
||||
logger.info("[llm_text_gen] Starting text generation")
|
||||
resolved_flow_type = flow_type or ("sif_agent" if preferred_hf_models else "premium_tool")
|
||||
flow_tag = f"flow_type={resolved_flow_type}"
|
||||
subscription_preflight_completed = False
|
||||
|
||||
logger.info(f"[llm_text_gen][{flow_tag}] Starting text generation")
|
||||
logger.debug(f"[llm_text_gen] Prompt length: {len(prompt)} characters")
|
||||
|
||||
# Set default values for LLM parameters
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
gpt_provider = "huggingface" # Default to premium HF route for ALwrity AI tools
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
temperature = 0.7
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
@@ -55,12 +61,87 @@ def llm_text_gen(
|
||||
|
||||
# Check for GPT_PROVIDER environment variable
|
||||
env_provider = os.getenv('GPT_PROVIDER', '').lower()
|
||||
if env_provider in ['gemini', 'google']:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif env_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
provider_list = [p.strip() for p in env_provider.split(',') if p.strip()]
|
||||
|
||||
# Determine if we're in strict mode (single provider) or fallback mode (multiple providers)
|
||||
strict_provider_mode = len(provider_list) == 1
|
||||
|
||||
if provider_list:
|
||||
# Use first provider as primary
|
||||
primary_provider = provider_list[0]
|
||||
if primary_provider in ['gemini', 'google']:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif primary_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
elif primary_provider == 'wavespeed':
|
||||
gpt_provider = "wavespeed"
|
||||
model = "openai/gpt-oss-120b"
|
||||
else:
|
||||
# Auto-detect mode
|
||||
strict_provider_mode = False # Auto-detect allows fallbacks
|
||||
gpt_provider = None
|
||||
model = None
|
||||
|
||||
# Explicit per-call provider override (used by tool-specific flows like podcast maker)
|
||||
if preferred_provider:
|
||||
preferred_providers = [p.strip() for p in preferred_provider.split(',') if p.strip()]
|
||||
# If explicit provider is set, it's strict mode (no cross-provider fallbacks)
|
||||
strict_provider_mode = len(preferred_providers) == 1
|
||||
|
||||
primary_provider = preferred_providers[0]
|
||||
if primary_provider in ['gemini', 'google']:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif primary_provider in ['hf_response_api', 'huggingface', 'hf']:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
elif primary_provider == 'wavespeed':
|
||||
gpt_provider = "wavespeed"
|
||||
model = "openai/gpt-oss-120b"
|
||||
|
||||
# Handle TEXTGEN_AI_MODELS for model selection
|
||||
textgen_models_env = os.getenv('TEXTGEN_AI_MODELS', '').strip()
|
||||
model_list = [m.strip() for m in textgen_models_env.split(',') if m.strip()] if textgen_models_env else []
|
||||
strict_model_mode = len(model_list) == 1
|
||||
|
||||
# Map model names to actual provider models
|
||||
if model_list:
|
||||
if gpt_provider == "huggingface":
|
||||
# Handle both short names and full model names
|
||||
model_mapping = {
|
||||
"gpt-oss": "openai/gpt-oss-120b:cerebras",
|
||||
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
|
||||
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"llama": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:cerebras"
|
||||
}
|
||||
# If model name contains "/", assume it's already a full model name
|
||||
if "/" in model_list[0]:
|
||||
model = model_list[0]
|
||||
else:
|
||||
model = model_mapping.get(model_list[0], model_list[0])
|
||||
elif gpt_provider == "wavespeed":
|
||||
# Handle both short names and full model names
|
||||
model_mapping = {
|
||||
"gpt-oss": "openai/gpt-oss-120b",
|
||||
"gpt-oss-120b": "openai/gpt-oss-120b",
|
||||
"mistral": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"llama": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct"
|
||||
}
|
||||
# If model name contains "/", assume it's already a full model name
|
||||
if "/" in model_list[0]:
|
||||
model = model_list[0]
|
||||
else:
|
||||
model = model_mapping.get(model_list[0], model_list[0])
|
||||
elif gpt_provider == "google":
|
||||
model = "gemini-2.0-flash-001" # Google has fewer options
|
||||
|
||||
# Default blog characteristics
|
||||
blog_tone = "Professional"
|
||||
@@ -77,42 +158,89 @@ def llm_text_gen(
|
||||
available_providers.append("google")
|
||||
if api_key_manager.get_api_key("hf_token"):
|
||||
available_providers.append("huggingface")
|
||||
if api_key_manager.get_api_key("wavespeed"):
|
||||
available_providers.append("wavespeed")
|
||||
logger.info(
|
||||
f"[llm_text_gen][{flow_tag}] Provider preflight: env_provider='{env_provider or 'auto'}', "
|
||||
f"provider_list={provider_list}, strict_provider_mode={strict_provider_mode}, "
|
||||
f"available_providers={available_providers}, preferred_provider={preferred_provider or 'none'}"
|
||||
)
|
||||
|
||||
if model_list:
|
||||
logger.info(
|
||||
f"[llm_text_gen][{flow_tag}] Model configuration: model_list={model_list}, "
|
||||
f"strict_model_mode={strict_model_mode}"
|
||||
)
|
||||
|
||||
# If no environment variable set, auto-detect based on available keys
|
||||
if not env_provider:
|
||||
# Prefer Google Gemini if available, otherwise use Hugging Face
|
||||
if "google" in available_providers:
|
||||
if preferred_provider:
|
||||
# Respect explicit per-call preference if the provider key exists
|
||||
if gpt_provider not in available_providers:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] Preferred provider {gpt_provider} unavailable, falling back to available providers"
|
||||
)
|
||||
if "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
elif "wavespeed" in available_providers:
|
||||
gpt_provider = "wavespeed"
|
||||
model = "openai/gpt-oss-120b"
|
||||
elif "google" in available_providers:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
else:
|
||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
||||
elif preferred_hf_models and "huggingface" in available_providers:
|
||||
# Low-cost SIF/agent flows pass preferred_hf_models; route directly to HF.
|
||||
gpt_provider = "huggingface"
|
||||
model = preferred_hf_models[0]
|
||||
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
|
||||
elif "google" in available_providers:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
elif "wavespeed" in available_providers:
|
||||
gpt_provider = "wavespeed"
|
||||
model = "openai/gpt-oss-120b"
|
||||
else:
|
||||
logger.error("[llm_text_gen] No API keys found for supported providers.")
|
||||
raise RuntimeError("No LLM API keys configured. Configure GEMINI_API_KEY or HF_TOKEN to enable AI responses.")
|
||||
else:
|
||||
# Environment variable was set, validate it's supported
|
||||
if gpt_provider not in available_providers:
|
||||
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
|
||||
if "google" in available_providers:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
if strict_provider_mode:
|
||||
# Strict mode: fail if specified provider not available
|
||||
raise RuntimeError(f"Provider {gpt_provider} not available. Available: {available_providers}")
|
||||
else:
|
||||
raise RuntimeError("No supported providers available.")
|
||||
# Fallback mode: try other providers
|
||||
logger.warning(f"[llm_text_gen] Provider {gpt_provider} not available, falling back to available providers")
|
||||
if "google" in available_providers:
|
||||
gpt_provider = "google"
|
||||
model = "gemini-2.0-flash-001"
|
||||
elif "huggingface" in available_providers:
|
||||
gpt_provider = "huggingface"
|
||||
model = "openai/gpt-oss-120b:cerebras"
|
||||
elif "wavespeed" in available_providers:
|
||||
gpt_provider = "wavespeed"
|
||||
model = "openai/gpt-oss-120b"
|
||||
else:
|
||||
raise RuntimeError("No supported providers available.")
|
||||
|
||||
if gpt_provider == "huggingface" and preferred_hf_models:
|
||||
model = preferred_hf_models[0]
|
||||
logger.info(f"[llm_text_gen] Using preferred low-cost HF model: {model}")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Using provider: {gpt_provider}, model: {model}")
|
||||
logger.info(f"[llm_text_gen][{flow_tag}] Using provider={gpt_provider}, model={model}")
|
||||
|
||||
# Map provider name to APIProvider enum (define at function scope for usage tracking)
|
||||
from models.subscription_models import APIProvider
|
||||
provider_enum = None
|
||||
# Store actual provider name for logging (e.g., "huggingface", "gemini")
|
||||
# Store actual provider name for logging (e.g., "huggingface", "gemini", "wavespeed")
|
||||
actual_provider_name = None
|
||||
if gpt_provider == "google":
|
||||
provider_enum = APIProvider.GEMINI
|
||||
@@ -120,6 +248,9 @@ def llm_text_gen(
|
||||
elif gpt_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
|
||||
actual_provider_name = "huggingface" # Keep actual provider name for logs
|
||||
elif gpt_provider == "wavespeed":
|
||||
provider_enum = APIProvider.WAVESPEED
|
||||
actual_provider_name = "wavespeed"
|
||||
|
||||
if not provider_enum:
|
||||
raise RuntimeError(f"Unknown provider {gpt_provider} for subscription checking")
|
||||
@@ -132,6 +263,11 @@ def llm_text_gen(
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
logger.info(
|
||||
f"[llm_text_gen][{flow_tag}] Starting subscription preflight for user={user_id}, "
|
||||
f"provider={actual_provider_name}, model={model}"
|
||||
)
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
@@ -162,6 +298,12 @@ def llm_text_gen(
|
||||
tokens_requested=estimated_total_tokens,
|
||||
actual_provider_name=actual_provider_name # Pass actual provider name for correct error messages
|
||||
)
|
||||
subscription_preflight_completed = True
|
||||
|
||||
logger.info(
|
||||
f"[llm_text_gen][{flow_tag}] Subscription preflight complete: can_proceed={can_proceed}, "
|
||||
f"estimated_tokens={estimated_total_tokens}, provider={actual_provider_name}"
|
||||
)
|
||||
|
||||
if not can_proceed:
|
||||
logger.warning(f"[llm_text_gen] Subscription limit exceeded for user {user_id}: {message}")
|
||||
@@ -219,6 +361,32 @@ def llm_text_gen(
|
||||
else:
|
||||
system_instructions = system_prompt
|
||||
|
||||
# HF behavior: fail fast on selected model; no intra-provider model fallback chain.
|
||||
hf_fallback_models: List[str] = []
|
||||
|
||||
# Set up model fallbacks based on strict_model_mode
|
||||
if not strict_model_mode and model_list and len(model_list) > 1:
|
||||
# Multi-model mode: create fallback list from TEXTGEN_AI_MODELS
|
||||
if gpt_provider == "huggingface":
|
||||
model_mapping = {
|
||||
"gpt-oss": "openai/gpt-oss-120b:cerebras",
|
||||
"gpt-oss-120b": "openai/gpt-oss-120b:cerebras",
|
||||
"mistral": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"mistral-7b": "mistralai/Mistral-7B-Instruct-v0.3:cerebras",
|
||||
"llama": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"llama-8b": "meta-llama/Llama-3.1-8B-Instruct:cerebras",
|
||||
"llama-70b": "meta-llama/Llama-3.1-70B-Instruct:cerebras"
|
||||
}
|
||||
hf_fallback_models = []
|
||||
for model_name in model_list[1:]:
|
||||
if "/" in model_name:
|
||||
# Full model name, use as-is
|
||||
hf_fallback_models.append(model_name)
|
||||
else:
|
||||
# Short name, map it
|
||||
mapped_model = model_mapping.get(model_name, model_name)
|
||||
hf_fallback_models.append(mapped_model)
|
||||
|
||||
# Generate response based on provider
|
||||
response_text = None
|
||||
actual_provider_used = gpt_provider
|
||||
@@ -249,6 +417,7 @@ def llm_text_gen(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model,
|
||||
fallback_models=hf_fallback_models,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
@@ -257,6 +426,29 @@ def llm_text_gen(
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
fallback_models=hf_fallback_models,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "wavespeed":
|
||||
from .wavespeed_provider import wavespeed_text_response, wavespeed_structured_json_response
|
||||
if json_struct:
|
||||
response_text = wavespeed_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model,
|
||||
fallback_models=None, # No fallbacks for WaveSpeed initially
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
fallback_models=None, # No fallbacks for WaveSpeed initially
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
@@ -264,11 +456,13 @@ def llm_text_gen(
|
||||
)
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
|
||||
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface")
|
||||
raise RuntimeError("Unknown LLM provider. Supported providers: google, huggingface, wavespeed")
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
logger.info(
|
||||
f"[llm_text_gen][{flow_tag}] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}"
|
||||
)
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
@@ -293,16 +487,37 @@ def llm_text_gen(
|
||||
|
||||
return response_text
|
||||
except Exception as provider_error:
|
||||
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
|
||||
logger.error(
|
||||
f"[llm_text_gen][{flow_tag}] Provider {gpt_provider} failed: {str(provider_error)} | "
|
||||
f"subscription_preflight_completed={subscription_preflight_completed} | model={model}"
|
||||
)
|
||||
|
||||
# CIRCUIT BREAKER: Only try ONE fallback to prevent expensive API calls
|
||||
fallback_providers = ["google", "huggingface"]
|
||||
# Use provider list from environment if available, otherwise default
|
||||
if provider_list and len(provider_list) > 1:
|
||||
# Use the specified fallback providers from GPT_PROVIDER
|
||||
fallback_providers = provider_list[1:] # Skip the primary (already tried)
|
||||
else:
|
||||
# Default fallback order
|
||||
fallback_providers = ["google", "huggingface", "wavespeed"]
|
||||
|
||||
# Filter to available providers and exclude current failed provider
|
||||
fallback_providers = [p for p in fallback_providers if p in available_providers and p != gpt_provider]
|
||||
|
||||
# Skip fallbacks if in strict provider mode
|
||||
if strict_provider_mode:
|
||||
logger.info(f"[llm_text_gen][{flow_tag}] Strict provider mode enabled; skipping cross-provider fallback")
|
||||
fallback_providers = []
|
||||
|
||||
if preferred_provider:
|
||||
# Caller explicitly pinned provider (e.g. podcast premium HF). Avoid cross-provider fallback noise.
|
||||
logger.info(f"[llm_text_gen][{flow_tag}] preferred_provider is set; skipping cross-provider fallback")
|
||||
fallback_providers = []
|
||||
|
||||
if fallback_providers:
|
||||
fallback_provider = fallback_providers[0] # Only try the first available
|
||||
try:
|
||||
logger.info(f"[llm_text_gen] Trying SINGLE fallback provider: {fallback_provider}")
|
||||
logger.info(f"[llm_text_gen][{flow_tag}] Trying SINGLE fallback provider: {fallback_provider}")
|
||||
actual_provider_used = fallback_provider
|
||||
|
||||
# Update provider enum for fallback
|
||||
@@ -313,7 +528,11 @@ def llm_text_gen(
|
||||
elif fallback_provider == "huggingface":
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
fallback_model = "mistralai/Mistral-7B-Instruct-v0.3:groq"
|
||||
fallback_model = preferred_hf_models[0] if preferred_hf_models else "openai/gpt-oss-120b:cerebras"
|
||||
elif fallback_provider == "wavespeed":
|
||||
provider_enum = APIProvider.WAVESPEED
|
||||
actual_provider_name = "wavespeed"
|
||||
fallback_model = "openai/gpt-oss-120b"
|
||||
|
||||
if fallback_provider == "google":
|
||||
if json_struct:
|
||||
@@ -340,7 +559,8 @@ def llm_text_gen(
|
||||
response_text = huggingface_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
model=fallback_model,
|
||||
fallback_models=hf_fallback_models,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
@@ -348,7 +568,30 @@ def llm_text_gen(
|
||||
else:
|
||||
response_text = huggingface_text_response(
|
||||
prompt=prompt,
|
||||
model="mistralai/Mistral-7B-Instruct-v0.3:groq",
|
||||
model=fallback_model,
|
||||
fallback_models=hf_fallback_models,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif fallback_provider == "wavespeed":
|
||||
from .wavespeed_provider import wavespeed_text_response, wavespeed_structured_json_response
|
||||
if json_struct:
|
||||
response_text = wavespeed_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=fallback_model,
|
||||
fallback_models=None,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=fallback_model,
|
||||
fallback_models=None,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
@@ -357,7 +600,9 @@ def llm_text_gen(
|
||||
|
||||
# TRACK USAGE after successful fallback call
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
logger.info(
|
||||
f"[llm_text_gen][{flow_tag}] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}"
|
||||
)
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
@@ -376,19 +621,19 @@ def llm_text_gen(
|
||||
|
||||
return response_text
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"[llm_text_gen] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
|
||||
logger.error(f"[llm_text_gen][{flow_tag}] Fallback provider {fallback_provider} also failed: {str(fallback_error)}")
|
||||
|
||||
# CIRCUIT BREAKER: Stop immediately to prevent expensive API calls
|
||||
logger.error("[llm_text_gen] CIRCUIT BREAKER: Stopping to prevent expensive API calls.")
|
||||
logger.error(f"[llm_text_gen][{flow_tag}] CIRCUIT BREAKER: Stopping to prevent expensive API calls.")
|
||||
raise RuntimeError("All LLM providers failed to generate a response.")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[llm_text_gen] Error during text generation: {str(e)}")
|
||||
logger.error(f"[llm_text_gen][{flow_tag}] Error during text generation: {str(e)}")
|
||||
raise
|
||||
|
||||
def check_gpt_provider(gpt_provider: str) -> bool:
|
||||
"""Check if the specified GPT provider is supported."""
|
||||
supported_providers = ["google", "huggingface"]
|
||||
supported_providers = ["google", "huggingface", "wavespeed"]
|
||||
return gpt_provider in supported_providers
|
||||
|
||||
def get_api_key(gpt_provider: str) -> Optional[str]:
|
||||
@@ -397,7 +642,8 @@ def get_api_key(gpt_provider: str) -> Optional[str]:
|
||||
api_key_manager = APIKeyManager()
|
||||
provider_mapping = {
|
||||
"google": "gemini",
|
||||
"huggingface": "hf_token"
|
||||
"huggingface": "hf_token",
|
||||
"wavespeed": "wavespeed"
|
||||
}
|
||||
|
||||
mapped_provider = provider_mapping.get(gpt_provider, gpt_provider)
|
||||
|
||||
527
backend/services/llm_providers/wavespeed_provider.py
Normal file
527
backend/services/llm_providers/wavespeed_provider.py
Normal file
@@ -0,0 +1,527 @@
|
||||
"""
|
||||
WaveSpeed LLM Provider Module for ALwrity
|
||||
|
||||
This module provides functions for interacting with WaveSpeed's LLM API
|
||||
using the OpenAI-compatible interface for text generation.
|
||||
|
||||
Key Features:
|
||||
- Text response generation with retry logic
|
||||
- Comprehensive error handling and logging
|
||||
- Automatic API key management
|
||||
- Support for gpt-oss and other WaveSpeed models
|
||||
- Integration with subscription/preflight checks
|
||||
|
||||
Best Practices:
|
||||
1. Use appropriate temperature for your use case (0.7 for creative, 0.1-0.3 for factual)
|
||||
2. Set max_tokens based on expected response length
|
||||
3. Use system_prompt to guide model behavior
|
||||
4. Handle errors gracefully in calling functions
|
||||
|
||||
Usage Examples:
|
||||
# Text response
|
||||
result = wavespeed_text_response(prompt, temperature=0.7, max_tokens=2048)
|
||||
|
||||
# Structured JSON response
|
||||
schema = {"type": "object", "properties": {"title": {"type": "string"}}}
|
||||
result = wavespeed_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
|
||||
|
||||
Dependencies:
|
||||
- openai (for WaveSpeed OpenAI-compatible API)
|
||||
- tenacity (for retry logic)
|
||||
- logging (for debugging)
|
||||
- json (for fallback parsing)
|
||||
|
||||
Author: ALwrity Team
|
||||
Version: 1.0
|
||||
Last Updated: March 2026
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Dict, Any, List
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Fix the environment loading path - load from backend directory
|
||||
current_dir = Path(__file__).parent.parent # services directory
|
||||
backend_dir = current_dir.parent # backend directory
|
||||
env_path = backend_dir / '.env'
|
||||
|
||||
if env_path.exists():
|
||||
load_dotenv(env_path)
|
||||
print(f"Loaded .env from: {env_path}")
|
||||
else:
|
||||
# Fallback to current directory
|
||||
load_dotenv()
|
||||
print(f"No .env found at {env_path}, using current directory")
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("wavespeed_provider")
|
||||
|
||||
from tenacity import (
|
||||
retry,
|
||||
retry_if_exception,
|
||||
stop_after_attempt,
|
||||
wait_random_exponential,
|
||||
)
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
from openai import NotFoundError
|
||||
OPENAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
OPENAI_AVAILABLE = False
|
||||
NotFoundError = Exception
|
||||
logger.warn("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
# Default WaveSpeed models for fallback
|
||||
WAVESPEED_FALLBACK_MODELS = [
|
||||
"openai/gpt-oss-120b",
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
"mistralai/Mistral-7B-Instruct-v0.3",
|
||||
"google/gemma-7b-it",
|
||||
]
|
||||
|
||||
def _candidate_model_variants(model: str):
|
||||
"""Yield model ids to try for a single logical model preference."""
|
||||
if not model:
|
||||
return
|
||||
|
||||
# Try configured model first
|
||||
yield model
|
||||
|
||||
# Fallback to base repo id when provider suffix is not recognized by the router
|
||||
if ":" in model:
|
||||
base_model = model.split(":", 1)[0]
|
||||
if base_model:
|
||||
yield base_model
|
||||
|
||||
def _fallback_model_sequence(model: str, fallback_models: Optional[List[str]] = None):
|
||||
# IMPORTANT: Do not apply implicit global fallback chains.
|
||||
# Callers must explicitly provide fallback_models when they want multi-model retries.
|
||||
if fallback_models:
|
||||
sequence = [model] + fallback_models
|
||||
else:
|
||||
sequence = [model]
|
||||
seen = set()
|
||||
for preferred_model in sequence:
|
||||
for candidate in _candidate_model_variants(preferred_model):
|
||||
if candidate and candidate not in seen:
|
||||
seen.add(candidate)
|
||||
yield candidate
|
||||
|
||||
def _is_non_retryable_wavespeed_error(exc: Exception) -> bool:
|
||||
"""Skip retries for deterministic WaveSpeed failures (e.g., unknown model ids, billing)."""
|
||||
msg = str(exc).lower()
|
||||
status = getattr(exc, "status_code", None)
|
||||
|
||||
# Non-retryable errors
|
||||
if isinstance(exc, NotFoundError) or "not found" in msg or "404" in msg:
|
||||
return True
|
||||
if status == 402 or "402" in msg or "depleted" in msg or "credits" in msg:
|
||||
return True
|
||||
if status == 401 or "unauthorized" in msg or "401" in msg:
|
||||
return True
|
||||
if status == 403 or "forbidden" in msg or "403" in msg:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _should_retry_wavespeed_error(exc: Exception) -> bool:
|
||||
return not _is_non_retryable_wavespeed_error(exc)
|
||||
|
||||
def _classify_wavespeed_error(exc: Exception) -> str:
|
||||
"""Classify WaveSpeed failures for actionable logs."""
|
||||
msg = str(exc).lower()
|
||||
if any(token in msg for token in ["insufficient", "balance", "quota", "billing", "payment", "402"]):
|
||||
return "billing_or_quota"
|
||||
if "unauthorized" in msg or "forbidden" in msg or "401" in msg or "403" in msg:
|
||||
return "auth_or_permission"
|
||||
if "not found" in msg or "404" in msg:
|
||||
return "model_not_found"
|
||||
return "unknown"
|
||||
|
||||
def _wavespeed_error_details(exc: Exception) -> str:
|
||||
"""Return compact, actionable exception details for logs."""
|
||||
status = getattr(exc, "status_code", None)
|
||||
err_type = type(exc).__name__
|
||||
message = str(exc)
|
||||
raw_body = getattr(exc, "body", None)
|
||||
details = f"type={err_type}"
|
||||
if status is not None:
|
||||
details += f", status={status}"
|
||||
if message:
|
||||
details += f", message={message}"
|
||||
if raw_body:
|
||||
details += f", body={raw_body}"
|
||||
details += f", repr={repr(exc)}"
|
||||
return details
|
||||
|
||||
def get_wavespeed_api_key() -> str:
|
||||
"""Get WaveSpeed API key with proper error handling."""
|
||||
api_key = os.getenv('WAVESPEED_API_KEY')
|
||||
if not api_key:
|
||||
error_msg = "WAVESPEED_API_KEY environment variable is not set. Please set it in your .env file."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Validate API key format (basic check)
|
||||
if not api_key or len(api_key) < 10:
|
||||
error_msg = "WAVESPEED_API_KEY appears to be invalid."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
return api_key
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_wavespeed_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def wavespeed_text_response(
|
||||
prompt: str,
|
||||
model: str = "openai/gpt-oss-120b",
|
||||
fallback_models: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 2048,
|
||||
top_p: float = 0.9,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> str:
|
||||
"""
|
||||
Generate text response using WaveSpeed LLM API.
|
||||
|
||||
This function uses the WaveSpeed OpenAI-compatible API for text generation
|
||||
with built-in retry logic and error handling.
|
||||
|
||||
Args:
|
||||
prompt (str): The input prompt for the AI model
|
||||
model (str): WaveSpeed model identifier (default: "openai/gpt-oss-120b")
|
||||
temperature (float): Controls randomness (0.0-1.0)
|
||||
max_tokens (int): Maximum tokens in response
|
||||
top_p (float): Nucleus sampling parameter (0.0-1.0)
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
|
||||
Returns:
|
||||
str: Generated text response
|
||||
|
||||
Raises:
|
||||
Exception: If API key is missing or API call fails
|
||||
|
||||
Best Practices:
|
||||
- Use appropriate temperature for your use case (0.7 for creative, 0.1-0.3 for factual)
|
||||
- Set max_tokens based on expected response length
|
||||
- Use system_prompt to guide model behavior
|
||||
- Handle errors gracefully in calling functions
|
||||
|
||||
Example:
|
||||
result = wavespeed_text_response(
|
||||
prompt="Write a blog post about AI",
|
||||
model="openai/gpt-oss-120b",
|
||||
temperature=0.7,
|
||||
max_tokens=2048,
|
||||
system_prompt="You are a professional content writer."
|
||||
)
|
||||
"""
|
||||
try:
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
# Get API key with proper error handling
|
||||
api_key = get_wavespeed_api_key()
|
||||
logger.info(f"🔑 WaveSpeed API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
|
||||
|
||||
if not api_key:
|
||||
raise Exception("WAVESPEED_API_KEY not found in environment variables")
|
||||
|
||||
# Initialize WaveSpeed client
|
||||
client = OpenAI(
|
||||
base_url="https://llm.wavespeed.ai/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ WaveSpeed client initialized for text response")
|
||||
|
||||
# Prepare input for the API
|
||||
messages = []
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
})
|
||||
|
||||
# Add user prompt
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": prompt
|
||||
})
|
||||
|
||||
# Add debugging for API call
|
||||
logger.info(
|
||||
"WaveSpeed text call | model={} | prompt_len={} | temp={} | top_p={} | max_tokens={}",
|
||||
model,
|
||||
len(prompt) if isinstance(prompt, str) else '<non-str>',
|
||||
temperature,
|
||||
top_p,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
logger.info("🚀 Making WaveSpeed API call (chat completion)...")
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
# Call exactly the requested model; no retries, no fallbacks, no variants
|
||||
response = client.chat.completions.create(
|
||||
model=model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
# Extract text from response
|
||||
generated_text = response.choices[0].message.content
|
||||
|
||||
# Clean up the response
|
||||
if generated_text:
|
||||
# Remove any markdown formatting if present
|
||||
generated_text = re.sub(r'```[a-zA-Z]*\n?', '', generated_text)
|
||||
generated_text = re.sub(r'```\n?', '', generated_text)
|
||||
generated_text = generated_text.strip()
|
||||
|
||||
logger.info(f"✅ WaveSpeed text response generated successfully (length: {len(generated_text)})")
|
||||
return generated_text
|
||||
|
||||
except Exception as e:
|
||||
error_class = _classify_wavespeed_error(e)
|
||||
error_details = _wavespeed_error_details(e)
|
||||
logger.error(f"❌ WaveSpeed text generation failed: {error_details}")
|
||||
|
||||
# Extra diagnostics: try to capture raw response if available
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
logger.error(f"🔍 WaveSpeed Error Diagnostics:")
|
||||
logger.error(f" - Status: {e.response.status_code}")
|
||||
logger.error(f" - Headers: {dict(e.response.headers)}")
|
||||
try:
|
||||
body_json = e.response.json()
|
||||
logger.error(f" - Body JSON: {json.dumps(body_json, indent=2)}")
|
||||
except Exception:
|
||||
logger.error(f" - Body Raw: {e.response.text[:1000]}")
|
||||
else:
|
||||
logger.error(f"🔍 No HTTP response attached to exception object.")
|
||||
|
||||
raise Exception(f"WaveSpeed text generation failed: {str(e)}")
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def wavespeed_structured_json_response(
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
model: str = "openai/gpt-oss-120b",
|
||||
fallback_models: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 8192,
|
||||
system_prompt: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate structured JSON response using WaveSpeed LLM API.
|
||||
|
||||
This function uses the WaveSpeed OpenAI-compatible API with structured output support
|
||||
to generate JSON responses that match a provided schema.
|
||||
|
||||
Args:
|
||||
prompt (str): The input prompt for the AI model
|
||||
schema (dict): JSON schema defining the expected output structure
|
||||
model (str): WaveSpeed model identifier (default: "openai/gpt-oss-120b")
|
||||
temperature (float): Controls randomness (0.0-1.0). Use 0.1-0.3 for structured output
|
||||
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
|
||||
Returns:
|
||||
dict: Parsed JSON response matching the provided schema
|
||||
|
||||
Raises:
|
||||
Exception: If API key is missing or API call fails
|
||||
|
||||
Best Practices:
|
||||
- Keep schemas simple and flat to avoid truncation
|
||||
- Use low temperature (0.1-0.3) for consistent structured output
|
||||
- Set max_tokens to 8192 for complex multi-field responses
|
||||
- Avoid deeply nested schemas with many required fields
|
||||
- Test with smaller outputs first, then scale up
|
||||
|
||||
Example:
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
result = wavespeed_structured_json_response(prompt, schema, temperature=0.2, max_tokens=8192)
|
||||
"""
|
||||
try:
|
||||
if not OPENAI_AVAILABLE:
|
||||
raise ImportError("OpenAI library not available. Install with: pip install openai")
|
||||
|
||||
# Get API key with proper error handling
|
||||
api_key = get_wavespeed_api_key()
|
||||
logger.info(f"🔑 WaveSpeed API key loaded: {bool(api_key)} (length: {len(api_key) if api_key else 0})")
|
||||
|
||||
if not api_key:
|
||||
raise Exception("WAVESPEED_API_KEY not found in environment variables")
|
||||
|
||||
# Initialize OpenAI client with WaveSpeed base URL
|
||||
client = OpenAI(
|
||||
base_url="https://llm.wavespeed.ai/v1",
|
||||
api_key=api_key,
|
||||
)
|
||||
logger.info("✅ WaveSpeed client initialized for structured JSON response")
|
||||
|
||||
# Prepare input for the API
|
||||
messages = []
|
||||
|
||||
# Add system prompt if provided
|
||||
if system_prompt:
|
||||
messages.append({
|
||||
"role": "system",
|
||||
"content": system_prompt
|
||||
})
|
||||
|
||||
# Add user prompt with JSON instruction
|
||||
json_instruction = "Please respond with valid JSON that matches the provided schema."
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"{prompt}\n\n{json_instruction}"
|
||||
})
|
||||
|
||||
# Add debugging for API call
|
||||
logger.info(
|
||||
"WaveSpeed structured call | model={} | prompt_len={} | schema_kind={} | temp={} | max_tokens={}",
|
||||
model,
|
||||
len(prompt) if isinstance(prompt, str) else '<non-str>',
|
||||
type(schema).__name__,
|
||||
temperature,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
logger.info("🚀 Making WaveSpeed structured API call...")
|
||||
|
||||
# Add JSON schema to prompt for guidance
|
||||
json_schema_str = json.dumps(schema, indent=2)
|
||||
messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}"
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
try:
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"} # Try to enforce JSON mode if supported
|
||||
)
|
||||
if candidate_model != model:
|
||||
logger.warning("WaveSpeed structured generation switched to fallback model: {}", candidate_model)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
logger.warning("WaveSpeed structured model not found: {}. Trying fallback model.", candidate_model)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
raise last_error or Exception("WaveSpeed structured generation failed: all fallback models failed")
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
|
||||
# Clean up response text if needed
|
||||
response_text = response_text.strip()
|
||||
if response_text.startswith("```json"):
|
||||
response_text = response_text[7:]
|
||||
if response_text.endswith("```"):
|
||||
response_text = response_text[:-3]
|
||||
response_text = response_text.strip()
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"❌ JSON parsing failed: {json_err}")
|
||||
logger.error(f"Raw response: {response_text}")
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
extracted_json = json.loads(json_match.group())
|
||||
logger.info("✅ JSON extracted using regex fallback")
|
||||
return extracted_json
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WaveSpeed API call failed: {e}")
|
||||
# If 422 Unprocessable Entity (often due to response_format not supported), retry without it
|
||||
if "422" in str(e) or "not supported" in str(e).lower() or isinstance(e, NotFoundError):
|
||||
logger.info("Retrying without response_format...")
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
if candidate_model != model:
|
||||
logger.warning("WaveSpeed structured no-response-format fallback model: {}", candidate_model)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
logger.warning("WaveSpeed structured model not found (no response_format path): {}", candidate_model)
|
||||
continue
|
||||
|
||||
if response is None:
|
||||
raise last_error or e
|
||||
response_text = response.choices[0].message.content
|
||||
# ... (same parsing logic would apply, simplified here for brevity)
|
||||
try:
|
||||
return json.loads(response_text)
|
||||
except:
|
||||
# Regex fallback
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e) if str(e) else repr(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"❌ WaveSpeed structured JSON generation failed [{error_type}]: {error_msg}")
|
||||
raise Exception(f"WaveSpeed structured JSON generation failed: {error_msg}")
|
||||
@@ -22,30 +22,45 @@ class PodcastBibleService:
|
||||
logger.info(f"Generating Podcast Bible for user {user_id}")
|
||||
|
||||
try:
|
||||
preferences = self.personalization_service.get_user_preferences(user_id)
|
||||
preferences = self.personalization_service.get_user_preferences(user_id) or {}
|
||||
if not isinstance(preferences, dict):
|
||||
logger.warning(f"Podcast Bible preferences payload is non-dict for user {user_id}, using defaults")
|
||||
preferences = {}
|
||||
|
||||
writing_style = preferences.get("writing_style", {})
|
||||
if not isinstance(writing_style, dict):
|
||||
writing_style = {}
|
||||
|
||||
style_prefs = preferences.get("style_preferences", {})
|
||||
if not isinstance(style_prefs, dict):
|
||||
style_prefs = {}
|
||||
|
||||
target_audience = preferences.get("target_audience", {})
|
||||
if not isinstance(target_audience, dict):
|
||||
target_audience = {}
|
||||
|
||||
industry = preferences.get("industry", "General Business")
|
||||
if not isinstance(industry, str) or not industry.strip():
|
||||
industry = "General Business"
|
||||
|
||||
# 1. Map Host Persona
|
||||
host = HostPersona(
|
||||
name="Your AI Host",
|
||||
background=f"Expert in {industry}",
|
||||
expertise_level=writing_style.get("complexity", "Expert").capitalize(),
|
||||
expertise_level=str(writing_style.get("complexity") or "Expert").capitalize(),
|
||||
personality_traits=[
|
||||
writing_style.get("tone", "Professional").capitalize(),
|
||||
writing_style.get("engagement_level", "Informative").capitalize()
|
||||
str(writing_style.get("tone") or "Professional").capitalize(),
|
||||
str(writing_style.get("engagement_level") or "Informative").capitalize()
|
||||
],
|
||||
vocal_style=writing_style.get("voice", "Authoritative").capitalize(),
|
||||
vocal_characteristics=["Clear", "Articulate", writing_style.get("voice", "Steady")],
|
||||
vocal_style=str(writing_style.get("voice") or "Authoritative").capitalize(),
|
||||
vocal_characteristics=["Clear", "Articulate", str(writing_style.get("voice") or "Steady")],
|
||||
look=f"A professional individual dressed in business-casual attire, fitting the {industry} industry aesthetic.",
|
||||
catchphrases=[]
|
||||
)
|
||||
|
||||
# 2. Map Audience DNA
|
||||
audience = AudienceDNA(
|
||||
expertise_level=target_audience.get("expertise_level", "Intermediate").capitalize(),
|
||||
expertise_level=str(target_audience.get("expertise_level") or "Intermediate").capitalize(),
|
||||
interests=target_audience.get("interests", ["Industry Trends", "Innovation"]),
|
||||
pain_points=target_audience.get("pain_points", ["Staying ahead of competition", "Efficiency"]),
|
||||
demographics=None
|
||||
@@ -54,15 +69,15 @@ class PodcastBibleService:
|
||||
# 3. Map Brand DNA
|
||||
brand = BrandDNA(
|
||||
industry=industry,
|
||||
tone=writing_style.get("tone", "Professional").capitalize(),
|
||||
communication_style=writing_style.get("engagement_level", "Informative").capitalize(),
|
||||
tone=str(writing_style.get("tone") or "Professional").capitalize(),
|
||||
communication_style=str(writing_style.get("engagement_level") or "Informative").capitalize(),
|
||||
key_messages=preferences.get("brand_values", []),
|
||||
competitor_context=None
|
||||
)
|
||||
|
||||
# 4. Map Visual Style
|
||||
visual = VisualStyle(
|
||||
style_preset=style_prefs.get("aesthetic", "Professional Studio").capitalize(),
|
||||
style_preset=str(style_prefs.get("aesthetic") or "Professional Studio").capitalize(),
|
||||
environment=f"A modern {industry}-themed podcast studio with professional equipment.",
|
||||
lighting="Soft, warm studio lighting with subtle rim lights.",
|
||||
color_palette=preferences.get("brand_colors", ["#1e293b", "#3b82f6"]),
|
||||
@@ -72,7 +87,7 @@ class PodcastBibleService:
|
||||
# 5. Map Audio Environment
|
||||
audio_env = AudioEnvironment(
|
||||
soundscape="Pristine studio environment with deep, warm acoustics.",
|
||||
music_mood=f"{writing_style.get('tone', 'Professional').capitalize()} & {writing_style.get('engagement_level', 'Upbeat').capitalize()}",
|
||||
music_mood=f"{str(writing_style.get('tone') or 'Professional').capitalize()} & {str(writing_style.get('engagement_level') or 'Upbeat').capitalize()}",
|
||||
sfx_style="Modern, clean interface-inspired sounds."
|
||||
)
|
||||
|
||||
@@ -80,11 +95,11 @@ class PodcastBibleService:
|
||||
show_rules = ShowRules(
|
||||
intro_format=f"Start with a high-energy hook about the episode topic, followed by a warm welcome and an overview of the {industry} insights to be shared.",
|
||||
outro_format="Summarize the key takeaways, provide a clear call to action, and sign off with a professional closing.",
|
||||
interaction_tone=writing_style.get("engagement_level", "Conversational").capitalize(),
|
||||
interaction_tone=str(writing_style.get("engagement_level") or "Conversational").capitalize(),
|
||||
constraints=[
|
||||
"Avoid overly technical jargon unless defined",
|
||||
"Keep segments concise and factual",
|
||||
f"Maintain a {writing_style.get('tone', 'Professional')} tone at all times"
|
||||
f"Maintain a {str(writing_style.get('tone') or 'Professional')} tone at all times"
|
||||
]
|
||||
)
|
||||
|
||||
@@ -102,7 +117,7 @@ class PodcastBibleService:
|
||||
return bible
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Podcast Bible: {str(e)}")
|
||||
logger.error(f"Error generating Podcast Bible: {str(e)}", exc_info=True)
|
||||
# Return a default bible if something goes wrong to ensure project creation doesn't fail
|
||||
return self._get_default_bible(project_id)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ Extracts ALL onboarding data and provides personalized defaults for forms and re
|
||||
from typing import Dict, Any, Optional, List
|
||||
from loguru import logger
|
||||
|
||||
from services.database import SessionLocal
|
||||
from services.database import get_session_for_user
|
||||
from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService
|
||||
|
||||
|
||||
@@ -20,6 +20,14 @@ class PersonalizationService:
|
||||
"""Initialize Personalization Service."""
|
||||
self.logger = logger
|
||||
logger.info("[Personalization Service] Initialized")
|
||||
|
||||
@staticmethod
|
||||
def _as_dict(value: Any) -> Dict[str, Any]:
|
||||
return value if isinstance(value, dict) else {}
|
||||
|
||||
@staticmethod
|
||||
def _as_list(value: Any) -> List[Any]:
|
||||
return value if isinstance(value, list) else []
|
||||
|
||||
def get_user_preferences(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
@@ -36,20 +44,36 @@ class PersonalizationService:
|
||||
- templates: Recommended templates for user's industry
|
||||
- channels: Recommended channels based on platform personas
|
||||
"""
|
||||
db = SessionLocal()
|
||||
db = None
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[Personalization] No DB session available for user {user_id}; using default preferences")
|
||||
return self._get_default_preferences()
|
||||
|
||||
integration_service = OnboardingDataIntegrationService()
|
||||
integrated_data = integration_service.get_integrated_data_sync(user_id, db)
|
||||
if not isinstance(integrated_data, dict):
|
||||
logger.warning(
|
||||
f"[Personalization] Integrated onboarding payload is non-dict for user {user_id}; using defaults"
|
||||
)
|
||||
integrated_data = {}
|
||||
|
||||
canonical_profile = integrated_data.get('canonical_profile', {})
|
||||
if not isinstance(canonical_profile, dict):
|
||||
logger.warning(
|
||||
f"[Personalization] Canonical profile is non-dict for user {user_id}; using defaults"
|
||||
)
|
||||
canonical_profile = {}
|
||||
|
||||
# Map strictly from Canonical Profile
|
||||
preferences = {
|
||||
"industry": canonical_profile.get("industry"),
|
||||
"target_audience": canonical_profile.get("target_audience", {}),
|
||||
"platform_preferences": canonical_profile.get("platform_preferences", []),
|
||||
"content_preferences": canonical_profile.get("content_types", []),
|
||||
"style_preferences": canonical_profile.get("visual_style", {}),
|
||||
"brand_colors": canonical_profile.get("brand_colors", []),
|
||||
"target_audience": self._as_dict(canonical_profile.get("target_audience", {})),
|
||||
"platform_preferences": self._as_list(canonical_profile.get("platform_preferences", [])),
|
||||
"content_preferences": self._as_list(canonical_profile.get("content_types", [])),
|
||||
"style_preferences": self._as_dict(canonical_profile.get("visual_style", {})),
|
||||
"brand_colors": self._as_list(canonical_profile.get("brand_colors", [])),
|
||||
"recommended_templates": [],
|
||||
"recommended_channels": [],
|
||||
"writing_style": {
|
||||
@@ -58,7 +82,7 @@ class PersonalizationService:
|
||||
"complexity": canonical_profile.get("writing_complexity", "intermediate"),
|
||||
"engagement_level": canonical_profile.get("writing_engagement", "moderate"),
|
||||
},
|
||||
"brand_values": canonical_profile.get("brand_values", []),
|
||||
"brand_values": self._as_list(canonical_profile.get("brand_values", [])),
|
||||
}
|
||||
|
||||
# Ensure target_audience structure
|
||||
@@ -104,7 +128,8 @@ class PersonalizationService:
|
||||
logger.error(f"[Personalization] Error getting user preferences: {str(e)}", exc_info=True)
|
||||
return self._get_default_preferences()
|
||||
finally:
|
||||
db.close()
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
def get_personalized_defaults(
|
||||
self,
|
||||
|
||||
@@ -13,6 +13,7 @@ from models.website_analysis_monitoring_models import (
|
||||
SIFIndexingTask,
|
||||
SIFIndexingExecutionLog
|
||||
)
|
||||
from models.onboarding import OnboardingSession
|
||||
from services.scheduler.core.executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from services.scheduler.core.failure_detection_service import FailureDetectionService
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
@@ -57,6 +58,36 @@ class SIFIndexingExecutor(TaskExecutor):
|
||||
|
||||
try:
|
||||
logger.info(f"Executing SIF indexing for user {user_id} ({website_url})")
|
||||
|
||||
onboarding_session = (
|
||||
db.query(OnboardingSession)
|
||||
.filter(OnboardingSession.user_id == user_id)
|
||||
.order_by(OnboardingSession.updated_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not onboarding_session:
|
||||
logger.info(
|
||||
f"Skipping SIF indexing for user {user_id}: no onboarding session found. "
|
||||
"Pausing task until onboarding completes."
|
||||
)
|
||||
task.last_executed = datetime.utcnow()
|
||||
task.status = "paused"
|
||||
task.next_execution = None
|
||||
|
||||
task_log.status = "skipped"
|
||||
task_log.result_data = {
|
||||
"reason": "no_onboarding_session",
|
||||
"website_url": website_url,
|
||||
}
|
||||
task_log.execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
db.commit()
|
||||
|
||||
return TaskExecutionResult(
|
||||
success=False,
|
||||
result_data=task_log.result_data,
|
||||
execution_time_ms=task_log.execution_time_ms,
|
||||
retryable=False,
|
||||
)
|
||||
|
||||
# Initialize SIF Service
|
||||
sif_service = SIFIntegrationService(user_id)
|
||||
|
||||
@@ -12,7 +12,7 @@ from datetime import datetime
|
||||
from sqlalchemy import select, desc
|
||||
import json
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_session_for_user, has_onboarding_session
|
||||
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
|
||||
|
||||
# Import existing SIF components
|
||||
@@ -1081,8 +1081,14 @@ class SIFIntegrationAPI:
|
||||
def __init__(self):
|
||||
self.services: Dict[str, SIFIntegrationService] = {}
|
||||
|
||||
def get_service(self, user_id: str) -> SIFIntegrationService:
|
||||
def get_service(self, user_id: str) -> Optional[SIFIntegrationService]:
|
||||
"""Get or create SIF service for a user."""
|
||||
if not has_onboarding_session(user_id):
|
||||
logger.debug(
|
||||
"Skipping SIF service creation for user {} via SIFIntegrationAPI: no onboarding session",
|
||||
user_id,
|
||||
)
|
||||
return None
|
||||
if user_id not in self.services:
|
||||
self.services[user_id] = SIFIntegrationService(user_id)
|
||||
return self.services[user_id]
|
||||
@@ -1090,11 +1096,25 @@ class SIFIntegrationAPI:
|
||||
async def get_semantic_insights_with_cache(self, user_id: str, website_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Get semantic insights with caching metadata."""
|
||||
service = self.get_service(user_id)
|
||||
if not service:
|
||||
return {
|
||||
"source": "skipped",
|
||||
"reason": "no_onboarding_session",
|
||||
"insights": {},
|
||||
}
|
||||
return await service.get_semantic_insights(website_data)
|
||||
|
||||
async def get_cache_performance(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get cache performance metrics for a user."""
|
||||
service = self.get_service(user_id)
|
||||
if not service:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"cache_enabled": False,
|
||||
"performance": {},
|
||||
"reason": "no_onboarding_session",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
stats = service.get_cache_performance_stats()
|
||||
|
||||
return {
|
||||
@@ -1107,6 +1127,13 @@ class SIFIntegrationAPI:
|
||||
async def invalidate_user_cache(self, user_id: str, reason: str = "api_request") -> Dict[str, Any]:
|
||||
"""Invalidate cache for a specific user."""
|
||||
service = self.get_service(user_id)
|
||||
if not service:
|
||||
return {
|
||||
"user_id": user_id,
|
||||
"success": False,
|
||||
"reason": "no_onboarding_session",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
success = await service.invalidate_user_cache(reason)
|
||||
|
||||
return {
|
||||
|
||||
@@ -79,10 +79,11 @@ class UsageTrackingService:
|
||||
# Calculate costs
|
||||
# Use specific model names instead of generic defaults
|
||||
default_models = {
|
||||
"gemini": "gemini-2.5-flash", # Use Flash as default (cost-effective)
|
||||
"openai": "gpt-4o-mini", # Use Mini as default (cost-effective)
|
||||
"anthropic": "claude-3.5-sonnet", # Use Sonnet as default
|
||||
"mistral": "openai/gpt-oss-120b:groq" # HuggingFace default model
|
||||
APIProvider.GEMINI: "gemini-2.5-flash", # Use Flash as default (cost-effective)
|
||||
APIProvider.OPENAI: "gpt-4o-mini", # Use Mini as default (cost-effective)
|
||||
APIProvider.ANTHROPIC: "claude-3.5-sonnet", # Use Sonnet as default
|
||||
APIProvider.MISTRAL: "openai/gpt-oss-120b:groq", # HuggingFace default model
|
||||
APIProvider.WAVESPEED: "openai/gpt-oss-120b" # WaveSpeed default model
|
||||
}
|
||||
|
||||
# For HuggingFace (stored as MISTRAL), use the actual model name or default
|
||||
@@ -91,9 +92,9 @@ class UsageTrackingService:
|
||||
if model_used:
|
||||
model_name = model_used
|
||||
else:
|
||||
model_name = default_models.get("mistral", "openai/gpt-oss-120b:groq")
|
||||
model_name = default_models.get(APIProvider.MISTRAL, "openai/gpt-oss-120b:groq")
|
||||
else:
|
||||
model_name = model_used or default_models.get(provider.value, f"{provider.value}-default")
|
||||
model_name = model_used or default_models.get(provider, f"{provider.value}-default")
|
||||
|
||||
cost_data = self.pricing_service.calculate_api_cost(
|
||||
provider=provider,
|
||||
@@ -199,7 +200,7 @@ class UsageTrackingService:
|
||||
setattr(summary, f"{provider_name}_calls", current_calls + 1)
|
||||
|
||||
# Update token usage for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL, APIProvider.WAVESPEED]:
|
||||
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
|
||||
|
||||
@@ -901,12 +902,14 @@ class UsageTrackingService:
|
||||
summary.openai_calls = 0
|
||||
summary.anthropic_calls = 0
|
||||
summary.mistral_calls = 0
|
||||
summary.wavespeed_calls = 0
|
||||
|
||||
# Reset all LLM provider token counters
|
||||
summary.gemini_tokens = 0
|
||||
summary.openai_tokens = 0
|
||||
summary.anthropic_tokens = 0
|
||||
summary.mistral_tokens = 0
|
||||
summary.wavespeed_tokens = 0
|
||||
|
||||
# Reset search/research provider counters
|
||||
summary.tavily_calls = 0
|
||||
@@ -932,6 +935,7 @@ class UsageTrackingService:
|
||||
summary.openai_cost = 0.0
|
||||
summary.anthropic_cost = 0.0
|
||||
summary.mistral_cost = 0.0
|
||||
summary.wavespeed_cost = 0.0
|
||||
summary.tavily_cost = 0.0
|
||||
summary.serper_cost = 0.0
|
||||
summary.metaphor_cost = 0.0
|
||||
|
||||
@@ -68,30 +68,72 @@ class SpeechGenerator:
|
||||
model_path = "minimax/speech-02-hd"
|
||||
url = f"{self.base_url}/{model_path}"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice_id": voice_id,
|
||||
"speed": speed,
|
||||
"volume": volume,
|
||||
"pitch": pitch,
|
||||
"emotion": emotion,
|
||||
"enable_sync_mode": enable_sync_mode,
|
||||
# Sanitize and validate parameters
|
||||
sanitized_text = str(text).strip()
|
||||
if not sanitized_text:
|
||||
raise ValueError("Text cannot be empty after sanitization")
|
||||
|
||||
sanitized_voice_id = str(voice_id).strip()
|
||||
if not sanitized_voice_id:
|
||||
raise ValueError("Voice ID cannot be empty after sanitization")
|
||||
|
||||
# Ensure numeric parameters are proper floats and within valid ranges
|
||||
sanitized_speed = max(0.5, min(2.0, float(speed))) if speed is not None else 1.0
|
||||
sanitized_volume = max(0.1, min(10.0, float(volume))) if volume is not None else 1.0
|
||||
sanitized_pitch = max(-12.0, min(12.0, float(pitch))) if pitch is not None else 0.0
|
||||
|
||||
# Sanitize emotion parameter - remove newlines and extra whitespace
|
||||
sanitized_emotion = str(emotion).strip().replace('\n', '').replace('\r', '')
|
||||
|
||||
# Map common emotions to minimax valid values
|
||||
emotion_mapping = {
|
||||
'neutral': 'neutral',
|
||||
'happy': 'happy',
|
||||
'sad': 'sad',
|
||||
'angry': 'angry',
|
||||
'excited': 'happy',
|
||||
'calm': 'neutral',
|
||||
'friendly': 'happy',
|
||||
'professional': 'neutral',
|
||||
'warm': 'happy',
|
||||
'serious': 'neutral'
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
# Use mapped emotion or default to 'happy'
|
||||
mapped_emotion = emotion_mapping.get(sanitized_emotion.lower(), 'happy')
|
||||
|
||||
payload = {
|
||||
"text": sanitized_text,
|
||||
"voice_id": sanitized_voice_id,
|
||||
"speed": sanitized_speed,
|
||||
"volume": sanitized_volume,
|
||||
"pitch": sanitized_pitch,
|
||||
"emotion": mapped_emotion,
|
||||
"enable_sync_mode": bool(enable_sync_mode),
|
||||
}
|
||||
|
||||
# Add optional parameters with proper type validation
|
||||
optional_params = [
|
||||
"english_normalization",
|
||||
"sample_rate",
|
||||
"sample_rate",
|
||||
"bitrate",
|
||||
"channel",
|
||||
"format",
|
||||
"language_boost",
|
||||
]
|
||||
for param in optional_params:
|
||||
if param in kwargs:
|
||||
payload[param] = kwargs[param]
|
||||
if param in kwargs and kwargs[param] is not None:
|
||||
value = kwargs[param]
|
||||
# Convert to appropriate type based on parameter
|
||||
if param == "english_normalization":
|
||||
payload[param] = bool(value)
|
||||
elif param in ["sample_rate", "bitrate"]:
|
||||
payload[param] = int(value) if value is not None else None
|
||||
else:
|
||||
payload[param] = str(value).strip() if value is not None else None
|
||||
|
||||
logger.info(f"[WaveSpeed] Generating speech via {url} (voice={voice_id}, text_length={len(text)})")
|
||||
logger.debug(f"[WaveSpeed] Payload being sent: {payload}")
|
||||
|
||||
# Retry on transient connection issues
|
||||
max_retries = 2
|
||||
|
||||
Reference in New Issue
Block a user