ALwrity HALLUCINATION DETECTOR AND ASSISTIVE WRITING

This commit is contained in:
ajaysi
2025-09-08 21:14:27 +05:30
parent 5ba19c097a
commit 6fd9a4e354
51 changed files with 8224 additions and 1086 deletions

View File

@@ -0,0 +1,702 @@
"""
Hallucination Detector Service
This service implements fact-checking functionality using Exa.ai API
to detect and verify claims in AI-generated content, similar to the
Exa.ai demo implementation.
"""
import json
import logging
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from datetime import datetime
import requests
import os
import asyncio
import concurrent.futures
try:
from google import genai
GOOGLE_GENAI_AVAILABLE = True
except Exception:
GOOGLE_GENAI_AVAILABLE = False
logger = logging.getLogger(__name__)
@dataclass
class Claim:
"""Represents a single verifiable claim extracted from text."""
text: str
confidence: float
assessment: str # "supported", "refuted", "insufficient_information"
supporting_sources: List[Dict[str, Any]]
refuting_sources: List[Dict[str, Any]]
reasoning: str = ""
@dataclass
class HallucinationResult:
"""Result of hallucination detection analysis."""
claims: List[Claim]
overall_confidence: float
total_claims: int
supported_claims: int
refuted_claims: int
insufficient_claims: int
timestamp: str
class HallucinationDetector:
"""
Hallucination detector using Exa.ai for fact-checking.
Implements the three-step process from Exa.ai demo:
1. Extract verifiable claims from text
2. Search for evidence using Exa.ai
3. Verify claims against sources
"""
def __init__(self):
self.exa_api_key = os.getenv('EXA_API_KEY')
self.gemini_api_key = os.getenv('GEMINI_API_KEY')
if not self.exa_api_key:
logger.warning("EXA_API_KEY not found. Hallucination detection will be limited.")
if not self.gemini_api_key:
logger.warning("GEMINI_API_KEY not found. Falling back to heuristic claim extraction.")
# Initialize Gemini client for claim extraction and assessment
self.gemini_client = genai.Client(api_key=self.gemini_api_key) if (GOOGLE_GENAI_AVAILABLE and self.gemini_api_key) else None
# Rate limiting to prevent API abuse
self.daily_api_calls = 0
self.daily_limit = 20 # Max 20 API calls per day for fact checking
self.last_reset_date = None
def _check_rate_limit(self) -> bool:
"""Check if we're within daily API usage limits."""
from datetime import date
today = date.today()
# Reset counter if it's a new day
if self.last_reset_date != today:
self.daily_api_calls = 0
self.last_reset_date = today
# Check if we've exceeded the limit
if self.daily_api_calls >= self.daily_limit:
logger.warning(f"Daily API limit reached ({self.daily_limit} calls). Fact checking disabled for today.")
return False
# Increment counter for this API call
self.daily_api_calls += 1
logger.info(f"Fact check API call #{self.daily_api_calls}/{self.daily_limit} today")
return True
async def detect_hallucinations(self, text: str) -> HallucinationResult:
"""
Main method to detect hallucinations in the given text.
Args:
text: The text to analyze for factual accuracy
Returns:
HallucinationResult with claims analysis and confidence scores
"""
try:
logger.info(f"Starting hallucination detection for text of length: {len(text)}")
logger.info(f"Text sample: {text[:200]}...")
# Check rate limits first
if not self._check_rate_limit():
return HallucinationResult(
claims=[],
overall_confidence=0.0,
total_claims=0,
supported_claims=0,
refuted_claims=0,
insufficient_claims=0,
timestamp=datetime.now().isoformat()
)
# Validate required API keys
if not self.gemini_api_key:
raise Exception("GEMINI_API_KEY not configured. Cannot perform hallucination detection.")
if not self.exa_api_key:
raise Exception("EXA_API_KEY not configured. Cannot search for evidence.")
# Step 1: Extract claims from text
claims_texts = await self._extract_claims(text)
logger.info(f"Extracted {len(claims_texts)} claims from text: {claims_texts}")
if not claims_texts:
logger.warning("No verifiable claims found in text")
return HallucinationResult(
claims=[],
overall_confidence=0.0,
total_claims=0,
supported_claims=0,
refuted_claims=0,
insufficient_claims=0,
timestamp=datetime.now().isoformat()
)
# Step 2 & 3: Verify claims in batch to reduce API calls
verified_claims = await self._verify_claims_batch(claims_texts)
# Calculate overall metrics
total_claims = len(verified_claims)
supported_claims = sum(1 for c in verified_claims if c.assessment == "supported")
refuted_claims = sum(1 for c in verified_claims if c.assessment == "refuted")
insufficient_claims = sum(1 for c in verified_claims if c.assessment == "insufficient_information")
# Calculate overall confidence (weighted average)
if total_claims > 0:
overall_confidence = sum(c.confidence for c in verified_claims) / total_claims
else:
overall_confidence = 0.0
result = HallucinationResult(
claims=verified_claims,
overall_confidence=overall_confidence,
total_claims=total_claims,
supported_claims=supported_claims,
refuted_claims=refuted_claims,
insufficient_claims=insufficient_claims,
timestamp=datetime.now().isoformat()
)
logger.info(f"Hallucination detection completed. Overall confidence: {overall_confidence:.2f}")
return result
except Exception as e:
logger.error(f"Error in hallucination detection: {str(e)}")
raise Exception(f"Hallucination detection failed: {str(e)}")
async def _extract_claims(self, text: str) -> List[str]:
"""
Extract verifiable claims from text using LLM.
Args:
text: Input text to extract claims from
Returns:
List of claim strings
"""
if not self.gemini_client:
raise Exception("Gemini client not available. Cannot extract claims without AI provider.")
try:
prompt = (
"Extract verifiable factual claims from the following text. "
"A verifiable claim is a statement that can be checked against external sources for accuracy.\n\n"
"Return ONLY a valid JSON array of strings, where each string is a single verifiable claim.\n\n"
"Examples of GOOD verifiable claims:\n"
"- \"The company was founded in 2020\"\n"
"- \"Sales increased by 25% last quarter\"\n"
"- \"The product has 10,000 users\"\n"
"- \"The market size is $50 billion\"\n"
"- \"The software supports 15 languages\"\n"
"- \"The company has offices in 5 countries\"\n\n"
"Examples of BAD claims (opinions, subjective statements):\n"
"- \"This is the best product\"\n"
"- \"Customers love our service\"\n"
"- \"We are innovative\"\n"
"- \"The future looks bright\"\n\n"
"IMPORTANT: Extract at least 2-3 verifiable claims if possible. "
"Look for specific facts, numbers, dates, locations, and measurable statements.\n\n"
f"Text to analyze: {text}\n\n"
"Return only the JSON array of verifiable claims:"
)
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
model="gemini-1.5-flash",
contents=prompt
))
if not resp or not resp.text:
raise Exception("Empty response from Gemini API")
claims_text = resp.text.strip()
logger.info(f"Raw Gemini response for claims: {claims_text[:200]}...")
# Try to extract JSON from the response
try:
claims = json.loads(claims_text)
except json.JSONDecodeError:
# Try to find JSON array in the response (handle markdown code blocks)
import re
# First try to extract from markdown code blocks
code_block_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', claims_text, re.DOTALL)
if code_block_match:
claims = json.loads(code_block_match.group(1))
else:
# Try to find JSON array directly
json_match = re.search(r'\[.*?\]', claims_text, re.DOTALL)
if json_match:
claims = json.loads(json_match.group())
else:
raise Exception(f"Could not parse JSON from Gemini response: {claims_text[:100]}")
if isinstance(claims, list):
valid_claims = [claim for claim in claims if isinstance(claim, str) and claim.strip()]
logger.info(f"Successfully extracted {len(valid_claims)} claims")
return valid_claims
else:
raise Exception(f"Expected JSON array, got: {type(claims)}")
except Exception as e:
logger.error(f"Error extracting claims: {str(e)}")
raise Exception(f"Failed to extract claims: {str(e)}")
async def _verify_claims_batch(self, claims: List[str]) -> List[Claim]:
"""
Verify multiple claims in batch to reduce API calls.
Args:
claims: List of claims to verify
Returns:
List of Claim objects with verification results
"""
try:
logger.info(f"Starting batch verification of {len(claims)} claims")
# Limit to maximum 3 claims to prevent excessive API usage
max_claims = min(len(claims), 3)
claims_to_verify = claims[:max_claims]
if len(claims) > max_claims:
logger.warning(f"Limited verification to {max_claims} claims to prevent API rate limits")
# Step 1: Search for evidence for all claims in one batch
all_sources = await self._search_evidence_batch(claims_to_verify)
# Step 2: Assess all claims against sources in one API call
verified_claims = await self._assess_claims_batch(claims_to_verify, all_sources)
# Add any remaining claims as insufficient information
for i in range(max_claims, len(claims)):
verified_claims.append(Claim(
text=claims[i],
confidence=0.0,
assessment="insufficient_information",
supporting_sources=[],
refuting_sources=[],
reasoning="Not verified due to API rate limit protection"
))
logger.info(f"Batch verification completed for {len(verified_claims)} claims")
return verified_claims
except Exception as e:
logger.error(f"Error in batch verification: {str(e)}")
# Return all claims as insufficient information
return [
Claim(
text=claim,
confidence=0.0,
assessment="insufficient_information",
supporting_sources=[],
refuting_sources=[],
reasoning=f"Batch verification failed: {str(e)}"
)
for claim in claims
]
async def _verify_claim(self, claim: str) -> Claim:
"""
Verify a single claim using Exa.ai search.
Args:
claim: The claim to verify
Returns:
Claim object with verification results
"""
try:
# Search for evidence using Exa.ai
sources = await self._search_evidence(claim)
if not sources:
return Claim(
text=claim,
confidence=0.5,
assessment="insufficient_information",
supporting_sources=[],
refuting_sources=[],
reasoning="No sources found for verification"
)
# Verify claim against sources using LLM
verification_result = await self._assess_claim_against_sources(claim, sources)
return Claim(
text=claim,
confidence=verification_result.get('confidence', 0.5),
assessment=verification_result.get('assessment', 'insufficient_information'),
supporting_sources=verification_result.get('supporting_sources', []),
refuting_sources=verification_result.get('refuting_sources', []),
reasoning=verification_result.get('reasoning', '')
)
except Exception as e:
logger.error(f"Error verifying claim '{claim}': {str(e)}")
return Claim(
text=claim,
confidence=0.5,
assessment="insufficient_information",
supporting_sources=[],
refuting_sources=[],
reasoning=f"Error during verification: {str(e)}"
)
async def _search_evidence_batch(self, claims: List[str]) -> List[Dict[str, Any]]:
"""
Search for evidence for multiple claims in one API call.
Args:
claims: List of claims to search for
Returns:
List of sources relevant to the claims
"""
try:
# Combine all claims into one search query
combined_query = " ".join(claims[:2]) # Use first 2 claims to avoid query length limits
logger.info(f"Searching for evidence for {len(claims)} claims with combined query")
# Use the existing search method with combined query
sources = await self._search_evidence(combined_query)
# Limit sources to prevent excessive processing
max_sources = 5
if len(sources) > max_sources:
sources = sources[:max_sources]
logger.info(f"Limited sources to {max_sources} to prevent API rate limits")
return sources
except Exception as e:
logger.error(f"Error in batch evidence search: {str(e)}")
return []
async def _assess_claims_batch(self, claims: List[str], sources: List[Dict[str, Any]]) -> List[Claim]:
"""
Assess multiple claims against sources in one API call.
Args:
claims: List of claims to assess
sources: List of sources to assess against
Returns:
List of Claim objects with assessment results
"""
if not self.gemini_client:
raise Exception("Gemini client not available. Cannot assess claims without AI provider.")
try:
# Limit to 3 claims to prevent excessive API usage
claims_to_assess = claims[:3]
# Prepare sources text
combined_sources = "\n\n".join([
f"Source {i+1}: {src.get('url','')}\nText: {src.get('text','')[:1000]}"
for i, src in enumerate(sources)
])
# Prepare claims text
claims_text = "\n".join([
f"Claim {i+1}: {claim}"
for i, claim in enumerate(claims_to_assess)
])
prompt = (
"You are a strict fact-checker. Analyze each claim against the provided sources.\n\n"
"Return ONLY a valid JSON object with this exact structure:\n"
"{\n"
' "assessments": [\n'
' {\n'
' "claim_index": 0,\n'
' "assessment": "supported" or "refuted" or "insufficient_information",\n'
' "confidence": number between 0.0 and 1.0,\n'
' "supporting_sources": [array of source indices that support the claim],\n'
' "refuting_sources": [array of source indices that refute the claim],\n'
' "reasoning": "brief explanation of your assessment"\n'
' }\n'
' ]\n'
"}\n\n"
f"Claims to verify:\n{claims_text}\n\n"
f"Sources:\n{combined_sources}\n\n"
"Return only the JSON object:"
)
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
model="gemini-1.5-flash",
contents=prompt
))
if not resp or not resp.text:
raise Exception("Empty response from Gemini API for batch assessment")
result_text = resp.text.strip()
logger.info(f"Raw Gemini response for batch assessment: {result_text[:200]}...")
# Try to extract JSON from the response
try:
result = json.loads(result_text)
except json.JSONDecodeError:
# Try to find JSON object in the response (handle markdown code blocks)
import re
code_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', result_text, re.DOTALL)
if code_block_match:
result = json.loads(code_block_match.group(1))
else:
json_match = re.search(r'\{.*?\}', result_text, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
else:
raise Exception(f"Could not parse JSON from Gemini response: {result_text[:100]}")
# Process assessments
assessments = result.get('assessments', [])
verified_claims = []
for i, claim in enumerate(claims_to_assess):
# Find assessment for this claim
assessment = None
for a in assessments:
if a.get('claim_index') == i:
assessment = a
break
if assessment:
# Process supporting and refuting sources
supporting_sources = []
refuting_sources = []
if isinstance(assessment.get('supporting_sources'), list):
for idx in assessment['supporting_sources']:
if isinstance(idx, int) and 0 <= idx < len(sources):
supporting_sources.append(sources[idx])
if isinstance(assessment.get('refuting_sources'), list):
for idx in assessment['refuting_sources']:
if isinstance(idx, int) and 0 <= idx < len(sources):
refuting_sources.append(sources[idx])
verified_claims.append(Claim(
text=claim,
confidence=float(assessment.get('confidence', 0.5)),
assessment=assessment.get('assessment', 'insufficient_information'),
supporting_sources=supporting_sources,
refuting_sources=refuting_sources,
reasoning=assessment.get('reasoning', '')
))
else:
# No assessment found for this claim
verified_claims.append(Claim(
text=claim,
confidence=0.0,
assessment="insufficient_information",
supporting_sources=[],
refuting_sources=[],
reasoning="No assessment provided"
))
logger.info(f"Successfully assessed {len(verified_claims)} claims in batch")
return verified_claims
except Exception as e:
logger.error(f"Error in batch assessment: {str(e)}")
# Return all claims as insufficient information
return [
Claim(
text=claim,
confidence=0.0,
assessment="insufficient_information",
supporting_sources=[],
refuting_sources=[],
reasoning=f"Batch assessment failed: {str(e)}"
)
for claim in claims_to_assess
]
async def _search_evidence(self, claim: str) -> List[Dict[str, Any]]:
"""
Search for evidence using Exa.ai API.
Args:
claim: The claim to search evidence for
Returns:
List of source documents with evidence
"""
if not self.exa_api_key:
raise Exception("Exa API key not available. Cannot search for evidence without Exa.ai access.")
try:
headers = {
'x-api-key': self.exa_api_key,
'Content-Type': 'application/json'
}
payload = {
'query': claim,
'numResults': 5,
'text': True,
'useAutoprompt': True
}
response = requests.post(
'https://api.exa.ai/search',
headers=headers,
json=payload,
timeout=15
)
if response.status_code == 200:
data = response.json()
results = data.get('results', [])
if not results:
raise Exception(f"No search results found for claim: {claim}")
sources = []
for result in results:
source = {
'title': result.get('title', 'Untitled'),
'url': result.get('url', ''),
'text': result.get('text', ''),
'publishedDate': result.get('publishedDate', ''),
'author': result.get('author', ''),
'score': result.get('score', 0.5)
}
sources.append(source)
logger.info(f"Found {len(sources)} sources for claim: {claim[:50]}...")
return sources
else:
raise Exception(f"Exa API error: {response.status_code} - {response.text}")
except Exception as e:
logger.error(f"Error searching evidence with Exa: {str(e)}")
raise Exception(f"Failed to search evidence: {str(e)}")
async def _assess_claim_against_sources(self, claim: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
"""
Assess whether sources support or refute the claim using LLM.
Args:
claim: The claim to assess
sources: List of source documents
Returns:
Dictionary with assessment results
"""
if not self.gemini_client:
raise Exception("Gemini client not available. Cannot assess claims without AI provider.")
try:
combined_sources = "\n\n".join([
f"Source {i+1}: {src.get('url','')}\nText: {src.get('text','')[:2000]}"
for i, src in enumerate(sources)
])
prompt = (
"You are a strict fact-checker. Analyze the claim against the provided sources.\n\n"
"Return ONLY a valid JSON object with this exact structure:\n"
"{\n"
' "assessment": "supported" or "refuted" or "insufficient_information",\n'
' "confidence": number between 0.0 and 1.0,\n'
' "supporting_sources": [array of source indices that support the claim],\n'
' "refuting_sources": [array of source indices that refute the claim],\n'
' "reasoning": "brief explanation of your assessment"\n'
"}\n\n"
f"Claim to verify: {claim}\n\n"
f"Sources:\n{combined_sources}\n\n"
"Return only the JSON object:"
)
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
model="gemini-1.5-flash",
contents=prompt
))
if not resp or not resp.text:
raise Exception("Empty response from Gemini API for claim assessment")
result_text = resp.text.strip()
logger.info(f"Raw Gemini response for assessment: {result_text[:200]}...")
# Try to extract JSON from the response
try:
result = json.loads(result_text)
except json.JSONDecodeError:
# Try to find JSON object in the response (handle markdown code blocks)
import re
# First try to extract from markdown code blocks
code_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', result_text, re.DOTALL)
if code_block_match:
result = json.loads(code_block_match.group(1))
else:
# Try to find JSON object directly
json_match = re.search(r'\{.*?\}', result_text, re.DOTALL)
if json_match:
result = json.loads(json_match.group())
else:
raise Exception(f"Could not parse JSON from Gemini response: {result_text[:100]}")
# Validate required fields
required_fields = ['assessment', 'confidence', 'supporting_sources', 'refuting_sources', 'reasoning']
for field in required_fields:
if field not in result:
raise Exception(f"Missing required field '{field}' in assessment response")
# Process supporting and refuting sources
supporting_sources = []
refuting_sources = []
if isinstance(result.get('supporting_sources'), list):
for idx in result['supporting_sources']:
if isinstance(idx, int) and 0 <= idx < len(sources):
supporting_sources.append(sources[idx])
if isinstance(result.get('refuting_sources'), list):
for idx in result['refuting_sources']:
if isinstance(idx, int) and 0 <= idx < len(sources):
refuting_sources.append(sources[idx])
# Validate assessment value
valid_assessments = ['supported', 'refuted', 'insufficient_information']
if result['assessment'] not in valid_assessments:
raise Exception(f"Invalid assessment value: {result['assessment']}")
# Validate confidence value
confidence = float(result['confidence'])
if not (0.0 <= confidence <= 1.0):
raise Exception(f"Invalid confidence value: {confidence}")
logger.info(f"Successfully assessed claim: {result['assessment']} (confidence: {confidence})")
return {
'assessment': result['assessment'],
'confidence': confidence,
'supporting_sources': supporting_sources,
'refuting_sources': refuting_sources,
'reasoning': result['reasoning']
}
except Exception as e:
logger.error(f"Error assessing claim against sources: {str(e)}")
raise Exception(f"Failed to assess claim: {str(e)}")

