Compare commits

..

1 Commits

Author SHA1 Message Date
ي
fb75377d37 Add OAuth social proxy callback binding and reconnect handling 2026-05-18 15:57:22 +05:30
3 changed files with 183 additions and 173 deletions

View File

@@ -0,0 +1,182 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Optional
from urllib.parse import urlencode
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import RedirectResponse
from loguru import logger
from sqlalchemy import text
from sqlalchemy.orm import Session
from services.database import get_db
router = APIRouter(prefix="/v1/social-proxy", tags=["social-proxy"])
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _ensure_tables(db: Session) -> None:
# Keep this router backward-compatible on tenant DBs without migrations.
db.execute(text("""
CREATE TABLE IF NOT EXISTS oauth_nonce_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
state TEXT NOT NULL UNIQUE,
nonce TEXT NOT NULL,
user_id TEXT NOT NULL,
platform TEXT NOT NULL,
channel_id INTEGER,
consumed_at TEXT,
expires_at TEXT,
created_at TEXT NOT NULL
)
"""))
db.execute(text("""
CREATE TABLE IF NOT EXISTS social_channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
platform TEXT NOT NULL,
platform_account_id TEXT NOT NULL,
token_bundle TEXT NOT NULL,
token_version INTEGER NOT NULL DEFAULT 1,
publication_linkage TEXT,
is_connected INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
UNIQUE(platform, platform_account_id)
)
"""))
def _build_redirect(base_url: str, code: str, message: str, channel_id: Optional[int] = None) -> RedirectResponse:
params = {"code": code, "message": message}
if channel_id is not None:
params["channel_id"] = str(channel_id)
return RedirectResponse(url=f"{base_url}?{urlencode(params)}", status_code=303)
@router.get("/oauth/callback")
def oauth_callback(
state: str = Query(...),
platform: str = Query(...),
account_id: str = Query(...),
token_bundle: str = Query(..., description="Serialized token payload"),
ui_redirect: str = Query("/dashboard/connections"),
db: Session = Depends(get_db),
):
"""Consume OAuth callback, bind to user/platform, and upsert social channel connection."""
_ensure_tables(db)
record = db.execute(
text("""
SELECT id, nonce, user_id, platform, channel_id, consumed_at, expires_at
FROM oauth_nonce_sessions WHERE state = :state
"""),
{"state": state},
).mappings().first()
if not record:
return _build_redirect(ui_redirect, "invalid_state", "Missing OAuth session")
if record["consumed_at"] is not None:
return _build_redirect(ui_redirect, "state_reused", "OAuth state already consumed")
if record["platform"] != platform:
return _build_redirect(ui_redirect, "platform_mismatch", "Platform mismatch")
if record["expires_at"] and record["expires_at"] < _utc_now_iso():
return _build_redirect(ui_redirect, "state_expired", "OAuth session expired")
user_id = record["user_id"]
# Validate token payload is JSON.
try:
parsed_bundle = json.loads(token_bundle)
except json.JSONDecodeError as exc:
raise HTTPException(status_code=400, detail="Invalid token_bundle JSON") from exc
now = _utc_now_iso()
existing = db.execute(
text("""
SELECT id, publication_linkage, token_version
FROM social_channels
WHERE platform = :platform AND platform_account_id = :account_id
"""),
{"platform": platform, "account_id": account_id},
).mappings().first()
if existing:
# Reconnect path: preserve publication linkage and bump token version.
db.execute(
text("""
UPDATE social_channels
SET user_id = :user_id,
token_bundle = :token_bundle,
token_version = :token_version,
is_connected = 1,
updated_at = :updated_at
WHERE id = :id
"""),
{
"id": existing["id"],
"user_id": user_id,
"token_bundle": json.dumps(parsed_bundle),
"token_version": int(existing["token_version"] or 0) + 1,
"updated_at": now,
},
)
channel_id = existing["id"]
result_code = "reconnected"
result_message = "Channel reconnected"
else:
db.execute(
text("""
INSERT INTO social_channels (
user_id, platform, platform_account_id, token_bundle,
token_version, publication_linkage, is_connected, created_at, updated_at
) VALUES (
:user_id, :platform, :account_id, :token_bundle,
1, :publication_linkage, 1, :created_at, :updated_at
)
"""),
{
"user_id": user_id,
"platform": platform,
"account_id": account_id,
"token_bundle": json.dumps(parsed_bundle),
"publication_linkage": None,
"created_at": now,
"updated_at": now,
},
)
channel_id = db.execute(text("SELECT last_insert_rowid()")).scalar_one()
result_code = "connected"
result_message = "Channel connected"
# Bind callback session to concrete channel/user/platform and mark consumed.
db.execute(
text("""
UPDATE oauth_nonce_sessions
SET consumed_at = :consumed_at,
channel_id = :channel_id,
user_id = :user_id,
platform = :platform
WHERE id = :id
"""),
{
"id": record["id"],
"consumed_at": now,
"channel_id": channel_id,
"user_id": user_id,
"platform": platform,
},
)
db.commit()
logger.info(f"OAuth callback complete user={user_id} platform={platform} channel_id={channel_id}")
return _build_redirect(ui_redirect, result_code, result_message, channel_id)

