Compare commits

..

1 Commits

Author SHA1 Message Date
ي
6fdf318d79 Add OAuth token refresh retries, status persistence, and alert payloads 2026-05-18 15:56:57 +05:30
6 changed files with 86 additions and 125 deletions

View File

@@ -40,6 +40,10 @@ class OAuthTokenMonitoringTask(Base):
# Scheduling
next_check = Column(DateTime, nullable=True, index=True) # Next scheduled check time
next_retry_at = Column(DateTime, nullable=True, index=True) # Backoff retry schedule for refresh failures
refresh_attempts = Column(Integer, default=0) # Current retry attempt count for refresh workflow
terminal_failure_reason = Column(Text, nullable=True) # Permanent failure reason requiring user action
channel_status = Column(String(32), default='connected') # connected, degraded, disconnected
# Metadata
created_at = Column(DateTime, default=datetime.utcnow)
@@ -97,4 +101,3 @@ class OAuthTokenExecutionLog(Base):
def __repr__(self):
return f"<OAuthTokenExecutionLog(id={self.id}, task_id={self.task_id}, status={self.status}, execution_date={self.execution_date})>"

View File

@@ -697,39 +697,6 @@ class BaseALwrityAgent(ABC):
"action_id": action.action_id,
"agent_id": self.agent_id,
}
capability_decision = self._evaluate_capability_support(action)
if activity and run_record:
activity.log_event(
event_type="decision",
severity="info" if capability_decision.get("supported", False) else "warning",
message=capability_decision.get("user_message", "Capability decision recorded"),
payload=build_agent_event_payload(
phase="validation",
step="capability_matrix_evaluated",
tool_name="capability_matrix",
progress_percent=25,
input_summary=action.action_type,
output_summary="Supported action" if capability_decision.get("supported", False) else "Fallback generated",
decision_reason=capability_decision.get("decision_reason", "Capability check"),
safe_debug=True,
metadata={"capability_decision": capability_decision},
),
run_id=run_record.id,
agent_type=self.agent_type,
)
if not capability_decision.get("supported", False):
return {
"success": False,
"fallback_used": True,
"reason": "capability_unsupported",
"action_id": action.action_id,
"agent_id": self.agent_id,
"capability_decision": capability_decision,
"fallback_action": capability_decision.get("fallback_action"),
"user_message": capability_decision.get("user_message"),
}
# 2. Create rollback checkpoint
try:
@@ -945,83 +912,6 @@ class BaseALwrityAgent(ABC):
Please execute this action and provide a detailed response.
Consider user goals, safety constraints, and potential impacts.
"""
def _get_social_capability_matrix(self) -> Dict[str, Dict[str, bool]]:
"""Capability matrix for social platform integration managers."""
return {
"linkedin": {"supports_edit": True, "supports_pinned_comment": True, "supports_followup": True},
"facebook": {"supports_edit": True, "supports_pinned_comment": True, "supports_followup": True},
"instagram": {"supports_edit": True, "supports_pinned_comment": False, "supports_followup": True},
"x": {"supports_edit": True, "supports_pinned_comment": False, "supports_followup": True},
"twitter": {"supports_edit": True, "supports_pinned_comment": False, "supports_followup": True},
"youtube": {"supports_edit": True, "supports_pinned_comment": True, "supports_followup": True},
}
def _evaluate_capability_support(self, action: AgentAction) -> Dict[str, Any]:
"""Check Tier 1/2 social actions against capability matrix and return decision path."""
platform = str(action.parameters.get("platform", "")).strip().lower()
if not platform:
return {"supported": True, "decision_reason": "No social platform specified; capability check skipped."}
matrix = self._get_social_capability_matrix()
platform_caps = matrix.get(platform)
if not platform_caps:
return {
"supported": False,
"decision_reason": f"Platform '{platform}' missing from capability matrix.",
"fallback_action": self._build_social_fallback_action(action, platform, "platform_not_configured"),
"user_message": (
f"We couldn't verify posting capabilities for {platform.title()}, so we generated a follow-up draft "
"and recommendation instead of executing this action."
),
}
action_tier = str(action.parameters.get("action_tier", "")).strip().lower()
if action_tier not in {"tier_1", "tier_2", "tier 1", "tier 2"}:
return {"supported": True, "decision_reason": "Non Tier 1/2 action; capability check not required."}
action_type = action.action_type.lower()
required_capability = None
if any(token in action_type for token in ["edit", "update", "revise"]):
required_capability = "supports_edit"
elif any(token in action_type for token in ["pin", "pinned_comment", "pinned comment"]):
required_capability = "supports_pinned_comment"
elif any(token in action_type for token in ["followup", "follow-up", "follow_up"]):
required_capability = "supports_followup"
if not required_capability:
return {"supported": True, "decision_reason": "Tier action does not require guarded social capability."}
supported = bool(platform_caps.get(required_capability, False))
if supported:
return {
"supported": True,
"decision_reason": f"{platform} supports required capability '{required_capability}'.",
"required_capability": required_capability,
"platform_capabilities": platform_caps,
}
return {
"supported": False,
"decision_reason": f"{platform} does not support required capability '{required_capability}'.",
"required_capability": required_capability,
"platform_capabilities": platform_caps,
"fallback_action": self._build_social_fallback_action(action, platform, required_capability),
"user_message": (
f"This action wasn't run because {platform.title()} does not support {required_capability}. "
"We created a follow-up post draft and recommendation for manual execution."
),
}
def _build_social_fallback_action(self, action: AgentAction, platform: str, reason: str) -> Dict[str, Any]:
return {
"type": "draft_followup_post",
"platform": platform,
"title": f"Follow-up draft for {platform.title()}",
"draft": f"Follow-up for original action '{action.action_type}' on {action.target_resource}.",
"recommendation": "Review and publish manually, then notify the team.",
"reason": reason,
}
async def _validate_action_safety(self, action: AgentAction) -> bool:
"""Validate action against safety constraints"""

View File

@@ -69,10 +69,6 @@ class SocialAmplificationAgent(BaseALwrityAgent):
# Instruction will be provided via orchestrator context or initial prompt
# Instruction should be provided during invocation or via orchestrator context
)
def get_social_integration_capabilities(self) -> Dict[str, Dict[str, bool]]:
"""Expose platform capability flags used by social integration managers."""
return self._get_social_capability_matrix()
# Tool Implementations

View File

@@ -26,7 +26,10 @@ from .executors.advertools_executor import AdvertoolsExecutor
from .executors.sif_indexing_executor import SIFIndexingExecutor
from .executors.market_trends_executor import MarketTrendsExecutor
from .utils.task_loader import load_due_monitoring_tasks
from .utils.oauth_token_task_loader import load_due_oauth_token_monitoring_tasks
from .utils.oauth_token_task_loader import (
load_due_oauth_token_monitoring_tasks,
load_near_expiry_oauth_token_tasks
)
from .utils.website_analysis_task_loader import load_due_website_analysis_tasks
from .utils.onboarding_full_website_analysis_task_loader import load_due_onboarding_full_website_analysis_tasks
from .utils.deep_competitor_analysis_task_loader import load_due_deep_competitor_analysis_tasks
@@ -70,6 +73,11 @@ def get_scheduler() -> TaskScheduler:
oauth_token_executor,
load_due_oauth_token_monitoring_tasks
)
_scheduler_instance.register_executor(
'oauth_token_refresh',
oauth_token_executor,
load_near_expiry_oauth_token_tasks
)
# Register website analysis executor
website_analysis_executor = WebsiteAnalysisExecutor()

View File

@@ -42,6 +42,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
self.exception_handler = SchedulerExceptionHandler()
# Expiration warning window (7 days before expiration)
self.expiration_warning_days = 7
self.max_refresh_retries = 3
self.base_retry_backoff_minutes = 15
async def execute_task(self, task: OAuthTokenMonitoringTask, db: Session) -> TaskExecutionResult:
"""
@@ -93,6 +95,10 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
task.last_success = datetime.utcnow()
task.status = 'active'
task.failure_reason = None
task.terminal_failure_reason = None
task.channel_status = 'connected'
task.refresh_attempts = 0
task.next_retry_at = None
# Reset failure tracking on success
task.consecutive_failures = 0
task.failure_pattern = None
@@ -112,6 +118,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
task.last_failure = datetime.utcnow()
task.failure_reason = result.error_message
task.refresh_attempts = (task.refresh_attempts or 0) + 1
if pattern and pattern.should_cool_off:
# Mark task for human intervention
@@ -126,6 +133,9 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
}
# Clear next_check - task won't run automatically
task.next_check = None
task.next_retry_at = None
task.channel_status = "disconnected"
task.terminal_failure_reason = result.error_message
self.logger.warning(
f"Task {task.id} marked for human intervention: "
@@ -133,10 +143,17 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
f"reason: {pattern.failure_reason.value}"
)
else:
# Normal failure handling
task.status = 'failed'
task.consecutive_failures = (task.consecutive_failures or 0) + 1
# Do NOT update next_check - wait for manual trigger
if task.refresh_attempts >= self.max_refresh_retries:
task.status = 'failed'
task.channel_status = 'disconnected'
task.terminal_failure_reason = result.error_message
task.next_retry_at = None
else:
task.status = 'degraded'
task.channel_status = 'degraded'
delay_minutes = self.base_retry_backoff_minutes * (2 ** (task.refresh_attempts - 1))
task.next_retry_at = datetime.utcnow() + timedelta(minutes=delay_minutes)
self.logger.warning(
f"OAuth token refresh failed for user {user_id}, platform {platform}. "
@@ -144,7 +161,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
)
# Create UsageAlert notification for the user
self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db)
self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db, task)
task.updated_at = datetime.utcnow()
db.commit()
@@ -193,12 +210,14 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
task.last_failure = datetime.utcnow()
task.failure_reason = str(e)
task.status = 'failed'
task.channel_status = 'disconnected'
task.terminal_failure_reason = str(e)
task.last_check = datetime.utcnow()
task.updated_at = datetime.utcnow()
# Do NOT update next_check - wait for manual trigger
task.next_retry_at = None
# Create UsageAlert notification for the user
self._create_failure_alert(user_id, task.platform, str(e), None, db)
self._create_failure_alert(user_id, task.platform, str(e), None, db, task)
db.commit()
except Exception as commit_error:
@@ -651,7 +670,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
platform: str,
error_message: str,
result_data: Optional[Dict[str, Any]],
db: Session
db: Session,
task: Optional[OAuthTokenMonitoringTask] = None
):
"""
Create a UsageAlert notification when OAuth token refresh fails.
@@ -723,6 +743,20 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
# Get current billing period (YYYY-MM format)
from datetime import datetime
billing_period = datetime.utcnow().strftime("%Y-%m")
alert_payload = {
"requires_user_action": True,
"platform": platform,
"channel_status": getattr(task, "channel_status", "disconnected"),
"terminal_failure_reason": getattr(task, "terminal_failure_reason", error_message),
"next_retry_at": (
task.next_retry_at.isoformat() if task and task.next_retry_at else None
),
"refresh_attempts": getattr(task, "refresh_attempts", 0),
"max_refresh_retries": self.max_refresh_retries,
}
message = f"{message} [ALERT_PAYLOAD] {alert_payload}"
# Create UsageAlert
alert = UsageAlert(
@@ -786,4 +820,3 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
f"Defaulting to Weekly (7 days)."
)
return last_execution + timedelta(days=7)

View File

@@ -3,7 +3,7 @@ OAuth Token Monitoring Task Loader
Functions to load due OAuth token monitoring tasks from database.
"""
from datetime import datetime
from datetime import datetime, timedelta
from typing import List, Optional, Union
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
@@ -52,3 +52,34 @@ def load_due_oauth_token_monitoring_tasks(
return query.all()
def load_near_expiry_oauth_token_tasks(
db: Session,
refresh_horizon_hours: int = 24,
user_id: Optional[Union[str, int]] = None
) -> List[OAuthTokenMonitoringTask]:
"""
Load OAuth tasks that should run token refresh logic soon.
Includes:
- tasks with a scheduled retry now due (next_retry_at <= now)
- tasks whose routine check is inside the near-expiry horizon window
"""
now = datetime.utcnow()
horizon = now + timedelta(hours=max(refresh_horizon_hours, 1))
query = db.query(OAuthTokenMonitoringTask).filter(
and_(
OAuthTokenMonitoringTask.status.in_(['active', 'failed', 'degraded']),
or_(
OAuthTokenMonitoringTask.next_retry_at <= now,
OAuthTokenMonitoringTask.next_check <= horizon,
OAuthTokenMonitoringTask.next_check.is_(None)
)
)
)
if user_id is not None:
query = query.filter(OAuthTokenMonitoringTask.user_id == str(user_id))
return query.all()