Compare commits
1 Commits
codex/impl
...
codex/add-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6fdf318d79 |
@@ -40,6 +40,10 @@ class OAuthTokenMonitoringTask(Base):
|
||||
|
||||
# Scheduling
|
||||
next_check = Column(DateTime, nullable=True, index=True) # Next scheduled check time
|
||||
next_retry_at = Column(DateTime, nullable=True, index=True) # Backoff retry schedule for refresh failures
|
||||
refresh_attempts = Column(Integer, default=0) # Current retry attempt count for refresh workflow
|
||||
terminal_failure_reason = Column(Text, nullable=True) # Permanent failure reason requiring user action
|
||||
channel_status = Column(String(32), default='connected') # connected, degraded, disconnected
|
||||
|
||||
# Metadata
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
@@ -97,4 +101,3 @@ class OAuthTokenExecutionLog(Base):
|
||||
|
||||
def __repr__(self):
|
||||
return f"<OAuthTokenExecutionLog(id={self.id}, task_id={self.task_id}, status={self.status}, execution_date={self.execution_date})>"
|
||||
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/v1/social-proxy", tags=["social-proxy"])
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _ensure_tables(db: Session) -> None:
|
||||
# Keep this router backward-compatible on tenant DBs without migrations.
|
||||
db.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS oauth_nonce_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
state TEXT NOT NULL UNIQUE,
|
||||
nonce TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
channel_id INTEGER,
|
||||
consumed_at TEXT,
|
||||
expires_at TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
"""))
|
||||
db.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS social_channels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
platform_account_id TEXT NOT NULL,
|
||||
token_bundle TEXT NOT NULL,
|
||||
token_version INTEGER NOT NULL DEFAULT 1,
|
||||
publication_linkage TEXT,
|
||||
is_connected INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
UNIQUE(platform, platform_account_id)
|
||||
)
|
||||
"""))
|
||||
|
||||
|
||||
def _build_redirect(base_url: str, code: str, message: str, channel_id: Optional[int] = None) -> RedirectResponse:
|
||||
params = {"code": code, "message": message}
|
||||
if channel_id is not None:
|
||||
params["channel_id"] = str(channel_id)
|
||||
return RedirectResponse(url=f"{base_url}?{urlencode(params)}", status_code=303)
|
||||
|
||||
|
||||
@router.get("/oauth/callback")
|
||||
def oauth_callback(
|
||||
state: str = Query(...),
|
||||
platform: str = Query(...),
|
||||
account_id: str = Query(...),
|
||||
token_bundle: str = Query(..., description="Serialized token payload"),
|
||||
ui_redirect: str = Query("/dashboard/connections"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Consume OAuth callback, bind to user/platform, and upsert social channel connection."""
|
||||
_ensure_tables(db)
|
||||
|
||||
record = db.execute(
|
||||
text("""
|
||||
SELECT id, nonce, user_id, platform, channel_id, consumed_at, expires_at
|
||||
FROM oauth_nonce_sessions WHERE state = :state
|
||||
"""),
|
||||
{"state": state},
|
||||
).mappings().first()
|
||||
|
||||
if not record:
|
||||
return _build_redirect(ui_redirect, "invalid_state", "Missing OAuth session")
|
||||
|
||||
if record["consumed_at"] is not None:
|
||||
return _build_redirect(ui_redirect, "state_reused", "OAuth state already consumed")
|
||||
|
||||
if record["platform"] != platform:
|
||||
return _build_redirect(ui_redirect, "platform_mismatch", "Platform mismatch")
|
||||
|
||||
if record["expires_at"] and record["expires_at"] < _utc_now_iso():
|
||||
return _build_redirect(ui_redirect, "state_expired", "OAuth session expired")
|
||||
|
||||
user_id = record["user_id"]
|
||||
|
||||
# Validate token payload is JSON.
|
||||
try:
|
||||
parsed_bundle = json.loads(token_bundle)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid token_bundle JSON") from exc
|
||||
|
||||
now = _utc_now_iso()
|
||||
|
||||
existing = db.execute(
|
||||
text("""
|
||||
SELECT id, publication_linkage, token_version
|
||||
FROM social_channels
|
||||
WHERE platform = :platform AND platform_account_id = :account_id
|
||||
"""),
|
||||
{"platform": platform, "account_id": account_id},
|
||||
).mappings().first()
|
||||
|
||||
if existing:
|
||||
# Reconnect path: preserve publication linkage and bump token version.
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE social_channels
|
||||
SET user_id = :user_id,
|
||||
token_bundle = :token_bundle,
|
||||
token_version = :token_version,
|
||||
is_connected = 1,
|
||||
updated_at = :updated_at
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{
|
||||
"id": existing["id"],
|
||||
"user_id": user_id,
|
||||
"token_bundle": json.dumps(parsed_bundle),
|
||||
"token_version": int(existing["token_version"] or 0) + 1,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
channel_id = existing["id"]
|
||||
result_code = "reconnected"
|
||||
result_message = "Channel reconnected"
|
||||
else:
|
||||
db.execute(
|
||||
text("""
|
||||
INSERT INTO social_channels (
|
||||
user_id, platform, platform_account_id, token_bundle,
|
||||
token_version, publication_linkage, is_connected, created_at, updated_at
|
||||
) VALUES (
|
||||
:user_id, :platform, :account_id, :token_bundle,
|
||||
1, :publication_linkage, 1, :created_at, :updated_at
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"platform": platform,
|
||||
"account_id": account_id,
|
||||
"token_bundle": json.dumps(parsed_bundle),
|
||||
"publication_linkage": None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
channel_id = db.execute(text("SELECT last_insert_rowid()")).scalar_one()
|
||||
result_code = "connected"
|
||||
result_message = "Channel connected"
|
||||
|
||||
# Bind callback session to concrete channel/user/platform and mark consumed.
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE oauth_nonce_sessions
|
||||
SET consumed_at = :consumed_at,
|
||||
channel_id = :channel_id,
|
||||
user_id = :user_id,
|
||||
platform = :platform
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{
|
||||
"id": record["id"],
|
||||
"consumed_at": now,
|
||||
"channel_id": channel_id,
|
||||
"user_id": user_id,
|
||||
"platform": platform,
|
||||
},
|
||||
)
|
||||
|
||||
db.commit()
|
||||
logger.info(f"OAuth callback complete user={user_id} platform={platform} channel_id={channel_id}")
|
||||
return _build_redirect(ui_redirect, result_code, result_message, channel_id)
|
||||
@@ -26,7 +26,10 @@ from .executors.advertools_executor import AdvertoolsExecutor
|
||||
from .executors.sif_indexing_executor import SIFIndexingExecutor
|
||||
from .executors.market_trends_executor import MarketTrendsExecutor
|
||||
from .utils.task_loader import load_due_monitoring_tasks
|
||||
from .utils.oauth_token_task_loader import load_due_oauth_token_monitoring_tasks
|
||||
from .utils.oauth_token_task_loader import (
|
||||
load_due_oauth_token_monitoring_tasks,
|
||||
load_near_expiry_oauth_token_tasks
|
||||
)
|
||||
from .utils.website_analysis_task_loader import load_due_website_analysis_tasks
|
||||
from .utils.onboarding_full_website_analysis_task_loader import load_due_onboarding_full_website_analysis_tasks
|
||||
from .utils.deep_competitor_analysis_task_loader import load_due_deep_competitor_analysis_tasks
|
||||
@@ -70,6 +73,11 @@ def get_scheduler() -> TaskScheduler:
|
||||
oauth_token_executor,
|
||||
load_due_oauth_token_monitoring_tasks
|
||||
)
|
||||
_scheduler_instance.register_executor(
|
||||
'oauth_token_refresh',
|
||||
oauth_token_executor,
|
||||
load_near_expiry_oauth_token_tasks
|
||||
)
|
||||
|
||||
# Register website analysis executor
|
||||
website_analysis_executor = WebsiteAnalysisExecutor()
|
||||
|
||||
@@ -42,6 +42,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
self.exception_handler = SchedulerExceptionHandler()
|
||||
# Expiration warning window (7 days before expiration)
|
||||
self.expiration_warning_days = 7
|
||||
self.max_refresh_retries = 3
|
||||
self.base_retry_backoff_minutes = 15
|
||||
|
||||
async def execute_task(self, task: OAuthTokenMonitoringTask, db: Session) -> TaskExecutionResult:
|
||||
"""
|
||||
@@ -93,6 +95,10 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
task.last_success = datetime.utcnow()
|
||||
task.status = 'active'
|
||||
task.failure_reason = None
|
||||
task.terminal_failure_reason = None
|
||||
task.channel_status = 'connected'
|
||||
task.refresh_attempts = 0
|
||||
task.next_retry_at = None
|
||||
# Reset failure tracking on success
|
||||
task.consecutive_failures = 0
|
||||
task.failure_pattern = None
|
||||
@@ -112,6 +118,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
|
||||
task.last_failure = datetime.utcnow()
|
||||
task.failure_reason = result.error_message
|
||||
task.refresh_attempts = (task.refresh_attempts or 0) + 1
|
||||
|
||||
if pattern and pattern.should_cool_off:
|
||||
# Mark task for human intervention
|
||||
@@ -126,6 +133,9 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
}
|
||||
# Clear next_check - task won't run automatically
|
||||
task.next_check = None
|
||||
task.next_retry_at = None
|
||||
task.channel_status = "disconnected"
|
||||
task.terminal_failure_reason = result.error_message
|
||||
|
||||
self.logger.warning(
|
||||
f"Task {task.id} marked for human intervention: "
|
||||
@@ -133,10 +143,17 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
f"reason: {pattern.failure_reason.value}"
|
||||
)
|
||||
else:
|
||||
# Normal failure handling
|
||||
task.status = 'failed'
|
||||
task.consecutive_failures = (task.consecutive_failures or 0) + 1
|
||||
# Do NOT update next_check - wait for manual trigger
|
||||
if task.refresh_attempts >= self.max_refresh_retries:
|
||||
task.status = 'failed'
|
||||
task.channel_status = 'disconnected'
|
||||
task.terminal_failure_reason = result.error_message
|
||||
task.next_retry_at = None
|
||||
else:
|
||||
task.status = 'degraded'
|
||||
task.channel_status = 'degraded'
|
||||
delay_minutes = self.base_retry_backoff_minutes * (2 ** (task.refresh_attempts - 1))
|
||||
task.next_retry_at = datetime.utcnow() + timedelta(minutes=delay_minutes)
|
||||
|
||||
self.logger.warning(
|
||||
f"OAuth token refresh failed for user {user_id}, platform {platform}. "
|
||||
@@ -144,7 +161,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
)
|
||||
|
||||
# Create UsageAlert notification for the user
|
||||
self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db)
|
||||
self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db, task)
|
||||
|
||||
task.updated_at = datetime.utcnow()
|
||||
db.commit()
|
||||
@@ -193,12 +210,14 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
task.last_failure = datetime.utcnow()
|
||||
task.failure_reason = str(e)
|
||||
task.status = 'failed'
|
||||
task.channel_status = 'disconnected'
|
||||
task.terminal_failure_reason = str(e)
|
||||
task.last_check = datetime.utcnow()
|
||||
task.updated_at = datetime.utcnow()
|
||||
# Do NOT update next_check - wait for manual trigger
|
||||
task.next_retry_at = None
|
||||
|
||||
# Create UsageAlert notification for the user
|
||||
self._create_failure_alert(user_id, task.platform, str(e), None, db)
|
||||
self._create_failure_alert(user_id, task.platform, str(e), None, db, task)
|
||||
|
||||
db.commit()
|
||||
except Exception as commit_error:
|
||||
@@ -651,7 +670,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
platform: str,
|
||||
error_message: str,
|
||||
result_data: Optional[Dict[str, Any]],
|
||||
db: Session
|
||||
db: Session,
|
||||
task: Optional[OAuthTokenMonitoringTask] = None
|
||||
):
|
||||
"""
|
||||
Create a UsageAlert notification when OAuth token refresh fails.
|
||||
@@ -723,6 +743,20 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
# Get current billing period (YYYY-MM format)
|
||||
from datetime import datetime
|
||||
billing_period = datetime.utcnow().strftime("%Y-%m")
|
||||
|
||||
alert_payload = {
|
||||
"requires_user_action": True,
|
||||
"platform": platform,
|
||||
"channel_status": getattr(task, "channel_status", "disconnected"),
|
||||
"terminal_failure_reason": getattr(task, "terminal_failure_reason", error_message),
|
||||
"next_retry_at": (
|
||||
task.next_retry_at.isoformat() if task and task.next_retry_at else None
|
||||
),
|
||||
"refresh_attempts": getattr(task, "refresh_attempts", 0),
|
||||
"max_refresh_retries": self.max_refresh_retries,
|
||||
}
|
||||
|
||||
message = f"{message} [ALERT_PAYLOAD] {alert_payload}"
|
||||
|
||||
# Create UsageAlert
|
||||
alert = UsageAlert(
|
||||
@@ -786,4 +820,3 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
|
||||
f"Defaulting to Weekly (7 days)."
|
||||
)
|
||||
return last_execution + timedelta(days=7)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ OAuth Token Monitoring Task Loader
|
||||
Functions to load due OAuth token monitoring tasks from database.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Union
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
@@ -52,3 +52,34 @@ def load_due_oauth_token_monitoring_tasks(
|
||||
|
||||
return query.all()
|
||||
|
||||
|
||||
def load_near_expiry_oauth_token_tasks(
|
||||
db: Session,
|
||||
refresh_horizon_hours: int = 24,
|
||||
user_id: Optional[Union[str, int]] = None
|
||||
) -> List[OAuthTokenMonitoringTask]:
|
||||
"""
|
||||
Load OAuth tasks that should run token refresh logic soon.
|
||||
|
||||
Includes:
|
||||
- tasks with a scheduled retry now due (next_retry_at <= now)
|
||||
- tasks whose routine check is inside the near-expiry horizon window
|
||||
"""
|
||||
now = datetime.utcnow()
|
||||
horizon = now + timedelta(hours=max(refresh_horizon_hours, 1))
|
||||
|
||||
query = db.query(OAuthTokenMonitoringTask).filter(
|
||||
and_(
|
||||
OAuthTokenMonitoringTask.status.in_(['active', 'failed', 'degraded']),
|
||||
or_(
|
||||
OAuthTokenMonitoringTask.next_retry_at <= now,
|
||||
OAuthTokenMonitoringTask.next_check <= horizon,
|
||||
OAuthTokenMonitoringTask.next_check.is_(None)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if user_id is not None:
|
||||
query = query.filter(OAuthTokenMonitoringTask.user_id == str(user_id))
|
||||
|
||||
return query.all()
|
||||
|
||||
Reference in New Issue
Block a user