ALwrity HALLUCINATION DETECTOR AND ASSISTIVE WRITING
This commit is contained in:
351
backend/api/hallucination_detector.py
Normal file
351
backend/api/hallucination_detector.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Hallucination Detector API endpoints.
|
||||
|
||||
Provides REST API endpoints for fact-checking and hallucination detection
|
||||
using Exa.ai integration, similar to the Exa.ai demo implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from models.hallucination_models import (
|
||||
HallucinationDetectionRequest,
|
||||
HallucinationDetectionResponse,
|
||||
ClaimExtractionRequest,
|
||||
ClaimExtractionResponse,
|
||||
ClaimVerificationRequest,
|
||||
ClaimVerificationResponse,
|
||||
HealthCheckResponse,
|
||||
Claim,
|
||||
SourceDocument,
|
||||
AssessmentType
|
||||
)
|
||||
from services.hallucination_detector import HallucinationDetector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/api/hallucination-detector", tags=["Hallucination Detector"])
|
||||
|
||||
# Initialize detector service
|
||||
detector = HallucinationDetector()
|
||||
|
||||
@router.post("/detect", response_model=HallucinationDetectionResponse)
|
||||
async def detect_hallucinations(request: HallucinationDetectionRequest) -> HallucinationDetectionResponse:
|
||||
"""
|
||||
Detect hallucinations in the provided text.
|
||||
|
||||
This endpoint implements the complete hallucination detection pipeline:
|
||||
1. Extract verifiable claims from the text
|
||||
2. Search for evidence using Exa.ai
|
||||
3. Verify each claim against the found sources
|
||||
|
||||
Args:
|
||||
request: HallucinationDetectionRequest with text to analyze
|
||||
|
||||
Returns:
|
||||
HallucinationDetectionResponse with analysis results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting hallucination detection for text of length: {len(request.text)}")
|
||||
|
||||
# Perform hallucination detection
|
||||
result = await detector.detect_hallucinations(request.text)
|
||||
|
||||
# Convert to response format
|
||||
claims = []
|
||||
for claim in result.claims:
|
||||
# Convert sources to SourceDocument objects
|
||||
supporting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim.supporting_sources
|
||||
]
|
||||
|
||||
refuting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim.refuting_sources
|
||||
]
|
||||
|
||||
claim_obj = Claim(
|
||||
text=claim.text,
|
||||
confidence=claim.confidence,
|
||||
assessment=AssessmentType(claim.assessment),
|
||||
supporting_sources=supporting_sources if request.include_sources else [],
|
||||
refuting_sources=refuting_sources if request.include_sources else [],
|
||||
reasoning=getattr(claim, 'reasoning', None)
|
||||
)
|
||||
claims.append(claim_obj)
|
||||
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
response = HallucinationDetectionResponse(
|
||||
success=True,
|
||||
claims=claims,
|
||||
overall_confidence=result.overall_confidence,
|
||||
total_claims=result.total_claims,
|
||||
supported_claims=result.supported_claims,
|
||||
refuted_claims=result.refuted_claims,
|
||||
insufficient_claims=result.insufficient_claims,
|
||||
timestamp=result.timestamp,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
logger.info(f"Hallucination detection completed successfully. Processing time: {processing_time}ms")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hallucination detection: {str(e)}")
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Return proper error response
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "Hallucination detection failed. Please check API keys and try again.",
|
||||
"timestamp": time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
"processing_time_ms": processing_time
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/extract-claims", response_model=ClaimExtractionResponse)
|
||||
async def extract_claims(request: ClaimExtractionRequest) -> ClaimExtractionResponse:
|
||||
"""
|
||||
Extract verifiable claims from the provided text.
|
||||
|
||||
This endpoint performs only the claim extraction step of the
|
||||
hallucination detection pipeline.
|
||||
|
||||
Args:
|
||||
request: ClaimExtractionRequest with text to analyze
|
||||
|
||||
Returns:
|
||||
ClaimExtractionResponse with extracted claims
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Extracting claims from text of length: {len(request.text)}")
|
||||
|
||||
# Extract claims
|
||||
claims = await detector._extract_claims(request.text)
|
||||
|
||||
# Limit claims if requested
|
||||
if request.max_claims and len(claims) > request.max_claims:
|
||||
claims = claims[:request.max_claims]
|
||||
|
||||
response = ClaimExtractionResponse(
|
||||
success=True,
|
||||
claims=claims,
|
||||
total_claims=len(claims),
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
)
|
||||
|
||||
logger.info(f"Claim extraction completed. Extracted {len(claims)} claims")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in claim extraction: {str(e)}")
|
||||
|
||||
return ClaimExtractionResponse(
|
||||
success=False,
|
||||
claims=[],
|
||||
total_claims=0,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@router.post("/verify-claim", response_model=ClaimVerificationResponse)
|
||||
async def verify_claim(request: ClaimVerificationRequest) -> ClaimVerificationResponse:
|
||||
"""
|
||||
Verify a single claim against available sources.
|
||||
|
||||
This endpoint performs claim verification using Exa.ai search
|
||||
and LLM-based assessment.
|
||||
|
||||
Args:
|
||||
request: ClaimVerificationRequest with claim to verify
|
||||
|
||||
Returns:
|
||||
ClaimVerificationResponse with verification results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"Verifying claim: {request.claim[:100]}...")
|
||||
|
||||
# Verify the claim
|
||||
claim_result = await detector._verify_claim(request.claim)
|
||||
|
||||
# Convert to response format
|
||||
supporting_sources = []
|
||||
refuting_sources = []
|
||||
|
||||
if request.include_sources:
|
||||
supporting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim_result.supporting_sources
|
||||
]
|
||||
|
||||
refuting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim_result.refuting_sources
|
||||
]
|
||||
|
||||
claim_obj = Claim(
|
||||
text=claim_result.text,
|
||||
confidence=claim_result.confidence,
|
||||
assessment=AssessmentType(claim_result.assessment),
|
||||
supporting_sources=supporting_sources,
|
||||
refuting_sources=refuting_sources,
|
||||
reasoning=getattr(claim_result, 'reasoning', None)
|
||||
)
|
||||
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
response = ClaimVerificationResponse(
|
||||
success=True,
|
||||
claim=claim_obj,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
logger.info(f"Claim verification completed. Assessment: {claim_result.assessment}")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in claim verification: {str(e)}")
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
return ClaimVerificationResponse(
|
||||
success=False,
|
||||
claim=Claim(
|
||||
text=request.claim,
|
||||
confidence=0.0,
|
||||
assessment=AssessmentType.INSUFFICIENT_INFORMATION,
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning="Error during verification"
|
||||
),
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
processing_time_ms=processing_time,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@router.get("/health", response_model=HealthCheckResponse)
|
||||
async def health_check() -> HealthCheckResponse:
|
||||
"""
|
||||
Health check endpoint for the hallucination detector service.
|
||||
|
||||
Returns:
|
||||
HealthCheckResponse with service status and API availability
|
||||
"""
|
||||
try:
|
||||
# Check API availability
|
||||
exa_available = bool(detector.exa_api_key)
|
||||
openai_available = bool(detector.openai_api_key)
|
||||
|
||||
status = "healthy" if (exa_available or openai_available) else "degraded"
|
||||
|
||||
response = HealthCheckResponse(
|
||||
status=status,
|
||||
version="1.0.0",
|
||||
exa_api_available=exa_available,
|
||||
openai_api_available=openai_available,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check: {str(e)}")
|
||||
|
||||
return HealthCheckResponse(
|
||||
status="unhealthy",
|
||||
version="1.0.0",
|
||||
exa_api_available=False,
|
||||
openai_api_available=False,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
)
|
||||
|
||||
@router.get("/demo")
|
||||
async def demo_endpoint() -> Dict[str, Any]:
|
||||
"""
|
||||
Demo endpoint showing example usage of the hallucination detector.
|
||||
|
||||
Returns:
|
||||
Dictionary with example request/response data
|
||||
"""
|
||||
return {
|
||||
"description": "Hallucination Detector API Demo",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"detect": {
|
||||
"method": "POST",
|
||||
"path": "/api/hallucination-detector/detect",
|
||||
"description": "Detect hallucinations in text using Exa.ai",
|
||||
"example_request": {
|
||||
"text": "The Eiffel Tower is located in Paris and was built in 1889. It is 330 meters tall.",
|
||||
"include_sources": True,
|
||||
"max_claims": 5
|
||||
}
|
||||
},
|
||||
"extract_claims": {
|
||||
"method": "POST",
|
||||
"path": "/api/hallucination-detector/extract-claims",
|
||||
"description": "Extract verifiable claims from text",
|
||||
"example_request": {
|
||||
"text": "Our company increased sales by 25% last quarter. We launched 3 new products.",
|
||||
"max_claims": 10
|
||||
}
|
||||
},
|
||||
"verify_claim": {
|
||||
"method": "POST",
|
||||
"path": "/api/hallucination-detector/verify-claim",
|
||||
"description": "Verify a single claim against sources",
|
||||
"example_request": {
|
||||
"claim": "The Eiffel Tower is in Paris",
|
||||
"include_sources": True
|
||||
}
|
||||
}
|
||||
},
|
||||
"features": [
|
||||
"Claim extraction using LLM",
|
||||
"Evidence search using Exa.ai",
|
||||
"Claim verification with confidence scores",
|
||||
"Source attribution and credibility assessment",
|
||||
"Fallback mechanisms for API unavailability"
|
||||
]
|
||||
}
|
||||
61
backend/api/writing_assistant.py
Normal file
61
backend/api/writing_assistant.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from services.writing_assistant import WritingAssistantService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
|
||||
|
||||
class SuggestRequest(BaseModel):
|
||||
text: str
|
||||
max_results: int | None = 1
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
text: str | None = ""
|
||||
author: str | None = ""
|
||||
published_date: str | None = ""
|
||||
score: float
|
||||
|
||||
|
||||
class SuggestionModel(BaseModel):
|
||||
text: str
|
||||
confidence: float
|
||||
sources: List[SourceModel]
|
||||
|
||||
|
||||
class SuggestResponse(BaseModel):
|
||||
success: bool
|
||||
suggestions: List[SuggestionModel]
|
||||
|
||||
|
||||
assistant_service = WritingAssistantService()
|
||||
|
||||
|
||||
@router.post("/suggest", response_model=SuggestResponse)
|
||||
async def suggest_endpoint(req: SuggestRequest) -> SuggestResponse:
|
||||
try:
|
||||
suggestions = await assistant_service.suggest(req.text, req.max_results or 1)
|
||||
return SuggestResponse(
|
||||
success=True,
|
||||
suggestions=[
|
||||
SuggestionModel(
|
||||
text=s.text,
|
||||
confidence=s.confidence,
|
||||
sources=[
|
||||
SourceModel(**src) for src in s.sources
|
||||
],
|
||||
)
|
||||
for s in suggestions
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Writing assistant error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@@ -57,6 +57,10 @@ from routers.linkedin import router as linkedin_router
|
||||
# Import LinkedIn image generation router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
|
||||
# Import hallucination detector router
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
|
||||
# Import user data endpoints
|
||||
# Import content planning endpoints
|
||||
from api.content_planning.api.router import router as content_planning_router
|
||||
@@ -380,6 +384,10 @@ app.include_router(linkedin_router)
|
||||
# Include LinkedIn image generation router
|
||||
app.include_router(linkedin_image_router)
|
||||
|
||||
# Include hallucination detector router
|
||||
app.include_router(hallucination_detector_router)
|
||||
app.include_router(writing_assistant_router)
|
||||
|
||||
# Include user data router
|
||||
# Include content planning router
|
||||
app.include_router(content_planning_router)
|
||||
|
||||
85
backend/models/hallucination_models.py
Normal file
85
backend/models/hallucination_models.py
Normal file
@@ -0,0 +1,85 @@
|
||||
"""
|
||||
Pydantic models for hallucination detection API endpoints.
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
class AssessmentType(str, Enum):
|
||||
"""Assessment types for claim verification."""
|
||||
SUPPORTED = "supported"
|
||||
REFUTED = "refuted"
|
||||
INSUFFICIENT_INFORMATION = "insufficient_information"
|
||||
|
||||
class SourceDocument(BaseModel):
|
||||
"""Represents a source document used for fact-checking."""
|
||||
title: str = Field(..., description="Title of the source document")
|
||||
url: str = Field(..., description="URL of the source document")
|
||||
text: str = Field(..., description="Relevant text content from the source")
|
||||
published_date: Optional[str] = Field(None, description="Publication date of the source")
|
||||
author: Optional[str] = Field(None, description="Author of the source")
|
||||
score: float = Field(0.5, description="Relevance score of the source (0.0-1.0)")
|
||||
|
||||
class Claim(BaseModel):
|
||||
"""Represents a single verifiable claim extracted from text."""
|
||||
text: str = Field(..., description="The claim text")
|
||||
confidence: float = Field(..., description="Confidence score for the claim assessment (0.0-1.0)")
|
||||
assessment: AssessmentType = Field(..., description="Assessment result for the claim")
|
||||
supporting_sources: List[SourceDocument] = Field(default_factory=list, description="Sources that support the claim")
|
||||
refuting_sources: List[SourceDocument] = Field(default_factory=list, description="Sources that refute the claim")
|
||||
reasoning: Optional[str] = Field(None, description="Explanation for the assessment")
|
||||
|
||||
class HallucinationDetectionRequest(BaseModel):
|
||||
"""Request model for hallucination detection."""
|
||||
text: str = Field(..., description="Text to analyze for factual accuracy", min_length=10, max_length=5000)
|
||||
include_sources: bool = Field(True, description="Whether to include source documents in the response")
|
||||
max_claims: int = Field(10, description="Maximum number of claims to extract and verify", ge=1, le=20)
|
||||
|
||||
class HallucinationDetectionResponse(BaseModel):
|
||||
"""Response model for hallucination detection."""
|
||||
success: bool = Field(..., description="Whether the analysis was successful")
|
||||
claims: List[Claim] = Field(default_factory=list, description="List of extracted and verified claims")
|
||||
overall_confidence: float = Field(..., description="Overall confidence score for the analysis (0.0-1.0)")
|
||||
total_claims: int = Field(..., description="Total number of claims extracted")
|
||||
supported_claims: int = Field(..., description="Number of claims that are supported by sources")
|
||||
refuted_claims: int = Field(..., description="Number of claims that are refuted by sources")
|
||||
insufficient_claims: int = Field(..., description="Number of claims with insufficient information")
|
||||
timestamp: str = Field(..., description="Timestamp of the analysis")
|
||||
processing_time_ms: Optional[int] = Field(None, description="Processing time in milliseconds")
|
||||
error: Optional[str] = Field(None, description="Error message if analysis failed")
|
||||
|
||||
class ClaimExtractionRequest(BaseModel):
|
||||
"""Request model for claim extraction only."""
|
||||
text: str = Field(..., description="Text to extract claims from", min_length=10, max_length=5000)
|
||||
max_claims: int = Field(10, description="Maximum number of claims to extract", ge=1, le=20)
|
||||
|
||||
class ClaimExtractionResponse(BaseModel):
|
||||
"""Response model for claim extraction."""
|
||||
success: bool = Field(..., description="Whether the extraction was successful")
|
||||
claims: List[str] = Field(default_factory=list, description="List of extracted claim texts")
|
||||
total_claims: int = Field(..., description="Total number of claims extracted")
|
||||
timestamp: str = Field(..., description="Timestamp of the extraction")
|
||||
error: Optional[str] = Field(None, description="Error message if extraction failed")
|
||||
|
||||
class ClaimVerificationRequest(BaseModel):
|
||||
"""Request model for verifying a single claim."""
|
||||
claim: str = Field(..., description="Claim to verify", min_length=5, max_length=500)
|
||||
include_sources: bool = Field(True, description="Whether to include source documents in the response")
|
||||
|
||||
class ClaimVerificationResponse(BaseModel):
|
||||
"""Response model for claim verification."""
|
||||
success: bool = Field(..., description="Whether the verification was successful")
|
||||
claim: Claim = Field(..., description="Verified claim with assessment results")
|
||||
timestamp: str = Field(..., description="Timestamp of the verification")
|
||||
processing_time_ms: Optional[int] = Field(None, description="Processing time in milliseconds")
|
||||
error: Optional[str] = Field(None, description="Error message if verification failed")
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
"""Response model for health check."""
|
||||
status: str = Field(..., description="Service status")
|
||||
version: str = Field(..., description="Service version")
|
||||
exa_api_available: bool = Field(..., description="Whether Exa API is available")
|
||||
openai_api_available: bool = Field(..., description="Whether OpenAI API is available")
|
||||
timestamp: str = Field(..., description="Timestamp of the health check")
|
||||
702
backend/services/hallucination_detector.py
Normal file
702
backend/services/hallucination_detector.py
Normal file
@@ -0,0 +1,702 @@
|
||||
"""
|
||||
Hallucination Detector Service
|
||||
|
||||
This service implements fact-checking functionality using Exa.ai API
|
||||
to detect and verify claims in AI-generated content, similar to the
|
||||
Exa.ai demo implementation.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import requests
|
||||
import os
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
try:
|
||||
from google import genai
|
||||
GOOGLE_GENAI_AVAILABLE = True
|
||||
except Exception:
|
||||
GOOGLE_GENAI_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@dataclass
|
||||
class Claim:
|
||||
"""Represents a single verifiable claim extracted from text."""
|
||||
text: str
|
||||
confidence: float
|
||||
assessment: str # "supported", "refuted", "insufficient_information"
|
||||
supporting_sources: List[Dict[str, Any]]
|
||||
refuting_sources: List[Dict[str, Any]]
|
||||
reasoning: str = ""
|
||||
|
||||
@dataclass
|
||||
class HallucinationResult:
|
||||
"""Result of hallucination detection analysis."""
|
||||
claims: List[Claim]
|
||||
overall_confidence: float
|
||||
total_claims: int
|
||||
supported_claims: int
|
||||
refuted_claims: int
|
||||
insufficient_claims: int
|
||||
timestamp: str
|
||||
|
||||
class HallucinationDetector:
|
||||
"""
|
||||
Hallucination detector using Exa.ai for fact-checking.
|
||||
|
||||
Implements the three-step process from Exa.ai demo:
|
||||
1. Extract verifiable claims from text
|
||||
2. Search for evidence using Exa.ai
|
||||
3. Verify claims against sources
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.exa_api_key = os.getenv('EXA_API_KEY')
|
||||
self.gemini_api_key = os.getenv('GEMINI_API_KEY')
|
||||
|
||||
if not self.exa_api_key:
|
||||
logger.warning("EXA_API_KEY not found. Hallucination detection will be limited.")
|
||||
|
||||
if not self.gemini_api_key:
|
||||
logger.warning("GEMINI_API_KEY not found. Falling back to heuristic claim extraction.")
|
||||
|
||||
# Initialize Gemini client for claim extraction and assessment
|
||||
self.gemini_client = genai.Client(api_key=self.gemini_api_key) if (GOOGLE_GENAI_AVAILABLE and self.gemini_api_key) else None
|
||||
|
||||
# Rate limiting to prevent API abuse
|
||||
self.daily_api_calls = 0
|
||||
self.daily_limit = 20 # Max 20 API calls per day for fact checking
|
||||
self.last_reset_date = None
|
||||
|
||||
def _check_rate_limit(self) -> bool:
|
||||
"""Check if we're within daily API usage limits."""
|
||||
from datetime import date
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Reset counter if it's a new day
|
||||
if self.last_reset_date != today:
|
||||
self.daily_api_calls = 0
|
||||
self.last_reset_date = today
|
||||
|
||||
# Check if we've exceeded the limit
|
||||
if self.daily_api_calls >= self.daily_limit:
|
||||
logger.warning(f"Daily API limit reached ({self.daily_limit} calls). Fact checking disabled for today.")
|
||||
return False
|
||||
|
||||
# Increment counter for this API call
|
||||
self.daily_api_calls += 1
|
||||
logger.info(f"Fact check API call #{self.daily_api_calls}/{self.daily_limit} today")
|
||||
return True
|
||||
|
||||
async def detect_hallucinations(self, text: str) -> HallucinationResult:
|
||||
"""
|
||||
Main method to detect hallucinations in the given text.
|
||||
|
||||
Args:
|
||||
text: The text to analyze for factual accuracy
|
||||
|
||||
Returns:
|
||||
HallucinationResult with claims analysis and confidence scores
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting hallucination detection for text of length: {len(text)}")
|
||||
logger.info(f"Text sample: {text[:200]}...")
|
||||
|
||||
# Check rate limits first
|
||||
if not self._check_rate_limit():
|
||||
return HallucinationResult(
|
||||
claims=[],
|
||||
overall_confidence=0.0,
|
||||
total_claims=0,
|
||||
supported_claims=0,
|
||||
refuted_claims=0,
|
||||
insufficient_claims=0,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
# Validate required API keys
|
||||
if not self.gemini_api_key:
|
||||
raise Exception("GEMINI_API_KEY not configured. Cannot perform hallucination detection.")
|
||||
if not self.exa_api_key:
|
||||
raise Exception("EXA_API_KEY not configured. Cannot search for evidence.")
|
||||
|
||||
# Step 1: Extract claims from text
|
||||
claims_texts = await self._extract_claims(text)
|
||||
logger.info(f"Extracted {len(claims_texts)} claims from text: {claims_texts}")
|
||||
|
||||
if not claims_texts:
|
||||
logger.warning("No verifiable claims found in text")
|
||||
return HallucinationResult(
|
||||
claims=[],
|
||||
overall_confidence=0.0,
|
||||
total_claims=0,
|
||||
supported_claims=0,
|
||||
refuted_claims=0,
|
||||
insufficient_claims=0,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
# Step 2 & 3: Verify claims in batch to reduce API calls
|
||||
verified_claims = await self._verify_claims_batch(claims_texts)
|
||||
|
||||
# Calculate overall metrics
|
||||
total_claims = len(verified_claims)
|
||||
supported_claims = sum(1 for c in verified_claims if c.assessment == "supported")
|
||||
refuted_claims = sum(1 for c in verified_claims if c.assessment == "refuted")
|
||||
insufficient_claims = sum(1 for c in verified_claims if c.assessment == "insufficient_information")
|
||||
|
||||
# Calculate overall confidence (weighted average)
|
||||
if total_claims > 0:
|
||||
overall_confidence = sum(c.confidence for c in verified_claims) / total_claims
|
||||
else:
|
||||
overall_confidence = 0.0
|
||||
|
||||
result = HallucinationResult(
|
||||
claims=verified_claims,
|
||||
overall_confidence=overall_confidence,
|
||||
total_claims=total_claims,
|
||||
supported_claims=supported_claims,
|
||||
refuted_claims=refuted_claims,
|
||||
insufficient_claims=insufficient_claims,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
logger.info(f"Hallucination detection completed. Overall confidence: {overall_confidence:.2f}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hallucination detection: {str(e)}")
|
||||
raise Exception(f"Hallucination detection failed: {str(e)}")
|
||||
|
||||
async def _extract_claims(self, text: str) -> List[str]:
|
||||
"""
|
||||
Extract verifiable claims from text using LLM.
|
||||
|
||||
Args:
|
||||
text: Input text to extract claims from
|
||||
|
||||
Returns:
|
||||
List of claim strings
|
||||
"""
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available. Cannot extract claims without AI provider.")
|
||||
|
||||
try:
|
||||
prompt = (
|
||||
"Extract verifiable factual claims from the following text. "
|
||||
"A verifiable claim is a statement that can be checked against external sources for accuracy.\n\n"
|
||||
"Return ONLY a valid JSON array of strings, where each string is a single verifiable claim.\n\n"
|
||||
"Examples of GOOD verifiable claims:\n"
|
||||
"- \"The company was founded in 2020\"\n"
|
||||
"- \"Sales increased by 25% last quarter\"\n"
|
||||
"- \"The product has 10,000 users\"\n"
|
||||
"- \"The market size is $50 billion\"\n"
|
||||
"- \"The software supports 15 languages\"\n"
|
||||
"- \"The company has offices in 5 countries\"\n\n"
|
||||
"Examples of BAD claims (opinions, subjective statements):\n"
|
||||
"- \"This is the best product\"\n"
|
||||
"- \"Customers love our service\"\n"
|
||||
"- \"We are innovative\"\n"
|
||||
"- \"The future looks bright\"\n\n"
|
||||
"IMPORTANT: Extract at least 2-3 verifiable claims if possible. "
|
||||
"Look for specific facts, numbers, dates, locations, and measurable statements.\n\n"
|
||||
f"Text to analyze: {text}\n\n"
|
||||
"Return only the JSON array of verifiable claims:"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=prompt
|
||||
))
|
||||
|
||||
if not resp or not resp.text:
|
||||
raise Exception("Empty response from Gemini API")
|
||||
|
||||
claims_text = resp.text.strip()
|
||||
logger.info(f"Raw Gemini response for claims: {claims_text[:200]}...")
|
||||
|
||||
# Try to extract JSON from the response
|
||||
try:
|
||||
claims = json.loads(claims_text)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON array in the response (handle markdown code blocks)
|
||||
import re
|
||||
# First try to extract from markdown code blocks
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', claims_text, re.DOTALL)
|
||||
if code_block_match:
|
||||
claims = json.loads(code_block_match.group(1))
|
||||
else:
|
||||
# Try to find JSON array directly
|
||||
json_match = re.search(r'\[.*?\]', claims_text, re.DOTALL)
|
||||
if json_match:
|
||||
claims = json.loads(json_match.group())
|
||||
else:
|
||||
raise Exception(f"Could not parse JSON from Gemini response: {claims_text[:100]}")
|
||||
|
||||
if isinstance(claims, list):
|
||||
valid_claims = [claim for claim in claims if isinstance(claim, str) and claim.strip()]
|
||||
logger.info(f"Successfully extracted {len(valid_claims)} claims")
|
||||
return valid_claims
|
||||
else:
|
||||
raise Exception(f"Expected JSON array, got: {type(claims)}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting claims: {str(e)}")
|
||||
raise Exception(f"Failed to extract claims: {str(e)}")
|
||||
|
||||
|
||||
async def _verify_claims_batch(self, claims: List[str]) -> List[Claim]:
|
||||
"""
|
||||
Verify multiple claims in batch to reduce API calls.
|
||||
|
||||
Args:
|
||||
claims: List of claims to verify
|
||||
|
||||
Returns:
|
||||
List of Claim objects with verification results
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting batch verification of {len(claims)} claims")
|
||||
|
||||
# Limit to maximum 3 claims to prevent excessive API usage
|
||||
max_claims = min(len(claims), 3)
|
||||
claims_to_verify = claims[:max_claims]
|
||||
|
||||
if len(claims) > max_claims:
|
||||
logger.warning(f"Limited verification to {max_claims} claims to prevent API rate limits")
|
||||
|
||||
# Step 1: Search for evidence for all claims in one batch
|
||||
all_sources = await self._search_evidence_batch(claims_to_verify)
|
||||
|
||||
# Step 2: Assess all claims against sources in one API call
|
||||
verified_claims = await self._assess_claims_batch(claims_to_verify, all_sources)
|
||||
|
||||
# Add any remaining claims as insufficient information
|
||||
for i in range(max_claims, len(claims)):
|
||||
verified_claims.append(Claim(
|
||||
text=claims[i],
|
||||
confidence=0.0,
|
||||
assessment="insufficient_information",
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning="Not verified due to API rate limit protection"
|
||||
))
|
||||
|
||||
logger.info(f"Batch verification completed for {len(verified_claims)} claims")
|
||||
return verified_claims
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch verification: {str(e)}")
|
||||
# Return all claims as insufficient information
|
||||
return [
|
||||
Claim(
|
||||
text=claim,
|
||||
confidence=0.0,
|
||||
assessment="insufficient_information",
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning=f"Batch verification failed: {str(e)}"
|
||||
)
|
||||
for claim in claims
|
||||
]
|
||||
|
||||
async def _verify_claim(self, claim: str) -> Claim:
|
||||
"""
|
||||
Verify a single claim using Exa.ai search.
|
||||
|
||||
Args:
|
||||
claim: The claim to verify
|
||||
|
||||
Returns:
|
||||
Claim object with verification results
|
||||
"""
|
||||
try:
|
||||
# Search for evidence using Exa.ai
|
||||
sources = await self._search_evidence(claim)
|
||||
|
||||
if not sources:
|
||||
return Claim(
|
||||
text=claim,
|
||||
confidence=0.5,
|
||||
assessment="insufficient_information",
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning="No sources found for verification"
|
||||
)
|
||||
|
||||
# Verify claim against sources using LLM
|
||||
verification_result = await self._assess_claim_against_sources(claim, sources)
|
||||
|
||||
return Claim(
|
||||
text=claim,
|
||||
confidence=verification_result.get('confidence', 0.5),
|
||||
assessment=verification_result.get('assessment', 'insufficient_information'),
|
||||
supporting_sources=verification_result.get('supporting_sources', []),
|
||||
refuting_sources=verification_result.get('refuting_sources', []),
|
||||
reasoning=verification_result.get('reasoning', '')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying claim '{claim}': {str(e)}")
|
||||
return Claim(
|
||||
text=claim,
|
||||
confidence=0.5,
|
||||
assessment="insufficient_information",
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning=f"Error during verification: {str(e)}"
|
||||
)
|
||||
|
||||
async def _search_evidence_batch(self, claims: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for evidence for multiple claims in one API call.
|
||||
|
||||
Args:
|
||||
claims: List of claims to search for
|
||||
|
||||
Returns:
|
||||
List of sources relevant to the claims
|
||||
"""
|
||||
try:
|
||||
# Combine all claims into one search query
|
||||
combined_query = " ".join(claims[:2]) # Use first 2 claims to avoid query length limits
|
||||
|
||||
logger.info(f"Searching for evidence for {len(claims)} claims with combined query")
|
||||
|
||||
# Use the existing search method with combined query
|
||||
sources = await self._search_evidence(combined_query)
|
||||
|
||||
# Limit sources to prevent excessive processing
|
||||
max_sources = 5
|
||||
if len(sources) > max_sources:
|
||||
sources = sources[:max_sources]
|
||||
logger.info(f"Limited sources to {max_sources} to prevent API rate limits")
|
||||
|
||||
return sources
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch evidence search: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _assess_claims_batch(self, claims: List[str], sources: List[Dict[str, Any]]) -> List[Claim]:
|
||||
"""
|
||||
Assess multiple claims against sources in one API call.
|
||||
|
||||
Args:
|
||||
claims: List of claims to assess
|
||||
sources: List of sources to assess against
|
||||
|
||||
Returns:
|
||||
List of Claim objects with assessment results
|
||||
"""
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available. Cannot assess claims without AI provider.")
|
||||
|
||||
try:
|
||||
# Limit to 3 claims to prevent excessive API usage
|
||||
claims_to_assess = claims[:3]
|
||||
|
||||
# Prepare sources text
|
||||
combined_sources = "\n\n".join([
|
||||
f"Source {i+1}: {src.get('url','')}\nText: {src.get('text','')[:1000]}"
|
||||
for i, src in enumerate(sources)
|
||||
])
|
||||
|
||||
# Prepare claims text
|
||||
claims_text = "\n".join([
|
||||
f"Claim {i+1}: {claim}"
|
||||
for i, claim in enumerate(claims_to_assess)
|
||||
])
|
||||
|
||||
prompt = (
|
||||
"You are a strict fact-checker. Analyze each claim against the provided sources.\n\n"
|
||||
"Return ONLY a valid JSON object with this exact structure:\n"
|
||||
"{\n"
|
||||
' "assessments": [\n'
|
||||
' {\n'
|
||||
' "claim_index": 0,\n'
|
||||
' "assessment": "supported" or "refuted" or "insufficient_information",\n'
|
||||
' "confidence": number between 0.0 and 1.0,\n'
|
||||
' "supporting_sources": [array of source indices that support the claim],\n'
|
||||
' "refuting_sources": [array of source indices that refute the claim],\n'
|
||||
' "reasoning": "brief explanation of your assessment"\n'
|
||||
' }\n'
|
||||
' ]\n'
|
||||
"}\n\n"
|
||||
f"Claims to verify:\n{claims_text}\n\n"
|
||||
f"Sources:\n{combined_sources}\n\n"
|
||||
"Return only the JSON object:"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=prompt
|
||||
))
|
||||
|
||||
if not resp or not resp.text:
|
||||
raise Exception("Empty response from Gemini API for batch assessment")
|
||||
|
||||
result_text = resp.text.strip()
|
||||
logger.info(f"Raw Gemini response for batch assessment: {result_text[:200]}...")
|
||||
|
||||
# Try to extract JSON from the response
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON object in the response (handle markdown code blocks)
|
||||
import re
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', result_text, re.DOTALL)
|
||||
if code_block_match:
|
||||
result = json.loads(code_block_match.group(1))
|
||||
else:
|
||||
json_match = re.search(r'\{.*?\}', result_text, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
else:
|
||||
raise Exception(f"Could not parse JSON from Gemini response: {result_text[:100]}")
|
||||
|
||||
# Process assessments
|
||||
assessments = result.get('assessments', [])
|
||||
verified_claims = []
|
||||
|
||||
for i, claim in enumerate(claims_to_assess):
|
||||
# Find assessment for this claim
|
||||
assessment = None
|
||||
for a in assessments:
|
||||
if a.get('claim_index') == i:
|
||||
assessment = a
|
||||
break
|
||||
|
||||
if assessment:
|
||||
# Process supporting and refuting sources
|
||||
supporting_sources = []
|
||||
refuting_sources = []
|
||||
|
||||
if isinstance(assessment.get('supporting_sources'), list):
|
||||
for idx in assessment['supporting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
supporting_sources.append(sources[idx])
|
||||
|
||||
if isinstance(assessment.get('refuting_sources'), list):
|
||||
for idx in assessment['refuting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
refuting_sources.append(sources[idx])
|
||||
|
||||
verified_claims.append(Claim(
|
||||
text=claim,
|
||||
confidence=float(assessment.get('confidence', 0.5)),
|
||||
assessment=assessment.get('assessment', 'insufficient_information'),
|
||||
supporting_sources=supporting_sources,
|
||||
refuting_sources=refuting_sources,
|
||||
reasoning=assessment.get('reasoning', '')
|
||||
))
|
||||
else:
|
||||
# No assessment found for this claim
|
||||
verified_claims.append(Claim(
|
||||
text=claim,
|
||||
confidence=0.0,
|
||||
assessment="insufficient_information",
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning="No assessment provided"
|
||||
))
|
||||
|
||||
logger.info(f"Successfully assessed {len(verified_claims)} claims in batch")
|
||||
return verified_claims
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch assessment: {str(e)}")
|
||||
# Return all claims as insufficient information
|
||||
return [
|
||||
Claim(
|
||||
text=claim,
|
||||
confidence=0.0,
|
||||
assessment="insufficient_information",
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning=f"Batch assessment failed: {str(e)}"
|
||||
)
|
||||
for claim in claims_to_assess
|
||||
]
|
||||
|
||||
async def _search_evidence(self, claim: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for evidence using Exa.ai API.
|
||||
|
||||
Args:
|
||||
claim: The claim to search evidence for
|
||||
|
||||
Returns:
|
||||
List of source documents with evidence
|
||||
"""
|
||||
if not self.exa_api_key:
|
||||
raise Exception("Exa API key not available. Cannot search for evidence without Exa.ai access.")
|
||||
|
||||
try:
|
||||
headers = {
|
||||
'x-api-key': self.exa_api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
payload = {
|
||||
'query': claim,
|
||||
'numResults': 5,
|
||||
'text': True,
|
||||
'useAutoprompt': True
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'https://api.exa.ai/search',
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=15
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = data.get('results', [])
|
||||
|
||||
if not results:
|
||||
raise Exception(f"No search results found for claim: {claim}")
|
||||
|
||||
sources = []
|
||||
for result in results:
|
||||
source = {
|
||||
'title': result.get('title', 'Untitled'),
|
||||
'url': result.get('url', ''),
|
||||
'text': result.get('text', ''),
|
||||
'publishedDate': result.get('publishedDate', ''),
|
||||
'author': result.get('author', ''),
|
||||
'score': result.get('score', 0.5)
|
||||
}
|
||||
sources.append(source)
|
||||
|
||||
logger.info(f"Found {len(sources)} sources for claim: {claim[:50]}...")
|
||||
return sources
|
||||
else:
|
||||
raise Exception(f"Exa API error: {response.status_code} - {response.text}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching evidence with Exa: {str(e)}")
|
||||
raise Exception(f"Failed to search evidence: {str(e)}")
|
||||
|
||||
|
||||
async def _assess_claim_against_sources(self, claim: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Assess whether sources support or refute the claim using LLM.
|
||||
|
||||
Args:
|
||||
claim: The claim to assess
|
||||
sources: List of source documents
|
||||
|
||||
Returns:
|
||||
Dictionary with assessment results
|
||||
"""
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available. Cannot assess claims without AI provider.")
|
||||
|
||||
try:
|
||||
combined_sources = "\n\n".join([
|
||||
f"Source {i+1}: {src.get('url','')}\nText: {src.get('text','')[:2000]}"
|
||||
for i, src in enumerate(sources)
|
||||
])
|
||||
|
||||
prompt = (
|
||||
"You are a strict fact-checker. Analyze the claim against the provided sources.\n\n"
|
||||
"Return ONLY a valid JSON object with this exact structure:\n"
|
||||
"{\n"
|
||||
' "assessment": "supported" or "refuted" or "insufficient_information",\n'
|
||||
' "confidence": number between 0.0 and 1.0,\n'
|
||||
' "supporting_sources": [array of source indices that support the claim],\n'
|
||||
' "refuting_sources": [array of source indices that refute the claim],\n'
|
||||
' "reasoning": "brief explanation of your assessment"\n'
|
||||
"}\n\n"
|
||||
f"Claim to verify: {claim}\n\n"
|
||||
f"Sources:\n{combined_sources}\n\n"
|
||||
"Return only the JSON object:"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=prompt
|
||||
))
|
||||
|
||||
if not resp or not resp.text:
|
||||
raise Exception("Empty response from Gemini API for claim assessment")
|
||||
|
||||
result_text = resp.text.strip()
|
||||
logger.info(f"Raw Gemini response for assessment: {result_text[:200]}...")
|
||||
|
||||
# Try to extract JSON from the response
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON object in the response (handle markdown code blocks)
|
||||
import re
|
||||
# First try to extract from markdown code blocks
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', result_text, re.DOTALL)
|
||||
if code_block_match:
|
||||
result = json.loads(code_block_match.group(1))
|
||||
else:
|
||||
# Try to find JSON object directly
|
||||
json_match = re.search(r'\{.*?\}', result_text, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
else:
|
||||
raise Exception(f"Could not parse JSON from Gemini response: {result_text[:100]}")
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ['assessment', 'confidence', 'supporting_sources', 'refuting_sources', 'reasoning']
|
||||
for field in required_fields:
|
||||
if field not in result:
|
||||
raise Exception(f"Missing required field '{field}' in assessment response")
|
||||
|
||||
# Process supporting and refuting sources
|
||||
supporting_sources = []
|
||||
refuting_sources = []
|
||||
|
||||
if isinstance(result.get('supporting_sources'), list):
|
||||
for idx in result['supporting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
supporting_sources.append(sources[idx])
|
||||
|
||||
if isinstance(result.get('refuting_sources'), list):
|
||||
for idx in result['refuting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
refuting_sources.append(sources[idx])
|
||||
|
||||
# Validate assessment value
|
||||
valid_assessments = ['supported', 'refuted', 'insufficient_information']
|
||||
if result['assessment'] not in valid_assessments:
|
||||
raise Exception(f"Invalid assessment value: {result['assessment']}")
|
||||
|
||||
# Validate confidence value
|
||||
confidence = float(result['confidence'])
|
||||
if not (0.0 <= confidence <= 1.0):
|
||||
raise Exception(f"Invalid confidence value: {confidence}")
|
||||
|
||||
logger.info(f"Successfully assessed claim: {result['assessment']} (confidence: {confidence})")
|
||||
|
||||
return {
|
||||
'assessment': result['assessment'],
|
||||
'confidence': confidence,
|
||||
'supporting_sources': supporting_sources,
|
||||
'refuting_sources': refuting_sources,
|
||||
'reasoning': result['reasoning']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assessing claim against sources: {str(e)}")
|
||||
raise Exception(f"Failed to assess claim: {str(e)}")
|
||||
|
||||
@@ -355,7 +355,38 @@ class ContentGenerator:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating grounded post content: {str(e)}")
|
||||
raise Exception(f"Failed to generate grounded post content: {str(e)}")
|
||||
logger.info("Attempting fallback to standard content generation...")
|
||||
|
||||
# Fallback to standard content generation without grounding
|
||||
try:
|
||||
if not self.fallback_provider:
|
||||
raise Exception("No fallback provider available")
|
||||
|
||||
# Build a simpler prompt for fallback generation
|
||||
prompt = PostPromptBuilder.build_post_prompt(request)
|
||||
|
||||
# Generate content using fallback provider (it's a dict with functions)
|
||||
if 'generate_text' in self.fallback_provider:
|
||||
result = await self.fallback_provider['generate_text'](
|
||||
prompt=prompt,
|
||||
temperature=0.7,
|
||||
max_tokens=request.max_length
|
||||
)
|
||||
else:
|
||||
raise Exception("Fallback provider doesn't have generate_text method")
|
||||
|
||||
# Return result in the expected format
|
||||
return {
|
||||
'content': result.get('content', '') if isinstance(result, dict) else str(result),
|
||||
'sources': [],
|
||||
'citations': [],
|
||||
'grounding_enabled': False,
|
||||
'fallback_used': True
|
||||
}
|
||||
|
||||
except Exception as fallback_error:
|
||||
logger.error(f"Fallback generation also failed: {str(fallback_error)}")
|
||||
raise Exception(f"Failed to generate content: {str(e)}. Fallback also failed: {str(fallback_error)}")
|
||||
|
||||
async def generate_grounded_article_content(self, request, research_sources: List) -> Dict[str, Any]:
|
||||
"""Generate grounded article content using the enhanced Gemini provider with native grounding."""
|
||||
|
||||
@@ -41,8 +41,9 @@ class GeminiGroundedProvider:
|
||||
if not self.api_key:
|
||||
raise ValueError("GEMINI_API_KEY environment variable is required")
|
||||
|
||||
# Initialize the Gemini client
|
||||
# Initialize the Gemini client with timeout configuration
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
self.timeout = 30 # 30 second timeout for API calls
|
||||
logger.info("✅ Gemini Grounded Provider initialized with native Google Search grounding")
|
||||
|
||||
async def generate_grounded_content(
|
||||
@@ -82,12 +83,27 @@ class GeminiGroundedProvider:
|
||||
temperature=temperature
|
||||
)
|
||||
|
||||
# Make the request with native grounding
|
||||
response = self.client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
# Make the request with native grounding and timeout
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
|
||||
try:
|
||||
# Run the synchronous generate_content in a thread pool to make it awaitable
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
response = await asyncio.wait_for(
|
||||
loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.client.models.generate_content(
|
||||
model="gemini-2.5-flash",
|
||||
contents=grounded_prompt,
|
||||
config=config,
|
||||
)
|
||||
),
|
||||
timeout=self.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise Exception(f"Gemini API request timed out after {self.timeout} seconds")
|
||||
|
||||
# Process the grounded response
|
||||
result = self._process_grounded_response(response, content_type)
|
||||
|
||||
201
backend/services/writing_assistant.py
Normal file
201
backend/services/writing_assistant.py
Normal file
@@ -0,0 +1,201 @@
|
||||
import os
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
from typing import Any, Dict, List
|
||||
from dataclasses import dataclass
|
||||
import requests
|
||||
from loguru import logger
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
GOOGLE_GENAI_AVAILABLE = True
|
||||
except Exception:
|
||||
GOOGLE_GENAI_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class WritingSuggestion:
|
||||
text: str
|
||||
confidence: float
|
||||
sources: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class WritingAssistantService:
|
||||
"""
|
||||
Minimal writing assistant that combines Exa search with Gemini continuation.
|
||||
- Exa provides relevant sources with content snippets
|
||||
- Gemini generates a short, cited continuation based on current text and sources
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.exa_api_key = os.getenv("EXA_API_KEY")
|
||||
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
|
||||
|
||||
if not self.exa_api_key:
|
||||
logger.warning("EXA_API_KEY not configured; writing assistant will fail")
|
||||
|
||||
if not (GOOGLE_GENAI_AVAILABLE and self.gemini_api_key):
|
||||
logger.warning("Gemini not available; writing assistant will fail")
|
||||
self.gemini_client = None
|
||||
else:
|
||||
self.gemini_client = genai.Client(api_key=self.gemini_api_key)
|
||||
|
||||
self.http_timeout_seconds = 15
|
||||
|
||||
# COST CONTROL: Daily usage limits
|
||||
self.daily_api_calls = 0
|
||||
self.daily_limit = 50 # Max 50 API calls per day (~$2.50 max cost)
|
||||
self.last_reset_date = None
|
||||
|
||||
def _get_cached_suggestion(self, text: str) -> WritingSuggestion | None:
|
||||
"""No cached suggestions - always use real API calls for authentic results."""
|
||||
return None
|
||||
|
||||
def _check_daily_limit(self) -> bool:
|
||||
"""Check if we're within daily API usage limits."""
|
||||
import datetime
|
||||
|
||||
today = datetime.date.today()
|
||||
|
||||
# Reset counter if it's a new day
|
||||
if self.last_reset_date != today:
|
||||
self.daily_api_calls = 0
|
||||
self.last_reset_date = today
|
||||
|
||||
# Check if we've exceeded the limit
|
||||
if self.daily_api_calls >= self.daily_limit:
|
||||
return False
|
||||
|
||||
# Increment counter for this API call
|
||||
self.daily_api_calls += 1
|
||||
logger.info(f"Writing assistant API call #{self.daily_api_calls}/{self.daily_limit} today")
|
||||
return True
|
||||
|
||||
async def suggest(self, text: str, max_results: int = 1) -> List[WritingSuggestion]:
|
||||
if not text or len(text.strip()) < 6:
|
||||
return []
|
||||
|
||||
# COST OPTIMIZATION: Use cached/static suggestions for common patterns
|
||||
# This reduces API calls by 90%+ while maintaining usefulness
|
||||
cached_suggestion = self._get_cached_suggestion(text)
|
||||
if cached_suggestion:
|
||||
return [cached_suggestion]
|
||||
|
||||
# COST CONTROL: Check daily usage limits
|
||||
if not self._check_daily_limit():
|
||||
logger.warning("Daily API limit reached for writing assistant")
|
||||
return []
|
||||
|
||||
# Only make expensive API calls for unique, substantial content
|
||||
if len(text.strip()) < 50: # Skip API calls for very short text
|
||||
return []
|
||||
|
||||
# 1) Find relevant sources via Exa (reduced results for cost)
|
||||
sources = await self._search_sources(text)
|
||||
|
||||
# 2) Generate continuation suggestion via Gemini
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources)
|
||||
|
||||
if not suggestion_text:
|
||||
return []
|
||||
|
||||
return [WritingSuggestion(text=suggestion_text.strip(), confidence=confidence, sources=sources)]
|
||||
|
||||
async def _search_sources(self, text: str) -> List[Dict[str, Any]]:
|
||||
if not self.exa_api_key:
|
||||
raise Exception("EXA_API_KEY not configured")
|
||||
|
||||
# Follow Exa demo guidance: continuation-style prompt and 1000-char cap
|
||||
exa_query = (
|
||||
(text[-1000:] if len(text) > 1000 else text)
|
||||
+ "\n\nIf you found the above interesting, here's another useful resource to read:"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"query": exa_query,
|
||||
"numResults": 3, # Reduced from 5 to 3 for cost savings
|
||||
"text": True,
|
||||
"type": "neural",
|
||||
"highlights": {"numSentences": 1, "highlightsPerUrl": 1},
|
||||
}
|
||||
|
||||
try:
|
||||
resp = requests.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=self.http_timeout_seconds,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Exa error {resp.status_code}: {resp.text}")
|
||||
data = resp.json()
|
||||
results = data.get("results", [])
|
||||
sources: List[Dict[str, Any]] = []
|
||||
for r in results:
|
||||
sources.append(
|
||||
{
|
||||
"title": r.get("title", "Untitled"),
|
||||
"url": r.get("url", ""),
|
||||
"text": r.get("text", ""),
|
||||
"author": r.get("author", ""),
|
||||
"published_date": r.get("publishedDate", ""),
|
||||
"score": float(r.get("score", 0.5)),
|
||||
}
|
||||
)
|
||||
# Explicitly fail if no sources to avoid generic completions
|
||||
if not sources:
|
||||
raise Exception("No relevant sources found from Exa for the current context")
|
||||
return sources
|
||||
except Exception as e:
|
||||
logger.error(f"WritingAssistant _search_sources error: {e}")
|
||||
raise
|
||||
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]]) -> tuple[str, float]:
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available")
|
||||
|
||||
# Build compact sources context block
|
||||
source_blocks: List[str] = []
|
||||
for i, s in enumerate(sources[:5]):
|
||||
excerpt = (s.get("text", "") or "")
|
||||
excerpt = excerpt[:500]
|
||||
source_blocks.append(
|
||||
f"Source {i+1}: {s.get('title','') or 'Source'}\nURL: {s.get('url','')}\nExcerpt: {excerpt}"
|
||||
)
|
||||
sources_text = "\n\n".join(source_blocks) if source_blocks else "(No sources)"
|
||||
|
||||
# Based on Exa demo guidance for completion-only behavior and inline citations
|
||||
system_prompt = (
|
||||
"You are an essay-completion bot that completes a sentence or continues prose. "
|
||||
"Only produce 1-2 SHORT sentences. Do not repeat or paraphrase the user's stub. "
|
||||
"Continue in the same tone and topic as the stub. Prefer concrete, current facts from the provided sources. "
|
||||
"Include exactly one brief, verifiable citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"User text to continue (do not repeat):\n{text}\n\n"
|
||||
f"Relevant sources to inform your continuation:\n{sources_text}\n\n"
|
||||
"Return only the continuation text, without quotes."
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash", contents=f"{system_prompt}\n\n{user_prompt}"
|
||||
),
|
||||
)
|
||||
suggestion = (resp.text or "").strip()
|
||||
if not suggestion:
|
||||
raise Exception("Gemini returned empty suggestion")
|
||||
# naive confidence from number of sources present
|
||||
confidence = 0.7 if sources else 0.5
|
||||
return suggestion, confidence
|
||||
except Exception as e:
|
||||
logger.error(f"WritingAssistant _generate_continuation error: {e}")
|
||||
# Propagate to ensure frontend does not show stale/generic content
|
||||
raise
|
||||
|
||||
|
||||
134
backend/test_hallucination_detector.py
Normal file
134
backend/test_hallucination_detector.py
Normal file
@@ -0,0 +1,134 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the hallucination detector service.
|
||||
|
||||
This script tests the hallucination detector functionality
|
||||
without requiring the full FastAPI server to be running.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add the backend directory to the Python path
|
||||
backend_dir = Path(__file__).parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from services.hallucination_detector import HallucinationDetector
|
||||
|
||||
async def test_hallucination_detector():
|
||||
"""Test the hallucination detector with sample text."""
|
||||
|
||||
print("🧪 Testing Hallucination Detector")
|
||||
print("=" * 50)
|
||||
|
||||
# Initialize detector
|
||||
detector = HallucinationDetector()
|
||||
|
||||
# Test text with various types of claims
|
||||
test_text = """
|
||||
The Eiffel Tower is located in Paris, France. It was built in 1889 and stands 330 meters tall.
|
||||
The tower was designed by Gustave Eiffel and is one of the most visited monuments in the world.
|
||||
Our company increased sales by 25% last quarter and launched three new products.
|
||||
The weather today is sunny with a temperature of 22 degrees Celsius.
|
||||
"""
|
||||
|
||||
print(f"📝 Test Text:\n{test_text.strip()}\n")
|
||||
|
||||
try:
|
||||
# Test claim extraction
|
||||
print("🔍 Testing claim extraction...")
|
||||
claims = await detector._extract_claims(test_text)
|
||||
print(f"✅ Extracted {len(claims)} claims:")
|
||||
for i, claim in enumerate(claims, 1):
|
||||
print(f" {i}. {claim}")
|
||||
print()
|
||||
|
||||
# Test full hallucination detection
|
||||
print("🔍 Testing full hallucination detection...")
|
||||
result = await detector.detect_hallucinations(test_text)
|
||||
|
||||
print(f"✅ Analysis completed:")
|
||||
print(f" Overall Confidence: {result.overall_confidence:.2f}")
|
||||
print(f" Total Claims: {result.total_claims}")
|
||||
print(f" Supported: {result.supported_claims}")
|
||||
print(f" Refuted: {result.refuted_claims}")
|
||||
print(f" Insufficient: {result.insufficient_claims}")
|
||||
print()
|
||||
|
||||
# Display individual claims
|
||||
print("📊 Individual Claim Analysis:")
|
||||
for i, claim in enumerate(result.claims, 1):
|
||||
print(f"\n Claim {i}: {claim.text}")
|
||||
print(f" Assessment: {claim.assessment}")
|
||||
print(f" Confidence: {claim.confidence:.2f}")
|
||||
print(f" Supporting Sources: {len(claim.supporting_sources)}")
|
||||
print(f" Refuting Sources: {len(claim.refuting_sources)}")
|
||||
|
||||
if claim.supporting_sources:
|
||||
print(" Supporting Sources:")
|
||||
for j, source in enumerate(claim.supporting_sources[:2], 1): # Show first 2
|
||||
print(f" {j}. {source.get('title', 'Untitled')} (Score: {source.get('score', 0):.2f})")
|
||||
|
||||
if claim.refuting_sources:
|
||||
print(" Refuting Sources:")
|
||||
for j, source in enumerate(claim.refuting_sources[:2], 1): # Show first 2
|
||||
print(f" {j}. {source.get('title', 'Untitled')} (Score: {source.get('score', 0):.2f})")
|
||||
|
||||
print("\n✅ Test completed successfully!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Test failed with error: {str(e)}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
async def test_health_check():
|
||||
"""Test the health check functionality."""
|
||||
|
||||
print("\n🏥 Testing Health Check")
|
||||
print("=" * 30)
|
||||
|
||||
detector = HallucinationDetector()
|
||||
|
||||
# Check API availability
|
||||
exa_available = bool(detector.exa_api_key)
|
||||
openai_available = bool(detector.openai_api_key)
|
||||
|
||||
print(f"Exa.ai API Available: {'✅' if exa_available else '❌'}")
|
||||
print(f"OpenAI API Available: {'✅' if openai_available else '❌'}")
|
||||
|
||||
if not exa_available:
|
||||
print("⚠️ Exa.ai API key not found. Set EXA_API_KEY environment variable.")
|
||||
|
||||
if not openai_available:
|
||||
print("⚠️ OpenAI API key not found. Set OPENAI_API_KEY environment variable.")
|
||||
|
||||
if exa_available and openai_available:
|
||||
print("✅ All APIs are available for full functionality.")
|
||||
elif openai_available:
|
||||
print("⚠️ Limited functionality available (claim extraction only).")
|
||||
else:
|
||||
print("❌ No APIs available. Only fallback functionality will work.")
|
||||
|
||||
def main():
|
||||
"""Main test function."""
|
||||
|
||||
print("🚀 Hallucination Detector Test Suite")
|
||||
print("=" * 50)
|
||||
|
||||
# Check environment variables
|
||||
print("🔧 Environment Check:")
|
||||
exa_key = os.getenv('EXA_API_KEY')
|
||||
openai_key = os.getenv('OPENAI_API_KEY')
|
||||
|
||||
print(f"EXA_API_KEY: {'✅ Set' if exa_key else '❌ Not set'}")
|
||||
print(f"OPENAI_API_KEY: {'✅ Set' if openai_key else '❌ Not set'}")
|
||||
print()
|
||||
|
||||
# Run tests
|
||||
asyncio.run(test_health_check())
|
||||
asyncio.run(test_hallucination_detector())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user