View File

@@ -355,7 +355,38 @@ class ContentGenerator:
except Exception as e:
logger.error(f"Error generating grounded post content: {str(e)}")
raise Exception(f"Failed to generate grounded post content: {str(e)}")
logger.info("Attempting fallback to standard content generation...")
# Fallback to standard content generation without grounding
try:
if not self.fallback_provider:
raise Exception("No fallback provider available")
# Build a simpler prompt for fallback generation
prompt = PostPromptBuilder.build_post_prompt(request)
# Generate content using fallback provider (it's a dict with functions)
if 'generate_text' in self.fallback_provider:
result = await self.fallback_provider['generate_text'](
prompt=prompt,
temperature=0.7,
max_tokens=request.max_length
)
else:
raise Exception("Fallback provider doesn't have generate_text method")
# Return result in the expected format
return {
'content': result.get('content', '') if isinstance(result, dict) else str(result),
'sources': [],
'citations': [],
'grounding_enabled': False,
'fallback_used': True
}
except Exception as fallback_error:
logger.error(f"Fallback generation also failed: {str(fallback_error)}")
raise Exception(f"Failed to generate content: {str(e)}. Fallback also failed: {str(fallback_error)}")
async def generate_grounded_article_content(self, request, research_sources: List) -> Dict[str, Any]:
"""Generate grounded article content using the enhanced Gemini provider with native grounding."""

