On main: session-work-2026-05-22
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import re
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
from dataclasses import dataclass
|
||||
from loguru import logger
|
||||
import random
|
||||
@@ -17,42 +18,33 @@ class WritingSuggestion:
|
||||
|
||||
class WritingAssistantService:
|
||||
"""
|
||||
Minimal writing assistant that combines Exa search with Gemini continuation.
|
||||
- Exa provides relevant sources with content snippets
|
||||
- Gemini generates a short, cited continuation based on current text and sources
|
||||
Writing assistant that combines Exa search with LLM continuation.
|
||||
- Searches relevant sources using the content near the cursor position
|
||||
- Generates a short continuation grounded in sources
|
||||
- Confidence derived from source availability and quality
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
# COST CONTROL: Daily usage limits
|
||||
self.daily_api_calls = 0
|
||||
self.daily_limit = 50 # Max 50 API calls per day (~$2.50 max cost)
|
||||
self.daily_limit = 50
|
||||
self.last_reset_date = None
|
||||
|
||||
def _get_cached_suggestion(self, text: str) -> WritingSuggestion | None:
|
||||
"""No cached suggestions - always use real API calls for authentic results."""
|
||||
return None
|
||||
|
||||
def _check_daily_limit(self) -> bool:
|
||||
"""Check if we're within daily API usage limits."""
|
||||
import datetime
|
||||
|
||||
today = datetime.date.today()
|
||||
|
||||
# Reset counter if it's a new day
|
||||
if self.last_reset_date != today:
|
||||
self.daily_api_calls = 0
|
||||
self.last_reset_date = today
|
||||
|
||||
# Check if we've exceeded the limit
|
||||
if self.daily_api_calls >= self.daily_limit:
|
||||
return False
|
||||
|
||||
# Increment counter for this API call
|
||||
self.daily_api_calls += 1
|
||||
logger.info(f"Writing assistant API call #{self.daily_api_calls}/{self.daily_limit} today")
|
||||
return True
|
||||
|
||||
async def suggest(self, text: str, user_id: str | None = None) -> List[WritingSuggestion]:
|
||||
async def suggest(self, text: str, user_id: str | None = None, cursor_position: Optional[int] = None) -> List[WritingSuggestion]:
|
||||
if not text or len(text.strip()) < 6:
|
||||
return []
|
||||
|
||||
@@ -67,26 +59,41 @@ class WritingAssistantService:
|
||||
if len(text.strip()) < 50:
|
||||
return []
|
||||
|
||||
# 1) Find relevant sources via Exa
|
||||
sources = await self._search_sources(text, user_id=user_id)
|
||||
# Use text before cursor for context (where the user is actively writing)
|
||||
if cursor_position is not None and 0 < cursor_position <= len(text):
|
||||
context_text = text[:cursor_position]
|
||||
else:
|
||||
context_text = text
|
||||
|
||||
# 2) Generate continuation suggestion via LLM grounded in sources
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources, user_id=user_id)
|
||||
# 1) Find relevant sources via Exa (non-fatal)
|
||||
sources = []
|
||||
try:
|
||||
sources = await self._search_sources(context_text, user_id=user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"WritingAssistant Exa search failed, proceeding without sources: {e}")
|
||||
|
||||
# 2) Generate continuation suggestion via LLM
|
||||
suggestion_text, confidence = await self._generate_continuation(context_text, sources, user_id=user_id)
|
||||
|
||||
if not suggestion_text:
|
||||
return []
|
||||
|
||||
return [WritingSuggestion(text=suggestion_text.strip(), confidence=confidence, sources=sources)]
|
||||
|
||||
async def _search_sources(self, text: str, user_id: str = None) -> List[Dict[str, Any]]:
|
||||
"""Search for relevant sources using ExaResearchProvider with subscription checks."""
|
||||
async def _search_sources(self, context_text: str, user_id: str = None) -> List[Dict[str, Any]]:
|
||||
"""Search Exa using the last sentence before cursor for a focused query."""
|
||||
try:
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
|
||||
exa_query = (
|
||||
(text[-1000:] if len(text) > 1000 else text)
|
||||
+ "\n\nIf you found the above interesting, here's another useful resource to read:"
|
||||
)
|
||||
# Extract the last sentence from context to use as a focused search query
|
||||
sentences = re.split(r'(?<=[.!?])\s+', context_text.strip())
|
||||
last_sentence = sentences[-1].strip().strip('"').strip("'") if sentences else context_text
|
||||
|
||||
# If very short, use last two sentences
|
||||
if len(last_sentence) < 20 and len(sentences) >= 2:
|
||||
last_sentence = ' '.join(s[-2:]).strip().strip('"').strip("'")
|
||||
|
||||
exa_query = last_sentence[:500] if len(last_sentence) > 500 else last_sentence
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
sources = await provider.simple_search(
|
||||
@@ -95,7 +102,6 @@ class WritingAssistantService:
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Normalize keys to match expected format
|
||||
normalized = []
|
||||
for s in sources:
|
||||
normalized.append({
|
||||
@@ -104,7 +110,7 @@ class WritingAssistantService:
|
||||
"text": s.get("text", ""),
|
||||
"author": s.get("author", ""),
|
||||
"published_date": s.get("publishedDate", ""),
|
||||
"score": float(s.get("score", 0.5)),
|
||||
"score": float(s.get("score") if s.get("score") is not None else 0.5),
|
||||
})
|
||||
|
||||
if not normalized:
|
||||
@@ -151,8 +157,21 @@ class WritingAssistantService:
|
||||
suggestion = (str(ai_resp or "")).strip()
|
||||
if not suggestion:
|
||||
raise Exception("Assistive writer returned empty suggestion")
|
||||
confidence = 0.7
|
||||
return suggestion, confidence
|
||||
|
||||
# Dynamic confidence based on source quality and response signals
|
||||
confidence = 0.5
|
||||
if sources:
|
||||
# More sources and higher scores = more confident
|
||||
avg_score = sum(s.get("score", 0.5) for s in sources) / len(sources)
|
||||
confidence = 0.5 + (len(sources) / 6.0) * 0.3 + avg_score * 0.2
|
||||
if suggestion.endswith(('.', '!', '?')):
|
||||
confidence += 0.05
|
||||
# Check if citation hint was included
|
||||
if '[http' in suggestion or '((' in suggestion:
|
||||
confidence += 0.05
|
||||
confidence = min(confidence, 1.0)
|
||||
|
||||
return suggestion, round(confidence, 2)
|
||||
except Exception as e:
|
||||
logger.error(f"WritingAssistant _generate_continuation error: {e}")
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user