feat(podcast): add pre-estimate endpoint, enhance cost estimator with multi-model support, cleanup alpha pricing seeding
- Add POST /podcast/pre-estimate endpoint for cost estimation before analysis - Enhance cost_estimator.py with multi-model support (gemini, audio, voice clone, image, video) - Add detailed cost breakdown (llm, audio, media costs + per-phase breakdown) - Remove redundant pricing seeding from init_alpha_subscription_tiers.py - Add SSOT pricing via PricingService.initialize_default_pricing() - Update TopicUrlInput tooltip to show estimate details - Add debug logging for pricing seeding and pre-estimate - Clean up verbose podcast mode debug logs in app.py
This commit is contained in:
@@ -3,6 +3,13 @@ Podcast cost estimation helpers.
|
||||
|
||||
Builds user-facing podcast estimates from the subscription pricing catalog
|
||||
instead of hard-coded frontend heuristics.
|
||||
|
||||
Supports multiple models for each component:
|
||||
- Audio TTS: minimax/speech-02-hd (default), qwen3-tts, cosyvoice-tts
|
||||
- Voice Clone: qwen3, cosyvoice, minimax
|
||||
- Image: qwen-image (default), ideogram-v3-turbo
|
||||
- Video: wan-2.5 (default), kling-v2.5, infinitetalk
|
||||
- LLM: gemini-2.5-flash (default)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -23,6 +30,7 @@ def _load_pricing(
|
||||
provider: APIProvider,
|
||||
preferred_model: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pricing for a provider and model, with fallback to default."""
|
||||
pricing = pricing_service.get_pricing_for_provider_model(provider, preferred_model)
|
||||
if pricing:
|
||||
return pricing
|
||||
@@ -30,6 +38,17 @@ def _load_pricing(
|
||||
return pricing_service.get_pricing_for_provider_model(provider, "default")
|
||||
|
||||
|
||||
# Default models used in podcast generation
|
||||
DEFAULT_MODELS = {
|
||||
"gemini": "gemini-2.5-flash",
|
||||
"exa": "exa-search",
|
||||
"audio_tts": "minimax/speech-02-hd",
|
||||
"voice_clone": "wavespeed-ai/qwen3-tts/voice-clone",
|
||||
"image": "qwen-image",
|
||||
"video": "wan-2.5",
|
||||
}
|
||||
|
||||
|
||||
def estimate_podcast_cost(
|
||||
*,
|
||||
db: Session,
|
||||
@@ -37,88 +56,150 @@ def estimate_podcast_cost(
|
||||
speakers: int,
|
||||
query_count: int,
|
||||
include_avatar_phase: bool = True,
|
||||
# Optional model overrides
|
||||
gemini_model: str = "gemini-2.5-flash",
|
||||
audio_tts_model: str = "minimax/speech-02-hd",
|
||||
voice_clone_engine: str = "qwen3",
|
||||
image_model: str = "qwen-image",
|
||||
video_model: str = "wan-2.5",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Compute a backend estimate for podcast creation.
|
||||
|
||||
Returns None when pricing rows are unavailable so UI can display "Unavailable".
|
||||
|
||||
Supports customizable models for each component.
|
||||
Uses pricing_catalog for accurate cost calculation.
|
||||
"""
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
gemini_pricing = _load_pricing(pricing_service, APIProvider.GEMINI, "gemini-2.5-flash")
|
||||
# Load pricing for each component and model
|
||||
gemini_pricing = _load_pricing(pricing_service, APIProvider.GEMINI, gemini_model)
|
||||
exa_pricing = _load_pricing(pricing_service, APIProvider.EXA, "exa-search")
|
||||
audio_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, "minimax/speech-02-hd")
|
||||
video_pricing = _load_pricing(pricing_service, APIProvider.VIDEO, "default")
|
||||
image_pricing = _load_pricing(pricing_service, APIProvider.STABILITY, "qwen-image")
|
||||
|
||||
# Audio TTS pricing (minimax/speech-02-hd)
|
||||
audio_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, audio_tts_model)
|
||||
|
||||
# Voice clone pricing (different engines)
|
||||
voice_clone_model = f"wavespeed-ai/{voice_clone_engine}-tts/voice-clone"
|
||||
voice_clone_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, voice_clone_model)
|
||||
if not voice_clone_pricing:
|
||||
# Try alternate model names
|
||||
voice_clone_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, f"{voice_clone_engine}/voice-clone")
|
||||
|
||||
# Image pricing (qwen-image or ideogram)
|
||||
image_pricing = _load_pricing(pricing_service, APIProvider.STABILITY, image_model)
|
||||
|
||||
# Video pricing (wan-2.5, kling, or infinitetalk)
|
||||
video_pricing = _load_pricing(pricing_service, APIProvider.VIDEO, video_model)
|
||||
|
||||
# Return None if critical pricing unavailable (fail fast)
|
||||
if not gemini_pricing:
|
||||
return None
|
||||
|
||||
# Configuration
|
||||
minutes = max(1, int(duration_minutes or 1))
|
||||
speaker_count = max(1, int(speakers or 1))
|
||||
research_queries = max(1, int(query_count or 1))
|
||||
|
||||
# Phase-level usage assumptions (token/request proxies for pre-creation estimate).
|
||||
# Token usage assumptions per phase
|
||||
analysis_input_tokens = 1800
|
||||
analysis_output_tokens = 1000
|
||||
research_synthesis_input_tokens = 2200
|
||||
research_synthesis_output_tokens = 900
|
||||
script_input_tokens = max(1800, minutes * 300)
|
||||
script_output_tokens = max(2200, minutes * 700)
|
||||
|
||||
# TTS token proxy: ~900 chars per minute per speaker.
|
||||
|
||||
# TTS: ~900 chars per minute per speaker
|
||||
estimated_tts_tokens = max(900, minutes * 900 * speaker_count)
|
||||
|
||||
# Voice clone: 1 clone operation per speaker
|
||||
voice_clone_count = speaker_count
|
||||
|
||||
# ===== COST CALCULATIONS =====
|
||||
|
||||
# 1. Analysis phase (LLM)
|
||||
analysis_cost = (
|
||||
analysis_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ analysis_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
+ float(gemini_pricing.get("cost_per_request") or 0.0)
|
||||
)
|
||||
|
||||
# 2. Research phase
|
||||
# 2a. LLM for research synthesis
|
||||
research_llm_cost = (
|
||||
research_synthesis_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ research_synthesis_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
+ float(gemini_pricing.get("cost_per_request") or 0.0)
|
||||
)
|
||||
script_cost = (
|
||||
script_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ script_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
+ float(gemini_pricing.get("cost_per_request") or 0.0)
|
||||
)
|
||||
|
||||
# 2b. Search API (Exa)
|
||||
research_search_cost = 0.0
|
||||
if exa_pricing:
|
||||
research_search_cost = research_queries * float(exa_pricing.get("cost_per_request") or 0.0)
|
||||
research_cost = research_search_cost + research_llm_cost
|
||||
|
||||
# 3. Script generation (LLM)
|
||||
script_cost = (
|
||||
script_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ script_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
)
|
||||
|
||||
# 4. Audio TTS
|
||||
tts_cost = 0.0
|
||||
if audio_pricing:
|
||||
tts_cost = (
|
||||
estimated_tts_tokens * float(audio_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ float(audio_pricing.get("cost_per_request") or 0.0)
|
||||
tts_cost = estimated_tts_tokens * float(audio_pricing.get("cost_per_input_token") or 0.0)
|
||||
|
||||
# 5. Voice cloning (if needed)
|
||||
voice_clone_cost = 0.0
|
||||
if voice_clone_pricing:
|
||||
voice_clone_cost = voice_clone_count * (
|
||||
float(voice_clone_pricing.get("cost_per_request") or 0.0)
|
||||
+ estimated_tts_tokens * float(voice_clone_pricing.get("cost_per_input_token") or 0.0)
|
||||
)
|
||||
|
||||
# Assume one video render request per minute (upper-bound planning estimate).
|
||||
video_cost = 0.0
|
||||
if video_pricing:
|
||||
video_cost = minutes * float(video_pricing.get("cost_per_request") or 0.0)
|
||||
|
||||
# 6. Avatar image generation
|
||||
avatar_cost = 0.0
|
||||
if include_avatar_phase and image_pricing:
|
||||
image_unit = float(image_pricing.get("cost_per_image") or image_pricing.get("cost_per_request") or 0.0)
|
||||
avatar_cost = speaker_count * image_unit
|
||||
|
||||
research_cost = research_search_cost + research_llm_cost
|
||||
total = analysis_cost + research_cost + script_cost + tts_cost + video_cost + avatar_cost
|
||||
# 7. Video rendering
|
||||
video_cost = 0.0
|
||||
if video_pricing:
|
||||
# Assume 1 video render per minute (upper bound)
|
||||
video_cost = minutes * float(video_pricing.get("cost_per_request") or 0.0)
|
||||
|
||||
# ===== TOTALS =====
|
||||
llm_total = analysis_cost + research_llm_cost + script_cost
|
||||
audio_total = tts_cost + voice_clone_cost
|
||||
media_total = avatar_cost + video_cost
|
||||
total = llm_total + research_search_cost + audio_total + media_total
|
||||
|
||||
return {
|
||||
# Cost breakdown
|
||||
"analysisCost": _round_money(analysis_cost),
|
||||
"researchCost": _round_money(research_cost),
|
||||
"researchSearchCost": _round_money(research_search_cost),
|
||||
"researchLlmCost": _round_money(research_llm_cost),
|
||||
"scriptCost": _round_money(script_cost),
|
||||
"ttsCost": _round_money(tts_cost),
|
||||
"voiceCloneCost": _round_money(voice_clone_cost),
|
||||
"avatarCost": _round_money(avatar_cost),
|
||||
"videoCost": _round_money(video_cost),
|
||||
"researchCost": _round_money(research_cost),
|
||||
"analysisCost": _round_money(analysis_cost),
|
||||
"scriptCost": _round_money(script_cost),
|
||||
"total": _round_money(total),
|
||||
# Totals by category
|
||||
"llmCost": _round_money(llm_total),
|
||||
"audioCost": _round_money(audio_total),
|
||||
"mediaCost": _round_money(media_total),
|
||||
# Currency
|
||||
"currency": "USD",
|
||||
"source": "pricing_catalog",
|
||||
# Models used for this estimate
|
||||
"models": {
|
||||
"llm": gemini_model,
|
||||
"research": "exa-search",
|
||||
"audio_tts": audio_tts_model,
|
||||
"voice_clone": voice_clone_model,
|
||||
"image": image_model,
|
||||
"video": video_model,
|
||||
},
|
||||
# Assumptions used
|
||||
"assumptions": {
|
||||
"analysis_input_tokens": analysis_input_tokens,
|
||||
"analysis_output_tokens": analysis_output_tokens,
|
||||
@@ -128,6 +209,8 @@ def estimate_podcast_cost(
|
||||
"script_output_tokens": script_output_tokens,
|
||||
"estimated_tts_tokens": estimated_tts_tokens,
|
||||
"research_queries": research_queries,
|
||||
"voice_clone_count": voice_clone_count,
|
||||
"video_requests": minutes,
|
||||
"avatar_requests": speaker_count if include_avatar_phase else 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -4,8 +4,9 @@ Podcast Analysis Handlers
|
||||
Analysis endpoint for podcast ideas.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -21,11 +22,18 @@ from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
import os
|
||||
from ..constants import get_podcast_media_dir
|
||||
from ..prompts import get_enhance_topic_prompt, format_website_context
|
||||
from ..models import (
|
||||
PodcastAnalyzeRequest,
|
||||
PodcastAnalyzeResponse,
|
||||
PodcastEnhanceIdeaRequest,
|
||||
PodcastEnhanceIdeaResponse
|
||||
PodcastEnhanceIdeaResponse,
|
||||
ExtractUrlRequest,
|
||||
ExtractUrlResponse,
|
||||
WebsiteAnalysisRequest,
|
||||
WebsiteAnalysisResponse,
|
||||
PodcastPreEstimateRequest,
|
||||
PodcastPreEstimateResponse,
|
||||
)
|
||||
from ..cost_estimator import estimate_podcast_cost
|
||||
|
||||
@@ -37,6 +45,74 @@ def _is_podcast_only_mode() -> bool:
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/pre-estimate", response_model=PodcastPreEstimateResponse)
|
||||
async def pre_estimate_cost(
|
||||
request: PodcastPreEstimateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Lightweight endpoint to estimate podcast creation cost before analysis.
|
||||
|
||||
Takes user configuration (duration, speakers, query_count, podcast_mode) and returns
|
||||
a cost estimate WITHOUT running full analysis.
|
||||
|
||||
Optional model overrides can be specified to estimate with different models.
|
||||
"""
|
||||
try:
|
||||
include_avatar_phase = request.podcast_mode != "audio_only"
|
||||
|
||||
estimate = estimate_podcast_cost(
|
||||
db=db,
|
||||
duration_minutes=request.duration,
|
||||
speakers=request.speakers,
|
||||
query_count=request.query_count,
|
||||
include_avatar_phase=include_avatar_phase,
|
||||
# Model overrides if provided
|
||||
gemini_model=request.gemini_model or "gemini-2.5-flash",
|
||||
audio_tts_model=request.audio_tts_model or "minimax/speech-02-hd",
|
||||
voice_clone_engine=request.voice_clone_engine or "qwen3",
|
||||
image_model=request.image_model or "qwen-image",
|
||||
video_model=request.video_model or "wan-2.5",
|
||||
)
|
||||
|
||||
# Debug: get pricing row count and providers
|
||||
from models.subscription_models import APIProviderPricing
|
||||
pricing_count = db.query(APIProviderPricing).count()
|
||||
providers = db.query(APIProviderPricing.provider).distinct().all()
|
||||
provider_list = sorted([p[0].value for p in providers]) if providers else []
|
||||
|
||||
debug_info = {
|
||||
"pricing_rows": pricing_count,
|
||||
"providers": provider_list,
|
||||
}
|
||||
|
||||
# Log pricing debug info at warning level
|
||||
logger.warning(f"[PRE-ESTIMATE] Pricing debug: rows={pricing_count}, providers={provider_list}")
|
||||
logger.warning(f"[PRE-ESTIMATE] Models: llm={request.gemini_model}, tts={request.audio_tts_model}, video={request.video_model}")
|
||||
|
||||
if estimate is None:
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=None,
|
||||
error="Pricing data unavailable. Please try again later.",
|
||||
pricing_available=False,
|
||||
debug=debug_info,
|
||||
)
|
||||
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=estimate,
|
||||
error=None,
|
||||
pricing_available=True,
|
||||
debug=debug_info,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pre-estimate error: {e}")
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=None,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/idea/enhance", response_model=PodcastEnhanceIdeaResponse)
|
||||
async def enhance_podcast_idea(
|
||||
request: PodcastEnhanceIdeaRequest,
|
||||
@@ -77,39 +153,27 @@ async def enhance_podcast_idea(
|
||||
except Exception as exc:
|
||||
logger.debug(f"[Podcast Enhance] Bible parsing skipped in podcast mode: {exc}")
|
||||
|
||||
prompt = f"""
|
||||
You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
# Log what's being used for context
|
||||
context_used = []
|
||||
if bible_context:
|
||||
context_used.append("Podcast Bible")
|
||||
if request.website_data:
|
||||
context_used.append("Website Extraction")
|
||||
if request.topic_context:
|
||||
category = request.topic_context.get("category", "unknown")
|
||||
context_used.append(f"Category Research ({category})")
|
||||
|
||||
logger.warning(f"[Podcast Enhance] Generating with context: {', '.join(context_used) if context_used else 'basic idea only'}")
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{request.idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}}
|
||||
"""
|
||||
# Use new context builder for prompt generation
|
||||
from services.podcast_context_builder import context_builder
|
||||
context_result = context_builder.build_enhance_context(
|
||||
idea=request.idea,
|
||||
bible_context=bible_context,
|
||||
website_data=request.website_data,
|
||||
topic_context=request.topic_context,
|
||||
)
|
||||
prompt = context_result["prompt"]
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
@@ -502,3 +566,316 @@ Requirements:
|
||||
except Exception as exc:
|
||||
logger.error(f"[Regenerate Queries] Failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Regenerate queries failed: {exc}")
|
||||
|
||||
|
||||
@router.post("/extract-url", response_model=ExtractUrlResponse)
|
||||
async def extract_url_content(
|
||||
request: ExtractUrlRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Extract content from a URL using Exa's get_contents API.
|
||||
|
||||
This allows users to paste a blog post or article URL as their podcast topic,
|
||||
and we'll extract the content to use as the podcast idea.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
from exa_py import Exa
|
||||
import os
|
||||
|
||||
api_key = os.getenv("EXA_API_KEY")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="EXA_API_KEY not configured")
|
||||
|
||||
exa = Exa(api_key)
|
||||
|
||||
logger.warning(f"[ExtractUrl] Extracting content from: {request.url} for user {user_id}")
|
||||
|
||||
try:
|
||||
result = exa.get_contents(
|
||||
urls=[request.url],
|
||||
text=True,
|
||||
highlights=True,
|
||||
summary=True,
|
||||
subpages=2,
|
||||
)
|
||||
except Exception as exa_error:
|
||||
logger.error(f"[ExtractUrl] Exa call error: {exa_error}")
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error=f"Exa API error: {str(exa_error)}"
|
||||
)
|
||||
|
||||
# Check for errors using the correct attribute (statuses is array of status objects)
|
||||
if hasattr(result, 'statuses') and result.statuses:
|
||||
for status in result.statuses:
|
||||
if status.status == "error":
|
||||
logger.error(f"[ExtractUrl] Failed to extract {status.id}: {status.error.tag if hasattr(status.error, 'tag') else 'unknown'}")
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error=f"Failed to extract content: {status.error.tag if hasattr(status.error, 'tag') else 'unknown error'}"
|
||||
)
|
||||
|
||||
if not result.results:
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error="No content found at the provided URL"
|
||||
)
|
||||
|
||||
# Extract content - safe to access result now
|
||||
content = result.results[0]
|
||||
|
||||
# Extract all available fields from Exa response
|
||||
extracted_text = content.text or ""
|
||||
extracted_summary = getattr(content, 'summary', "") or ""
|
||||
extracted_title = content.title or ""
|
||||
|
||||
# Highlights - extract from content.highlights array if available
|
||||
highlights = []
|
||||
if hasattr(content, 'highlights') and content.highlights:
|
||||
highlights = [h for h in content.highlights if h]
|
||||
|
||||
# Additional fields from Exa response
|
||||
image = getattr(content, 'image', None)
|
||||
favicon = getattr(content, 'favicon', None)
|
||||
|
||||
# Subpages - extract with their own content
|
||||
subpages = []
|
||||
if hasattr(content, 'subpages') and content.subpages:
|
||||
for sp in content.subpages:
|
||||
subpages.append({
|
||||
'id': sp.get('id', ''),
|
||||
'title': sp.get('title', ''),
|
||||
'url': sp.get('url', ''),
|
||||
'summary': sp.get('summary', ''),
|
||||
'text': sp.get('text', '')[:500] if sp.get('text') else '', # First 500 chars
|
||||
})
|
||||
|
||||
logger.warning(f"[ExtractUrl] Successfully extracted {len(extracted_text)} chars from {request.url}")
|
||||
logger.warning(f"[ExtractUrl] title={extracted_title[:50]}, summary={extracted_summary[:50]}, highlights={len(highlights)}, subpages={len(subpages)}")
|
||||
|
||||
return ExtractUrlResponse(
|
||||
success=True,
|
||||
title=extracted_title,
|
||||
text=extracted_text,
|
||||
summary=extracted_summary,
|
||||
author=getattr(content, 'author', None),
|
||||
highlights=highlights,
|
||||
url=request.url,
|
||||
image=image,
|
||||
favicon=favicon,
|
||||
subpages=subpages,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/website-analysis", response_model=WebsiteAnalysisResponse)
|
||||
async def save_website_analysis(
|
||||
request: WebsiteAnalysisRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save the user's website analysis for reuse in future podcasts."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.user_data_service import user_data_service
|
||||
|
||||
website_data = {
|
||||
"website_url": request.website_url,
|
||||
"extracted_at": datetime.now().isoformat(),
|
||||
"exa_content": request.exa_content,
|
||||
"full_analysis": None,
|
||||
"analysis_status": "pending",
|
||||
}
|
||||
|
||||
success = user_data_service.save_user_data(
|
||||
user_id=user_id,
|
||||
data_key="website_analysis",
|
||||
data_value=website_data,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.warning(f"[WebsiteAnalysis] Saved analysis for user {user_id}: {request.website_url}")
|
||||
return WebsiteAnalysisResponse(
|
||||
success=True,
|
||||
website_url=request.website_url,
|
||||
message="Website analysis saved successfully",
|
||||
)
|
||||
else:
|
||||
return WebsiteAnalysisResponse(
|
||||
success=False,
|
||||
error="Failed to save website analysis",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteAnalysis] Failed to save for user {user_id}: {exc}")
|
||||
return WebsiteAnalysisResponse(
|
||||
success=False,
|
||||
error=f"Failed to save: {str(exc)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/website-extraction")
|
||||
async def get_saved_website_extraction(request: Request = None):
|
||||
"""Get previously saved website extraction data for this user."""
|
||||
try:
|
||||
# Safely get current_user from Depends
|
||||
if request is None or not hasattr(request, 'state'):
|
||||
logger.warning("[WebsiteExtraction] No request or state - user not authenticated")
|
||||
return {"success": False, "data": None, "error": "Not authenticated"}
|
||||
|
||||
current_user = getattr(request.state, 'user', None)
|
||||
if not current_user:
|
||||
logger.warning("[WebsiteExtraction] No user in request state")
|
||||
return {"success": False, "data": None, "error": "Not authenticated"}
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
from services.user_data_service import UserDataService
|
||||
from services.database import get_db
|
||||
db = next(get_db())
|
||||
|
||||
user_service = UserDataService(db)
|
||||
extraction = user_service.get_website_extraction(user_id)
|
||||
|
||||
if extraction:
|
||||
logger.info(f"[WebsiteExtraction] Found saved data for user {user_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": extraction
|
||||
}
|
||||
else:
|
||||
logger.info(f"[WebsiteExtraction] No saved data for user {user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
"data": None
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteExtraction] Failed for user: {exc}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/website-extraction")
|
||||
async def save_website_extraction(
|
||||
extraction: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save website extraction data for future use."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.user_data_service import UserDataService
|
||||
from services.database import get_db
|
||||
db = next(get_db())
|
||||
|
||||
user_service = UserDataService(db)
|
||||
success = user_service.save_website_extraction(user_id, extraction)
|
||||
|
||||
if success:
|
||||
logger.info(f"[WebsiteExtraction] Saved for user {user_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Website extraction saved"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Failed to save"
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteExtraction] Save failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/project/{project_id}/topic-context")
|
||||
async def save_topic_context(
|
||||
project_id: str,
|
||||
topic_context: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save topic context (category research) to a podcast project."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from models.podcast_models import PodcastProject
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
# Find the project
|
||||
project = db.query(PodcastProject).filter(
|
||||
PodcastProject.project_id == project_id,
|
||||
PodcastProject.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not project:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Project not found"
|
||||
}
|
||||
|
||||
# Update topic context
|
||||
project.topic_context = topic_context
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[TopicContext] Saved for project {project_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Topic context saved"
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[TopicContext] Save failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/project/{project_id}/topic-context")
|
||||
async def get_topic_context(
|
||||
project_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get topic context from a podcast project."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from models.podcast_models import PodcastProject
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
project = db.query(PodcastProject).filter(
|
||||
PodcastProject.project_id == project_id,
|
||||
PodcastProject.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not project:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Project not found"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": project.topic_context
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[TopicContext] Get failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
251
backend/api/podcast/handlers/tavily_category_research.py
Normal file
251
backend/api/podcast/handlers/tavily_category_research.py
Normal file
@@ -0,0 +1,251 @@
|
||||
"""
|
||||
Category Research Handlers
|
||||
|
||||
Research endpoints using Tavily or Exa for category-based topic discovery.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from types import SimpleNamespace
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.research.tavily_service import TavilyService
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
|
||||
router = APIRouter(prefix="/research", tags=["Podcast Category Research"])
|
||||
|
||||
CATEGORY_PROVIDER_MAP = {
|
||||
"news": "tavily",
|
||||
"finance": "tavily",
|
||||
"research-paper": "exa",
|
||||
"personal-site": "exa",
|
||||
}
|
||||
|
||||
EXA_CATEGORY_MAP = {
|
||||
"research-paper": "research paper",
|
||||
"personal-site": "personal site",
|
||||
}
|
||||
|
||||
|
||||
class CategoryResearchRequest(BaseModel):
|
||||
category: str
|
||||
keyword: Optional[str] = None
|
||||
max_results: Optional[int] = 8
|
||||
website_url: Optional[str] = None
|
||||
|
||||
|
||||
class CategoryTopic(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
score: float
|
||||
favicon: Optional[str] = None
|
||||
|
||||
|
||||
class CategoryResearchResponse(BaseModel):
|
||||
success: bool
|
||||
category: str
|
||||
provider: str
|
||||
topics: List[CategoryTopic]
|
||||
query: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def _normalize_tavily_results(results: List[Dict]) -> List[CategoryTopic]:
|
||||
topics = []
|
||||
for item in results:
|
||||
topics.append(CategoryTopic(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("content", ""),
|
||||
score=item.get("score", 0.0),
|
||||
favicon=item.get("favicon"),
|
||||
))
|
||||
return topics
|
||||
|
||||
|
||||
def _normalize_exa_results(results: List[Dict], query: str) -> List[CategoryTopic]:
|
||||
topics = []
|
||||
for idx, item in enumerate(results):
|
||||
score = 1.0 - (idx * 0.1)
|
||||
topics.append(CategoryTopic(
|
||||
title=item.get("title", "") or f"Result {idx + 1}",
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("summary", "") or item.get("text", "") or "",
|
||||
score=max(0.5, score),
|
||||
favicon=None,
|
||||
))
|
||||
return topics
|
||||
|
||||
|
||||
async def _search_tavily(category: str, keyword: str, max_results: int) -> CategoryResearchResponse:
|
||||
logger.info(f"[CategoryResearch] Using Tavily for category={category}, keyword={keyword}")
|
||||
|
||||
try:
|
||||
tavily = TavilyService()
|
||||
result = await tavily.search(
|
||||
query=keyword,
|
||||
topic=category,
|
||||
search_depth="basic",
|
||||
max_results=max_results,
|
||||
include_favicon=True,
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=result.get("error", "Tavily search failed")
|
||||
)
|
||||
|
||||
topics = _normalize_tavily_results(result.get("results", []))
|
||||
logger.info(f"[CategoryResearch] Tavily found {len(topics)} topics")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
provider="tavily",
|
||||
topics=topics,
|
||||
query=keyword,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Tavily error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _search_exa(category: str, keyword: str, max_results: int, website_url: Optional[str] = None) -> CategoryResearchResponse:
|
||||
exa_category = EXA_CATEGORY_MAP.get(category, category)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa: category={category}, exa_category={exa_category}, keyword={keyword}, website_url={website_url}")
|
||||
|
||||
try:
|
||||
# Import exa directly for more control
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
exa_api_key = os.getenv("EXA_API_KEY")
|
||||
if not exa_api_key:
|
||||
raise HTTPException(status_code=500, detail="EXA_API_KEY not configured")
|
||||
|
||||
from exa_py import Exa
|
||||
exa = Exa(exa_api_key)
|
||||
logger.info(f"[CategoryResearch] Exa client initialized")
|
||||
|
||||
# Build search parameters
|
||||
search_params = {
|
||||
"num_results": max_results,
|
||||
"category": exa_category,
|
||||
}
|
||||
|
||||
# For personal-site, extract domain from URL if provided
|
||||
include_domains = None
|
||||
if category == "personal-site" and website_url:
|
||||
try:
|
||||
parsed = urlparse(website_url)
|
||||
if parsed.netloc:
|
||||
include_domains = [parsed.netloc]
|
||||
logger.info(f"[CategoryResearch] Personal site - limiting to domain: {parsed.netloc}")
|
||||
elif parsed.path and "." in parsed.path:
|
||||
# Could be domain without protocol
|
||||
include_domains = [parsed.path]
|
||||
logger.info(f"[CategoryResearch] Personal site - using as domain: {parsed.path}")
|
||||
except Exception as url_err:
|
||||
logger.warning(f"[CategoryResearch] Failed to parse website_url: {url_err}")
|
||||
|
||||
logger.info(f"[CategoryResearch] Calling Exa with params: {search_params}, include_domains={include_domains}")
|
||||
|
||||
# Make the search call
|
||||
results = exa.search_and_contents(
|
||||
query=keyword,
|
||||
type="auto" if category != "personal-site" else "neural",
|
||||
num_results=max_results,
|
||||
category=exa_category,
|
||||
text=True,
|
||||
summary=True,
|
||||
include_domains=include_domains,
|
||||
)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa search completed, got results")
|
||||
|
||||
# Transform results to our format
|
||||
topics = []
|
||||
if results and hasattr(results, 'results'):
|
||||
for item in results.results:
|
||||
title = getattr(item, 'title', 'Untitled')
|
||||
url = getattr(item, 'url', '')
|
||||
snippet = getattr(item, 'summary', '') or getattr(item, 'text', '') or ''
|
||||
score = 0.8 # Default score for Exa results
|
||||
|
||||
topics.append(CategoryTopic(
|
||||
title=title,
|
||||
url=url,
|
||||
snippet=snippet[:300] if snippet else '',
|
||||
score=score,
|
||||
favicon=None,
|
||||
))
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa found {len(topics)} topics")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
provider="exa",
|
||||
topics=topics,
|
||||
query=keyword,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"[CategoryResearch] Exa error: {type(e).__name__}: {e}")
|
||||
logger.error(f"[CategoryResearch] Stack: {traceback.format_exc()}")
|
||||
raise HTTPException(status_code=500, detail=f"Exa search failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tavily-category", response_model=CategoryResearchResponse)
|
||||
async def research_by_category(
|
||||
request: CategoryResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Research topics by category using Tavily or Exa.
|
||||
|
||||
Categories:
|
||||
- news, finance: Uses Tavily
|
||||
- research-paper, personal-site: Uses Exa
|
||||
"""
|
||||
category = request.category.lower()
|
||||
valid_categories = list(CATEGORY_PROVIDER_MAP.keys())
|
||||
|
||||
logger.info(f"[CategoryResearch] Full request payload: category={request.category}, keyword={request.keyword}, website_url={request.website_url}")
|
||||
|
||||
if category not in valid_categories:
|
||||
logger.error(f"[CategoryResearch] Invalid category: {category}, valid: {valid_categories}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Category must be one of: {', '.join(valid_categories)}"
|
||||
)
|
||||
|
||||
keyword = request.keyword or category
|
||||
max_results = min(max(request.max_results or 8, 5), 10)
|
||||
website_url = request.website_url
|
||||
|
||||
logger.info(f"[CategoryResearch] Processing: category={category}, keyword={keyword}, max_results={max_results}, website_url={website_url}")
|
||||
|
||||
provider = CATEGORY_PROVIDER_MAP.get(category, "tavily")
|
||||
logger.info(f"[CategoryResearch] Selected provider: {provider} for category: {category}")
|
||||
|
||||
try:
|
||||
if provider == "tavily":
|
||||
return await _search_tavily(category, keyword, max_results)
|
||||
elif provider == "exa":
|
||||
return await _search_exa(category, keyword, max_results, website_url)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unknown provider")
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Outer error: {type(e).__name__}: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -18,6 +18,7 @@ class PodcastTrendsRequest(BaseModel):
|
||||
keywords: List[str] = Field(..., min_length=1, max_length=5, description="1-5 keywords to analyze")
|
||||
timeframe: str = Field(default="today 12-m", description="Timeframe: 'today 3-m', 'today 12-m', 'today 5-y', 'all'")
|
||||
geo: str = Field(default="US", description="Country code: 'US', 'GB', 'IN', etc.")
|
||||
source: str = Field(default="web", description="Data source: 'web' (Google), 'podcast' (YouTube)")
|
||||
|
||||
|
||||
class PodcastTrendsResponse(BaseModel):
|
||||
@@ -47,12 +48,39 @@ async def get_podcast_trends(
|
||||
|
||||
try:
|
||||
service = GoogleTrendsService()
|
||||
# Map 'source' to 'gprop' - 'podcast' uses YouTube for video/podcast relevance
|
||||
gprop_map = {"": "", "web": "", "podcast": "youtube", "news": "news", "images": "images", "shopping": "froogle"}
|
||||
gprop = gprop_map.get(request.source, "")
|
||||
|
||||
result = await service.analyze_trends(
|
||||
keywords=request.keywords,
|
||||
timeframe=request.timeframe,
|
||||
geo=request.geo,
|
||||
gprop=gprop,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
has_error = result.get("error")
|
||||
has_data = (
|
||||
len(result.get("interest_over_time", [])) > 0
|
||||
or len(result.get("interest_by_region", [])) > 0
|
||||
or len(result.get("related_topics", {}).get("top", [])) > 0
|
||||
or len(result.get("related_topics", {}).get("rising", [])) > 0
|
||||
or len(result.get("related_queries", {}).get("top", [])) > 0
|
||||
or len(result.get("related_queries", {}).get("rising", [])) > 0
|
||||
)
|
||||
|
||||
# Return error if: has error OR no data (meaning blocked/empty)
|
||||
if has_error and not has_data:
|
||||
error_msg = result.get("error", "")
|
||||
logger.warning(f"[Trends] No data or error: {error_msg[:100]}")
|
||||
return PodcastTrendsResponse(success=False, data=result, error=error_msg or "No trends data available. Google may be blocking requests.")
|
||||
|
||||
# Even if no error but empty data - return error
|
||||
if not has_data:
|
||||
logger.warning("[Trends] Empty data returned")
|
||||
return PodcastTrendsResponse(success=False, data=result, error="No trends data available. Please try different keywords.")
|
||||
|
||||
return PodcastTrendsResponse(success=True, data=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
@@ -80,6 +80,14 @@ class PodcastEnhanceIdeaRequest(BaseModel):
|
||||
"""Request model for enhancing a podcast idea with AI."""
|
||||
idea: str = Field(..., description="The raw podcast idea or keywords")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
website_data: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional website extraction data for enriched context (title, summary, highlights, subpages, url)"
|
||||
)
|
||||
topic_context: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional category research context (category, topics, selected_topic)"
|
||||
)
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaResponse(BaseModel):
|
||||
@@ -470,3 +478,59 @@ class VoiceCloneResult(BaseModel):
|
||||
file_size: int
|
||||
task_id: str
|
||||
status: str = "completed"
|
||||
|
||||
|
||||
class ExtractUrlRequest(BaseModel):
|
||||
"""Request to extract content from a URL using Exa."""
|
||||
url: str = Field(..., description="URL to extract content from")
|
||||
|
||||
|
||||
class ExtractUrlResponse(BaseModel):
|
||||
"""Response with extracted content from URL."""
|
||||
success: bool
|
||||
title: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
highlights: Optional[List[str]] = Field(default_factory=list, description="Key highlights from the content")
|
||||
url: str
|
||||
image: Optional[str] = None
|
||||
favicon: Optional[str] = None
|
||||
subpages: Optional[List[Dict[str, Any]]] = Field(default_factory=list, description="Subpages with their own content")
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class WebsiteAnalysisRequest(BaseModel):
|
||||
"""Request to save user's website analysis."""
|
||||
website_url: str = Field(..., description="The website URL")
|
||||
exa_content: Dict[str, Any] = Field(default_factory=dict, description="Exa extracted content")
|
||||
|
||||
|
||||
class WebsiteAnalysisResponse(BaseModel):
|
||||
"""Response for website analysis."""
|
||||
success: bool
|
||||
website_url: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastPreEstimateRequest(BaseModel):
|
||||
"""Request model for pre-analysis cost estimate."""
|
||||
duration: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
query_count: int = Field(default=3, description="Number of research queries")
|
||||
podcast_mode: str = Field(default="audio_video", description="Podcast mode: audio_only, video_only, or audio_video")
|
||||
# Optional model overrides for cost estimation
|
||||
gemini_model: Optional[str] = Field(default=None, description="LLM model: gemini-2.5-flash, gemini-1.5-flash, etc.")
|
||||
audio_tts_model: Optional[str] = Field(default=None, description="Audio TTS model: minimax/speech-02-hd")
|
||||
voice_clone_engine: Optional[str] = Field(default=None, description="Voice clone engine: qwen3, cosyvoice, minimax")
|
||||
image_model: Optional[str] = Field(default=None, description="Image model: qwen-image, ideogram-v3-turbo")
|
||||
video_model: Optional[str] = Field(default=None, description="Video model: wan-2.5, kling-v2.5-turbo-std-5s, wavespeed-ai/infinitetalk")
|
||||
|
||||
|
||||
class PodcastPreEstimateResponse(BaseModel):
|
||||
"""Response model for pre-analysis cost estimate."""
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
pricing_available: bool = Field(default=False, description="Whether pricing data is available in DB")
|
||||
debug: Optional[Dict[str, Any]] = Field(default=None, description="Debug info: pricing rows count, providers")
|
||||
|
||||
24
backend/api/podcast/prompts/__init__.py
Normal file
24
backend/api/podcast/prompts/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Prompts module for podcast topic enhancement.
|
||||
"""
|
||||
|
||||
from .website_enhance_prompts import (
|
||||
get_enhance_topic_prompt,
|
||||
format_website_context,
|
||||
STANDARD_ENHANCE_PROMPT,
|
||||
WEBSITE_AWARE_ENHANCE_PROMPT,
|
||||
)
|
||||
|
||||
from services.podcast_context_builder import (
|
||||
PodcastContextBuilder,
|
||||
context_builder,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_enhance_topic_prompt",
|
||||
"format_website_context",
|
||||
"STANDARD_ENHANCE_PROMPT",
|
||||
"WEBSITE_AWARE_ENHANCE_PROMPT",
|
||||
"PodcastContextBuilder",
|
||||
"context_builder",
|
||||
]
|
||||
187
backend/api/podcast/prompts/website_enhance_prompts.py
Normal file
187
backend/api/podcast/prompts/website_enhance_prompts.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Website-aware prompts for podcast topic enhancement.
|
||||
|
||||
This module provides prompts for enhancing podcast topics with optional
|
||||
website extraction data for richer context.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from string import Template
|
||||
|
||||
|
||||
# Standard prompt for when no website data is available
|
||||
STANDARD_ENHANCE_PROMPT = Template("""">You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
|
||||
${bible_context}
|
||||
|
||||
RAW IDEA/KEYWORDS: "$idea"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
# Website-aware prompt for when website data is available
|
||||
WEBSITE_AWARE_ENHANCE_PROMPT = Template("""">You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea, enriched with website content analysis.
|
||||
|
||||
${bible_context}
|
||||
|
||||
WEBSITE CONTENT ANALYSIS:
|
||||
${website_context}
|
||||
|
||||
RAW IDEA/KEYWORDS: "$idea"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle, that INCORPORATE the website content context:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise from the website)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections tied to the brand)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance leveraging the site's focus areas)
|
||||
|
||||
Each version should:
|
||||
- Be 2-3 sentences
|
||||
- Reference specific elements from the website content when relevant
|
||||
- Be audience-focused and align with host persona if provided
|
||||
- NOT just repeat the website summary - create fresh podcast angles
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
def get_enhance_topic_prompt(
|
||||
idea: str,
|
||||
bible_context: str = "",
|
||||
website_data: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Returns the appropriate prompt based on available context.
|
||||
|
||||
Args:
|
||||
idea: The raw podcast idea or keywords
|
||||
bible_context: Optional Podcast Bible context string
|
||||
website_data: Optional website extraction data
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with appropriate context
|
||||
"""
|
||||
# Build bible context section
|
||||
bible_section = f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""
|
||||
|
||||
if website_data:
|
||||
# Build website context section
|
||||
website_context_parts = []
|
||||
if website_data.get('url'):
|
||||
website_context_parts.append(f"Source: {website_data.get('url')}")
|
||||
if website_data.get('title'):
|
||||
website_context_parts.append(f"Company/Organization: {website_data.get('title')}")
|
||||
if website_data.get('summary'):
|
||||
website_context_parts.append(f"About: {website_data.get('summary')}")
|
||||
if website_data.get('highlights'):
|
||||
highlights_str = ', '.join(website_data.get('highlights', [])[:3])
|
||||
website_context_parts.append(f"Key Highlights: {highlights_str}")
|
||||
if website_data.get('subpages'):
|
||||
subpages_str = ', '.join([
|
||||
sp.get('title', sp.get('url', ''))
|
||||
for sp in website_data.get('subpages', [])[:3]
|
||||
])
|
||||
website_context_parts.append(f"Subpages: {subpages_str}")
|
||||
|
||||
website_context_str = "\n".join(website_context_parts)
|
||||
|
||||
return WEBSITE_AWARE_ENHANCE_PROMPT.substitute(
|
||||
idea=idea,
|
||||
bible_context=bible_section,
|
||||
website_context=website_context_str
|
||||
)
|
||||
else:
|
||||
return STANDARD_ENHANCE_PROMPT.substitute(
|
||||
idea=idea,
|
||||
bible_context=bible_section
|
||||
)
|
||||
|
||||
|
||||
def format_website_context(website_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format website data for inclusion in progress messages.
|
||||
|
||||
Args:
|
||||
website_data: Website extraction data
|
||||
|
||||
Returns:
|
||||
Formatted string describing what's being used
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if website_data.get('title'):
|
||||
parts.append(f"• {website_data['title']}")
|
||||
|
||||
if website_data.get('summary'):
|
||||
summary_preview = website_data['summary'][:100]
|
||||
parts.append(f"• Summary: {summary_preview}...")
|
||||
|
||||
if website_data.get('highlights'):
|
||||
parts.append(f"• {len(website_data['highlights'])} key highlights")
|
||||
|
||||
if website_data.get('subpages'):
|
||||
parts.append(f"• {len(website_data['subpages'])} subpages analyzed")
|
||||
|
||||
if website_data.get('url'):
|
||||
parts.append(f"• Source: {website_data['url']}")
|
||||
|
||||
return "\n".join(parts) if parts else "Basic website analysis"
|
||||
|
||||
if website_data.get('title'):
|
||||
parts.append(f"• {website_data['title']}")
|
||||
|
||||
if website_data.get('summary'):
|
||||
summary_preview = website_data['summary'][:100]
|
||||
parts.append(f"• Summary: {summary_preview}...")
|
||||
|
||||
if website_data.get('highlights'):
|
||||
parts.append(f"• {len(website_data['highlights'])} key highlights")
|
||||
|
||||
if website_data.get('subpages'):
|
||||
parts.append(f"• {len(website_data['subpages'])} subpages analyzed")
|
||||
|
||||
if website_data.get('url'):
|
||||
parts.append(f"• Source: {website_data['url']}")
|
||||
|
||||
return "\n".join(parts) if parts else "Basic website analysis"
|
||||
@@ -12,7 +12,7 @@ from api.story_writer.utils.auth import require_authenticated_user
|
||||
from api.story_writer.task_manager import task_manager
|
||||
|
||||
# Import all handler routers
|
||||
from .handlers import projects, analysis, research, script, audio, images, video, avatar, dubbing, broll, trends
|
||||
from .handlers import projects, analysis, research, script, audio, images, video, avatar, dubbing, broll, trends, tavily_category_research
|
||||
|
||||
# Create main router
|
||||
router = APIRouter(prefix="/api/podcast", tags=["Podcast Maker"])
|
||||
@@ -29,6 +29,7 @@ router.include_router(avatar.router)
|
||||
router.include_router(dubbing.router)
|
||||
router.include_router(broll.router)
|
||||
router.include_router(trends.router)
|
||||
router.include_router(tavily_category_research.router)
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/status")
|
||||
|
||||
@@ -52,7 +52,7 @@ def is_podcast_only_demo_mode() -> bool:
|
||||
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "all")
|
||||
enabled = get_enabled_features()
|
||||
result = "podcast" in enabled and "all" not in enabled
|
||||
print(f"[DEBUG] is_podcast_only_demo_mode: ALWRITY_ENABLED_FEATURES={env_val}, enabled={enabled}, result={result}", flush=True)
|
||||
# Removed debug print - too verbose during startup
|
||||
return result
|
||||
|
||||
|
||||
@@ -712,6 +712,9 @@ async def startup_event():
|
||||
try:
|
||||
_log_memory_usage()
|
||||
|
||||
# Note: Pricing is initialized per-user in services/database.py:init_user_database()
|
||||
# which runs on first database access for each user. No global seeding needed at startup.
|
||||
|
||||
# Skip startup health checks in podcast-only mode to avoid unnecessary DB errors
|
||||
if not is_podcast_only_demo_mode():
|
||||
startup_report = run_startup_health_routine(app)
|
||||
|
||||
@@ -45,6 +45,9 @@ class PodcastProject(Base):
|
||||
knobs = Column(JSON, nullable=True) # Knobs settings
|
||||
research_provider = Column(String(50), nullable=True, default="google") # Research provider
|
||||
|
||||
# Project-specific topic context (category research, selected topics)
|
||||
topic_context = Column(JSON, nullable=True) # { category: "news"|"finance", topics: [...], selected_topic: {...} }
|
||||
|
||||
# UI state
|
||||
show_script_editor = Column(Boolean, default=False)
|
||||
show_render_queue = Column(Boolean, default=False)
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
"""
|
||||
Initialize Alpha Tester Subscription Tiers
|
||||
Creates subscription plans for alpha testing with appropriate limits.
|
||||
|
||||
NOTE: Pricing is seeded via PricingService.initialize_default_pricing()
|
||||
which runs in services/database.py:init_user_database()
|
||||
NOT via this script.
|
||||
"""
|
||||
|
||||
import sys
|
||||
@@ -10,7 +14,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from models.subscription_models import (
|
||||
SubscriptionPlan, SubscriptionTier, APIProviderPricing, APIProvider
|
||||
SubscriptionPlan, SubscriptionTier
|
||||
)
|
||||
from services.database import get_db_session
|
||||
from datetime import datetime
|
||||
@@ -24,7 +28,7 @@ def create_alpha_subscription_tiers():
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error("❌ Could not get database session")
|
||||
logger.error("Could not get database session")
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -38,12 +42,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Free tier for alpha testing - Limited usage",
|
||||
"features": ["blog_writer", "basic_seo", "content_planning"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 50, # 50 calls per day
|
||||
"gemini_tokens_limit": 10000, # 10k tokens per day
|
||||
"tavily_calls_limit": 20, # 20 searches per day
|
||||
"serper_calls_limit": 10, # 10 SEO searches per day
|
||||
"stability_calls_limit": 5, # 5 images per day
|
||||
"monthly_cost_limit": 5.0 # $5 monthly limit
|
||||
"gemini_calls_limit": 50,
|
||||
"gemini_tokens_limit": 10000,
|
||||
"tavily_calls_limit": 20,
|
||||
"serper_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"monthly_cost_limit": 5.0
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -54,12 +58,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Basic alpha tier - Moderate usage for testing",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 200, # 200 calls per day
|
||||
"gemini_tokens_limit": 50000, # 50k tokens per day
|
||||
"tavily_calls_limit": 100, # 100 searches per day
|
||||
"serper_calls_limit": 50, # 50 SEO searches per day
|
||||
"stability_calls_limit": 25, # 25 images per day
|
||||
"monthly_cost_limit": 25.0 # $25 monthly limit
|
||||
"gemini_calls_limit": 200,
|
||||
"gemini_tokens_limit": 50000,
|
||||
"tavily_calls_limit": 100,
|
||||
"serper_calls_limit": 50,
|
||||
"stability_calls_limit": 25,
|
||||
"monthly_cost_limit": 25.0
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -70,12 +74,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Pro alpha tier - High usage for power users",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot", "advanced_analytics"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 500, # 500 calls per day
|
||||
"gemini_tokens_limit": 150000, # 150k tokens per day
|
||||
"tavily_calls_limit": 300, # 300 searches per day
|
||||
"serper_calls_limit": 150, # 150 SEO searches per day
|
||||
"stability_calls_limit": 100, # 100 images per day
|
||||
"monthly_cost_limit": 100.0 # $100 monthly limit
|
||||
"gemini_calls_limit": 500,
|
||||
"gemini_tokens_limit": 150000,
|
||||
"tavily_calls_limit": 300,
|
||||
"serper_calls_limit": 150,
|
||||
"stability_calls_limit": 100,
|
||||
"monthly_cost_limit": 100.0
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -86,34 +90,31 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Enterprise alpha tier - Unlimited usage for enterprise testing",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot", "advanced_analytics", "custom_integrations"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 0, # Unlimited calls
|
||||
"gemini_tokens_limit": 0, # Unlimited tokens
|
||||
"tavily_calls_limit": 0, # Unlimited searches
|
||||
"serper_calls_limit": 0, # Unlimited SEO searches
|
||||
"stability_calls_limit": 0, # Unlimited images
|
||||
"monthly_cost_limit": 500.0 # $500 monthly limit
|
||||
"gemini_calls_limit": 0,
|
||||
"gemini_tokens_limit": 0,
|
||||
"tavily_calls_limit": 0,
|
||||
"serper_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"monthly_cost_limit": 500.0
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Create subscription plans
|
||||
for tier_data in alpha_tiers:
|
||||
# Check if plan already exists
|
||||
existing_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == tier_data["name"]
|
||||
).first()
|
||||
|
||||
if existing_plan:
|
||||
logger.info(f"✅ Plan '{tier_data['name']}' already exists, updating...")
|
||||
# Update existing plan
|
||||
logger.info(f"Plan '{tier_data['name']}' already exists, updating...")
|
||||
for key, value in tier_data["limits"].items():
|
||||
setattr(existing_plan, key, value)
|
||||
existing_plan.description = tier_data["description"]
|
||||
existing_plan.features = tier_data["features"]
|
||||
existing_plan.updated_at = datetime.utcnow()
|
||||
else:
|
||||
logger.info(f"🆕 Creating new plan: {tier_data['name']}")
|
||||
# Create new plan
|
||||
logger.info(f"Creating new plan: {tier_data['name']}")
|
||||
plan = SubscriptionPlan(
|
||||
name=tier_data["name"],
|
||||
tier=tier_data["tier"],
|
||||
@@ -126,106 +127,17 @@ def create_alpha_subscription_tiers():
|
||||
db.add(plan)
|
||||
|
||||
db.commit()
|
||||
logger.info("✅ Alpha subscription tiers created/updated successfully!")
|
||||
|
||||
# Create API provider pricing
|
||||
create_api_pricing(db)
|
||||
logger.info("Alpha subscription tiers created/updated successfully!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating alpha subscription tiers: {e}")
|
||||
logger.error(f"Error creating alpha subscription tiers: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def create_api_pricing(db: Session):
|
||||
"""Create API provider pricing configuration."""
|
||||
|
||||
try:
|
||||
# Gemini pricing (based on current Google AI pricing)
|
||||
gemini_pricing = [
|
||||
{
|
||||
"model_name": "gemini-2.0-flash-exp",
|
||||
"cost_per_input_token": 0.00000075, # $0.75 per 1M tokens
|
||||
"cost_per_output_token": 0.000003, # $3 per 1M tokens
|
||||
"description": "Gemini 2.0 Flash Experimental"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-1.5-flash",
|
||||
"cost_per_input_token": 0.00000075, # $0.75 per 1M tokens
|
||||
"cost_per_output_token": 0.000003, # $3 per 1M tokens
|
||||
"description": "Gemini 1.5 Flash"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-1.5-pro",
|
||||
"cost_per_input_token": 0.00000125, # $1.25 per 1M tokens
|
||||
"cost_per_output_token": 0.000005, # $5 per 1M tokens
|
||||
"description": "Gemini 1.5 Pro"
|
||||
}
|
||||
]
|
||||
|
||||
# Tavily pricing
|
||||
tavily_pricing = [
|
||||
{
|
||||
"model_name": "search",
|
||||
"cost_per_search": 0.001, # $0.001 per search
|
||||
"description": "Tavily Search API"
|
||||
}
|
||||
]
|
||||
|
||||
# Serper pricing
|
||||
serper_pricing = [
|
||||
{
|
||||
"model_name": "search",
|
||||
"cost_per_search": 0.001, # $0.001 per search
|
||||
"description": "Serper Google Search API"
|
||||
}
|
||||
]
|
||||
|
||||
# Stability AI pricing
|
||||
stability_pricing = [
|
||||
{
|
||||
"model_name": "stable-diffusion-xl",
|
||||
"cost_per_image": 0.01, # $0.01 per image
|
||||
"description": "Stable Diffusion XL"
|
||||
}
|
||||
]
|
||||
|
||||
# Create pricing records
|
||||
pricing_configs = [
|
||||
(APIProvider.GEMINI, gemini_pricing),
|
||||
(APIProvider.TAVILY, tavily_pricing),
|
||||
(APIProvider.SERPER, serper_pricing),
|
||||
(APIProvider.STABILITY, stability_pricing)
|
||||
]
|
||||
|
||||
for provider, pricing_list in pricing_configs:
|
||||
for pricing_data in pricing_list:
|
||||
# Check if pricing already exists
|
||||
existing_pricing = db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == pricing_data["model_name"]
|
||||
).first()
|
||||
|
||||
if existing_pricing:
|
||||
logger.info(f"✅ Pricing for {provider.value}/{pricing_data['model_name']} already exists")
|
||||
else:
|
||||
logger.info(f"🆕 Creating pricing for {provider.value}/{pricing_data['model_name']}")
|
||||
pricing = APIProviderPricing(
|
||||
provider=provider,
|
||||
**pricing_data
|
||||
)
|
||||
db.add(pricing)
|
||||
|
||||
db.commit()
|
||||
logger.info("✅ API provider pricing created successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating API pricing: {e}")
|
||||
db.rollback()
|
||||
|
||||
def assign_default_plan_to_users():
|
||||
"""Assign Free Alpha plan to all existing users."""
|
||||
if os.getenv('ENABLE_ALPHA', 'false').lower() not in {'1','true','yes','on'}:
|
||||
@@ -234,32 +146,28 @@ def assign_default_plan_to_users():
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error("❌ Could not get database session")
|
||||
logger.error("Could not get database session")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get Free Alpha plan
|
||||
free_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == "Free Alpha"
|
||||
).first()
|
||||
|
||||
if not free_plan:
|
||||
logger.error("❌ Free Alpha plan not found")
|
||||
logger.error("Free Alpha plan not found")
|
||||
return False
|
||||
|
||||
# For now, we'll create a default user subscription
|
||||
# In a real system, you'd query actual users
|
||||
|
||||
from models.subscription_models import UserSubscription, BillingCycle, UsageStatus
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import timedelta
|
||||
|
||||
# Create default user subscription for testing
|
||||
default_user_id = "default_user"
|
||||
existing_subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == default_user_id
|
||||
).first()
|
||||
|
||||
if not existing_subscription:
|
||||
logger.info(f"🆕 Creating default subscription for {default_user_id}")
|
||||
logger.info(f"Creating default subscription for {default_user_id}")
|
||||
subscription = UserSubscription(
|
||||
user_id=default_user_id,
|
||||
plan_id=free_plan.id,
|
||||
@@ -272,33 +180,32 @@ def assign_default_plan_to_users():
|
||||
)
|
||||
db.add(subscription)
|
||||
db.commit()
|
||||
logger.info(f"✅ Default subscription created for {default_user_id}")
|
||||
logger.info(f"Default subscription created for {default_user_id}")
|
||||
else:
|
||||
logger.info(f"✅ Default subscription already exists for {default_user_id}")
|
||||
logger.info(f"Default subscription already exists for {default_user_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error assigning default plan: {e}")
|
||||
logger.error(f"Error assigning default plan: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Initializing Alpha Subscription Tiers...")
|
||||
logger.info("Initializing Alpha Subscription Tiers...")
|
||||
|
||||
success = create_alpha_subscription_tiers()
|
||||
if success:
|
||||
logger.info("✅ Subscription tiers created successfully!")
|
||||
logger.info("Subscription tiers created successfully!")
|
||||
|
||||
# Assign default plan
|
||||
assign_success = assign_default_plan_to_users()
|
||||
if assign_success:
|
||||
logger.info("✅ Default plan assigned successfully!")
|
||||
logger.info("Default plan assigned successfully!")
|
||||
else:
|
||||
logger.error("❌ Failed to assign default plan")
|
||||
logger.error("Failed to assign default plan")
|
||||
else:
|
||||
logger.error("❌ Failed to create subscription tiers")
|
||||
logger.error("Failed to create subscription tiers")
|
||||
|
||||
logger.info("🎉 Alpha subscription system initialization complete!")
|
||||
logger.info("Alpha subscription system initialization complete!")
|
||||
@@ -67,10 +67,11 @@ import sys
|
||||
from pathlib import Path
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
from services.api_key_manager import APIKeyManager
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("gemini_audio_text")
|
||||
|
||||
281
backend/services/podcast_context_builder.py
Normal file
281
backend/services/podcast_context_builder.py
Normal file
@@ -0,0 +1,281 @@
|
||||
"""
|
||||
Podcast Context Builder Service
|
||||
|
||||
Builds unified context for AI prompts from multiple sources:
|
||||
- Podcast Bible (user personalization)
|
||||
- Website Extraction (from Exa)
|
||||
- Topic Context (category research: News/Finance)
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional, List
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class PodcastContextBuilder:
|
||||
"""Builds unified context for AI prompt enhancements."""
|
||||
|
||||
def build_enhance_context(
|
||||
self,
|
||||
idea: str,
|
||||
bible_context: str = "",
|
||||
website_data: Optional[Dict[str, Any]] = None,
|
||||
topic_context: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build context for topic enhancement prompt.
|
||||
|
||||
Args:
|
||||
idea: Raw podcast idea/keywords
|
||||
bible_context: Serialized Podcast Bible string
|
||||
website_data: Website extraction data (title, summary, highlights, url, subpages)
|
||||
topic_context: Category research data (category, topics, selected_topic)
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- prompt: The formatted prompt
|
||||
- contexts_used: List of context types being used
|
||||
- context_description: Human-readable description for logging
|
||||
"""
|
||||
contexts_used = []
|
||||
context_parts = []
|
||||
|
||||
# Track what contexts are available
|
||||
if bible_context:
|
||||
contexts_used.append("Podcast Bible")
|
||||
|
||||
if website_data:
|
||||
contexts_used.append("Website Analysis")
|
||||
|
||||
if topic_context:
|
||||
category = topic_context.get("category", "unknown")
|
||||
contexts_used.append(f"Category Research ({category})")
|
||||
|
||||
# Build Bible section
|
||||
if bible_context:
|
||||
context_parts.append(f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}")
|
||||
|
||||
# Build Website section
|
||||
if website_data:
|
||||
website_section = self._format_website_section(website_data)
|
||||
context_parts.append(website_section)
|
||||
|
||||
# Build Topic/Category section
|
||||
if topic_context:
|
||||
topic_section = self._format_topic_section(topic_context)
|
||||
context_parts.append(topic_section)
|
||||
|
||||
# Select appropriate prompt template based on available context
|
||||
prompt = self._select_prompt(idea, context_parts, website_data, topic_context)
|
||||
|
||||
return {
|
||||
"prompt": prompt,
|
||||
"contexts_used": contexts_used,
|
||||
"context_description": ", ".join(contexts_used) if contexts_used else "basic idea only",
|
||||
}
|
||||
|
||||
def _format_website_section(self, website_data: Dict[str, Any]) -> str:
|
||||
"""Format website data for prompt inclusion."""
|
||||
parts = []
|
||||
|
||||
if website_data.get("url"):
|
||||
parts.append(f"Source URL: {website_data['url']}")
|
||||
|
||||
if website_data.get("title"):
|
||||
parts.append(f"Company/Organization: {website_data['title']}")
|
||||
|
||||
if website_data.get("summary"):
|
||||
parts.append(f"About: {website_data['summary']}")
|
||||
|
||||
if website_data.get("highlights"):
|
||||
highlights = website_data.get("highlights", [])
|
||||
if highlights:
|
||||
parts.append(f"Key Highlights: {', '.join(highlights[:3])}")
|
||||
|
||||
if website_data.get("subpages"):
|
||||
subpages = website_data.get("subpages", [])
|
||||
if subpages:
|
||||
subpage_titles = [sp.get("title", sp.get("url", "")) for sp in subpages[:3]]
|
||||
parts.append(f"Subpages: {', '.join(subpage_titles)}")
|
||||
|
||||
return "WEBSITE CONTENT ANALYSIS:\n" + "\n".join(parts)
|
||||
|
||||
def _format_topic_section(self, topic_context: Dict[str, Any]) -> str:
|
||||
"""Format category research data for prompt inclusion."""
|
||||
parts = []
|
||||
|
||||
category = topic_context.get("category", "")
|
||||
if category:
|
||||
parts.append(f"Research Category: {category.upper()}")
|
||||
|
||||
# Include selected topic details
|
||||
selected = topic_context.get("selected_topic", {})
|
||||
if selected:
|
||||
if selected.get("title"):
|
||||
parts.append(f"Selected Topic: {selected['title']}")
|
||||
if selected.get("snippet"):
|
||||
parts.append(f"Context: {selected['snippet']}")
|
||||
if selected.get("url"):
|
||||
parts.append(f"Source: {selected['url']}")
|
||||
|
||||
# Include some alternative topics for reference
|
||||
topics = topic_context.get("topics", [])
|
||||
if topics:
|
||||
alt_titles = [t.get("title", "") for t in topics[:3] if t.get("title")]
|
||||
if alt_titles:
|
||||
parts.append(f"Related Topics: {', '.join(alt_titles)}")
|
||||
|
||||
return "CATEGORY RESEARCH CONTEXT:\n" + "\n".join(parts)
|
||||
|
||||
def _select_prompt(
|
||||
self,
|
||||
idea: str,
|
||||
context_parts: List[str],
|
||||
website_data: Optional[Dict[str, Any]],
|
||||
topic_context: Optional[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""Select and format the appropriate prompt based on available context."""
|
||||
|
||||
context_str = "\n\n".join(context_parts)
|
||||
|
||||
# Full context prompt (all sources available)
|
||||
if website_data and topic_context:
|
||||
return f"""You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea, enriched with website content analysis AND category research.
|
||||
|
||||
{context_str}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions that INCORPORATE both the website content AND category research context:
|
||||
1. Professional & Expert-led angle (leverage website authority + research insights)
|
||||
2. Storytelling & Human interest angle (brand narratives + research findings)
|
||||
3. Trendy & Contemporary angle (current trends + research relevance)
|
||||
|
||||
Each version should:
|
||||
- Be 2-3 sentences
|
||||
- Reference specific elements from both website AND research when relevant
|
||||
- Be audience-focused and align with host persona if provided
|
||||
- NOT just repeat summaries - create fresh podcast angles
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings (each a complete episode pitch)
|
||||
- rationales: array of 3 strings explaining each approach
|
||||
|
||||
Example format:
|
||||
{{
|
||||
"enhanced_ideas": ["Pitch 1...", "Pitch 2...", "Pitch 3..."],
|
||||
"rationales": ["Reason 1", "Reason 2", "Reason 3"]
|
||||
}}
|
||||
"""
|
||||
|
||||
# Website-only context
|
||||
elif website_data:
|
||||
return f"""You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea, enriched with website content analysis.
|
||||
|
||||
{context_str}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions that INCORPORATE the website content:
|
||||
1. Professional & Expert-led angle (focus on authority, insights from website)
|
||||
2. Storytelling & Human interest angle (brand narratives, personal connections)
|
||||
3. Trendy & Contemporary angle (modern perspectives, current relevance)
|
||||
|
||||
Each version should:
|
||||
- Be 2-3 sentences
|
||||
- Reference specific elements from the website when relevant
|
||||
- Be audience-focused and align with host persona if provided
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings
|
||||
- rationales: array of 3 strings
|
||||
|
||||
Example format:
|
||||
{{
|
||||
"enhanced_ideas": ["Pitch 1...", "Pitch 2...", "Pitch 3..."],
|
||||
"rationales": ["Reason 1", "Reason 2", "Reason 3"]
|
||||
}}
|
||||
"""
|
||||
|
||||
# Category research only context
|
||||
elif topic_context:
|
||||
category = topic_context.get("category", "research").upper()
|
||||
return f"""You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea, enriched with {category} category research.
|
||||
|
||||
{context_str}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions that INCORPORATE the {category} research:
|
||||
1. Professional & Expert-led angle (leverage research insights and data)
|
||||
2. Storytelling & Human interest angle (real-world applications, human impact)
|
||||
3. Trendy & Contemporary angle (cutting-edge trends, future outlook)
|
||||
|
||||
Each version should:
|
||||
- Be 2-3 sentences
|
||||
- Reference specific elements from the research when relevant
|
||||
- Connect the research to the raw idea meaningfully
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings
|
||||
- rationales: array of 3 strings
|
||||
|
||||
Example format:
|
||||
{{
|
||||
"enhanced_ideas": ["Pitch 1...", "Pitch 2...", "Pitch 3..."],
|
||||
"rationales": ["Reason 1", "Reason 2", "Reason 3"]
|
||||
}}
|
||||
"""
|
||||
|
||||
# Standard context (no additional context)
|
||||
else:
|
||||
return f"""You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
|
||||
{context_str}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions with unique angles:
|
||||
1. Professional & Expert-led angle (focus on authority, insights)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions)
|
||||
3. Trendy & Contemporary angle (focus on trends, modern relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings
|
||||
- rationales: array of 3 strings
|
||||
|
||||
Example format:
|
||||
{{
|
||||
"enhanced_ideas": ["Pitch 1...", "Pitch 2...", "Pitch 3..."],
|
||||
"rationales": ["Reason 1", "Reason 2", "Reason 3"]
|
||||
}}
|
||||
"""
|
||||
|
||||
def format_context_for_logging(
|
||||
self,
|
||||
website_data: Optional[Dict] = None,
|
||||
topic_context: Optional[Dict] = None,
|
||||
) -> str:
|
||||
"""Format context description for logging."""
|
||||
contexts = []
|
||||
|
||||
if website_data:
|
||||
title = website_data.get("title", "Unknown")
|
||||
contexts.append(f"Website: {title[:30]}...")
|
||||
|
||||
if topic_context:
|
||||
category = topic_context.get("category", "unknown")
|
||||
selected = topic_context.get("selected_topic", {})
|
||||
topic_title = selected.get("title", "Not selected")
|
||||
contexts.append(f"Category: {category} ({topic_title[:20]}...)")
|
||||
|
||||
return " | ".join(contexts) if contexts else "No extended context"
|
||||
|
||||
|
||||
# Singleton instance for reuse
|
||||
context_builder = PodcastContextBuilder()
|
||||
@@ -4,147 +4,273 @@ Google Trends Service
|
||||
Provides Google Trends data integration for the Research Engine.
|
||||
Handles rate limiting, caching, error handling, and data serialization.
|
||||
|
||||
Key design decisions:
|
||||
- Monkey-patches urllib3 Retry to fix method_whitelist→allowed_methods (urllib3 2.x)
|
||||
- Monkey-patches pytrends related_topics/related_queries to catch IndexError bug
|
||||
- Uses TrendReq built-in retries (3 retries, 1s backoff) for automatic 429 handling
|
||||
- Random user-agent rotation per instance to reduce fingerprinting
|
||||
- 1-second delays between sequential requests to respect rate limits
|
||||
- 24-hour in-memory cache to avoid redundant API calls
|
||||
|
||||
Author: ALwrity Team
|
||||
Version: 1.0
|
||||
Version: 2.0
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from loguru import logger
|
||||
import pandas as pd
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Monkey-patches: fix compatibility issues before importing/using pytrends
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Patch 1: urllib3 2.x renamed Retry's `method_whitelist` to `allowed_methods`.
|
||||
# pytrends 4.9.2 still uses `method_whitelist`, which crashes with urllib3 2.x.
|
||||
# We patch Retry.__init__ to accept `method_whitelist` and remap it.
|
||||
try:
|
||||
from pytrends.request import TrendReq
|
||||
from urllib3.util.retry import Retry as _OrigRetry
|
||||
|
||||
_orig_retry_init = _OrigRetry.__init__
|
||||
|
||||
def _patched_retry_init(self, *args, **kwargs):
|
||||
if 'method_whitelist' in kwargs and 'allowed_methods' not in kwargs:
|
||||
kwargs['allowed_methods'] = kwargs.pop('method_whitelist')
|
||||
_orig_retry_init(self, *args, **kwargs)
|
||||
|
||||
_OrigRetry.__init__ = _patched_retry_init
|
||||
logger.debug("[Trends] Patched urllib3 Retry.__init__ for method_whitelist→allowed_methods")
|
||||
except Exception as _patch_err:
|
||||
logger.warning(f"[Trends] Could not patch urllib3 Retry: {_patch_err}")
|
||||
|
||||
# Now safe to import pytrends
|
||||
try:
|
||||
from pytrends.request import TrendReq as _TrendReq
|
||||
PYTrends_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYTrends_AVAILABLE = False
|
||||
logger.warning("pytrends not installed. Google Trends features will be unavailable.")
|
||||
|
||||
# Patch 2: pytrends related_topics() and related_queries() use keyword[0]
|
||||
# which raises IndexError on empty lists, but only catch KeyError.
|
||||
# We fix this by catching (KeyError, IndexError) for the keyword extraction.
|
||||
if PYTrends_AVAILABLE:
|
||||
import json as _json
|
||||
import pandas as _pd
|
||||
|
||||
def _fixed_related_topics(self):
|
||||
result_dict = {}
|
||||
related_payload = {}
|
||||
for request_json in self.related_topics_widget_list:
|
||||
try:
|
||||
kw = request_json['request']['restriction'][
|
||||
'complexKeywordsRestriction']['keyword'][0]['value']
|
||||
except (KeyError, IndexError):
|
||||
kw = ''
|
||||
related_payload['req'] = _json.dumps(request_json['request'])
|
||||
related_payload['token'] = request_json['token']
|
||||
related_payload['tz'] = self.tz
|
||||
req_json = self._get_data(
|
||||
url=_TrendReq.RELATED_QUERIES_URL,
|
||||
method=_TrendReq.GET_METHOD,
|
||||
trim_chars=5,
|
||||
params=related_payload,
|
||||
)
|
||||
try:
|
||||
top_list = req_json['default']['rankedList'][0]['rankedKeyword']
|
||||
df_top = _pd.json_normalize(top_list, sep='_')
|
||||
except (KeyError, IndexError):
|
||||
df_top = None
|
||||
try:
|
||||
rising_list = req_json['default']['rankedList'][1]['rankedKeyword']
|
||||
df_rising = _pd.json_normalize(rising_list, sep='_')
|
||||
except (KeyError, IndexError):
|
||||
df_rising = None
|
||||
result_dict[kw] = {'rising': df_rising, 'top': df_top}
|
||||
return result_dict
|
||||
|
||||
def _fixed_related_queries(self):
|
||||
result_dict = {}
|
||||
related_payload = {}
|
||||
for request_json in self.related_queries_widget_list:
|
||||
try:
|
||||
kw = request_json['request']['restriction'][
|
||||
'complexKeywordsRestriction']['keyword'][0]['value']
|
||||
except (KeyError, IndexError):
|
||||
kw = ''
|
||||
related_payload['req'] = _json.dumps(request_json['request'])
|
||||
related_payload['token'] = request_json['token']
|
||||
related_payload['tz'] = self.tz
|
||||
req_json = self._get_data(
|
||||
url=_TrendReq.RELATED_QUERIES_URL,
|
||||
method=_TrendReq.GET_METHOD,
|
||||
trim_chars=5,
|
||||
params=related_payload,
|
||||
)
|
||||
try:
|
||||
top_df = _pd.DataFrame(
|
||||
req_json['default']['rankedList'][0]['rankedKeyword'])
|
||||
top_df = top_df[['query', 'value']]
|
||||
except (KeyError, IndexError):
|
||||
top_df = None
|
||||
try:
|
||||
rising_df = _pd.DataFrame(
|
||||
req_json['default']['rankedList'][1]['rankedKeyword'])
|
||||
rising_df = rising_df[['query', 'value']]
|
||||
except (KeyError, IndexError):
|
||||
rising_df = None
|
||||
result_dict[kw] = {'top': top_df, 'rising': rising_df}
|
||||
return result_dict
|
||||
|
||||
_TrendReq.related_topics = _fixed_related_topics
|
||||
_TrendReq.related_queries = _fixed_related_queries
|
||||
logger.debug("[Trends] Patched TrendReq.related_topics/related_queries for IndexError")
|
||||
|
||||
from .rate_limiter import RateLimiter
|
||||
|
||||
|
||||
class GoogleTrendsService:
|
||||
"""
|
||||
Service for fetching and analyzing Google Trends data.
|
||||
|
||||
Features:
|
||||
- Interest over time
|
||||
- Interest by region
|
||||
- Related topics
|
||||
- Related queries
|
||||
- Rate limiting (1 req/sec)
|
||||
- Caching (24-hour TTL)
|
||||
- Async support
|
||||
- Error handling with retry logic
|
||||
|
||||
Uses TrendReq with no retries (fail-fast) to avoid hitting CAPTCHA on blocks.
|
||||
429 retry handling (1s, 2s, 4s backoff). Random user-agent is set
|
||||
per instance to reduce fingerprinting.
|
||||
"""
|
||||
|
||||
|
||||
USER_AGENTS = [
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:125.0) Gecko/20100101 Firefox/125.0",
|
||||
"Mozilla/5.0 (Macintosh; Intel Mac OS X 14_4) AppleWebKit/605.1.15 (KHTML, like Gecko) Version/17.3 Safari/605.1.15",
|
||||
"Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36",
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Google Trends service."""
|
||||
if not PYTrends_AVAILABLE:
|
||||
raise RuntimeError("pytrends library is required. Install with: pip install pytrends")
|
||||
|
||||
self.rate_limiter = RateLimiter(max_calls=1, period=1.0) # 1 request per second
|
||||
self.cache: Dict[str, Dict[str, Any]] = {} # Simple in-memory cache
|
||||
self.cache_ttl = timedelta(hours=24) # 24-hour cache
|
||||
|
||||
logger.info("GoogleTrendsService initialized")
|
||||
|
||||
|
||||
self.rate_limiter = RateLimiter(max_calls=1, period=1.0)
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self.cache_ttl = timedelta(hours=24)
|
||||
|
||||
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, fail-fast, 2s delays)")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
async def analyze_trends(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str = "today 12-m",
|
||||
geo: str = "US",
|
||||
gprop: str = "",
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive trends analysis.
|
||||
|
||||
Fetches all trends data in a single optimized call:
|
||||
- Interest over time
|
||||
- Interest by region
|
||||
- Related topics (top & rising)
|
||||
- Related queries (top & rising)
|
||||
|
||||
|
||||
Args:
|
||||
keywords: List of keywords to analyze (1-5 keywords recommended)
|
||||
timeframe: Timeframe string (e.g., "today 12-m", "today 1-y", "all")
|
||||
keywords: List of keywords to analyze (1-5)
|
||||
timeframe: Timeframe (e.g., "today 12-m", "today 3-m", "today 5-y")
|
||||
geo: Country code (e.g., "US", "GB", "IN")
|
||||
user_id: User ID for subscription checks (optional for now)
|
||||
|
||||
Returns:
|
||||
Dict containing all trends data in serializable format
|
||||
|
||||
Raises:
|
||||
ValueError: If keywords list is empty or too long
|
||||
RuntimeError: If pytrends is not available or API fails
|
||||
gprop: Google property filter - '' for web, 'youtube' for YouTube, 'news', 'images', 'froogle'
|
||||
user_id: Optional user ID for tracking
|
||||
|
||||
Fetches: interest over time, interest by region, related topics,
|
||||
and related queries using a single TrendReq session.
|
||||
"""
|
||||
if not keywords:
|
||||
raise ValueError("Keywords list cannot be empty")
|
||||
|
||||
|
||||
if len(keywords) > 5:
|
||||
logger.warning(f"Too many keywords ({len(keywords)}), using first 5")
|
||||
keywords = keywords[:5]
|
||||
|
||||
# Check cache first
|
||||
|
||||
cache_key = self._build_cache_key(keywords, timeframe, geo)
|
||||
cached_data = self._get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Returning cached trends data for: {keywords}")
|
||||
return {**cached_data, "cached": True}
|
||||
|
||||
# Rate limit
|
||||
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
|
||||
total_start = time.monotonic()
|
||||
|
||||
interest_over_time: List[Dict[str, Any]] = []
|
||||
interest_by_region: List[Dict[str, Any]] = []
|
||||
related_topics: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
related_queries: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
|
||||
try:
|
||||
logger.info(f"Fetching Google Trends data for: {keywords} (timeframe: {timeframe}, geo: {geo})")
|
||||
|
||||
# Initialize pytrends (sync operation, run in thread)
|
||||
logger.info(f"[Trends] ===== START analyze_trends ===== keywords={keywords} timeframe={timeframe} geo={geo}")
|
||||
|
||||
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
|
||||
init_start = time.monotonic()
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._initialize_pytrends,
|
||||
self._create_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo
|
||||
geo,
|
||||
gprop,
|
||||
)
|
||||
|
||||
# Fetch all data in parallel (pytrends methods are sync, so use to_thread)
|
||||
interest_over_time_task = asyncio.to_thread(
|
||||
lambda: self._safe_interest_over_time(pytrends)
|
||||
init_ms = int((time.monotonic() - init_start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
|
||||
|
||||
# --- Interest Over Time ---
|
||||
iot_start = time.monotonic()
|
||||
interest_over_time = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_over_time(pytrends)
|
||||
)
|
||||
interest_by_region_task = asyncio.to_thread(
|
||||
lambda: self._safe_interest_by_region(pytrends)
|
||||
iot_ms = int((time.monotonic() - iot_start) * 1000)
|
||||
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# --- Interest By Region ---
|
||||
ibr_start = time.monotonic()
|
||||
interest_by_region = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_by_region(pytrends)
|
||||
)
|
||||
related_topics_task = asyncio.to_thread(
|
||||
lambda: self._safe_related_topics(pytrends, keywords)
|
||||
ibr_ms = int((time.monotonic() - ibr_start) * 1000)
|
||||
logger.info(f"[Trends] interest_by_region took {ibr_ms}ms, returned {len(interest_by_region)} regions")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# --- Related Topics ---
|
||||
rt_start = time.monotonic()
|
||||
related_topics = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_topics(pytrends)
|
||||
)
|
||||
related_queries_task = asyncio.to_thread(
|
||||
lambda: self._safe_related_queries(pytrends, keywords)
|
||||
rt_ms = int((time.monotonic() - rt_start) * 1000)
|
||||
rt_top = len(related_topics.get("top", []))
|
||||
rt_rising = len(related_topics.get("rising", []))
|
||||
logger.info(f"[Trends] related_topics took {rt_ms}ms, top={rt_top} rising={rt_rising}")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# --- Related Queries ---
|
||||
rq_start = time.monotonic()
|
||||
related_queries = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_queries(pytrends)
|
||||
)
|
||||
|
||||
# Wait for all tasks
|
||||
interest_over_time, interest_by_region, related_topics, related_queries = await asyncio.gather(
|
||||
interest_over_time_task,
|
||||
interest_by_region_task,
|
||||
related_topics_task,
|
||||
related_queries_task,
|
||||
return_exceptions=True
|
||||
rq_ms = int((time.monotonic() - rq_start) * 1000)
|
||||
rq_top = len(related_queries.get("top", []))
|
||||
rq_rising = len(related_queries.get("rising", []))
|
||||
logger.info(f"[Trends] related_queries took {rq_ms}ms, top={rq_top} rising={rq_rising}")
|
||||
|
||||
total_ms = int((time.monotonic() - total_start) * 1000)
|
||||
logger.info(
|
||||
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
|
||||
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
|
||||
f"rt_top={rt_top} rq_top={rq_top}"
|
||||
)
|
||||
|
||||
# Handle exceptions
|
||||
if isinstance(interest_over_time, Exception):
|
||||
logger.error(f"Interest over time failed: {interest_over_time}")
|
||||
interest_over_time = []
|
||||
if isinstance(interest_by_region, Exception):
|
||||
logger.error(f"Interest by region failed: {interest_by_region}")
|
||||
interest_by_region = []
|
||||
if isinstance(related_topics, Exception):
|
||||
logger.error(f"Related topics failed: {related_topics}")
|
||||
related_topics = {"top": [], "rising": []}
|
||||
if isinstance(related_queries, Exception):
|
||||
logger.error(f"Related queries failed: {related_queries}")
|
||||
related_queries = {"top": [], "rising": []}
|
||||
|
||||
# Build result
|
||||
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
@@ -153,186 +279,268 @@ class GoogleTrendsService:
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False
|
||||
"cached": False,
|
||||
}
|
||||
|
||||
# Cache result
|
||||
|
||||
self._save_to_cache(cache_key, result)
|
||||
|
||||
logger.info(f"Google Trends data fetched successfully: {len(interest_over_time)} time points, {len(interest_by_region)} regions")
|
||||
|
||||
|
||||
logger.info(
|
||||
f"Google Trends data fetched successfully: "
|
||||
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google Trends analysis failed: {e}")
|
||||
# Return fallback response
|
||||
return self._create_fallback_response(keywords, timeframe, geo, str(e))
|
||||
|
||||
def _initialize_pytrends(
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# TrendReq factory
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _create_pytrends(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str,
|
||||
geo: str
|
||||
) -> TrendReq:
|
||||
"""Initialize pytrends and build payload (sync operation)."""
|
||||
pytrends = TrendReq(hl='en-US', tz=360)
|
||||
pytrends.build_payload(kw_list=keywords, timeframe=timeframe, geo=geo)
|
||||
geo: str,
|
||||
gprop: str = "",
|
||||
) -> _TrendReq:
|
||||
"""Create TrendReq with optional gprop (e.g., 'youtube' for video trends)."""
|
||||
start = time.monotonic()
|
||||
ua = random.choice(self.USER_AGENTS)
|
||||
logger.info(f"[Trends] Creating TrendReq (fail-fast, gprop='{gprop}', UA={ua[:40]}...)")
|
||||
pytrends = _TrendReq(
|
||||
hl='en-US',
|
||||
tz=360,
|
||||
timeout=(10, 30),
|
||||
retries=0,
|
||||
backoff_factor=0,
|
||||
requests_args={'headers': {'User-Agent': ua}},
|
||||
)
|
||||
# gprop: '' = web, 'youtube' = YouTube, 'news', 'images', 'froogle'
|
||||
pytrends.build_payload(kw_list=keywords, timeframe=timeframe, geo=geo, gprop=gprop)
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload completed in {elapsed}ms (gprop={gprop})")
|
||||
return pytrends
|
||||
|
||||
def _safe_interest_over_time(self, pytrends: TrendReq) -> List[Dict[str, Any]]:
|
||||
"""Safely fetch interest over time data."""
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Data fetchers — each catches all exceptions and returns defaults
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _fetch_interest_over_time(self, pytrends: _TrendReq, keywords: List[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Fetch interest over time data."""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
df = pytrends.interest_over_time()
|
||||
if df.empty:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
if df is None or (hasattr(df, 'empty') and df.empty):
|
||||
logger.info(f"[Trends] interest_over_time returned empty in {elapsed}ms")
|
||||
return []
|
||||
return self._format_dataframe(df.reset_index())
|
||||
# Use pytrends.kw_list if keywords not provided
|
||||
kw = keywords or pytrends.kw_list
|
||||
result = self._format_dataframe(df.reset_index(), kw)
|
||||
logger.info(f"[Trends] interest_over_time returned {len(result)} points in {elapsed}ms")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching interest over time: {e}")
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
logger.error(f"[Trends] interest_over_time failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
def _safe_interest_by_region(self, pytrends: TrendReq) -> List[Dict[str, Any]]:
|
||||
"""Safely fetch interest by region data."""
|
||||
|
||||
def _fetch_interest_by_region(self, pytrends: _TrendReq, keywords: List[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Fetch interest by region data."""
|
||||
start = time.monotonic()
|
||||
try:
|
||||
df = pytrends.interest_by_region(resolution='COUNTRY', inc_low_vol=True, inc_geo_code=False)
|
||||
if df.empty:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
if df is None or (hasattr(df, 'empty') and df.empty):
|
||||
logger.info(f"[Trends] interest_by_region returned empty in {elapsed}ms")
|
||||
return []
|
||||
return self._format_dataframe(df.reset_index())
|
||||
result = self._format_dataframe(df.reset_index(), keywords or pytrends.kw_list)
|
||||
logger.info(f"[Trends] interest_by_region returned {len(result)} regions in {elapsed}ms")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching interest by region: {e}")
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
logger.error(f"[Trends] interest_by_region failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
def _safe_related_topics(
|
||||
self,
|
||||
pytrends: TrendReq,
|
||||
keywords: List[str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Safely fetch related topics."""
|
||||
|
||||
def _fetch_related_topics(self, pytrends: _TrendReq) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Fetch related topics. Patches catch IndexError from pytrends bug."""
|
||||
start = time.monotonic()
|
||||
result = {"top": [], "rising": []}
|
||||
try:
|
||||
topics_data = pytrends.related_topics()
|
||||
result = {"top": [], "rising": []}
|
||||
|
||||
for keyword in keywords:
|
||||
if keyword in topics_data and isinstance(topics_data[keyword], dict):
|
||||
keyword_topics = topics_data[keyword]
|
||||
|
||||
if "top" in keyword_topics and not keyword_topics["top"].empty:
|
||||
top_df = keyword_topics["top"]
|
||||
# Select relevant columns
|
||||
if "topic_title" in top_df.columns and "value" in top_df.columns:
|
||||
top_data = top_df[["topic_title", "value"]].to_dict('records')
|
||||
result["top"].extend(top_data)
|
||||
|
||||
if "rising" in keyword_topics and not keyword_topics["rising"].empty:
|
||||
rising_df = keyword_topics["rising"]
|
||||
if "topic_title" in rising_df.columns and "value" in rising_df.columns:
|
||||
rising_data = rising_df[["topic_title", "value"]].to_dict('records')
|
||||
result["rising"].extend(rising_data)
|
||||
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if topics_data is None:
|
||||
logger.info(f"[Trends] related_topics returned None in {elapsed}ms")
|
||||
return result
|
||||
|
||||
if not isinstance(topics_data, dict):
|
||||
logger.info(f"[Trends] related_topics returned {type(topics_data).__name__}, expected dict")
|
||||
return result
|
||||
|
||||
for key, keyword_data in topics_data.items():
|
||||
if keyword_data is None or not isinstance(keyword_data, dict):
|
||||
continue
|
||||
|
||||
for section in ["top", "rising"]:
|
||||
section_df = keyword_data.get(section)
|
||||
if section_df is None:
|
||||
continue
|
||||
if hasattr(section_df, 'empty') and section_df.empty:
|
||||
continue
|
||||
if not hasattr(section_df, 'to_dict'):
|
||||
continue
|
||||
|
||||
try:
|
||||
if "topic_title" in section_df.columns and "value" in section_df.columns:
|
||||
data = section_df[["topic_title", "value"]].to_dict('records')
|
||||
else:
|
||||
data = section_df.to_dict('records')
|
||||
result[section].extend(data)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error parsing {section} topics for key '{key}': {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"[Trends] related_topics completed in {elapsed}ms, top={len(result['top'])} rising={len(result['rising'])}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching related topics: {e}")
|
||||
return {"top": [], "rising": []}
|
||||
|
||||
def _safe_related_queries(
|
||||
self,
|
||||
pytrends: TrendReq,
|
||||
keywords: List[str]
|
||||
) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Safely fetch related queries."""
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
logger.error(f"[Trends] related_topics failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
def _fetch_related_queries(self, pytrends: _TrendReq) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Fetch related queries. Patches catch IndexError from pytrends bug."""
|
||||
start = time.monotonic()
|
||||
result = {"top": [], "rising": []}
|
||||
try:
|
||||
queries_data = pytrends.related_queries()
|
||||
result = {"top": [], "rising": []}
|
||||
|
||||
for keyword in keywords:
|
||||
if keyword in queries_data and isinstance(queries_data[keyword], dict):
|
||||
keyword_queries = queries_data[keyword]
|
||||
|
||||
if "top" in keyword_queries and not keyword_queries["top"].empty:
|
||||
top_df = keyword_queries["top"]
|
||||
result["top"].extend(top_df.to_dict('records'))
|
||||
|
||||
if "rising" in keyword_queries and not keyword_queries["rising"].empty:
|
||||
rising_df = keyword_queries["rising"]
|
||||
result["rising"].extend(rising_df.to_dict('records'))
|
||||
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
|
||||
if queries_data is None:
|
||||
logger.info(f"[Trends] related_queries returned None in {elapsed}ms")
|
||||
return result
|
||||
|
||||
if not isinstance(queries_data, dict):
|
||||
logger.info(f"[Trends] related_queries returned {type(queries_data).__name__}, expected dict")
|
||||
return result
|
||||
|
||||
for key, keyword_data in queries_data.items():
|
||||
if keyword_data is None or not isinstance(keyword_data, dict):
|
||||
continue
|
||||
|
||||
for section in ["top", "rising"]:
|
||||
section_df = keyword_data.get(section)
|
||||
if section_df is None:
|
||||
continue
|
||||
if hasattr(section_df, 'empty') and section_df.empty:
|
||||
continue
|
||||
if not hasattr(section_df, 'to_dict'):
|
||||
continue
|
||||
|
||||
try:
|
||||
data = section_df.to_dict('records')
|
||||
result[section].extend(data)
|
||||
except Exception as e:
|
||||
logger.debug(f"Error parsing {section} queries for key '{key}': {e}")
|
||||
continue
|
||||
|
||||
logger.info(f"[Trends] related_queries completed in {elapsed}ms, top={len(result['top'])} rising={len(result['rising'])}")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching related queries: {e}")
|
||||
return {"top": [], "rising": []}
|
||||
|
||||
def _format_dataframe(self, df: pd.DataFrame) -> List[Dict[str, Any]]:
|
||||
"""Convert DataFrame to list of dicts (serializable format)."""
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
logger.error(f"[Trends] related_queries failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Helpers
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
def _format_dataframe(self, df: pd.DataFrame, keywords: List[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Convert DataFrame to list of dicts. Handles both pytrends and SerpAPI formats."""
|
||||
if df.empty:
|
||||
return []
|
||||
|
||||
# Convert datetime columns to strings
|
||||
for col in df.columns:
|
||||
if pd.api.types.is_datetime64_any_dtype(df[col]):
|
||||
df[col] = df[col].astype(str)
|
||||
# Try to detect and handle SerpAPI-style nested data
|
||||
# Check if the dataframe has 'date' column and 'values' array column
|
||||
records = df.to_dict('records')
|
||||
|
||||
# Convert to dict records
|
||||
return df.to_dict('records')
|
||||
|
||||
# Check first record for nested values pattern (SerpAPI format)
|
||||
if records and 'values' in records[0] and isinstance(records[0]['values'], list):
|
||||
# SerpAPI-style: need to flatten
|
||||
flat_records = []
|
||||
for record in records:
|
||||
date_str = record.get('date', '')
|
||||
timestamp = record.get('timestamp', '')
|
||||
is_partial = record.get('partial_data', False)
|
||||
|
||||
# Extract values from nested array
|
||||
for val_entry in record['values']:
|
||||
keyword_name = val_entry.get('query', '')
|
||||
value = val_entry.get('value', val_entry.get('extracted_value', 0))
|
||||
flat_record = {
|
||||
'date': date_str,
|
||||
'timestamp': timestamp,
|
||||
keyword_name: int(value) if value else 0,
|
||||
}
|
||||
if is_partial:
|
||||
flat_record['isPartial'] = True
|
||||
flat_records.append(flat_record)
|
||||
records = flat_records
|
||||
|
||||
# Convert datetime columns to strings
|
||||
for record in records:
|
||||
for key, value in record.items():
|
||||
if hasattr(value, 'year'): # datetime-like
|
||||
record[key] = str(value)
|
||||
|
||||
return records
|
||||
|
||||
def _build_cache_key(self, keywords: List[str], timeframe: str, geo: str) -> str:
|
||||
"""Build cache key from parameters."""
|
||||
keywords_str = ":".join(sorted(keywords))
|
||||
return f"google_trends:{keywords_str}:{timeframe}:{geo}"
|
||||
|
||||
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
"""Get data from cache if not expired."""
|
||||
if cache_key not in self.cache:
|
||||
return None
|
||||
|
||||
cached_entry = self.cache[cache_key]
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
# Expired, remove from cache
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
|
||||
# Return cached data (without cached flag)
|
||||
result = {**cached_entry}
|
||||
result.pop("cached", None)
|
||||
return result
|
||||
|
||||
|
||||
def _save_to_cache(self, cache_key: str, data: Dict[str, Any]):
|
||||
"""Save data to cache."""
|
||||
# Store with timestamp
|
||||
cache_entry = {
|
||||
**data,
|
||||
"cached_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
cache_entry = {**data, "cached_at": datetime.utcnow().isoformat()}
|
||||
self.cache[cache_key] = cache_entry
|
||||
|
||||
# Clean up old cache entries periodically
|
||||
if len(self.cache) > 100: # Limit cache size
|
||||
if len(self.cache) > 100:
|
||||
self._cleanup_cache()
|
||||
|
||||
|
||||
def _cleanup_cache(self):
|
||||
"""Remove expired cache entries."""
|
||||
now = datetime.utcnow()
|
||||
expired_keys = []
|
||||
|
||||
for key, entry in self.cache.items():
|
||||
cached_time = datetime.fromisoformat(entry.get("cached_at", entry.get("timestamp", "")))
|
||||
if now - cached_time > self.cache_ttl:
|
||||
expired_keys.append(key)
|
||||
|
||||
for key in expired_keys:
|
||||
del self.cache[key]
|
||||
|
||||
logger.debug(f"Cleaned up {len(expired_keys)} expired cache entries")
|
||||
|
||||
|
||||
def _create_fallback_response(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str,
|
||||
geo: str,
|
||||
error_message: str
|
||||
gprop: str = "",
|
||||
error_message: str = "",
|
||||
) -> Dict[str, Any]:
|
||||
"""Create fallback response when trends analysis fails."""
|
||||
source = "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop
|
||||
return {
|
||||
"interest_over_time": [],
|
||||
"interest_by_region": [],
|
||||
@@ -341,40 +549,38 @@ class GoogleTrendsService:
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": source,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
"error": error_message
|
||||
"error": error_message,
|
||||
}
|
||||
|
||||
|
||||
async def get_trending_searches(
|
||||
self,
|
||||
country: str = "united_states",
|
||||
user_id: Optional[str] = None
|
||||
user_id: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get current trending searches for a country.
|
||||
|
||||
Args:
|
||||
country: Country name (e.g., "united_states", "united_kingdom")
|
||||
user_id: User ID for subscription checks
|
||||
|
||||
Returns:
|
||||
List of trending search terms
|
||||
"""
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
|
||||
try:
|
||||
pytrends = TrendReq(hl='en-US', tz=360)
|
||||
ua = random.choice(self.USER_AGENTS)
|
||||
pytrends = _TrendReq(
|
||||
hl='en-US',
|
||||
tz=360,
|
||||
timeout=(10, 30),
|
||||
retries=0,
|
||||
backoff_factor=0,
|
||||
requests_args={'headers': {'User-Agent': ua}},
|
||||
)
|
||||
trending_df = await asyncio.to_thread(
|
||||
lambda: pytrends.trending_searches(pn=country)
|
||||
)
|
||||
|
||||
if trending_df.empty:
|
||||
|
||||
if trending_df is None or (hasattr(trending_df, 'empty') and trending_df.empty):
|
||||
return []
|
||||
|
||||
# Return as list of strings
|
||||
|
||||
return trending_df[0].tolist() if len(trending_df.columns) > 0 else []
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching trending searches: {e}")
|
||||
return []
|
||||
return []
|
||||
@@ -494,7 +494,16 @@ class PricingService:
|
||||
logger.debug(f"Added new pricing for {pricing_data['provider'].value}:{pricing_data['model_name']}")
|
||||
|
||||
self.db.commit()
|
||||
logger.info("Default API pricing initialized/updated. HuggingFace pricing loaded from env vars if available.")
|
||||
|
||||
# Debug: count pricing rows seeded
|
||||
total_rows = self.db.query(APIProviderPricing).count()
|
||||
providers = self.db.query(APIProviderPricing.provider).distinct().all()
|
||||
provider_list = sorted([p[0].value for p in providers]) if providers else []
|
||||
logger.info(f"[PRICING_INIT] Default API pricing initialized: {len(all_pricing)} rows configured, {total_rows} rows in DB, providers: {provider_list}")
|
||||
|
||||
# Warning-level log that will be visible
|
||||
logger.warning(f"[PRICING_INIT] Pricing ready: {total_rows} rows for {len(provider_list)} providers")
|
||||
logger.warning("Default API pricing initialized/updated. HuggingFace pricing loaded from env vars if available.")
|
||||
|
||||
def initialize_default_plans(self):
|
||||
"""Initialize default subscription plans."""
|
||||
|
||||
@@ -4,6 +4,7 @@ Handles fetching user data from the onboarding database.
|
||||
"""
|
||||
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
@@ -92,5 +93,88 @@ class UserDataService:
|
||||
return integrated_data.get('website_analysis')
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user website analysis: {str(e)}")
|
||||
logger.error(f"Error getting user website analysis: {e}")
|
||||
return None
|
||||
|
||||
def save_website_extraction(self, user_id: str, extraction_data: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Save website extraction data for future use.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
extraction_data: Website extraction data (title, summary, highlights, url, subpages)
|
||||
|
||||
Returns:
|
||||
True if saved successfully
|
||||
"""
|
||||
try:
|
||||
# Clean data - remove images/favicon
|
||||
clean_data = {
|
||||
k: v for k, v in extraction_data.items()
|
||||
if k not in ('image', 'favicon')
|
||||
}
|
||||
clean_data['saved_at'] = datetime.now().isoformat()
|
||||
|
||||
# Find or create user session for storing
|
||||
onboarding = self.db.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not onboarding:
|
||||
# Create new session if not exists
|
||||
onboarding = OnboardingSession(user_id=user_id)
|
||||
self.db.add(onboarding)
|
||||
|
||||
# Try to update website_analysis field
|
||||
# The field might be JSON in the model
|
||||
try:
|
||||
existing = onboarding.website_analysis
|
||||
if isinstance(existing, dict):
|
||||
existing.update(clean_data)
|
||||
onboarding.website_analysis = existing
|
||||
else:
|
||||
onboarding.website_analysis = clean_data
|
||||
except Exception as ex:
|
||||
logger.warning(f"Could not update website_analysis: {ex}")
|
||||
onboarding.website_analysis = clean_data
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Saved website extraction for user {user_id}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving website extraction: {str(e)}")
|
||||
self.db.rollback()
|
||||
return False
|
||||
|
||||
def get_website_extraction(self, user_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Get saved website extraction data.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
|
||||
Returns:
|
||||
Website extraction data or None
|
||||
"""
|
||||
try:
|
||||
onboarding = self.db.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not onboarding:
|
||||
return None
|
||||
|
||||
extraction = onboarding.website_analysis
|
||||
if isinstance(extraction, dict):
|
||||
# Return clean data without internal fields
|
||||
return {
|
||||
k: v for k, v in extraction.items()
|
||||
if k not in ('saved_at', 'full_analysis', 'analysis_status')
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting website extraction: {str(e)}")
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user