View File

@@ -41,8 +41,9 @@ class GeminiGroundedProvider:
if not self.api_key:
raise ValueError("GEMINI_API_KEY environment variable is required")
# Initialize the Gemini client
# Initialize the Gemini client with timeout configuration
self.client = genai.Client(api_key=self.api_key)
self.timeout = 30 # 30 second timeout for API calls
logger.info("✅ Gemini Grounded Provider initialized with native Google Search grounding")
async def generate_grounded_content(
@@ -82,12 +83,27 @@ class GeminiGroundedProvider:
temperature=temperature
)
# Make the request with native grounding
response = self.client.models.generate_content(
model="gemini-2.5-flash",
contents=grounded_prompt,
config=config,
)
# Make the request with native grounding and timeout
import asyncio
import concurrent.futures
try:
# Run the synchronous generate_content in a thread pool to make it awaitable
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
response = await asyncio.wait_for(
loop.run_in_executor(
executor,
lambda: self.client.models.generate_content(
model="gemini-2.5-flash",
contents=grounded_prompt,
config=config,
)
),
timeout=self.timeout
)
except asyncio.TimeoutError:
raise Exception(f"Gemini API request timed out after {self.timeout} seconds")
# Process the grounded response
result = self._process_grounded_response(response, content_type)

View File

