Enable phase-3 task memory filtering and add coverage
This commit is contained in:
@@ -13,6 +13,10 @@ from sqlalchemy.orm import Session
|
|||||||
from models.daily_workflow_models import TaskHistory, DailyWorkflowTask
|
from models.daily_workflow_models import TaskHistory, DailyWorkflowTask
|
||||||
from services.intelligence.txtai_service import TxtaiIntelligenceService
|
from services.intelligence.txtai_service import TxtaiIntelligenceService
|
||||||
|
|
||||||
|
EXACT_DUPLICATE_LOOKBACK_DAYS = 7
|
||||||
|
SEMANTIC_SUPPRESSION_SCORE_THRESHOLD = 0.85
|
||||||
|
SUPPRESSED_STATUSES = {"dismissed", "rejected"}
|
||||||
|
|
||||||
class TaskMemoryService:
|
class TaskMemoryService:
|
||||||
"""
|
"""
|
||||||
Manages the long-term memory of user tasks.
|
Manages the long-term memory of user tasks.
|
||||||
@@ -96,7 +100,7 @@ class TaskMemoryService:
|
|||||||
filtered = []
|
filtered = []
|
||||||
|
|
||||||
# Get recent history hashes (last 7 days)
|
# Get recent history hashes (last 7 days)
|
||||||
cutoff = datetime.utcnow() - timedelta(days=7)
|
cutoff = datetime.utcnow() - timedelta(days=EXACT_DUPLICATE_LOOKBACK_DAYS)
|
||||||
recent_hashes = {
|
recent_hashes = {
|
||||||
row.task_hash for row in
|
row.task_hash for row in
|
||||||
self.db.query(TaskHistory.task_hash)
|
self.db.query(TaskHistory.task_hash)
|
||||||
@@ -117,23 +121,39 @@ class TaskMemoryService:
|
|||||||
is_semantic_duplicate = False
|
is_semantic_duplicate = False
|
||||||
try:
|
try:
|
||||||
# Check if similar tasks were REJECTED recently
|
# Check if similar tasks were REJECTED recently
|
||||||
results = self.intelligence.search(
|
results = await self.intelligence.search(
|
||||||
f"{p.title} {p.description}",
|
f"{p.title} {p.description}",
|
||||||
limit=1
|
limit=1
|
||||||
)
|
)
|
||||||
|
|
||||||
if results:
|
if results:
|
||||||
top = results[0]
|
top = results[0]
|
||||||
# If very similar (>0.85) and was REJECTED/DISMISSED
|
top_score = float(top.get("score", 0))
|
||||||
# We might need to fetch the metadata from the result if txtai returns it
|
if top_score >= SEMANTIC_SUPPRESSION_SCORE_THRESHOLD:
|
||||||
# For now, this is a heuristic stub. Txtai search returns dict with 'id', 'score', 'text', etc.
|
indexed_status = self._extract_indexed_status(top)
|
||||||
# If we stored 'status' in metadata, we check it.
|
if indexed_status in SUPPRESSED_STATUSES:
|
||||||
|
logger.info(
|
||||||
if top['score'] > 0.85:
|
f"Filtering redundant task (semantic {top_score:.2f}, indexed status={indexed_status}): {p.title}"
|
||||||
# Retrieve status from DB using vector_id if needed, or if metadata is returned
|
)
|
||||||
# Assuming we want to avoid repeating REJECTED ideas
|
is_semantic_duplicate = True
|
||||||
# This requires storing 'status' in the index metadata
|
else:
|
||||||
pass
|
vector_id = top.get("id") or top.get("vector_id")
|
||||||
|
if vector_id:
|
||||||
|
history = (
|
||||||
|
self.db.query(TaskHistory.status)
|
||||||
|
.filter(
|
||||||
|
TaskHistory.user_id == self.user_id,
|
||||||
|
TaskHistory.vector_id == str(vector_id),
|
||||||
|
)
|
||||||
|
.order_by(TaskHistory.created_at.desc())
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
history_status = getattr(history, "status", None)
|
||||||
|
if history_status in SUPPRESSED_STATUSES:
|
||||||
|
logger.info(
|
||||||
|
f"Filtering redundant task (semantic {top_score:.2f}, history status={history_status}): {p.title}"
|
||||||
|
)
|
||||||
|
is_semantic_duplicate = True
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -141,3 +161,16 @@ class TaskMemoryService:
|
|||||||
filtered.append(p)
|
filtered.append(p)
|
||||||
|
|
||||||
return filtered
|
return filtered
|
||||||
|
|
||||||
|
def _extract_indexed_status(self, search_result: Dict[str, Any]) -> Optional[str]:
|
||||||
|
"""Extract indexed status from txtai result metadata if available."""
|
||||||
|
status = search_result.get("status")
|
||||||
|
if status:
|
||||||
|
return str(status).lower()
|
||||||
|
|
||||||
|
obj = search_result.get("object")
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
obj_status = obj.get("status")
|
||||||
|
return str(obj_status).lower() if obj_status else None
|
||||||
|
|
||||||
|
return None
|
||||||
|
|||||||
@@ -204,8 +204,7 @@ async def generate_agent_enhanced_plan(db: Session, user_id: str, date: str) ->
|
|||||||
agent_tasks = list(unique_map.values())
|
agent_tasks = list(unique_map.values())
|
||||||
|
|
||||||
# Phase 3: Check memory for rejections (Semantic Filter)
|
# Phase 3: Check memory for rejections (Semantic Filter)
|
||||||
# For now, we rely on exact match logic in memory service if implemented fully
|
agent_tasks = await memory_service.filter_redundant_proposals(agent_tasks)
|
||||||
# agent_tasks = await memory_service.filter_redundant_proposals(agent_tasks)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Committee proposal phase failed: {e}")
|
logger.error(f"Committee proposal phase failed: {e}")
|
||||||
|
|||||||
101
backend/test/test_task_memory_service.py
Normal file
101
backend/test/test_task_memory_service.py
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
|
from models.daily_workflow_models import TaskHistory
|
||||||
|
from models.enhanced_strategy_models import Base
|
||||||
|
from services.task_memory_service import TaskMemoryService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_session():
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
Base.metadata.create_all(engine, tables=[TaskHistory.__table__])
|
||||||
|
SessionLocal = sessionmaker(bind=engine)
|
||||||
|
session = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield session
|
||||||
|
finally:
|
||||||
|
session.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_redundant_proposals_suppresses_exact_hash_duplicates(db_session):
|
||||||
|
service = TaskMemoryService(user_id="user-1", db=db_session)
|
||||||
|
service.intelligence = SimpleNamespace(search=AsyncMock(return_value=[]))
|
||||||
|
|
||||||
|
title = "Create LinkedIn post"
|
||||||
|
description = "Draft a post about customer success stories"
|
||||||
|
|
||||||
|
db_session.add(
|
||||||
|
TaskHistory(
|
||||||
|
user_id="user-1",
|
||||||
|
task_hash=service._compute_hash(title, description),
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
pillar_id="engage",
|
||||||
|
status="completed",
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
vector_id="vec-exact",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
proposals = [SimpleNamespace(title=title, description=description)]
|
||||||
|
filtered = await service.filter_redundant_proposals(proposals)
|
||||||
|
|
||||||
|
assert filtered == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_redundant_proposals_suppresses_semantic_dismissed_by_vector_id_lookup(db_session):
|
||||||
|
service = TaskMemoryService(user_id="user-2", db=db_session)
|
||||||
|
service.intelligence = SimpleNamespace(
|
||||||
|
search=AsyncMock(return_value=[{"id": "vec-dismissed", "score": 0.93}])
|
||||||
|
)
|
||||||
|
|
||||||
|
db_session.add(
|
||||||
|
TaskHistory(
|
||||||
|
user_id="user-2",
|
||||||
|
task_hash="hash-1",
|
||||||
|
title="Old task",
|
||||||
|
description="Old description",
|
||||||
|
pillar_id="plan",
|
||||||
|
status="dismissed",
|
||||||
|
created_at=datetime.utcnow(),
|
||||||
|
vector_id="vec-dismissed",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
proposals = [
|
||||||
|
SimpleNamespace(
|
||||||
|
title="Plan daily content topics",
|
||||||
|
description="Choose 3 content ideas for this week",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = await service.filter_redundant_proposals(proposals)
|
||||||
|
|
||||||
|
assert filtered == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_filter_redundant_proposals_keeps_non_duplicates(db_session):
|
||||||
|
service = TaskMemoryService(user_id="user-3", db=db_session)
|
||||||
|
service.intelligence = SimpleNamespace(
|
||||||
|
search=AsyncMock(return_value=[{"id": "vec-completed", "score": 0.40}])
|
||||||
|
)
|
||||||
|
|
||||||
|
proposal = SimpleNamespace(
|
||||||
|
title="Write newsletter intro",
|
||||||
|
description="Prepare a short intro for the weekly newsletter",
|
||||||
|
)
|
||||||
|
filtered = await service.filter_redundant_proposals([proposal])
|
||||||
|
|
||||||
|
assert filtered == [proposal]
|
||||||
|
service.intelligence.search.assert_awaited_once()
|
||||||
Reference in New Issue
Block a user