feat: Brainstorm Topics with GSC + Issue #518 fixes + Blog Editor enhancements
Issue #518 - Subscription not updating after checkout: - Fix stale closure in SubscriptionContext checkout polling (use subscriptionRef) - Move checkout success polling from InitialRouteHandler into SubscriptionContext - Remove redundant polling code from InitialRouteHandler - Fix plan label: 'Free' instead of 'No Plan', proper capitalization - Add plan refresh button in UserBadge - Add 'View Costing Details' to UserBadge dropdown - Rename 'ALwrity Podcast Maker' to 'Podcast Creator' across UI - Clean subscription=success URL param after verification Blog Writer WYSIWYG Editor enhancements: - Per-section preview toggle (view/edit icons) - Enhanced hover-based toolbar - Circular SVG progress stats bar with detailed tooltip - Research tool chips in stats bar footer - Per-section TTS with useTextToSpeech hook (browser native) - Full blog preview modal with print/PDF support - PlayAllTTSButton: sequential playback with progress bar - OnThisPageNav: floating sidebar with scroll tracking - Section data attributes for scroll anchoring GSC Brainstorm Topics feature: - Backend: gsc_brainstorm_service.py (rule-based + LLM recommendations) - Backend: POST /gsc/brainstorm endpoint with 3-word minimum validation - Frontend: gscBrainstorm.ts API client - Frontend: useGSCBrainstormConnection hook (popup OAuth, no /onboarding redirect) - Frontend: useGSCBrainstorm hook (connect check + brainstorm call) - Frontend: GSCBrainstormModal (3-tab results: Opportunities, Gaps, AI Recs) - Frontend: BrainstormButton (visible at 3+ words, GSC connect overlay) - Wire BrainstormButton into ManualResearchForm and ResearchAction - Add blog_writer to gsc_auth router features for ALWRITY_ENABLED_FEATURES
This commit is contained in:
@@ -44,8 +44,8 @@ CORE_ROUTER_REGISTRY = [
|
||||
OPTIONAL_ROUTER_REGISTRY = [
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story_writer"}},
|
||||
{"name": "wix", "module": "api.wix_routes", "attr": "router", "features": {"all"}},
|
||||
{"name": "wix_test", "module": "api.wix_routes", "attr": "qa_router", "features": {"all"}},
|
||||
{"name": "wix", "module": "api.wix_routes", "attr": "router", "features": {"all"}},
|
||||
{"name": "wix_test", "module": "api.wix_routes", "attr": "qa_router", "features": {"all"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "persona", "module": "api.persona_routes", "attr": "router", "features": {"all", "persona"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video_studio"}},
|
||||
|
||||
192
backend/api/charts.py
Normal file
192
backend/api/charts.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""
|
||||
Chart API — Shared chart generation endpoints for Blog Writer, Podcast Maker, etc.
|
||||
|
||||
Two modes:
|
||||
1. Explicit: POST /api/charts/generate with { chart_type, chart_data, title }
|
||||
2. AI-driven: POST /api/charts/generate with { text } → LLM infers chart_type + data
|
||||
|
||||
Both return { preview_url, chart_id, chart_type?, chart_data?, title? }
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.chart_service import get_chart_service, VALID_CHART_TYPES
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/charts", tags=["Charts"])
|
||||
|
||||
|
||||
class ChartGenerateRequest(BaseModel):
|
||||
"""Request for chart generation.
|
||||
|
||||
Provide either:
|
||||
- chart_type + chart_data (explicit mode), OR
|
||||
- text (AI inference mode — LLM determines chart_type + data)
|
||||
"""
|
||||
chart_data: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Chart data dict (labels, values, before/after, etc.)"
|
||||
)
|
||||
chart_type: Optional[str] = Field(
|
||||
default=None,
|
||||
description=f"Chart type: {', '.join(VALID_CHART_TYPES)}"
|
||||
)
|
||||
title: str = Field(default="", description="Chart title")
|
||||
subtitle: Optional[str] = Field(default="", description="Optional subtitle")
|
||||
text: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Text to infer chart from (AI mode). Mutually exclusive with chart_type+chart_data."
|
||||
)
|
||||
section_heading: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Blog section heading for context (AI mode with research)"
|
||||
)
|
||||
section_key_points: Optional[list] = Field(
|
||||
default=None,
|
||||
description="Key points from the section (AI mode with research)"
|
||||
)
|
||||
|
||||
|
||||
class ChartGenerateResponse(BaseModel):
|
||||
"""Response for chart generation."""
|
||||
preview_url: str = ""
|
||||
chart_id: str = ""
|
||||
chart_type: Optional[str] = None
|
||||
chart_data: Optional[Dict[str, Any]] = None
|
||||
title: Optional[str] = None
|
||||
warnings: list = Field(default_factory=list, description="Pipeline warnings (e.g. Exa search failures)")
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ChartGenerateResponse)
|
||||
async def generate_chart(
|
||||
request: ChartGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Generate a chart PNG preview.
|
||||
|
||||
Two modes:
|
||||
1. Explicit: Provide chart_type + chart_data
|
||||
2. AI-driven: Provide text, and the LLM infers chart_type + chart_data
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
chart_svc = get_chart_service(user_id=user_id)
|
||||
|
||||
if request.text and not request.chart_type:
|
||||
# AI inference mode
|
||||
logger.info(f"[Charts] AI inference mode for user {user_id}, text length={len(request.text)}")
|
||||
result = await chart_svc.generate_chart_from_text(
|
||||
text=request.text,
|
||||
user_id=user_id,
|
||||
section_heading=request.section_heading,
|
||||
section_key_points=request.section_key_points,
|
||||
)
|
||||
|
||||
if not result.get("path"):
|
||||
raise HTTPException(status_code=500, detail="Chart generation failed")
|
||||
|
||||
chart_id = result["chart_id"]
|
||||
filename = result.get("filename", f"chart_preview_{chart_id}.png")
|
||||
|
||||
return ChartGenerateResponse(
|
||||
preview_url=f"/api/charts/preview/{chart_id}/{filename}",
|
||||
chart_id=chart_id,
|
||||
chart_type=result.get("chart_type"),
|
||||
chart_data=result.get("chart_data"),
|
||||
title=result.get("title"),
|
||||
warnings=result.get("warnings", []),
|
||||
)
|
||||
|
||||
elif request.chart_type and request.chart_data:
|
||||
# Explicit mode
|
||||
chart_type = request.chart_type
|
||||
if chart_type not in VALID_CHART_TYPES:
|
||||
# Try normalizing aliases
|
||||
from services.chart_service import _normalize_chart_type
|
||||
chart_type = _normalize_chart_type(chart_type)
|
||||
if chart_type not in VALID_CHART_TYPES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid chart_type. Must be one of: {VALID_CHART_TYPES}"
|
||||
)
|
||||
|
||||
logger.info(f"[Charts] Explicit mode: type={chart_type}, user={user_id}")
|
||||
|
||||
chart_id = uuid.uuid4().hex[:8]
|
||||
result = chart_svc.generate_chart(
|
||||
chart_data=request.chart_data,
|
||||
chart_type=chart_type,
|
||||
title=request.title,
|
||||
subtitle=request.subtitle or "",
|
||||
chart_id=chart_id,
|
||||
)
|
||||
|
||||
if not result.get("path"):
|
||||
raise HTTPException(status_code=500, detail="Chart generation failed — check chart_data format")
|
||||
|
||||
filename = result.get("filename", f"chart_preview_{chart_id}.png")
|
||||
|
||||
return ChartGenerateResponse(
|
||||
preview_url=f"/api/charts/preview/{chart_id}/{filename}",
|
||||
chart_id=chart_id,
|
||||
chart_type=chart_type,
|
||||
chart_data=request.chart_data,
|
||||
title=request.title,
|
||||
)
|
||||
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provide either 'text' (AI mode) or 'chart_type' + 'chart_data' (explicit mode)"
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Charts] Generation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Chart generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/preview/{chart_id}/{filename}")
|
||||
async def serve_chart_preview(
|
||||
chart_id: str,
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve chart preview PNG files. Auth via header or query token."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
|
||||
chart_svc = get_chart_service(user_id=user_id)
|
||||
file_path = chart_svc.get_chart_preview_path(chart_id)
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Chart preview not found")
|
||||
|
||||
if not str(file_path.resolve()).startswith(str(chart_svc.output_dir.resolve())):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type="image/png",
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def charts_health():
|
||||
"""Health check for Charts service."""
|
||||
return {"status": "ok", "service": "charts"}
|
||||
@@ -8,7 +8,7 @@ using Exa.ai integration, similar to the Exa.ai demo implementation.
|
||||
import time
|
||||
import logging
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from models.hallucination_models import (
|
||||
@@ -24,6 +24,7 @@ from models.hallucination_models import (
|
||||
AssessmentType
|
||||
)
|
||||
from services.hallucination_detector import HallucinationDetector
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -34,7 +35,7 @@ router = APIRouter(prefix="/api/hallucination-detector", tags=["Hallucination De
|
||||
detector = HallucinationDetector()
|
||||
|
||||
@router.post("/detect", response_model=HallucinationDetectionResponse)
|
||||
async def detect_hallucinations(request: HallucinationDetectionRequest) -> HallucinationDetectionResponse:
|
||||
async def detect_hallucinations(request: HallucinationDetectionRequest, current_user: Dict[str, Any] = Depends(get_current_user)) -> HallucinationDetectionResponse:
|
||||
"""
|
||||
Detect hallucinations in the provided text.
|
||||
|
||||
@@ -54,8 +55,10 @@ async def detect_hallucinations(request: HallucinationDetectionRequest) -> Hallu
|
||||
try:
|
||||
logger.info(f"Starting hallucination detection for text of length: {len(request.text)}")
|
||||
|
||||
user_id = current_user.get("id")
|
||||
|
||||
# Perform hallucination detection
|
||||
result = await detector.detect_hallucinations(request.text)
|
||||
result = await detector.detect_hallucinations(request.text, user_id=user_id)
|
||||
|
||||
# Convert to response format
|
||||
claims = []
|
||||
@@ -113,6 +116,8 @@ async def detect_hallucinations(request: HallucinationDetectionRequest) -> Hallu
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
logger.error(f"Error in hallucination detection: {str(e)}")
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -174,7 +179,7 @@ async def extract_claims(request: ClaimExtractionRequest) -> ClaimExtractionResp
|
||||
)
|
||||
|
||||
@router.post("/verify-claim", response_model=ClaimVerificationResponse)
|
||||
async def verify_claim(request: ClaimVerificationRequest) -> ClaimVerificationResponse:
|
||||
async def verify_claim(request: ClaimVerificationRequest, current_user: Dict[str, Any] = Depends(get_current_user)) -> ClaimVerificationResponse:
|
||||
"""
|
||||
Verify a single claim against available sources.
|
||||
|
||||
@@ -192,8 +197,10 @@ async def verify_claim(request: ClaimVerificationRequest) -> ClaimVerificationRe
|
||||
try:
|
||||
logger.info(f"Verifying claim: {request.claim[:100]}...")
|
||||
|
||||
user_id = current_user.get("id")
|
||||
|
||||
# Verify the claim
|
||||
claim_result = await detector._verify_claim(request.claim)
|
||||
claim_result = await detector._verify_claim(request.claim, user_id=user_id)
|
||||
|
||||
# Convert to response format
|
||||
supporting_sources = []
|
||||
@@ -246,6 +253,8 @@ async def verify_claim(request: ClaimVerificationRequest) -> ClaimVerificationRe
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
logger.error(f"Error in claim verification: {str(e)}")
|
||||
processing_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -273,17 +282,21 @@ async def health_check() -> HealthCheckResponse:
|
||||
HealthCheckResponse with service status and API availability
|
||||
"""
|
||||
try:
|
||||
# Check API availability
|
||||
exa_available = bool(detector.exa_api_key)
|
||||
openai_available = bool(detector.openai_api_key)
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
try:
|
||||
exa_provider = ExaResearchProvider()
|
||||
exa_available = bool(exa_provider.api_key)
|
||||
except RuntimeError:
|
||||
exa_available = False
|
||||
llm_available = True # llm_text_gen handles provider selection via GPT_PROVIDER
|
||||
|
||||
status = "healthy" if (exa_available or openai_available) else "degraded"
|
||||
status = "healthy" if (exa_available and llm_available) else ("degraded" if exa_available or llm_available else "unhealthy")
|
||||
|
||||
response = HealthCheckResponse(
|
||||
status=status,
|
||||
version="1.0.0",
|
||||
exa_api_available=exa_available,
|
||||
openai_api_available=openai_available,
|
||||
openai_api_available=llm_available,
|
||||
timestamp=time.strftime('%Y-%m-%dT%H:%M:%S')
|
||||
)
|
||||
|
||||
|
||||
185
backend/api/links.py
Normal file
185
backend/api/links.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Link Search API — Internal & external link discovery and reword-with-links.
|
||||
|
||||
Endpoints:
|
||||
POST /api/links/search — Search for internal or external links via Exa
|
||||
POST /api/links/reword — Reword text to naturally incorporate selected links
|
||||
GET /api/links/health — Health check
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.link_search_service import get_link_search_service
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/links", tags=["Links"])
|
||||
|
||||
|
||||
class LinkSearchRequest(BaseModel):
|
||||
"""Request for link search (internal or external)."""
|
||||
query: str = Field(..., description="Search query (typically section heading or topic)")
|
||||
link_type: str = Field(
|
||||
...,
|
||||
description="Type of links: 'internal' or 'external'",
|
||||
)
|
||||
site_url: Optional[str] = Field(
|
||||
default=None,
|
||||
description="User's website URL (required for internal links, optional for external to exclude own domain)",
|
||||
)
|
||||
num_results: int = Field(default=5, description="Number of results to return", ge=1, le=15)
|
||||
|
||||
|
||||
class LinkSearchResult(BaseModel):
|
||||
"""A single link search result."""
|
||||
title: str = ""
|
||||
url: str = ""
|
||||
text: str = ""
|
||||
publishedDate: str = ""
|
||||
author: str = ""
|
||||
score: float = 0.5
|
||||
|
||||
|
||||
class LinkSearchResponse(BaseModel):
|
||||
"""Response for link search."""
|
||||
results: List[LinkSearchResult] = Field(default_factory=list)
|
||||
warnings: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class RewordRequest(BaseModel):
|
||||
"""Request to reword text with selected links."""
|
||||
section_text: str = Field(..., description="Full section text")
|
||||
selected_text: Optional[str] = Field(
|
||||
default=None,
|
||||
description="If provided, only reword this portion of the text",
|
||||
)
|
||||
section_heading: Optional[str] = Field(default=None, description="Section heading for context")
|
||||
links: List[Dict[str, str]] = Field(
|
||||
...,
|
||||
description="List of {'url': str, 'title': str} dicts to incorporate",
|
||||
)
|
||||
|
||||
|
||||
class RewordResponse(BaseModel):
|
||||
"""Response for reword-with-links."""
|
||||
reworded_text: str = ""
|
||||
warnings: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.post("/search", response_model=LinkSearchResponse)
|
||||
async def search_links(
|
||||
request: LinkSearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Search for internal or external links using Exa."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if request.link_type not in ("internal", "external"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="link_type must be 'internal' or 'external'",
|
||||
)
|
||||
|
||||
if request.link_type == "internal" and not request.site_url:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="site_url is required for internal link search",
|
||||
)
|
||||
|
||||
if len(request.query) > 500:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Query must be 500 characters or less",
|
||||
)
|
||||
|
||||
service = get_link_search_service(user_id=user_id)
|
||||
|
||||
try:
|
||||
if request.link_type == "internal":
|
||||
logger.info(f"[Links] Internal search: query='{request.query[:50]}', site='{request.site_url}', user={user_id}")
|
||||
result = await service.search_internal(
|
||||
query=request.query,
|
||||
site_url=request.site_url,
|
||||
user_id=user_id,
|
||||
num_results=request.num_results,
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Links] External search: query='{request.query[:50]}', user={user_id}")
|
||||
result = await service.search_external(
|
||||
query=request.query,
|
||||
site_url=request.site_url,
|
||||
user_id=user_id,
|
||||
num_results=request.num_results,
|
||||
)
|
||||
|
||||
return LinkSearchResponse(
|
||||
results=[LinkSearchResult(**r) for r in result.get("results", [])],
|
||||
warnings=result.get("warnings", []),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Links] Search failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Link search failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/reword", response_model=RewordResponse)
|
||||
async def reword_with_links(
|
||||
request: RewordRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Reword text to naturally incorporate selected links."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.links:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one link must be provided",
|
||||
)
|
||||
|
||||
# Validate each link has a url
|
||||
for i, link in enumerate(request.links):
|
||||
if not link.get("url"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Link at index {i} is missing a 'url' field",
|
||||
)
|
||||
|
||||
if len(request.section_text) > 10000:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="section_text must be 10000 characters or less",
|
||||
)
|
||||
|
||||
service = get_link_search_service(user_id=user_id)
|
||||
|
||||
try:
|
||||
logger.info(f"[Links] Reword: heading='{request.section_heading}', links={len(request.links)}, user={user_id}")
|
||||
result = service.reword_with_links(
|
||||
section_text=request.section_text,
|
||||
links=request.links,
|
||||
section_heading=request.section_heading,
|
||||
selected_text=request.selected_text,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
return RewordResponse(
|
||||
reworded_text=result.get("reworded_text", request.section_text),
|
||||
warnings=result.get("warnings", []),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Links] Reword failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Reword failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def links_health():
|
||||
"""Health check for Links service."""
|
||||
return {"status": "ok", "service": "links"}
|
||||
@@ -123,3 +123,187 @@ async def stripe_webhook(
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook: {e}")
|
||||
raise HTTPException(status_code=500, detail="Webhook processing failed")
|
||||
|
||||
@router.get("/verify-checkout/{user_id}")
|
||||
async def verify_checkout_status(
|
||||
user_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
request: Request = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Directly query Stripe for user's current subscription status.
|
||||
Used during post-checkout polling to get fresh data without waiting for webhooks.
|
||||
|
||||
Rate limited: 5 requests per minute per user to prevent abuse.
|
||||
"""
|
||||
from ..dependencies import verify_user_access
|
||||
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier
|
||||
from services.subscription import PricingService
|
||||
from api.subscription.utils import format_plan_limits
|
||||
from datetime import datetime
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
# Rate limiting: 5 requests per minute per user
|
||||
now = time.time()
|
||||
window_start = now - 60 # 1 minute window
|
||||
if user_id not in _checkout_attempts_by_user:
|
||||
_checkout_attempts_by_user[user_id] = []
|
||||
attempts = _checkout_attempts_by_user[user_id]
|
||||
attempts[:] = [ts for ts in attempts if ts >= window_start]
|
||||
attempts.append(now)
|
||||
_checkout_attempts_by_user[user_id] = attempts
|
||||
|
||||
if len(attempts) > 5:
|
||||
client_ip = request.client.host if request and request.client else "unknown"
|
||||
logger.warning(f"Verify-checkout rate limit exceeded for user_id={user_id}, ip={client_ip}")
|
||||
raise HTTPException(status_code=429, detail="Too many verification requests. Please wait before trying again.")
|
||||
|
||||
stripe_service = StripeService(db)
|
||||
|
||||
try:
|
||||
# First, try to find user in local DB
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
stripe_customer_id = subscription.stripe_customer_id if subscription else None
|
||||
|
||||
# If no stripe_customer_id in DB, try to find it by email
|
||||
if not stripe_customer_id:
|
||||
try:
|
||||
import stripe
|
||||
# Get user email from auth context
|
||||
user_email = current_user.get("email")
|
||||
if user_email:
|
||||
customers = stripe.Customer.list(email=user_email, limit=1)
|
||||
if customers and customers.data:
|
||||
stripe_customer_id = customers.data[0].id
|
||||
logger.info(f"Verify-checkout: Found Stripe customer by email for user {user_id}")
|
||||
|
||||
# Update DB with found customer ID
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = stripe_customer_id
|
||||
db.commit()
|
||||
else:
|
||||
logger.info(f"Verify-checkout: No local subscription record for user {user_id}, will query Stripe directly")
|
||||
except Exception as email_err:
|
||||
logger.warning(f"Failed to find Stripe customer by email: {email_err}")
|
||||
|
||||
# If user has a Stripe customer ID, query Stripe directly
|
||||
if stripe_customer_id:
|
||||
try:
|
||||
import stripe
|
||||
stripe_subscriptions = stripe.Subscription.list(
|
||||
customer=stripe_customer_id,
|
||||
status="active",
|
||||
limit=1
|
||||
)
|
||||
|
||||
if stripe_subscriptions and stripe_subscriptions.data:
|
||||
stripe_sub = stripe_subscriptions.data[0]
|
||||
price_id = stripe_sub['items']['data'][0]['price']['id']
|
||||
|
||||
logger.info(f"Verify-checkout: Found active Stripe subscription for user {user_id}, plan from price {price_id}")
|
||||
|
||||
# Update local DB with fresh Stripe data
|
||||
stripe_service._update_user_subscription(
|
||||
user_id,
|
||||
stripe_customer_id=stripe_customer_id,
|
||||
stripe_subscription_id=stripe_sub.id,
|
||||
status="active",
|
||||
price_id=price_id
|
||||
)
|
||||
|
||||
# Clear caches
|
||||
try:
|
||||
PricingService.clear_user_cache(user_id)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
db.expire_all()
|
||||
|
||||
# Re-query with fresh data
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": format_plan_limits(subscription.plan),
|
||||
"source": "stripe_direct"
|
||||
}
|
||||
}
|
||||
except Exception as stripe_err:
|
||||
logger.warning(f"Failed to query Stripe directly for user {user_id}: {stripe_err}")
|
||||
|
||||
# Fallback to local DB status
|
||||
if subscription and subscription.is_active:
|
||||
from services.subscription.pricing_service import PricingService
|
||||
pricing = PricingService(db)
|
||||
try:
|
||||
pricing._ensure_subscription_current(subscription)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": subscription.plan.tier.value,
|
||||
"tier": subscription.plan.tier.value,
|
||||
"can_use_api": True,
|
||||
"limits": format_plan_limits(subscription.plan),
|
||||
"source": "local_db"
|
||||
}
|
||||
}
|
||||
|
||||
# No active subscription - return free tier
|
||||
free_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.tier == SubscriptionTier.FREE,
|
||||
SubscriptionPlan.is_active == True
|
||||
).first()
|
||||
|
||||
if free_plan:
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": True,
|
||||
"plan": "free",
|
||||
"tier": "free",
|
||||
"can_use_api": True,
|
||||
"limits": format_plan_limits(free_plan),
|
||||
"source": "free_tier"
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": {
|
||||
"active": False,
|
||||
"plan": "none",
|
||||
"tier": "none",
|
||||
"can_use_api": False,
|
||||
"reason": "No active subscription found",
|
||||
"source": "none"
|
||||
}
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying checkout status for user {user_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to verify subscription: {str(e)}")
|
||||
|
||||
@@ -156,10 +156,13 @@ class WixPublishRequest(BaseModel):
|
||||
content: str
|
||||
cover_image_url: Optional[str] = None
|
||||
category_ids: Optional[list] = None
|
||||
category_names: Optional[list] = None
|
||||
tag_ids: Optional[list] = None
|
||||
tag_names: Optional[list] = None
|
||||
publish: bool = True
|
||||
# Optional access token for test-real publish flow
|
||||
access_token: Optional[str] = None
|
||||
member_id: Optional[str] = None
|
||||
seo_metadata: Optional[Dict[str, Any]] = None
|
||||
class WixCreateCategoryRequest(BaseModel):
|
||||
access_token: str
|
||||
label: str
|
||||
@@ -398,31 +401,29 @@ async def handle_oauth_callback_get(code: str, state: Optional[str] = None, requ
|
||||
|
||||
|
||||
@router.get("/connection/status")
|
||||
async def get_connection_status(current_user: dict = Depends(get_current_user)) -> WixConnectionStatus:
|
||||
async def get_connection_status(current_user: dict = Depends(get_current_user)) -> Dict[str, Any]:
|
||||
"""
|
||||
Check Wix connection status and permissions
|
||||
|
||||
Args:
|
||||
current_user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Connection status and permissions
|
||||
Check Wix connection status and permissions.
|
||||
Returns connected: false when no tokens are stored (instead of 401).
|
||||
"""
|
||||
try:
|
||||
token_info = _resolve_valid_wix_token(current_user)
|
||||
access_token = token_info["access_token"]
|
||||
site_info = wix_service.get_site_info(access_token)
|
||||
permissions = wix_service.check_blog_permissions(access_token)
|
||||
return WixConnectionStatus(
|
||||
connected=True,
|
||||
has_permissions=permissions.get("has_permissions", False),
|
||||
site_info=site_info,
|
||||
permissions=permissions
|
||||
)
|
||||
return {
|
||||
"connected": True,
|
||||
"has_permissions": permissions.get("has_permissions", False),
|
||||
"site_info": site_info,
|
||||
"permissions": permissions
|
||||
}
|
||||
except HTTPException as e:
|
||||
if e.status_code == 401:
|
||||
return {"connected": False, "has_permissions": False}
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check connection status: {e}")
|
||||
mapped = _map_wix_error(e, "Failed to check Wix connection status")
|
||||
raise mapped
|
||||
return {"connected": False, "has_permissions": False}
|
||||
|
||||
|
||||
@router.get("/status")
|
||||
@@ -450,41 +451,81 @@ async def get_wix_status(current_user: dict = Depends(get_current_user)) -> Dict
|
||||
@router.post("/publish")
|
||||
async def publish_to_wix(request: WixPublishRequest, current_user: dict = Depends(get_current_user)) -> Dict[str, Any]:
|
||||
"""
|
||||
Publish blog post to Wix
|
||||
Publish blog post to Wix using server-stored OAuth tokens.
|
||||
|
||||
Args:
|
||||
request: Blog post data
|
||||
current_user: Current authenticated user
|
||||
|
||||
Returns:
|
||||
Published blog post information
|
||||
The backend resolves the access token from the database (via
|
||||
_resolve_valid_wix_token), so callers do NOT need to pass
|
||||
access_token unless they want to override the stored one.
|
||||
"""
|
||||
try:
|
||||
token_info = _resolve_valid_wix_token(current_user)
|
||||
access_token = token_info["access_token"]
|
||||
if request.access_token:
|
||||
from services.integrations.wix.utils import normalize_token_string
|
||||
access_token = normalize_token_string(request.access_token)
|
||||
else:
|
||||
try:
|
||||
token_info = _resolve_valid_wix_token(current_user)
|
||||
access_token = token_info["access_token"]
|
||||
except HTTPException:
|
||||
access_token = None
|
||||
|
||||
member_id = token_info.get("member_id") or wix_service.extract_member_id_from_access_token(access_token)
|
||||
if not access_token:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Wix account not connected. Connect your Wix account first.",
|
||||
}
|
||||
|
||||
member_id = request.member_id
|
||||
if not member_id:
|
||||
member_id = wix_service.extract_member_id_from_access_token(access_token)
|
||||
if not member_id:
|
||||
member_info = wix_service.get_current_member(access_token)
|
||||
member_id = (member_info.get("member") or {}).get("id") or member_info.get("id")
|
||||
if not member_id:
|
||||
raise HTTPException(status_code=401, detail="Unable to resolve Wix member ID")
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Unable to resolve Wix member ID. Please reconnect your Wix account.",
|
||||
}
|
||||
|
||||
# Resolve categories: accept IDs or names (looked up/created)
|
||||
category_ids = request.category_ids or request.category_names
|
||||
tag_ids = request.tag_ids or request.tag_names
|
||||
|
||||
seo_metadata = request.seo_metadata
|
||||
if seo_metadata:
|
||||
if not category_ids and seo_metadata.get("blog_categories"):
|
||||
category_ids = seo_metadata.get("blog_categories")
|
||||
if not tag_ids and seo_metadata.get("blog_tags"):
|
||||
tag_ids = seo_metadata.get("blog_tags")
|
||||
|
||||
# Ensure category_ids and tag_ids are lists of strings (not ints)
|
||||
if category_ids:
|
||||
category_ids = [str(c) for c in category_ids if c is not None]
|
||||
if tag_ids:
|
||||
tag_ids = [str(t) for t in tag_ids if t is not None]
|
||||
|
||||
result = wix_service.create_blog_post(
|
||||
access_token=access_token,
|
||||
title=request.title,
|
||||
content=request.content,
|
||||
cover_image_url=request.cover_image_url,
|
||||
category_ids=request.category_ids,
|
||||
tag_ids=request.tag_ids,
|
||||
category_ids=category_ids,
|
||||
tag_ids=tag_ids,
|
||||
publish=request.publish,
|
||||
member_id=member_id,
|
||||
seo_metadata=seo_metadata,
|
||||
)
|
||||
post = result.get("draftPost") or result.get("post") or result
|
||||
raw_url = post.get("url")
|
||||
if isinstance(raw_url, dict):
|
||||
post_url = raw_url.get("base", "").rstrip("/") + "/" + raw_url.get("path", "").lstrip("/")
|
||||
elif isinstance(raw_url, str):
|
||||
post_url = raw_url
|
||||
else:
|
||||
post_url = None
|
||||
return {
|
||||
"success": True,
|
||||
"post_id": post.get("id"),
|
||||
"url": post.get("url"),
|
||||
"post_id": str(post.get("id", "")),
|
||||
"url": post_url,
|
||||
"publish_state": "PUBLISHED" if request.publish else "DRAFT"
|
||||
}
|
||||
except Exception as e:
|
||||
|
||||
@@ -55,6 +55,8 @@ async def suggest_endpoint(req: SuggestRequest, current_user: Dict[str, Any] = D
|
||||
for s in suggestions
|
||||
],
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Writing assistant error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -147,13 +147,26 @@ else:
|
||||
product_marketing_router = None
|
||||
campaign_creator_router = None
|
||||
|
||||
# Import hallucination detector router (skip in feature-only modes - triggers heavy ML)
|
||||
if _is_full_mode():
|
||||
# Import hallucination detector router
|
||||
try:
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
else:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import hallucination_detector router: {e}")
|
||||
hallucination_detector_router = None
|
||||
writing_assistant_router = None
|
||||
|
||||
# Import charts router (shared chart generation for blog writer, podcast, etc.)
|
||||
try:
|
||||
from api.charts import router as charts_router
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import charts router: {e}")
|
||||
charts_router = None
|
||||
|
||||
# Import links router (internal & external link search and rewording)
|
||||
try:
|
||||
from api.links import router as links_router
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import links router: {e}")
|
||||
links_router = None
|
||||
|
||||
# Import research configuration router (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
@@ -486,10 +499,18 @@ else:
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Safety net: explicitly include hallucination detector (router_manager may skip silently)
|
||||
# Safety net: explicitly include hallucination detector (import may fail gracefully)
|
||||
if hallucination_detector_router:
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
|
||||
# Include charts router (shared chart generation)
|
||||
if charts_router:
|
||||
router_manager.include_router_safely(charts_router, "charts")
|
||||
|
||||
# Include links router (internal & external link search)
|
||||
if links_router:
|
||||
router_manager.include_router_safely(links_router, "links")
|
||||
|
||||
# Log startup summary
|
||||
router_manager.log_startup_summary()
|
||||
|
||||
|
||||
@@ -81,6 +81,8 @@ from routers.campaign_creator import router as campaign_creator_router
|
||||
# Import hallucination detector router
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
from api.charts import router as charts_router
|
||||
from api.links import router as links_router
|
||||
|
||||
# Import research configuration router
|
||||
from api.research_config import router as research_config_router
|
||||
@@ -254,6 +256,10 @@ router_manager.include_core_routers()
|
||||
router_manager.include_router_safely(subscription_router, "subscription")
|
||||
# Include hallucination detector explicitly (router_manager may skip silently on import failure)
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
# Include charts router (shared chart generation for blog writer, podcast, etc.)
|
||||
router_manager.include_router_safely(charts_router, "charts")
|
||||
# Include links router (internal & external link search and rewording)
|
||||
router_manager.include_router_safely(links_router, "links")
|
||||
router_manager.include_optional_routers()
|
||||
|
||||
# SEO Dashboard endpoints
|
||||
|
||||
@@ -157,6 +157,9 @@ class BlogOutlineSection(BaseModel):
|
||||
references: List[ResearchSource] = []
|
||||
target_words: Optional[int] = None
|
||||
keywords: List[str] = []
|
||||
chart_data: Optional[Dict[str, Any]] = None
|
||||
chart_url: Optional[str] = None
|
||||
chart_id: Optional[str] = None
|
||||
|
||||
|
||||
class BlogOutlineRequest(BaseModel):
|
||||
|
||||
@@ -8,6 +8,7 @@ from loguru import logger
|
||||
import os
|
||||
|
||||
from services.gsc_service import GSCService
|
||||
from services.gsc_brainstorm_service import GSCBrainstormService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
# Initialize router
|
||||
@@ -15,6 +16,7 @@ router = APIRouter(prefix="/gsc", tags=["Google Search Console"])
|
||||
|
||||
# Initialize GSC service
|
||||
gsc_service = GSCService()
|
||||
brainstorm_service = GSCBrainstormService(gsc_service)
|
||||
|
||||
# Pydantic models
|
||||
class GSCAnalyticsRequest(BaseModel):
|
||||
@@ -22,6 +24,10 @@ class GSCAnalyticsRequest(BaseModel):
|
||||
start_date: Optional[str] = None
|
||||
end_date: Optional[str] = None
|
||||
|
||||
class GSCBrainstormRequest(BaseModel):
|
||||
keywords: str
|
||||
site_url: Optional[str] = None
|
||||
|
||||
class GSCStatusResponse(BaseModel):
|
||||
connected: bool
|
||||
sites: Optional[List[Dict[str, Any]]] = None
|
||||
@@ -199,6 +205,49 @@ async def get_gsc_analytics(
|
||||
logger.error(f"Error getting GSC analytics: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error getting analytics: {str(e)}")
|
||||
|
||||
@router.post("/brainstorm")
|
||||
async def brainstorm_topics(
|
||||
request: GSCBrainstormRequest,
|
||||
user: dict = Depends(get_current_user),
|
||||
):
|
||||
"""Brainstorm blog topic suggestions based on the user's GSC data.
|
||||
|
||||
The user must have GSC connected. If no site_url is provided,
|
||||
the first verified site is used automatically.
|
||||
"""
|
||||
try:
|
||||
user_id = user.get('id')
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="User ID not found")
|
||||
|
||||
tokens = request.keywords.strip().split()
|
||||
if len(tokens) < 3:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Please provide at least 3 words for brainstorming topic suggestions.",
|
||||
)
|
||||
|
||||
logger.info(f"GSC brainstorm for user: {user_id}, keywords: {request.keywords!r}")
|
||||
|
||||
result = brainstorm_service.brainstorm_topics(
|
||||
user_id=user_id,
|
||||
keywords=request.keywords,
|
||||
site_url=request.site_url,
|
||||
)
|
||||
|
||||
if "error" in result and not result.get("content_opportunities"):
|
||||
status = 400 if "No GSC sites" in result["error"] else 500
|
||||
raise HTTPException(status_code=status, detail=result["error"])
|
||||
|
||||
logger.info(f"GSC brainstorm completed for user: {user_id}")
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in GSC brainstorm: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Error brainstorming topics: {str(e)}")
|
||||
|
||||
@router.get("/sitemaps/{site_url:path}")
|
||||
async def get_gsc_sitemaps(
|
||||
site_url: str,
|
||||
|
||||
@@ -269,16 +269,18 @@ class MediumBlogGenerator:
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=full_content,
|
||||
source_module="medium_blog_writer",
|
||||
source_module="blog_writer",
|
||||
title=result.title,
|
||||
description=f"Generated medium blog: {result.title}",
|
||||
tags=req.researchKeywords or ["medium_blog", "ai_generated"],
|
||||
description=f"Blog: {result.title}",
|
||||
tags=req.researchKeywords or ["blog", "ai_generated"],
|
||||
asset_metadata={
|
||||
"blog_type": "medium",
|
||||
"model": result.model,
|
||||
"generation_time_ms": result.generation_time_ms,
|
||||
"word_count": sum(s.wordCount for s in result.sections)
|
||||
"word_count": sum(s.wordCount for s in result.sections),
|
||||
"section_count": len(result.sections),
|
||||
},
|
||||
subdirectory="medium_blogs"
|
||||
subdirectory="blogs"
|
||||
)
|
||||
logger.info(f"Saved medium blog content to user workspace for user {user_id}")
|
||||
except Exception as e:
|
||||
|
||||
@@ -6,8 +6,11 @@ Neural search implementation using Exa API for high-quality, citation-rich resea
|
||||
|
||||
from exa_py import Exa
|
||||
import os
|
||||
import asyncio
|
||||
from typing import List, Dict, Any
|
||||
from loguru import logger
|
||||
from models.subscription_models import APIProvider
|
||||
from fastapi import HTTPException
|
||||
from .base_provider import ResearchProvider as BaseProvider
|
||||
|
||||
|
||||
@@ -216,6 +219,123 @@ class ExaResearchProvider(BaseProvider):
|
||||
"""Estimate token usage for Exa (not token-based)."""
|
||||
return 0 # Exa is per-search, not token-based
|
||||
|
||||
async def simple_search(
|
||||
self,
|
||||
query: str,
|
||||
num_results: int = 5,
|
||||
user_id: str = None,
|
||||
include_domains: List[str] = None,
|
||||
exclude_domains: List[str] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Simple Exa search for fact-checking and writing assistance.
|
||||
Handles subscription preflight check and usage tracking.
|
||||
|
||||
Args:
|
||||
query: Search query string
|
||||
num_results: Number of results to return (default 5)
|
||||
user_id: Optional user ID for subscription checking
|
||||
include_domains: Only return results from these domains (for internal links)
|
||||
exclude_domains: Exclude results from these domains (for external-only links)
|
||||
|
||||
Returns:
|
||||
List of source dicts with title, url, text, publishedDate, author, score keys
|
||||
|
||||
Raises:
|
||||
HTTPException(429): If user has exceeded subscription limits
|
||||
Exception: If Exa API key not configured or search fails
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise Exception("EXA_API_KEY not configured")
|
||||
|
||||
# Preflight subscription check
|
||||
if user_id:
|
||||
from services.subscription import PricingService
|
||||
from services.database import get_session_for_user
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': 'insufficient_balance',
|
||||
'message': message,
|
||||
'provider': 'exa',
|
||||
'usage_info': usage_info or {}
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[Exa simple_search] Preflight check failed: {e}")
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
search_kwargs = {
|
||||
"type": "auto",
|
||||
"num_results": num_results,
|
||||
"text": {"max_characters": 1000},
|
||||
"highlights": {"num_sentences": 2, "highlights_per_url": 2},
|
||||
}
|
||||
if include_domains:
|
||||
search_kwargs["include_domains"] = include_domains
|
||||
if exclude_domains:
|
||||
search_kwargs["exclude_domains"] = exclude_domains
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.exa.search_and_contents(query, **search_kwargs),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Exa simple_search] API call failed: {e}")
|
||||
# Retry with simpler parameters
|
||||
retry_kwargs = {"type": "auto", "num_results": num_results, "text": True}
|
||||
if include_domains:
|
||||
retry_kwargs["include_domains"] = include_domains
|
||||
if exclude_domains:
|
||||
retry_kwargs["exclude_domains"] = exclude_domains
|
||||
try:
|
||||
logger.info("[Exa simple_search] Retrying with simplified parameters")
|
||||
results = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.exa.search_and_contents(query, **retry_kwargs),
|
||||
)
|
||||
except Exception as retry_error:
|
||||
logger.error(f"[Exa simple_search] Retry also failed: {retry_error}")
|
||||
raise RuntimeError(f"Exa search failed: {str(retry_error)}") from retry_error
|
||||
|
||||
sources = []
|
||||
for result in results.results:
|
||||
sources.append({
|
||||
'title': getattr(result, 'title', 'Untitled'),
|
||||
'url': getattr(result, 'url', ''),
|
||||
'text': getattr(result, 'text', ''),
|
||||
'publishedDate': getattr(result, 'publishedDate', ''),
|
||||
'author': getattr(result, 'author', ''),
|
||||
'score': getattr(result, 'score', 0.5),
|
||||
})
|
||||
|
||||
# Track usage
|
||||
if user_id:
|
||||
cost = 0.005 # ~0.5 cents per search
|
||||
try:
|
||||
self.track_exa_usage(user_id, cost)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Exa simple_search] Failed to track usage: {e}")
|
||||
|
||||
logger.info(f"[Exa simple_search] Found {len(sources)} sources for query: {query[:80]}...")
|
||||
return sources
|
||||
|
||||
def _map_source_type_to_category(self, source_types):
|
||||
"""Map SourceType enum to Exa category parameter."""
|
||||
if not source_types:
|
||||
|
||||
951
backend/services/chart_service.py
Normal file
951
backend/services/chart_service.py
Normal file
@@ -0,0 +1,951 @@
|
||||
"""
|
||||
Chart Service — Shared chart generation for Blog Writer, Podcast Maker, and future modules.
|
||||
|
||||
Extracts the chart rendering logic from podcast/broll_composer into a reusable service
|
||||
that any module can call. Supports:
|
||||
- Direct chart rendering (caller provides chart_type + chart_data)
|
||||
- AI-driven chart inference (caller provides text, LLM infers chart_type + chart_data)
|
||||
|
||||
Chart types: bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet_points
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List
|
||||
from dataclasses import dataclass, field
|
||||
from loguru import logger
|
||||
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
CHART_STYLE = {
|
||||
"bg": "#0D0D0D",
|
||||
"bar_before": "#2E4057",
|
||||
"bar_after": "#E63946",
|
||||
"text": "#F1F1EF",
|
||||
"grid": "#2A2A2A",
|
||||
"accent": "#E63946",
|
||||
"pie_colors": ["#E63946", "#2E4057", "#457B9D", "#A8DADC", "#F4A261", "#2A9D8F"],
|
||||
}
|
||||
|
||||
VALID_CHART_TYPES = [
|
||||
"bar_comparison", "bar_chart_comparison",
|
||||
"bar_horizontal", "line_trend",
|
||||
"pie", "stacked_bar",
|
||||
"bullet", "bullet_points",
|
||||
]
|
||||
|
||||
CHART_INFERENCE_SYSTEM_PROMPT = """You are a data visualization expert. Given text content, determine the most appropriate chart type and extract structured data for rendering.
|
||||
|
||||
You MUST respond with ONLY a valid JSON object (no markdown, no explanation) with this exact structure:
|
||||
{
|
||||
"chart_type": "one of: bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet_points",
|
||||
"chart_data": { ... appropriate data structure for the chart type ... },
|
||||
"title": "A clear, concise chart title"
|
||||
}
|
||||
|
||||
Chart data structures by type:
|
||||
- bar_comparison: {"labels": [...], "before": [...], "after": [...]} OR {"labels": [...], "values": [...]}
|
||||
- bar_horizontal: {"labels": [...], "values": [...]}
|
||||
- line_trend: {"labels": [...], "values": [...]}
|
||||
- pie: {"labels": [...], "values": [...]}
|
||||
- stacked_bar: {"labels": [...], "stacks": [[...], [...]]}
|
||||
- bullet_points: {"bullet_points": [...]}
|
||||
|
||||
Rules:
|
||||
1. Choose the chart type that best represents the information in the text.
|
||||
2. Use bar_comparison for before/after comparisons.
|
||||
3. Use line_trend for time-series or sequential data.
|
||||
4. Use pie for proportional breakdowns of a whole.
|
||||
5. Use bar_horizontal for rankings or comparisons.
|
||||
6. Use bullet_points if the text is qualitative with no strong numeric data.
|
||||
7. Extract realistic numeric values from the text when available.
|
||||
8. If no data is extractable, use bullet_points and list key points.
|
||||
9. Keep labels short (under 20 chars)."""
|
||||
|
||||
|
||||
CHART_INFERENCE_USER_PROMPT = """Create a chart from this text:
|
||||
|
||||
{text}
|
||||
|
||||
Return ONLY the JSON object with chart_type, chart_data, and title."""
|
||||
|
||||
|
||||
CHART_ANALYSIS_SYSTEM_PROMPT = """You are a data visualization analyst. Given text from a blog section, your job is to:
|
||||
1. Determine whether the text contains enough specific numeric data to create a meaningful chart
|
||||
2. If YES: explain what data is available and suggest a chart type
|
||||
3. If NO: suggest 2-3 specific search queries that would find relevant statistics/data to create a chart for this topic
|
||||
|
||||
You MUST respond with ONLY a valid JSON object (no markdown, no explanation):
|
||||
{
|
||||
"has_data": true/false,
|
||||
"data_description": "brief description of what data is available or why it's insufficient",
|
||||
"suggested_chart_type": "best chart type if has_data is true, otherwise null",
|
||||
"search_queries": ["query1", "query2", "query3"] // Empty array if has_data is true
|
||||
}
|
||||
|
||||
Be optimistic — if there's ANY numeric claim, percentage, comparison, or trend in the text, set has_data to true.
|
||||
Only set has_data to false if the text is purely qualitative with no numbers, percentages, comparisons, or trends."""
|
||||
|
||||
|
||||
CHART_ANALYSIS_USER_PROMPT = """Analyze this text for chart potential:
|
||||
|
||||
Section: {section_heading}
|
||||
{key_points_section}
|
||||
Text: {text}
|
||||
|
||||
Determine if this text contains enough data for a chart, or suggest search queries to find the data."""
|
||||
|
||||
|
||||
CHART_SYNTHESIS_SYSTEM_PROMPT = """You are a data visualization expert. You have been given:
|
||||
1. Original text from a blog section
|
||||
2. Research data found from web searches
|
||||
|
||||
Create a chart that visualizes the most interesting insight from the combination of the original text and research data.
|
||||
|
||||
You MUST respond with ONLY a valid JSON object (no markdown, no explanation) with this exact structure:
|
||||
{
|
||||
"chart_type": "one of: bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet_points",
|
||||
"chart_data": { ... appropriate data structure ... },
|
||||
"title": "A clear, concise chart title",
|
||||
"source": "Brief source attribution"
|
||||
}
|
||||
|
||||
Chart data structures by type:
|
||||
- bar_comparison: {"labels": [...], "before": [...], "after": [...]} OR {"labels": [...], "values": [...]}
|
||||
- bar_horizontal: {"labels": [...], "values": [...]}
|
||||
- line_trend: {"labels": [...], "values": [...]}
|
||||
- pie: {"labels": [...], "values": [...]}
|
||||
- stacked_bar: {"labels": [...], "stacks": [[...], [...]]}
|
||||
- bullet_points: {"bullet_points": [...]}
|
||||
|
||||
Rules:
|
||||
1. Use the research data to create accurate, fact-based charts
|
||||
2. Prefer bar_comparison for before/after or categorical comparisons
|
||||
3. Prefer line_trend for trends over time
|
||||
4. Prefer pie for market share or proportional breakdowns
|
||||
5. Keep labels short (under 20 characters)
|
||||
6. Use realistic values from the research — do NOT invent numbers
|
||||
7. Always include a source attribution based on where the data came from
|
||||
8. If the research doesn't contain useful numeric data, fall back to bullet_points with key insights"""
|
||||
|
||||
|
||||
CHART_SYNTHESIS_USER_PROMPT = """Original text:
|
||||
{text}
|
||||
|
||||
Research data found:
|
||||
{research}
|
||||
|
||||
Create a chart that visualizes the most interesting data insight from the combination above."""
|
||||
|
||||
|
||||
def _normalize_chart_type(chart_type: str) -> str:
|
||||
"""Normalize chart type aliases."""
|
||||
mapping = {
|
||||
"bar_chart_comparison": "bar_comparison",
|
||||
"bullet": "bullet_points",
|
||||
}
|
||||
return mapping.get(chart_type, chart_type)
|
||||
|
||||
|
||||
def _add_source_overlay(image_path: str, source: str) -> None:
|
||||
"""Add a source attribution overlay to a chart image (in-place)."""
|
||||
if not source or not os.path.exists(image_path):
|
||||
return
|
||||
try:
|
||||
img = Image.open(image_path).convert("RGBA")
|
||||
draw = ImageDraw.Draw(img)
|
||||
source_text = f"Source: {source[:80]}"
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 11)
|
||||
except (OSError, IOError):
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 11)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
text_bbox = draw.textbbox((0, 0), source_text, font=font)
|
||||
text_w = text_bbox[2] - text_bbox[0]
|
||||
text_h = text_bbox[3] - text_bbox[1]
|
||||
x = img.width - text_w - 12
|
||||
y = img.height - text_h - 8
|
||||
draw.rectangle([x - 4, y - 2, x + text_w + 4, y + text_h + 2], fill=(0, 0, 0, 140))
|
||||
draw.text((x, y), source_text, fill=(200, 200, 200, 220), font=font)
|
||||
img.save(image_path)
|
||||
except Exception as e:
|
||||
logger.warning(f"[ChartService] Source overlay failed (non-fatal): {e}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chart generators (Matplotlib → PNG with transparency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_bar_chart(data: dict, out_path: str, title: str = "",
|
||||
show_legend: bool = True, value_suffix: str = "%",
|
||||
subtitle: str = "") -> str:
|
||||
labels = data.get("labels", [])
|
||||
before = data.get("before", [])
|
||||
after = data.get("after", [])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
if not before and not after:
|
||||
values = data.get("values", [])
|
||||
if values and labels:
|
||||
n = min(len(labels), len(values))
|
||||
labels = labels[:n]
|
||||
before = [0] * n
|
||||
after = values[:n]
|
||||
data = {**data, "labels": labels, "before": before, "after": after}
|
||||
|
||||
x = np.arange(len(labels))
|
||||
w = 0.35
|
||||
bars_b = ax.bar(x - w / 2, before, w, color=CHART_STYLE["bar_before"],
|
||||
label="Before", zorder=3, edgecolor="none")
|
||||
bars_a = ax.bar(x + w / 2, after, w, color=CHART_STYLE["bar_after"],
|
||||
label="After", zorder=3, edgecolor="none")
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
||||
ax.tick_params(axis="y", colors=CHART_STYLE["text"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
|
||||
for bar in [*bars_b, *bars_a]:
|
||||
h = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.5, f"{h:.0f}{value_suffix}",
|
||||
ha="center", va="bottom", color=CHART_STYLE["text"], fontsize=9,
|
||||
fontweight="bold")
|
||||
|
||||
if show_legend:
|
||||
ax.legend(frameon=False, labelcolor=CHART_STYLE["text"],
|
||||
fontsize=10, loc="upper left")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
if subtitle:
|
||||
fig.text(0.5, 0.02, subtitle, ha='center', color=CHART_STYLE["text"],
|
||||
fontsize=10, style='italic')
|
||||
|
||||
fig.tight_layout(pad=0.5, rect=(0, 0.03 if subtitle else 0, 1, 1))
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_horizontal_bar(data: dict, out_path: str, title: str = "",
|
||||
value_suffix: str = "%", bar_color: str = None) -> str:
|
||||
labels = data.get("labels", [])
|
||||
values = data.get("values", data.get("y", []))
|
||||
|
||||
if not values:
|
||||
return ""
|
||||
|
||||
bar_color = bar_color or CHART_STYLE["bar_after"]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
y_pos = np.arange(len(labels))
|
||||
bars = ax.barh(y_pos, values, color=bar_color, zorder=3, edgecolor="none", height=0.6)
|
||||
|
||||
ax.set_yticks(y_pos)
|
||||
ax.set_yticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
||||
ax.tick_params(axis="x", colors=CHART_STYLE["text"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.xaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
ax.invert_yaxis()
|
||||
|
||||
for i, bar in enumerate(bars):
|
||||
width = bar.get_width()
|
||||
ax.text(width + 0.5, bar.get_y() + bar.get_height()/2, f"{width:.0f}{value_suffix}",
|
||||
ha="left", va="center", color=CHART_STYLE["text"], fontsize=10,
|
||||
fontweight="bold")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_pie_chart(data: dict, out_path: str, title: str = "",
|
||||
show_labels: bool = True, show_percent: bool = True,
|
||||
donut: bool = False) -> str:
|
||||
labels = data.get("labels", [])
|
||||
values = data.get("values", data.get("y", []))
|
||||
|
||||
if not values:
|
||||
return ""
|
||||
|
||||
colors = CHART_STYLE["pie_colors"][:len(values)]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(6, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
if donut:
|
||||
wedges, texts, autotexts = ax.pie(
|
||||
values, labels=labels if show_labels else None,
|
||||
colors=colors, autopct=lambda p: f'{p:.1f}%' if show_percent else '',
|
||||
startangle=90, pctdistance=0.75,
|
||||
wedgeprops=dict(width=0.5, edgecolor="none")
|
||||
)
|
||||
else:
|
||||
wedges, texts, autotexts = ax.pie(
|
||||
values, labels=labels if show_labels else None,
|
||||
colors=colors, autopct=lambda p: f'{p:.1f}%' if show_percent else '',
|
||||
startangle=90, pctdistance=0.8
|
||||
)
|
||||
|
||||
for text in texts:
|
||||
text.set_color(CHART_STYLE["text"])
|
||||
text.set_fontsize(10)
|
||||
|
||||
for autotext in autotexts:
|
||||
autotext.set_color(CHART_STYLE["text"])
|
||||
autotext.set_fontsize(9)
|
||||
autotext.set_fontweight("bold")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_stacked_bar(data: dict, out_path: str, title: str = "",
|
||||
stack_labels: list = None) -> str:
|
||||
labels = data.get("labels", [])
|
||||
stacks = data.get("stacks", [])
|
||||
|
||||
if not stacks or len(stacks) < 2:
|
||||
return ""
|
||||
|
||||
stack_labels = stack_labels or [f"Series {i+1}" for i in range(len(stacks))]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
x = np.arange(len(labels))
|
||||
bottom = np.zeros(len(labels))
|
||||
colors = CHART_STYLE["pie_colors"][:len(stacks)]
|
||||
|
||||
for i, stack in enumerate(stacks):
|
||||
bars = ax.bar(x, stack, 0.6, bottom=bottom, color=colors[i],
|
||||
label=stack_labels[i], zorder=3, edgecolor="none")
|
||||
|
||||
for j, bar in enumerate(bars):
|
||||
height = bar.get_height()
|
||||
if height > 5:
|
||||
ax.text(bar.get_x() + bar.get_width()/2,
|
||||
bottom[j] + height/2,
|
||||
f"{height:.0f}", ha="center", va="center",
|
||||
color=CHART_STYLE["text"], fontsize=8, fontweight="bold")
|
||||
|
||||
bottom = bottom + np.array(stack)
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
||||
ax.tick_params(axis="y", colors=CHART_STYLE["text"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.legend(frameon=False, labelcolor=CHART_STYLE["text"], fontsize=9, loc="upper left")
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_line_trend(data: dict, out_path: str, title: str = "") -> str:
|
||||
x_labels = data.get("labels", data.get("x", []))
|
||||
y_vals = data.get("values", data.get("y", []))
|
||||
|
||||
if not x_labels or not y_vals:
|
||||
return ""
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
try:
|
||||
x_vals = [float(v) for v in x_labels]
|
||||
except (ValueError, TypeError):
|
||||
x_vals = list(range(len(x_labels)))
|
||||
|
||||
ax.plot(x_vals, y_vals, color=CHART_STYLE["accent"],
|
||||
linewidth=2.5, marker="o", markersize=7, zorder=3)
|
||||
ax.fill_between(x_vals, y_vals, alpha=0.12, color=CHART_STYLE["accent"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.tick_params(colors=CHART_STYLE["text"])
|
||||
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
|
||||
try:
|
||||
x_labels_f = [float(v) for v in x_labels]
|
||||
except (ValueError, TypeError):
|
||||
ax.set_xticks(x_vals)
|
||||
ax.set_xticklabels(x_labels, color=CHART_STYLE["text"], fontsize=10)
|
||||
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_bullet_overlay(lines: list, out_path: str,
|
||||
width: int = 900, font_size: int = 32) -> str:
|
||||
padding = 32
|
||||
line_h = font_size + 16
|
||||
img_h = padding * 2 + len(lines) * line_h + 12
|
||||
img = Image.new("RGBA", (width, img_h), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
draw.rounded_rectangle([0, 0, width - 1, img_h - 1],
|
||||
radius=18, fill=(10, 10, 10, 185))
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
||||
font_size)
|
||||
except OSError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
y = padding
|
||||
for line in lines:
|
||||
draw.text((padding + 18, y), f"\u2022 {line}", font=font, fill=(241, 241, 239, 255))
|
||||
y += line_h
|
||||
|
||||
img.save(out_path, format="PNG")
|
||||
return out_path
|
||||
|
||||
|
||||
CHART_RENDERERS = {
|
||||
"bar_comparison": make_bar_chart,
|
||||
"bar_chart_comparison": make_bar_chart,
|
||||
"bar_horizontal": make_horizontal_bar,
|
||||
"line_trend": make_line_trend,
|
||||
"pie": make_pie_chart,
|
||||
"stacked_bar": make_stacked_bar,
|
||||
"bullet_points": make_bullet_overlay,
|
||||
"bullet": make_bullet_overlay,
|
||||
}
|
||||
|
||||
|
||||
class ChartService:
|
||||
"""Shared chart generation service for all modules."""
|
||||
|
||||
def __init__(self, output_dir: Optional[str] = None, user_id: Optional[str] = None):
|
||||
if output_dir:
|
||||
self.output_dir = Path(output_dir)
|
||||
else:
|
||||
self.output_dir = self._default_chart_dir(user_id)
|
||||
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"[ChartService] Initialized with output directory: {self.output_dir}")
|
||||
|
||||
@staticmethod
|
||||
def _default_chart_dir(user_id: Optional[str] = None) -> Path:
|
||||
"""Get default chart directory (workspace-aware if user_id provided)."""
|
||||
if user_id:
|
||||
try:
|
||||
from api.podcast.constants import get_podcast_media_dir
|
||||
return get_podcast_media_dir("chart", user_id, ensure_exists=True)
|
||||
except Exception:
|
||||
pass
|
||||
base = Path.home() / ".alwrity" / "charts"
|
||||
base.mkdir(parents=True, exist_ok=True)
|
||||
return base
|
||||
|
||||
def get_output_path(self, filename: str) -> Path:
|
||||
return self.output_dir / filename
|
||||
|
||||
def get_chart_preview_path(self, chart_id: str) -> Path:
|
||||
return self.get_output_path(f"chart_preview_{chart_id}.png")
|
||||
|
||||
def generate_chart(
|
||||
self,
|
||||
chart_data: Dict[str, Any],
|
||||
chart_type: str = "bar_comparison",
|
||||
title: str = "",
|
||||
subtitle: str = "",
|
||||
chart_id: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Generate a chart PNG and return metadata.
|
||||
|
||||
Returns:
|
||||
{"path": str, "chart_id": str, "filename": str}
|
||||
Returns {"path": "", "chart_id": str, "filename": ""} on failure.
|
||||
"""
|
||||
resolved_id = chart_id or uuid.uuid4().hex[:8]
|
||||
out_path = str(self.get_chart_preview_path(resolved_id))
|
||||
normalized_type = _normalize_chart_type(chart_type)
|
||||
|
||||
logger.info(f"[ChartService] Generating chart: type={normalized_type}, id={resolved_id}")
|
||||
|
||||
try:
|
||||
result_path = self._render_chart(normalized_type, chart_data, out_path, title, subtitle)
|
||||
|
||||
if not result_path or not os.path.exists(result_path):
|
||||
logger.warning(f"[ChartService] Chart rendering returned empty path or file missing for type={normalized_type}")
|
||||
return {"path": "", "chart_id": resolved_id, "filename": ""}
|
||||
|
||||
source = chart_data.get("source", "").strip()
|
||||
if source:
|
||||
_add_source_overlay(result_path, source)
|
||||
|
||||
filename = Path(result_path).name
|
||||
logger.info(f"[ChartService] Chart generated: id={resolved_id}, path={result_path}")
|
||||
return {"path": result_path, "chart_id": resolved_id, "filename": filename}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ChartService] Chart generation failed: {e}")
|
||||
return {"path": "", "chart_id": resolved_id, "filename": ""}
|
||||
|
||||
def _render_chart(self, chart_type: str, chart_data: Dict[str, Any],
|
||||
out_path: str, title: str, subtitle: str) -> str:
|
||||
"""Dispatch to the appropriate chart renderer."""
|
||||
|
||||
if chart_type in ("bar_comparison", "bar_chart_comparison"):
|
||||
labels = chart_data.get("labels", [])
|
||||
before = chart_data.get("before", [])
|
||||
after = chart_data.get("after", [])
|
||||
if not before and not after:
|
||||
values = chart_data.get("values", [])
|
||||
if values and labels:
|
||||
n = min(len(labels), len(values))
|
||||
chart_data = {**chart_data, "labels": labels[:n], "before": [0] * n, "after": values[:n]}
|
||||
return make_bar_chart(chart_data, out_path, title, subtitle=subtitle)
|
||||
|
||||
elif chart_type == "bar_horizontal":
|
||||
return make_horizontal_bar(chart_data, out_path, title)
|
||||
|
||||
elif chart_type == "line_trend":
|
||||
return make_line_trend(chart_data, out_path, title)
|
||||
|
||||
elif chart_type == "pie":
|
||||
return make_pie_chart(chart_data, out_path, title)
|
||||
|
||||
elif chart_type == "stacked_bar":
|
||||
return make_stacked_bar(chart_data, out_path, title)
|
||||
|
||||
elif chart_type in ("bullet", "bullet_points"):
|
||||
bullet_points = chart_data.get("bullet_points", chart_data.get("labels", []))
|
||||
if bullet_points:
|
||||
return make_bullet_overlay(bullet_points, out_path)
|
||||
return ""
|
||||
|
||||
else:
|
||||
logger.warning(f"[ChartService] Unknown chart type: {chart_type}, falling back to bar_comparison")
|
||||
return make_bar_chart(chart_data, out_path, title, subtitle=subtitle)
|
||||
|
||||
def infer_chart_from_text(self, text: str, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Use LLM to infer chart_type and chart_data from text.
|
||||
|
||||
Returns:
|
||||
{"chart_type": str, "chart_data": dict, "title": str}
|
||||
Falls back to bullet_points with key sentences extracted from text.
|
||||
"""
|
||||
try:
|
||||
prompt = CHART_INFERENCE_USER_PROMPT.format(text=text[:3000])
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
system_prompt=CHART_INFERENCE_SYSTEM_PROMPT,
|
||||
json_struct=None,
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if isinstance(result, dict) and result.get("text"):
|
||||
raw = result["text"]
|
||||
else:
|
||||
raw = str(result) if result else ""
|
||||
|
||||
import json
|
||||
import re
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1)
|
||||
|
||||
parsed = json.loads(raw)
|
||||
|
||||
chart_type = parsed.get("chart_type", "bullet_points")
|
||||
chart_data = parsed.get("chart_data", {})
|
||||
title = parsed.get("title", "")
|
||||
|
||||
if chart_type not in VALID_CHART_TYPES:
|
||||
chart_type = _normalize_chart_type(chart_type)
|
||||
if chart_type not in VALID_CHART_TYPES:
|
||||
chart_type = "bullet_points"
|
||||
|
||||
logger.info(f"[ChartService] Inferred chart: type={chart_type}, title={title}")
|
||||
return {"chart_type": chart_type, "chart_data": chart_data, "title": title}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ChartService] Chart inference failed: {e}")
|
||||
sentences = [s.strip() for s in text.replace(".", ". ").split(". ") if len(s.strip()) > 10][:5]
|
||||
return {
|
||||
"chart_type": "bullet_points",
|
||||
"chart_data": {"bullet_points": sentences or ["No data extracted"]},
|
||||
"title": "Key Points",
|
||||
}
|
||||
|
||||
async def _analyze_chart_potential(
|
||||
self,
|
||||
text: str,
|
||||
section_heading: Optional[str] = None,
|
||||
section_key_points: Optional[List[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Stage 1: Analyze whether text has enough data for a chart.
|
||||
If not, suggest Exa search queries to find relevant data.
|
||||
|
||||
Returns:
|
||||
{"has_data": bool, "data_description": str, "suggested_chart_type": str|null, "search_queries": [...]}
|
||||
"""
|
||||
key_points_text = ""
|
||||
if section_key_points:
|
||||
key_points_text = f"\n\nKey points:\n" + "\n".join(f"- {p}" for p in section_key_points[:5])
|
||||
|
||||
prompt = CHART_ANALYSIS_USER_PROMPT.format(
|
||||
section_heading=section_heading or "Blog Section",
|
||||
key_points_section=key_points_text,
|
||||
text=text[:3000],
|
||||
)
|
||||
|
||||
try:
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
system_prompt=CHART_ANALYSIS_SYSTEM_PROMPT,
|
||||
json_struct=None,
|
||||
max_tokens=1500,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
raw = result.get("text", "") if isinstance(result, dict) else str(result) if result else ""
|
||||
|
||||
import json
|
||||
import re
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1)
|
||||
|
||||
parsed = json.loads(raw)
|
||||
|
||||
has_data = parsed.get("has_data", False)
|
||||
data_description = parsed.get("data_description", "")
|
||||
suggested_chart_type = parsed.get("suggested_chart_type")
|
||||
search_queries = parsed.get("search_queries", [])
|
||||
|
||||
if suggested_chart_type and suggested_chart_type not in VALID_CHART_TYPES:
|
||||
suggested_chart_type = _normalize_chart_type(suggested_chart_type)
|
||||
if suggested_chart_type not in VALID_CHART_TYPES:
|
||||
suggested_chart_type = None
|
||||
|
||||
logger.info(f"[ChartService] Chart analysis: has_data={has_data}, queries={search_queries}")
|
||||
return {
|
||||
"has_data": has_data,
|
||||
"data_description": data_description,
|
||||
"suggested_chart_type": suggested_chart_type,
|
||||
"search_queries": search_queries,
|
||||
"warnings": [],
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ChartService] Chart analysis failed: {e}")
|
||||
heading = section_heading or ""
|
||||
words = text.split()[:10]
|
||||
fallback_queries = [
|
||||
f"{heading} statistics data",
|
||||
f"{heading} trends report",
|
||||
f"{' '.join(words)} statistics",
|
||||
] if heading.strip() or text.strip() else []
|
||||
return {
|
||||
"has_data": False,
|
||||
"data_description": f"Analysis failed: {e}",
|
||||
"suggested_chart_type": None,
|
||||
"search_queries": fallback_queries,
|
||||
"warnings": [f"Chart analysis LLM call failed: {e}"],
|
||||
}
|
||||
|
||||
async def _search_for_chart_data(
|
||||
self,
|
||||
queries: List[str],
|
||||
section_heading: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Stage 2: Use Exa search to find relevant statistics and data for chart creation.
|
||||
|
||||
Returns:
|
||||
{"research": str, "warnings": list[str]}
|
||||
"""
|
||||
if not queries:
|
||||
return {"research": "", "warnings": []}
|
||||
|
||||
warnings = []
|
||||
try:
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
all_results = []
|
||||
search_errors = 0
|
||||
|
||||
for query in queries[:3]:
|
||||
try:
|
||||
results = await provider.simple_search(
|
||||
query=query,
|
||||
num_results=3,
|
||||
user_id=user_id,
|
||||
)
|
||||
all_results.extend(results)
|
||||
except Exception as e:
|
||||
search_errors += 1
|
||||
logger.warning(f"[ChartService] Exa search for '{query}' failed: {e}")
|
||||
continue
|
||||
|
||||
if search_errors == len(queries[:3]):
|
||||
warnings.append("All Exa search queries failed — external data search unavailable. Chart may lack supporting data.")
|
||||
|
||||
if not all_results:
|
||||
return {"research": "", "warnings": warnings}
|
||||
|
||||
research_parts = []
|
||||
seen_urls = set()
|
||||
for r in all_results:
|
||||
url = r.get("url", "")
|
||||
if url in seen_urls:
|
||||
continue
|
||||
seen_urls.add(url)
|
||||
title = r.get("title", "Untitled")
|
||||
text = r.get("text", "")[:500]
|
||||
if text:
|
||||
research_parts.append(f"- {title} ({url}): {text}")
|
||||
|
||||
if not research_parts:
|
||||
return {"research": "", "warnings": warnings}
|
||||
|
||||
return {"research": "\n".join(research_parts), "warnings": warnings}
|
||||
|
||||
except ImportError:
|
||||
msg = "Exa provider not available — skipping external data search."
|
||||
logger.warning(f"[ChartService] {msg}")
|
||||
warnings.append(msg)
|
||||
return {"research": "", "warnings": warnings}
|
||||
except Exception as e:
|
||||
msg = f"Chart data search failed: {e}"
|
||||
logger.error(f"[ChartService] {msg}")
|
||||
warnings.append(msg)
|
||||
return {"research": "", "warnings": warnings}
|
||||
|
||||
async def _synthesize_chart_from_research(
|
||||
self,
|
||||
text: str,
|
||||
research: str,
|
||||
section_heading: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Stage 3: Generate chart spec from text + research data using LLM.
|
||||
|
||||
Returns:
|
||||
{"chart_type": str, "chart_data": dict, "title": str, "source": str}
|
||||
"""
|
||||
try:
|
||||
prompt = CHART_SYNTHESIS_USER_PROMPT.format(
|
||||
text=text[:2000],
|
||||
research=research[:3000],
|
||||
)
|
||||
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
system_prompt=CHART_SYNTHESIS_SYSTEM_PROMPT,
|
||||
json_struct=None,
|
||||
max_tokens=2000,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
raw = result.get("text", "") if isinstance(result, dict) else str(result) if result else ""
|
||||
|
||||
import json
|
||||
import re
|
||||
raw = raw.strip()
|
||||
if raw.startswith("```"):
|
||||
match = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1)
|
||||
|
||||
parsed = json.loads(raw)
|
||||
|
||||
chart_type = parsed.get("chart_type", "bullet_points")
|
||||
chart_data = parsed.get("chart_data", {})
|
||||
title = parsed.get("title", "")
|
||||
source = parsed.get("source", "")
|
||||
|
||||
if chart_type not in VALID_CHART_TYPES:
|
||||
chart_type = _normalize_chart_type(chart_type)
|
||||
if chart_type not in VALID_CHART_TYPES:
|
||||
chart_type = "bullet_points"
|
||||
|
||||
if source and isinstance(chart_data, dict):
|
||||
chart_data["source"] = source
|
||||
|
||||
logger.info(f"[ChartService] Synthesized chart: type={chart_type}, title={title}")
|
||||
return {"chart_type": chart_type, "chart_data": chart_data, "title": title}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[ChartService] Chart synthesis failed: {e}")
|
||||
sentences = [s.strip() for s in text.replace(".", ". ").split(". ") if len(s.strip()) > 10][:5]
|
||||
return {
|
||||
"chart_type": "bullet_points",
|
||||
"chart_data": {"bullet_points": sentences or ["No data available"]},
|
||||
"title": section_heading or "Key Points",
|
||||
}
|
||||
|
||||
async def infer_chart_with_research(
|
||||
self,
|
||||
text: str,
|
||||
section_heading: Optional[str] = None,
|
||||
section_key_points: Optional[List[str]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
3-stage chart inference pipeline:
|
||||
1. Analyze text for chart potential — does it have data? If not, what to search for?
|
||||
2. If no data, search Exa for relevant statistics.
|
||||
3. Synthesize chart spec from text + research data.
|
||||
|
||||
Returns:
|
||||
{"chart_type": str, "chart_data": dict, "title": str, "warnings": list[str]}
|
||||
"""
|
||||
warnings = []
|
||||
logger.info(f"[ChartService] infer_chart_with_research: heading={section_heading}, text_len={len(text)}, user={user_id}")
|
||||
|
||||
# Stage 1: Analyze
|
||||
analysis = await self._analyze_chart_potential(
|
||||
text=text,
|
||||
section_heading=section_heading,
|
||||
section_key_points=section_key_points,
|
||||
user_id=user_id,
|
||||
)
|
||||
warnings.extend(analysis.get("warnings", []))
|
||||
|
||||
if analysis.get("has_data") and analysis.get("suggested_chart_type"):
|
||||
# Text has enough data — do direct inference
|
||||
logger.info("[ChartService] Text has sufficient data, using direct inference")
|
||||
result = self.infer_chart_from_text(text, user_id=user_id)
|
||||
if analysis.get("suggested_chart_type") and result.get("chart_type") == "bullet_points":
|
||||
result["chart_type"] = analysis["suggested_chart_type"]
|
||||
result["warnings"] = warnings
|
||||
return result
|
||||
|
||||
# Stage 2: Search for data
|
||||
search_queries = analysis.get("search_queries", [])
|
||||
if not search_queries:
|
||||
# Build queries from section heading + text keywords
|
||||
heading = section_heading or ""
|
||||
words = text.split()[:10]
|
||||
search_queries = [
|
||||
f"{heading} statistics data",
|
||||
f"{heading} trends report",
|
||||
f"{' '.join(words)} statistics",
|
||||
]
|
||||
|
||||
logger.info(f"[ChartService] Searching Exa for chart data, queries: {search_queries}")
|
||||
search_result = await self._search_for_chart_data(
|
||||
queries=search_queries,
|
||||
section_heading=section_heading,
|
||||
user_id=user_id,
|
||||
)
|
||||
research = search_result.get("research", "")
|
||||
warnings.extend(search_result.get("warnings", []))
|
||||
|
||||
if not research:
|
||||
logger.warning("[ChartService] No research data found, falling back to text-only inference")
|
||||
result = self.infer_chart_from_text(text, user_id=user_id)
|
||||
result["warnings"] = warnings
|
||||
return result
|
||||
|
||||
# Stage 3: Synthesize chart from text + research
|
||||
logger.info("[ChartService] Synthesizing chart from text + research data")
|
||||
result = await self._synthesize_chart_from_research(
|
||||
text=text,
|
||||
research=research,
|
||||
section_heading=section_heading,
|
||||
user_id=user_id,
|
||||
)
|
||||
result["warnings"] = warnings
|
||||
return result
|
||||
|
||||
async def generate_chart_from_text(
|
||||
self,
|
||||
text: str,
|
||||
user_id: Optional[str] = None,
|
||||
chart_id: Optional[str] = None,
|
||||
section_heading: Optional[str] = None,
|
||||
section_key_points: Optional[List[str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
End-to-end: analyze text, optionally research data, then infer and render chart.
|
||||
|
||||
Uses the 3-stage pipeline (analyze → search → synthesize) for richer charts
|
||||
with real data from Exa when the original text lacks statistics.
|
||||
|
||||
Returns:
|
||||
{"path": str, "chart_id": str, "filename": str, "chart_type": str, "chart_data": dict, "title": str}
|
||||
"""
|
||||
inference = await self.infer_chart_with_research(
|
||||
text=text,
|
||||
section_heading=section_heading,
|
||||
section_key_points=section_key_points,
|
||||
user_id=user_id,
|
||||
)
|
||||
result = self.generate_chart(
|
||||
chart_data=inference["chart_data"],
|
||||
chart_type=inference["chart_type"],
|
||||
title=inference["title"],
|
||||
chart_id=chart_id,
|
||||
)
|
||||
result["chart_type"] = inference["chart_type"]
|
||||
result["chart_data"] = inference["chart_data"]
|
||||
result["title"] = inference["title"]
|
||||
result["warnings"] = inference.get("warnings", [])
|
||||
return result
|
||||
|
||||
|
||||
# Per-user service instances
|
||||
_chart_service_instances: Dict[str, ChartService] = {}
|
||||
|
||||
|
||||
def get_chart_service(output_dir: Optional[str] = None, user_id: Optional[str] = None) -> ChartService:
|
||||
"""Get or create ChartService for the given user."""
|
||||
cache_key = output_dir or user_id or "default"
|
||||
if cache_key not in _chart_service_instances:
|
||||
_chart_service_instances[cache_key] = ChartService(output_dir=output_dir, user_id=user_id)
|
||||
return _chart_service_instances[cache_key]
|
||||
404
backend/services/gsc_brainstorm_service.py
Normal file
404
backend/services/gsc_brainstorm_service.py
Normal file
@@ -0,0 +1,404 @@
|
||||
"""
|
||||
GSC Brainstorm Service for ALwrity.
|
||||
|
||||
Analyzes Google Search Console data to suggest blog topics the user should write about.
|
||||
Combines rule-based heuristics (high-impression/low-CTR keywords, near-page-1 positions)
|
||||
with LLM-powered strategic recommendations tailored to the user's topic intent.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Any, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.gsc_service import GSCService
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
class GSCBrainstormService:
|
||||
"""
|
||||
Suggests blog topics based on the user's live GSC data.
|
||||
|
||||
Flow:
|
||||
1. Fetch real GSC search analytics (query + page data, 30 days)
|
||||
2. Apply rule-based filters (Content Optimization, Content Enhancement, Keyword Gap)
|
||||
3. Generate LLM-powered strategic recommendations contextualised to the user's keywords
|
||||
4. Return structured results
|
||||
"""
|
||||
|
||||
def __init__(self, gsc_service: GSCService = None):
|
||||
self.gsc_service = gsc_service or GSCService()
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Public entry point
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def brainstorm_topics(
|
||||
self,
|
||||
user_id: str,
|
||||
keywords: str,
|
||||
site_url: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate blog topic suggestions from the user's GSC data.
|
||||
|
||||
Args:
|
||||
user_id: Clerk user ID (must have GSC connected).
|
||||
keywords: User's 3+ word topic intent (e.g. "content marketing strategy").
|
||||
site_url: Optional site URL; auto-selected from user's first GSC site if omitted.
|
||||
|
||||
Returns:
|
||||
Dict with content_opportunities, keyword_gaps, ai_recommendations, summary.
|
||||
"""
|
||||
self._user_id = user_id
|
||||
# 1. Resolve site_url
|
||||
if not site_url:
|
||||
sites = self.gsc_service.get_site_list(user_id)
|
||||
if not sites:
|
||||
return {
|
||||
"error": "No GSC sites found. Make sure your site is verified in Google Search Console.",
|
||||
"content_opportunities": [],
|
||||
"keyword_gaps": [],
|
||||
"ai_recommendations": {},
|
||||
"summary": {},
|
||||
}
|
||||
site_url = sites[0].get("siteUrl", "")
|
||||
|
||||
# 2. Fetch GSC analytics (30 days)
|
||||
end_date = datetime.now().strftime("%Y-%m-%d")
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime("%Y-%m-%d")
|
||||
|
||||
analytics = self.gsc_service.get_search_analytics(
|
||||
user_id=user_id,
|
||||
site_url=site_url,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
if "error" in analytics:
|
||||
return {
|
||||
"error": analytics.get("error", "Failed to fetch GSC data"),
|
||||
"content_opportunities": [],
|
||||
"keyword_gaps": [],
|
||||
"ai_recommendations": {},
|
||||
"summary": {},
|
||||
}
|
||||
|
||||
# 3. Parse GSC rows into structured data
|
||||
query_rows = analytics.get("query_data", {}).get("rows", [])
|
||||
page_rows = analytics.get("page_data", {}).get("rows", [])
|
||||
|
||||
keywords_data = self._parse_query_rows(query_rows)
|
||||
pages_data = self._parse_page_rows(page_rows)
|
||||
|
||||
if not keywords_data:
|
||||
return {
|
||||
"error": "No keyword data available for the selected period.",
|
||||
"content_opportunities": [],
|
||||
"keyword_gaps": [],
|
||||
"ai_recommendations": {},
|
||||
"summary": {
|
||||
"site_url": site_url,
|
||||
"date_range": {"start": start_date, "end": end_date},
|
||||
"total_keywords_analyzed": 0,
|
||||
},
|
||||
}
|
||||
|
||||
# 4. Rule-based analysis
|
||||
content_opportunities = self._identify_content_opportunities(keywords_data)
|
||||
keyword_gaps = self._identify_keyword_gaps(keywords_data)
|
||||
|
||||
# 5. Summary metrics
|
||||
summary = self._compute_summary(keywords_data, pages_data, site_url, start_date, end_date)
|
||||
|
||||
# 6. AI recommendations (best-effort; don't fail the whole request on LLM error)
|
||||
ai_recommendations = self._generate_ai_recommendations(
|
||||
keywords_data, pages_data, summary, keywords
|
||||
)
|
||||
|
||||
return {
|
||||
"content_opportunities": content_opportunities,
|
||||
"keyword_gaps": keyword_gaps,
|
||||
"ai_recommendations": ai_recommendations,
|
||||
"summary": summary,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Data parsing helpers
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@staticmethod
|
||||
def _parse_query_rows(rows: List[Dict]) -> List[Dict[str, Any]]:
|
||||
parsed = []
|
||||
for row in rows:
|
||||
keys = row.get("keys", [])
|
||||
keyword = keys[0] if len(keys) >= 1 else "(not set)"
|
||||
parsed.append({
|
||||
"keyword": keyword,
|
||||
"clicks": row.get("clicks", 0),
|
||||
"impressions": row.get("impressions", 0),
|
||||
"ctr": round(row.get("ctr", 0) * 100, 2),
|
||||
"position": round(row.get("position", 0), 1),
|
||||
})
|
||||
return parsed
|
||||
|
||||
@staticmethod
|
||||
def _parse_page_rows(rows: List[Dict]) -> List[Dict[str, Any]]:
|
||||
parsed = []
|
||||
for row in rows:
|
||||
keys = row.get("keys", [])
|
||||
page = keys[0] if len(keys) >= 1 else "(not set)"
|
||||
parsed.append({
|
||||
"page": page,
|
||||
"clicks": row.get("clicks", 0),
|
||||
"impressions": row.get("impressions", 0),
|
||||
"ctr": round(row.get("ctr", 0) * 100, 2),
|
||||
"position": round(row.get("position", 0), 1),
|
||||
})
|
||||
return parsed
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Rule-based opportunity identification
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@staticmethod
|
||||
def _identify_content_opportunities(
|
||||
keywords_data: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
opportunities: List[Dict[str, Any]] = []
|
||||
|
||||
# Rule 1: Content Optimization — high impressions, low CTR
|
||||
for kw in keywords_data:
|
||||
if kw["impressions"] > 500 and kw["ctr"] < 3:
|
||||
opportunities.append({
|
||||
"type": "Content Optimization",
|
||||
"keyword": kw["keyword"],
|
||||
"opportunity": (
|
||||
f"Optimize existing content for '{kw['keyword']}' "
|
||||
f"to improve CTR from {kw['ctr']:.1f}% "
|
||||
f"(position {kw['position']:.1f})"
|
||||
),
|
||||
"potential_impact": "High",
|
||||
"current_position": kw["position"],
|
||||
"impressions": kw["impressions"],
|
||||
"priority": "High" if kw["impressions"] > 1000 else "Medium",
|
||||
})
|
||||
|
||||
# Rule 2: Content Enhancement — positions 11-20 with decent impressions
|
||||
for kw in keywords_data:
|
||||
if 10 < kw["position"] <= 20 and kw["impressions"] > 100:
|
||||
opportunities.append({
|
||||
"type": "Content Enhancement",
|
||||
"keyword": kw["keyword"],
|
||||
"opportunity": (
|
||||
f"Enhance content for '{kw['keyword']}' to move from "
|
||||
f"position {kw['position']:.1f} to the first page"
|
||||
),
|
||||
"potential_impact": "Medium",
|
||||
"current_position": kw["position"],
|
||||
"impressions": kw["impressions"],
|
||||
"priority": "Medium",
|
||||
})
|
||||
|
||||
# Sort by impressions descending, keep top 10
|
||||
opportunities.sort(key=lambda x: x["impressions"], reverse=True)
|
||||
return opportunities[:10]
|
||||
|
||||
@staticmethod
|
||||
def _identify_keyword_gaps(
|
||||
keywords_data: List[Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
gaps: List[Dict[str, Any]] = []
|
||||
|
||||
for kw in keywords_data:
|
||||
if 4 <= kw["position"] <= 20 and kw["impressions"] >= 50:
|
||||
gaps.append({
|
||||
"keyword": kw["keyword"],
|
||||
"position": kw["position"],
|
||||
"impressions": kw["impressions"],
|
||||
})
|
||||
|
||||
gaps.sort(key=lambda x: x["impressions"], reverse=True)
|
||||
return gaps[:10]
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Summary metrics
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
@staticmethod
|
||||
def _compute_summary(
|
||||
keywords_data: List[Dict],
|
||||
pages_data: List[Dict],
|
||||
site_url: str,
|
||||
start_date: str,
|
||||
end_date: str,
|
||||
) -> Dict[str, Any]:
|
||||
total_impressions = sum(kw["impressions"] for kw in keywords_data)
|
||||
total_clicks = sum(kw["clicks"] for kw in keywords_data)
|
||||
avg_ctr = round((total_clicks / total_impressions * 100) if total_impressions else 0, 2)
|
||||
avg_position = round(
|
||||
sum(kw["position"] for kw in keywords_data) / len(keywords_data), 1
|
||||
) if keywords_data else 0
|
||||
|
||||
pos_1_3 = len([kw for kw in keywords_data if kw["position"] <= 3])
|
||||
pos_4_10 = len([kw for kw in keywords_data if 3 < kw["position"] <= 10])
|
||||
pos_11_20 = len([kw for kw in keywords_data if 10 < kw["position"] <= 20])
|
||||
pos_21_plus = len([kw for kw in keywords_data if kw["position"] > 20])
|
||||
|
||||
top_keywords = sorted(keywords_data, key=lambda x: x["impressions"], reverse=True)[:5]
|
||||
top_pages = sorted(pages_data, key=lambda x: x["clicks"], reverse=True)[:3]
|
||||
|
||||
return {
|
||||
"site_url": site_url,
|
||||
"date_range": {"start": start_date, "end": end_date},
|
||||
"total_keywords_analyzed": len(keywords_data),
|
||||
"total_impressions": total_impressions,
|
||||
"total_clicks": total_clicks,
|
||||
"avg_ctr": avg_ctr,
|
||||
"avg_position": avg_position,
|
||||
"keyword_distribution": {
|
||||
"positions_1_3": pos_1_3,
|
||||
"positions_4_10": pos_4_10,
|
||||
"positions_11_20": pos_11_20,
|
||||
"positions_21_plus": pos_21_plus,
|
||||
},
|
||||
"top_keywords": [
|
||||
{"keyword": kw["keyword"], "impressions": kw["impressions"], "position": kw["position"]}
|
||||
for kw in top_keywords
|
||||
],
|
||||
"top_pages": [
|
||||
{"page": pg["page"], "clicks": pg["clicks"], "impressions": pg["impressions"]}
|
||||
for pg in top_pages
|
||||
],
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# AI-powered strategic recommendations
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def _generate_ai_recommendations(
|
||||
self,
|
||||
keywords_data: List[Dict],
|
||||
pages_data: List[Dict],
|
||||
summary: Dict,
|
||||
user_keywords: str,
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
top_kw = ", ".join(kw["keyword"] for kw in summary.get("top_keywords", []))
|
||||
dist = summary.get("keyword_distribution", {})
|
||||
|
||||
prompt = f"""Analyze this Google Search Console data and suggest blog topics the user should write about.
|
||||
|
||||
USER'S TOPIC INTENT: "{user_keywords}"
|
||||
|
||||
SEARCH PERFORMANCE SUMMARY:
|
||||
- Total Keywords Tracked: {summary.get('total_keywords_analyzed', 0)}
|
||||
- Total Impressions: {summary.get('total_impressions', 0):,}
|
||||
- Total Clicks: {summary.get('total_clicks', 0):,}
|
||||
- Average CTR: {summary.get('avg_ctr', 0):.2f}%
|
||||
- Average Position: {summary.get('avg_position', 0):.1f}
|
||||
|
||||
TOP PERFORMING KEYWORDS:
|
||||
{top_kw}
|
||||
|
||||
KEYWORD POSITION DISTRIBUTION:
|
||||
- Positions 1-3: {dist.get('positions_1_3', 0)}
|
||||
- Positions 4-10: {dist.get('positions_4_10', 0)}
|
||||
- Positions 11-20: {dist.get('positions_11_20', 0)}
|
||||
- Positions 21+: {dist.get('positions_21_plus', 0)}
|
||||
|
||||
Based on this data, provide:
|
||||
|
||||
1. IMMEDIATE TOPIC OPPORTUNITIES (0-30 days):
|
||||
- Specific blog post titles the user should write
|
||||
- Each tied to a keyword opportunity from the data
|
||||
- 3-5 suggestions
|
||||
|
||||
2. CONTENT STRATEGY TOPICS (1-3 months):
|
||||
- New topic clusters to build authority
|
||||
- Content pillar ideas
|
||||
- 3-5 suggestions
|
||||
|
||||
3. LONG-TERM CONTENT VISION (3-12 months):
|
||||
- Market expansion topics
|
||||
- Authority-building content ideas
|
||||
- 3-5 suggestions
|
||||
|
||||
IMPORTANT: Relate every topic suggestion to the user's interest in "{user_keywords}".
|
||||
Return your response in this exact JSON format:
|
||||
{{
|
||||
"immediate_opportunities": ["topic 1", "topic 2", "topic 3"],
|
||||
"content_strategy": ["strategy 1", "strategy 2", "strategy 3"],
|
||||
"long_term_strategy": ["vision 1", "vision 2", "vision 3"]
|
||||
}}"""
|
||||
|
||||
system_prompt = (
|
||||
"You are an enterprise SEO content strategist. Provide specific, data-driven "
|
||||
"blog topic suggestions that will improve the user's search performance. "
|
||||
"Always respond with valid JSON matching the requested format."
|
||||
)
|
||||
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt,
|
||||
user_id=getattr(self, '_user_id', None),
|
||||
flow_type="gsc_brainstorm",
|
||||
)
|
||||
|
||||
if result:
|
||||
parsed = self._parse_ai_response(result)
|
||||
if parsed:
|
||||
return parsed
|
||||
|
||||
return self._fallback_ai_recommendations(keywords_data)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"GSC brainstorm AI recommendations failed: {e}")
|
||||
return self._fallback_ai_recommendations(keywords_data)
|
||||
|
||||
@staticmethod
|
||||
def _parse_ai_response(raw: str) -> Optional[Dict[str, List[str]]]:
|
||||
try:
|
||||
json_start = raw.find("{")
|
||||
json_end = raw.rfind("}") + 1
|
||||
if json_start == -1 or json_end == 0:
|
||||
return None
|
||||
|
||||
chunk = raw[json_start:json_end]
|
||||
parsed = json.loads(chunk)
|
||||
|
||||
return {
|
||||
"immediate_opportunities": parsed.get("immediate_opportunities", [])[:5],
|
||||
"content_strategy": parsed.get("content_strategy", [])[:5],
|
||||
"long_term_strategy": parsed.get("long_term_strategy", [])[:5],
|
||||
}
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
logger.warning(f"Failed to parse AI brainstorm response as JSON: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _fallback_ai_recommendations(
|
||||
keywords_data: List[Dict],
|
||||
) -> Dict[str, Any]:
|
||||
top_kw = keywords_data[:3] if keywords_data else []
|
||||
immediate = []
|
||||
for kw in top_kw:
|
||||
immediate.append(
|
||||
f"Write a comprehensive guide on '{kw['keyword']}' "
|
||||
f"(currently at position {kw['position']:.1f} with "
|
||||
f"{kw['impressions']} impressions)"
|
||||
)
|
||||
|
||||
return {
|
||||
"immediate_opportunities": immediate or ["No keyword data available for recommendations"],
|
||||
"content_strategy": [
|
||||
"Develop topic clusters around your top-performing keywords",
|
||||
"Create comparison and vs-style content for competitive terms",
|
||||
"Build FAQ sections targeting question-based queries",
|
||||
],
|
||||
"long_term_strategy": [
|
||||
"Build domain authority through pillar content",
|
||||
"Expand into adjacent topic areas",
|
||||
"Develop thought leadership content series",
|
||||
],
|
||||
}
|
||||
@@ -1,9 +1,9 @@
|
||||
"""
|
||||
Hallucination Detector Service
|
||||
|
||||
This service implements fact-checking functionality using Exa.ai API
|
||||
to detect and verify claims in AI-generated content, similar to the
|
||||
Exa.ai demo implementation.
|
||||
Implements fact-checking using Exa.ai for evidence search and the
|
||||
configured LLM provider (via GPT_PROVIDER) for claim extraction and assessment.
|
||||
Respects GPT_PROVIDER env var: google, wavespeed, openai, huggingface.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -11,15 +11,9 @@ import logging
|
||||
from typing import List, Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
import requests
|
||||
import os
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
try:
|
||||
from google import genai
|
||||
GOOGLE_GENAI_AVAILABLE = True
|
||||
except Exception:
|
||||
GOOGLE_GENAI_AVAILABLE = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -44,70 +38,121 @@ class HallucinationResult:
|
||||
insufficient_claims: int
|
||||
timestamp: str
|
||||
|
||||
|
||||
def _get_llm_provider_info() -> Dict[str, str]:
|
||||
"""Determine the LLM provider from GPT_PROVIDER env var."""
|
||||
provider_env = os.getenv('GPT_PROVIDER', 'google').lower().strip()
|
||||
provider = provider_env.split(',')[0].strip() if provider_env else 'google'
|
||||
|
||||
if provider in ('wavespeed', 'wave'):
|
||||
return {'provider': 'wavespeed', 'name': 'WaveSpeed'}
|
||||
elif provider in ('gemini', 'google'):
|
||||
return {'provider': 'google', 'name': 'Gemini'}
|
||||
elif provider in ('openai', 'gpt'):
|
||||
return {'provider': 'openai', 'name': 'OpenAI'}
|
||||
elif provider in ('hf_response_api', 'huggingface', 'hf'):
|
||||
return {'provider': 'huggingface', 'name': 'HuggingFace'}
|
||||
else:
|
||||
return {'provider': provider, 'name': provider.capitalize()}
|
||||
|
||||
|
||||
class HallucinationDetector:
|
||||
"""
|
||||
Hallucination detector using Exa.ai for fact-checking.
|
||||
|
||||
Implements the three-step process from Exa.ai demo:
|
||||
Hallucination detector using Exa.ai for evidence search
|
||||
and the configured LLM provider (GPT_PROVIDER) for claim extraction/assessment.
|
||||
|
||||
Implements the three-step process:
|
||||
1. Extract verifiable claims from text
|
||||
2. Search for evidence using Exa.ai
|
||||
3. Verify claims against sources
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.exa_api_key = os.getenv('EXA_API_KEY')
|
||||
self.gemini_api_key = os.getenv('GEMINI_API_KEY')
|
||||
|
||||
if not self.exa_api_key:
|
||||
logger.warning("EXA_API_KEY not found. Hallucination detection will be limited.")
|
||||
|
||||
if not self.gemini_api_key:
|
||||
logger.warning("GEMINI_API_KEY not found. Falling back to heuristic claim extraction.")
|
||||
|
||||
# Initialize Gemini client for claim extraction and assessment
|
||||
self.gemini_client = genai.Client(api_key=self.gemini_api_key) if (GOOGLE_GENAI_AVAILABLE and self.gemini_api_key) else None
|
||||
|
||||
# Rate limiting to prevent API abuse
|
||||
self._llm_provider_info = _get_llm_provider_info()
|
||||
|
||||
# Check that at least one LLM key is available for the configured provider
|
||||
self._check_provider_keys()
|
||||
|
||||
# Rate limiting
|
||||
self.daily_api_calls = 0
|
||||
self.daily_limit = 20 # Max 20 API calls per day for fact checking
|
||||
self.daily_limit = 20
|
||||
self.last_reset_date = None
|
||||
|
||||
|
||||
def _check_provider_keys(self):
|
||||
"""Check that API keys for the configured provider are available."""
|
||||
provider = self._llm_provider_info['provider']
|
||||
if provider == 'google':
|
||||
key = os.getenv('GEMINI_API_KEY')
|
||||
if not key:
|
||||
logger.warning(f"GEMINI_API_KEY not found. Hallucination detection will fail for provider '{provider}'.")
|
||||
elif provider == 'wavespeed':
|
||||
key = os.getenv('WAVESPEED_API_KEY')
|
||||
if not key:
|
||||
logger.warning(f"WAVESPEED_API_KEY not found. Hallucination detection will fail for provider '{provider}'.")
|
||||
elif provider == 'openai':
|
||||
key = os.getenv('OPENAI_API_KEY')
|
||||
if not key:
|
||||
logger.warning(f"OPENAI_API_KEY not found. Hallucination detection will fail for provider '{provider}'.")
|
||||
# huggingface uses serverless endpoint or HF token
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
return self._llm_provider_info['name']
|
||||
|
||||
@property
|
||||
def provider_key(self) -> str:
|
||||
return self._llm_provider_info['provider']
|
||||
|
||||
def _check_rate_limit(self) -> bool:
|
||||
"""Check if we're within daily API usage limits."""
|
||||
from datetime import date
|
||||
|
||||
today = date.today()
|
||||
|
||||
# Reset counter if it's a new day
|
||||
if self.last_reset_date != today:
|
||||
self.daily_api_calls = 0
|
||||
self.last_reset_date = today
|
||||
|
||||
# Check if we've exceeded the limit
|
||||
if self.daily_api_calls >= self.daily_limit:
|
||||
logger.warning(f"Daily API limit reached ({self.daily_limit} calls). Fact checking disabled for today.")
|
||||
return False
|
||||
|
||||
# Increment counter for this API call
|
||||
self.daily_api_calls += 1
|
||||
logger.info(f"Fact check API call #{self.daily_api_calls}/{self.daily_limit} today")
|
||||
return True
|
||||
|
||||
async def detect_hallucinations(self, text: str) -> HallucinationResult:
|
||||
|
||||
def _generate_text(self, prompt: str, system_prompt: Optional[str] = None, user_id: str = None) -> str:
|
||||
"""Generate text using the configured LLM provider (respects GPT_PROVIDER)."""
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
system_prompt=system_prompt or "You are a precise fact-checking assistant. Respond only with valid JSON as instructed.",
|
||||
max_tokens=4000,
|
||||
user_id=user_id,
|
||||
)
|
||||
return result
|
||||
|
||||
async def _generate_text_async(self, prompt: str, system_prompt: Optional[str] = None, user_id: str = None) -> str:
|
||||
"""Async wrapper for _generate_text."""
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
result = await loop.run_in_executor(
|
||||
executor,
|
||||
lambda: self._generate_text(prompt, system_prompt, user_id)
|
||||
)
|
||||
return result
|
||||
|
||||
async def detect_hallucinations(self, text: str, user_id: str = None) -> HallucinationResult:
|
||||
"""
|
||||
Main method to detect hallucinations in the given text.
|
||||
|
||||
|
||||
Args:
|
||||
text: The text to analyze for factual accuracy
|
||||
|
||||
|
||||
Returns:
|
||||
HallucinationResult with claims analysis and confidence scores
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Starting hallucination detection for text of length: {len(text)}")
|
||||
logger.info(f"Text sample: {text[:200]}...")
|
||||
|
||||
# Check rate limits first
|
||||
|
||||
if not self._check_rate_limit():
|
||||
return HallucinationResult(
|
||||
claims=[],
|
||||
@@ -118,17 +163,11 @@ class HallucinationDetector:
|
||||
insufficient_claims=0,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
# Validate required API keys
|
||||
if not self.gemini_api_key:
|
||||
raise Exception("GEMINI_API_KEY not configured. Cannot perform hallucination detection.")
|
||||
if not self.exa_api_key:
|
||||
raise Exception("EXA_API_KEY not configured. Cannot search for evidence.")
|
||||
|
||||
|
||||
# Step 1: Extract claims from text
|
||||
claims_texts = await self._extract_claims(text)
|
||||
claims_texts = await self._extract_claims(text, user_id=user_id)
|
||||
logger.info(f"Extracted {len(claims_texts)} claims from text: {claims_texts}")
|
||||
|
||||
|
||||
if not claims_texts:
|
||||
logger.warning("No verifiable claims found in text")
|
||||
return HallucinationResult(
|
||||
@@ -140,22 +179,18 @@ class HallucinationDetector:
|
||||
insufficient_claims=0,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
# Step 2 & 3: Verify claims in batch to reduce API calls
|
||||
verified_claims = await self._verify_claims_batch(claims_texts)
|
||||
|
||||
|
||||
# Step 2 & 3: Verify claims in batch
|
||||
verified_claims = await self._verify_claims_batch(claims_texts, user_id=user_id)
|
||||
|
||||
# Calculate overall metrics
|
||||
total_claims = len(verified_claims)
|
||||
supported_claims = sum(1 for c in verified_claims if c.assessment == "supported")
|
||||
refuted_claims = sum(1 for c in verified_claims if c.assessment == "refuted")
|
||||
insufficient_claims = sum(1 for c in verified_claims if c.assessment == "insufficient_information")
|
||||
|
||||
# Calculate overall confidence (weighted average)
|
||||
if total_claims > 0:
|
||||
overall_confidence = sum(c.confidence for c in verified_claims) / total_claims
|
||||
else:
|
||||
overall_confidence = 0.0
|
||||
|
||||
|
||||
overall_confidence = sum(c.confidence for c in verified_claims) / total_claims if total_claims > 0 else 0.0
|
||||
|
||||
result = HallucinationResult(
|
||||
claims=verified_claims,
|
||||
overall_confidence=overall_confidence,
|
||||
@@ -165,120 +200,67 @@ class HallucinationDetector:
|
||||
insufficient_claims=insufficient_claims,
|
||||
timestamp=datetime.now().isoformat()
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"Hallucination detection completed. Overall confidence: {overall_confidence:.2f}")
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in hallucination detection: {str(e)}")
|
||||
raise Exception(f"Hallucination detection failed: {str(e)}")
|
||||
|
||||
async def _extract_claims(self, text: str) -> List[str]:
|
||||
"""
|
||||
Extract verifiable claims from text using LLM.
|
||||
|
||||
Args:
|
||||
text: Input text to extract claims from
|
||||
|
||||
Returns:
|
||||
List of claim strings
|
||||
"""
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available. Cannot extract claims without AI provider.")
|
||||
|
||||
|
||||
async def _extract_claims(self, text: str, user_id: str = None) -> List[str]:
|
||||
"""Extract verifiable claims from text using LLM."""
|
||||
try:
|
||||
prompt = (
|
||||
"Extract verifiable factual claims from the following text. "
|
||||
"A verifiable claim is a statement that can be checked against external sources for accuracy.\n\n"
|
||||
"Return ONLY a valid JSON array of strings, where each string is a single verifiable claim.\n\n"
|
||||
"Examples of GOOD verifiable claims:\n"
|
||||
"- \"The company was founded in 2020\"\n"
|
||||
"- \"Sales increased by 25% last quarter\"\n"
|
||||
"- \"The product has 10,000 users\"\n"
|
||||
"- \"The market size is $50 billion\"\n"
|
||||
"- \"The software supports 15 languages\"\n"
|
||||
"- \"The company has offices in 5 countries\"\n\n"
|
||||
'- "The company was founded in 2020"\n'
|
||||
'- "Sales increased by 25% last quarter"\n'
|
||||
'- "The product has 10,000 users"\n\n'
|
||||
"Examples of BAD claims (opinions, subjective statements):\n"
|
||||
"- \"This is the best product\"\n"
|
||||
"- \"Customers love our service\"\n"
|
||||
"- \"We are innovative\"\n"
|
||||
"- \"The future looks bright\"\n\n"
|
||||
'- "This is the best product"\n'
|
||||
'- "Customers love our service"\n\n'
|
||||
"IMPORTANT: Extract at least 2-3 verifiable claims if possible. "
|
||||
"Look for specific facts, numbers, dates, locations, and measurable statements.\n\n"
|
||||
f"Text to analyze: {text}\n\n"
|
||||
"Return only the JSON array of verifiable claims:"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=prompt
|
||||
))
|
||||
|
||||
if not resp or not resp.text:
|
||||
raise Exception("Empty response from Gemini API")
|
||||
|
||||
claims_text = resp.text.strip()
|
||||
logger.info(f"Raw Gemini response for claims: {claims_text[:200]}...")
|
||||
|
||||
# Try to extract JSON from the response
|
||||
try:
|
||||
claims = json.loads(claims_text)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON array in the response (handle markdown code blocks)
|
||||
import re
|
||||
# First try to extract from markdown code blocks
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', claims_text, re.DOTALL)
|
||||
if code_block_match:
|
||||
claims = json.loads(code_block_match.group(1))
|
||||
else:
|
||||
# Try to find JSON array directly
|
||||
json_match = re.search(r'\[.*?\]', claims_text, re.DOTALL)
|
||||
if json_match:
|
||||
claims = json.loads(json_match.group())
|
||||
else:
|
||||
raise Exception(f"Could not parse JSON from Gemini response: {claims_text[:100]}")
|
||||
|
||||
|
||||
result_text = await self._generate_text_async(prompt, user_id=user_id)
|
||||
logger.info(f"Raw LLM response for claims: {result_text[:200]}...")
|
||||
|
||||
claims = self._parse_json_from_response(result_text, expect_array=True)
|
||||
|
||||
if isinstance(claims, list):
|
||||
valid_claims = [claim for claim in claims if isinstance(claim, str) and claim.strip()]
|
||||
logger.info(f"Successfully extracted {len(valid_claims)} claims")
|
||||
return valid_claims
|
||||
else:
|
||||
raise Exception(f"Expected JSON array, got: {type(claims)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting claims: {str(e)}")
|
||||
raise Exception(f"Failed to extract claims: {str(e)}")
|
||||
|
||||
|
||||
async def _verify_claims_batch(self, claims: List[str]) -> List[Claim]:
|
||||
"""
|
||||
Verify multiple claims in batch to reduce API calls.
|
||||
|
||||
Args:
|
||||
claims: List of claims to verify
|
||||
|
||||
Returns:
|
||||
List of Claim objects with verification results
|
||||
"""
|
||||
|
||||
async def _verify_claims_batch(self, claims: List[str], user_id: str = None) -> List[Claim]:
|
||||
"""Verify multiple claims in batch to reduce API calls."""
|
||||
try:
|
||||
logger.info(f"Starting batch verification of {len(claims)} claims")
|
||||
|
||||
# Limit to maximum 3 claims to prevent excessive API usage
|
||||
max_claims = min(len(claims), 3)
|
||||
claims_to_verify = claims[:max_claims]
|
||||
|
||||
|
||||
if len(claims) > max_claims:
|
||||
logger.warning(f"Limited verification to {max_claims} claims to prevent API rate limits")
|
||||
|
||||
# Step 1: Search for evidence for all claims in one batch
|
||||
all_sources = await self._search_evidence_batch(claims_to_verify)
|
||||
|
||||
# Step 2: Assess all claims against sources in one API call
|
||||
verified_claims = await self._assess_claims_batch(claims_to_verify, all_sources)
|
||||
|
||||
# Add any remaining claims as insufficient information
|
||||
|
||||
# Step 1: Search for evidence
|
||||
all_sources = await self._search_evidence_batch(claims_to_verify, user_id=user_id)
|
||||
|
||||
# Step 2: Assess claims against sources
|
||||
verified_claims = await self._assess_claims_batch(claims_to_verify, all_sources, user_id=user_id)
|
||||
|
||||
# Add remaining claims as insufficient information
|
||||
for i in range(max_claims, len(claims)):
|
||||
verified_claims.append(Claim(
|
||||
text=claims[i],
|
||||
@@ -288,13 +270,12 @@ class HallucinationDetector:
|
||||
refuting_sources=[],
|
||||
reasoning="Not verified due to API rate limit protection"
|
||||
))
|
||||
|
||||
|
||||
logger.info(f"Batch verification completed for {len(verified_claims)} claims")
|
||||
return verified_claims
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch verification: {str(e)}")
|
||||
# Return all claims as insufficient information
|
||||
return [
|
||||
Claim(
|
||||
text=claim,
|
||||
@@ -307,20 +288,11 @@ class HallucinationDetector:
|
||||
for claim in claims
|
||||
]
|
||||
|
||||
async def _verify_claim(self, claim: str) -> Claim:
|
||||
"""
|
||||
Verify a single claim using Exa.ai search.
|
||||
|
||||
Args:
|
||||
claim: The claim to verify
|
||||
|
||||
Returns:
|
||||
Claim object with verification results
|
||||
"""
|
||||
async def _verify_claim(self, claim: str, user_id: str = None) -> Claim:
|
||||
"""Verify a single claim using Exa.ai search."""
|
||||
try:
|
||||
# Search for evidence using Exa.ai
|
||||
sources = await self._search_evidence(claim)
|
||||
|
||||
sources = await self._search_evidence(claim, user_id=user_id)
|
||||
|
||||
if not sources:
|
||||
return Claim(
|
||||
text=claim,
|
||||
@@ -330,10 +302,9 @@ class HallucinationDetector:
|
||||
refuting_sources=[],
|
||||
reasoning="No sources found for verification"
|
||||
)
|
||||
|
||||
# Verify claim against sources using LLM
|
||||
verification_result = await self._assess_claim_against_sources(claim, sources)
|
||||
|
||||
|
||||
verification_result = await self._assess_claim_against_sources(claim, sources, user_id=user_id)
|
||||
|
||||
return Claim(
|
||||
text=claim,
|
||||
confidence=verification_result.get('confidence', 0.5),
|
||||
@@ -342,7 +313,7 @@ class HallucinationDetector:
|
||||
refuting_sources=verification_result.get('refuting_sources', []),
|
||||
reasoning=verification_result.get('reasoning', '')
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error verifying claim '{claim}': {str(e)}")
|
||||
return Claim(
|
||||
@@ -353,68 +324,40 @@ class HallucinationDetector:
|
||||
refuting_sources=[],
|
||||
reasoning=f"Error during verification: {str(e)}"
|
||||
)
|
||||
|
||||
async def _search_evidence_batch(self, claims: List[str]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for evidence for multiple claims in one API call.
|
||||
|
||||
Args:
|
||||
claims: List of claims to search for
|
||||
|
||||
Returns:
|
||||
List of sources relevant to the claims
|
||||
"""
|
||||
|
||||
async def _search_evidence_batch(self, claims: List[str], user_id: str = None) -> List[Dict[str, Any]]:
|
||||
"""Search for evidence for multiple claims in one API call."""
|
||||
try:
|
||||
# Combine all claims into one search query
|
||||
combined_query = " ".join(claims[:2]) # Use first 2 claims to avoid query length limits
|
||||
|
||||
combined_query = " ".join(claims[:2])
|
||||
logger.info(f"Searching for evidence for {len(claims)} claims with combined query")
|
||||
|
||||
# Use the existing search method with combined query
|
||||
sources = await self._search_evidence(combined_query)
|
||||
|
||||
# Limit sources to prevent excessive processing
|
||||
sources = await self._search_evidence(combined_query, user_id=user_id)
|
||||
|
||||
max_sources = 5
|
||||
if len(sources) > max_sources:
|
||||
sources = sources[:max_sources]
|
||||
logger.info(f"Limited sources to {max_sources} to prevent API rate limits")
|
||||
|
||||
|
||||
return sources
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch evidence search: {str(e)}")
|
||||
return []
|
||||
|
||||
async def _assess_claims_batch(self, claims: List[str], sources: List[Dict[str, Any]]) -> List[Claim]:
|
||||
"""
|
||||
Assess multiple claims against sources in one API call.
|
||||
|
||||
Args:
|
||||
claims: List of claims to assess
|
||||
sources: List of sources to assess against
|
||||
|
||||
Returns:
|
||||
List of Claim objects with assessment results
|
||||
"""
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available. Cannot assess claims without AI provider.")
|
||||
|
||||
async def _assess_claims_batch(self, claims: List[str], sources: List[Dict[str, Any]], user_id: str = None) -> List[Claim]:
|
||||
"""Assess multiple claims against sources in one LLM call."""
|
||||
try:
|
||||
# Limit to 3 claims to prevent excessive API usage
|
||||
claims_to_assess = claims[:3]
|
||||
|
||||
# Prepare sources text
|
||||
|
||||
combined_sources = "\n\n".join([
|
||||
f"Source {i+1}: {src.get('url','')}\nText: {src.get('text','')[:1000]}"
|
||||
for i, src in enumerate(sources)
|
||||
])
|
||||
|
||||
# Prepare claims text
|
||||
|
||||
claims_text = "\n".join([
|
||||
f"Claim {i+1}: {claim}"
|
||||
for i, claim in enumerate(claims_to_assess)
|
||||
])
|
||||
|
||||
|
||||
prompt = (
|
||||
"You are a strict fact-checker. Analyze each claim against the provided sources.\n\n"
|
||||
"Return ONLY a valid JSON object with this exact structure:\n"
|
||||
@@ -434,63 +377,36 @@ class HallucinationDetector:
|
||||
f"Sources:\n{combined_sources}\n\n"
|
||||
"Return only the JSON object:"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=prompt
|
||||
))
|
||||
|
||||
if not resp or not resp.text:
|
||||
raise Exception("Empty response from Gemini API for batch assessment")
|
||||
|
||||
result_text = resp.text.strip()
|
||||
logger.info(f"Raw Gemini response for batch assessment: {result_text[:200]}...")
|
||||
|
||||
# Try to extract JSON from the response
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON object in the response (handle markdown code blocks)
|
||||
import re
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', result_text, re.DOTALL)
|
||||
if code_block_match:
|
||||
result = json.loads(code_block_match.group(1))
|
||||
else:
|
||||
json_match = re.search(r'\{.*?\}', result_text, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
else:
|
||||
raise Exception(f"Could not parse JSON from Gemini response: {result_text[:100]}")
|
||||
|
||||
# Process assessments
|
||||
|
||||
result_text = await self._generate_text_async(prompt, user_id=user_id)
|
||||
logger.info(f"Raw LLM response for batch assessment: {result_text[:200]}...")
|
||||
|
||||
result = self._parse_json_from_response(result_text, expect_array=False)
|
||||
|
||||
assessments = result.get('assessments', [])
|
||||
verified_claims = []
|
||||
|
||||
|
||||
for i, claim in enumerate(claims_to_assess):
|
||||
# Find assessment for this claim
|
||||
assessment = None
|
||||
for a in assessments:
|
||||
if a.get('claim_index') == i:
|
||||
assessment = a
|
||||
break
|
||||
|
||||
|
||||
if assessment:
|
||||
# Process supporting and refuting sources
|
||||
supporting_sources = []
|
||||
refuting_sources = []
|
||||
|
||||
|
||||
if isinstance(assessment.get('supporting_sources'), list):
|
||||
for idx in assessment['supporting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
supporting_sources.append(sources[idx])
|
||||
|
||||
|
||||
if isinstance(assessment.get('refuting_sources'), list):
|
||||
for idx in assessment['refuting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
refuting_sources.append(sources[idx])
|
||||
|
||||
|
||||
verified_claims.append(Claim(
|
||||
text=claim,
|
||||
confidence=float(assessment.get('confidence', 0.5)),
|
||||
@@ -500,7 +416,6 @@ class HallucinationDetector:
|
||||
reasoning=assessment.get('reasoning', '')
|
||||
))
|
||||
else:
|
||||
# No assessment found for this claim
|
||||
verified_claims.append(Claim(
|
||||
text=claim,
|
||||
confidence=0.0,
|
||||
@@ -509,13 +424,12 @@ class HallucinationDetector:
|
||||
refuting_sources=[],
|
||||
reasoning="No assessment provided"
|
||||
))
|
||||
|
||||
|
||||
logger.info(f"Successfully assessed {len(verified_claims)} claims in batch")
|
||||
return verified_claims
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in batch assessment: {str(e)}")
|
||||
# Return all claims as insufficient information
|
||||
return [
|
||||
Claim(
|
||||
text=claim,
|
||||
@@ -528,88 +442,32 @@ class HallucinationDetector:
|
||||
for claim in claims_to_assess
|
||||
]
|
||||
|
||||
async def _search_evidence(self, claim: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Search for evidence using Exa.ai API.
|
||||
|
||||
Args:
|
||||
claim: The claim to search evidence for
|
||||
|
||||
Returns:
|
||||
List of source documents with evidence
|
||||
"""
|
||||
if not self.exa_api_key:
|
||||
raise Exception("Exa API key not available. Cannot search for evidence without Exa.ai access.")
|
||||
|
||||
async def _search_evidence(self, claim: str, user_id: str = None) -> List[Dict[str, Any]]:
|
||||
"""Search for evidence using ExaResearchProvider with subscription checks."""
|
||||
try:
|
||||
headers = {
|
||||
'x-api-key': self.exa_api_key,
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
payload = {
|
||||
'query': claim,
|
||||
'numResults': 5,
|
||||
'text': True,
|
||||
'useAutoprompt': True
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
'https://api.exa.ai/search',
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=15
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
provider = ExaResearchProvider()
|
||||
sources = await provider.simple_search(
|
||||
query=claim,
|
||||
num_results=5,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
results = data.get('results', [])
|
||||
|
||||
if not results:
|
||||
raise Exception(f"No search results found for claim: {claim}")
|
||||
|
||||
sources = []
|
||||
for result in results:
|
||||
source = {
|
||||
'title': result.get('title', 'Untitled'),
|
||||
'url': result.get('url', ''),
|
||||
'text': result.get('text', ''),
|
||||
'publishedDate': result.get('publishedDate', ''),
|
||||
'author': result.get('author', ''),
|
||||
'score': result.get('score', 0.5)
|
||||
}
|
||||
sources.append(source)
|
||||
|
||||
logger.info(f"Found {len(sources)} sources for claim: {claim[:50]}...")
|
||||
return sources
|
||||
else:
|
||||
raise Exception(f"Exa API error: {response.status_code} - {response.text}")
|
||||
|
||||
if not sources:
|
||||
raise Exception(f"No search results found for claim: {claim}")
|
||||
logger.info(f"Found {len(sources)} sources for claim: {claim[:50]}...")
|
||||
return sources
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching evidence with Exa: {str(e)}")
|
||||
raise Exception(f"Failed to search evidence: {str(e)}")
|
||||
|
||||
|
||||
async def _assess_claim_against_sources(self, claim: str, sources: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Assess whether sources support or refute the claim using LLM.
|
||||
|
||||
Args:
|
||||
claim: The claim to assess
|
||||
sources: List of source documents
|
||||
|
||||
Returns:
|
||||
Dictionary with assessment results
|
||||
"""
|
||||
if not self.gemini_client:
|
||||
raise Exception("Gemini client not available. Cannot assess claims without AI provider.")
|
||||
|
||||
|
||||
async def _assess_claim_against_sources(self, claim: str, sources: List[Dict[str, Any]], user_id: str = None) -> Dict[str, Any]:
|
||||
"""Assess whether sources support or refute the claim using LLM."""
|
||||
try:
|
||||
combined_sources = "\n\n".join([
|
||||
f"Source {i+1}: {src.get('url','')}\nText: {src.get('text','')[:2000]}"
|
||||
for i, src in enumerate(sources)
|
||||
])
|
||||
|
||||
|
||||
prompt = (
|
||||
"You are a strict fact-checker. Analyze the claim against the provided sources.\n\n"
|
||||
"Return ONLY a valid JSON object with this exact structure:\n"
|
||||
@@ -624,70 +482,44 @@ class HallucinationDetector:
|
||||
f"Sources:\n{combined_sources}\n\n"
|
||||
"Return only the JSON object:"
|
||||
)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
resp = await loop.run_in_executor(executor, lambda: self.gemini_client.models.generate_content(
|
||||
model="gemini-1.5-flash",
|
||||
contents=prompt
|
||||
))
|
||||
|
||||
if not resp or not resp.text:
|
||||
raise Exception("Empty response from Gemini API for claim assessment")
|
||||
|
||||
result_text = resp.text.strip()
|
||||
logger.info(f"Raw Gemini response for assessment: {result_text[:200]}...")
|
||||
|
||||
# Try to extract JSON from the response
|
||||
try:
|
||||
result = json.loads(result_text)
|
||||
except json.JSONDecodeError:
|
||||
# Try to find JSON object in the response (handle markdown code blocks)
|
||||
import re
|
||||
# First try to extract from markdown code blocks
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', result_text, re.DOTALL)
|
||||
if code_block_match:
|
||||
result = json.loads(code_block_match.group(1))
|
||||
else:
|
||||
# Try to find JSON object directly
|
||||
json_match = re.search(r'\{.*?\}', result_text, re.DOTALL)
|
||||
if json_match:
|
||||
result = json.loads(json_match.group())
|
||||
else:
|
||||
raise Exception(f"Could not parse JSON from Gemini response: {result_text[:100]}")
|
||||
|
||||
|
||||
result_text = await self._generate_text_async(prompt, user_id=user_id)
|
||||
logger.info(f"Raw LLM response for assessment: {result_text[:200]}...")
|
||||
|
||||
result = self._parse_json_from_response(result_text, expect_array=False)
|
||||
|
||||
# Validate required fields
|
||||
required_fields = ['assessment', 'confidence', 'supporting_sources', 'refuting_sources', 'reasoning']
|
||||
for field in required_fields:
|
||||
if field not in result:
|
||||
raise Exception(f"Missing required field '{field}' in assessment response")
|
||||
|
||||
|
||||
# Process supporting and refuting sources
|
||||
supporting_sources = []
|
||||
refuting_sources = []
|
||||
|
||||
|
||||
if isinstance(result.get('supporting_sources'), list):
|
||||
for idx in result['supporting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
supporting_sources.append(sources[idx])
|
||||
|
||||
|
||||
if isinstance(result.get('refuting_sources'), list):
|
||||
for idx in result['refuting_sources']:
|
||||
if isinstance(idx, int) and 0 <= idx < len(sources):
|
||||
refuting_sources.append(sources[idx])
|
||||
|
||||
|
||||
# Validate assessment value
|
||||
valid_assessments = ['supported', 'refuted', 'insufficient_information']
|
||||
if result['assessment'] not in valid_assessments:
|
||||
raise Exception(f"Invalid assessment value: {result['assessment']}")
|
||||
|
||||
|
||||
# Validate confidence value
|
||||
confidence = float(result['confidence'])
|
||||
if not (0.0 <= confidence <= 1.0):
|
||||
raise Exception(f"Invalid confidence value: {confidence}")
|
||||
|
||||
|
||||
logger.info(f"Successfully assessed claim: {result['assessment']} (confidence: {confidence})")
|
||||
|
||||
|
||||
return {
|
||||
'assessment': result['assessment'],
|
||||
'confidence': confidence,
|
||||
@@ -695,8 +527,39 @@ class HallucinationDetector:
|
||||
'refuting_sources': refuting_sources,
|
||||
'reasoning': result['reasoning']
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assessing claim against sources: {str(e)}")
|
||||
raise Exception(f"Failed to assess claim: {str(e)}")
|
||||
|
||||
|
||||
def _parse_json_from_response(self, text: str, expect_array: bool = False):
|
||||
"""Extract and parse JSON from LLM response, handling markdown code blocks."""
|
||||
text = text.strip()
|
||||
|
||||
# Try direct parse first
|
||||
try:
|
||||
result = json.loads(text)
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
import re
|
||||
# Try to extract from markdown code blocks
|
||||
if expect_array:
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\[.*?\])\s*```', text, re.DOTALL)
|
||||
if code_block_match:
|
||||
return json.loads(code_block_match.group(1))
|
||||
# Try to find JSON array directly
|
||||
json_match = re.search(r'\[.*\]', text, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
else:
|
||||
code_block_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', text, re.DOTALL)
|
||||
if code_block_match:
|
||||
return json.loads(code_block_match.group(1))
|
||||
# Try to find JSON object directly
|
||||
json_match = re.search(r'\{.*\}', text, re.DOTALL)
|
||||
if json_match:
|
||||
return json.loads(json_match.group())
|
||||
|
||||
raise Exception(f"Could not parse JSON from LLM response: {text[:100]}")
|
||||
@@ -53,6 +53,7 @@ class WixBlogService:
|
||||
"""Create draft post with consolidated logging"""
|
||||
from .logger import wix_logger
|
||||
import json
|
||||
import traceback as tb
|
||||
|
||||
# Build payload summary for logging
|
||||
payload_summary = {}
|
||||
@@ -65,7 +66,14 @@ class WixBlogService:
|
||||
}
|
||||
|
||||
request_headers = self.headers(access_token, extra_headers)
|
||||
response = requests.post(f"{self.base_url}/blog/v3/draft-posts", headers=request_headers, json=payload)
|
||||
try:
|
||||
response = requests.post(f"{self.base_url}/blog/v3/draft-posts", headers=request_headers, json=payload)
|
||||
except TypeError as e:
|
||||
logger.error(f"TypeError during requests.post in create_draft_post: {e}")
|
||||
logger.error(f"Traceback: {tb.format_exc()}")
|
||||
logger.error(f"access_token type: {type(access_token)}")
|
||||
logger.error(f"payload type: {type(payload)}, keys: {list(payload.keys()) if isinstance(payload, dict) else 'N/A'}")
|
||||
raise
|
||||
|
||||
# Consolidated error logging
|
||||
error_body = None
|
||||
|
||||
@@ -5,6 +5,7 @@ Handles blog post creation, validation, and publishing to Wix.
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import requests
|
||||
import jwt
|
||||
@@ -398,6 +399,30 @@ def create_blog_post(
|
||||
# Ensure we only have 'nodes' in richContent for CREATE endpoint
|
||||
ricos_content = {'nodes': ricos_content['nodes']}
|
||||
|
||||
# SAFE ITEM 4: Prepend H1 title node if content doesn't start with one.
|
||||
# The markdown typically starts at ## (H2) because the title is separate,
|
||||
# but Wix renders the richContent as the full post body including the title.
|
||||
# Without an H1, the post looks like it has no heading.
|
||||
existing_first = ricos_content['nodes'][0] if ricos_content['nodes'] else None
|
||||
has_h1 = existing_first and existing_first.get('type') == 'HEADING' and existing_first.get('headingData', {}).get('level') == 1
|
||||
if not has_h1 and title:
|
||||
title_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'HEADING',
|
||||
'nodes': [{
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'TEXT',
|
||||
'nodes': [],
|
||||
'textData': {
|
||||
'text': str(title).strip(),
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'headingData': {'level': 1}
|
||||
}
|
||||
ricos_content['nodes'] = [title_node] + ricos_content['nodes']
|
||||
logger.debug(f"Prepended H1 title node: '{str(title).strip()[:50]}'")
|
||||
|
||||
logger.debug(f"✅ richContent structure validated: {len(ricos_content['nodes'])} nodes, keys: {list(ricos_content.keys())}")
|
||||
|
||||
# Minimal payload per Wix docs: title, memberId, and richContent
|
||||
@@ -407,15 +432,39 @@ def create_blog_post(
|
||||
'title': str(title).strip() if title else "Untitled",
|
||||
'memberId': str(member_id).strip(), # Required for third-party apps (validated above)
|
||||
'richContent': ricos_content, # Must be a valid Ricos object with ONLY 'nodes'
|
||||
'language': 'en',
|
||||
},
|
||||
'publish': bool(publish),
|
||||
'fieldsets': ['URL'] # Simplified fieldsets
|
||||
}
|
||||
|
||||
# Add excerpt only if content exists and is not empty (avoid None or empty strings)
|
||||
excerpt = (content or '').strip()[:200] if content else None
|
||||
if excerpt and len(excerpt) > 0:
|
||||
blog_data['draftPost']['excerpt'] = str(excerpt)
|
||||
# SAFE ITEM 1: Auto-generate seoSlug from title if not provided by SEO metadata
|
||||
# Wix uses this for the URL path (e.g. /post/my-blog-title)
|
||||
slug_source = None
|
||||
if seo_metadata and seo_metadata.get('url_slug'):
|
||||
slug_source = str(seo_metadata['url_slug']).strip()
|
||||
elif title:
|
||||
slug_source = re.sub(r'[^a-z0-9]+', '-', str(title).strip().lower()).strip('-')
|
||||
slug_source = slug_source[:60].rstrip('-')
|
||||
if slug_source:
|
||||
blog_data['draftPost']['seoSlug'] = slug_source
|
||||
|
||||
# SAFE ITEM 3: Better excerpt — prefer meta_description, then first plain-text paragraph
|
||||
excerpt = None
|
||||
if seo_metadata and seo_metadata.get('meta_description'):
|
||||
excerpt = str(seo_metadata['meta_description']).strip()[:200]
|
||||
if not excerpt and content:
|
||||
for node in ricos_content['nodes']:
|
||||
if node.get('type') == 'PARAGRAPH':
|
||||
texts = []
|
||||
for child in node.get('nodes', []):
|
||||
if child.get('type') == 'TEXT' and child.get('textData', {}).get('text'):
|
||||
texts.append(child['textData']['text'])
|
||||
if texts:
|
||||
excerpt = ' '.join(texts).strip()[:200]
|
||||
break
|
||||
if excerpt:
|
||||
blog_data['draftPost']['excerpt'] = excerpt
|
||||
|
||||
# Add cover image if provided
|
||||
if cover_image_url and import_image_func:
|
||||
@@ -495,7 +544,6 @@ def create_blog_post(
|
||||
|
||||
# Build SEO data from metadata if provided
|
||||
# NOTE: seoData is optional - if it causes issues, we can create post without it
|
||||
seo_data = None
|
||||
if seo_metadata:
|
||||
try:
|
||||
seo_data = build_seo_data(seo_metadata, title)
|
||||
@@ -506,13 +554,8 @@ def create_blog_post(
|
||||
blog_data['draftPost']['seoData'] = seo_data
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ Wix: SEO data build failed - {str(e)[:50]}")
|
||||
wix_logger.add_warning(f"SEO build: {str(e)[:50]}")
|
||||
|
||||
# Add SEO slug if provided
|
||||
if seo_metadata.get('url_slug'):
|
||||
blog_data['draftPost']['seoSlug'] = str(seo_metadata.get('url_slug')).strip()
|
||||
else:
|
||||
logger.warning("⚠️ No SEO metadata provided to create_blog_post")
|
||||
logger.debug("No SEO metadata provided to create_blog_post")
|
||||
|
||||
try:
|
||||
# Extract wix-site-id from token if possible
|
||||
@@ -534,7 +577,6 @@ def create_blog_post(
|
||||
meta_site_id = instance_data.get('metaSiteId')
|
||||
if isinstance(meta_site_id, str) and meta_site_id:
|
||||
extra_headers['wix-site-id'] = meta_site_id
|
||||
headers['wix-site-id'] = meta_site_id
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -574,156 +616,27 @@ def create_blog_post(
|
||||
logger.error(f"❌ Payload validation failed: {e}")
|
||||
raise
|
||||
|
||||
# Log full payload structure for debugging (sanitized)
|
||||
logger.warning(f"📦 Full payload structure validation:")
|
||||
logger.warning(f" - draftPost type: {type(draft_post)}")
|
||||
logger.warning(f" - draftPost keys: {list(draft_post.keys())}")
|
||||
logger.warning(f" - richContent type: {type(draft_post.get('richContent'))}")
|
||||
if 'richContent' in draft_post:
|
||||
rc = draft_post['richContent']
|
||||
logger.warning(f" - richContent keys: {list(rc.keys()) if isinstance(rc, dict) else 'N/A'}")
|
||||
logger.warning(f" - richContent.nodes type: {type(rc.get('nodes'))}, count: {len(rc.get('nodes', []))}")
|
||||
logger.warning(f" - richContent.metadata type: {type(rc.get('metadata'))}")
|
||||
logger.warning(f" - richContent.documentStyle type: {type(rc.get('documentStyle'))}")
|
||||
logger.warning(f" - seoData type: {type(draft_post.get('seoData'))}")
|
||||
if 'seoData' in draft_post:
|
||||
seo = draft_post['seoData']
|
||||
logger.warning(f" - seoData keys: {list(seo.keys()) if isinstance(seo, dict) else 'N/A'}")
|
||||
logger.warning(f" - seoData.tags type: {type(seo.get('tags'))}, count: {len(seo.get('tags', []))}")
|
||||
logger.warning(f" - seoData.settings type: {type(seo.get('settings'))}")
|
||||
if 'categoryIds' in draft_post:
|
||||
logger.warning(f" - categoryIds type: {type(draft_post.get('categoryIds'))}, count: {len(draft_post.get('categoryIds', []))}")
|
||||
if 'tagIds' in draft_post:
|
||||
logger.warning(f" - tagIds type: {type(draft_post.get('tagIds'))}, count: {len(draft_post.get('tagIds', []))}")
|
||||
|
||||
# Log a sample of the payload JSON to see exact structure (first 2000 chars)
|
||||
try:
|
||||
import json
|
||||
payload_json = json.dumps(blog_data, indent=2, ensure_ascii=False)
|
||||
logger.warning(f"📄 Payload JSON preview (first 3000 chars):\n{payload_json[:3000]}...")
|
||||
|
||||
# Also log a deep structure inspection of richContent.nodes (first few nodes)
|
||||
if 'richContent' in blog_data['draftPost']:
|
||||
nodes = blog_data['draftPost']['richContent'].get('nodes', [])
|
||||
if nodes:
|
||||
logger.warning(f"🔍 Inspecting first 5 richContent.nodes:")
|
||||
for i, node in enumerate(nodes[:5]):
|
||||
logger.warning(f" Node {i+1}: type={node.get('type')}, keys={list(node.keys())}")
|
||||
# Check for any None values in node
|
||||
for key, value in node.items():
|
||||
if value is None:
|
||||
logger.error(f" ⚠️ Node {i+1}.{key} is None!")
|
||||
elif isinstance(value, dict):
|
||||
for k, v in value.items():
|
||||
if v is None:
|
||||
logger.error(f" ⚠️ Node {i+1}.{key}.{k} is None!")
|
||||
# Deep check: if it's a list-type node, inspect list items
|
||||
if node.get('type') in ['BULLETED_LIST', 'ORDERED_LIST']:
|
||||
list_items = node.get('nodes', [])
|
||||
if list_items:
|
||||
logger.warning(f" List has {len(list_items)} items, checking first LIST_ITEM:")
|
||||
first_item = list_items[0]
|
||||
logger.warning(f" LIST_ITEM keys: {list(first_item.keys())}")
|
||||
# Verify listItemData is NOT present (correct per Wix API spec)
|
||||
if 'listItemData' in first_item:
|
||||
logger.error(f" ❌ LIST_ITEM incorrectly has listItemData!")
|
||||
else:
|
||||
logger.debug(f" ✅ LIST_ITEM correctly has no listItemData")
|
||||
# Check nested PARAGRAPH nodes
|
||||
nested_nodes = first_item.get('nodes', [])
|
||||
if nested_nodes:
|
||||
logger.warning(f" LIST_ITEM has {len(nested_nodes)} nested nodes")
|
||||
for n_idx, n_node in enumerate(nested_nodes[:2]):
|
||||
logger.warning(f" Nested node {n_idx+1}: type={n_node.get('type')}, keys={list(n_node.keys())}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not serialize payload for logging: {e}")
|
||||
|
||||
# Note: All node validation is done by validate_ricos_content() which runs earlier
|
||||
# The recursive validation ensures all required data fields are present at any depth
|
||||
# Log payload summary
|
||||
logger.debug(f"Payload: draftPost keys={list(draft_post.keys())}, "
|
||||
f"nodes={len(draft_post.get('richContent', {}).get('nodes', []))}, "
|
||||
f"has_seo={'seoData' in draft_post}")
|
||||
|
||||
# Final deep validation: Serialize and deserialize to catch any JSON-serialization issues
|
||||
# This will raise an error if there are any objects that can't be serialized
|
||||
try:
|
||||
import json
|
||||
test_json = json.dumps(blog_data, ensure_ascii=False)
|
||||
test_parsed = json.loads(test_json)
|
||||
logger.debug("✅ Payload JSON serialization test passed")
|
||||
json.dumps(blog_data, ensure_ascii=False)
|
||||
except (TypeError, ValueError) as e:
|
||||
logger.error(f"❌ Payload JSON serialization failed: {e}")
|
||||
raise ValueError(f"Payload contains non-serializable data: {e}")
|
||||
|
||||
# Final check: Ensure documentStyle and metadata are valid objects (not None, not empty strings)
|
||||
# Clean up None values that Wix API would reject
|
||||
rc = blog_data['draftPost']['richContent']
|
||||
if 'documentStyle' in rc:
|
||||
doc_style = rc['documentStyle']
|
||||
if doc_style is None or doc_style == "":
|
||||
logger.warning("⚠️ documentStyle is None or empty string, removing it")
|
||||
del rc['documentStyle']
|
||||
elif not isinstance(doc_style, dict):
|
||||
logger.warning(f"⚠️ documentStyle is not a dict ({type(doc_style)}), removing it")
|
||||
del rc['documentStyle']
|
||||
for field in ['documentStyle', 'metadata']:
|
||||
if field in rc and (rc[field] is None or rc[field] == "" or not isinstance(rc[field], dict)):
|
||||
del rc[field]
|
||||
|
||||
if 'metadata' in rc:
|
||||
metadata = rc['metadata']
|
||||
if metadata is None or metadata == "":
|
||||
logger.warning("⚠️ metadata is None or empty string, removing it")
|
||||
del rc['metadata']
|
||||
elif not isinstance(metadata, dict):
|
||||
logger.warning(f"⚠️ metadata is not a dict ({type(metadata)}), removing it")
|
||||
del rc['metadata']
|
||||
|
||||
# Check for any None values in critical nested structures
|
||||
def check_none_in_dict(d, path=""):
|
||||
"""Recursively check for None values that shouldn't be there"""
|
||||
issues = []
|
||||
if isinstance(d, dict):
|
||||
for key, value in d.items():
|
||||
current_path = f"{path}.{key}" if path else key
|
||||
if value is None:
|
||||
# Some fields can legitimately be None, but most shouldn't
|
||||
if key not in ['decorations', 'nodeStyle', 'props']:
|
||||
issues.append(current_path)
|
||||
elif isinstance(value, dict):
|
||||
issues.extend(check_none_in_dict(value, current_path))
|
||||
elif isinstance(value, list):
|
||||
for i, item in enumerate(value):
|
||||
if item is None:
|
||||
issues.append(f"{current_path}[{i}]")
|
||||
elif isinstance(item, dict):
|
||||
issues.extend(check_none_in_dict(item, f"{current_path}[{i}]"))
|
||||
return issues
|
||||
|
||||
none_issues = check_none_in_dict(blog_data['draftPost']['richContent'])
|
||||
if none_issues:
|
||||
logger.error(f"❌ Found None values in richContent at: {none_issues[:10]}") # Limit to first 10
|
||||
# Remove None values from critical paths
|
||||
for issue_path in none_issues[:5]: # Fix first 5
|
||||
parts = issue_path.split('.')
|
||||
try:
|
||||
obj = blog_data['draftPost']['richContent']
|
||||
for part in parts[:-1]:
|
||||
if '[' in part:
|
||||
key, idx = part.split('[')
|
||||
idx = int(idx.rstrip(']'))
|
||||
obj = obj[key][idx]
|
||||
else:
|
||||
obj = obj[part]
|
||||
final_key = parts[-1]
|
||||
if '[' in final_key:
|
||||
key, idx = final_key.split('[')
|
||||
idx = int(idx.rstrip(']'))
|
||||
obj[key][idx] = {}
|
||||
else:
|
||||
obj[final_key] = {}
|
||||
logger.warning(f"Fixed None value at {issue_path}")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Log the final payload structure one more time before sending
|
||||
logger.warning(f"📤 Final payload ready - draftPost keys: {list(blog_data['draftPost'].keys())}")
|
||||
logger.warning(f"📤 RichContent nodes count: {len(blog_data['draftPost']['richContent'].get('nodes', []))}")
|
||||
logger.warning(f"📤 RichContent has metadata: {bool(blog_data['draftPost']['richContent'].get('metadata'))}")
|
||||
logger.warning(f"📤 RichContent has documentStyle: {bool(blog_data['draftPost']['richContent'].get('documentStyle'))}")
|
||||
logger.info(f"📤 Publishing to Wix: title='{blog_data['draftPost'].get('title', '')}', "
|
||||
f"nodes={len(rc.get('nodes', []))}")
|
||||
|
||||
result = blog_service.create_draft_post(access_token, blog_data, extra_headers or None)
|
||||
|
||||
@@ -734,6 +647,11 @@ def create_blog_post(
|
||||
logger.success(f"✅ Wix: Blog post created - ID: {post_id}")
|
||||
|
||||
return result
|
||||
except TypeError as e:
|
||||
import traceback
|
||||
logger.error(f"TypeError in create_blog_post: {e}")
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
raise
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Failed to create blog post: {e}")
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
|
||||
@@ -66,7 +66,8 @@ class WixLogger:
|
||||
if 'title' in dp:
|
||||
parts.append(f"title='{str(dp['title'])[:50]}...'")
|
||||
if 'richContent' in dp:
|
||||
nodes_count = len(dp['richContent'].get('nodes', []))
|
||||
nodes_val = dp['richContent'].get('nodes', [])
|
||||
nodes_count = nodes_val if isinstance(nodes_val, int) else len(nodes_val)
|
||||
parts.append(f"nodes={nodes_count}")
|
||||
if 'seoData' in dp:
|
||||
parts.append("has_seoData")
|
||||
|
||||
323
backend/services/link_search_service.py
Normal file
323
backend/services/link_search_service.py
Normal file
@@ -0,0 +1,323 @@
|
||||
"""
|
||||
Link Search Service — Internal & external link discovery and rewording.
|
||||
|
||||
Provides:
|
||||
- Internal link search (Exa include_domains scoped to user's website)
|
||||
- External link search (Exa general search, optionally excluding user's domain)
|
||||
- Reword-with-links (LLM embeds selected links naturally into section/selected text)
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Dict, Any, List, Optional
|
||||
from loguru import logger
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
|
||||
LINK_SEARCH_SYSTEM_PROMPT = """You are an SEO and content linking expert. Your task is to naturally incorporate provided links into text using markdown link syntax, following the best practices below.
|
||||
|
||||
## SEO Linking Best Practices
|
||||
|
||||
1. **Anchor text must be descriptive and keyword-rich.** Use the surrounding context to create natural, specific anchor text. Never use "click here", "read more", "learn more", or bare URLs as anchors.
|
||||
- GOOD: [HubSpot's content marketing statistics](url) — descriptive, includes keywords
|
||||
- BAD: [click here](url) — vague, no SEO value
|
||||
- BAD: [https://example.com](url) — raw URL, harmful to readability
|
||||
|
||||
2. **Match link type to content context:**
|
||||
- Internal links: Point anchor text at relevant topic keywords that describe the destination page
|
||||
- External links: Cite authoritative sources (research, official docs, industry leaders) using the source name or key finding as anchor text
|
||||
|
||||
3. **Link equity (PageRank) distribution:** Spread links naturally. Aim for 1-2 links per paragraph at most. Don't cluster all links together.
|
||||
|
||||
4. **Preserve the original text's meaning, tone, structure, and approximate length.** You are inserting links, NOT rewriting the content.
|
||||
|
||||
5. **If selected_text is provided, ONLY modify that specific portion.** The rest of section_text must remain IDENTICAL — character-for-character unchanged.
|
||||
|
||||
6. **If selected_text is NOT provided, you may insert links throughout the entire section_text.**
|
||||
|
||||
7. **Link placement should feel earned, not forced.** Only insert a link where a reader would genuinely want to learn more. If a link doesn't naturally fit, skip it.
|
||||
|
||||
8. **Prioritize high-authority external sources** (research papers, official documentation, industry leaders) when linking externally.
|
||||
|
||||
9. **Return ONLY the reworded text.** No explanations, no preamble, no markdown code fences. Just the text with [anchor text](url) links embedded."""
|
||||
|
||||
|
||||
LINK_SEARCH_USER_PROMPT = """## Section Heading
|
||||
{section_heading}
|
||||
|
||||
## Full Section Text
|
||||
{section_text}
|
||||
|
||||
{selected_text_block}
|
||||
|
||||
## Available Links to Incorporate
|
||||
{links}
|
||||
|
||||
## Instructions
|
||||
Carefully read the section text above and insert the most relevant links from the "Available Links" list using markdown format: [descriptive anchor text](url).
|
||||
|
||||
Remember:
|
||||
- Use keyword-rich, descriptive anchor text (NOT "click here" or bare URLs)
|
||||
- Only insert links where they naturally enhance the reader's experience
|
||||
- Preserve the original text's meaning, tone, and structure
|
||||
- Aim for 1-2 links per paragraph maximum
|
||||
- If no links fit naturally, return the text unchanged
|
||||
|
||||
Return ONLY the text with links embedded. No explanations."""
|
||||
|
||||
|
||||
def _extract_domain(url: str) -> str:
|
||||
"""Extract the registered domain from a URL.
|
||||
|
||||
Handles common multi-part TLDs like .co.uk, .com.au, .co.jp, etc.
|
||||
Falls back to last two parts for unknown TLDs.
|
||||
"""
|
||||
url = url.strip()
|
||||
if not url:
|
||||
return ""
|
||||
# Add protocol if missing
|
||||
if not url.startswith(("http://", "https://")):
|
||||
url = "https://" + url
|
||||
# Remove protocol
|
||||
domain = re.sub(r"^https?://", "", url)
|
||||
# Remove path and query
|
||||
domain = domain.split("/")[0].split("?")[0].split("#")[0]
|
||||
# Remove port
|
||||
domain = domain.split(":")[0]
|
||||
# Remove userinfo (user:pass@)
|
||||
if "@" in domain:
|
||||
domain = domain.split("@")[-1]
|
||||
domain = domain.lower().strip()
|
||||
if not domain:
|
||||
return ""
|
||||
|
||||
# Known multi-part TLDs (common ccTLDs with second-level domains)
|
||||
multi_part_tlds = {
|
||||
"co.uk", "org.uk", "ac.uk", "gov.uk", "co.jp", "or.jp", "ne.jp", "ac.jp",
|
||||
"co.au", "com.au", "org.au", "net.au", "co.nz", "net.nz", "org.nz",
|
||||
"co.in", "net.in", "org.in", "ac.in", "co.kr", "co.za", "org.za", "web.za",
|
||||
"com.br", "com.mx", "com.ar", "com.sg", "com.hk", "com.tw", "com.my",
|
||||
"com.cn", "org.cn", "net.cn", "ac.ke", "co.ke",
|
||||
}
|
||||
parts = domain.split(".")
|
||||
if len(parts) < 2:
|
||||
return domain
|
||||
|
||||
# Check if last two parts form a known multi-part TLD
|
||||
last_two = ".".join(parts[-2:])
|
||||
if last_two in multi_part_tlds and len(parts) > 2:
|
||||
# e.g. blog.example.co.uk → example.co.uk
|
||||
return ".".join(parts[-3:])
|
||||
# Default: last two parts (example.com)
|
||||
return ".".join(parts[-2:])
|
||||
|
||||
|
||||
def _filter_search_results(results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Filter out results with empty URLs or missing essential fields."""
|
||||
filtered = []
|
||||
for r in results:
|
||||
url = r.get("url", "").strip()
|
||||
title = r.get("title", "").strip() or "Untitled"
|
||||
if url:
|
||||
filtered.append({
|
||||
"title": title,
|
||||
"url": url,
|
||||
"text": r.get("text", ""),
|
||||
"publishedDate": r.get("publishedDate", ""),
|
||||
"author": r.get("author", ""),
|
||||
"score": r.get("score", 0.5),
|
||||
})
|
||||
return filtered
|
||||
|
||||
|
||||
class LinkSearchService:
|
||||
"""Service for finding internal/external links and rewording text to include them."""
|
||||
|
||||
async def search_internal(
|
||||
self,
|
||||
query: str,
|
||||
site_url: str,
|
||||
user_id: Optional[str] = None,
|
||||
num_results: int = 5,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for internal links (from the user's own website).
|
||||
|
||||
Args:
|
||||
query: Search query (section topic/heading)
|
||||
site_url: User's website URL to scope search via include_domains
|
||||
user_id: Optional user ID for subscription tracking
|
||||
num_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
{"results": [...], "warnings": [...]}
|
||||
"""
|
||||
warnings = []
|
||||
domain = _extract_domain(site_url)
|
||||
|
||||
if not domain:
|
||||
return {
|
||||
"results": [],
|
||||
"warnings": [f"Could not extract domain from '{site_url}'"],
|
||||
}
|
||||
|
||||
try:
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
results = await provider.simple_search(
|
||||
query=query,
|
||||
num_results=num_results,
|
||||
user_id=user_id,
|
||||
include_domains=[domain],
|
||||
)
|
||||
filtered = _filter_search_results(results)
|
||||
return {"results": filtered, "warnings": warnings}
|
||||
|
||||
except ImportError:
|
||||
msg = "Exa provider not available — link search requires Exa API."
|
||||
logger.warning(f"[LinkSearchService] {msg}")
|
||||
warnings.append(msg)
|
||||
return {"results": [], "warnings": warnings}
|
||||
except Exception as e:
|
||||
logger.error(f"[LinkSearchService] Internal link search failed: {e}")
|
||||
warnings.append(f"Search failed: {str(e)}")
|
||||
return {"results": [], "warnings": warnings}
|
||||
|
||||
async def search_external(
|
||||
self,
|
||||
query: str,
|
||||
site_url: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
num_results: int = 5,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Search for external links (optionally excluding the user's own domain).
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
site_url: User's website URL — results from this domain will be excluded
|
||||
user_id: Optional user ID for subscription tracking
|
||||
num_results: Number of results to return
|
||||
|
||||
Returns:
|
||||
{"results": [...], "warnings": [...]}
|
||||
"""
|
||||
warnings = []
|
||||
exclude_domains = None
|
||||
|
||||
if site_url:
|
||||
domain = _extract_domain(site_url)
|
||||
if domain:
|
||||
exclude_domains = [domain]
|
||||
|
||||
try:
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
results = await provider.simple_search(
|
||||
query=query,
|
||||
num_results=num_results,
|
||||
user_id=user_id,
|
||||
exclude_domains=exclude_domains,
|
||||
)
|
||||
filtered = _filter_search_results(results)
|
||||
return {"results": filtered, "warnings": warnings}
|
||||
|
||||
except ImportError:
|
||||
msg = "Exa provider not available — link search requires Exa API."
|
||||
logger.warning(f"[LinkSearchService] {msg}")
|
||||
warnings.append(msg)
|
||||
return {"results": [], "warnings": warnings}
|
||||
except Exception as e:
|
||||
logger.error(f"[LinkSearchService] External link search failed: {e}")
|
||||
warnings.append(f"Search failed: {str(e)}")
|
||||
return {"results": [], "warnings": warnings}
|
||||
|
||||
def reword_with_links(
|
||||
self,
|
||||
section_text: str,
|
||||
links: List[Dict[str, str]],
|
||||
section_heading: Optional[str] = None,
|
||||
selected_text: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Use LLM to reword text, naturally incorporating the selected links.
|
||||
|
||||
Args:
|
||||
section_text: Full section text
|
||||
links: List of {"url": str, "title": str} dicts
|
||||
section_heading: Optional section heading for context
|
||||
selected_text: If provided, only reword this portion of the text
|
||||
user_id: Optional user ID for LLM routing
|
||||
|
||||
Returns:
|
||||
{"reworded_text": str, "warnings": [...]}
|
||||
"""
|
||||
warnings = []
|
||||
|
||||
if not links:
|
||||
return {
|
||||
"reworded_text": section_text,
|
||||
"warnings": ["No links provided — returning original text unchanged."],
|
||||
}
|
||||
|
||||
links_text = "\n".join(
|
||||
f"- [{link.get('title', 'Untitled')}]({link.get('url', '')}) — {link.get('title', '')}"
|
||||
for link in links
|
||||
)
|
||||
|
||||
selected_text_block = ""
|
||||
if selected_text:
|
||||
selected_text_block = f"Selected text to reword (keep surrounding text unchanged):\n{selected_text}"
|
||||
|
||||
prompt = LINK_SEARCH_USER_PROMPT.format(
|
||||
section_heading=section_heading or "Blog Section",
|
||||
section_text=section_text[:3000],
|
||||
selected_text_block=selected_text_block,
|
||||
links=links_text,
|
||||
)
|
||||
|
||||
try:
|
||||
result = llm_text_gen(
|
||||
prompt=prompt,
|
||||
system_prompt=LINK_SEARCH_SYSTEM_PROMPT,
|
||||
json_struct=None,
|
||||
max_tokens=3000,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
raw = result.get("text", "") if isinstance(result, dict) else str(result) if result else ""
|
||||
raw = raw.strip()
|
||||
|
||||
# Strip markdown code fences if the LLM wrapped the output
|
||||
if raw.startswith("```"):
|
||||
match = re.search(r"```(?:markdown|md)?\s*(.*?)\s*```", raw, re.DOTALL)
|
||||
if match:
|
||||
raw = match.group(1).strip()
|
||||
|
||||
if not raw:
|
||||
warnings.append("LLM returned empty reworded text — returning original.")
|
||||
return {"reworded_text": section_text, "warnings": warnings}
|
||||
|
||||
logger.info(f"[LinkSearchService] Reworded text: {len(raw)} chars, {len(links)} links provided")
|
||||
return {"reworded_text": raw, "warnings": warnings}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[LinkSearchService] Reword failed: {e}")
|
||||
warnings.append(f"Reword failed: {str(e)}")
|
||||
return {"reworded_text": section_text, "warnings": warnings}
|
||||
|
||||
|
||||
# Per-user service instances (not strictly needed since service is stateless,
|
||||
# but kept for consistency with chart_service pattern)
|
||||
_link_search_instances: Dict[str, LinkSearchService] = {}
|
||||
|
||||
|
||||
def get_link_search_service(user_id: Optional[str] = None) -> LinkSearchService:
|
||||
"""Get or create LinkSearchService for the given user."""
|
||||
cache_key = user_id or "default"
|
||||
if cache_key not in _link_search_instances:
|
||||
_link_search_instances[cache_key] = LinkSearchService()
|
||||
return _link_search_instances[cache_key]
|
||||
@@ -429,6 +429,23 @@ def llm_text_gen(
|
||||
except Exception as provider_error:
|
||||
logger.error(f"[llm_text_gen] Provider {gpt_provider} failed: {str(provider_error)}")
|
||||
|
||||
# Surface balance/quota errors immediately without fallback
|
||||
error_str = str(provider_error).lower()
|
||||
if "insufficient_balance" in error_str or "balance_not_enough" in error_str or ("403" in error_str and "balance" in error_str):
|
||||
logger.error(f"[llm_text_gen] Balance/quota error from {gpt_provider}, not attempting fallback")
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "insufficient_balance",
|
||||
"message": f"Your {gpt_provider.capitalize()} API balance is insufficient. Please top up your account or switch providers.",
|
||||
"usage_info": {
|
||||
"error_type": "insufficient_balance",
|
||||
"provider": gpt_provider,
|
||||
"suggestion": f"Set GPT_PROVIDER=google in your environment to use Gemini instead, or add credits to your {gpt_provider.capitalize()} account."
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
# CIRCUIT BREAKER: Only try ONE fallback to prevent expensive API calls
|
||||
fallback_providers = ["google", "huggingface"]
|
||||
fallback_providers = [p for p in fallback_providers if p in available_providers and p != gpt_provider]
|
||||
|
||||
@@ -353,7 +353,11 @@ def wavespeed_text_response(
|
||||
|
||||
raise Exception(f"WaveSpeed text generation failed: {str(e)}")
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_wavespeed_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
stop=stop_after_attempt(6),
|
||||
)
|
||||
def wavespeed_structured_json_response(
|
||||
prompt: str,
|
||||
schema: Dict[str, Any],
|
||||
@@ -608,4 +612,20 @@ def wavespeed_structured_json_response(
|
||||
error_msg = str(e) if str(e) else repr(e)
|
||||
error_type = type(e).__name__
|
||||
logger.error(f"❌ WaveSpeed structured JSON generation failed [{error_type}]: {error_msg}")
|
||||
|
||||
# Surface balance/quota errors as HTTPException so upstream can show user-friendly messages
|
||||
from fastapi import HTTPException
|
||||
if "balance_not_enough" in error_msg or "403" in error_msg or "PermissionDenied" in error_type:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": "insufficient_balance",
|
||||
"message": "WaveSpeed API balance is insufficient. Please top up your WaveSpeed account or switch to a different provider.",
|
||||
"usage_info": {
|
||||
"error_type": "insufficient_balance",
|
||||
"provider": "wavespeed",
|
||||
"suggestion": "Set GPT_PROVIDER=google in your environment to use Gemini instead, or add credits to your WaveSpeed account."
|
||||
}
|
||||
}
|
||||
)
|
||||
raise Exception(f"WaveSpeed structured JSON generation failed: {error_msg}")
|
||||
|
||||
@@ -5,6 +5,8 @@ This service handles:
|
||||
- Chart data extraction from research
|
||||
- Individual scene B-roll video generation
|
||||
- Final video composition from multiple B-roll scenes
|
||||
|
||||
Chart preview generation is delegated to the shared ChartService.
|
||||
"""
|
||||
|
||||
import json
|
||||
@@ -15,21 +17,18 @@ from pathlib import Path
|
||||
from typing import Dict, Any, Optional, List, TYPE_CHECKING
|
||||
from loguru import logger
|
||||
|
||||
# Import chart generators directly
|
||||
# Import video compositing from broll_composer
|
||||
from services.podcast.broll_composer import (
|
||||
Insight,
|
||||
SceneAssets,
|
||||
dispatch_scene,
|
||||
compose_video,
|
||||
make_bar_chart,
|
||||
make_horizontal_bar,
|
||||
make_line_trend,
|
||||
make_pie_chart,
|
||||
make_stacked_bar,
|
||||
make_bullet_overlay,
|
||||
make_insight_card,
|
||||
)
|
||||
|
||||
# Import shared chart service for preview generation
|
||||
from services.chart_service import ChartService, get_chart_service
|
||||
|
||||
|
||||
class BrollService:
|
||||
"""Orchestrates B-roll composition for podcast scenes."""
|
||||
@@ -42,13 +41,14 @@ class BrollService:
|
||||
output_dir: Base directory for B-roll output. Defaults to workspace chart directory.
|
||||
user_id: User ID for multi-tenant workspace isolation.
|
||||
"""
|
||||
self._user_id = user_id
|
||||
if output_dir:
|
||||
self.output_dir = Path(output_dir)
|
||||
else:
|
||||
self.output_dir = self._get_chart_dir(user_id)
|
||||
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
logger.warning(f"[BrollService] Initialized with output directory: {self.output_dir}")
|
||||
logger.info(f"[BrollService] Initialized with output directory: {self.output_dir}")
|
||||
|
||||
def _get_chart_dir(self, user_id: Optional[str] = None) -> Path:
|
||||
"""Get chart directory from podcast constants (workspace-aware)."""
|
||||
@@ -78,145 +78,22 @@ class BrollService:
|
||||
"""
|
||||
Generate a chart PNG preview (static, for Write phase).
|
||||
|
||||
Args:
|
||||
chart_data: Chart data dict with labels, before/after, etc.
|
||||
chart_type: Type of chart (bar_comparison, bar_horizontal, line_trend, pie, stacked_bar, bullet)
|
||||
title: Title for the chart
|
||||
subtitle: Optional subtitle at bottom
|
||||
|
||||
Returns:
|
||||
Path to generated PNG file
|
||||
Delegates to ChartService for rendering, then returns the local file path.
|
||||
"""
|
||||
resolved_chart_id = chart_id or uuid.uuid4().hex[:8]
|
||||
out_path = str(self.get_chart_preview_path(resolved_chart_id))
|
||||
|
||||
# Debug logging
|
||||
logger.warning(f"[BrollService] Generating: type={chart_type}, data keys={list(chart_data.keys())}")
|
||||
logger.info(f"[BrollService] Generating chart preview: type={chart_type}, id={resolved_chart_id}")
|
||||
|
||||
try:
|
||||
if chart_type == "bar_comparison":
|
||||
# Accept both formats: {labels, before, after} OR {labels, values}
|
||||
labels = chart_data.get("labels", [])
|
||||
before = chart_data.get("before", [])
|
||||
after = chart_data.get("after", [])
|
||||
# If using new format (labels, values), treat as single bar chart
|
||||
if not before and not after:
|
||||
values = chart_data.get("values", [])
|
||||
if values:
|
||||
# Normalize to same length, truncating or padding as needed
|
||||
n = min(len(labels), len(values))
|
||||
labels = labels[:n]
|
||||
before = [0] * n
|
||||
after = values[:n]
|
||||
# Create modified data dict with proper format for make_bar_chart
|
||||
chart_data_for_render = {
|
||||
"labels": labels,
|
||||
"before": before,
|
||||
"after": after
|
||||
}
|
||||
else:
|
||||
chart_data_for_render = chart_data
|
||||
else:
|
||||
chart_data_for_render = chart_data
|
||||
if not labels or (not before and not after):
|
||||
logger.warning(f"[BrollService] Missing required data for bar_comparison: labels={len(labels)}, before={len(before)}, after={len(after)}")
|
||||
return ""
|
||||
if len(labels) != len(before) or len(labels) != len(after):
|
||||
logger.warning(f"[BrollService] Data shape mismatch: labels={len(labels)}, before={len(before)}, after={len(after)}")
|
||||
return ""
|
||||
make_bar_chart(chart_data_for_render, out_path, title, subtitle=subtitle)
|
||||
logger.warning(f"[BrollService] bar_comparison rendered: {out_path}, exists={os.path.exists(out_path)}")
|
||||
elif chart_type == "bar_horizontal":
|
||||
labels = chart_data.get("labels", [])
|
||||
values = chart_data.get("values", [])
|
||||
if not labels or not values:
|
||||
logger.warning("[BrollService] Missing required data for bar_horizontal")
|
||||
return ""
|
||||
make_horizontal_bar(chart_data, out_path, title)
|
||||
logger.warning(f"[BrollService] bar_horizontal rendered: {out_path}, exists={os.path.exists(out_path)}")
|
||||
elif chart_type == "line_trend":
|
||||
labels = chart_data.get("labels", [])
|
||||
values = chart_data.get("values", [])
|
||||
if not labels or not values:
|
||||
logger.warning("[BrollService] Missing required data for line_trend")
|
||||
return ""
|
||||
make_line_trend(chart_data, out_path, title)
|
||||
logger.warning(f"[BrollService] line_trend rendered: {out_path}, exists={os.path.exists(out_path)}")
|
||||
elif chart_type == "pie":
|
||||
labels = chart_data.get("labels", [])
|
||||
values = chart_data.get("values", [])
|
||||
if not labels or not values:
|
||||
logger.warning("[BrollService] Missing required data for pie")
|
||||
return ""
|
||||
make_pie_chart(chart_data, out_path, title)
|
||||
logger.warning(f"[BrollService] pie rendered: {out_path}, exists={os.path.exists(out_path)}")
|
||||
elif chart_type == "stacked_bar":
|
||||
labels = chart_data.get("labels", [])
|
||||
segments = chart_data.get("segments", [])
|
||||
if not labels or not segments:
|
||||
logger.warning("[BrollService] Missing required data for stacked_bar")
|
||||
return ""
|
||||
make_stacked_bar(chart_data, out_path, title)
|
||||
logger.warning(f"[BrollService] stacked_bar rendered: {out_path}, exists={os.path.exists(out_path)}")
|
||||
elif chart_type == "bullet" or chart_type == "bullet_points":
|
||||
# Accept both: bullet_points OR labels
|
||||
bullet_points = chart_data.get("bullet_points", [])
|
||||
# If using new format, use labels as bullet points
|
||||
if not bullet_points:
|
||||
bullet_points = chart_data.get("labels", [])
|
||||
if not bullet_points:
|
||||
labels_fallback = chart_data.get("labels", [])
|
||||
if labels_fallback:
|
||||
bullet_points = labels_fallback
|
||||
if bullet_points:
|
||||
make_bullet_overlay(bullet_points, out_path)
|
||||
logger.warning(f"[BrollService] bullet_points rendered: {out_path}, exists={os.path.exists(out_path)}")
|
||||
else:
|
||||
logger.warning("[BrollService] No bullet points provided")
|
||||
return ""
|
||||
else:
|
||||
logger.warning(f"[BrollService] Unknown chart type: {chart_type}, falling back to bar_comparison")
|
||||
# Try bar_comparison as fallback
|
||||
try:
|
||||
make_bar_chart(chart_data, out_path, title, subtitle=subtitle)
|
||||
return out_path
|
||||
except Exception as fallback_err:
|
||||
logger.warning(f"[BrollService] Fallback also failed: {fallback_err}")
|
||||
return ""
|
||||
|
||||
logger.warning(f"[BrollService] Chart preview generated: {out_path}, exists={os.path.exists(out_path) if out_path else 'N/A'}")
|
||||
|
||||
# Add source attribution overlay if present
|
||||
source = chart_data.get("source", "").strip()
|
||||
if source and out_path and os.path.exists(out_path):
|
||||
try:
|
||||
from PIL import Image as PILImage, ImageDraw, ImageFont
|
||||
img = PILImage.open(out_path).convert("RGBA")
|
||||
draw = ImageDraw.Draw(img)
|
||||
source_text = f"Source: {source[:80]}"
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 11)
|
||||
except (OSError, IOError):
|
||||
try:
|
||||
font = ImageFont.truetype("arial.ttf", 11)
|
||||
except (OSError, IOError):
|
||||
font = ImageFont.load_default()
|
||||
text_bbox = draw.textbbox((0, 0), source_text, font=font)
|
||||
text_w = text_bbox[2] - text_bbox[0]
|
||||
text_h = text_bbox[3] - text_bbox[1]
|
||||
x = img.width - text_w - 12
|
||||
y = img.height - text_h - 8
|
||||
draw.rectangle([x - 4, y - 2, x + text_w + 4, y + text_h + 2], fill=(0, 0, 0, 140))
|
||||
draw.text((x, y), source_text, fill=(200, 200, 200, 220), font=font)
|
||||
img.save(out_path)
|
||||
except Exception as src_err:
|
||||
logger.warning(f"[BrollService] Source overlay failed (non-fatal): {src_err}")
|
||||
|
||||
return out_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[BrollService] Failed to generate chart preview: {e}")
|
||||
return ""
|
||||
chart_svc = get_chart_service(user_id=self._user_id)
|
||||
result = chart_svc.generate_chart(
|
||||
chart_data=chart_data,
|
||||
chart_type=chart_type,
|
||||
title=title,
|
||||
subtitle=subtitle or "",
|
||||
chart_id=resolved_chart_id,
|
||||
)
|
||||
|
||||
return result.get("path", "")
|
||||
|
||||
def generate_scene_broll(
|
||||
self,
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from dataclasses import dataclass
|
||||
import httpx
|
||||
from loguru import logger
|
||||
import random
|
||||
|
||||
@@ -24,13 +23,6 @@ class WritingAssistantService:
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.exa_api_key = os.getenv("EXA_API_KEY")
|
||||
|
||||
if not self.exa_api_key:
|
||||
logger.warning("EXA_API_KEY not configured; writing assistant will fail")
|
||||
|
||||
self.http_timeout_seconds = 15
|
||||
|
||||
# COST CONTROL: Daily usage limits
|
||||
self.daily_api_calls = 0
|
||||
self.daily_limit = 50 # Max 50 API calls per day (~$2.50 max cost)
|
||||
@@ -76,7 +68,7 @@ class WritingAssistantService:
|
||||
return []
|
||||
|
||||
# 1) Find relevant sources via Exa
|
||||
sources = await self._search_sources(text)
|
||||
sources = await self._search_sources(text, user_id=user_id)
|
||||
|
||||
# 2) Generate continuation suggestion via LLM grounded in sources
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources, user_id=user_id)
|
||||
@@ -86,51 +78,38 @@ class WritingAssistantService:
|
||||
|
||||
return [WritingSuggestion(text=suggestion_text.strip(), confidence=confidence, sources=sources)]
|
||||
|
||||
async def _search_sources(self, text: str) -> List[Dict[str, Any]]:
|
||||
if not self.exa_api_key:
|
||||
raise Exception("EXA_API_KEY not configured")
|
||||
|
||||
# Follow Exa demo guidance: continuation-style prompt and 1000-char cap
|
||||
exa_query = (
|
||||
(text[-1000:] if len(text) > 1000 else text)
|
||||
+ "\n\nIf you found the above interesting, here's another useful resource to read:"
|
||||
)
|
||||
|
||||
payload = {
|
||||
"query": exa_query,
|
||||
"numResults": 3, # Reduced from 5 to 3 for cost savings
|
||||
"text": True,
|
||||
"type": "neural",
|
||||
"highlights": {"numSentences": 1, "highlightsPerUrl": 1},
|
||||
}
|
||||
|
||||
async def _search_sources(self, text: str, user_id: str = None) -> List[Dict[str, Any]]:
|
||||
"""Search for relevant sources using ExaResearchProvider with subscription checks."""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.http_timeout_seconds) as client:
|
||||
resp = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Exa error {resp.status_code}: {resp.text}")
|
||||
data = resp.json()
|
||||
results = data.get("results", [])
|
||||
sources: List[Dict[str, Any]] = []
|
||||
for r in results:
|
||||
sources.append(
|
||||
{
|
||||
"title": r.get("title", "Untitled"),
|
||||
"url": r.get("url", ""),
|
||||
"text": r.get("text", ""),
|
||||
"author": r.get("author", ""),
|
||||
"published_date": r.get("publishedDate", ""),
|
||||
"score": float(r.get("score", 0.5)),
|
||||
}
|
||||
)
|
||||
# Explicitly fail if no sources to avoid generic completions
|
||||
if not sources:
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
|
||||
exa_query = (
|
||||
(text[-1000:] if len(text) > 1000 else text)
|
||||
+ "\n\nIf you found the above interesting, here's another useful resource to read:"
|
||||
)
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
sources = await provider.simple_search(
|
||||
query=exa_query,
|
||||
num_results=3,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Normalize keys to match expected format
|
||||
normalized = []
|
||||
for s in sources:
|
||||
normalized.append({
|
||||
"title": s.get("title", "Untitled"),
|
||||
"url": s.get("url", ""),
|
||||
"text": s.get("text", ""),
|
||||
"author": s.get("author", ""),
|
||||
"published_date": s.get("publishedDate", ""),
|
||||
"score": float(s.get("score", 0.5)),
|
||||
})
|
||||
|
||||
if not normalized:
|
||||
raise Exception("No relevant sources found from Exa for the current context")
|
||||
return sources
|
||||
return normalized
|
||||
except Exception as e:
|
||||
logger.error(f"WritingAssistant _search_sources error: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user