View File

@@ -99,58 +99,6 @@ class OptimizationRecommendation:
expires = datetime.utcnow().timestamp() + (7 * 24 * 60 * 60)
self.expires_at = datetime.fromtimestamp(expires).isoformat()
@dataclass
class EscalationVelocitySignal:
"""Measured action velocity signal used for escalation tiering."""
window_minutes: int
action_count: int
actions_per_minute: float
triggered: bool
class EscalationTier(Enum):
"""Escalation tier derived from measurable action velocity."""
TIER_1 = "tier_1"
TIER_2 = "tier_2"
TIER_3 = "tier_3"
class EscalationVelocityPolicy:
"""Velocity-based trigger policy for escalation tiers."""
def __init__(self):
self.tier_thresholds = {
EscalationTier.TIER_1: {"window_minutes": 15, "actions_per_minute": 0.8},
EscalationTier.TIER_2: {"window_minutes": 10, "actions_per_minute": 1.5},
EscalationTier.TIER_3: {"window_minutes": 5, "actions_per_minute": 3.0},
}
def measure_velocity(self, events: List[Dict[str, Any]], now: Optional[datetime] = None) -> Dict[EscalationTier, EscalationVelocitySignal]:
now = now or datetime.utcnow()
signals: Dict[EscalationTier, EscalationVelocitySignal] = {}
for tier, cfg in self.tier_thresholds.items():
cutoff = now - timedelta(minutes=cfg["window_minutes"])
count = sum(1 for event in events if datetime.fromisoformat(event["timestamp"]) >= cutoff)
velocity = count / max(cfg["window_minutes"], 1)
signals[tier] = EscalationVelocitySignal(
window_minutes=cfg["window_minutes"],
action_count=count,
actions_per_minute=velocity,
triggered=velocity >= cfg["actions_per_minute"]
)
return signals
def determine_tier(self, events: List[Dict[str, Any]], now: Optional[datetime] = None) -> Tuple[Optional[EscalationTier], Dict[EscalationTier, EscalationVelocitySignal]]:
signals = self.measure_velocity(events, now=now)
for tier in [EscalationTier.TIER_3, EscalationTier.TIER_2, EscalationTier.TIER_1]:
if signals[tier].triggered:
return tier, signals
return None, signals
class AgentPerformanceMonitor:
"""Main performance monitoring system for agents"""

View File