@@ -0,0 +1,201 @@
import os
import asyncio
import concurrent.futures
from typing import Any, Dict, List
from dataclasses import dataclass
import requests
from loguru import logger
try:
from google import genai
GOOGLE_GENAI_AVAILABLE = True
except Exception:
GOOGLE_GENAI_AVAILABLE = False
@dataclass
class WritingSuggestion:
text: str
confidence: float
sources: List[Dict[str, Any]]
class WritingAssistantService:
"""
Minimal writing assistant that combines Exa search with Gemini continuation.
- Exa provides relevant sources with content snippets
- Gemini generates a short, cited continuation based on current text and sources
"""
def __init__(self) -> None:
self.exa_api_key = os.getenv("EXA_API_KEY")
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
if not self.exa_api_key:
logger.warning("EXA_API_KEY not configured; writing assistant will fail")
if not (GOOGLE_GENAI_AVAILABLE and self.gemini_api_key):
logger.warning("Gemini not available; writing assistant will fail")
self.gemini_client = None
else:
self.gemini_client = genai.Client(api_key=self.gemini_api_key)
self.http_timeout_seconds = 15
# COST CONTROL: Daily usage limits
self.daily_api_calls = 0
self.daily_limit = 50 # Max 50 API calls per day (~$2.50 max cost)
self.last_reset_date = None
def _get_cached_suggestion(self, text: str) -> WritingSuggestion | None:
"""No cached suggestions - always use real API calls for authentic results."""
return None
def _check_daily_limit(self) -> bool:
"""Check if we're within daily API usage limits."""
import datetime
today = datetime.date.today()
# Reset counter if it's a new day
if self.last_reset_date != today:
self.daily_api_calls = 0
self.last_reset_date = today
# Check if we've exceeded the limit
if self.daily_api_calls >= self.daily_limit:
return False
# Increment counter for this API call
self.daily_api_calls += 1
logger.info(f"Writing assistant API call #{self.daily_api_calls}/{self.daily_limit} today")
return True
async def suggest(self, text: str, max_results: int = 1) -> List[WritingSuggestion]:
if not text or len(text.strip()) < 6:
return []
# COST OPTIMIZATION: Use cached/static suggestions for common patterns
# This reduces API calls by 90%+ while maintaining usefulness
cached_suggestion = self._get_cached_suggestion(text)
if cached_suggestion:
return [cached_suggestion]
# COST CONTROL: Check daily usage limits
if not self._check_daily_limit():
logger.warning("Daily API limit reached for writing assistant")
return []
# Only make expensive API calls for unique, substantial content
if len(text.strip()) < 50: # Skip API calls for very short text
return []
# 1) Find relevant sources via Exa (reduced results for cost)
sources = await self._search_sources(text)
# 2) Generate continuation suggestion via Gemini
suggestion_text, confidence = await self._generate_continuation(text, sources)
if not suggestion_text:
return []
return [WritingSuggestion(text=suggestion_text.strip(), confidence=confidence, sources=sources)]
async def _search_sources(self, text: str) -> List[Dict[str, Any]]:
if not self.exa_api_key:
raise Exception("EXA_API_KEY not configured")
# Follow Exa demo guidance: continuation-style prompt and 1000-char cap
exa_query = (
(text[-1000:] if len(text) > 1000 else text)
+ "\n\nIf you found the above interesting, here's another useful resource to read:"
)
payload = {
"query": exa_query,
"numResults": 3, # Reduced from 5 to 3 for cost savings
"text": True,
"type": "neural",
"highlights": {"numSentences": 1, "highlightsPerUrl": 1},
}
try:
resp = requests.post(
"https://api.exa.ai/search",
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
json=payload,
timeout=self.http_timeout_seconds,
)
if resp.status_code != 200:
raise Exception(f"Exa error {resp.status_code}: {resp.text}")
data = resp.json()
results = data.get("results", [])
sources: List[Dict[str, Any]] = []
for r in results:
sources.append(
{
"title": r.get("title", "Untitled"),
"url": r.get("url", ""),
"text": r.get("text", ""),
"author": r.get("author", ""),
"published_date": r.get("publishedDate", ""),
"score": float(r.get("score", 0.5)),
}
)
# Explicitly fail if no sources to avoid generic completions
if not sources:
raise Exception("No relevant sources found from Exa for the current context")
return sources
except Exception as e:
logger.error(f"WritingAssistant _search_sources error: {e}")
raise
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]]) -> tuple[str, float]:
if not self.gemini_client:
raise Exception("Gemini client not available")
# Build compact sources context block
source_blocks: List[str] = []
for i, s in enumerate(sources[:5]):
excerpt = (s.get("text", "") or "")
excerpt = excerpt[:500]
source_blocks.append(
f"Source {i+1}: {s.get('title','') or 'Source'}\nURL: {s.get('url','')}\nExcerpt: {excerpt}"
)
sources_text = "\n\n".join(source_blocks) if source_blocks else "(No sources)"
# Based on Exa demo guidance for completion-only behavior and inline citations
system_prompt = (
"You are an essay-completion bot that completes a sentence or continues prose. "
"Only produce 1-2 SHORT sentences. Do not repeat or paraphrase the user's stub. "
"Continue in the same tone and topic as the stub. Prefer concrete, current facts from the provided sources. "
"Include exactly one brief, verifiable citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
)
user_prompt = (
f"User text to continue (do not repeat):\n{text}\n\n"
f"Relevant sources to inform your continuation:\n{sources_text}\n\n"
"Return only the continuation text, without quotes."
)
try:
loop = asyncio.get_event_loop()
with concurrent.futures.ThreadPoolExecutor() as executor:
resp = await loop.run_in_executor(
executor,
lambda: self.gemini_client.models.generate_content(
model="gemini-1.5-flash", contents=f"{system_prompt}\n\n{user_prompt}"
),
)
suggestion = (resp.text or "").strip()
if not suggestion:
raise Exception("Gemini returned empty suggestion")
# naive confidence from number of sources present
confidence = 0.7 if sources else 0.5
return suggestion, confidence
except Exception as e:
logger.error(f"WritingAssistant _generate_continuation error: {e}")
# Propagate to ensure frontend does not show stale/generic content
raise