diff --git a/backend/models/oauth_token_monitoring_models.py b/backend/models/oauth_token_monitoring_models.py index a6d4c48e..6830a62b 100644 --- a/backend/models/oauth_token_monitoring_models.py +++ b/backend/models/oauth_token_monitoring_models.py @@ -40,6 +40,10 @@ class OAuthTokenMonitoringTask(Base): # Scheduling next_check = Column(DateTime, nullable=True, index=True) # Next scheduled check time + next_retry_at = Column(DateTime, nullable=True, index=True) # Backoff retry schedule for refresh failures + refresh_attempts = Column(Integer, default=0) # Current retry attempt count for refresh workflow + terminal_failure_reason = Column(Text, nullable=True) # Permanent failure reason requiring user action + channel_status = Column(String(32), default='connected') # connected, degraded, disconnected # Metadata created_at = Column(DateTime, default=datetime.utcnow) @@ -97,4 +101,3 @@ class OAuthTokenExecutionLog(Base): def __repr__(self): return f"" - diff --git a/backend/services/scheduler/__init__.py b/backend/services/scheduler/__init__.py index 5d7d1983..08dd0f2a 100644 --- a/backend/services/scheduler/__init__.py +++ b/backend/services/scheduler/__init__.py @@ -26,7 +26,10 @@ from .executors.advertools_executor import AdvertoolsExecutor from .executors.sif_indexing_executor import SIFIndexingExecutor from .executors.market_trends_executor import MarketTrendsExecutor from .utils.task_loader import load_due_monitoring_tasks -from .utils.oauth_token_task_loader import load_due_oauth_token_monitoring_tasks +from .utils.oauth_token_task_loader import ( + load_due_oauth_token_monitoring_tasks, + load_near_expiry_oauth_token_tasks +) from .utils.website_analysis_task_loader import load_due_website_analysis_tasks from .utils.onboarding_full_website_analysis_task_loader import load_due_onboarding_full_website_analysis_tasks from .utils.deep_competitor_analysis_task_loader import load_due_deep_competitor_analysis_tasks @@ -70,6 +73,11 @@ def get_scheduler() -> TaskScheduler: oauth_token_executor, load_due_oauth_token_monitoring_tasks ) + _scheduler_instance.register_executor( + 'oauth_token_refresh', + oauth_token_executor, + load_near_expiry_oauth_token_tasks + ) # Register website analysis executor website_analysis_executor = WebsiteAnalysisExecutor() diff --git a/backend/services/scheduler/executors/oauth_token_monitoring_executor.py b/backend/services/scheduler/executors/oauth_token_monitoring_executor.py index 33922e4b..5c9be289 100644 --- a/backend/services/scheduler/executors/oauth_token_monitoring_executor.py +++ b/backend/services/scheduler/executors/oauth_token_monitoring_executor.py @@ -42,6 +42,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): self.exception_handler = SchedulerExceptionHandler() # Expiration warning window (7 days before expiration) self.expiration_warning_days = 7 + self.max_refresh_retries = 3 + self.base_retry_backoff_minutes = 15 async def execute_task(self, task: OAuthTokenMonitoringTask, db: Session) -> TaskExecutionResult: """ @@ -93,6 +95,10 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): task.last_success = datetime.utcnow() task.status = 'active' task.failure_reason = None + task.terminal_failure_reason = None + task.channel_status = 'connected' + task.refresh_attempts = 0 + task.next_retry_at = None # Reset failure tracking on success task.consecutive_failures = 0 task.failure_pattern = None @@ -112,6 +118,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): task.last_failure = datetime.utcnow() task.failure_reason = result.error_message + task.refresh_attempts = (task.refresh_attempts or 0) + 1 if pattern and pattern.should_cool_off: # Mark task for human intervention @@ -126,6 +133,9 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): } # Clear next_check - task won't run automatically task.next_check = None + task.next_retry_at = None + task.channel_status = "disconnected" + task.terminal_failure_reason = result.error_message self.logger.warning( f"Task {task.id} marked for human intervention: " @@ -133,10 +143,17 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): f"reason: {pattern.failure_reason.value}" ) else: - # Normal failure handling - task.status = 'failed' task.consecutive_failures = (task.consecutive_failures or 0) + 1 - # Do NOT update next_check - wait for manual trigger + if task.refresh_attempts >= self.max_refresh_retries: + task.status = 'failed' + task.channel_status = 'disconnected' + task.terminal_failure_reason = result.error_message + task.next_retry_at = None + else: + task.status = 'degraded' + task.channel_status = 'degraded' + delay_minutes = self.base_retry_backoff_minutes * (2 ** (task.refresh_attempts - 1)) + task.next_retry_at = datetime.utcnow() + timedelta(minutes=delay_minutes) self.logger.warning( f"OAuth token refresh failed for user {user_id}, platform {platform}. " @@ -144,7 +161,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): ) # Create UsageAlert notification for the user - self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db) + self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db, task) task.updated_at = datetime.utcnow() db.commit() @@ -193,12 +210,14 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): task.last_failure = datetime.utcnow() task.failure_reason = str(e) task.status = 'failed' + task.channel_status = 'disconnected' + task.terminal_failure_reason = str(e) task.last_check = datetime.utcnow() task.updated_at = datetime.utcnow() - # Do NOT update next_check - wait for manual trigger + task.next_retry_at = None # Create UsageAlert notification for the user - self._create_failure_alert(user_id, task.platform, str(e), None, db) + self._create_failure_alert(user_id, task.platform, str(e), None, db, task) db.commit() except Exception as commit_error: @@ -651,7 +670,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): platform: str, error_message: str, result_data: Optional[Dict[str, Any]], - db: Session + db: Session, + task: Optional[OAuthTokenMonitoringTask] = None ): """ Create a UsageAlert notification when OAuth token refresh fails. @@ -723,6 +743,20 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): # Get current billing period (YYYY-MM format) from datetime import datetime billing_period = datetime.utcnow().strftime("%Y-%m") + + alert_payload = { + "requires_user_action": True, + "platform": platform, + "channel_status": getattr(task, "channel_status", "disconnected"), + "terminal_failure_reason": getattr(task, "terminal_failure_reason", error_message), + "next_retry_at": ( + task.next_retry_at.isoformat() if task and task.next_retry_at else None + ), + "refresh_attempts": getattr(task, "refresh_attempts", 0), + "max_refresh_retries": self.max_refresh_retries, + } + + message = f"{message} [ALERT_PAYLOAD] {alert_payload}" # Create UsageAlert alert = UsageAlert( @@ -786,4 +820,3 @@ class OAuthTokenMonitoringExecutor(TaskExecutor): f"Defaulting to Weekly (7 days)." ) return last_execution + timedelta(days=7) - diff --git a/backend/services/scheduler/utils/oauth_token_task_loader.py b/backend/services/scheduler/utils/oauth_token_task_loader.py index 15ca30c2..a5f62c48 100644 --- a/backend/services/scheduler/utils/oauth_token_task_loader.py +++ b/backend/services/scheduler/utils/oauth_token_task_loader.py @@ -3,7 +3,7 @@ OAuth Token Monitoring Task Loader Functions to load due OAuth token monitoring tasks from database. """ -from datetime import datetime +from datetime import datetime, timedelta from typing import List, Optional, Union from sqlalchemy.orm import Session from sqlalchemy import and_, or_ @@ -52,3 +52,34 @@ def load_due_oauth_token_monitoring_tasks( return query.all() + +def load_near_expiry_oauth_token_tasks( + db: Session, + refresh_horizon_hours: int = 24, + user_id: Optional[Union[str, int]] = None +) -> List[OAuthTokenMonitoringTask]: + """ + Load OAuth tasks that should run token refresh logic soon. + + Includes: + - tasks with a scheduled retry now due (next_retry_at <= now) + - tasks whose routine check is inside the near-expiry horizon window + """ + now = datetime.utcnow() + horizon = now + timedelta(hours=max(refresh_horizon_hours, 1)) + + query = db.query(OAuthTokenMonitoringTask).filter( + and_( + OAuthTokenMonitoringTask.status.in_(['active', 'failed', 'degraded']), + or_( + OAuthTokenMonitoringTask.next_retry_at <= now, + OAuthTokenMonitoringTask.next_check <= horizon, + OAuthTokenMonitoringTask.next_check.is_(None) + ) + ) + ) + + if user_id is not None: + query = query.filter(OAuthTokenMonitoringTask.user_id == str(user_id)) + + return query.all()