Compare commits

..

1 Commits

Author SHA1 Message Date
ي
6fdf318d79 Add OAuth token refresh retries, status persistence, and alert payloads 2026-05-18 15:56:57 +05:30
10 changed files with 100 additions and 169 deletions

View File

@@ -5,7 +5,7 @@ API endpoints for managing unified content assets across all modules.
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any, Set
from typing import List, Optional, Dict, Any
from pydantic import BaseModel, Field
from datetime import datetime
@@ -47,33 +47,6 @@ class AssetResponse(BaseModel):
from_attributes = True
def _parse_source_modules(source_module: Optional[List[str]]) -> Optional[List[AssetSource]]:
"""Parse source_module query values from repeated params and/or comma-separated values."""
if not source_module:
return None
parsed_values: List[AssetSource] = []
seen: Set[AssetSource] = set()
for raw_value in source_module:
for value in raw_value.split(","):
normalized = value.strip().lower()
if not normalized:
continue
try:
module = AssetSource(normalized)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid source module: {value.strip()}")
if module not in seen:
seen.add(module)
parsed_values.append(module)
return parsed_values or None
class AssetListResponse(BaseModel):
"""Response model for asset list."""
assets: List[AssetResponse]
@@ -85,7 +58,7 @@ class AssetListResponse(BaseModel):
@router.get("/", response_model=AssetListResponse)
async def get_assets(
asset_type: Optional[str] = Query(None, description="Filter by asset type"),
source_module: Optional[List[str]] = Query(None, description="Filter by source module(s); supports repeated params and comma-separated values"),
source_module: Optional[str] = Query(None, description="Filter by source module"),
search: Optional[str] = Query(None, description="Search query"),
tags: Optional[str] = Query(None, description="Comma-separated tags"),
favorites_only: bool = Query(False, description="Only favorites"),
@@ -116,7 +89,12 @@ async def get_assets(
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid asset type: {asset_type}")
source_modules_enum = _parse_source_modules(source_module)
source_module_enum = None
if source_module:
try:
source_module_enum = AssetSource(source_module.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid source module: {source_module}")
tags_list = None
if tags:
@@ -148,7 +126,7 @@ async def get_assets(
assets, total = service.get_user_assets(
user_id=user_id,
asset_type=asset_type_enum,
source_modules=source_modules_enum,
source_module=source_module_enum,
search_query=search,
tags=tags_list,
favorites_only=favorites_only,
@@ -222,7 +200,7 @@ async def create_asset(
asset = service.create_asset(
user_id=user_id,
asset_type=asset_type_enum,
source_modules=source_modules_enum,
source_module=source_module_enum,
filename=asset_data.filename,
file_url=asset_data.file_url,
file_path=asset_data.file_path,

View File

@@ -40,6 +40,10 @@ class OAuthTokenMonitoringTask(Base):
# Scheduling
next_check = Column(DateTime, nullable=True, index=True) # Next scheduled check time
next_retry_at = Column(DateTime, nullable=True, index=True) # Backoff retry schedule for refresh failures
refresh_attempts = Column(Integer, default=0) # Current retry attempt count for refresh workflow
terminal_failure_reason = Column(Text, nullable=True) # Permanent failure reason requiring user action
channel_status = Column(String(32), default='connected') # connected, degraded, disconnected
# Metadata
created_at = Column(DateTime, default=datetime.utcnow)
@@ -97,4 +101,3 @@ class OAuthTokenExecutionLog(Base):
def __repr__(self):
return f"<OAuthTokenExecutionLog(id={self.id}, task_id={self.task_id}, status={self.status}, execution_date={self.execution_date})>"

View File

@@ -107,7 +107,6 @@ class ContentAssetService:
user_id: str,
asset_type: Optional[AssetType] = None,
source_module: Optional[AssetSource] = None,
source_modules: Optional[List[AssetSource]] = None,
search_query: Optional[str] = None,
tags: Optional[List[str]] = None,
favorites_only: bool = False,
@@ -126,7 +125,6 @@ class ContentAssetService:
user_id: Clerk user ID
asset_type: Filter by asset type (optional)
source_module: Filter by source module (optional)
source_modules: Filter by multiple source modules (optional)
search_query: Search in title, description, prompt (optional)
tags: Filter by tags (optional)
favorites_only: Only return favorites (optional)
@@ -144,9 +142,7 @@ class ContentAssetService:
if asset_type:
query = query.filter(ContentAsset.asset_type == asset_type)
if source_modules:
query = query.filter(ContentAsset.source_module.in_(source_modules))
elif source_module:
if source_module:
query = query.filter(ContentAsset.source_module == source_module)
if favorites_only:

View File

@@ -26,7 +26,10 @@ from .executors.advertools_executor import AdvertoolsExecutor
from .executors.sif_indexing_executor import SIFIndexingExecutor
from .executors.market_trends_executor import MarketTrendsExecutor
from .utils.task_loader import load_due_monitoring_tasks
from .utils.oauth_token_task_loader import load_due_oauth_token_monitoring_tasks
from .utils.oauth_token_task_loader import (
load_due_oauth_token_monitoring_tasks,
load_near_expiry_oauth_token_tasks
)
from .utils.website_analysis_task_loader import load_due_website_analysis_tasks
from .utils.onboarding_full_website_analysis_task_loader import load_due_onboarding_full_website_analysis_tasks
from .utils.deep_competitor_analysis_task_loader import load_due_deep_competitor_analysis_tasks
@@ -70,6 +73,11 @@ def get_scheduler() -> TaskScheduler:
oauth_token_executor,
load_due_oauth_token_monitoring_tasks
)
_scheduler_instance.register_executor(
'oauth_token_refresh',
oauth_token_executor,
load_near_expiry_oauth_token_tasks
)
# Register website analysis executor
website_analysis_executor = WebsiteAnalysisExecutor()

View File

@@ -42,6 +42,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
self.exception_handler = SchedulerExceptionHandler()
# Expiration warning window (7 days before expiration)
self.expiration_warning_days = 7
self.max_refresh_retries = 3
self.base_retry_backoff_minutes = 15
async def execute_task(self, task: OAuthTokenMonitoringTask, db: Session) -> TaskExecutionResult:
"""
@@ -93,6 +95,10 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
task.last_success = datetime.utcnow()
task.status = 'active'
task.failure_reason = None
task.terminal_failure_reason = None
task.channel_status = 'connected'
task.refresh_attempts = 0
task.next_retry_at = None
# Reset failure tracking on success
task.consecutive_failures = 0
task.failure_pattern = None
@@ -112,6 +118,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
task.last_failure = datetime.utcnow()
task.failure_reason = result.error_message
task.refresh_attempts = (task.refresh_attempts or 0) + 1
if pattern and pattern.should_cool_off:
# Mark task for human intervention
@@ -126,6 +133,9 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
}
# Clear next_check - task won't run automatically
task.next_check = None
task.next_retry_at = None
task.channel_status = "disconnected"
task.terminal_failure_reason = result.error_message
self.logger.warning(
f"Task {task.id} marked for human intervention: "
@@ -133,10 +143,17 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
f"reason: {pattern.failure_reason.value}"
)
else:
# Normal failure handling
task.status = 'failed'
task.consecutive_failures = (task.consecutive_failures or 0) + 1
# Do NOT update next_check - wait for manual trigger
if task.refresh_attempts >= self.max_refresh_retries:
task.status = 'failed'
task.channel_status = 'disconnected'
task.terminal_failure_reason = result.error_message
task.next_retry_at = None
else:
task.status = 'degraded'
task.channel_status = 'degraded'
delay_minutes = self.base_retry_backoff_minutes * (2 ** (task.refresh_attempts - 1))
task.next_retry_at = datetime.utcnow() + timedelta(minutes=delay_minutes)
self.logger.warning(
f"OAuth token refresh failed for user {user_id}, platform {platform}. "
@@ -144,7 +161,7 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
)
# Create UsageAlert notification for the user
self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db)
self._create_failure_alert(user_id, platform, result.error_message, result.result_data, db, task)
task.updated_at = datetime.utcnow()
db.commit()
@@ -193,12 +210,14 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
task.last_failure = datetime.utcnow()
task.failure_reason = str(e)
task.status = 'failed'
task.channel_status = 'disconnected'
task.terminal_failure_reason = str(e)
task.last_check = datetime.utcnow()
task.updated_at = datetime.utcnow()
# Do NOT update next_check - wait for manual trigger
task.next_retry_at = None
# Create UsageAlert notification for the user
self._create_failure_alert(user_id, task.platform, str(e), None, db)
self._create_failure_alert(user_id, task.platform, str(e), None, db, task)
db.commit()
except Exception as commit_error:
@@ -651,7 +670,8 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
platform: str,
error_message: str,
result_data: Optional[Dict[str, Any]],
db: Session
db: Session,
task: Optional[OAuthTokenMonitoringTask] = None
):
"""
Create a UsageAlert notification when OAuth token refresh fails.
@@ -723,6 +743,20 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
# Get current billing period (YYYY-MM format)
from datetime import datetime
billing_period = datetime.utcnow().strftime("%Y-%m")
alert_payload = {
"requires_user_action": True,
"platform": platform,
"channel_status": getattr(task, "channel_status", "disconnected"),
"terminal_failure_reason": getattr(task, "terminal_failure_reason", error_message),
"next_retry_at": (
task.next_retry_at.isoformat() if task and task.next_retry_at else None
),
"refresh_attempts": getattr(task, "refresh_attempts", 0),
"max_refresh_retries": self.max_refresh_retries,
}
message = f"{message} [ALERT_PAYLOAD] {alert_payload}"
# Create UsageAlert
alert = UsageAlert(
@@ -786,4 +820,3 @@ class OAuthTokenMonitoringExecutor(TaskExecutor):
f"Defaulting to Weekly (7 days)."
)
return last_execution + timedelta(days=7)

View File

@@ -3,7 +3,7 @@ OAuth Token Monitoring Task Loader
Functions to load due OAuth token monitoring tasks from database.
"""
from datetime import datetime
from datetime import datetime, timedelta
from typing import List, Optional, Union
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
@@ -52,3 +52,34 @@ def load_due_oauth_token_monitoring_tasks(
return query.all()
def load_near_expiry_oauth_token_tasks(
db: Session,
refresh_horizon_hours: int = 24,
user_id: Optional[Union[str, int]] = None
) -> List[OAuthTokenMonitoringTask]:
"""
Load OAuth tasks that should run token refresh logic soon.
Includes:
- tasks with a scheduled retry now due (next_retry_at <= now)
- tasks whose routine check is inside the near-expiry horizon window
"""
now = datetime.utcnow()
horizon = now + timedelta(hours=max(refresh_horizon_hours, 1))
query = db.query(OAuthTokenMonitoringTask).filter(
and_(
OAuthTokenMonitoringTask.status.in_(['active', 'failed', 'degraded']),
or_(
OAuthTokenMonitoringTask.next_retry_at <= now,
OAuthTokenMonitoringTask.next_check <= horizon,
OAuthTokenMonitoringTask.next_check.is_(None)
)
)
)
if user_id is not None:
query = query.filter(OAuthTokenMonitoringTask.user_id == str(user_id))
return query.all()

View File

@@ -1,31 +0,0 @@
import importlib.util
from pathlib import Path
from fastapi import HTTPException
ROOT = Path(__file__).resolve().parents[3]
ROUTER_PATH = ROOT / 'backend' / 'api' / 'content_assets' / 'router.py'
MODELS_PATH = ROOT / 'backend' / 'models' / 'content_asset_models.py'
models_spec = importlib.util.spec_from_file_location('content_asset_models', MODELS_PATH)
models = importlib.util.module_from_spec(models_spec)
models_spec.loader.exec_module(models)
AssetSource = models.AssetSource
router_spec = importlib.util.spec_from_file_location('content_assets_router', ROUTER_PATH)
router = importlib.util.module_from_spec(router_spec)
router_spec.loader.exec_module(router)
def test_parse_source_modules_supports_repeated_and_csv_values():
parsed = router._parse_source_modules(["blog_writer", "youtube,podcast"])
assert parsed == [AssetSource.BLOG_WRITER, AssetSource.YOUTUBE, AssetSource.PODCAST]
def test_parse_source_modules_raises_for_invalid_values():
try:
router._parse_source_modules(["blog_writer,unknown"])
except HTTPException as exc:
assert exc.status_code == 400
assert "Invalid source module" in exc.detail
else:
raise AssertionError("Expected HTTPException for invalid source module")

View File

@@ -1,50 +0,0 @@
import importlib.util
from pathlib import Path
ROOT = Path(__file__).resolve().parents[3]
SERVICE_PATH = ROOT / 'backend' / 'services' / 'content_asset_service.py'
MODELS_PATH = ROOT / 'backend' / 'models' / 'content_asset_models.py'
models_spec = importlib.util.spec_from_file_location('content_asset_models', MODELS_PATH)
models = importlib.util.module_from_spec(models_spec)
models_spec.loader.exec_module(models)
AssetSource = models.AssetSource
service_spec = importlib.util.spec_from_file_location('content_asset_service', SERVICE_PATH)
service_module = importlib.util.module_from_spec(service_spec)
service_spec.loader.exec_module(service_module)
ContentAssetService = service_module.ContentAssetService
class DummyQuery:
def __init__(self):
self.filters = []
def filter(self, expr):
self.filters.append(expr)
return self
def count(self): return 0
def order_by(self, *_args, **_kwargs): return self
def limit(self, *_args, **_kwargs): return self
def offset(self, *_args, **_kwargs): return self
def all(self): return []
class DummyDB:
def __init__(self): self.query_obj = DummyQuery()
def query(self, *_args, **_kwargs): return self.query_obj
def test_get_user_assets_accepts_multiple_source_modules_filter():
db = DummyDB()
service = ContentAssetService(db)
assets, total = service.get_user_assets(
user_id="user-1",
source_modules=[AssetSource.BLOG_WRITER, AssetSource.YOUTUBE],
)
assert assets == []
assert total == 0
assert len(db.query_obj.filters) >= 2

View File

@@ -1,35 +0,0 @@
import { renderHook, waitFor } from '@testing-library/react';
import { useContentAssets } from '../useContentAssets';
const getTokenMock = jest.fn();
jest.mock('@clerk/clerk-react', () => ({
useAuth: () => ({ getToken: getTokenMock }),
}));
describe('useContentAssets', () => {
beforeEach(() => {
getTokenMock.mockResolvedValue('test-token');
global.fetch = jest.fn().mockResolvedValue({
ok: true,
json: async () => ({ assets: [], total: 0, limit: 100, offset: 0 }),
} as Response);
});
afterEach(() => {
jest.clearAllMocks();
});
it('sends all source_module values as repeated query params', async () => {
renderHook(() =>
useContentAssets({ source_module: ['blog_writer', 'youtube'], limit: 50, offset: 0 })
);
await waitFor(() => expect(global.fetch).toHaveBeenCalled());
const calledUrl = (global.fetch as jest.Mock).mock.calls[0][0] as string;
const params = new URL(calledUrl).searchParams;
expect(params.getAll('source_module')).toEqual(['blog_writer', 'youtube']);
});
});

View File

@@ -29,7 +29,7 @@ export interface ContentAsset {
export interface AssetFilters {
asset_type?: 'text' | 'image' | 'video' | 'audio';
source_module?: string | string[]; // Supports single or multiple source modules
source_module?: string | string[]; // Support single or multiple source modules
search?: string;
tags?: string[];
favorites_only?: boolean;
@@ -146,10 +146,8 @@ export const useContentAssets = (filters: AssetFilters = {}) => {
if (currentFilters.source_module) {
// Handle both string and array cases
if (Array.isArray(currentFilters.source_module)) {
// Send every selected source module as repeated query params
currentFilters.source_module.forEach((module) => {
params.append('source_module', module);
});
// For arrays, use the first value (backend doesn't support multiple yet)
params.append('source_module', currentFilters.source_module[0]);
} else {
params.append('source_module', currentFilters.source_module);
}