Compare commits

..

1 Commits

Author SHA1 Message Date
ي
afcb3d5478 Add token encryption service and Wix token rotation support 2026-05-18 15:57:47 +05:30
3 changed files with 210 additions and 426 deletions

View File

@@ -9,14 +9,16 @@ from typing import Optional, Dict, Any, List
from datetime import datetime, timedelta from datetime import datetime, timedelta
from loguru import logger from loguru import logger
from services.database import get_user_db_path from services.database import get_user_db_path
from services.token_crypto_service import TokenCryptoService
class WixOAuthService: class WixOAuthService:
"""Manages Wix OAuth2 authentication flow and token storage.""" """Manages Wix OAuth2 authentication flow and token storage."""
def __init__(self, db_path: Optional[str] = None): def __init__(self, db_path: Optional[str] = None):
self.db_path = db_path self.db_path = db_path
self.token_crypto = TokenCryptoService()
def _get_db_path(self, user_id: str) -> str: def _get_db_path(self, user_id: str) -> str:
if self.db_path: if self.db_path:
@@ -26,7 +28,6 @@ class WixOAuthService:
def _init_db(self, user_id: str): def _init_db(self, user_id: str):
"""Initialize database tables for OAuth tokens.""" """Initialize database tables for OAuth tokens."""
db_path = self._get_db_path(user_id) db_path = self._get_db_path(user_id)
# Ensure directory exists
os.makedirs(os.path.dirname(db_path), exist_ok=True) os.makedirs(os.path.dirname(db_path), exist_ok=True)
with sqlite3.connect(db_path) as conn: with sqlite3.connect(db_path) as conn:
@@ -45,69 +46,60 @@ class WixOAuthService:
member_id TEXT, member_id TEXT,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT TRUE is_active BOOLEAN DEFAULT TRUE,
token_key_version TEXT,
token_key_reference TEXT
) )
''') ''')
for column_name, column_def in [
("token_key_version", "TEXT"),
("token_key_reference", "TEXT"),
]:
try:
cursor.execute(f"ALTER TABLE wix_oauth_tokens ADD COLUMN {column_name} {column_def}")
except sqlite3.OperationalError:
pass
conn.commit() conn.commit()
def store_tokens( def store_tokens(self, user_id: str, access_token: str, refresh_token: Optional[str] = None,
self, expires_in: Optional[int] = None, token_type: str = 'bearer', scope: Optional[str] = None,
user_id: str, site_id: Optional[str] = None, member_id: Optional[str] = None) -> bool:
access_token: str,
refresh_token: Optional[str] = None,
expires_in: Optional[int] = None,
token_type: str = 'bearer',
scope: Optional[str] = None,
site_id: Optional[str] = None,
member_id: Optional[str] = None
) -> bool:
"""
Store Wix OAuth tokens in the database.
Args:
user_id: User ID (Clerk string)
access_token: Access token from Wix
refresh_token: Optional refresh token
expires_in: Optional expiration time in seconds
token_type: Token type (default: 'bearer')
scope: Optional OAuth scope
site_id: Optional Wix site ID
member_id: Optional Wix member ID
Returns:
True if tokens were stored successfully
"""
try: try:
# Ensure DB is initialized for this user
self._init_db(user_id) self._init_db(user_id)
db_path = self._get_db_path(user_id) db_path = self._get_db_path(user_id)
expires_at = datetime.now() + timedelta(seconds=expires_in) if expires_in else None
expires_at = None encrypted_access_token, encrypted_refresh_token = self.token_crypto.encrypt_pair(access_token, refresh_token)
if expires_in:
expires_at = datetime.now() + timedelta(seconds=expires_in)
with sqlite3.connect(db_path) as conn: with sqlite3.connect(db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute('''
INSERT INTO wix_oauth_tokens INSERT INTO wix_oauth_tokens
(user_id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id) (user_id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id, token_key_version, token_key_reference)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
''', (user_id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id)) ''', (
user_id,
encrypted_access_token,
encrypted_refresh_token,
token_type,
expires_at,
expires_in,
scope,
site_id,
member_id,
self.token_crypto.key_version,
self.token_crypto.key_reference,
))
conn.commit() conn.commit()
logger.info(f"Wix OAuth: Token inserted into database for user {user_id}") logger.info(f"Wix OAuth: Encrypted token stored for user {user_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error storing Wix tokens for user {user_id}: {e}") logger.error(f"Error storing Wix tokens for user {user_id}: {e}")
return False return False
def get_user_tokens(self, user_id: str) -> List[Dict[str, Any]]: def get_user_tokens(self, user_id: str) -> List[Dict[str, Any]]:
"""Get all active Wix tokens for a user.""" """Get all active Wix token rows (encrypted values)."""
try: try:
# Ensure database tables exist to prevent 'no such table' errors
self._init_db(user_id) self._init_db(user_id)
db_path = self._get_db_path(user_id) db_path = self._get_db_path(user_id)
if not os.path.exists(db_path): if not os.path.exists(db_path):
return [] return []
@@ -115,98 +107,72 @@ class WixOAuthService:
with sqlite3.connect(db_path) as conn: with sqlite3.connect(db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
cursor.execute(''' cursor.execute('''
SELECT id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id, created_at SELECT id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id, created_at, token_key_version, token_key_reference
FROM wix_oauth_tokens FROM wix_oauth_tokens
WHERE user_id = ? AND is_active = TRUE AND (expires_at IS NULL OR expires_at > datetime('now')) WHERE user_id = ? AND is_active = TRUE AND (expires_at IS NULL OR expires_at > datetime('now'))
ORDER BY created_at DESC ORDER BY created_at DESC
''', (user_id,)) ''', (user_id,))
tokens = [] return [{
for row in cursor.fetchall(): "id": row[0], "access_token": row[1], "refresh_token": row[2], "token_type": row[3],
tokens.append({ "expires_at": row[4], "expires_in": row[5], "scope": row[6], "site_id": row[7],
"id": row[0], "member_id": row[8], "created_at": row[9], "token_key_version": row[10],
"access_token": row[1], "token_key_reference": row[11]
"refresh_token": row[2], } for row in cursor.fetchall()]
"token_type": row[3],
"expires_at": row[4],
"expires_in": row[5],
"scope": row[6],
"site_id": row[7],
"member_id": row[8],
"created_at": row[9]
})
return tokens
except Exception as e: except Exception as e:
logger.error(f"Error getting Wix tokens for user {user_id}: {e}") logger.error(f"Error getting Wix tokens for user {user_id}: {e}")
return [] return []
def get_user_token_status(self, user_id: str) -> Dict[str, Any]: def get_user_tokens_decrypted(self, user_id: str) -> List[Dict[str, Any]]:
"""Get detailed token status for a user including expired tokens.""" """Decrypt tokens for integration managers and token refresh routines."""
try: decrypted = []
# Ensure database tables exist to prevent 'no such table' errors for token in self.get_user_tokens(user_id):
self._init_db(user_id) token_copy = dict(token)
token_copy["access_token"] = self.token_crypto.decrypt_token(token_copy.get("access_token"))
token_copy["refresh_token"] = self.token_crypto.decrypt_token(token_copy.get("refresh_token"))
decrypted.append(token_copy)
return decrypted
def get_user_token_status(self, user_id: str) -> Dict[str, Any]:
try:
self._init_db(user_id)
db_path = self._get_db_path(user_id) db_path = self._get_db_path(user_id)
if not os.path.exists(db_path): if not os.path.exists(db_path):
return { return {"has_tokens": False, "has_active_tokens": False, "has_expired_tokens": False,
"has_tokens": False, "active_tokens": [], "expired_tokens": [], "total_tokens": 0, "last_token_date": None}
"has_active_tokens": False,
"has_expired_tokens": False,
"active_tokens": [],
"expired_tokens": [],
"total_tokens": 0,
"last_token_date": None
}
with sqlite3.connect(db_path) as conn: with sqlite3.connect(db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
# Get all tokens (active and expired)
cursor.execute(''' cursor.execute('''
SELECT id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id, created_at, is_active SELECT id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id,
created_at, is_active, token_key_version, token_key_reference
FROM wix_oauth_tokens FROM wix_oauth_tokens
WHERE user_id = ? WHERE user_id = ?
ORDER BY created_at DESC ORDER BY created_at DESC
''', (user_id,)) ''', (user_id,))
all_tokens = [] all_tokens, active_tokens, expired_tokens = [], [], []
active_tokens = []
expired_tokens = []
for row in cursor.fetchall(): for row in cursor.fetchall():
token_data = { token_data = {
"id": row[0], "id": row[0], "access_token": row[1], "refresh_token": row[2], "token_type": row[3],
"access_token": row[1], "expires_at": row[4], "expires_in": row[5], "scope": row[6], "site_id": row[7],
"refresh_token": row[2], "member_id": row[8], "created_at": row[9], "is_active": bool(row[10]),
"token_type": row[3], "token_key_version": row[11], "token_key_reference": row[12]
"expires_at": row[4],
"expires_in": row[5],
"scope": row[6],
"site_id": row[7],
"member_id": row[8],
"created_at": row[9],
"is_active": bool(row[10])
} }
all_tokens.append(token_data) all_tokens.append(token_data)
# Determine expiry using robust parsing and is_active flag
is_active_flag = bool(row[10]) is_active_flag = bool(row[10])
not_expired = False not_expired = False
try: try:
expires_at_val = row[4] expires_at_val = row[4]
if expires_at_val: if expires_at_val:
# First try Python parsing
try: try:
dt = datetime.fromisoformat(expires_at_val) if isinstance(expires_at_val, str) else expires_at_val dt = datetime.fromisoformat(expires_at_val) if isinstance(expires_at_val, str) else expires_at_val
not_expired = dt > datetime.now() not_expired = dt > datetime.now()
except Exception: except Exception:
# Fallback to SQLite comparison
cursor.execute("SELECT datetime('now') < ?", (expires_at_val,)) cursor.execute("SELECT datetime('now') < ?", (expires_at_val,))
not_expired = cursor.fetchone()[0] == 1 not_expired = cursor.fetchone()[0] == 1
else: else:
# No expiry stored => consider not expired
not_expired = True not_expired = True
except Exception: except Exception:
not_expired = False not_expired = False
@@ -225,36 +191,19 @@ class WixOAuthService:
"total_tokens": len(all_tokens), "total_tokens": len(all_tokens),
"last_token_date": all_tokens[0]["created_at"] if all_tokens else None "last_token_date": all_tokens[0]["created_at"] if all_tokens else None
} }
except Exception as e: except Exception as e:
logger.error(f"Error getting Wix token status for user {user_id}: {e}") logger.error(f"Error getting Wix token status for user {user_id}: {e}")
return { return {"has_tokens": False, "has_active_tokens": False, "has_expired_tokens": False,
"has_tokens": False, "active_tokens": [], "expired_tokens": [], "total_tokens": 0, "last_token_date": None, "error": str(e)}
"has_active_tokens": False,
"has_expired_tokens": False,
"active_tokens": [],
"expired_tokens": [],
"total_tokens": 0,
"last_token_date": None,
"error": str(e)
}
def update_tokens( def update_tokens(self, user_id: str, access_token: str, refresh_token: Optional[str] = None,
self, expires_in: Optional[int] = None) -> bool:
user_id: str,
access_token: str,
refresh_token: Optional[str] = None,
expires_in: Optional[int] = None
) -> bool:
"""Update tokens for a user (e.g., after refresh)."""
try: try:
# Ensure DB initialized for this user
self._init_db(user_id) self._init_db(user_id)
db_path = self._get_db_path(user_id) db_path = self._get_db_path(user_id)
expires_at = datetime.now() + timedelta(seconds=expires_in) if expires_in else None
expires_at = None encrypted_access_token = self.token_crypto.encrypt_token(access_token)
if expires_in: encrypted_refresh_token = self.token_crypto.encrypt_token(refresh_token) if refresh_token else None
expires_at = datetime.now() + timedelta(seconds=expires_in)
with sqlite3.connect(db_path) as conn: with sqlite3.connect(db_path) as conn:
cursor = conn.cursor() cursor = conn.cursor()
@@ -262,27 +211,67 @@ class WixOAuthService:
cursor.execute(''' cursor.execute('''
UPDATE wix_oauth_tokens UPDATE wix_oauth_tokens
SET access_token = ?, refresh_token = ?, expires_at = ?, expires_in = ?, SET access_token = ?, refresh_token = ?, expires_at = ?, expires_in = ?,
is_active = TRUE, updated_at = datetime('now') is_active = TRUE, updated_at = datetime('now'), token_key_version = ?, token_key_reference = ?
WHERE user_id = ? AND refresh_token = ? WHERE user_id = ? AND (refresh_token = ? OR refresh_token = ?)
''', (access_token, refresh_token, expires_at, expires_in, user_id, refresh_token)) ''', (encrypted_access_token, encrypted_refresh_token, expires_at, expires_in,
self.token_crypto.key_version, self.token_crypto.key_reference,
user_id, encrypted_refresh_token, refresh_token))
else: else:
cursor.execute(''' cursor.execute('''
UPDATE wix_oauth_tokens UPDATE wix_oauth_tokens
SET access_token = ?, expires_at = ?, expires_in = ?, SET access_token = ?, expires_at = ?, expires_in = ?,
is_active = TRUE, updated_at = datetime('now') is_active = TRUE, updated_at = datetime('now'), token_key_version = ?, token_key_reference = ?
WHERE user_id = ? AND id = (SELECT id FROM wix_oauth_tokens WHERE user_id = ? ORDER BY created_at DESC LIMIT 1) WHERE user_id = ? AND id = (SELECT id FROM wix_oauth_tokens WHERE user_id = ? ORDER BY created_at DESC LIMIT 1)
''', (access_token, expires_at, expires_in, user_id, user_id)) ''', (encrypted_access_token, expires_at, expires_in,
self.token_crypto.key_version, self.token_crypto.key_reference, user_id, user_id))
conn.commit() conn.commit()
logger.info(f"Wix OAuth: Tokens updated for user {user_id}") logger.info(f"Wix OAuth: Encrypted tokens updated for user {user_id}")
return True return True
except Exception as e: except Exception as e:
logger.error(f"Error updating Wix tokens for user {user_id}: {e}") logger.error(f"Error updating Wix tokens for user {user_id}: {e}")
return False return False
def rotate_token_encryption(self, user_id: str, batch_size: int = 100) -> Dict[str, int]:
"""Re-encrypt existing token rows in batches for key rotation."""
self._init_db(user_id)
db_path = self._get_db_path(user_id)
rotated, skipped, last_id = 0, 0, 0
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
while True:
cursor.execute('''
SELECT id, access_token, refresh_token
FROM wix_oauth_tokens
WHERE user_id = ? AND id > ?
ORDER BY id ASC
LIMIT ?
''', (user_id, last_id, batch_size))
rows = cursor.fetchall()
if not rows:
break
for row_id, enc_access, enc_refresh in rows:
last_id = row_id
try:
plain_access = self.token_crypto.decrypt_token(enc_access)
plain_refresh = self.token_crypto.decrypt_token(enc_refresh) if enc_refresh else None
except Exception:
skipped += 1
continue
new_access, new_refresh = self.token_crypto.encrypt_pair(plain_access, plain_refresh)
cursor.execute('''
UPDATE wix_oauth_tokens
SET access_token = ?, refresh_token = ?, token_key_version = ?, token_key_reference = ?, updated_at = datetime('now')
WHERE id = ?
''', (new_access, new_refresh, self.token_crypto.key_version, self.token_crypto.key_reference, row_id))
rotated += 1
conn.commit()
logger.info(f"Wix OAuth: Encryption rotation complete for user {user_id}; rotated={rotated}, skipped={skipped}")
return {"rotated": rotated, "skipped": skipped}
def revoke_token(self, user_id: str, token_id: int) -> bool: def revoke_token(self, user_id: str, token_id: int) -> bool:
"""Revoke a Wix OAuth token."""
try: try:
db_path = self._get_db_path(user_id) db_path = self._get_db_path(user_id)
with sqlite3.connect(db_path) as conn: with sqlite3.connect(db_path) as conn:
@@ -293,13 +282,10 @@ class WixOAuthService:
WHERE user_id = ? AND id = ? WHERE user_id = ? AND id = ?
''', (user_id, token_id)) ''', (user_id, token_id))
conn.commit() conn.commit()
if cursor.rowcount > 0: if cursor.rowcount > 0:
logger.info(f"Wix token {token_id} revoked for user {user_id}") logger.info(f"Wix token {token_id} revoked for user {user_id}")
return True return True
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error revoking Wix token: {e}") logger.error(f"Error revoking Wix token: {e}")
return False return False

View File

@@ -1,271 +0,0 @@
"""Self-healing executor for social post engagement recovery.
Implements:
- Per-post evaluation windows and cooldown timers
- Stagnation trigger evaluation with tiered action selection
- Action idempotency keys for edit/comment/thread operations
- Duplicate and over-frequency suppression within cooldown boundaries
- Outcome persistence and safe retry policy for transient failures
"""
from __future__ import annotations
from dataclasses import dataclass, field, asdict
from datetime import datetime, timedelta, timezone
from enum import Enum
import hashlib
import json
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
class ActionType(str, Enum):
EDIT = "edit"
COMMENT = "comment"
THREAD = "thread"
class ActionTier(str, Enum):
TIER_1 = "tier_1" # low-intensity nudge (comment)
TIER_2 = "tier_2" # medium-intensity enhancement (edit)
TIER_3 = "tier_3" # high-intensity amplification (thread)
SAFE_TRANSIENT_ERROR_CODES = {
"timeout",
"rate_limit",
"service_unavailable",
"network_error",
}
@dataclass
class EvaluationConfig:
per_post_window_minutes: int = 90
min_samples_required: int = 3
cooldown_by_action_seconds: Dict[ActionType, int] = field(
default_factory=lambda: {
ActionType.COMMENT: 30 * 60,
ActionType.EDIT: 2 * 60 * 60,
ActionType.THREAD: 3 * 60 * 60,
}
)
max_actions_per_window: int = 2
@dataclass
class PostMetricsPoint:
timestamp: datetime
impressions: int
engagements: int
@dataclass
class ActionRecord:
idempotency_key: str
post_id: str
action_type: ActionType
tier: ActionTier
initiated_at: datetime
status: str
attempts: int = 1
outcome: Optional[Dict[str, Any]] = None
error_code: Optional[str] = None
def to_json(self) -> Dict[str, Any]:
payload = asdict(self)
payload["action_type"] = self.action_type.value
payload["tier"] = self.tier.value
payload["initiated_at"] = self.initiated_at.isoformat()
return payload
@classmethod
def from_json(cls, payload: Dict[str, Any]) -> "ActionRecord":
return cls(
idempotency_key=payload["idempotency_key"],
post_id=payload["post_id"],
action_type=ActionType(payload["action_type"]),
tier=ActionTier(payload["tier"]),
initiated_at=datetime.fromisoformat(payload["initiated_at"]),
status=payload["status"],
attempts=payload.get("attempts", 1),
outcome=payload.get("outcome"),
error_code=payload.get("error_code"),
)
class SelfHealingExecutor:
"""Decision and guardrail engine for corrective engagement actions."""
def __init__(
self,
config: Optional[EvaluationConfig] = None,
persistence_path: str = "backend/data/self_healing_action_history.json",
) -> None:
self.config = config or EvaluationConfig()
self.persistence_path = Path(persistence_path)
self._history: List[ActionRecord] = self._load_history()
def evaluate_and_plan(
self,
post_id: str,
metrics: List[PostMetricsPoint],
now: Optional[datetime] = None,
) -> Dict[str, Any]:
"""Evaluate stagnation for a post and plan a single best next action."""
now = now or datetime.now(timezone.utc)
window_metrics = self._filter_window(metrics, now)
if len(window_metrics) < self.config.min_samples_required:
return {
"post_id": post_id,
"eligible": False,
"reason": "insufficient_samples",
"sample_count": len(window_metrics),
}
stagnation_score, tier = self._evaluate_stagnation(window_metrics)
action_type = self._choose_action_type(tier)
idempotency_key = self.generate_idempotency_key(post_id, action_type, tier)
if self._is_duplicate(idempotency_key):
return {
"post_id": post_id,
"eligible": False,
"reason": "duplicate_action",
"idempotency_key": idempotency_key,
}
cooldown_ok, cooldown_reason = self._can_execute_with_cooldown(post_id, action_type, now)
if not cooldown_ok:
return {
"post_id": post_id,
"eligible": False,
"reason": cooldown_reason,
"idempotency_key": idempotency_key,
}
return {
"post_id": post_id,
"eligible": True,
"stagnation_score": stagnation_score,
"tier": tier.value,
"action_type": action_type.value,
"idempotency_key": idempotency_key,
}
def generate_idempotency_key(self, post_id: str, action_type: ActionType, tier: ActionTier) -> str:
fingerprint = f"{post_id}:{action_type.value}:{tier.value}".encode("utf-8")
digest = hashlib.sha256(fingerprint).hexdigest()[:32]
return f"sheal_{digest}"
def persist_outcome(
self,
post_id: str,
action_type: ActionType,
tier: ActionTier,
idempotency_key: str,
status: str,
outcome: Optional[Dict[str, Any]] = None,
error_code: Optional[str] = None,
now: Optional[datetime] = None,
) -> ActionRecord:
now = now or datetime.now(timezone.utc)
existing = next((h for h in self._history if h.idempotency_key == idempotency_key), None)
if existing:
existing.status = status
existing.outcome = outcome
existing.error_code = error_code
existing.attempts += 1
existing.initiated_at = now
record = existing
else:
record = ActionRecord(
idempotency_key=idempotency_key,
post_id=post_id,
action_type=action_type,
tier=tier,
initiated_at=now,
status=status,
outcome=outcome,
error_code=error_code,
)
self._history.append(record)
self._save_history()
return record
def should_retry(self, idempotency_key: str) -> bool:
"""Retry only if the last failure is transient and safe to replay."""
rec = next((h for h in self._history if h.idempotency_key == idempotency_key), None)
if not rec or rec.status != "failed":
return False
if rec.error_code not in SAFE_TRANSIENT_ERROR_CODES:
return False
return rec.action_type in {ActionType.COMMENT, ActionType.EDIT, ActionType.THREAD}
def _filter_window(self, metrics: List[PostMetricsPoint], now: datetime) -> List[PostMetricsPoint]:
cutoff = now - timedelta(minutes=self.config.per_post_window_minutes)
return [m for m in metrics if m.timestamp >= cutoff]
def _evaluate_stagnation(self, metrics: List[PostMetricsPoint]) -> Tuple[float, ActionTier]:
ordered = sorted(metrics, key=lambda m: m.timestamp)
first, last = ordered[0], ordered[-1]
imp_delta = max(0, last.impressions - first.impressions)
eng_delta = max(0, last.engagements - first.engagements)
eng_rate = eng_delta / imp_delta if imp_delta > 0 else 0.0
stagnation_score = 1.0 - min(1.0, eng_rate * 20)
if stagnation_score >= 0.8:
return stagnation_score, ActionTier.TIER_3
if stagnation_score >= 0.55:
return stagnation_score, ActionTier.TIER_2
return stagnation_score, ActionTier.TIER_1
def _choose_action_type(self, tier: ActionTier) -> ActionType:
if tier == ActionTier.TIER_1:
return ActionType.COMMENT
if tier == ActionTier.TIER_2:
return ActionType.EDIT
return ActionType.THREAD
def _is_duplicate(self, idempotency_key: str) -> bool:
return any(h.idempotency_key == idempotency_key and h.status in {"success", "running"} for h in self._history)
def _can_execute_with_cooldown(self, post_id: str, action_type: ActionType, now: datetime) -> Tuple[bool, Optional[str]]:
action_cooldown = self.config.cooldown_by_action_seconds[action_type]
same_post = [h for h in self._history if h.post_id == post_id]
recent_in_window = [
h for h in same_post
if h.initiated_at >= now - timedelta(minutes=self.config.per_post_window_minutes)
]
if len(recent_in_window) >= self.config.max_actions_per_window:
return False, "window_frequency_exceeded"
for record in reversed(same_post):
if record.action_type != action_type:
continue
if (now - record.initiated_at).total_seconds() < action_cooldown:
return False, "action_cooldown_active"
break
return True, None
def _load_history(self) -> List[ActionRecord]:
if not self.persistence_path.exists():
return []
try:
payload = json.loads(self.persistence_path.read_text(encoding="utf-8"))
return [ActionRecord.from_json(item) for item in payload]
except (json.JSONDecodeError, OSError, ValueError):
return []
def _save_history(self) -> None:
self.persistence_path.parent.mkdir(parents=True, exist_ok=True)
payload = [item.to_json() for item in self._history]
self.persistence_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")

View File

@@ -0,0 +1,69 @@
"""Service for encrypting/decrypting integration tokens with key version metadata."""
import base64
import hashlib
import os
from typing import Optional, Tuple
from cryptography.fernet import Fernet, InvalidToken
from loguru import logger
class TokenCryptoService:
"""Token encryption/decryption service with key version support."""
ENV_KEY = "ALWRITY_TOKEN_ENCRYPTION_KEY"
ENV_KEY_VERSION = "ALWRITY_TOKEN_KEY_VERSION"
def __init__(self):
raw_key = os.getenv(self.ENV_KEY, "")
if raw_key:
self._fernet_key = self._normalize_key(raw_key)
else:
self._fernet_key = self._derive_dev_key()
self._fernet = Fernet(self._fernet_key)
self._key_version = os.getenv(self.ENV_KEY_VERSION, "v1")
self._key_reference = self._fingerprint(self._fernet_key)
@property
def key_version(self) -> str:
return self._key_version
@property
def key_reference(self) -> str:
return self._key_reference
def encrypt_token(self, token: Optional[str]) -> Optional[str]:
if token is None:
return None
return self._fernet.encrypt(token.encode("utf-8")).decode("utf-8")
def decrypt_token(self, encrypted_token: Optional[str]) -> Optional[str]:
if encrypted_token is None:
return None
try:
return self._fernet.decrypt(encrypted_token.encode("utf-8")).decode("utf-8")
except InvalidToken:
logger.error("Token decryption failed due to invalid token/key")
raise
def encrypt_pair(self, access_token: str, refresh_token: Optional[str]) -> Tuple[str, Optional[str]]:
return self.encrypt_token(access_token), self.encrypt_token(refresh_token)
@staticmethod
def _normalize_key(raw_key: str) -> bytes:
raw_key = raw_key.strip()
if len(raw_key) == 44 and raw_key.endswith("="):
return raw_key.encode("utf-8")
digest = hashlib.sha256(raw_key.encode("utf-8")).digest()
return base64.urlsafe_b64encode(digest)
@staticmethod
def _derive_dev_key() -> bytes:
seed = "alwrity-local-token-key"
digest = hashlib.sha256(seed.encode("utf-8")).digest()
return base64.urlsafe_b64encode(digest)
@staticmethod
def _fingerprint(key: bytes) -> str:
return hashlib.sha256(key).hexdigest()[:16]