from __future__ import annotations from dataclasses import asdict, dataclass, field from datetime import datetime from typing import Any, Dict, List, Optional, Union from sqlalchemy import func from sqlalchemy.orm import Session from models.agent_activity_models import AgentAlert, AgentApprovalRequest, AgentEvent, AgentRun @dataclass class AgentEventPayload: """Shared schema for agent activity event payloads.""" phase: Optional[str] = None step: Optional[str] = None tool_name: Optional[str] = None progress_percent: Optional[float] = None input_summary: Optional[str] = None output_summary: Optional[str] = None decision_reason: Optional[str] = None evidence_refs: List[str] = field(default_factory=list) safe_debug: bool = True metadata: Dict[str, Any] = field(default_factory=dict) def build_agent_event_payload( *, phase: Optional[str] = None, step: Optional[str] = None, tool_name: Optional[str] = None, progress_percent: Optional[float] = None, input_summary: Optional[str] = None, output_summary: Optional[str] = None, decision_reason: Optional[str] = None, evidence_refs: Optional[List[str]] = None, safe_debug: bool = True, metadata: Optional[Dict[str, Any]] = None, ) -> Dict[str, Any]: return asdict( AgentEventPayload( phase=phase, step=step, tool_name=tool_name, progress_percent=progress_percent, input_summary=input_summary, output_summary=output_summary, decision_reason=decision_reason, evidence_refs=list(evidence_refs or []), safe_debug=bool(safe_debug), metadata=dict(metadata or {}), ) ) def _normalize_event_payload(payload: Optional[Union[Dict[str, Any], AgentEventPayload]]) -> Dict[str, Any]: if payload is None: return build_agent_event_payload() if isinstance(payload, AgentEventPayload): return asdict(payload) if not isinstance(payload, dict): return build_agent_event_payload(output_summary=str(payload)[:2000], safe_debug=False) return build_agent_event_payload( phase=payload.get("phase"), step=payload.get("step"), tool_name=payload.get("tool_name"), progress_percent=payload.get("progress_percent"), input_summary=payload.get("input_summary"), output_summary=payload.get("output_summary"), decision_reason=payload.get("decision_reason"), evidence_refs=payload.get("evidence_refs") if isinstance(payload.get("evidence_refs"), list) else [], safe_debug=bool(payload.get("safe_debug", True)), metadata=payload.get("metadata") if isinstance(payload.get("metadata"), dict) else {}, ) class AgentActivityService: def __init__(self, db: Session, user_id: str): self.db = db self.user_id = user_id def start_run(self, agent_type: str, prompt: Optional[str] = None, mlflow_run_id: Optional[str] = None) -> AgentRun: run = AgentRun( user_id=self.user_id, agent_type=agent_type, prompt=prompt, status="running", mlflow_run_id=mlflow_run_id, started_at=datetime.utcnow(), ) self.db.add(run) self.db.commit() self.db.refresh(run) return run def finish_run( self, run_id: int, success: bool, result_summary: Optional[str] = None, error_message: Optional[str] = None, ) -> None: run = self.db.query(AgentRun).filter(AgentRun.id == run_id, AgentRun.user_id == self.user_id).first() if not run: return run.status = "completed" if success else "failed" run.success = bool(success) run.result_summary = result_summary run.error_message = error_message run.finished_at = datetime.utcnow() self.db.add(run) self.db.commit() def log_event( self, event_type: str, severity: str = "info", message: Optional[str] = None, payload: Optional[Union[Dict[str, Any], AgentEventPayload]] = None, run_id: Optional[int] = None, agent_type: Optional[str] = None, ) -> AgentEvent: normalized_payload = _normalize_event_payload(payload) evt = AgentEvent( run_id=run_id, user_id=self.user_id, agent_type=agent_type, event_type=event_type, severity=severity, message=message, payload=normalized_payload, created_at=datetime.utcnow(), ) self.db.add(evt) self.db.commit() self.db.refresh(evt) return evt def create_alert( self, alert_type: str, title: str, message: str, severity: str = "info", payload: Optional[Dict[str, Any]] = None, cta_path: Optional[str] = None, dedupe_key: Optional[str] = None, ) -> Optional[AgentAlert]: if dedupe_key: existing = ( self.db.query(AgentAlert) .filter( AgentAlert.user_id == self.user_id, AgentAlert.dedupe_key == dedupe_key, AgentAlert.read_at.is_(None), ) .first() ) if existing: return None alert = AgentAlert( user_id=self.user_id, source="agents", alert_type=alert_type, severity=severity, title=title, message=message, cta_path=cta_path, payload=payload, dedupe_key=dedupe_key, created_at=datetime.utcnow(), ) self.db.add(alert) self.db.commit() self.db.refresh(alert) return alert def list_alerts(self, unread_only: bool = True, limit: int = 50) -> List[AgentAlert]: q = self.db.query(AgentAlert).filter(AgentAlert.user_id == self.user_id) if unread_only: q = q.filter(AgentAlert.read_at.is_(None)) return q.order_by(AgentAlert.created_at.desc()).limit(limit).all() def mark_alert_read(self, alert_id: int) -> bool: alert = self.db.query(AgentAlert).filter(AgentAlert.id == alert_id, AgentAlert.user_id == self.user_id).first() if not alert: return False alert.read_at = datetime.utcnow() self.db.add(alert) self.db.commit() return True def list_runs(self, limit: int = 30) -> List[AgentRun]: return ( self.db.query(AgentRun) .filter(AgentRun.user_id == self.user_id) .order_by(AgentRun.started_at.desc()) .limit(limit) .all() ) def list_events(self, run_id: Optional[int] = None, limit: int = 200) -> List[AgentEvent]: q = self.db.query(AgentEvent).filter(AgentEvent.user_id == self.user_id) if run_id is not None: q = q.filter(AgentEvent.run_id == run_id) return q.order_by(AgentEvent.created_at.desc()).limit(limit).all() def create_approval_request( self, action_id: str, action_type: str, risk_level: float, payload: Optional[Dict[str, Any]] = None, agent_type: Optional[str] = None, target_resource: Optional[str] = None, run_id: Optional[int] = None, expires_at: Optional[datetime] = None, ) -> AgentApprovalRequest: req = AgentApprovalRequest( user_id=self.user_id, run_id=run_id, agent_type=agent_type, action_id=action_id, action_type=action_type, target_resource=target_resource, risk_level=float(risk_level or 0.5), payload=payload, status="pending", expires_at=expires_at, created_at=datetime.utcnow(), ) self.db.add(req) self.db.commit() self.db.refresh(req) return req def list_approval_requests(self, status: Optional[str] = "pending", limit: int = 50) -> List[AgentApprovalRequest]: q = self.db.query(AgentApprovalRequest).filter(AgentApprovalRequest.user_id == self.user_id) if status: q = q.filter(AgentApprovalRequest.status == status) return q.order_by(AgentApprovalRequest.created_at.desc()).limit(limit).all() def decide_approval_request(self, approval_id: int, decision: str, user_comments: str = "") -> Optional[AgentApprovalRequest]: req = ( self.db.query(AgentApprovalRequest) .filter(AgentApprovalRequest.id == approval_id, AgentApprovalRequest.user_id == self.user_id) .first() ) if not req: return None decision_value = str(decision or "").lower().strip() if decision_value not in {"approved", "rejected"}: decision_value = "rejected" req.status = "approved" if decision_value == "approved" else "rejected" req.decision = decision_value req.user_comments = (user_comments or "")[:4000] req.decided_at = datetime.utcnow() self.db.add(req) self.db.commit() self.db.refresh(req) return req def get_huddle_feed( self, since: Optional[str] = None, cursor: Optional[str] = None, runs_limit: int = 20, events_limit: int = 50, alerts_limit: int = 20, approvals_limit: int = 20, ) -> Dict[str, Any]: now = datetime.utcnow() since_dt = self._parse_datetime(since) cursor_dt = self._parse_datetime(cursor) statuses = self._get_active_statuses() runs = self._list_runs_for_feed(limit=runs_limit, since_dt=since_dt, cursor_dt=cursor_dt) events = self._list_events_for_feed(limit=events_limit, since_dt=since_dt, cursor_dt=cursor_dt) alerts = self._list_alerts_for_feed(limit=alerts_limit, since_dt=since_dt, cursor_dt=cursor_dt) approvals = self._list_approvals_for_feed(limit=approvals_limit, since_dt=since_dt, cursor_dt=cursor_dt) cursors = { "runs": self._next_cursor(runs, "started_at"), "events": self._next_cursor(events, "created_at"), "alerts": self._next_cursor(alerts, "created_at"), "approvals": self._next_cursor(approvals, "created_at"), "feed": now.isoformat(), } return { "statuses": statuses, "runs": [self._serialize_run(run) for run in runs], "events": [self._serialize_event(evt) for evt in events], "alerts": [self._serialize_alert(alert) for alert in alerts], "approvals": [self._serialize_approval(req) for req in approvals], "unread_alerts": self._count_unread_alerts(), "pending_approvals": self._count_pending_approvals(), "cursors": cursors, "server_timestamp": now.isoformat(), } @staticmethod def _parse_datetime(value: Optional[str]) -> Optional[datetime]: if not value: return None text = str(value).strip() if not text: return None if text.endswith("Z"): text = text.replace("Z", "+00:00") try: parsed = datetime.fromisoformat(text) if parsed.tzinfo is not None: return parsed.replace(tzinfo=None) return parsed except ValueError: return None def _get_active_statuses(self) -> List[Dict[str, Any]]: subquery = ( self.db.query( AgentRun.agent_type.label("agent_type"), func.max(AgentRun.started_at).label("max_started_at"), ) .filter(AgentRun.user_id == self.user_id) .group_by(AgentRun.agent_type) .subquery() ) rows = ( self.db.query(AgentRun) .join( subquery, (AgentRun.agent_type == subquery.c.agent_type) & (AgentRun.started_at == subquery.c.max_started_at), ) .filter(AgentRun.user_id == self.user_id) .all() ) return [ { "agent_type": row.agent_type, "status": row.status, "success": row.success, "run_id": row.id, "updated_at": (row.finished_at or row.started_at).isoformat() if (row.finished_at or row.started_at) else None, } for row in rows ] def _list_runs_for_feed(self, limit: int, since_dt: Optional[datetime], cursor_dt: Optional[datetime]) -> List[AgentRun]: q = self.db.query(AgentRun).filter(AgentRun.user_id == self.user_id) if since_dt: q = q.filter(AgentRun.started_at >= since_dt) if cursor_dt: q = q.filter(AgentRun.started_at < cursor_dt) return q.order_by(AgentRun.started_at.desc()).limit(limit).all() def _list_events_for_feed(self, limit: int, since_dt: Optional[datetime], cursor_dt: Optional[datetime]) -> List[AgentEvent]: q = self.db.query(AgentEvent).filter(AgentEvent.user_id == self.user_id) if since_dt: q = q.filter(AgentEvent.created_at >= since_dt) if cursor_dt: q = q.filter(AgentEvent.created_at < cursor_dt) return q.order_by(AgentEvent.created_at.desc()).limit(limit).all() def _list_alerts_for_feed(self, limit: int, since_dt: Optional[datetime], cursor_dt: Optional[datetime]) -> List[AgentAlert]: q = self.db.query(AgentAlert).filter(AgentAlert.user_id == self.user_id, AgentAlert.read_at.is_(None)) if since_dt: q = q.filter(AgentAlert.created_at >= since_dt) if cursor_dt: q = q.filter(AgentAlert.created_at < cursor_dt) return q.order_by(AgentAlert.created_at.desc()).limit(limit).all() def _list_approvals_for_feed( self, limit: int, since_dt: Optional[datetime], cursor_dt: Optional[datetime], ) -> List[AgentApprovalRequest]: q = self.db.query(AgentApprovalRequest).filter( AgentApprovalRequest.user_id == self.user_id, AgentApprovalRequest.status == "pending", ) if since_dt: q = q.filter(AgentApprovalRequest.created_at >= since_dt) if cursor_dt: q = q.filter(AgentApprovalRequest.created_at < cursor_dt) return q.order_by(AgentApprovalRequest.created_at.desc()).limit(limit).all() def _count_unread_alerts(self) -> int: return self.db.query(AgentAlert).filter( AgentAlert.user_id == self.user_id, AgentAlert.read_at.is_(None), ).count() def _count_pending_approvals(self) -> int: return self.db.query(AgentApprovalRequest).filter( AgentApprovalRequest.user_id == self.user_id, AgentApprovalRequest.status == "pending", ).count() @staticmethod def _next_cursor(items: List[Any], time_attr: str) -> Optional[str]: if not items: return None ts = getattr(items[-1], time_attr, None) return ts.isoformat() if ts else None @staticmethod def _serialize_run(run: AgentRun) -> Dict[str, Any]: return { "id": run.id, "user_id": run.user_id, "agent_type": run.agent_type, "status": run.status, "success": run.success, "error_message": run.error_message, "result_summary": run.result_summary, "mlflow_run_id": run.mlflow_run_id, "started_at": run.started_at.isoformat() if run.started_at else None, "finished_at": run.finished_at.isoformat() if run.finished_at else None, } @staticmethod def _serialize_event(evt: AgentEvent) -> Dict[str, Any]: return { "id": evt.id, "run_id": evt.run_id, "agent_type": evt.agent_type, "event_type": evt.event_type, "severity": evt.severity, "message": evt.message, "payload": evt.payload, "created_at": evt.created_at.isoformat() if evt.created_at else None, } @staticmethod def _serialize_alert(alert: AgentAlert) -> Dict[str, Any]: return { "id": alert.id, "source": alert.source, "type": alert.alert_type, "severity": alert.severity, "title": alert.title, "message": alert.message, "cta_path": alert.cta_path, "payload": alert.payload, "created_at": alert.created_at.isoformat() if alert.created_at else None, "read_at": alert.read_at.isoformat() if alert.read_at else None, } @staticmethod def _serialize_approval(req: AgentApprovalRequest) -> Dict[str, Any]: return { "id": req.id, "status": req.status, "decision": req.decision, "action_id": req.action_id, "action_type": req.action_type, "agent_type": req.agent_type, "target_resource": req.target_resource, "risk_level": req.risk_level, "payload": req.payload, "created_at": req.created_at.isoformat() if req.created_at else None, "decided_at": req.decided_at.isoformat() if req.decided_at else None, }