1019 lines
43 KiB
Python
1019 lines
43 KiB
Python
"""
|
|
Agent Safety Framework for ALwrity Autonomous Marketing Agents
|
|
Implements safety constraints, validation, and rollback mechanisms
|
|
"""
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Any, Optional, Set
|
|
from dataclasses import dataclass, asdict
|
|
from enum import Enum
|
|
|
|
from utils.logger_utils import get_service_logger
|
|
from services.database import get_session_for_user
|
|
from services.intelligence.agents.performance_monitor import EscalationVelocityPolicy, EscalationTier
|
|
|
|
logger = get_service_logger(__name__)
|
|
|
|
class RiskLevel(Enum):
|
|
"""Risk levels for agent actions"""
|
|
LOW = "low"
|
|
MEDIUM = "medium"
|
|
HIGH = "high"
|
|
CRITICAL = "critical"
|
|
|
|
class ActionCategory(Enum):
|
|
"""Categories of agent actions"""
|
|
CONTENT_MODIFICATION = "content_modification"
|
|
SEO_OPTIMIZATION = "seo_optimization"
|
|
COMPETITOR_RESPONSE = "competitor_response"
|
|
SOCIAL_AMPLIFICATION = "social_amplification"
|
|
STRATEGY_CHANGE = "strategy_change"
|
|
SYSTEM_CONFIGURATION = "system_configuration"
|
|
|
|
@dataclass
|
|
class SafetyConstraint:
|
|
"""Represents a safety constraint for agent actions"""
|
|
constraint_id: str
|
|
name: str
|
|
description: str
|
|
action_categories: List[ActionCategory]
|
|
risk_threshold: float # Maximum allowed risk level (0.0 to 1.0)
|
|
approval_required: bool
|
|
auto_approval_threshold: float # Risk level below which auto-approval is allowed
|
|
daily_limit: Optional[int] = None # Maximum actions per day
|
|
hourly_limit: Optional[int] = None # Maximum actions per hour
|
|
conditions: Dict[str, Any] = None # Additional conditions for validation
|
|
created_at: str = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.utcnow().isoformat()
|
|
if self.conditions is None:
|
|
self.conditions = {}
|
|
|
|
@dataclass
|
|
class ActionCheckpoint:
|
|
"""Represents a checkpoint for rollback purposes"""
|
|
checkpoint_id: str
|
|
action_id: str
|
|
agent_id: str
|
|
user_id: str
|
|
action_type: str
|
|
action_data: Dict[str, Any]
|
|
system_state: Dict[str, Any]
|
|
created_at: str = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.utcnow().isoformat()
|
|
|
|
@dataclass
|
|
class SafetyValidation:
|
|
"""Result of safety validation"""
|
|
is_valid: bool
|
|
risk_level: RiskLevel
|
|
violations: List[str]
|
|
recommendations: List[str]
|
|
requires_approval: bool
|
|
confidence_score: float # 0.0 to 1.0
|
|
validation_timestamp: str = None
|
|
|
|
def __post_init__(self):
|
|
if self.validation_timestamp is None:
|
|
self.validation_timestamp = datetime.utcnow().isoformat()
|
|
|
|
|
|
|
|
@dataclass
|
|
class EscalationDecision:
|
|
"""Structured escalation payload for autonomous safety routing."""
|
|
tier: str
|
|
action: str
|
|
confidence: float
|
|
risk_class: str
|
|
rationale: str
|
|
velocity: Dict[str, Any]
|
|
lockout_auto_edits: bool
|
|
executor: Optional[str]
|
|
created_at: str = None
|
|
|
|
def __post_init__(self):
|
|
if self.created_at is None:
|
|
self.created_at = datetime.utcnow().isoformat()
|
|
|
|
class SafetyConstraintManager:
|
|
"""Manages safety constraints for agent actions"""
|
|
|
|
def __init__(self, user_id: str):
|
|
self.user_id = user_id
|
|
self.constraints: Dict[str, SafetyConstraint] = {}
|
|
self.action_history: List[Dict[str, Any]] = []
|
|
self.violation_history: List[Dict[str, Any]] = []
|
|
self.escalation_policy = EscalationVelocityPolicy()
|
|
self.escalation_history: List[Dict[str, Any]] = []
|
|
self.auto_edit_lockout = False
|
|
self.executor_routes = {"tier_1": "autonomous_guardian_executor", "tier_2": "autonomous_recovery_executor"}
|
|
self.alert_history: List[Dict[str, Any]] = []
|
|
|
|
# Initialize default constraints
|
|
self._initialize_default_constraints()
|
|
|
|
logger.info(f"Initialized SafetyConstraintManager for user: {user_id}")
|
|
|
|
def _initialize_default_constraints(self):
|
|
"""Initialize default safety constraints"""
|
|
default_constraints = [
|
|
SafetyConstraint(
|
|
constraint_id="content_modification_limit",
|
|
name="Content Modification Daily Limit",
|
|
description="Limit the number of content modifications per day",
|
|
action_categories=[ActionCategory.CONTENT_MODIFICATION],
|
|
risk_threshold=0.7,
|
|
approval_required=False,
|
|
auto_approval_threshold=0.3,
|
|
daily_limit=50,
|
|
hourly_limit=10
|
|
),
|
|
SafetyConstraint(
|
|
constraint_id="high_risk_approval_required",
|
|
name="High Risk Action Approval",
|
|
description="Require approval for high-risk actions",
|
|
action_categories=[ActionCategory.STRATEGY_CHANGE, ActionCategory.SYSTEM_CONFIGURATION],
|
|
risk_threshold=0.8,
|
|
approval_required=True,
|
|
auto_approval_threshold=0.2
|
|
),
|
|
SafetyConstraint(
|
|
constraint_id="competitor_response_cooldown",
|
|
name="Competitor Response Cooldown",
|
|
description="Prevent excessive competitor responses",
|
|
action_categories=[ActionCategory.COMPETITOR_RESPONSE],
|
|
risk_threshold=0.6,
|
|
approval_required=False,
|
|
auto_approval_threshold=0.4,
|
|
daily_limit=20,
|
|
hourly_limit=5
|
|
),
|
|
SafetyConstraint(
|
|
constraint_id="seo_optimization_safety",
|
|
name="SEO Optimization Safety",
|
|
description="Ensure SEO optimizations don't harm rankings",
|
|
action_categories=[ActionCategory.SEO_OPTIMIZATION],
|
|
risk_threshold=0.5,
|
|
approval_required=False,
|
|
auto_approval_threshold=0.3,
|
|
daily_limit=30,
|
|
hourly_limit=8
|
|
),
|
|
SafetyConstraint(
|
|
constraint_id="social_amplification_limits",
|
|
name="Social Amplification Limits",
|
|
description="Limit social media amplification to prevent spam",
|
|
action_categories=[ActionCategory.SOCIAL_AMPLIFICATION],
|
|
risk_threshold=0.6,
|
|
approval_required=False,
|
|
auto_approval_threshold=0.4,
|
|
daily_limit=25,
|
|
hourly_limit=6
|
|
)
|
|
]
|
|
|
|
for constraint in default_constraints:
|
|
self.constraints[constraint.constraint_id] = constraint
|
|
|
|
async def validate_action(self, action_data: Dict[str, Any]) -> SafetyValidation:
|
|
"""Validate an action against safety constraints"""
|
|
try:
|
|
logger.info(f"Validating action for user {self.user_id}: {action_data.get('action_type', 'unknown')}")
|
|
|
|
violations = []
|
|
recommendations = []
|
|
requires_approval = False
|
|
confidence_score = 1.0
|
|
|
|
# Extract action details
|
|
action_type = action_data.get('action_type', 'unknown')
|
|
action_category = self._determine_action_category(action_type)
|
|
risk_score = action_data.get('risk_score', 0.5)
|
|
impact_score = action_data.get('impact_score', 0.5)
|
|
|
|
# Determine risk level
|
|
risk_level = self._calculate_risk_level(risk_score, impact_score)
|
|
|
|
# Check against all relevant constraints
|
|
for constraint in self.constraints.values():
|
|
if action_category in constraint.action_categories:
|
|
constraint_result = await self._check_constraint(constraint, action_data, risk_level)
|
|
|
|
if not constraint_result['is_valid']:
|
|
violations.extend(constraint_result['violations'])
|
|
confidence_score *= 0.9 # Reduce confidence for violations
|
|
|
|
if constraint_result['requires_approval']:
|
|
requires_approval = True
|
|
|
|
recommendations.extend(constraint_result['recommendations'])
|
|
|
|
# Check rate limits
|
|
rate_limit_result = await self._check_rate_limits(action_category, action_data)
|
|
if not rate_limit_result['is_valid']:
|
|
violations.extend(rate_limit_result['violations'])
|
|
confidence_score *= 0.8
|
|
|
|
# Check for suspicious patterns
|
|
pattern_result = await self._check_suspicious_patterns(action_data)
|
|
if not pattern_result['is_valid']:
|
|
violations.extend(pattern_result['violations'])
|
|
confidence_score *= 0.7
|
|
requires_approval = True # Suspicious patterns always require approval
|
|
|
|
# Final validation
|
|
is_valid = len(violations) == 0 and not requires_approval
|
|
|
|
logger.info(f"Action validation completed for user {self.user_id}. Valid: {is_valid}, Risk: {risk_level.value}, Violations: {len(violations)}")
|
|
|
|
# Record in history
|
|
await self._record_validation_history(action_data, is_valid, violations)
|
|
|
|
validation = SafetyValidation(
|
|
is_valid=is_valid,
|
|
risk_level=risk_level,
|
|
violations=violations,
|
|
recommendations=recommendations,
|
|
requires_approval=requires_approval,
|
|
confidence_score=max(0.0, min(1.0, confidence_score))
|
|
)
|
|
escalation = await self.evaluate_escalation(action_data, validation)
|
|
if escalation:
|
|
recommendations.append(f"Escalation action: {escalation.action} ({escalation.tier})")
|
|
return validation
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating action for user {self.user_id}: {e}")
|
|
|
|
# Return safe default on error
|
|
return SafetyValidation(
|
|
is_valid=False,
|
|
risk_level=RiskLevel.CRITICAL,
|
|
violations=["Validation system error"],
|
|
recommendations=["Manual review required"],
|
|
requires_approval=True,
|
|
confidence_score=0.0
|
|
)
|
|
|
|
def _determine_action_category(self, action_type: str) -> ActionCategory:
|
|
"""Determine the category of an action"""
|
|
action_type_lower = action_type.lower()
|
|
|
|
if any(keyword in action_type_lower for keyword in ['content', 'blog', 'article', 'post']):
|
|
return ActionCategory.CONTENT_MODIFICATION
|
|
elif any(keyword in action_type_lower for keyword in ['seo', 'meta', 'keyword', 'optimization']):
|
|
return ActionCategory.SEO_OPTIMIZATION
|
|
elif any(keyword in action_type_lower for keyword in ['competitor', 'competitive', 'response']):
|
|
return ActionCategory.COMPETITOR_RESPONSE
|
|
elif any(keyword in action_type_lower for keyword in ['social', 'share', 'amplify', 'distribute']):
|
|
return ActionCategory.SOCIAL_AMPLIFICATION
|
|
elif any(keyword in action_type_lower for keyword in ['strategy', 'plan', 'approach']):
|
|
return ActionCategory.STRATEGY_CHANGE
|
|
elif any(keyword in action_type_lower for keyword in ['config', 'setting', 'system']):
|
|
return ActionCategory.SYSTEM_CONFIGURATION
|
|
else:
|
|
return ActionCategory.CONTENT_MODIFICATION # Default category
|
|
|
|
def _calculate_risk_level(self, risk_score: float, impact_score: float) -> RiskLevel:
|
|
"""Calculate overall risk level"""
|
|
# Weighted combination of risk and impact
|
|
combined_score = (risk_score * 0.6) + (impact_score * 0.4)
|
|
|
|
if combined_score >= 0.8:
|
|
return RiskLevel.CRITICAL
|
|
elif combined_score >= 0.6:
|
|
return RiskLevel.HIGH
|
|
elif combined_score >= 0.3:
|
|
return RiskLevel.MEDIUM
|
|
else:
|
|
return RiskLevel.LOW
|
|
|
|
async def _check_constraint(self, constraint: SafetyConstraint, action_data: Dict[str, Any], risk_level: RiskLevel) -> Dict[str, Any]:
|
|
"""Check an action against a specific constraint"""
|
|
violations = []
|
|
recommendations = []
|
|
requires_approval = False
|
|
|
|
# Check risk threshold
|
|
if risk_level.value in ['high', 'critical'] and constraint.risk_threshold < 0.8:
|
|
violations.append(f"Risk level {risk_level.value} exceeds constraint threshold")
|
|
requires_approval = True
|
|
|
|
# Check rate limits
|
|
if constraint.daily_limit:
|
|
daily_count = await self._get_daily_action_count(constraint.constraint_id)
|
|
if daily_count >= constraint.daily_limit:
|
|
violations.append(f"Daily limit exceeded: {daily_count}/{constraint.daily_limit}")
|
|
|
|
if constraint.hourly_limit:
|
|
hourly_count = await self._get_hourly_action_count(constraint.constraint_id)
|
|
if hourly_count >= constraint.hourly_limit:
|
|
violations.append(f"Hourly limit exceeded: {hourly_count}/{constraint.hourly_limit}")
|
|
|
|
# Check approval requirement
|
|
if constraint.approval_required:
|
|
requires_approval = True
|
|
recommendations.append("Action requires manual approval due to safety constraints")
|
|
|
|
# Check auto-approval threshold
|
|
risk_score = action_data.get('risk_score', 0.5)
|
|
if risk_score > constraint.auto_approval_threshold:
|
|
requires_approval = True
|
|
|
|
# Custom condition checks
|
|
if constraint.conditions:
|
|
condition_result = await self._check_custom_conditions(constraint.conditions, action_data)
|
|
if not condition_result['is_valid']:
|
|
violations.extend(condition_result['violations'])
|
|
|
|
is_valid = len(violations) == 0 and not requires_approval
|
|
|
|
return {
|
|
"is_valid": is_valid,
|
|
"violations": violations,
|
|
"recommendations": recommendations,
|
|
"requires_approval": requires_approval
|
|
}
|
|
|
|
async def _check_rate_limits(self, action_category: ActionCategory, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Check rate limits for actions"""
|
|
violations = []
|
|
|
|
# Get current time window counts
|
|
recent_actions = await self._get_recent_actions(hours=1)
|
|
category_actions = [action for action in recent_actions if self._determine_action_category(action.get('action_type', '')) == action_category]
|
|
|
|
# Check hourly limits
|
|
if len(category_actions) > 50: # Default hourly limit
|
|
violations.append(f"Hourly action limit exceeded for {action_category.value}")
|
|
|
|
# Check daily limits
|
|
daily_actions = await self._get_recent_actions(hours=24)
|
|
daily_category_actions = [action for action in daily_actions if self._determine_action_category(action.get('action_type', '')) == action_category]
|
|
|
|
if len(daily_category_actions) > 200: # Default daily limit
|
|
violations.append(f"Daily action limit exceeded for {action_category.value}")
|
|
|
|
return {
|
|
"is_valid": len(violations) == 0,
|
|
"violations": violations
|
|
}
|
|
|
|
async def _check_suspicious_patterns(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Check for suspicious patterns in actions"""
|
|
violations = []
|
|
|
|
# Get recent action patterns
|
|
recent_actions = await self._get_recent_actions(hours=24)
|
|
|
|
# Check for rapid repetitive actions
|
|
action_type = action_data.get('action_type', '')
|
|
similar_actions = [action for action in recent_actions if action.get('action_type') == action_type]
|
|
|
|
if len(similar_actions) > 10: # More than 10 similar actions in 24 hours
|
|
violations.append(f"Suspicious pattern: {len(similar_actions)} similar actions in 24 hours")
|
|
|
|
# Check for unusual timing patterns
|
|
if len(recent_actions) > 100: # More than 100 actions in 1 hour
|
|
violations.append("Suspicious pattern: Unusually high action frequency")
|
|
|
|
# Check for conflicting actions
|
|
conflicting_actions = await self._detect_conflicting_actions(action_data, recent_actions)
|
|
if conflicting_actions:
|
|
violations.append(f"Conflicting actions detected: {len(conflicting_actions)}")
|
|
|
|
return {
|
|
"is_valid": len(violations) == 0,
|
|
"violations": violations
|
|
}
|
|
|
|
async def _detect_conflicting_actions(self, current_action: Dict[str, Any], recent_actions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
|
"""Detect actions that conflict with recent actions"""
|
|
conflicts = []
|
|
|
|
# Simple conflict detection based on action types
|
|
conflicting_pairs = [
|
|
("optimize_content", "delete_content"),
|
|
("increase_keywords", "decrease_keywords"),
|
|
("enable_feature", "disable_feature")
|
|
]
|
|
|
|
current_action_type = current_action.get('action_type', '')
|
|
|
|
for pair in conflicting_pairs:
|
|
if current_action_type == pair[0]:
|
|
# Check for recent opposite action
|
|
for action in recent_actions:
|
|
if action.get('action_type') == pair[1]:
|
|
conflicts.append(action)
|
|
break
|
|
elif current_action_type == pair[1]:
|
|
# Check for recent opposite action
|
|
for action in recent_actions:
|
|
if action.get('action_type') == pair[0]:
|
|
conflicts.append(action)
|
|
break
|
|
|
|
return conflicts
|
|
|
|
async def _check_custom_conditions(self, conditions: Dict[str, Any], action_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Check custom conditions for constraints"""
|
|
violations = []
|
|
|
|
# Example custom conditions (can be extended)
|
|
if conditions.get('max_content_length'):
|
|
content_length = len(action_data.get('content', ''))
|
|
if content_length > conditions['max_content_length']:
|
|
violations.append(f"Content length {content_length} exceeds maximum {conditions['max_content_length']}")
|
|
|
|
if conditions.get('allowed_keywords'):
|
|
content = action_data.get('content', '').lower()
|
|
allowed_keywords = [kw.lower() for kw in conditions['allowed_keywords']]
|
|
if not any(keyword in content for keyword in allowed_keywords):
|
|
violations.append("Content does not contain required keywords")
|
|
|
|
return {
|
|
"is_valid": len(violations) == 0,
|
|
"violations": violations
|
|
}
|
|
|
|
async def _get_recent_actions(self, hours: int = 24) -> List[Dict[str, Any]]:
|
|
"""Get recent actions from history"""
|
|
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
|
|
|
|
return [
|
|
action for action in self.action_history
|
|
if datetime.fromisoformat(action.get('timestamp', datetime.utcnow().isoformat())) > cutoff_time
|
|
]
|
|
|
|
async def _get_daily_action_count(self, constraint_id: str) -> int:
|
|
"""Get daily action count for a specific constraint"""
|
|
daily_actions = await self._get_recent_actions(hours=24)
|
|
return len(daily_actions)
|
|
|
|
async def _get_hourly_action_count(self, constraint_id: str) -> int:
|
|
"""Get hourly action count for a specific constraint"""
|
|
hourly_actions = await self._get_recent_actions(hours=1)
|
|
return len(hourly_actions)
|
|
|
|
async def _record_validation_history(self, action_data: Dict[str, Any], is_valid: bool, violations: List[str]):
|
|
"""Record validation in history"""
|
|
validation_record = {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"action_type": action_data.get('action_type', 'unknown'),
|
|
"is_valid": is_valid,
|
|
"violations": violations,
|
|
"action_data": action_data
|
|
}
|
|
|
|
self.action_history.append(validation_record)
|
|
|
|
# Keep only recent history (last 1000 records)
|
|
if len(self.action_history) > 1000:
|
|
self.action_history = self.action_history[-1000:]
|
|
|
|
# Record violations separately
|
|
if violations:
|
|
violation_record = {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"action_type": action_data.get('action_type', 'unknown'),
|
|
"violations": violations,
|
|
"severity": "high" if len(violations) > 2 else "medium"
|
|
}
|
|
self.violation_history.append(violation_record)
|
|
|
|
# Keep only recent violations (last 500 records)
|
|
if len(self.violation_history) > 500:
|
|
self.violation_history = self.violation_history[-500:]
|
|
|
|
async def evaluate_escalation(self, action_data: Dict[str, Any], validation: SafetyValidation) -> Optional[EscalationDecision]:
|
|
"""Evaluate velocity-triggered escalation and produce structured decision payload."""
|
|
if self.auto_edit_lockout:
|
|
decision = EscalationDecision(
|
|
tier=EscalationTier.TIER_3.value,
|
|
action="lockout_enforced",
|
|
confidence=1.0,
|
|
risk_class=RiskLevel.CRITICAL.value,
|
|
rationale="Tier 3 lockout already active; autonomous edits blocked until manual reset",
|
|
velocity={},
|
|
lockout_auto_edits=True,
|
|
executor=None
|
|
)
|
|
await self._persist_escalation_decision(decision, action_data, outcome={"status": "blocked_by_lockout"})
|
|
return decision
|
|
|
|
tier, signals = self.escalation_policy.determine_tier(self.action_history)
|
|
if not tier:
|
|
return None
|
|
|
|
risk_class_map = {EscalationTier.TIER_1: RiskLevel.MEDIUM.value, EscalationTier.TIER_2: RiskLevel.HIGH.value, EscalationTier.TIER_3: RiskLevel.CRITICAL.value}
|
|
confidence = min(1.0, max(0.1, 0.55 + (len(validation.violations) * 0.05) + ((1 - validation.confidence_score) * 0.4)))
|
|
|
|
velocity_signal = signals[tier]
|
|
velocity_payload = {
|
|
"window_minutes": velocity_signal.window_minutes,
|
|
"action_count": velocity_signal.action_count,
|
|
"actions_per_minute": round(velocity_signal.actions_per_minute, 4),
|
|
"threshold_actions_per_minute": self.escalation_policy.tier_thresholds[tier]["actions_per_minute"],
|
|
}
|
|
|
|
executor = self.executor_routes.get(tier.value)
|
|
action = "route_to_autonomous_executor" if tier in (EscalationTier.TIER_1, EscalationTier.TIER_2) else "lockout_autonomous_edits"
|
|
rationale = f"{tier.value} triggered by velocity {velocity_payload['actions_per_minute']}/min over {velocity_signal.window_minutes}m window"
|
|
|
|
decision = EscalationDecision(
|
|
tier=tier.value,
|
|
action=action,
|
|
confidence=round(confidence, 3),
|
|
risk_class=risk_class_map[tier],
|
|
rationale=rationale,
|
|
velocity=velocity_payload,
|
|
lockout_auto_edits=(tier == EscalationTier.TIER_3),
|
|
executor=executor if tier != EscalationTier.TIER_3 else None
|
|
)
|
|
|
|
outcome = await self._apply_escalation_decision(decision, action_data, validation)
|
|
await self._persist_escalation_decision(decision, action_data, outcome=outcome)
|
|
return decision
|
|
|
|
async def _apply_escalation_decision(self, decision: EscalationDecision, action_data: Dict[str, Any], validation: SafetyValidation) -> Dict[str, Any]:
|
|
if decision.tier in (EscalationTier.TIER_1.value, EscalationTier.TIER_2.value):
|
|
return {
|
|
"status": "routed",
|
|
"executor": decision.executor,
|
|
"reason": decision.rationale
|
|
}
|
|
|
|
self.auto_edit_lockout = True
|
|
brief = {
|
|
"type": "diagnostic_brief",
|
|
"severity": "critical",
|
|
"tier": decision.tier,
|
|
"user_rationale": "Autonomous edits have been paused to protect account safety after sustained high-velocity actions.",
|
|
"validation_violations": validation.violations,
|
|
"action_type": action_data.get("action_type", "unknown"),
|
|
"timestamp": datetime.utcnow().isoformat()
|
|
}
|
|
self.alert_history.append(brief)
|
|
if len(self.alert_history) > 500:
|
|
self.alert_history = self.alert_history[-500:]
|
|
|
|
return {"status": "lockout_enabled", "diagnostic_brief": brief}
|
|
|
|
async def _persist_escalation_decision(self, decision: EscalationDecision, action_data: Dict[str, Any], outcome: Dict[str, Any]):
|
|
record = {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"decision": asdict(decision),
|
|
"action_data": action_data,
|
|
"outcome": outcome
|
|
}
|
|
self.escalation_history.append(record)
|
|
if len(self.escalation_history) > 2000:
|
|
self.escalation_history = self.escalation_history[-2000:]
|
|
|
|
def get_escalation_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
return self.escalation_history[-limit:] if self.escalation_history else []
|
|
|
|
def reset_auto_edit_lockout(self):
|
|
self.auto_edit_lockout = False
|
|
|
|
def add_custom_constraint(self, constraint: SafetyConstraint):
|
|
"""Add a custom safety constraint"""
|
|
self.constraints[constraint.constraint_id] = constraint
|
|
logger.info(f"Added custom constraint for user {self.user_id}: {constraint.constraint_id}")
|
|
|
|
def remove_constraint(self, constraint_id: str):
|
|
"""Remove a safety constraint"""
|
|
if constraint_id in self.constraints:
|
|
del self.constraints[constraint_id]
|
|
logger.info(f"Removed constraint for user {self.user_id}: {constraint_id}")
|
|
|
|
def get_constraints(self) -> Dict[str, SafetyConstraint]:
|
|
"""Get all safety constraints"""
|
|
return self.constraints.copy()
|
|
|
|
def get_validation_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
|
"""Get recent validation history"""
|
|
return self.action_history[-limit:] if self.action_history else []
|
|
|
|
def get_violation_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
|
"""Get recent violation history"""
|
|
return self.violation_history[-limit:] if self.violation_history else []
|
|
|
|
class RollbackManager:
|
|
"""Manages rollback operations for agent actions"""
|
|
|
|
def __init__(self, user_id: str):
|
|
self.user_id = user_id
|
|
self.checkpoints: List[ActionCheckpoint] = []
|
|
self.rollback_history: List[Dict[str, Any]] = []
|
|
|
|
logger.info(f"Initialized RollbackManager for user: {user_id}")
|
|
|
|
async def create_checkpoint(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> str:
|
|
"""Create a checkpoint before executing an action"""
|
|
try:
|
|
checkpoint_id = f"checkpoint_{self.user_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
|
|
|
|
checkpoint = ActionCheckpoint(
|
|
checkpoint_id=checkpoint_id,
|
|
action_id=action_data.get('action_id', 'unknown'),
|
|
agent_id=action_data.get('agent_id', 'unknown'),
|
|
user_id=self.user_id,
|
|
action_type=action_data.get('action_type', 'unknown'),
|
|
action_data=action_data,
|
|
system_state=system_state
|
|
)
|
|
|
|
self.checkpoints.append(checkpoint)
|
|
|
|
# Keep only recent checkpoints (last 100)
|
|
if len(self.checkpoints) > 100:
|
|
self.checkpoints = self.checkpoints[-100:]
|
|
|
|
logger.info(f"Created checkpoint for user {self.user_id}: {checkpoint_id}")
|
|
return checkpoint_id
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating checkpoint for user {self.user_id}: {e}")
|
|
raise e
|
|
|
|
async def rollback_to_checkpoint(self, checkpoint_id: str) -> Dict[str, Any]:
|
|
"""Rollback to a specific checkpoint"""
|
|
try:
|
|
# Find checkpoint
|
|
checkpoint = next((cp for cp in self.checkpoints if cp.checkpoint_id == checkpoint_id), None)
|
|
|
|
if not checkpoint:
|
|
return {
|
|
"success": False,
|
|
"error": f"Checkpoint not found: {checkpoint_id}"
|
|
}
|
|
|
|
logger.info(f"Rolling back to checkpoint for user {self.user_id}: {checkpoint_id}")
|
|
|
|
# Execute rollback (implementation depends on action type)
|
|
rollback_result = await self._execute_rollback(checkpoint)
|
|
|
|
# Record in history
|
|
rollback_record = {
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"checkpoint_id": checkpoint_id,
|
|
"action_type": checkpoint.action_type,
|
|
"success": rollback_result["success"],
|
|
"details": rollback_result
|
|
}
|
|
self.rollback_history.append(rollback_record)
|
|
|
|
# Keep only recent rollback history (last 50)
|
|
if len(self.rollback_history) > 50:
|
|
self.rollback_history = self.rollback_history[-50:]
|
|
|
|
return rollback_result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error rolling back to checkpoint {checkpoint_id} for user {self.user_id}: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
async def _execute_rollback(self, checkpoint: ActionCheckpoint) -> Dict[str, Any]:
|
|
"""Execute the rollback operation based on action type"""
|
|
try:
|
|
action_type = checkpoint.action_type
|
|
action_data = checkpoint.action_data
|
|
system_state = checkpoint.system_state
|
|
|
|
# Implement rollback logic for different action types
|
|
if action_type == "content_modification":
|
|
return await self._rollback_content_modification(action_data, system_state)
|
|
elif action_type == "seo_optimization":
|
|
return await self._rollback_seo_optimization(action_data, system_state)
|
|
elif action_type == "competitor_response":
|
|
return await self._rollback_competitor_response(action_data, system_state)
|
|
elif action_type == "social_amplification":
|
|
return await self._rollback_social_amplification(action_data, system_state)
|
|
else:
|
|
# Generic rollback
|
|
return await self._rollback_generic(action_data, system_state)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error executing rollback for action {action_type}: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
async def _rollback_content_modification(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Rollback content modification"""
|
|
try:
|
|
# Implementation would depend on how content is stored and managed
|
|
# For now, return a placeholder implementation
|
|
|
|
original_content = system_state.get('original_content', {})
|
|
modified_content = action_data.get('content', {})
|
|
|
|
logger.info(f"Rolling back content modification: {action_data.get('content_id', 'unknown')}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Content modification rolled back successfully",
|
|
"details": {
|
|
"content_id": action_data.get('content_id'),
|
|
"rollback_type": "content_modification",
|
|
"original_state_restored": bool(original_content)
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Failed to rollback content modification: {str(e)}"
|
|
}
|
|
|
|
async def _rollback_seo_optimization(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Rollback SEO optimization"""
|
|
try:
|
|
original_seo_state = system_state.get('seo_state', {})
|
|
|
|
logger.info(f"Rolling back SEO optimization: {action_data.get('optimization_type', 'unknown')}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "SEO optimization rolled back successfully",
|
|
"details": {
|
|
"optimization_type": action_data.get('optimization_type'),
|
|
"rollback_type": "seo_optimization",
|
|
"original_state_restored": bool(original_seo_state)
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Failed to rollback SEO optimization: {str(e)}"
|
|
}
|
|
|
|
async def _rollback_competitor_response(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Rollback competitor response"""
|
|
try:
|
|
logger.info(f"Rolling back competitor response: {action_data.get('response_type', 'unknown')}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Competitor response rolled back successfully",
|
|
"details": {
|
|
"response_type": action_data.get('response_type'),
|
|
"rollback_type": "competitor_response",
|
|
"original_state_restored": True
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Failed to rollback competitor response: {str(e)}"
|
|
}
|
|
|
|
async def _rollback_social_amplification(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Rollback social amplification"""
|
|
try:
|
|
logger.info(f"Rolling back social amplification: {action_data.get('platform', 'unknown')}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Social amplification rolled back successfully",
|
|
"details": {
|
|
"platform": action_data.get('platform'),
|
|
"rollback_type": "social_amplification",
|
|
"original_state_restored": True
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Failed to rollback social amplification: {str(e)}"
|
|
}
|
|
|
|
async def _rollback_generic(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Generic rollback for unknown action types"""
|
|
try:
|
|
logger.info(f"Performing generic rollback for action: {action_data.get('action_type', 'unknown')}")
|
|
|
|
return {
|
|
"success": True,
|
|
"message": "Generic rollback completed",
|
|
"details": {
|
|
"action_type": action_data.get('action_type'),
|
|
"rollback_type": "generic",
|
|
"system_state_available": bool(system_state)
|
|
}
|
|
}
|
|
|
|
except Exception as e:
|
|
return {
|
|
"success": False,
|
|
"error": f"Failed to perform generic rollback: {str(e)}"
|
|
}
|
|
|
|
async def rollback_latest_actions(self, count: int = 1) -> List[Dict[str, Any]]:
|
|
"""Rollback the latest N actions"""
|
|
results = []
|
|
|
|
# Get latest checkpoints
|
|
latest_checkpoints = self.checkpoints[-count:] if self.checkpoints else []
|
|
|
|
for checkpoint in reversed(latest_checkpoints):
|
|
result = await self.rollback_to_checkpoint(checkpoint.checkpoint_id)
|
|
results.append(result)
|
|
|
|
return results
|
|
|
|
def get_checkpoints(self, limit: int = 50) -> List[Dict[str, Any]]:
|
|
"""Get recent checkpoints"""
|
|
checkpoints_data = []
|
|
|
|
for checkpoint in self.checkpoints[-limit:]:
|
|
checkpoints_data.append({
|
|
"checkpoint_id": checkpoint.checkpoint_id,
|
|
"action_id": checkpoint.action_id,
|
|
"action_type": checkpoint.action_type,
|
|
"agent_id": checkpoint.agent_id,
|
|
"created_at": checkpoint.created_at,
|
|
"system_state_keys": list(checkpoint.system_state.keys())
|
|
})
|
|
|
|
return checkpoints_data
|
|
|
|
def get_rollback_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
|
"""Get rollback history"""
|
|
return self.rollback_history[-limit:] if self.rollback_history else []
|
|
|
|
class UserApprovalSystem:
|
|
"""Manages user approval for high-risk actions"""
|
|
|
|
def __init__(self, user_id: str):
|
|
self.user_id = user_id
|
|
self.pending_approvals: Dict[str, Dict[str, Any]] = {}
|
|
self.approval_history: List[Dict[str, Any]] = []
|
|
|
|
logger.info(f"Initialized UserApprovalSystem for user: {user_id}")
|
|
|
|
async def request_approval(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Request user approval for an action"""
|
|
try:
|
|
approval_id = f"approval_{self.user_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
|
|
|
|
approval_request = {
|
|
"approval_id": approval_id,
|
|
"action_data": action_data,
|
|
"requested_at": datetime.utcnow().isoformat(),
|
|
"status": "pending",
|
|
"expires_at": (datetime.utcnow() + timedelta(hours=24)).isoformat()
|
|
}
|
|
|
|
self.pending_approvals[approval_id] = approval_request
|
|
|
|
logger.info(f"Created approval request for user {self.user_id}: {approval_id}")
|
|
|
|
return {
|
|
"success": True,
|
|
"approval_id": approval_id,
|
|
"status": "pending",
|
|
"message": "Approval request created successfully"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error creating approval request for user {self.user_id}: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
async def approve_action(self, approval_id: str, user_decision: str, user_comments: str = "") -> Dict[str, Any]:
|
|
"""Process user approval decision"""
|
|
try:
|
|
if approval_id not in self.pending_approvals:
|
|
return {
|
|
"success": False,
|
|
"error": "Approval request not found"
|
|
}
|
|
|
|
approval_request = self.pending_approvals[approval_id]
|
|
|
|
# Check if approval has expired
|
|
expires_at = datetime.fromisoformat(approval_request["expires_at"])
|
|
if datetime.utcnow() > expires_at:
|
|
del self.pending_approvals[approval_id]
|
|
return {
|
|
"success": False,
|
|
"error": "Approval request has expired"
|
|
}
|
|
|
|
# Process decision
|
|
approval_request["status"] = user_decision
|
|
approval_request["decision_at"] = datetime.utcnow().isoformat()
|
|
approval_request["user_comments"] = user_comments
|
|
|
|
# Record in history
|
|
self.approval_history.append(approval_request)
|
|
|
|
# Remove from pending
|
|
del self.pending_approvals[approval_id]
|
|
|
|
# Keep only recent history (last 100)
|
|
if len(self.approval_history) > 100:
|
|
self.approval_history = self.approval_history[-100:]
|
|
|
|
logger.info(f"Processed approval decision for user {self.user_id}: {approval_id} - {user_decision}")
|
|
|
|
return {
|
|
"success": True,
|
|
"approval_id": approval_id,
|
|
"status": user_decision,
|
|
"message": f"Action {user_decision} successfully"
|
|
}
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error processing approval decision for user {self.user_id}: {e}")
|
|
return {
|
|
"success": False,
|
|
"error": str(e)
|
|
}
|
|
|
|
def get_pending_approvals(self) -> List[Dict[str, Any]]:
|
|
"""Get all pending approval requests"""
|
|
return list(self.pending_approvals.values())
|
|
|
|
def get_approval_history(self, limit: int = 50) -> List[Dict[str, Any]]:
|
|
"""Get recent approval history"""
|
|
return self.approval_history[-limit:] if self.approval_history else []
|
|
|
|
def get_approval_statistics(self) -> Dict[str, Any]:
|
|
"""Get approval statistics"""
|
|
if not self.approval_history:
|
|
return {
|
|
"total_approvals": 0,
|
|
"approved_count": 0,
|
|
"rejected_count": 0,
|
|
"approval_rate": 0.0,
|
|
"pending_count": len(self.pending_approvals)
|
|
}
|
|
|
|
total = len(self.approval_history)
|
|
approved = len([a for a in self.approval_history if a["status"] == "approved"])
|
|
rejected = len([a for a in self.approval_history if a["status"] == "rejected"])
|
|
|
|
return {
|
|
"total_approvals": total,
|
|
"approved_count": approved,
|
|
"rejected_count": rejected,
|
|
"approval_rate": approved / total if total > 0 else 0.0,
|
|
"pending_count": len(self.pending_approvals)
|
|
}
|
|
|
|
# Global safety framework instance
|
|
safety_framework_instances: Dict[str, Dict[str, Any]] = {}
|
|
|
|
def get_safety_framework(user_id: str) -> Dict[str, Any]:
|
|
"""Get or create safety framework components for a user"""
|
|
if user_id not in safety_framework_instances:
|
|
safety_framework_instances[user_id] = {
|
|
"constraint_manager": SafetyConstraintManager(user_id),
|
|
"rollback_manager": RollbackManager(user_id),
|
|
"approval_system": UserApprovalSystem(user_id)
|
|
}
|
|
|
|
return safety_framework_instances[user_id]
|
|
|
|
# Convenience functions
|
|
async def validate_agent_action(user_id: str, action_data: Dict[str, Any]) -> SafetyValidation:
|
|
"""Validate an agent action for a user"""
|
|
framework = get_safety_framework(user_id)
|
|
return await framework["constraint_manager"].validate_action(action_data)
|
|
|
|
async def create_action_checkpoint(user_id: str, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> str:
|
|
"""Create a checkpoint for an action"""
|
|
framework = get_safety_framework(user_id)
|
|
return await framework["rollback_manager"].create_checkpoint(action_data, system_state)
|
|
|
|
async def rollback_to_checkpoint(user_id: str, checkpoint_id: str) -> Dict[str, Any]:
|
|
"""Rollback to a specific checkpoint"""
|
|
framework = get_safety_framework(user_id)
|
|
return await framework["rollback_manager"].rollback_to_checkpoint(checkpoint_id)
|
|
|
|
async def request_user_approval(user_id: str, action_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Request user approval for an action"""
|
|
framework = get_safety_framework(user_id)
|
|
return await framework["approval_system"].request_approval(action_data) |