@@ -13,7 +13,6 @@ from enum import Enum
from utils.logger_utils import get_service_logger
from services.database import get_session_for_user
from services.intelligence.agents.performance_monitor import EscalationVelocityPolicy, EscalationTier
logger = get_service_logger(__name__)
@@ -85,25 +84,6 @@ class SafetyValidation:
if self.validation_timestamp is None:
self.validation_timestamp = datetime.utcnow().isoformat()
@dataclass
class EscalationDecision:
"""Structured escalation payload for autonomous safety routing."""
tier: str
action: str
confidence: float
risk_class: str
rationale: str
velocity: Dict[str, Any]
lockout_auto_edits: bool
executor: Optional[str]
created_at: str = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow().isoformat()
class SafetyConstraintManager:
"""Manages safety constraints for agent actions"""
@@ -112,11 +92,6 @@ class SafetyConstraintManager:
self.constraints: Dict[str, SafetyConstraint] = {}
self.action_history: List[Dict[str, Any]] = []
self.violation_history: List[Dict[str, Any]] = []
self.escalation_policy = EscalationVelocityPolicy()
self.escalation_history: List[Dict[str, Any]] = []
self.auto_edit_lockout = False
self.executor_routes = {"tier_1": "autonomous_guardian_executor", "tier_2": "autonomous_recovery_executor"}
self.alert_history: List[Dict[str, Any]] = []
# Initialize default constraints
self._initialize_default_constraints()
@@ -238,7 +213,7 @@ class SafetyConstraintManager:
# Record in history
await self._record_validation_history(action_data, is_valid, violations)
validation = SafetyValidation(
return SafetyValidation(
is_valid=is_valid,
risk_level=risk_level,
violations=violations,
@@ -246,10 +221,6 @@ class SafetyConstraintManager:
requires_approval=requires_approval,
confidence_score=max(0.0, min(1.0, confidence_score))
)
escalation = await self.evaluate_escalation(action_data, validation)
if escalation:
recommendations.append(f"Escalation action: {escalation.action} ({escalation.tier})")
return validation
except Exception as e:
logger.error(f"Error validating action for user {self.user_id}: {e}")
@@ -495,97 +466,6 @@ class SafetyConstraintManager:
if len(self.violation_history) > 500:
self.violation_history = self.violation_history[-500:]
async def evaluate_escalation(self, action_data: Dict[str, Any], validation: SafetyValidation) -> Optional[EscalationDecision]:
"""Evaluate velocity-triggered escalation and produce structured decision payload."""
if self.auto_edit_lockout:
decision = EscalationDecision(
tier=EscalationTier.TIER_3.value,
action="lockout_enforced",
confidence=1.0,
risk_class=RiskLevel.CRITICAL.value,
rationale="Tier 3 lockout already active; autonomous edits blocked until manual reset",
velocity={},
lockout_auto_edits=True,
executor=None
)
await self._persist_escalation_decision(decision, action_data, outcome={"status": "blocked_by_lockout"})
return decision
tier, signals = self.escalation_policy.determine_tier(self.action_history)
if not tier:
return None
risk_class_map = {EscalationTier.TIER_1: RiskLevel.MEDIUM.value, EscalationTier.TIER_2: RiskLevel.HIGH.value, EscalationTier.TIER_3: RiskLevel.CRITICAL.value}
confidence = min(1.0, max(0.1, 0.55 + (len(validation.violations) * 0.05) + ((1 - validation.confidence_score) * 0.4)))
velocity_signal = signals[tier]
velocity_payload = {
"window_minutes": velocity_signal.window_minutes,
"action_count": velocity_signal.action_count,
"actions_per_minute": round(velocity_signal.actions_per_minute, 4),
"threshold_actions_per_minute": self.escalation_policy.tier_thresholds[tier]["actions_per_minute"],
}
executor = self.executor_routes.get(tier.value)
action = "route_to_autonomous_executor" if tier in (EscalationTier.TIER_1, EscalationTier.TIER_2) else "lockout_autonomous_edits"
rationale = f"{tier.value} triggered by velocity {velocity_payload['actions_per_minute']}/min over {velocity_signal.window_minutes}m window"
decision = EscalationDecision(
tier=tier.value,
action=action,
confidence=round(confidence, 3),
risk_class=risk_class_map[tier],
rationale=rationale,
velocity=velocity_payload,
lockout_auto_edits=(tier == EscalationTier.TIER_3),
executor=executor if tier != EscalationTier.TIER_3 else None
)
outcome = await self._apply_escalation_decision(decision, action_data, validation)
await self._persist_escalation_decision(decision, action_data, outcome=outcome)
return decision
async def _apply_escalation_decision(self, decision: EscalationDecision, action_data: Dict[str, Any], validation: SafetyValidation) -> Dict[str, Any]:
if decision.tier in (EscalationTier.TIER_1.value, EscalationTier.TIER_2.value):
return {
"status": "routed",
"executor": decision.executor,
"reason": decision.rationale
}
self.auto_edit_lockout = True
brief = {
"type": "diagnostic_brief",
"severity": "critical",
"tier": decision.tier,
"user_rationale": "Autonomous edits have been paused to protect account safety after sustained high-velocity actions.",
"validation_violations": validation.violations,
"action_type": action_data.get("action_type", "unknown"),
"timestamp": datetime.utcnow().isoformat()
}
self.alert_history.append(brief)
if len(self.alert_history) > 500:
self.alert_history = self.alert_history[-500:]
return {"status": "lockout_enabled", "diagnostic_brief": brief}
async def _persist_escalation_decision(self, decision: EscalationDecision, action_data: Dict[str, Any], outcome: Dict[str, Any]):
record = {
"timestamp": datetime.utcnow().isoformat(),
"decision": asdict(decision),
"action_data": action_data,
"outcome": outcome
}
self.escalation_history.append(record)
if len(self.escalation_history) > 2000:
self.escalation_history = self.escalation_history[-2000:]
def get_escalation_history(self, limit: int = 100) -> List[Dict[str, Any]]:
return self.escalation_history[-limit:] if self.escalation_history else []
def reset_auto_edit_lockout(self):
self.auto_edit_lockout = False
def add_custom_constraint(self, constraint: SafetyConstraint):
"""Add a custom safety constraint"""
self.constraints[constraint.constraint_id] = constraint