Files
ALwrity/backend/services/intelligence/agents/safety_framework.py

1019 lines
43 KiB
Python

"""
Agent Safety Framework for ALwrity Autonomous Marketing Agents
Implements safety constraints, validation, and rollback mechanisms
"""
import asyncio
import json
import logging
from datetime import datetime, timedelta
from typing import Dict, List, Any, Optional, Set
from dataclasses import dataclass, asdict
from enum import Enum
from utils.logger_utils import get_service_logger
from services.database import get_session_for_user
from services.intelligence.agents.performance_monitor import EscalationVelocityPolicy, EscalationTier
logger = get_service_logger(__name__)
class RiskLevel(Enum):
"""Risk levels for agent actions"""
LOW = "low"
MEDIUM = "medium"
HIGH = "high"
CRITICAL = "critical"
class ActionCategory(Enum):
"""Categories of agent actions"""
CONTENT_MODIFICATION = "content_modification"
SEO_OPTIMIZATION = "seo_optimization"
COMPETITOR_RESPONSE = "competitor_response"
SOCIAL_AMPLIFICATION = "social_amplification"
STRATEGY_CHANGE = "strategy_change"
SYSTEM_CONFIGURATION = "system_configuration"
@dataclass
class SafetyConstraint:
"""Represents a safety constraint for agent actions"""
constraint_id: str
name: str
description: str
action_categories: List[ActionCategory]
risk_threshold: float # Maximum allowed risk level (0.0 to 1.0)
approval_required: bool
auto_approval_threshold: float # Risk level below which auto-approval is allowed
daily_limit: Optional[int] = None # Maximum actions per day
hourly_limit: Optional[int] = None # Maximum actions per hour
conditions: Dict[str, Any] = None # Additional conditions for validation
created_at: str = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow().isoformat()
if self.conditions is None:
self.conditions = {}
@dataclass
class ActionCheckpoint:
"""Represents a checkpoint for rollback purposes"""
checkpoint_id: str
action_id: str
agent_id: str
user_id: str
action_type: str
action_data: Dict[str, Any]
system_state: Dict[str, Any]
created_at: str = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow().isoformat()
@dataclass
class SafetyValidation:
"""Result of safety validation"""
is_valid: bool
risk_level: RiskLevel
violations: List[str]
recommendations: List[str]
requires_approval: bool
confidence_score: float # 0.0 to 1.0
validation_timestamp: str = None
def __post_init__(self):
if self.validation_timestamp is None:
self.validation_timestamp = datetime.utcnow().isoformat()
@dataclass
class EscalationDecision:
"""Structured escalation payload for autonomous safety routing."""
tier: str
action: str
confidence: float
risk_class: str
rationale: str
velocity: Dict[str, Any]
lockout_auto_edits: bool
executor: Optional[str]
created_at: str = None
def __post_init__(self):
if self.created_at is None:
self.created_at = datetime.utcnow().isoformat()
class SafetyConstraintManager:
"""Manages safety constraints for agent actions"""
def __init__(self, user_id: str):
self.user_id = user_id
self.constraints: Dict[str, SafetyConstraint] = {}
self.action_history: List[Dict[str, Any]] = []
self.violation_history: List[Dict[str, Any]] = []
self.escalation_policy = EscalationVelocityPolicy()
self.escalation_history: List[Dict[str, Any]] = []
self.auto_edit_lockout = False
self.executor_routes = {"tier_1": "autonomous_guardian_executor", "tier_2": "autonomous_recovery_executor"}
self.alert_history: List[Dict[str, Any]] = []
# Initialize default constraints
self._initialize_default_constraints()
logger.info(f"Initialized SafetyConstraintManager for user: {user_id}")
def _initialize_default_constraints(self):
"""Initialize default safety constraints"""
default_constraints = [
SafetyConstraint(
constraint_id="content_modification_limit",
name="Content Modification Daily Limit",
description="Limit the number of content modifications per day",
action_categories=[ActionCategory.CONTENT_MODIFICATION],
risk_threshold=0.7,
approval_required=False,
auto_approval_threshold=0.3,
daily_limit=50,
hourly_limit=10
),
SafetyConstraint(
constraint_id="high_risk_approval_required",
name="High Risk Action Approval",
description="Require approval for high-risk actions",
action_categories=[ActionCategory.STRATEGY_CHANGE, ActionCategory.SYSTEM_CONFIGURATION],
risk_threshold=0.8,
approval_required=True,
auto_approval_threshold=0.2
),
SafetyConstraint(
constraint_id="competitor_response_cooldown",
name="Competitor Response Cooldown",
description="Prevent excessive competitor responses",
action_categories=[ActionCategory.COMPETITOR_RESPONSE],
risk_threshold=0.6,
approval_required=False,
auto_approval_threshold=0.4,
daily_limit=20,
hourly_limit=5
),
SafetyConstraint(
constraint_id="seo_optimization_safety",
name="SEO Optimization Safety",
description="Ensure SEO optimizations don't harm rankings",
action_categories=[ActionCategory.SEO_OPTIMIZATION],
risk_threshold=0.5,
approval_required=False,
auto_approval_threshold=0.3,
daily_limit=30,
hourly_limit=8
),
SafetyConstraint(
constraint_id="social_amplification_limits",
name="Social Amplification Limits",
description="Limit social media amplification to prevent spam",
action_categories=[ActionCategory.SOCIAL_AMPLIFICATION],
risk_threshold=0.6,
approval_required=False,
auto_approval_threshold=0.4,
daily_limit=25,
hourly_limit=6
)
]
for constraint in default_constraints:
self.constraints[constraint.constraint_id] = constraint
async def validate_action(self, action_data: Dict[str, Any]) -> SafetyValidation:
"""Validate an action against safety constraints"""
try:
logger.info(f"Validating action for user {self.user_id}: {action_data.get('action_type', 'unknown')}")
violations = []
recommendations = []
requires_approval = False
confidence_score = 1.0
# Extract action details
action_type = action_data.get('action_type', 'unknown')
action_category = self._determine_action_category(action_type)
risk_score = action_data.get('risk_score', 0.5)
impact_score = action_data.get('impact_score', 0.5)
# Determine risk level
risk_level = self._calculate_risk_level(risk_score, impact_score)
# Check against all relevant constraints
for constraint in self.constraints.values():
if action_category in constraint.action_categories:
constraint_result = await self._check_constraint(constraint, action_data, risk_level)
if not constraint_result['is_valid']:
violations.extend(constraint_result['violations'])
confidence_score *= 0.9 # Reduce confidence for violations
if constraint_result['requires_approval']:
requires_approval = True
recommendations.extend(constraint_result['recommendations'])
# Check rate limits
rate_limit_result = await self._check_rate_limits(action_category, action_data)
if not rate_limit_result['is_valid']:
violations.extend(rate_limit_result['violations'])
confidence_score *= 0.8
# Check for suspicious patterns
pattern_result = await self._check_suspicious_patterns(action_data)
if not pattern_result['is_valid']:
violations.extend(pattern_result['violations'])
confidence_score *= 0.7
requires_approval = True # Suspicious patterns always require approval
# Final validation
is_valid = len(violations) == 0 and not requires_approval
logger.info(f"Action validation completed for user {self.user_id}. Valid: {is_valid}, Risk: {risk_level.value}, Violations: {len(violations)}")
# Record in history
await self._record_validation_history(action_data, is_valid, violations)
validation = SafetyValidation(
is_valid=is_valid,
risk_level=risk_level,
violations=violations,
recommendations=recommendations,
requires_approval=requires_approval,
confidence_score=max(0.0, min(1.0, confidence_score))
)
escalation = await self.evaluate_escalation(action_data, validation)
if escalation:
recommendations.append(f"Escalation action: {escalation.action} ({escalation.tier})")
return validation
except Exception as e:
logger.error(f"Error validating action for user {self.user_id}: {e}")
# Return safe default on error
return SafetyValidation(
is_valid=False,
risk_level=RiskLevel.CRITICAL,
violations=["Validation system error"],
recommendations=["Manual review required"],
requires_approval=True,
confidence_score=0.0
)
def _determine_action_category(self, action_type: str) -> ActionCategory:
"""Determine the category of an action"""
action_type_lower = action_type.lower()
if any(keyword in action_type_lower for keyword in ['content', 'blog', 'article', 'post']):
return ActionCategory.CONTENT_MODIFICATION
elif any(keyword in action_type_lower for keyword in ['seo', 'meta', 'keyword', 'optimization']):
return ActionCategory.SEO_OPTIMIZATION
elif any(keyword in action_type_lower for keyword in ['competitor', 'competitive', 'response']):
return ActionCategory.COMPETITOR_RESPONSE
elif any(keyword in action_type_lower for keyword in ['social', 'share', 'amplify', 'distribute']):
return ActionCategory.SOCIAL_AMPLIFICATION
elif any(keyword in action_type_lower for keyword in ['strategy', 'plan', 'approach']):
return ActionCategory.STRATEGY_CHANGE
elif any(keyword in action_type_lower for keyword in ['config', 'setting', 'system']):
return ActionCategory.SYSTEM_CONFIGURATION
else:
return ActionCategory.CONTENT_MODIFICATION # Default category
def _calculate_risk_level(self, risk_score: float, impact_score: float) -> RiskLevel:
"""Calculate overall risk level"""
# Weighted combination of risk and impact
combined_score = (risk_score * 0.6) + (impact_score * 0.4)
if combined_score >= 0.8:
return RiskLevel.CRITICAL
elif combined_score >= 0.6:
return RiskLevel.HIGH
elif combined_score >= 0.3:
return RiskLevel.MEDIUM
else:
return RiskLevel.LOW
async def _check_constraint(self, constraint: SafetyConstraint, action_data: Dict[str, Any], risk_level: RiskLevel) -> Dict[str, Any]:
"""Check an action against a specific constraint"""
violations = []
recommendations = []
requires_approval = False
# Check risk threshold
if risk_level.value in ['high', 'critical'] and constraint.risk_threshold < 0.8:
violations.append(f"Risk level {risk_level.value} exceeds constraint threshold")
requires_approval = True
# Check rate limits
if constraint.daily_limit:
daily_count = await self._get_daily_action_count(constraint.constraint_id)
if daily_count >= constraint.daily_limit:
violations.append(f"Daily limit exceeded: {daily_count}/{constraint.daily_limit}")
if constraint.hourly_limit:
hourly_count = await self._get_hourly_action_count(constraint.constraint_id)
if hourly_count >= constraint.hourly_limit:
violations.append(f"Hourly limit exceeded: {hourly_count}/{constraint.hourly_limit}")
# Check approval requirement
if constraint.approval_required:
requires_approval = True
recommendations.append("Action requires manual approval due to safety constraints")
# Check auto-approval threshold
risk_score = action_data.get('risk_score', 0.5)
if risk_score > constraint.auto_approval_threshold:
requires_approval = True
# Custom condition checks
if constraint.conditions:
condition_result = await self._check_custom_conditions(constraint.conditions, action_data)
if not condition_result['is_valid']:
violations.extend(condition_result['violations'])
is_valid = len(violations) == 0 and not requires_approval
return {
"is_valid": is_valid,
"violations": violations,
"recommendations": recommendations,
"requires_approval": requires_approval
}
async def _check_rate_limits(self, action_category: ActionCategory, action_data: Dict[str, Any]) -> Dict[str, Any]:
"""Check rate limits for actions"""
violations = []
# Get current time window counts
recent_actions = await self._get_recent_actions(hours=1)
category_actions = [action for action in recent_actions if self._determine_action_category(action.get('action_type', '')) == action_category]
# Check hourly limits
if len(category_actions) > 50: # Default hourly limit
violations.append(f"Hourly action limit exceeded for {action_category.value}")
# Check daily limits
daily_actions = await self._get_recent_actions(hours=24)
daily_category_actions = [action for action in daily_actions if self._determine_action_category(action.get('action_type', '')) == action_category]
if len(daily_category_actions) > 200: # Default daily limit
violations.append(f"Daily action limit exceeded for {action_category.value}")
return {
"is_valid": len(violations) == 0,
"violations": violations
}
async def _check_suspicious_patterns(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
"""Check for suspicious patterns in actions"""
violations = []
# Get recent action patterns
recent_actions = await self._get_recent_actions(hours=24)
# Check for rapid repetitive actions
action_type = action_data.get('action_type', '')
similar_actions = [action for action in recent_actions if action.get('action_type') == action_type]
if len(similar_actions) > 10: # More than 10 similar actions in 24 hours
violations.append(f"Suspicious pattern: {len(similar_actions)} similar actions in 24 hours")
# Check for unusual timing patterns
if len(recent_actions) > 100: # More than 100 actions in 1 hour
violations.append("Suspicious pattern: Unusually high action frequency")
# Check for conflicting actions
conflicting_actions = await self._detect_conflicting_actions(action_data, recent_actions)
if conflicting_actions:
violations.append(f"Conflicting actions detected: {len(conflicting_actions)}")
return {
"is_valid": len(violations) == 0,
"violations": violations
}
async def _detect_conflicting_actions(self, current_action: Dict[str, Any], recent_actions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Detect actions that conflict with recent actions"""
conflicts = []
# Simple conflict detection based on action types
conflicting_pairs = [
("optimize_content", "delete_content"),
("increase_keywords", "decrease_keywords"),
("enable_feature", "disable_feature")
]
current_action_type = current_action.get('action_type', '')
for pair in conflicting_pairs:
if current_action_type == pair[0]:
# Check for recent opposite action
for action in recent_actions:
if action.get('action_type') == pair[1]:
conflicts.append(action)
break
elif current_action_type == pair[1]:
# Check for recent opposite action
for action in recent_actions:
if action.get('action_type') == pair[0]:
conflicts.append(action)
break
return conflicts
async def _check_custom_conditions(self, conditions: Dict[str, Any], action_data: Dict[str, Any]) -> Dict[str, Any]:
"""Check custom conditions for constraints"""
violations = []
# Example custom conditions (can be extended)
if conditions.get('max_content_length'):
content_length = len(action_data.get('content', ''))
if content_length > conditions['max_content_length']:
violations.append(f"Content length {content_length} exceeds maximum {conditions['max_content_length']}")
if conditions.get('allowed_keywords'):
content = action_data.get('content', '').lower()
allowed_keywords = [kw.lower() for kw in conditions['allowed_keywords']]
if not any(keyword in content for keyword in allowed_keywords):
violations.append("Content does not contain required keywords")
return {
"is_valid": len(violations) == 0,
"violations": violations
}
async def _get_recent_actions(self, hours: int = 24) -> List[Dict[str, Any]]:
"""Get recent actions from history"""
cutoff_time = datetime.utcnow() - timedelta(hours=hours)
return [
action for action in self.action_history
if datetime.fromisoformat(action.get('timestamp', datetime.utcnow().isoformat())) > cutoff_time
]
async def _get_daily_action_count(self, constraint_id: str) -> int:
"""Get daily action count for a specific constraint"""
daily_actions = await self._get_recent_actions(hours=24)
return len(daily_actions)
async def _get_hourly_action_count(self, constraint_id: str) -> int:
"""Get hourly action count for a specific constraint"""
hourly_actions = await self._get_recent_actions(hours=1)
return len(hourly_actions)
async def _record_validation_history(self, action_data: Dict[str, Any], is_valid: bool, violations: List[str]):
"""Record validation in history"""
validation_record = {
"timestamp": datetime.utcnow().isoformat(),
"action_type": action_data.get('action_type', 'unknown'),
"is_valid": is_valid,
"violations": violations,
"action_data": action_data
}
self.action_history.append(validation_record)
# Keep only recent history (last 1000 records)
if len(self.action_history) > 1000:
self.action_history = self.action_history[-1000:]
# Record violations separately
if violations:
violation_record = {
"timestamp": datetime.utcnow().isoformat(),
"action_type": action_data.get('action_type', 'unknown'),
"violations": violations,
"severity": "high" if len(violations) > 2 else "medium"
}
self.violation_history.append(violation_record)
# Keep only recent violations (last 500 records)
if len(self.violation_history) > 500:
self.violation_history = self.violation_history[-500:]
async def evaluate_escalation(self, action_data: Dict[str, Any], validation: SafetyValidation) -> Optional[EscalationDecision]:
"""Evaluate velocity-triggered escalation and produce structured decision payload."""
if self.auto_edit_lockout:
decision = EscalationDecision(
tier=EscalationTier.TIER_3.value,
action="lockout_enforced",
confidence=1.0,
risk_class=RiskLevel.CRITICAL.value,
rationale="Tier 3 lockout already active; autonomous edits blocked until manual reset",
velocity={},
lockout_auto_edits=True,
executor=None
)
await self._persist_escalation_decision(decision, action_data, outcome={"status": "blocked_by_lockout"})
return decision
tier, signals = self.escalation_policy.determine_tier(self.action_history)
if not tier:
return None
risk_class_map = {EscalationTier.TIER_1: RiskLevel.MEDIUM.value, EscalationTier.TIER_2: RiskLevel.HIGH.value, EscalationTier.TIER_3: RiskLevel.CRITICAL.value}
confidence = min(1.0, max(0.1, 0.55 + (len(validation.violations) * 0.05) + ((1 - validation.confidence_score) * 0.4)))
velocity_signal = signals[tier]
velocity_payload = {
"window_minutes": velocity_signal.window_minutes,
"action_count": velocity_signal.action_count,
"actions_per_minute": round(velocity_signal.actions_per_minute, 4),
"threshold_actions_per_minute": self.escalation_policy.tier_thresholds[tier]["actions_per_minute"],
}
executor = self.executor_routes.get(tier.value)
action = "route_to_autonomous_executor" if tier in (EscalationTier.TIER_1, EscalationTier.TIER_2) else "lockout_autonomous_edits"
rationale = f"{tier.value} triggered by velocity {velocity_payload['actions_per_minute']}/min over {velocity_signal.window_minutes}m window"
decision = EscalationDecision(
tier=tier.value,
action=action,
confidence=round(confidence, 3),
risk_class=risk_class_map[tier],
rationale=rationale,
velocity=velocity_payload,
lockout_auto_edits=(tier == EscalationTier.TIER_3),
executor=executor if tier != EscalationTier.TIER_3 else None
)
outcome = await self._apply_escalation_decision(decision, action_data, validation)
await self._persist_escalation_decision(decision, action_data, outcome=outcome)
return decision
async def _apply_escalation_decision(self, decision: EscalationDecision, action_data: Dict[str, Any], validation: SafetyValidation) -> Dict[str, Any]:
if decision.tier in (EscalationTier.TIER_1.value, EscalationTier.TIER_2.value):
return {
"status": "routed",
"executor": decision.executor,
"reason": decision.rationale
}
self.auto_edit_lockout = True
brief = {
"type": "diagnostic_brief",
"severity": "critical",
"tier": decision.tier,
"user_rationale": "Autonomous edits have been paused to protect account safety after sustained high-velocity actions.",
"validation_violations": validation.violations,
"action_type": action_data.get("action_type", "unknown"),
"timestamp": datetime.utcnow().isoformat()
}
self.alert_history.append(brief)
if len(self.alert_history) > 500:
self.alert_history = self.alert_history[-500:]
return {"status": "lockout_enabled", "diagnostic_brief": brief}
async def _persist_escalation_decision(self, decision: EscalationDecision, action_data: Dict[str, Any], outcome: Dict[str, Any]):
record = {
"timestamp": datetime.utcnow().isoformat(),
"decision": asdict(decision),
"action_data": action_data,
"outcome": outcome
}
self.escalation_history.append(record)
if len(self.escalation_history) > 2000:
self.escalation_history = self.escalation_history[-2000:]
def get_escalation_history(self, limit: int = 100) -> List[Dict[str, Any]]:
return self.escalation_history[-limit:] if self.escalation_history else []
def reset_auto_edit_lockout(self):
self.auto_edit_lockout = False
def add_custom_constraint(self, constraint: SafetyConstraint):
"""Add a custom safety constraint"""
self.constraints[constraint.constraint_id] = constraint
logger.info(f"Added custom constraint for user {self.user_id}: {constraint.constraint_id}")
def remove_constraint(self, constraint_id: str):
"""Remove a safety constraint"""
if constraint_id in self.constraints:
del self.constraints[constraint_id]
logger.info(f"Removed constraint for user {self.user_id}: {constraint_id}")
def get_constraints(self) -> Dict[str, SafetyConstraint]:
"""Get all safety constraints"""
return self.constraints.copy()
def get_validation_history(self, limit: int = 100) -> List[Dict[str, Any]]:
"""Get recent validation history"""
return self.action_history[-limit:] if self.action_history else []
def get_violation_history(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Get recent violation history"""
return self.violation_history[-limit:] if self.violation_history else []
class RollbackManager:
"""Manages rollback operations for agent actions"""
def __init__(self, user_id: str):
self.user_id = user_id
self.checkpoints: List[ActionCheckpoint] = []
self.rollback_history: List[Dict[str, Any]] = []
logger.info(f"Initialized RollbackManager for user: {user_id}")
async def create_checkpoint(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> str:
"""Create a checkpoint before executing an action"""
try:
checkpoint_id = f"checkpoint_{self.user_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
checkpoint = ActionCheckpoint(
checkpoint_id=checkpoint_id,
action_id=action_data.get('action_id', 'unknown'),
agent_id=action_data.get('agent_id', 'unknown'),
user_id=self.user_id,
action_type=action_data.get('action_type', 'unknown'),
action_data=action_data,
system_state=system_state
)
self.checkpoints.append(checkpoint)
# Keep only recent checkpoints (last 100)
if len(self.checkpoints) > 100:
self.checkpoints = self.checkpoints[-100:]
logger.info(f"Created checkpoint for user {self.user_id}: {checkpoint_id}")
return checkpoint_id
except Exception as e:
logger.error(f"Error creating checkpoint for user {self.user_id}: {e}")
raise e
async def rollback_to_checkpoint(self, checkpoint_id: str) -> Dict[str, Any]:
"""Rollback to a specific checkpoint"""
try:
# Find checkpoint
checkpoint = next((cp for cp in self.checkpoints if cp.checkpoint_id == checkpoint_id), None)
if not checkpoint:
return {
"success": False,
"error": f"Checkpoint not found: {checkpoint_id}"
}
logger.info(f"Rolling back to checkpoint for user {self.user_id}: {checkpoint_id}")
# Execute rollback (implementation depends on action type)
rollback_result = await self._execute_rollback(checkpoint)
# Record in history
rollback_record = {
"timestamp": datetime.utcnow().isoformat(),
"checkpoint_id": checkpoint_id,
"action_type": checkpoint.action_type,
"success": rollback_result["success"],
"details": rollback_result
}
self.rollback_history.append(rollback_record)
# Keep only recent rollback history (last 50)
if len(self.rollback_history) > 50:
self.rollback_history = self.rollback_history[-50:]
return rollback_result
except Exception as e:
logger.error(f"Error rolling back to checkpoint {checkpoint_id} for user {self.user_id}: {e}")
return {
"success": False,
"error": str(e)
}
async def _execute_rollback(self, checkpoint: ActionCheckpoint) -> Dict[str, Any]:
"""Execute the rollback operation based on action type"""
try:
action_type = checkpoint.action_type
action_data = checkpoint.action_data
system_state = checkpoint.system_state
# Implement rollback logic for different action types
if action_type == "content_modification":
return await self._rollback_content_modification(action_data, system_state)
elif action_type == "seo_optimization":
return await self._rollback_seo_optimization(action_data, system_state)
elif action_type == "competitor_response":
return await self._rollback_competitor_response(action_data, system_state)
elif action_type == "social_amplification":
return await self._rollback_social_amplification(action_data, system_state)
else:
# Generic rollback
return await self._rollback_generic(action_data, system_state)
except Exception as e:
logger.error(f"Error executing rollback for action {action_type}: {e}")
return {
"success": False,
"error": str(e)
}
async def _rollback_content_modification(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
"""Rollback content modification"""
try:
# Implementation would depend on how content is stored and managed
# For now, return a placeholder implementation
original_content = system_state.get('original_content', {})
modified_content = action_data.get('content', {})
logger.info(f"Rolling back content modification: {action_data.get('content_id', 'unknown')}")
return {
"success": True,
"message": "Content modification rolled back successfully",
"details": {
"content_id": action_data.get('content_id'),
"rollback_type": "content_modification",
"original_state_restored": bool(original_content)
}
}
except Exception as e:
return {
"success": False,
"error": f"Failed to rollback content modification: {str(e)}"
}
async def _rollback_seo_optimization(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
"""Rollback SEO optimization"""
try:
original_seo_state = system_state.get('seo_state', {})
logger.info(f"Rolling back SEO optimization: {action_data.get('optimization_type', 'unknown')}")
return {
"success": True,
"message": "SEO optimization rolled back successfully",
"details": {
"optimization_type": action_data.get('optimization_type'),
"rollback_type": "seo_optimization",
"original_state_restored": bool(original_seo_state)
}
}
except Exception as e:
return {
"success": False,
"error": f"Failed to rollback SEO optimization: {str(e)}"
}
async def _rollback_competitor_response(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
"""Rollback competitor response"""
try:
logger.info(f"Rolling back competitor response: {action_data.get('response_type', 'unknown')}")
return {
"success": True,
"message": "Competitor response rolled back successfully",
"details": {
"response_type": action_data.get('response_type'),
"rollback_type": "competitor_response",
"original_state_restored": True
}
}
except Exception as e:
return {
"success": False,
"error": f"Failed to rollback competitor response: {str(e)}"
}
async def _rollback_social_amplification(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
"""Rollback social amplification"""
try:
logger.info(f"Rolling back social amplification: {action_data.get('platform', 'unknown')}")
return {
"success": True,
"message": "Social amplification rolled back successfully",
"details": {
"platform": action_data.get('platform'),
"rollback_type": "social_amplification",
"original_state_restored": True
}
}
except Exception as e:
return {
"success": False,
"error": f"Failed to rollback social amplification: {str(e)}"
}
async def _rollback_generic(self, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> Dict[str, Any]:
"""Generic rollback for unknown action types"""
try:
logger.info(f"Performing generic rollback for action: {action_data.get('action_type', 'unknown')}")
return {
"success": True,
"message": "Generic rollback completed",
"details": {
"action_type": action_data.get('action_type'),
"rollback_type": "generic",
"system_state_available": bool(system_state)
}
}
except Exception as e:
return {
"success": False,
"error": f"Failed to perform generic rollback: {str(e)}"
}
async def rollback_latest_actions(self, count: int = 1) -> List[Dict[str, Any]]:
"""Rollback the latest N actions"""
results = []
# Get latest checkpoints
latest_checkpoints = self.checkpoints[-count:] if self.checkpoints else []
for checkpoint in reversed(latest_checkpoints):
result = await self.rollback_to_checkpoint(checkpoint.checkpoint_id)
results.append(result)
return results
def get_checkpoints(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Get recent checkpoints"""
checkpoints_data = []
for checkpoint in self.checkpoints[-limit:]:
checkpoints_data.append({
"checkpoint_id": checkpoint.checkpoint_id,
"action_id": checkpoint.action_id,
"action_type": checkpoint.action_type,
"agent_id": checkpoint.agent_id,
"created_at": checkpoint.created_at,
"system_state_keys": list(checkpoint.system_state.keys())
})
return checkpoints_data
def get_rollback_history(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Get rollback history"""
return self.rollback_history[-limit:] if self.rollback_history else []
class UserApprovalSystem:
"""Manages user approval for high-risk actions"""
def __init__(self, user_id: str):
self.user_id = user_id
self.pending_approvals: Dict[str, Dict[str, Any]] = {}
self.approval_history: List[Dict[str, Any]] = []
logger.info(f"Initialized UserApprovalSystem for user: {user_id}")
async def request_approval(self, action_data: Dict[str, Any]) -> Dict[str, Any]:
"""Request user approval for an action"""
try:
approval_id = f"approval_{self.user_id}_{datetime.utcnow().strftime('%Y%m%d%H%M%S')}"
approval_request = {
"approval_id": approval_id,
"action_data": action_data,
"requested_at": datetime.utcnow().isoformat(),
"status": "pending",
"expires_at": (datetime.utcnow() + timedelta(hours=24)).isoformat()
}
self.pending_approvals[approval_id] = approval_request
logger.info(f"Created approval request for user {self.user_id}: {approval_id}")
return {
"success": True,
"approval_id": approval_id,
"status": "pending",
"message": "Approval request created successfully"
}
except Exception as e:
logger.error(f"Error creating approval request for user {self.user_id}: {e}")
return {
"success": False,
"error": str(e)
}
async def approve_action(self, approval_id: str, user_decision: str, user_comments: str = "") -> Dict[str, Any]:
"""Process user approval decision"""
try:
if approval_id not in self.pending_approvals:
return {
"success": False,
"error": "Approval request not found"
}
approval_request = self.pending_approvals[approval_id]
# Check if approval has expired
expires_at = datetime.fromisoformat(approval_request["expires_at"])
if datetime.utcnow() > expires_at:
del self.pending_approvals[approval_id]
return {
"success": False,
"error": "Approval request has expired"
}
# Process decision
approval_request["status"] = user_decision
approval_request["decision_at"] = datetime.utcnow().isoformat()
approval_request["user_comments"] = user_comments
# Record in history
self.approval_history.append(approval_request)
# Remove from pending
del self.pending_approvals[approval_id]
# Keep only recent history (last 100)
if len(self.approval_history) > 100:
self.approval_history = self.approval_history[-100:]
logger.info(f"Processed approval decision for user {self.user_id}: {approval_id} - {user_decision}")
return {
"success": True,
"approval_id": approval_id,
"status": user_decision,
"message": f"Action {user_decision} successfully"
}
except Exception as e:
logger.error(f"Error processing approval decision for user {self.user_id}: {e}")
return {
"success": False,
"error": str(e)
}
def get_pending_approvals(self) -> List[Dict[str, Any]]:
"""Get all pending approval requests"""
return list(self.pending_approvals.values())
def get_approval_history(self, limit: int = 50) -> List[Dict[str, Any]]:
"""Get recent approval history"""
return self.approval_history[-limit:] if self.approval_history else []
def get_approval_statistics(self) -> Dict[str, Any]:
"""Get approval statistics"""
if not self.approval_history:
return {
"total_approvals": 0,
"approved_count": 0,
"rejected_count": 0,
"approval_rate": 0.0,
"pending_count": len(self.pending_approvals)
}
total = len(self.approval_history)
approved = len([a for a in self.approval_history if a["status"] == "approved"])
rejected = len([a for a in self.approval_history if a["status"] == "rejected"])
return {
"total_approvals": total,
"approved_count": approved,
"rejected_count": rejected,
"approval_rate": approved / total if total > 0 else 0.0,
"pending_count": len(self.pending_approvals)
}
# Global safety framework instance
safety_framework_instances: Dict[str, Dict[str, Any]] = {}
def get_safety_framework(user_id: str) -> Dict[str, Any]:
"""Get or create safety framework components for a user"""
if user_id not in safety_framework_instances:
safety_framework_instances[user_id] = {
"constraint_manager": SafetyConstraintManager(user_id),
"rollback_manager": RollbackManager(user_id),
"approval_system": UserApprovalSystem(user_id)
}
return safety_framework_instances[user_id]
# Convenience functions
async def validate_agent_action(user_id: str, action_data: Dict[str, Any]) -> SafetyValidation:
"""Validate an agent action for a user"""
framework = get_safety_framework(user_id)
return await framework["constraint_manager"].validate_action(action_data)
async def create_action_checkpoint(user_id: str, action_data: Dict[str, Any], system_state: Dict[str, Any]) -> str:
"""Create a checkpoint for an action"""
framework = get_safety_framework(user_id)
return await framework["rollback_manager"].create_checkpoint(action_data, system_state)
async def rollback_to_checkpoint(user_id: str, checkpoint_id: str) -> Dict[str, Any]:
"""Rollback to a specific checkpoint"""
framework = get_safety_framework(user_id)
return await framework["rollback_manager"].rollback_to_checkpoint(checkpoint_id)
async def request_user_approval(user_id: str, action_data: Dict[str, Any]) -> Dict[str, Any]:
"""Request user approval for an action"""
framework = get_safety_framework(user_id)
return await framework["approval_system"].request_approval(action_data)