Merge branch 'main' into cursor/implement-usage-based-subscription-and-monitoring-0179
This commit is contained in:
351
backend/api/hallucination_detector.py
Normal file
351
backend/api/hallucination_detector.py
Normal file
@@ -0,0 +1,351 @@
|
||||
"""
|
||||
Hallucination Detector API endpoints.
|
||||
|
||||
Provides REST API endpoints for fact-checking and hallucination detection
|
||||
using Exa.ai integration, similar to the Exa.ai demo implementation.
|
||||
"""
|
||||
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from models.hallucination_models import (
|
||||
HallucinationDetectionRequest,
|
||||
HallucinationDetectionResponse,
|
||||
ClaimExtractionRequest,
|
||||
ClaimExtractionResponse,
|
||||
ClaimVerificationRequest,
|
||||
ClaimVerificationResponse,
|
||||
HealthCheckResponse,
|
||||
Claim,
|
||||
SourceDocument,
|
||||
AssessmentType
|
||||
)
|
||||
from services.hallucination_detector import HallucinationDetector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create router
|
||||
router = APIRouter(prefix="/api/hallucination-detector", tags=["Hallucination Detector"])
|
||||
|
||||
# Initialize detector service
|
||||
detector = HallucinationDetector()
|
||||
|
||||
@router.post("/detect", response_model=HallucinationDetectionResponse)
|
||||
async def detect_hallucinations(request: HallucinationDetectionRequest) -> HallucinationDetectionResponse:
|
||||
"""
|
||||
Detect hallucinations in the provided text.
|
||||
|
||||
This endpoint implements the complete hallucination detection pipeline:
|
||||
1. Extract verifiable claims from the text
|
||||
2. Search for evidence using Exa.ai
|
||||
3. Verify each claim against the found sources
|
||||
|
||||
Args:
|
||||
request: HallucinationDetectionRequest with text to analyze
|
||||
|
||||
Returns:
|
||||
HallucinationDetectionResponse with analysis results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"Starting hallucination detection for text of length: {len(request.text)}")
|
||||
|
||||
# Perform hallucination detection
|
||||
result = await detector.detect_hallucinations(request.text)
|
||||
|
||||
# Convert to response format
|
||||
claims = []
|
||||
for claim in result.claims:
|
||||
# Convert sources to SourceDocument objects
|
||||
supporting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim.supporting_sources
|
||||
]
|
||||
|
||||
refuting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim.refuting_sources
|
||||
]
|
||||
|
||||
claim_obj = Claim(
|
||||
text=claim.text,
|
||||
confidence=claim.confidence,
|
||||
assessment=AssessmentType(claim.assessment),
|
||||
supporting_sources=supporting_sources if request.include_sources else [],
|
||||
refuting_sources=refuting_sources if request.include_sources else [],
|
||||
reasoning=getattr(claim, 'reasoning', None)
|
||||
)
|
||||
claims.append(claim_obj)
|
||||
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
response = HallucinationDetectionResponse(
|
||||
success=True,
|
||||
claims=claims,
|
||||
overall_confidence=result.overall_confidence,
|
||||
total_claims=result.total_claims,
|
||||
supported_claims=result.supported_claims,
|
||||
refuted_claims=result.refuted_claims,
|
||||
insufficient_claims=result.insufficient_claims,
|
||||
timestamp=result.timestamp,
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
logger.info(f"Hallucination detection completed successfully. Processing time: {processing_time}ms")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hallucination detection: {str(e)}")
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Return proper error response
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"message": "Hallucination detection failed. Please check API keys and try again.",
|
||||
"timestamp": time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
"processing_time_ms": processing_time
|
||||
}
|
||||
)
|
||||
|
||||
@router.post("/extract-claims", response_model=ClaimExtractionResponse)
|
||||
async def extract_claims(request: ClaimExtractionRequest) -> ClaimExtractionResponse:
|
||||
"""
|
||||
Extract verifiable claims from the provided text.
|
||||
|
||||
This endpoint performs only the claim extraction step of the
|
||||
hallucination detection pipeline.
|
||||
|
||||
Args:
|
||||
request: ClaimExtractionRequest with text to analyze
|
||||
|
||||
Returns:
|
||||
ClaimExtractionResponse with extracted claims
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Extracting claims from text of length: {len(request.text)}")
|
||||
|
||||
# Extract claims
|
||||
claims = await detector._extract_claims(request.text)
|
||||
|
||||
# Limit claims if requested
|
||||
if request.max_claims and len(claims) > request.max_claims:
|
||||
claims = claims[:request.max_claims]
|
||||
|
||||
response = ClaimExtractionResponse(
|
||||
success=True,
|
||||
claims=claims,
|
||||
total_claims=len(claims),
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
)
|
||||
|
||||
logger.info(f"Claim extraction completed. Extracted {len(claims)} claims")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in claim extraction: {str(e)}")
|
||||
|
||||
return ClaimExtractionResponse(
|
||||
success=False,
|
||||
claims=[],
|
||||
total_claims=0,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@router.post("/verify-claim", response_model=ClaimVerificationResponse)
|
||||
async def verify_claim(request: ClaimVerificationRequest) -> ClaimVerificationResponse:
|
||||
"""
|
||||
Verify a single claim against available sources.
|
||||
|
||||
This endpoint performs claim verification using Exa.ai search
|
||||
and LLM-based assessment.
|
||||
|
||||
Args:
|
||||
request: ClaimVerificationRequest with claim to verify
|
||||
|
||||
Returns:
|
||||
ClaimVerificationResponse with verification results
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"Verifying claim: {request.claim[:100]}...")
|
||||
|
||||
# Verify the claim
|
||||
claim_result = await detector._verify_claim(request.claim)
|
||||
|
||||
# Convert to response format
|
||||
supporting_sources = []
|
||||
refuting_sources = []
|
||||
|
||||
if request.include_sources:
|
||||
supporting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim_result.supporting_sources
|
||||
]
|
||||
|
||||
refuting_sources = [
|
||||
SourceDocument(
|
||||
title=source.get('title', 'Untitled'),
|
||||
url=source.get('url', ''),
|
||||
text=source.get('text', ''),
|
||||
published_date=source.get('publishedDate'),
|
||||
author=source.get('author'),
|
||||
score=source.get('score', 0.5)
|
||||
)
|
||||
for source in claim_result.refuting_sources
|
||||
]
|
||||
|
||||
claim_obj = Claim(
|
||||
text=claim_result.text,
|
||||
confidence=claim_result.confidence,
|
||||
assessment=AssessmentType(claim_result.assessment),
|
||||
supporting_sources=supporting_sources,
|
||||
refuting_sources=refuting_sources,
|
||||
reasoning=getattr(claim_result, 'reasoning', None)
|
||||
)
|
||||
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
response = ClaimVerificationResponse(
|
||||
success=True,
|
||||
claim=claim_obj,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
processing_time_ms=processing_time
|
||||
)
|
||||
|
||||
logger.info(f"Claim verification completed. Assessment: {claim_result.assessment}")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in claim verification: {str(e)}")
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
return ClaimVerificationResponse(
|
||||
success=False,
|
||||
claim=Claim(
|
||||
text=request.claim,
|
||||
confidence=0.0,
|
||||
assessment=AssessmentType.INSUFFICIENT_INFORMATION,
|
||||
supporting_sources=[],
|
||||
refuting_sources=[],
|
||||
reasoning="Error during verification"
|
||||
),
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S'),
|
||||
processing_time_ms=processing_time,
|
||||
error=str(e)
|
||||
)
|
||||
|
||||
@router.get("/health", response_model=HealthCheckResponse)
|
||||
async def health_check() -> HealthCheckResponse:
|
||||
"""
|
||||
Health check endpoint for the hallucination detector service.
|
||||
|
||||
Returns:
|
||||
HealthCheckResponse with service status and API availability
|
||||
"""
|
||||
try:
|
||||
# Check API availability
|
||||
exa_available = bool(detector.exa_api_key)
|
||||
openai_available = bool(detector.openai_api_key)
|
||||
|
||||
status = "healthy" if (exa_available or openai_available) else "degraded"
|
||||
|
||||
response = HealthCheckResponse(
|
||||
status=status,
|
||||
version="1.0.0",
|
||||
exa_api_available=exa_available,
|
||||
openai_api_available=openai_available,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in health check: {str(e)}")
|
||||
|
||||
return HealthCheckResponse(
|
||||
status="unhealthy",
|
||||
version="1.0.0",
|
||||
exa_api_available=False,
|
||||
openai_api_available=False,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
)
|
||||
|
||||
@router.get("/demo")
|
||||
async def demo_endpoint() -> Dict[str, Any]:
|
||||
"""
|
||||
Demo endpoint showing example usage of the hallucination detector.
|
||||
|
||||
Returns:
|
||||
Dictionary with example request/response data
|
||||
"""
|
||||
return {
|
||||
"description": "Hallucination Detector API Demo",
|
||||
"version": "1.0.0",
|
||||
"endpoints": {
|
||||
"detect": {
|
||||
"method": "POST",
|
||||
"path": "/api/hallucination-detector/detect",
|
||||
"description": "Detect hallucinations in text using Exa.ai",
|
||||
"example_request": {
|
||||
"text": "The Eiffel Tower is located in Paris and was built in 1889. It is 330 meters tall.",
|
||||
"include_sources": True,
|
||||
"max_claims": 5
|
||||
}
|
||||
},
|
||||
"extract_claims": {
|
||||
"method": "POST",
|
||||
"path": "/api/hallucination-detector/extract-claims",
|
||||
"description": "Extract verifiable claims from text",
|
||||
"example_request": {
|
||||
"text": "Our company increased sales by 25% last quarter. We launched 3 new products.",
|
||||
"max_claims": 10
|
||||
}
|
||||
},
|
||||
"verify_claim": {
|
||||
"method": "POST",
|
||||
"path": "/api/hallucination-detector/verify-claim",
|
||||
"description": "Verify a single claim against sources",
|
||||
"example_request": {
|
||||
"claim": "The Eiffel Tower is in Paris",
|
||||
"include_sources": True
|
||||
}
|
||||
}
|
||||
},
|
||||
"features": [
|
||||
"Claim extraction using LLM",
|
||||
"Evidence search using Exa.ai",
|
||||
"Claim verification with confidence scores",
|
||||
"Source attribution and credibility assessment",
|
||||
"Fallback mechanisms for API unavailability"
|
||||
]
|
||||
}
|
||||
@@ -45,6 +45,24 @@ class PersonaGenerationResponse(BaseModel):
|
||||
data_sufficiency: Optional[float] = None
|
||||
platforms_generated: List[str] = []
|
||||
|
||||
class LinkedInPersonaValidationRequest(BaseModel):
|
||||
"""Request model for LinkedIn persona validation."""
|
||||
persona_data: Dict[str, Any]
|
||||
|
||||
class LinkedInPersonaValidationResponse(BaseModel):
|
||||
"""Response model for LinkedIn persona validation."""
|
||||
is_valid: bool
|
||||
quality_score: float
|
||||
completeness_score: float
|
||||
professional_context_score: float
|
||||
linkedin_optimization_score: float
|
||||
missing_fields: List[str]
|
||||
incomplete_fields: List[str]
|
||||
recommendations: List[str]
|
||||
quality_issues: List[str]
|
||||
strengths: List[str]
|
||||
validation_details: Dict[str, Any]
|
||||
|
||||
# Dependency to get persona service
|
||||
def get_persona_service() -> PersonaAnalysisService:
|
||||
"""Get the persona analysis service instance."""
|
||||
@@ -380,6 +398,211 @@ async def get_supported_platforms():
|
||||
"description": "Newsletter platform for building subscriber relationships",
|
||||
"format": "email newsletter",
|
||||
"subscription_focus": True
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
class LinkedInOptimizationRequest(BaseModel):
|
||||
"""Request model for LinkedIn algorithm optimization."""
|
||||
persona_data: Dict[str, Any]
|
||||
|
||||
|
||||
class LinkedInOptimizationResponse(BaseModel):
|
||||
"""Response model for LinkedIn algorithm optimization."""
|
||||
optimized_persona: Dict[str, Any]
|
||||
optimization_applied: bool
|
||||
optimization_details: Dict[str, Any]
|
||||
|
||||
|
||||
async def validate_linkedin_persona(
|
||||
request: LinkedInPersonaValidationRequest,
|
||||
persona_service: PersonaAnalysisService = Depends(get_persona_service)
|
||||
):
|
||||
"""
|
||||
Validate LinkedIn persona data for completeness and quality.
|
||||
|
||||
This endpoint provides comprehensive validation of LinkedIn persona data,
|
||||
including core fields, LinkedIn-specific optimizations, professional context,
|
||||
and content quality assessments.
|
||||
"""
|
||||
try:
|
||||
logger.info("Validating LinkedIn persona data")
|
||||
|
||||
# Get LinkedIn persona service
|
||||
from services.persona.linkedin.linkedin_persona_service import LinkedInPersonaService
|
||||
linkedin_service = LinkedInPersonaService()
|
||||
|
||||
# Validate the persona data
|
||||
validation_results = linkedin_service.validate_linkedin_persona(request.persona_data)
|
||||
|
||||
logger.info(f"LinkedIn persona validation completed: Quality Score: {validation_results['quality_score']:.1f}%")
|
||||
|
||||
return LinkedInPersonaValidationResponse(**validation_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating LinkedIn persona: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to validate LinkedIn persona: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def optimize_linkedin_persona(
|
||||
request: LinkedInOptimizationRequest,
|
||||
persona_service: PersonaAnalysisService = Depends(get_persona_service)
|
||||
):
|
||||
"""
|
||||
Optimize LinkedIn persona data for maximum algorithm performance.
|
||||
|
||||
This endpoint applies comprehensive LinkedIn algorithm optimization to persona data,
|
||||
including content quality optimization, multimedia strategy, engagement optimization,
|
||||
timing optimization, and professional context optimization.
|
||||
"""
|
||||
try:
|
||||
logger.info("Optimizing LinkedIn persona for algorithm performance")
|
||||
|
||||
# Get LinkedIn persona service
|
||||
from services.persona.linkedin.linkedin_persona_service import LinkedInPersonaService
|
||||
linkedin_service = LinkedInPersonaService()
|
||||
|
||||
# Apply algorithm optimization
|
||||
optimized_persona = linkedin_service.optimize_for_linkedin_algorithm(request.persona_data)
|
||||
|
||||
# Extract optimization details
|
||||
optimization_details = optimized_persona.get("algorithm_optimization", {})
|
||||
|
||||
logger.info("✅ LinkedIn persona algorithm optimization completed successfully")
|
||||
|
||||
return LinkedInOptimizationResponse(
|
||||
optimized_persona=optimized_persona,
|
||||
optimization_applied=True,
|
||||
optimization_details={
|
||||
"optimization_categories": list(optimization_details.keys()),
|
||||
"total_optimization_strategies": sum(len(strategies) if isinstance(strategies, list) else 1
|
||||
for category in optimization_details.values()
|
||||
for strategies in category.values() if isinstance(category, dict)),
|
||||
"optimization_timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing LinkedIn persona: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to optimize LinkedIn persona: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
class FacebookPersonaValidationRequest(BaseModel):
|
||||
"""Request model for Facebook persona validation."""
|
||||
persona_data: Dict[str, Any]
|
||||
|
||||
|
||||
class FacebookPersonaValidationResponse(BaseModel):
|
||||
"""Response model for Facebook persona validation."""
|
||||
is_valid: bool
|
||||
quality_score: float
|
||||
completeness_score: float
|
||||
facebook_optimization_score: float
|
||||
engagement_strategy_score: float
|
||||
content_format_score: float
|
||||
audience_targeting_score: float
|
||||
community_building_score: float
|
||||
missing_fields: List[str]
|
||||
incomplete_fields: List[str]
|
||||
recommendations: List[str]
|
||||
quality_issues: List[str]
|
||||
strengths: List[str]
|
||||
validation_details: Dict[str, Any]
|
||||
|
||||
|
||||
class FacebookOptimizationRequest(BaseModel):
|
||||
"""Request model for Facebook algorithm optimization."""
|
||||
persona_data: Dict[str, Any]
|
||||
|
||||
|
||||
class FacebookOptimizationResponse(BaseModel):
|
||||
"""Response model for Facebook algorithm optimization."""
|
||||
optimized_persona: Dict[str, Any]
|
||||
optimization_applied: bool
|
||||
optimization_details: Dict[str, Any]
|
||||
|
||||
|
||||
async def validate_facebook_persona(
|
||||
request: FacebookPersonaValidationRequest,
|
||||
persona_service: PersonaAnalysisService = Depends(get_persona_service)
|
||||
):
|
||||
"""
|
||||
Validate Facebook persona data for completeness and quality.
|
||||
|
||||
This endpoint provides comprehensive validation of Facebook persona data,
|
||||
including core fields, Facebook-specific optimizations, engagement strategies,
|
||||
content formats, audience targeting, and community building assessments.
|
||||
"""
|
||||
try:
|
||||
logger.info("Validating Facebook persona data")
|
||||
|
||||
# Get Facebook persona service
|
||||
from services.persona.facebook.facebook_persona_service import FacebookPersonaService
|
||||
facebook_service = FacebookPersonaService()
|
||||
|
||||
# Validate the persona data
|
||||
validation_results = facebook_service.validate_facebook_persona(request.persona_data)
|
||||
|
||||
logger.info(f"Facebook persona validation completed: Quality Score: {validation_results['quality_score']:.1f}%")
|
||||
|
||||
return FacebookPersonaValidationResponse(**validation_results)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error validating Facebook persona: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to validate Facebook persona: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
async def optimize_facebook_persona(
|
||||
request: FacebookOptimizationRequest,
|
||||
persona_service: PersonaAnalysisService = Depends(get_persona_service)
|
||||
):
|
||||
"""
|
||||
Optimize Facebook persona data for maximum algorithm performance.
|
||||
|
||||
This endpoint applies comprehensive Facebook algorithm optimization to persona data,
|
||||
including engagement optimization, content quality optimization, timing optimization,
|
||||
audience targeting optimization, and community building strategies.
|
||||
"""
|
||||
try:
|
||||
logger.info("Optimizing Facebook persona for algorithm performance")
|
||||
|
||||
# Get Facebook persona service
|
||||
from services.persona.facebook.facebook_persona_service import FacebookPersonaService
|
||||
facebook_service = FacebookPersonaService()
|
||||
|
||||
# Apply algorithm optimization
|
||||
optimized_persona = facebook_service.optimize_for_facebook_algorithm(request.persona_data)
|
||||
|
||||
# Extract optimization details
|
||||
optimization_details = optimized_persona.get("algorithm_optimization", {})
|
||||
|
||||
logger.info("✅ Facebook persona algorithm optimization completed successfully")
|
||||
|
||||
# Use the optimization metadata from the service
|
||||
optimization_metadata = optimized_persona.get("optimization_metadata", {})
|
||||
|
||||
return FacebookOptimizationResponse(
|
||||
optimized_persona=optimized_persona,
|
||||
optimization_applied=True,
|
||||
optimization_details={
|
||||
"optimization_categories": optimization_metadata.get("optimization_categories", []),
|
||||
"total_optimization_strategies": optimization_metadata.get("total_optimization_strategies", 0),
|
||||
"optimization_timestamp": optimization_metadata.get("optimization_timestamp", datetime.utcnow().isoformat())
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error optimizing Facebook persona: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to optimize Facebook persona: {str(e)}"
|
||||
)
|
||||
@@ -16,7 +16,19 @@ from api.persona import (
|
||||
validate_persona_generation_readiness,
|
||||
generate_persona_preview,
|
||||
get_supported_platforms,
|
||||
PersonaGenerationRequest
|
||||
validate_linkedin_persona,
|
||||
optimize_linkedin_persona,
|
||||
validate_facebook_persona,
|
||||
optimize_facebook_persona,
|
||||
PersonaGenerationRequest,
|
||||
LinkedInPersonaValidationRequest,
|
||||
LinkedInPersonaValidationResponse,
|
||||
LinkedInOptimizationRequest,
|
||||
LinkedInOptimizationResponse,
|
||||
FacebookPersonaValidationRequest,
|
||||
FacebookPersonaValidationResponse,
|
||||
FacebookOptimizationRequest,
|
||||
FacebookOptimizationResponse
|
||||
)
|
||||
|
||||
from services.persona_replication_engine import PersonaReplicationEngine
|
||||
@@ -89,6 +101,34 @@ async def get_supported_platforms_endpoint():
|
||||
"""Get list of supported platforms for persona generation."""
|
||||
return await get_supported_platforms()
|
||||
|
||||
@router.post("/linkedin/validate", response_model=LinkedInPersonaValidationResponse)
|
||||
async def validate_linkedin_persona_endpoint(
|
||||
request: LinkedInPersonaValidationRequest
|
||||
):
|
||||
"""Validate LinkedIn persona data for completeness and quality."""
|
||||
return await validate_linkedin_persona(request)
|
||||
|
||||
@router.post("/linkedin/optimize", response_model=LinkedInOptimizationResponse)
|
||||
async def optimize_linkedin_persona_endpoint(
|
||||
request: LinkedInOptimizationRequest
|
||||
):
|
||||
"""Optimize LinkedIn persona data for maximum algorithm performance."""
|
||||
return await optimize_linkedin_persona(request)
|
||||
|
||||
@router.post("/facebook/validate", response_model=FacebookPersonaValidationResponse)
|
||||
async def validate_facebook_persona_endpoint(
|
||||
request: FacebookPersonaValidationRequest
|
||||
):
|
||||
"""Validate Facebook persona data for completeness and quality."""
|
||||
return await validate_facebook_persona(request)
|
||||
|
||||
@router.post("/facebook/optimize", response_model=FacebookOptimizationResponse)
|
||||
async def optimize_facebook_persona_endpoint(
|
||||
request: FacebookOptimizationRequest
|
||||
):
|
||||
"""Optimize Facebook persona data for maximum algorithm performance."""
|
||||
return await optimize_facebook_persona(request)
|
||||
|
||||
@router.post("/generate-content")
|
||||
async def generate_content_with_persona_endpoint(
|
||||
request: Dict[str, Any]
|
||||
|
||||
61
backend/api/writing_assistant.py
Normal file
61
backend/api/writing_assistant.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from services.writing_assistant import WritingAssistantService
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
|
||||
|
||||
class SuggestRequest(BaseModel):
|
||||
text: str
|
||||
max_results: int | None = 1
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
text: str | None = ""
|
||||
author: str | None = ""
|
||||
published_date: str | None = ""
|
||||
score: float
|
||||
|
||||
|
||||
class SuggestionModel(BaseModel):
|
||||
text: str
|
||||
confidence: float
|
||||
sources: List[SourceModel]
|
||||
|
||||
|
||||
class SuggestResponse(BaseModel):
|
||||
success: bool
|
||||
suggestions: List[SuggestionModel]
|
||||
|
||||
|
||||
assistant_service = WritingAssistantService()
|
||||
|
||||
|
||||
@router.post("/suggest", response_model=SuggestResponse)
|
||||
async def suggest_endpoint(req: SuggestRequest) -> SuggestResponse:
|
||||
try:
|
||||
suggestions = await assistant_service.suggest(req.text, req.max_results or 1)
|
||||
return SuggestResponse(
|
||||
success=True,
|
||||
suggestions=[
|
||||
SuggestionModel(
|
||||
text=s.text,
|
||||
confidence=s.confidence,
|
||||
sources=[
|
||||
SourceModel(**src) for src in s.sources
|
||||
],
|
||||
)
|
||||
for s in suggestions
|
||||
],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Writing assistant error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user