Base code
This commit is contained in:
31
backend/services/blog_writer/research/__init__.py
Normal file
31
backend/services/blog_writer/research/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Research module for AI Blog Writer.
|
||||
|
||||
This module handles all research-related functionality including:
|
||||
- Google Search grounding integration
|
||||
- Keyword analysis and competitor research
|
||||
- Content angle discovery
|
||||
- Research caching and optimization
|
||||
"""
|
||||
|
||||
from .research_service import ResearchService
|
||||
from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
from .content_angle_generator import ContentAngleGenerator
|
||||
from .data_filter import ResearchDataFilter
|
||||
from .base_provider import ResearchProvider as BaseResearchProvider
|
||||
from .google_provider import GoogleResearchProvider
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
|
||||
__all__ = [
|
||||
'ResearchService',
|
||||
'KeywordAnalyzer',
|
||||
'CompetitorAnalyzer',
|
||||
'ContentAngleGenerator',
|
||||
'ResearchDataFilter',
|
||||
'BaseResearchProvider',
|
||||
'GoogleResearchProvider',
|
||||
'ExaResearchProvider',
|
||||
'TavilyResearchProvider',
|
||||
]
|
||||
37
backend/services/blog_writer/research/base_provider.py
Normal file
37
backend/services/blog_writer/research/base_provider.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
Base Research Provider Interface
|
||||
|
||||
Abstract base class for research provider implementations.
|
||||
Ensures consistency across different research providers (Google, Exa, etc.)
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
|
||||
|
||||
class ResearchProvider(ABC):
|
||||
"""Abstract base class for research providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def search(
|
||||
self,
|
||||
prompt: str,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: Any, # ResearchConfig
|
||||
user_id: str
|
||||
) -> Dict[str, Any]:
|
||||
"""Execute research and return raw results."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_provider_enum(self):
|
||||
"""Return APIProvider enum for subscription tracking."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Estimate token usage for pre-flight validation."""
|
||||
pass
|
||||
|
||||
72
backend/services/blog_writer/research/competitor_analyzer.py
Normal file
72
backend/services/blog_writer/research/competitor_analyzer.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
Competitor Analyzer - AI-powered competitor analysis for research content.
|
||||
|
||||
Extracts competitor insights and market intelligence from research content.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class CompetitorAnalyzer:
|
||||
"""Analyzes competitors and market intelligence from research content."""
|
||||
|
||||
def analyze(self, content: str, user_id: str = None) -> Dict[str, Any]:
|
||||
"""Parse comprehensive competitor analysis from the research content using AI."""
|
||||
competitor_prompt = f"""
|
||||
Analyze the following research content and extract competitor insights:
|
||||
|
||||
Research Content:
|
||||
{content[:3000]}
|
||||
|
||||
Extract and analyze:
|
||||
1. Top competitors mentioned (companies, brands, platforms)
|
||||
2. Content gaps (what competitors are missing)
|
||||
3. Market opportunities (untapped areas)
|
||||
4. Competitive advantages (what makes content unique)
|
||||
5. Market positioning insights
|
||||
6. Industry leaders and their strategies
|
||||
|
||||
Respond with JSON:
|
||||
{{
|
||||
"top_competitors": ["competitor1", "competitor2"],
|
||||
"content_gaps": ["gap1", "gap2"],
|
||||
"opportunities": ["opportunity1", "opportunity2"],
|
||||
"competitive_advantages": ["advantage1", "advantage2"],
|
||||
"market_positioning": "positioning insights",
|
||||
"industry_leaders": ["leader1", "leader2"],
|
||||
"analysis_notes": "Comprehensive competitor analysis summary"
|
||||
}}
|
||||
"""
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
competitor_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"top_competitors": {"type": "array", "items": {"type": "string"}},
|
||||
"content_gaps": {"type": "array", "items": {"type": "string"}},
|
||||
"opportunities": {"type": "array", "items": {"type": "string"}},
|
||||
"competitive_advantages": {"type": "array", "items": {"type": "string"}},
|
||||
"market_positioning": {"type": "string"},
|
||||
"industry_leaders": {"type": "array", "items": {"type": "string"}},
|
||||
"analysis_notes": {"type": "string"}
|
||||
},
|
||||
"required": ["top_competitors", "content_gaps", "opportunities", "competitive_advantages", "market_positioning", "industry_leaders", "analysis_notes"]
|
||||
}
|
||||
|
||||
competitor_analysis = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = competitor_analysis.get('error', 'Unknown error') if isinstance(competitor_analysis, dict) else str(competitor_analysis)
|
||||
logger.error(f"AI competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Competitor analysis failed: {error_msg}")
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Content Angle Generator - AI-powered content angle discovery.
|
||||
|
||||
Generates strategic content angles from research content for blog posts.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class ContentAngleGenerator:
|
||||
"""Generates strategic content angles from research content."""
|
||||
|
||||
def generate(self, content: str, topic: str, industry: str, user_id: str = None) -> List[str]:
|
||||
"""Parse strategic content angles from the research content using AI."""
|
||||
angles_prompt = f"""
|
||||
Analyze the following research content and create strategic content angles for: {topic} in {industry}
|
||||
|
||||
Research Content:
|
||||
{content[:3000]}
|
||||
|
||||
Create 7 compelling content angles that:
|
||||
1. Leverage current trends and data from the research
|
||||
2. Address content gaps and opportunities
|
||||
3. Appeal to different audience segments
|
||||
4. Include unique perspectives not covered by competitors
|
||||
5. Incorporate specific statistics, case studies, or expert insights
|
||||
6. Create emotional connection and urgency
|
||||
7. Provide actionable value to readers
|
||||
|
||||
Each angle should be:
|
||||
- Specific and data-driven
|
||||
- Unique and differentiated
|
||||
- Compelling and click-worthy
|
||||
- Actionable for readers
|
||||
|
||||
Respond with JSON:
|
||||
{{
|
||||
"content_angles": [
|
||||
"Specific angle 1 with data/trends",
|
||||
"Specific angle 2 with unique perspective",
|
||||
"Specific angle 3 with actionable insights",
|
||||
"Specific angle 4 with case study focus",
|
||||
"Specific angle 5 with future outlook",
|
||||
"Specific angle 6 with problem-solving focus",
|
||||
"Specific angle 7 with industry insights"
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
angles_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content_angles": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 5,
|
||||
"maxItems": 7
|
||||
}
|
||||
},
|
||||
"required": ["content_angles"]
|
||||
}
|
||||
|
||||
angles_result = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = angles_result.get('error', 'Unknown error') if isinstance(angles_result, dict) else str(angles_result)
|
||||
logger.error(f"AI content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Content angles generation failed: {error_msg}")
|
||||
|
||||
519
backend/services/blog_writer/research/data_filter.py
Normal file
519
backend/services/blog_writer/research/data_filter.py
Normal file
@@ -0,0 +1,519 @@
|
||||
"""
|
||||
Research Data Filter - Filters and cleans research data for optimal AI processing.
|
||||
|
||||
This module provides intelligent filtering and cleaning of research data to:
|
||||
1. Remove low-quality sources and irrelevant content
|
||||
2. Optimize data for AI processing (reduce tokens, improve quality)
|
||||
3. Ensure only high-value insights are sent to AI prompts
|
||||
4. Maintain data integrity while improving processing efficiency
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import re
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchResponse,
|
||||
ResearchSource,
|
||||
GroundingMetadata,
|
||||
GroundingChunk,
|
||||
GroundingSupport,
|
||||
Citation,
|
||||
)
|
||||
|
||||
|
||||
class ResearchDataFilter:
|
||||
"""Filters and cleans research data for optimal AI processing."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the research data filter with default settings."""
|
||||
# Be conservative but avoid over-filtering which can lead to empty UI
|
||||
self.min_credibility_score = 0.5
|
||||
self.min_excerpt_length = 20
|
||||
self.max_sources = 15
|
||||
self.max_grounding_chunks = 20
|
||||
self.max_content_gaps = 5
|
||||
self.max_keywords_per_category = 10
|
||||
self.min_grounding_confidence = 0.5
|
||||
self.max_source_age_days = 365 * 5 # allow up to 5 years if relevant
|
||||
|
||||
# Common stop words for keyword cleaning
|
||||
self.stop_words = {
|
||||
'the', 'a', 'an', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by',
|
||||
'is', 'are', 'was', 'were', 'be', 'been', 'being', 'have', 'has', 'had', 'do', 'does', 'did',
|
||||
'will', 'would', 'could', 'should', 'may', 'might', 'must', 'can', 'this', 'that', 'these', 'those'
|
||||
}
|
||||
|
||||
# Irrelevant source patterns
|
||||
self.irrelevant_patterns = [
|
||||
r'\.(pdf|doc|docx|xls|xlsx|ppt|pptx)$', # Document files
|
||||
r'\.(jpg|jpeg|png|gif|svg|webp)$', # Image files
|
||||
r'\.(mp4|avi|mov|wmv|flv|webm)$', # Video files
|
||||
r'\.(mp3|wav|flac|aac)$', # Audio files
|
||||
r'\.(zip|rar|7z|tar|gz)$', # Archive files
|
||||
r'^https?://(www\.)?(facebook|twitter|instagram|linkedin|youtube)\.com', # Social media
|
||||
r'^https?://(www\.)?(amazon|ebay|etsy)\.com', # E-commerce
|
||||
r'^https?://(www\.)?(wikipedia)\.org', # Wikipedia (too generic)
|
||||
]
|
||||
|
||||
logger.info("✅ ResearchDataFilter initialized with quality thresholds")
|
||||
|
||||
def filter_research_data(self, research_data: BlogResearchResponse) -> BlogResearchResponse:
|
||||
"""
|
||||
Main filtering method that processes all research data components.
|
||||
|
||||
Args:
|
||||
research_data: Raw research data from the research service
|
||||
|
||||
Returns:
|
||||
Filtered and cleaned research data optimized for AI processing
|
||||
"""
|
||||
logger.info(f"Starting research data filtering for {len(research_data.sources)} sources")
|
||||
|
||||
# Track original counts for logging
|
||||
original_counts = {
|
||||
'sources': len(research_data.sources),
|
||||
'grounding_chunks': len(research_data.grounding_metadata.grounding_chunks) if research_data.grounding_metadata else 0,
|
||||
'grounding_supports': len(research_data.grounding_metadata.grounding_supports) if research_data.grounding_metadata else 0,
|
||||
'citations': len(research_data.grounding_metadata.citations) if research_data.grounding_metadata else 0,
|
||||
}
|
||||
|
||||
# Filter sources
|
||||
filtered_sources = self.filter_sources(research_data.sources)
|
||||
|
||||
# Filter grounding metadata
|
||||
filtered_grounding_metadata = self.filter_grounding_metadata(research_data.grounding_metadata)
|
||||
|
||||
# Clean keyword analysis
|
||||
cleaned_keyword_analysis = self.clean_keyword_analysis(research_data.keyword_analysis)
|
||||
|
||||
# Clean competitor analysis
|
||||
cleaned_competitor_analysis = self.clean_competitor_analysis(research_data.competitor_analysis)
|
||||
|
||||
# Filter content gaps
|
||||
filtered_content_gaps = self.filter_content_gaps(
|
||||
research_data.keyword_analysis.get('content_gaps', []),
|
||||
research_data
|
||||
)
|
||||
|
||||
# Update keyword analysis with filtered content gaps
|
||||
cleaned_keyword_analysis['content_gaps'] = filtered_content_gaps
|
||||
|
||||
# Create filtered research response
|
||||
filtered_research = BlogResearchResponse(
|
||||
success=research_data.success,
|
||||
sources=filtered_sources,
|
||||
keyword_analysis=cleaned_keyword_analysis,
|
||||
competitor_analysis=cleaned_competitor_analysis,
|
||||
suggested_angles=research_data.suggested_angles, # Keep as-is for now
|
||||
search_widget=research_data.search_widget,
|
||||
search_queries=research_data.search_queries,
|
||||
grounding_metadata=filtered_grounding_metadata,
|
||||
error_message=research_data.error_message
|
||||
)
|
||||
|
||||
# Log filtering results
|
||||
self._log_filtering_results(original_counts, filtered_research)
|
||||
|
||||
return filtered_research
|
||||
|
||||
def filter_sources(self, sources: List[ResearchSource]) -> List[ResearchSource]:
|
||||
"""
|
||||
Filter sources based on quality, relevance, and recency criteria.
|
||||
|
||||
Args:
|
||||
sources: List of research sources to filter
|
||||
|
||||
Returns:
|
||||
Filtered list of high-quality sources
|
||||
"""
|
||||
if not sources:
|
||||
return []
|
||||
|
||||
filtered_sources = []
|
||||
|
||||
for source in sources:
|
||||
# Quality filters
|
||||
if not self._is_source_high_quality(source):
|
||||
continue
|
||||
|
||||
# Relevance filters
|
||||
if not self._is_source_relevant(source):
|
||||
continue
|
||||
|
||||
# Recency filters
|
||||
if not self._is_source_recent(source):
|
||||
continue
|
||||
|
||||
filtered_sources.append(source)
|
||||
|
||||
# Sort by credibility score and limit to max_sources
|
||||
filtered_sources.sort(key=lambda s: s.credibility_score or 0.8, reverse=True)
|
||||
filtered_sources = filtered_sources[:self.max_sources]
|
||||
|
||||
# Fail-open: if everything was filtered out, return a trimmed set of original sources
|
||||
if not filtered_sources and sources:
|
||||
logger.warning("All sources filtered out by thresholds. Falling back to top sources without strict filters.")
|
||||
fallback = sorted(
|
||||
sources,
|
||||
key=lambda s: (s.credibility_score or 0.8),
|
||||
reverse=True
|
||||
)[: self.max_sources]
|
||||
return fallback
|
||||
|
||||
logger.info(f"Filtered sources: {len(sources)} → {len(filtered_sources)}")
|
||||
return filtered_sources
|
||||
|
||||
def filter_grounding_metadata(self, grounding_metadata: Optional[GroundingMetadata]) -> Optional[GroundingMetadata]:
|
||||
"""
|
||||
Filter grounding metadata to keep only high-confidence, relevant data.
|
||||
|
||||
Args:
|
||||
grounding_metadata: Raw grounding metadata to filter
|
||||
|
||||
Returns:
|
||||
Filtered grounding metadata with high-quality data only
|
||||
"""
|
||||
if not grounding_metadata:
|
||||
return None
|
||||
|
||||
# Filter grounding chunks by confidence
|
||||
filtered_chunks = []
|
||||
for chunk in grounding_metadata.grounding_chunks:
|
||||
if chunk.confidence_score and chunk.confidence_score >= self.min_grounding_confidence:
|
||||
filtered_chunks.append(chunk)
|
||||
|
||||
# Limit chunks to max_grounding_chunks
|
||||
filtered_chunks = filtered_chunks[:self.max_grounding_chunks]
|
||||
|
||||
# Filter grounding supports by confidence
|
||||
filtered_supports = []
|
||||
for support in grounding_metadata.grounding_supports:
|
||||
if support.confidence_scores and max(support.confidence_scores) >= self.min_grounding_confidence:
|
||||
filtered_supports.append(support)
|
||||
|
||||
# Filter citations by type and relevance
|
||||
filtered_citations = []
|
||||
for citation in grounding_metadata.citations:
|
||||
if self._is_citation_relevant(citation):
|
||||
filtered_citations.append(citation)
|
||||
|
||||
# Fail-open strategies to avoid empty UI:
|
||||
if not filtered_chunks and grounding_metadata.grounding_chunks:
|
||||
logger.warning("All grounding chunks filtered out. Falling back to first N chunks without confidence filter.")
|
||||
filtered_chunks = grounding_metadata.grounding_chunks[: self.max_grounding_chunks]
|
||||
if not filtered_supports and grounding_metadata.grounding_supports:
|
||||
logger.warning("All grounding supports filtered out. Falling back to first N supports without confidence filter.")
|
||||
filtered_supports = grounding_metadata.grounding_supports[: self.max_grounding_chunks]
|
||||
|
||||
# Create filtered grounding metadata
|
||||
filtered_metadata = GroundingMetadata(
|
||||
grounding_chunks=filtered_chunks,
|
||||
grounding_supports=filtered_supports,
|
||||
citations=filtered_citations,
|
||||
search_entry_point=grounding_metadata.search_entry_point,
|
||||
web_search_queries=grounding_metadata.web_search_queries
|
||||
)
|
||||
|
||||
logger.info(f"Filtered grounding metadata: {len(grounding_metadata.grounding_chunks)} chunks → {len(filtered_chunks)} chunks")
|
||||
return filtered_metadata
|
||||
|
||||
def clean_keyword_analysis(self, keyword_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Clean and deduplicate keyword analysis data.
|
||||
|
||||
Args:
|
||||
keyword_analysis: Raw keyword analysis data
|
||||
|
||||
Returns:
|
||||
Cleaned and deduplicated keyword analysis
|
||||
"""
|
||||
if not keyword_analysis:
|
||||
return {}
|
||||
|
||||
cleaned_analysis = {}
|
||||
|
||||
# Clean and deduplicate keyword lists
|
||||
keyword_categories = ['primary', 'secondary', 'long_tail', 'semantic_keywords', 'trending_terms']
|
||||
|
||||
for category in keyword_categories:
|
||||
if category in keyword_analysis and isinstance(keyword_analysis[category], list):
|
||||
cleaned_keywords = self._clean_keyword_list(keyword_analysis[category])
|
||||
cleaned_analysis[category] = cleaned_keywords[:self.max_keywords_per_category]
|
||||
|
||||
# Clean other fields
|
||||
other_fields = ['search_intent', 'difficulty', 'analysis_insights']
|
||||
for field in other_fields:
|
||||
if field in keyword_analysis:
|
||||
cleaned_analysis[field] = keyword_analysis[field]
|
||||
|
||||
# Clean content gaps separately (handled by filter_content_gaps)
|
||||
# Don't add content_gaps if it's empty to avoid adding empty lists
|
||||
if 'content_gaps' in keyword_analysis and keyword_analysis['content_gaps']:
|
||||
cleaned_analysis['content_gaps'] = keyword_analysis['content_gaps'] # Will be filtered later
|
||||
|
||||
logger.info(f"Cleaned keyword analysis: {len(keyword_analysis)} categories → {len(cleaned_analysis)} categories")
|
||||
return cleaned_analysis
|
||||
|
||||
def clean_competitor_analysis(self, competitor_analysis: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Clean and validate competitor analysis data.
|
||||
|
||||
Args:
|
||||
competitor_analysis: Raw competitor analysis data
|
||||
|
||||
Returns:
|
||||
Cleaned competitor analysis data
|
||||
"""
|
||||
if not competitor_analysis:
|
||||
return {}
|
||||
|
||||
cleaned_analysis = {}
|
||||
|
||||
# Clean competitor lists
|
||||
competitor_lists = ['top_competitors', 'opportunities', 'competitive_advantages']
|
||||
for field in competitor_lists:
|
||||
if field in competitor_analysis and isinstance(competitor_analysis[field], list):
|
||||
cleaned_list = [item.strip() for item in competitor_analysis[field] if item.strip()]
|
||||
cleaned_analysis[field] = cleaned_list[:10] # Limit to top 10
|
||||
|
||||
# Clean other fields
|
||||
other_fields = ['market_positioning', 'competitive_landscape', 'market_share']
|
||||
for field in other_fields:
|
||||
if field in competitor_analysis:
|
||||
cleaned_analysis[field] = competitor_analysis[field]
|
||||
|
||||
logger.info(f"Cleaned competitor analysis: {len(competitor_analysis)} fields → {len(cleaned_analysis)} fields")
|
||||
return cleaned_analysis
|
||||
|
||||
def filter_content_gaps(self, content_gaps: List[str], research_data: BlogResearchResponse) -> List[str]:
|
||||
"""
|
||||
Filter content gaps to keep only actionable, high-value ones.
|
||||
|
||||
Args:
|
||||
content_gaps: List of identified content gaps
|
||||
research_data: Research data for context
|
||||
|
||||
Returns:
|
||||
Filtered list of actionable content gaps
|
||||
"""
|
||||
if not content_gaps:
|
||||
return []
|
||||
|
||||
filtered_gaps = []
|
||||
|
||||
for gap in content_gaps:
|
||||
# Quality filters
|
||||
if not self._is_gap_high_quality(gap):
|
||||
continue
|
||||
|
||||
# Relevance filters
|
||||
if not self._is_gap_relevant_to_topic(gap, research_data):
|
||||
continue
|
||||
|
||||
# Actionability filters
|
||||
if not self._is_gap_actionable(gap):
|
||||
continue
|
||||
|
||||
filtered_gaps.append(gap)
|
||||
|
||||
# Limit to max_content_gaps
|
||||
filtered_gaps = filtered_gaps[:self.max_content_gaps]
|
||||
|
||||
logger.info(f"Filtered content gaps: {len(content_gaps)} → {len(filtered_gaps)}")
|
||||
return filtered_gaps
|
||||
|
||||
# Private helper methods
|
||||
|
||||
def _is_source_high_quality(self, source: ResearchSource) -> bool:
|
||||
"""Check if source meets quality criteria."""
|
||||
# Credibility score check
|
||||
if source.credibility_score and source.credibility_score < self.min_credibility_score:
|
||||
return False
|
||||
|
||||
# Excerpt length check
|
||||
if source.excerpt and len(source.excerpt) < self.min_excerpt_length:
|
||||
return False
|
||||
|
||||
# Title quality check
|
||||
if not source.title or len(source.title.strip()) < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_source_relevant(self, source: ResearchSource) -> bool:
|
||||
"""Check if source is relevant (not irrelevant patterns)."""
|
||||
if not source.url:
|
||||
return True # Keep sources without URLs
|
||||
|
||||
# Check against irrelevant patterns
|
||||
for pattern in self.irrelevant_patterns:
|
||||
if re.search(pattern, source.url, re.IGNORECASE):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_source_recent(self, source: ResearchSource) -> bool:
|
||||
"""Check if source is recent enough."""
|
||||
if not source.published_at:
|
||||
return True # Keep sources without dates
|
||||
|
||||
try:
|
||||
# Parse date (assuming ISO format or common formats)
|
||||
published_date = self._parse_date(source.published_at)
|
||||
if published_date:
|
||||
cutoff_date = datetime.now() - timedelta(days=self.max_source_age_days)
|
||||
return published_date >= cutoff_date
|
||||
except Exception as e:
|
||||
logger.warning(f"Error parsing date '{source.published_at}': {e}")
|
||||
|
||||
return True # Keep sources with unparseable dates
|
||||
|
||||
def _is_citation_relevant(self, citation: Citation) -> bool:
|
||||
"""Check if citation is relevant and high-quality."""
|
||||
# Check citation type
|
||||
relevant_types = ['expert_opinion', 'statistical_data', 'recent_news', 'research_study']
|
||||
if citation.citation_type not in relevant_types:
|
||||
return False
|
||||
|
||||
# Check text quality
|
||||
if not citation.text or len(citation.text.strip()) < 20:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_gap_high_quality(self, gap: str) -> bool:
|
||||
"""Check if content gap is high quality."""
|
||||
gap = gap.strip()
|
||||
|
||||
# Length check
|
||||
if len(gap) < 10:
|
||||
return False
|
||||
|
||||
# Generic gap check
|
||||
generic_gaps = ['general', 'overview', 'introduction', 'basics', 'fundamentals']
|
||||
if gap.lower() in generic_gaps:
|
||||
return False
|
||||
|
||||
# Check for meaningful content
|
||||
if len(gap.split()) < 3:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _is_gap_relevant_to_topic(self, gap: str, research_data: BlogResearchResponse) -> bool:
|
||||
"""Check if content gap is relevant to the research topic."""
|
||||
# Simple relevance check - could be enhanced with more sophisticated matching
|
||||
primary_keywords = research_data.keyword_analysis.get('primary', [])
|
||||
|
||||
if not primary_keywords:
|
||||
return True # Keep gaps if no keywords available
|
||||
|
||||
gap_lower = gap.lower()
|
||||
for keyword in primary_keywords:
|
||||
if keyword.lower() in gap_lower:
|
||||
return True
|
||||
|
||||
# If no direct keyword match, check for common AI-related terms
|
||||
ai_terms = ['ai', 'artificial intelligence', 'machine learning', 'automation', 'technology', 'digital']
|
||||
for term in ai_terms:
|
||||
if term in gap_lower:
|
||||
return True
|
||||
|
||||
return True # Default to keeping gaps if no clear relevance check
|
||||
|
||||
def _is_gap_actionable(self, gap: str) -> bool:
|
||||
"""Check if content gap is actionable (can be addressed with content)."""
|
||||
gap_lower = gap.lower()
|
||||
|
||||
# Check for actionable indicators
|
||||
actionable_indicators = [
|
||||
'how to', 'guide', 'tutorial', 'steps', 'process', 'method',
|
||||
'best practices', 'tips', 'strategies', 'techniques', 'approach',
|
||||
'comparison', 'vs', 'versus', 'difference', 'pros and cons',
|
||||
'trends', 'future', '2024', '2025', 'emerging', 'new'
|
||||
]
|
||||
|
||||
for indicator in actionable_indicators:
|
||||
if indicator in gap_lower:
|
||||
return True
|
||||
|
||||
return True # Default to actionable if no specific indicators
|
||||
|
||||
def _clean_keyword_list(self, keywords: List[str]) -> List[str]:
|
||||
"""Clean and deduplicate a list of keywords."""
|
||||
cleaned_keywords = []
|
||||
seen_keywords = set()
|
||||
|
||||
for keyword in keywords:
|
||||
if not keyword or not isinstance(keyword, str):
|
||||
continue
|
||||
|
||||
# Clean keyword
|
||||
cleaned_keyword = keyword.strip().lower()
|
||||
|
||||
# Skip empty or too short keywords
|
||||
if len(cleaned_keyword) < 2:
|
||||
continue
|
||||
|
||||
# Skip stop words
|
||||
if cleaned_keyword in self.stop_words:
|
||||
continue
|
||||
|
||||
# Skip duplicates
|
||||
if cleaned_keyword in seen_keywords:
|
||||
continue
|
||||
|
||||
cleaned_keywords.append(cleaned_keyword)
|
||||
seen_keywords.add(cleaned_keyword)
|
||||
|
||||
return cleaned_keywords
|
||||
|
||||
def _parse_date(self, date_str: str) -> Optional[datetime]:
|
||||
"""Parse date string into datetime object."""
|
||||
if not date_str:
|
||||
return None
|
||||
|
||||
# Common date formats
|
||||
date_formats = [
|
||||
'%Y-%m-%d',
|
||||
'%Y-%m-%dT%H:%M:%S',
|
||||
'%Y-%m-%dT%H:%M:%SZ',
|
||||
'%Y-%m-%dT%H:%M:%S.%fZ',
|
||||
'%B %d, %Y',
|
||||
'%b %d, %Y',
|
||||
'%d %B %Y',
|
||||
'%d %b %Y',
|
||||
'%m/%d/%Y',
|
||||
'%d/%m/%Y'
|
||||
]
|
||||
|
||||
for fmt in date_formats:
|
||||
try:
|
||||
return datetime.strptime(date_str, fmt)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _log_filtering_results(self, original_counts: Dict[str, int], filtered_research: BlogResearchResponse):
|
||||
"""Log the results of filtering operations."""
|
||||
filtered_counts = {
|
||||
'sources': len(filtered_research.sources),
|
||||
'grounding_chunks': len(filtered_research.grounding_metadata.grounding_chunks) if filtered_research.grounding_metadata else 0,
|
||||
'grounding_supports': len(filtered_research.grounding_metadata.grounding_supports) if filtered_research.grounding_metadata else 0,
|
||||
'citations': len(filtered_research.grounding_metadata.citations) if filtered_research.grounding_metadata else 0,
|
||||
}
|
||||
|
||||
logger.info("📊 Research Data Filtering Results:")
|
||||
for key, original_count in original_counts.items():
|
||||
filtered_count = filtered_counts[key]
|
||||
reduction_percent = ((original_count - filtered_count) / original_count * 100) if original_count > 0 else 0
|
||||
logger.info(f" {key}: {original_count} → {filtered_count} ({reduction_percent:.1f}% reduction)")
|
||||
|
||||
# Log content gaps filtering
|
||||
original_gaps = len(filtered_research.keyword_analysis.get('content_gaps', []))
|
||||
logger.info(f" content_gaps: {original_gaps} → {len(filtered_research.keyword_analysis.get('content_gaps', []))}")
|
||||
|
||||
logger.info("✅ Research data filtering completed successfully")
|
||||
226
backend/services/blog_writer/research/exa_provider.py
Normal file
226
backend/services/blog_writer/research/exa_provider.py
Normal file
@@ -0,0 +1,226 @@
|
||||
"""
|
||||
Exa Research Provider
|
||||
|
||||
Neural search implementation using Exa API for high-quality, citation-rich research.
|
||||
"""
|
||||
|
||||
from exa_py import Exa
|
||||
import os
|
||||
from loguru import logger
|
||||
from models.subscription_models import APIProvider
|
||||
from .base_provider import ResearchProvider as BaseProvider
|
||||
|
||||
|
||||
class ExaResearchProvider(BaseProvider):
|
||||
"""Exa neural search provider."""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("EXA_API_KEY")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("EXA_API_KEY not configured")
|
||||
self.exa = Exa(self.api_key)
|
||||
logger.info("✅ Exa Research Provider initialized")
|
||||
|
||||
async def search(self, prompt, topic, industry, target_audience, config, user_id):
|
||||
"""Execute Exa neural search and return standardized results."""
|
||||
# Build Exa query
|
||||
query = f"{topic} {industry} {target_audience}"
|
||||
|
||||
# Determine category: use exa_category if set, otherwise map from source_types
|
||||
category = config.exa_category if config.exa_category else self._map_source_type_to_category(config.source_types)
|
||||
|
||||
# Build search kwargs - use correct Exa API format
|
||||
search_kwargs = {
|
||||
'type': config.exa_search_type or "auto",
|
||||
'num_results': min(config.max_sources, 25),
|
||||
'text': {'max_characters': 1000},
|
||||
'summary': {'query': f"Key insights about {topic}"},
|
||||
'highlights': {
|
||||
'num_sentences': 2,
|
||||
'highlights_per_url': 3
|
||||
}
|
||||
}
|
||||
|
||||
# Add optional filters
|
||||
if category:
|
||||
search_kwargs['category'] = category
|
||||
if config.exa_include_domains:
|
||||
search_kwargs['include_domains'] = config.exa_include_domains
|
||||
if config.exa_exclude_domains:
|
||||
search_kwargs['exclude_domains'] = config.exa_exclude_domains
|
||||
|
||||
logger.info(f"[Exa Research] Executing search: {query}")
|
||||
|
||||
# Execute Exa search - pass contents parameters directly, not nested
|
||||
try:
|
||||
results = self.exa.search_and_contents(
|
||||
query,
|
||||
text={'max_characters': 1000},
|
||||
summary={'query': f"Key insights about {topic}"},
|
||||
highlights={'num_sentences': 2, 'highlights_per_url': 3},
|
||||
type=config.exa_search_type or "auto",
|
||||
num_results=min(config.max_sources, 25),
|
||||
**({k: v for k, v in {
|
||||
'category': category,
|
||||
'include_domains': config.exa_include_domains,
|
||||
'exclude_domains': config.exa_exclude_domains
|
||||
}.items() if v})
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Exa Research] API call failed: {e}")
|
||||
# Try simpler call without contents if the above fails
|
||||
try:
|
||||
logger.info("[Exa Research] Retrying with simplified parameters")
|
||||
results = self.exa.search_and_contents(
|
||||
query,
|
||||
type=config.exa_search_type or "auto",
|
||||
num_results=min(config.max_sources, 25),
|
||||
**({k: v for k, v in {
|
||||
'category': category,
|
||||
'include_domains': config.exa_include_domains,
|
||||
'exclude_domains': config.exa_exclude_domains
|
||||
}.items() if v})
|
||||
)
|
||||
except Exception as retry_error:
|
||||
logger.error(f"[Exa Research] Retry also failed: {retry_error}")
|
||||
raise RuntimeError(f"Exa search failed: {str(retry_error)}") from retry_error
|
||||
|
||||
# Transform to standardized format
|
||||
sources = self._transform_sources(results.results)
|
||||
content = self._aggregate_content(results.results)
|
||||
search_type = getattr(results, 'resolvedSearchType', 'neural') if hasattr(results, 'resolvedSearchType') else 'neural'
|
||||
|
||||
# Get cost if available
|
||||
cost = 0.005 # Default Exa cost for 1-25 results
|
||||
if hasattr(results, 'costDollars'):
|
||||
if hasattr(results.costDollars, 'total'):
|
||||
cost = results.costDollars.total
|
||||
|
||||
logger.info(f"[Exa Research] Search completed: {len(sources)} sources, type: {search_type}")
|
||||
|
||||
return {
|
||||
'sources': sources,
|
||||
'content': content,
|
||||
'search_type': search_type,
|
||||
'provider': 'exa',
|
||||
'search_queries': [query],
|
||||
'cost': {'total': cost}
|
||||
}
|
||||
|
||||
def get_provider_enum(self):
|
||||
"""Return EXA provider enum for subscription tracking."""
|
||||
return APIProvider.EXA
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Estimate token usage for Exa (not token-based)."""
|
||||
return 0 # Exa is per-search, not token-based
|
||||
|
||||
def _map_source_type_to_category(self, source_types):
|
||||
"""Map SourceType enum to Exa category parameter."""
|
||||
if not source_types:
|
||||
return None
|
||||
|
||||
category_map = {
|
||||
'research paper': 'research paper',
|
||||
'news': 'news',
|
||||
'web': 'personal site',
|
||||
'industry': 'company',
|
||||
'expert': 'linkedin profile'
|
||||
}
|
||||
|
||||
for st in source_types:
|
||||
if st.value in category_map:
|
||||
return category_map[st.value]
|
||||
|
||||
return None
|
||||
|
||||
def _transform_sources(self, results):
|
||||
"""Transform Exa results to ResearchSource format."""
|
||||
sources = []
|
||||
for idx, result in enumerate(results):
|
||||
source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '')
|
||||
|
||||
sources.append({
|
||||
'title': result.title if hasattr(result, 'title') else '',
|
||||
'url': result.url if hasattr(result, 'url') else '',
|
||||
'excerpt': self._get_excerpt(result),
|
||||
'credibility_score': 0.85, # Exa results are high quality
|
||||
'published_at': result.publishedDate if hasattr(result, 'publishedDate') else None,
|
||||
'index': idx,
|
||||
'source_type': source_type,
|
||||
'content': result.text if hasattr(result, 'text') else '',
|
||||
'highlights': result.highlights if hasattr(result, 'highlights') else [],
|
||||
'summary': result.summary if hasattr(result, 'summary') else ''
|
||||
})
|
||||
|
||||
return sources
|
||||
|
||||
def _get_excerpt(self, result):
|
||||
"""Extract excerpt from Exa result."""
|
||||
if hasattr(result, 'text') and result.text:
|
||||
return result.text[:500]
|
||||
elif hasattr(result, 'summary') and result.summary:
|
||||
return result.summary
|
||||
return ''
|
||||
|
||||
def _determine_source_type(self, url):
|
||||
"""Determine source type from URL."""
|
||||
if not url:
|
||||
return 'web'
|
||||
|
||||
url_lower = url.lower()
|
||||
if 'arxiv.org' in url_lower or 'research' in url_lower:
|
||||
return 'academic'
|
||||
elif any(news in url_lower for news in ['cnn.com', 'bbc.com', 'reuters.com', 'theguardian.com']):
|
||||
return 'news'
|
||||
elif 'linkedin.com' in url_lower:
|
||||
return 'expert'
|
||||
else:
|
||||
return 'web'
|
||||
|
||||
def _aggregate_content(self, results):
|
||||
"""Aggregate content from Exa results for LLM analysis."""
|
||||
content_parts = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
content_parts.append(f"Source {idx + 1}: {result.summary}")
|
||||
elif hasattr(result, 'text') and result.text:
|
||||
content_parts.append(f"Source {idx + 1}: {result.text[:1000]}")
|
||||
|
||||
return "\n\n".join(content_parts)
|
||||
|
||||
def track_exa_usage(self, user_id: str, cost: float):
|
||||
"""Track Exa API usage after successful call."""
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
# Update exa_calls and exa_cost via SQL UPDATE
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET exa_calls = COALESCE(exa_calls, 0) + 1,
|
||||
exa_cost = COALESCE(exa_cost, 0) + :cost,
|
||||
total_calls = total_calls + 1,
|
||||
total_cost = total_cost + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[Exa] Tracked usage: user={user_id}, cost=${cost}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Exa] Failed to track usage: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
40
backend/services/blog_writer/research/google_provider.py
Normal file
40
backend/services/blog_writer/research/google_provider.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Google Research Provider
|
||||
|
||||
Wrapper for Gemini native Google Search grounding to match base provider interface.
|
||||
"""
|
||||
|
||||
from services.llm_providers.gemini_grounded_provider import GeminiGroundedProvider
|
||||
from models.subscription_models import APIProvider
|
||||
from .base_provider import ResearchProvider as BaseProvider
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class GoogleResearchProvider(BaseProvider):
|
||||
"""Google research provider using Gemini native grounding."""
|
||||
|
||||
def __init__(self):
|
||||
self.gemini = GeminiGroundedProvider()
|
||||
|
||||
async def search(self, prompt, topic, industry, target_audience, config, user_id):
|
||||
"""Call Gemini grounding with pre-flight validation."""
|
||||
logger.info(f"[Google Research] Executing search for topic: {topic}")
|
||||
|
||||
result = await self.gemini.generate_grounded_content(
|
||||
prompt=prompt,
|
||||
content_type="research",
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
validate_subsequent_operations=True
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_provider_enum(self):
|
||||
"""Return GEMINI provider enum for subscription tracking."""
|
||||
return APIProvider.GEMINI
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Estimate token usage for Google grounding."""
|
||||
return 1200 # Conservative estimate
|
||||
|
||||
79
backend/services/blog_writer/research/keyword_analyzer.py
Normal file
79
backend/services/blog_writer/research/keyword_analyzer.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""
|
||||
Keyword Analyzer - AI-powered keyword analysis for research content.
|
||||
|
||||
Extracts and analyzes keywords from research content using structured AI responses.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class KeywordAnalyzer:
|
||||
"""Analyzes keywords from research content using AI-powered extraction."""
|
||||
|
||||
def analyze(self, content: str, original_keywords: List[str], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Parse comprehensive keyword analysis from the research content using AI."""
|
||||
# Use AI to extract and analyze keywords from the rich research content
|
||||
keyword_prompt = f"""
|
||||
Analyze the following research content and extract comprehensive keyword insights for: {', '.join(original_keywords)}
|
||||
|
||||
Research Content:
|
||||
{content[:3000]} # Limit to avoid token limits
|
||||
|
||||
Extract and analyze:
|
||||
1. Primary keywords (main topic terms)
|
||||
2. Secondary keywords (related terms, synonyms)
|
||||
3. Long-tail opportunities (specific phrases people search for)
|
||||
4. Search intent (informational, commercial, navigational, transactional)
|
||||
5. Keyword difficulty assessment (1-10 scale)
|
||||
6. Content gaps (what competitors are missing)
|
||||
7. Semantic keywords (related concepts)
|
||||
8. Trending terms (emerging keywords)
|
||||
|
||||
Respond with JSON:
|
||||
{{
|
||||
"primary": ["keyword1", "keyword2"],
|
||||
"secondary": ["related1", "related2"],
|
||||
"long_tail": ["specific phrase 1", "specific phrase 2"],
|
||||
"search_intent": "informational|commercial|navigational|transactional",
|
||||
"difficulty": 7,
|
||||
"content_gaps": ["gap1", "gap2"],
|
||||
"semantic_keywords": ["concept1", "concept2"],
|
||||
"trending_terms": ["trend1", "trend2"],
|
||||
"analysis_insights": "Brief analysis of keyword landscape"
|
||||
}}
|
||||
"""
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
keyword_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"primary": {"type": "array", "items": {"type": "string"}},
|
||||
"secondary": {"type": "array", "items": {"type": "string"}},
|
||||
"long_tail": {"type": "array", "items": {"type": "string"}},
|
||||
"search_intent": {"type": "string"},
|
||||
"difficulty": {"type": "integer"},
|
||||
"content_gaps": {"type": "array", "items": {"type": "string"}},
|
||||
"semantic_keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"trending_terms": {"type": "array", "items": {"type": "string"}},
|
||||
"analysis_insights": {"type": "string"}
|
||||
},
|
||||
"required": ["primary", "secondary", "long_tail", "search_intent", "difficulty", "content_gaps", "semantic_keywords", "trending_terms", "analysis_insights"]
|
||||
}
|
||||
|
||||
keyword_analysis = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = keyword_analysis.get('error', 'Unknown error') if isinstance(keyword_analysis, dict) else str(keyword_analysis)
|
||||
logger.error(f"AI keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Keyword analysis failed: {error_msg}")
|
||||
|
||||
914
backend/services/blog_writer/research/research_service.py
Normal file
914
backend/services/blog_writer/research/research_service.py
Normal file
@@ -0,0 +1,914 @@
|
||||
"""
|
||||
Research Service - Core research functionality for AI Blog Writer.
|
||||
|
||||
Handles Google Search grounding, caching, and research orchestration.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
BlogResearchResponse,
|
||||
ResearchSource,
|
||||
GroundingMetadata,
|
||||
GroundingChunk,
|
||||
GroundingSupport,
|
||||
Citation,
|
||||
ResearchConfig,
|
||||
ResearchMode,
|
||||
ResearchProvider,
|
||||
)
|
||||
from services.blog_writer.logger_config import blog_writer_logger, log_function_call
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
from .content_angle_generator import ContentAngleGenerator
|
||||
from .data_filter import ResearchDataFilter
|
||||
from .research_strategies import get_strategy_for_mode
|
||||
|
||||
|
||||
class ResearchService:
|
||||
"""Service for conducting comprehensive research using Google Search grounding."""
|
||||
|
||||
def __init__(self):
|
||||
self.keyword_analyzer = KeywordAnalyzer()
|
||||
self.competitor_analyzer = CompetitorAnalyzer()
|
||||
self.content_angle_generator = ContentAngleGenerator()
|
||||
self.data_filter = ResearchDataFilter()
|
||||
|
||||
@log_function_call("research_operation")
|
||||
async def research(self, request: BlogResearchRequest, user_id: str) -> BlogResearchResponse:
|
||||
"""
|
||||
Stage 1: Research & Strategy (AI Orchestration)
|
||||
Uses ONLY Gemini's native Google Search grounding - ONE API call for everything.
|
||||
Follows LinkedIn service pattern for efficiency and cost optimization.
|
||||
Includes intelligent caching for exact keyword matches.
|
||||
"""
|
||||
try:
|
||||
from services.cache.research_cache import research_cache
|
||||
|
||||
topic = request.topic or ", ".join(request.keywords)
|
||||
industry = request.industry or (request.persona.industry if request.persona and request.persona.industry else "General")
|
||||
target_audience = getattr(request.persona, 'target_audience', 'General') if request.persona else 'General'
|
||||
|
||||
# Log research parameters
|
||||
blog_writer_logger.log_operation_start(
|
||||
"research",
|
||||
topic=topic,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
keywords=request.keywords,
|
||||
keyword_count=len(request.keywords)
|
||||
)
|
||||
|
||||
# Check cache first for exact keyword match
|
||||
cached_result = research_cache.get_cached_result(
|
||||
keywords=request.keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience
|
||||
)
|
||||
|
||||
if cached_result:
|
||||
logger.info(f"Returning cached research result for keywords: {request.keywords}")
|
||||
blog_writer_logger.log_operation_end("research", 0, success=True, cache_hit=True)
|
||||
# Normalize cached data to fix None values in confidence_scores
|
||||
normalized_result = self._normalize_cached_research_data(cached_result)
|
||||
return BlogResearchResponse(**normalized_result)
|
||||
|
||||
# User ID validation (validation logic is now in Google Grounding provider)
|
||||
if not user_id:
|
||||
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
|
||||
|
||||
# Cache miss - proceed with API call
|
||||
logger.info(f"Cache miss - making API call for keywords: {request.keywords}")
|
||||
blog_writer_logger.log_operation_start("research_api_call", api_name="research", operation="research")
|
||||
|
||||
# Determine research mode and get appropriate strategy
|
||||
research_mode = request.research_mode or ResearchMode.BASIC
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
|
||||
strategy = get_strategy_for_mode(research_mode)
|
||||
|
||||
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
|
||||
|
||||
# Build research prompt based on strategy
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Route to appropriate provider
|
||||
if config.provider == ResearchProvider.EXA:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
finally:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
api_start_time = time.time()
|
||||
try:
|
||||
exa_provider = ExaResearchProvider()
|
||||
raw_result = await exa_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
# Track usage
|
||||
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
|
||||
exa_provider.track_exa_usage(user_id, cost)
|
||||
|
||||
# Log API call performance
|
||||
blog_writer_logger.log_api_call(
|
||||
"exa_search",
|
||||
"search_and_contents",
|
||||
api_duration_ms,
|
||||
token_usage={},
|
||||
content_length=len(raw_result.get('content', ''))
|
||||
)
|
||||
|
||||
# Extract content for downstream analysis
|
||||
content = raw_result.get('content', '')
|
||||
sources = raw_result.get('sources', [])
|
||||
search_widget = "" # Exa doesn't provide search widgets
|
||||
search_queries = raw_result.get('search_queries', [])
|
||||
grounding_metadata = None # Exa doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "EXA_API_KEY not configured" in str(e):
|
||||
logger.warning("Exa not configured, falling back to Google")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
raw_result = None
|
||||
else:
|
||||
raise
|
||||
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (similar to Exa)
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
tavily_limit = limits.get('limits', {}).get('tavily_calls', 0) if limits else 0
|
||||
|
||||
# Get current usage
|
||||
from models.subscription_models import UsageSummary
|
||||
from datetime import datetime
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = db_val.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
current_calls = getattr(usage, 'tavily_calls', 0) or 0 if usage else 0
|
||||
|
||||
if tavily_limit > 0 and current_calls >= tavily_limit:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': 'Tavily API call limit exceeded',
|
||||
'message': f'You have reached your Tavily API call limit ({tavily_limit} calls). Please upgrade your plan or wait for the next billing period.',
|
||||
'provider': 'tavily',
|
||||
'usage_info': {
|
||||
'current': current_calls,
|
||||
'limit': tavily_limit
|
||||
}
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Tavily limits: {e}")
|
||||
finally:
|
||||
db_val.close()
|
||||
|
||||
# Execute Tavily search
|
||||
api_start_time = time.time()
|
||||
try:
|
||||
tavily_provider = TavilyResearchProvider()
|
||||
raw_result = await tavily_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
# Track usage
|
||||
cost = raw_result.get('cost', {}).get('total', 0.001) if isinstance(raw_result.get('cost'), dict) else 0.001
|
||||
search_depth = config.tavily_search_depth or "basic"
|
||||
tavily_provider.track_tavily_usage(user_id, cost, search_depth)
|
||||
|
||||
# Log API call performance
|
||||
blog_writer_logger.log_api_call(
|
||||
"tavily_search",
|
||||
"search",
|
||||
api_duration_ms,
|
||||
token_usage={},
|
||||
content_length=len(raw_result.get('content', ''))
|
||||
)
|
||||
|
||||
# Extract content for downstream analysis
|
||||
content = raw_result.get('content', '')
|
||||
sources = raw_result.get('sources', [])
|
||||
search_widget = "" # Tavily doesn't provide search widgets
|
||||
search_queries = raw_result.get('search_queries', [])
|
||||
grounding_metadata = None # Tavily doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "TAVILY_API_KEY not configured" in str(e):
|
||||
logger.warning("Tavily not configured, falling back to Google")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
raw_result = None
|
||||
else:
|
||||
raise
|
||||
|
||||
if config.provider not in [ResearchProvider.EXA, ResearchProvider.TAVILY]:
|
||||
# Google research (existing flow) or fallback from Exa
|
||||
from .google_provider import GoogleResearchProvider
|
||||
import time
|
||||
|
||||
api_start_time = time.time()
|
||||
google_provider = GoogleResearchProvider()
|
||||
gemini_result = await google_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
api_duration_ms = (time.time() - api_start_time) * 1000
|
||||
|
||||
# Log API call performance
|
||||
blog_writer_logger.log_api_call(
|
||||
"gemini_grounded",
|
||||
"generate_grounded_content",
|
||||
api_duration_ms,
|
||||
token_usage=gemini_result.get("token_usage", {}),
|
||||
content_length=len(gemini_result.get("content", ""))
|
||||
)
|
||||
|
||||
# Extract sources and content
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
content = gemini_result.get("content", "")
|
||||
search_widget = gemini_result.get("search_widget", "") or ""
|
||||
search_queries = gemini_result.get("search_queries", []) or []
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
|
||||
# Continue with common analysis (same for both providers)
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
|
||||
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
|
||||
|
||||
# Log analysis results
|
||||
blog_writer_logger.log_performance(
|
||||
"research_analysis",
|
||||
len(content),
|
||||
"characters",
|
||||
sources_count=len(sources),
|
||||
search_queries_count=len(search_queries),
|
||||
keyword_analysis_keys=len(keyword_analysis),
|
||||
suggested_angles_count=len(suggested_angles)
|
||||
)
|
||||
|
||||
# Create the response
|
||||
response = BlogResearchResponse(
|
||||
success=True,
|
||||
sources=sources,
|
||||
keyword_analysis=keyword_analysis,
|
||||
competitor_analysis=competitor_analysis,
|
||||
suggested_angles=suggested_angles,
|
||||
# Add search widget and queries for UI display
|
||||
search_widget=search_widget if 'search_widget' in locals() else "",
|
||||
search_queries=search_queries if 'search_queries' in locals() else [],
|
||||
# Add grounding metadata for detailed UI display
|
||||
grounding_metadata=grounding_metadata,
|
||||
)
|
||||
|
||||
# Filter and clean research data for optimal AI processing
|
||||
filtered_response = self.data_filter.filter_research_data(response)
|
||||
logger.info("Research data filtering completed successfully")
|
||||
|
||||
# Cache the successful result for future exact keyword matches (both caches)
|
||||
persistent_research_cache.cache_result(
|
||||
keywords=request.keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
result=filtered_response.dict()
|
||||
)
|
||||
|
||||
# Also cache in memory for faster access
|
||||
research_cache.cache_result(
|
||||
keywords=request.keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
result=filtered_response.dict()
|
||||
)
|
||||
|
||||
return filtered_response
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (subscription errors) - let task manager handle it
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Research failed: {error_message}")
|
||||
|
||||
# Log error with full context
|
||||
blog_writer_logger.log_error(
|
||||
e,
|
||||
"research",
|
||||
context={
|
||||
"topic": topic,
|
||||
"keywords": request.keywords,
|
||||
"industry": industry,
|
||||
"target_audience": target_audience
|
||||
}
|
||||
)
|
||||
|
||||
# Import custom exceptions for better error handling
|
||||
from services.blog_writer.exceptions import (
|
||||
ResearchFailedException,
|
||||
APIRateLimitException,
|
||||
APITimeoutException,
|
||||
ValidationException
|
||||
)
|
||||
|
||||
# Determine if this is a retryable error
|
||||
retry_suggested = True
|
||||
user_message = "Research failed. Please try again with different keywords or check your internet connection."
|
||||
|
||||
if isinstance(e, APIRateLimitException):
|
||||
retry_suggested = True
|
||||
user_message = f"Rate limit exceeded. Please wait {e.context.get('retry_after', 60)} seconds before trying again."
|
||||
elif isinstance(e, APITimeoutException):
|
||||
retry_suggested = True
|
||||
user_message = "Research request timed out. Please try again with a shorter query or check your internet connection."
|
||||
elif isinstance(e, ValidationException):
|
||||
retry_suggested = False
|
||||
user_message = "Invalid research request. Please check your input parameters and try again."
|
||||
elif "401" in error_message or "403" in error_message:
|
||||
retry_suggested = False
|
||||
user_message = "Authentication failed. Please check your API credentials."
|
||||
elif "400" in error_message:
|
||||
retry_suggested = False
|
||||
user_message = "Invalid request. Please check your input parameters."
|
||||
|
||||
# Return a graceful failure response with enhanced error information
|
||||
return BlogResearchResponse(
|
||||
success=False,
|
||||
sources=[],
|
||||
keyword_analysis={},
|
||||
competitor_analysis={},
|
||||
suggested_angles=[],
|
||||
search_widget="",
|
||||
search_queries=[],
|
||||
error_message=user_message,
|
||||
retry_suggested=retry_suggested,
|
||||
error_code=getattr(e, 'error_code', 'RESEARCH_FAILED'),
|
||||
actionable_steps=getattr(e, 'actionable_steps', [
|
||||
"Try with different keywords",
|
||||
"Check your internet connection",
|
||||
"Wait a few minutes and try again",
|
||||
"Contact support if the issue persists"
|
||||
])
|
||||
)
|
||||
|
||||
@log_function_call("research_with_progress")
|
||||
async def research_with_progress(self, request: BlogResearchRequest, task_id: str, user_id: str) -> BlogResearchResponse:
|
||||
"""
|
||||
Research method with progress updates for real-time feedback.
|
||||
"""
|
||||
try:
|
||||
from services.cache.research_cache import research_cache
|
||||
from services.cache.persistent_research_cache import persistent_research_cache
|
||||
from api.blog_writer.task_manager import task_manager
|
||||
|
||||
topic = request.topic or ", ".join(request.keywords)
|
||||
industry = request.industry or (request.persona.industry if request.persona and request.persona.industry else "General")
|
||||
target_audience = getattr(request.persona, 'target_audience', 'General') if request.persona else 'General'
|
||||
|
||||
# Check cache first for exact keyword match (try both caches)
|
||||
await task_manager.update_progress(task_id, "🔍 Checking cache for existing research...")
|
||||
|
||||
# Try persistent cache first (survives restarts)
|
||||
cached_result = persistent_research_cache.get_cached_result(
|
||||
keywords=request.keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience
|
||||
)
|
||||
|
||||
# Fallback to in-memory cache
|
||||
if not cached_result:
|
||||
cached_result = research_cache.get_cached_result(
|
||||
keywords=request.keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience
|
||||
)
|
||||
|
||||
if cached_result:
|
||||
await task_manager.update_progress(task_id, "✅ Found cached research results! Returning instantly...")
|
||||
logger.info(f"Returning cached research result for keywords: {request.keywords}")
|
||||
# Normalize cached data to fix None values in confidence_scores
|
||||
normalized_result = self._normalize_cached_research_data(cached_result)
|
||||
return BlogResearchResponse(**normalized_result)
|
||||
|
||||
# User ID validation
|
||||
if not user_id:
|
||||
await task_manager.update_progress(task_id, "❌ Error: User ID is required for research operation")
|
||||
raise ValueError("user_id is required for research operation. Please provide Clerk user ID.")
|
||||
|
||||
# Determine research mode and get appropriate strategy
|
||||
research_mode = request.research_mode or ResearchMode.BASIC
|
||||
config = request.config or ResearchConfig(mode=research_mode, provider=ResearchProvider.GOOGLE)
|
||||
strategy = get_strategy_for_mode(research_mode)
|
||||
|
||||
logger.info(f"Research: mode={research_mode.value}, provider={config.provider.value}")
|
||||
|
||||
# Build research prompt based on strategy
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Route to appropriate provider
|
||||
if config.provider == ResearchProvider.EXA:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
except HTTPException as http_error:
|
||||
logger.error(f"Subscription limit exceeded for Exa research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
finally:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
|
||||
try:
|
||||
exa_provider = ExaResearchProvider()
|
||||
raw_result = await exa_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
|
||||
# Track usage
|
||||
cost = raw_result.get('cost', {}).get('total', 0.005) if isinstance(raw_result.get('cost'), dict) else 0.005
|
||||
exa_provider.track_exa_usage(user_id, cost)
|
||||
|
||||
# Extract content for downstream analysis
|
||||
# Handle None result case
|
||||
if raw_result is None:
|
||||
logger.error("raw_result is None after Exa search - this should not happen if HTTPException was raised")
|
||||
raise ValueError("Exa research result is None - search operation failed unexpectedly")
|
||||
|
||||
if not isinstance(raw_result, dict):
|
||||
logger.warning(f"raw_result is not a dict (type: {type(raw_result)}), using defaults")
|
||||
raw_result = {}
|
||||
|
||||
content = raw_result.get('content', '')
|
||||
sources = raw_result.get('sources', []) or []
|
||||
search_widget = "" # Exa doesn't provide search widgets
|
||||
search_queries = raw_result.get('search_queries', []) or []
|
||||
grounding_metadata = None # Exa doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "EXA_API_KEY not configured" in str(e):
|
||||
logger.warning("Exa not configured, falling back to Google")
|
||||
await task_manager.update_progress(task_id, "⚠️ Exa not configured, falling back to Google Search")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
else:
|
||||
raise
|
||||
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Tavily AI search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
tavily_limit = limits.get('limits', {}).get('tavily_calls', 0) if limits else 0
|
||||
|
||||
# Get current usage
|
||||
from models.subscription_models import UsageSummary
|
||||
from datetime import datetime
|
||||
current_period = pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
usage = db_val.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
current_calls = getattr(usage, 'tavily_calls', 0) or 0 if usage else 0
|
||||
|
||||
if tavily_limit > 0 and current_calls >= tavily_limit:
|
||||
await task_manager.update_progress(task_id, f"❌ Tavily API call limit exceeded ({current_calls}/{tavily_limit})")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
'error': 'Tavily API call limit exceeded',
|
||||
'message': f'You have reached your Tavily API call limit ({tavily_limit} calls). Please upgrade your plan or wait for the next billing period.',
|
||||
'provider': 'tavily',
|
||||
'usage_info': {
|
||||
'current': current_calls,
|
||||
'limit': tavily_limit
|
||||
}
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Tavily limits: {e}")
|
||||
finally:
|
||||
db_val.close()
|
||||
|
||||
# Execute Tavily search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Tavily AI search...")
|
||||
try:
|
||||
tavily_provider = TavilyResearchProvider()
|
||||
raw_result = await tavily_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
|
||||
# Track usage
|
||||
cost = raw_result.get('cost', {}).get('total', 0.001) if isinstance(raw_result.get('cost'), dict) else 0.001
|
||||
search_depth = config.tavily_search_depth or "basic"
|
||||
tavily_provider.track_tavily_usage(user_id, cost, search_depth)
|
||||
|
||||
# Extract content for downstream analysis
|
||||
if raw_result is None:
|
||||
logger.error("raw_result is None after Tavily search")
|
||||
raise ValueError("Tavily research result is None - search operation failed unexpectedly")
|
||||
|
||||
if not isinstance(raw_result, dict):
|
||||
logger.warning(f"raw_result is not a dict (type: {type(raw_result)}), using defaults")
|
||||
raw_result = {}
|
||||
|
||||
content = raw_result.get('content', '')
|
||||
sources = raw_result.get('sources', []) or []
|
||||
search_widget = "" # Tavily doesn't provide search widgets
|
||||
search_queries = raw_result.get('search_queries', []) or []
|
||||
grounding_metadata = None # Tavily doesn't provide grounding metadata
|
||||
|
||||
except RuntimeError as e:
|
||||
if "TAVILY_API_KEY not configured" in str(e):
|
||||
logger.warning("Tavily not configured, falling back to Google")
|
||||
await task_manager.update_progress(task_id, "⚠️ Tavily not configured, falling back to Google Search")
|
||||
config.provider = ResearchProvider.GOOGLE
|
||||
# Continue to Google flow below
|
||||
else:
|
||||
raise
|
||||
|
||||
if config.provider not in [ResearchProvider.EXA, ResearchProvider.TAVILY]:
|
||||
# Google research (existing flow)
|
||||
from .google_provider import GoogleResearchProvider
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Google Search grounding...")
|
||||
google_provider = GoogleResearchProvider()
|
||||
|
||||
await task_manager.update_progress(task_id, "🤖 Making AI request to Gemini with Google Search grounding...")
|
||||
try:
|
||||
gemini_result = await google_provider.search(
|
||||
research_prompt, topic, industry, target_audience, config, user_id
|
||||
)
|
||||
except HTTPException as http_error:
|
||||
logger.error(f"Subscription limit exceeded for Google research: {http_error.detail}")
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
|
||||
await task_manager.update_progress(task_id, "📊 Processing research results and extracting insights...")
|
||||
# Extract sources and content
|
||||
# Handle None result case
|
||||
if gemini_result is None:
|
||||
logger.error("gemini_result is None after search - this should not happen if HTTPException was raised")
|
||||
raise ValueError("Research result is None - search operation failed unexpectedly")
|
||||
|
||||
sources = self._extract_sources_from_grounding(gemini_result)
|
||||
content = gemini_result.get("content", "") if isinstance(gemini_result, dict) else ""
|
||||
search_widget = gemini_result.get("search_widget", "") or "" if isinstance(gemini_result, dict) else ""
|
||||
search_queries = gemini_result.get("search_queries", []) or [] if isinstance(gemini_result, dict) else []
|
||||
grounding_metadata = self._extract_grounding_metadata(gemini_result)
|
||||
|
||||
# Continue with common analysis (same for both providers)
|
||||
await task_manager.update_progress(task_id, "🔍 Analyzing keywords and content angles...")
|
||||
keyword_analysis = self.keyword_analyzer.analyze(content, request.keywords, user_id=user_id)
|
||||
competitor_analysis = self.competitor_analyzer.analyze(content, user_id=user_id)
|
||||
suggested_angles = self.content_angle_generator.generate(content, topic, industry, user_id=user_id)
|
||||
|
||||
await task_manager.update_progress(task_id, "💾 Caching results for future use...")
|
||||
logger.info(f"Research completed successfully with {len(sources)} sources and {len(search_queries)} search queries")
|
||||
|
||||
# Create the response
|
||||
response = BlogResearchResponse(
|
||||
success=True,
|
||||
sources=sources,
|
||||
keyword_analysis=keyword_analysis,
|
||||
competitor_analysis=competitor_analysis,
|
||||
suggested_angles=suggested_angles,
|
||||
# Add search widget and queries for UI display
|
||||
search_widget=search_widget if 'search_widget' in locals() else "",
|
||||
search_queries=search_queries if 'search_queries' in locals() else [],
|
||||
# Add grounding metadata for detailed UI display
|
||||
grounding_metadata=grounding_metadata,
|
||||
# Preserve original user keywords for caching
|
||||
original_keywords=request.keywords,
|
||||
)
|
||||
|
||||
# Filter and clean research data for optimal AI processing
|
||||
await task_manager.update_progress(task_id, "🔍 Filtering and cleaning research data...")
|
||||
filtered_response = self.data_filter.filter_research_data(response)
|
||||
logger.info("Research data filtering completed successfully")
|
||||
|
||||
# Cache the successful result for future exact keyword matches (both caches)
|
||||
persistent_research_cache.cache_result(
|
||||
keywords=request.keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
result=filtered_response.dict()
|
||||
)
|
||||
|
||||
# Also cache in memory for faster access
|
||||
research_cache.cache_result(
|
||||
keywords=request.keywords,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
result=filtered_response.dict()
|
||||
)
|
||||
|
||||
return filtered_response
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPException (subscription errors) - let task manager handle it
|
||||
raise
|
||||
except Exception as e:
|
||||
error_message = str(e)
|
||||
logger.error(f"Research failed: {error_message}")
|
||||
|
||||
# Log error with full context
|
||||
blog_writer_logger.log_error(
|
||||
e,
|
||||
"research",
|
||||
context={
|
||||
"topic": topic,
|
||||
"keywords": request.keywords,
|
||||
"industry": industry,
|
||||
"target_audience": target_audience
|
||||
}
|
||||
)
|
||||
|
||||
# Import custom exceptions for better error handling
|
||||
from services.blog_writer.exceptions import (
|
||||
ResearchFailedException,
|
||||
APIRateLimitException,
|
||||
APITimeoutException,
|
||||
ValidationException
|
||||
)
|
||||
|
||||
# Determine if this is a retryable error
|
||||
retry_suggested = True
|
||||
user_message = "Research failed. Please try again with different keywords or check your internet connection."
|
||||
|
||||
if isinstance(e, APIRateLimitException):
|
||||
retry_suggested = True
|
||||
user_message = f"Rate limit exceeded. Please wait {e.context.get('retry_after', 60)} seconds before trying again."
|
||||
elif isinstance(e, APITimeoutException):
|
||||
retry_suggested = True
|
||||
user_message = "Research request timed out. Please try again with a shorter query or check your internet connection."
|
||||
elif isinstance(e, ValidationException):
|
||||
retry_suggested = False
|
||||
user_message = "Invalid research request. Please check your input parameters and try again."
|
||||
elif "401" in error_message or "403" in error_message:
|
||||
retry_suggested = False
|
||||
user_message = "Authentication failed. Please check your API credentials."
|
||||
elif "400" in error_message:
|
||||
retry_suggested = False
|
||||
user_message = "Invalid request. Please check your input parameters."
|
||||
|
||||
# Return a graceful failure response with enhanced error information
|
||||
return BlogResearchResponse(
|
||||
success=False,
|
||||
sources=[],
|
||||
keyword_analysis={},
|
||||
competitor_analysis={},
|
||||
suggested_angles=[],
|
||||
search_widget="",
|
||||
search_queries=[],
|
||||
error_message=user_message,
|
||||
retry_suggested=retry_suggested,
|
||||
error_code=getattr(e, 'error_code', 'RESEARCH_FAILED'),
|
||||
actionable_steps=getattr(e, 'actionable_steps', [
|
||||
"Try with different keywords",
|
||||
"Check your internet connection",
|
||||
"Wait a few minutes and try again",
|
||||
"Contact support if the issue persists"
|
||||
])
|
||||
)
|
||||
|
||||
def _extract_sources_from_grounding(self, gemini_result: Dict[str, Any]) -> List[ResearchSource]:
|
||||
"""Extract sources from Gemini grounding metadata."""
|
||||
sources = []
|
||||
|
||||
# Handle None or invalid gemini_result
|
||||
if not gemini_result or not isinstance(gemini_result, dict):
|
||||
logger.warning("gemini_result is None or not a dict, returning empty sources")
|
||||
return sources
|
||||
|
||||
# The Gemini grounded provider already extracts sources and puts them in the 'sources' field
|
||||
raw_sources = gemini_result.get("sources", [])
|
||||
# Ensure raw_sources is a list (handle None case)
|
||||
if raw_sources is None:
|
||||
raw_sources = []
|
||||
|
||||
for src in raw_sources:
|
||||
source = ResearchSource(
|
||||
title=src.get("title", "Untitled"),
|
||||
url=src.get("url", ""),
|
||||
excerpt=src.get("content", "")[:500] if src.get("content") else f"Source from {src.get('title', 'web')}",
|
||||
credibility_score=float(src.get("credibility_score", 0.8)),
|
||||
published_at=str(src.get("publication_date", "2024-01-01")),
|
||||
index=src.get("index"),
|
||||
source_type=src.get("type", "web")
|
||||
)
|
||||
sources.append(source)
|
||||
|
||||
return sources
|
||||
|
||||
def _normalize_cached_research_data(self, cached_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Normalize cached research data to fix None values in confidence_scores.
|
||||
Ensures all GroundingSupport objects have confidence_scores as a list.
|
||||
"""
|
||||
if not isinstance(cached_data, dict):
|
||||
return cached_data
|
||||
|
||||
normalized = cached_data.copy()
|
||||
|
||||
# Normalize grounding_metadata if present
|
||||
if "grounding_metadata" in normalized and normalized["grounding_metadata"]:
|
||||
grounding_metadata = normalized["grounding_metadata"].copy() if isinstance(normalized["grounding_metadata"], dict) else {}
|
||||
|
||||
# Normalize grounding_supports
|
||||
if "grounding_supports" in grounding_metadata and isinstance(grounding_metadata["grounding_supports"], list):
|
||||
normalized_supports = []
|
||||
for support in grounding_metadata["grounding_supports"]:
|
||||
if isinstance(support, dict):
|
||||
normalized_support = support.copy()
|
||||
# Fix confidence_scores: ensure it's a list, not None
|
||||
if normalized_support.get("confidence_scores") is None:
|
||||
normalized_support["confidence_scores"] = []
|
||||
elif not isinstance(normalized_support.get("confidence_scores"), list):
|
||||
# If it's not a list, try to convert or default to empty list
|
||||
normalized_support["confidence_scores"] = []
|
||||
# Fix grounding_chunk_indices: ensure it's a list, not None
|
||||
if normalized_support.get("grounding_chunk_indices") is None:
|
||||
normalized_support["grounding_chunk_indices"] = []
|
||||
elif not isinstance(normalized_support.get("grounding_chunk_indices"), list):
|
||||
normalized_support["grounding_chunk_indices"] = []
|
||||
# Ensure segment_text is a string
|
||||
if normalized_support.get("segment_text") is None:
|
||||
normalized_support["segment_text"] = ""
|
||||
normalized_supports.append(normalized_support)
|
||||
else:
|
||||
normalized_supports.append(support)
|
||||
grounding_metadata["grounding_supports"] = normalized_supports
|
||||
|
||||
normalized["grounding_metadata"] = grounding_metadata
|
||||
|
||||
return normalized
|
||||
|
||||
def _extract_grounding_metadata(self, gemini_result: Dict[str, Any]) -> GroundingMetadata:
|
||||
"""Extract detailed grounding metadata from Gemini result."""
|
||||
grounding_chunks = []
|
||||
grounding_supports = []
|
||||
citations = []
|
||||
|
||||
# Handle None or invalid gemini_result
|
||||
if not gemini_result or not isinstance(gemini_result, dict):
|
||||
logger.warning("gemini_result is None or not a dict, returning empty grounding metadata")
|
||||
return GroundingMetadata(
|
||||
grounding_chunks=grounding_chunks,
|
||||
grounding_supports=grounding_supports,
|
||||
citations=citations
|
||||
)
|
||||
|
||||
# Extract grounding chunks from the raw grounding metadata
|
||||
raw_grounding = gemini_result.get("grounding_metadata", {})
|
||||
|
||||
# Handle case where grounding_metadata might be a GroundingMetadata object
|
||||
if hasattr(raw_grounding, 'grounding_chunks'):
|
||||
raw_chunks = raw_grounding.grounding_chunks
|
||||
else:
|
||||
raw_chunks = raw_grounding.get("grounding_chunks", []) if isinstance(raw_grounding, dict) else []
|
||||
|
||||
# Ensure raw_chunks is a list (handle None case)
|
||||
if raw_chunks is None:
|
||||
raw_chunks = []
|
||||
|
||||
for chunk in raw_chunks:
|
||||
if "web" in chunk:
|
||||
web_data = chunk["web"]
|
||||
grounding_chunk = GroundingChunk(
|
||||
title=web_data.get("title", "Untitled"),
|
||||
url=web_data.get("uri", ""),
|
||||
confidence_score=None # Will be set from supports
|
||||
)
|
||||
grounding_chunks.append(grounding_chunk)
|
||||
|
||||
# Extract grounding supports with confidence scores
|
||||
if hasattr(raw_grounding, 'grounding_supports'):
|
||||
raw_supports = raw_grounding.grounding_supports
|
||||
else:
|
||||
raw_supports = raw_grounding.get("grounding_supports", [])
|
||||
for support in raw_supports:
|
||||
# Handle both dictionary and GroundingSupport object formats
|
||||
if hasattr(support, 'confidence_scores'):
|
||||
confidence_scores = support.confidence_scores
|
||||
chunk_indices = support.grounding_chunk_indices
|
||||
segment_text = getattr(support, 'segment_text', '')
|
||||
start_index = getattr(support, 'start_index', None)
|
||||
end_index = getattr(support, 'end_index', None)
|
||||
else:
|
||||
confidence_scores = support.get("confidence_scores", [])
|
||||
chunk_indices = support.get("grounding_chunk_indices", [])
|
||||
segment = support.get("segment", {})
|
||||
segment_text = segment.get("text", "")
|
||||
start_index = segment.get("start_index")
|
||||
end_index = segment.get("end_index")
|
||||
|
||||
grounding_support = GroundingSupport(
|
||||
confidence_scores=confidence_scores,
|
||||
grounding_chunk_indices=chunk_indices,
|
||||
segment_text=segment_text,
|
||||
start_index=start_index,
|
||||
end_index=end_index
|
||||
)
|
||||
grounding_supports.append(grounding_support)
|
||||
|
||||
# Update confidence scores for chunks
|
||||
if confidence_scores and chunk_indices:
|
||||
avg_confidence = sum(confidence_scores) / len(confidence_scores)
|
||||
for idx in chunk_indices:
|
||||
if idx < len(grounding_chunks):
|
||||
grounding_chunks[idx].confidence_score = avg_confidence
|
||||
|
||||
# Extract citations from the raw result
|
||||
raw_citations = gemini_result.get("citations", [])
|
||||
for citation in raw_citations:
|
||||
citation_obj = Citation(
|
||||
citation_type=citation.get("type", "inline"),
|
||||
start_index=citation.get("start_index", 0),
|
||||
end_index=citation.get("end_index", 0),
|
||||
text=citation.get("text", ""),
|
||||
source_indices=citation.get("source_indices", []),
|
||||
reference=citation.get("reference", "")
|
||||
)
|
||||
citations.append(citation_obj)
|
||||
|
||||
# Extract search entry point and web search queries
|
||||
if hasattr(raw_grounding, 'search_entry_point'):
|
||||
search_entry_point = getattr(raw_grounding.search_entry_point, 'rendered_content', '') if raw_grounding.search_entry_point else ''
|
||||
else:
|
||||
search_entry_point = raw_grounding.get("search_entry_point", {}).get("rendered_content", "")
|
||||
|
||||
if hasattr(raw_grounding, 'web_search_queries'):
|
||||
web_search_queries = raw_grounding.web_search_queries
|
||||
else:
|
||||
web_search_queries = raw_grounding.get("web_search_queries", [])
|
||||
|
||||
return GroundingMetadata(
|
||||
grounding_chunks=grounding_chunks,
|
||||
grounding_supports=grounding_supports,
|
||||
citations=citations,
|
||||
search_entry_point=search_entry_point,
|
||||
web_search_queries=web_search_queries
|
||||
)
|
||||
230
backend/services/blog_writer/research/research_strategies.py
Normal file
230
backend/services/blog_writer/research/research_strategies.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""
|
||||
Research Strategy Pattern Implementation
|
||||
|
||||
Different strategies for executing research based on depth and focus.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import BlogResearchRequest, ResearchMode, ResearchConfig
|
||||
from .keyword_analyzer import KeywordAnalyzer
|
||||
from .competitor_analyzer import CompetitorAnalyzer
|
||||
from .content_angle_generator import ContentAngleGenerator
|
||||
|
||||
|
||||
class ResearchStrategy(ABC):
|
||||
"""Base class for research strategies."""
|
||||
|
||||
def __init__(self):
|
||||
self.keyword_analyzer = KeywordAnalyzer()
|
||||
self.competitor_analyzer = CompetitorAnalyzer()
|
||||
self.content_angle_generator = ContentAngleGenerator()
|
||||
|
||||
@abstractmethod
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build the research prompt for the strategy."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_mode(self) -> ResearchMode:
|
||||
"""Return the research mode this strategy handles."""
|
||||
pass
|
||||
|
||||
|
||||
class BasicResearchStrategy(ResearchStrategy):
|
||||
"""Basic research strategy - keyword focused, minimal analysis."""
|
||||
|
||||
def get_mode(self) -> ResearchMode:
|
||||
return ResearchMode.BASIC
|
||||
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build basic research prompt focused on podcast-ready, actionable insights."""
|
||||
prompt = f"""You are a podcast researcher creating TALKING POINTS and FACT CARDS for a {industry} audience of {target_audience}.
|
||||
|
||||
Research Topic: "{topic}"
|
||||
|
||||
Provide analysis in this EXACT format:
|
||||
|
||||
## PODCAST HOOKS (3)
|
||||
- [Hook line with tension + data point + source URL]
|
||||
|
||||
## OBJECTIONS & COUNTERS (3)
|
||||
- Objection: [common listener objection]
|
||||
Counter: [concise rebuttal with stat + source URL]
|
||||
|
||||
## KEY STATS & PROOF (6)
|
||||
- [Specific metric with %/number, date, and source URL]
|
||||
|
||||
## MINI CASE SNAPS (3)
|
||||
- [Brand/company], [what they did], [outcome metric], [source URL]
|
||||
|
||||
## KEYWORDS TO MENTION (Primary + 5 Secondary)
|
||||
- Primary: "{topic}"
|
||||
- Secondary: [5 related keywords]
|
||||
|
||||
## 5 CONTENT ANGLES
|
||||
1. [Angle with audience benefit + why-now]
|
||||
2. [Angle ...]
|
||||
3. [Angle ...]
|
||||
4. [Angle ...]
|
||||
5. [Angle ...]
|
||||
|
||||
## FACT CARD LIST (8)
|
||||
- For each: Quote/claim, source URL, published date, metric/context.
|
||||
|
||||
REQUIREMENTS:
|
||||
- Every claim MUST include a source URL (authoritative, recent: 2024-2025 preferred).
|
||||
- Use concrete numbers, dates, outcomes; avoid generic advice.
|
||||
- Keep bullets tight and scannable for spoken narration."""
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
class ComprehensiveResearchStrategy(ResearchStrategy):
|
||||
"""Comprehensive research strategy - full analysis with all components."""
|
||||
|
||||
def get_mode(self) -> ResearchMode:
|
||||
return ResearchMode.COMPREHENSIVE
|
||||
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build comprehensive research prompt with podcast-focused, high-value insights."""
|
||||
date_filter = f"\nDate Focus: {config.date_range.value.replace('_', ' ')}" if config.date_range else ""
|
||||
source_filter = f"\nPriority Sources: {', '.join([s.value for s in config.source_types])}" if config.source_types else ""
|
||||
|
||||
prompt = f"""You are a senior podcast researcher creating deeply sourced talking points for a {industry} audience of {target_audience}.
|
||||
|
||||
Research Topic: "{topic}"{date_filter}{source_filter}
|
||||
|
||||
Provide COMPLETE analysis in this EXACT format:
|
||||
|
||||
## WHAT'S CHANGED (2024-2025)
|
||||
[5-7 concise trend bullets with numbers + source URLs]
|
||||
|
||||
## PROOF & NUMBERS
|
||||
[10 stats with metric, date, sample size/method, and source URL]
|
||||
|
||||
## EXPERT SIGNALS
|
||||
[5 expert quotes with name, title/company, source URL]
|
||||
|
||||
## RECENT MOVES
|
||||
[5-7 news items or launches with dates and source URLs]
|
||||
|
||||
## MARKET SNAPSHOTS
|
||||
[3-5 insights with TAM/SAM/SOM or adoption metrics, source URLs]
|
||||
|
||||
## CASE SNAPS
|
||||
[3-5 cases: who, what they did, outcome metric, source URL]
|
||||
|
||||
## KEYWORD PLAN
|
||||
Primary (3), Secondary (8-10), Long-tail (5-7) with intent hints.
|
||||
|
||||
## COMPETITOR GAPS
|
||||
- Top 5 competitors (URL) + 1-line strength
|
||||
- 5 content gaps we can own
|
||||
- 3 unique angles to differentiate
|
||||
|
||||
## PODCAST-READY ANGLES (5)
|
||||
- Each: Hook, promised takeaway, data or example, source URL.
|
||||
|
||||
## FACT CARD LIST (10)
|
||||
- Each: Quote/claim, source URL, published date, metric/context, suggested angle tag.
|
||||
|
||||
VERIFICATION REQUIREMENTS:
|
||||
- Minimum 2 authoritative sources per major claim.
|
||||
- Prefer industry reports > research papers > news > blogs.
|
||||
- 2024-2025 data strongly preferred.
|
||||
- All numbers must include timeframe and methodology.
|
||||
- Every bullet must be concise for spoken narration and actionable for {target_audience}."""
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
class TargetedResearchStrategy(ResearchStrategy):
|
||||
"""Targeted research strategy - focused on specific aspects."""
|
||||
|
||||
def get_mode(self) -> ResearchMode:
|
||||
return ResearchMode.TARGETED
|
||||
|
||||
def build_research_prompt(
|
||||
self,
|
||||
topic: str,
|
||||
industry: str,
|
||||
target_audience: str,
|
||||
config: ResearchConfig
|
||||
) -> str:
|
||||
"""Build targeted research prompt based on config preferences."""
|
||||
sections = []
|
||||
|
||||
if config.include_trends:
|
||||
sections.append("""## CURRENT TRENDS
|
||||
[3-5 trends with data and source URLs]""")
|
||||
|
||||
if config.include_statistics:
|
||||
sections.append("""## KEY STATISTICS
|
||||
[5-7 statistics with numbers and source URLs]""")
|
||||
|
||||
if config.include_expert_quotes:
|
||||
sections.append("""## EXPERT OPINIONS
|
||||
[3-4 expert quotes with attribution and source URLs]""")
|
||||
|
||||
if config.include_competitors:
|
||||
sections.append("""## COMPETITOR ANALYSIS
|
||||
Top Competitors: [3-5]
|
||||
Content Gaps: [3-5]""")
|
||||
|
||||
# Always include keywords and angles
|
||||
sections.append("""## KEYWORD ANALYSIS
|
||||
Primary: [2-3 variations]
|
||||
Secondary: [5-7 keywords]
|
||||
Long-Tail: [3-5 phrases]""")
|
||||
|
||||
sections.append("""## CONTENT ANGLES (3-5)
|
||||
[Unique blog angles with reasoning]""")
|
||||
|
||||
sections_str = "\n\n".join(sections)
|
||||
|
||||
prompt = f"""You are a blog content strategist conducting targeted research for a {industry} blog targeting {target_audience}.
|
||||
|
||||
Research Topic: "{topic}"
|
||||
|
||||
Provide focused analysis in this EXACT format:
|
||||
|
||||
{sections_str}
|
||||
|
||||
REQUIREMENTS:
|
||||
- Cite all claims with authoritative source URLs
|
||||
- Include specific numbers, dates, examples
|
||||
- Focus on actionable insights for {target_audience}
|
||||
- Use 2024-2025 data when available"""
|
||||
return prompt.strip()
|
||||
|
||||
|
||||
def get_strategy_for_mode(mode: ResearchMode) -> ResearchStrategy:
|
||||
"""Factory function to get the appropriate strategy for a mode."""
|
||||
strategy_map = {
|
||||
ResearchMode.BASIC: BasicResearchStrategy,
|
||||
ResearchMode.COMPREHENSIVE: ComprehensiveResearchStrategy,
|
||||
ResearchMode.TARGETED: TargetedResearchStrategy,
|
||||
}
|
||||
|
||||
strategy_class = strategy_map.get(mode, BasicResearchStrategy)
|
||||
return strategy_class()
|
||||
|
||||
169
backend/services/blog_writer/research/tavily_provider.py
Normal file
169
backend/services/blog_writer/research/tavily_provider.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""
|
||||
Tavily Research Provider
|
||||
|
||||
AI-powered search implementation using Tavily API for high-quality research.
|
||||
"""
|
||||
|
||||
import os
|
||||
from loguru import logger
|
||||
from models.subscription_models import APIProvider
|
||||
from services.research.tavily_service import TavilyService
|
||||
from .base_provider import ResearchProvider as BaseProvider
|
||||
|
||||
|
||||
class TavilyResearchProvider(BaseProvider):
|
||||
"""Tavily AI-powered search provider."""
|
||||
|
||||
def __init__(self):
|
||||
self.api_key = os.getenv("TAVILY_API_KEY")
|
||||
if not self.api_key:
|
||||
raise RuntimeError("TAVILY_API_KEY not configured")
|
||||
self.tavily_service = TavilyService()
|
||||
logger.info("✅ Tavily Research Provider initialized")
|
||||
|
||||
async def search(self, prompt, topic, industry, target_audience, config, user_id):
|
||||
"""Execute Tavily search and return standardized results."""
|
||||
# Build Tavily query
|
||||
query = f"{topic} {industry} {target_audience}"
|
||||
|
||||
# Get Tavily-specific config options
|
||||
topic = config.tavily_topic or "general"
|
||||
search_depth = config.tavily_search_depth or "basic"
|
||||
|
||||
logger.info(f"[Tavily Research] Executing search: {query}")
|
||||
|
||||
# Execute Tavily search
|
||||
result = await self.tavily_service.search(
|
||||
query=query,
|
||||
topic=topic,
|
||||
search_depth=search_depth,
|
||||
max_results=min(config.max_sources, 20),
|
||||
include_domains=config.tavily_include_domains or None,
|
||||
exclude_domains=config.tavily_exclude_domains or None,
|
||||
include_answer=config.tavily_include_answer or False,
|
||||
include_raw_content=config.tavily_include_raw_content or False,
|
||||
include_images=config.tavily_include_images or False,
|
||||
include_image_descriptions=config.tavily_include_image_descriptions or False,
|
||||
time_range=config.tavily_time_range,
|
||||
start_date=config.tavily_start_date,
|
||||
end_date=config.tavily_end_date,
|
||||
country=config.tavily_country,
|
||||
chunks_per_source=config.tavily_chunks_per_source or 3,
|
||||
auto_parameters=config.tavily_auto_parameters or False
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise RuntimeError(f"Tavily search failed: {result.get('error', 'Unknown error')}")
|
||||
|
||||
# Transform to standardized format
|
||||
sources = self._transform_sources(result.get("results", []))
|
||||
content = self._aggregate_content(result.get("results", []))
|
||||
|
||||
# Calculate cost (basic = 1 credit, advanced = 2 credits)
|
||||
cost = 0.001 if search_depth == "basic" else 0.002 # Estimate cost per search
|
||||
|
||||
logger.info(f"[Tavily Research] Search completed: {len(sources)} sources, depth: {search_depth}")
|
||||
|
||||
return {
|
||||
'sources': sources,
|
||||
'content': content,
|
||||
'search_type': search_depth,
|
||||
'provider': 'tavily',
|
||||
'search_queries': [query],
|
||||
'cost': {'total': cost},
|
||||
'answer': result.get("answer"), # If include_answer was requested
|
||||
'images': result.get("images", [])
|
||||
}
|
||||
|
||||
def get_provider_enum(self):
|
||||
"""Return TAVILY provider enum for subscription tracking."""
|
||||
return APIProvider.TAVILY
|
||||
|
||||
def estimate_tokens(self) -> int:
|
||||
"""Estimate token usage for Tavily (not token-based, but we estimate API calls)."""
|
||||
return 0 # Tavily is per-search, not token-based
|
||||
|
||||
def _transform_sources(self, results):
|
||||
"""Transform Tavily results to ResearchSource format."""
|
||||
sources = []
|
||||
for idx, result in enumerate(results):
|
||||
source_type = self._determine_source_type(result.get("url", ""))
|
||||
|
||||
sources.append({
|
||||
'title': result.get("title", ""),
|
||||
'url': result.get("url", ""),
|
||||
'excerpt': result.get("content", "")[:500], # First 500 chars
|
||||
'credibility_score': result.get("relevance_score", 0.5),
|
||||
'published_at': result.get("published_date"),
|
||||
'index': idx,
|
||||
'source_type': source_type,
|
||||
'content': result.get("content", ""),
|
||||
'raw_content': result.get("raw_content"), # If include_raw_content was requested
|
||||
'score': result.get("score", result.get("relevance_score", 0.5)),
|
||||
'favicon': result.get("favicon")
|
||||
})
|
||||
|
||||
return sources
|
||||
|
||||
def _determine_source_type(self, url):
|
||||
"""Determine source type from URL."""
|
||||
if not url:
|
||||
return 'web'
|
||||
|
||||
url_lower = url.lower()
|
||||
if 'arxiv.org' in url_lower or 'research' in url_lower or '.edu' in url_lower:
|
||||
return 'academic'
|
||||
elif any(news in url_lower for news in ['cnn.com', 'bbc.com', 'reuters.com', 'theguardian.com', 'nytimes.com']):
|
||||
return 'news'
|
||||
elif 'linkedin.com' in url_lower:
|
||||
return 'expert'
|
||||
elif '.gov' in url_lower:
|
||||
return 'government'
|
||||
else:
|
||||
return 'web'
|
||||
|
||||
def _aggregate_content(self, results):
|
||||
"""Aggregate content from Tavily results for LLM analysis."""
|
||||
content_parts = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
content = result.get("content", "")
|
||||
if content:
|
||||
content_parts.append(f"Source {idx + 1}: {content}")
|
||||
|
||||
return "\n\n".join(content_parts)
|
||||
|
||||
def track_tavily_usage(self, user_id: str, cost: float, search_depth: str):
|
||||
"""Track Tavily API usage after successful call."""
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
# Update tavily_calls and tavily_cost via SQL UPDATE
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET tavily_calls = COALESCE(tavily_calls, 0) + 1,
|
||||
tavily_cost = COALESCE(tavily_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[Tavily] Tracked usage: user={user_id}, cost=${cost}, depth={search_depth}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Tavily] Failed to track usage: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
Reference in New Issue
Block a user