""" Wix OAuth2 Service Handles Wix OAuth2 authentication flow and token storage. """ import os import sqlite3 from typing import Optional, Dict, Any, List from datetime import datetime, timedelta from loguru import logger from cryptography.fernet import Fernet, InvalidToken from services.database import get_user_db_path class WixOAuthService: """Manages Wix OAuth2 authentication flow and token storage.""" def __init__(self, db_path: Optional[str] = None): self.db_path = db_path self.token_encryption_key = ( os.getenv("WIX_TOKEN_ENCRYPTION_KEY") or os.getenv("OAUTH_TOKEN_ENCRYPTION_KEY") ) self._fernet = self._initialize_fernet() self._migration_done: set = set() def _initialize_fernet(self) -> Optional[Fernet]: if not self.token_encryption_key: logger.error("Wix token encryption key is not configured.") return None try: return Fernet(self.token_encryption_key.encode("utf-8")) except Exception: logger.error("Wix token encryption key is invalid.") return None def _encrypt_token(self, token: Optional[str]) -> Optional[str]: if not token: return None if not self._fernet: raise ValueError("Token encryption is unavailable: missing/invalid managed key") return self._fernet.encrypt(token.encode("utf-8")).decode("utf-8") def _decrypt_token(self, token_blob: Optional[str]) -> Optional[str]: if not token_blob: return None if not self._fernet: raise ValueError("Token decryption is unavailable: missing/invalid managed key") return self._fernet.decrypt(token_blob.encode("utf-8")).decode("utf-8") def _is_likely_encrypted_blob(self, value: Optional[str]) -> bool: return bool(value and value.startswith("gAAAAA")) def _migrate_plaintext_tokens_if_needed(self, conn: sqlite3.Connection, user_id: str) -> None: if not self._fernet or user_id in self._migration_done: return cursor = conn.cursor() cursor.execute( "SELECT id, access_token, refresh_token FROM wix_oauth_tokens WHERE user_id = ?", (user_id,), ) rows = cursor.fetchall() migrated = 0 for token_id, access_token, refresh_token in rows: needs_access = access_token and not self._is_likely_encrypted_blob(access_token) needs_refresh = refresh_token and not self._is_likely_encrypted_blob(refresh_token) if not (needs_access or needs_refresh): continue enc_access = self._encrypt_token(access_token) if needs_access else access_token enc_refresh = self._encrypt_token(refresh_token) if needs_refresh else refresh_token cursor.execute( "UPDATE wix_oauth_tokens SET access_token = ?, refresh_token = ?, updated_at = datetime('now') WHERE id = ? AND user_id = ?", (enc_access, enc_refresh, token_id, user_id), ) migrated += 1 if migrated: conn.commit() logger.info(f"Wix OAuth token migration completed for user {user_id}; rows migrated={migrated}") self._migration_done.add(user_id) def _get_db_path(self, user_id: str) -> str: if self.db_path: return self.db_path return get_user_db_path(user_id) def _init_db(self, user_id: str): """Initialize database tables for OAuth tokens.""" db_path = self._get_db_path(user_id) # Ensure directory exists os.makedirs(os.path.dirname(db_path), exist_ok=True) with sqlite3.connect(db_path) as conn: cursor = conn.cursor() cursor.execute(''' CREATE TABLE IF NOT EXISTS wix_oauth_tokens ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id TEXT NOT NULL, access_token TEXT NOT NULL, refresh_token TEXT, token_type TEXT DEFAULT 'bearer', expires_at TIMESTAMP, expires_in INTEGER, scope TEXT, site_id TEXT, member_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, is_active BOOLEAN DEFAULT TRUE ) ''') cursor.execute(''' CREATE TABLE IF NOT EXISTS wix_oauth_pkce_states ( id INTEGER PRIMARY KEY AUTOINCREMENT, user_id TEXT NOT NULL, state TEXT NOT NULL UNIQUE, code_verifier TEXT NOT NULL, expires_at TIMESTAMP NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, used_at TIMESTAMP ) ''') cursor.execute(''' CREATE INDEX IF NOT EXISTS idx_wix_oauth_pkce_user_state ON wix_oauth_pkce_states (user_id, state) ''') conn.commit() def cleanup_expired_pkce_states(self, user_id: str) -> int: """Delete expired or already-used PKCE state records.""" try: self._init_db(user_id) db_path = self._get_db_path(user_id) with sqlite3.connect(db_path) as conn: cursor = conn.cursor() cursor.execute( ''' DELETE FROM wix_oauth_pkce_states WHERE used_at IS NOT NULL OR expires_at <= datetime('now') ''' ) deleted = cursor.rowcount conn.commit() return deleted if deleted is not None else 0 except Exception as e: logger.warning(f"Failed to cleanup expired Wix PKCE states for user {user_id}: {e}") return 0 def store_pkce_verifier(self, user_id: str, state: str, code_verifier: str, ttl_seconds: int = 600) -> bool: """Store PKCE code verifier by OAuth state with short TTL.""" try: self._init_db(user_id) self.cleanup_expired_pkce_states(user_id) db_path = self._get_db_path(user_id) expires_at = datetime.now() + timedelta(seconds=ttl_seconds) with sqlite3.connect(db_path) as conn: cursor = conn.cursor() cursor.execute( ''' INSERT OR REPLACE INTO wix_oauth_pkce_states (user_id, state, code_verifier, expires_at, created_at, used_at) VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, NULL) ''', (user_id, state, code_verifier, expires_at) ) conn.commit() return True except Exception as e: logger.error(f"Failed storing Wix PKCE verifier for user {user_id}, state {state}: {e}") return False def consume_pkce_verifier(self, user_id: str, state: str) -> Optional[str]: """Get and invalidate one-time PKCE verifier for a state if valid and unexpired.""" try: self._init_db(user_id) self.cleanup_expired_pkce_states(user_id) db_path = self._get_db_path(user_id) with sqlite3.connect(db_path) as conn: cursor = conn.cursor() cursor.execute( ''' SELECT id, code_verifier FROM wix_oauth_pkce_states WHERE user_id = ? AND state = ? AND used_at IS NULL AND expires_at > datetime('now') LIMIT 1 ''', (user_id, state) ) row = cursor.fetchone() if not row: return None cursor.execute( "UPDATE wix_oauth_pkce_states SET used_at = CURRENT_TIMESTAMP WHERE id = ?", (row[0],) ) conn.commit() return row[1] except Exception as e: logger.error(f"Failed consuming Wix PKCE verifier for user {user_id}, state {state}: {e}") return None def store_tokens( self, user_id: str, access_token: str, refresh_token: Optional[str] = None, expires_in: Optional[int] = None, token_type: str = 'bearer', scope: Optional[str] = None, site_id: Optional[str] = None, member_id: Optional[str] = None ) -> bool: """ Store Wix OAuth tokens in the database. Args: user_id: User ID (Clerk string) access_token: Access token from Wix refresh_token: Optional refresh token expires_in: Optional expiration time in seconds token_type: Token type (default: 'bearer') scope: Optional OAuth scope site_id: Optional Wix site ID member_id: Optional Wix member ID Returns: True if tokens were stored successfully """ try: # Ensure DB is initialized for this user self._init_db(user_id) db_path = self._get_db_path(user_id) expires_at = None if expires_in: expires_at = datetime.now() + timedelta(seconds=expires_in) encrypted_access = self._encrypt_token(access_token) encrypted_refresh = self._encrypt_token(refresh_token) if refresh_token else None with sqlite3.connect(db_path) as conn: cursor = conn.cursor() cursor.execute(''' INSERT INTO wix_oauth_tokens (user_id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ''', (user_id, encrypted_access, encrypted_refresh, token_type, expires_at, expires_in, scope, site_id, member_id)) conn.commit() logger.info(f"Wix OAuth: Token inserted into database for user {user_id}") return True except Exception as e: logger.error(f"Error storing Wix tokens for user {user_id}: {e}") return False def get_user_tokens(self, user_id: str) -> List[Dict[str, Any]]: """Get all active Wix tokens for a user.""" try: # Ensure database tables exist to prevent 'no such table' errors self._init_db(user_id) db_path = self._get_db_path(user_id) if not os.path.exists(db_path): return [] with sqlite3.connect(db_path) as conn: self._migrate_plaintext_tokens_if_needed(conn, user_id) cursor = conn.cursor() cursor.execute(''' SELECT id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id, created_at FROM wix_oauth_tokens WHERE user_id = ? AND is_active = TRUE AND (expires_at IS NULL OR expires_at > datetime('now')) ORDER BY created_at DESC ''', (user_id,)) tokens = [] for row in cursor.fetchall(): access_token_val = row[1] refresh_token_val = row[2] try: decrypted_access = ( self._decrypt_token(access_token_val) if self._is_likely_encrypted_blob(access_token_val) else access_token_val ) except InvalidToken: logger.error(f"Failed to decrypt Wix access token for user {user_id}, token_id={row[0]}") continue try: decrypted_refresh = ( self._decrypt_token(refresh_token_val) if self._is_likely_encrypted_blob(refresh_token_val) else refresh_token_val ) except InvalidToken: decrypted_refresh = None tokens.append({ "id": row[0], "access_token": decrypted_access, "refresh_token": decrypted_refresh, "token_type": row[3], "expires_at": row[4], "expires_in": row[5], "scope": row[6], "site_id": row[7], "member_id": row[8], "created_at": row[9] }) return tokens except Exception as e: logger.error(f"Error getting Wix tokens for user {user_id}: {e}") return [] def get_user_token_status(self, user_id: str) -> Dict[str, Any]: """Get detailed token status for a user including expired tokens.""" try: # Ensure database tables exist to prevent 'no such table' errors self._init_db(user_id) db_path = self._get_db_path(user_id) if not os.path.exists(db_path): return { "has_tokens": False, "has_active_tokens": False, "has_expired_tokens": False, "active_tokens": [], "expired_tokens": [], "total_tokens": 0, "last_token_date": None } with sqlite3.connect(db_path) as conn: self._migrate_plaintext_tokens_if_needed(conn, user_id) cursor = conn.cursor() cursor.execute(''' SELECT id, access_token, refresh_token, token_type, expires_at, expires_in, scope, site_id, member_id, created_at, is_active FROM wix_oauth_tokens WHERE user_id = ? ORDER BY created_at DESC ''', (user_id,)) all_tokens = [] active_tokens = [] expired_tokens = [] for row in cursor.fetchall(): access_token_val = row[1] refresh_token_val = row[2] try: decrypted_access = ( self._decrypt_token(access_token_val) if self._is_likely_encrypted_blob(access_token_val) else access_token_val ) except InvalidToken: decrypted_access = None try: decrypted_refresh = ( self._decrypt_token(refresh_token_val) if self._is_likely_encrypted_blob(refresh_token_val) else refresh_token_val ) except InvalidToken: decrypted_refresh = None token_data = { "id": row[0], "access_token": decrypted_access, "refresh_token": decrypted_refresh, "token_type": row[3], "expires_at": row[4], "expires_in": row[5], "scope": row[6], "site_id": row[7], "member_id": row[8], "created_at": row[9], "is_active": bool(row[10]) } all_tokens.append(token_data) # Determine expiry using robust parsing and is_active flag is_active_flag = bool(row[10]) not_expired = False try: expires_at_val = row[4] if expires_at_val: # First try Python parsing try: dt = datetime.fromisoformat(expires_at_val) if isinstance(expires_at_val, str) else expires_at_val not_expired = dt > datetime.now() except Exception: # Fallback to SQLite comparison cursor.execute("SELECT datetime('now') < ?", (expires_at_val,)) not_expired = cursor.fetchone()[0] == 1 else: # No expiry stored => consider not expired not_expired = True except Exception: not_expired = False if is_active_flag and not_expired: active_tokens.append(token_data) else: expired_tokens.append(token_data) return { "has_tokens": len(all_tokens) > 0, "has_active_tokens": len(active_tokens) > 0, "has_expired_tokens": len(expired_tokens) > 0, "active_tokens": active_tokens, "expired_tokens": expired_tokens, "total_tokens": len(all_tokens), "last_token_date": all_tokens[0]["created_at"] if all_tokens else None } except Exception as e: logger.error(f"Error getting Wix token status for user {user_id}: {e}") return { "has_tokens": False, "has_active_tokens": False, "has_expired_tokens": False, "active_tokens": [], "expired_tokens": [], "total_tokens": 0, "last_token_date": None, "error": str(e) } def update_tokens( self, user_id: str, access_token: str, refresh_token: Optional[str] = None, expires_in: Optional[int] = None, token_id: Optional[int] = None ) -> bool: """Update tokens for a user (e.g., after refresh).""" try: self._init_db(user_id) db_path = self._get_db_path(user_id) expires_at = None if expires_in: expires_at = datetime.now() + timedelta(seconds=expires_in) encrypted_access = self._encrypt_token(access_token) encrypted_refresh = self._encrypt_token(refresh_token) if refresh_token else None with sqlite3.connect(db_path) as conn: self._migrate_plaintext_tokens_if_needed(conn, user_id) cursor = conn.cursor() if token_id: if encrypted_refresh: cursor.execute(''' UPDATE wix_oauth_tokens SET access_token = ?, refresh_token = ?, expires_at = ?, expires_in = ?, is_active = TRUE, updated_at = datetime('now') WHERE user_id = ? AND id = ? ''', (encrypted_access, encrypted_refresh, expires_at, expires_in, user_id, token_id)) else: cursor.execute(''' UPDATE wix_oauth_tokens SET access_token = ?, expires_at = ?, expires_in = ?, is_active = TRUE, updated_at = datetime('now') WHERE user_id = ? AND id = ? ''', (encrypted_access, expires_at, expires_in, user_id, token_id)) else: cursor.execute(''' UPDATE wix_oauth_tokens SET access_token = ?, expires_at = ?, expires_in = ?, is_active = TRUE, updated_at = datetime('now') WHERE user_id = ? AND id = (SELECT id FROM wix_oauth_tokens WHERE user_id = ? ORDER BY created_at DESC LIMIT 1) ''', (encrypted_access, expires_at, expires_in, user_id, user_id)) conn.commit() logger.info(f"Wix OAuth: Tokens updated for user {user_id}") return True except Exception as e: logger.error(f"Error updating Wix tokens for user {user_id}: {e}") return False def revoke_token(self, user_id: str, token_id: int) -> bool: """Revoke a Wix OAuth token.""" try: db_path = self._get_db_path(user_id) with sqlite3.connect(db_path) as conn: cursor = conn.cursor() cursor.execute(''' UPDATE wix_oauth_tokens SET is_active = FALSE, updated_at = datetime('now') WHERE user_id = ? AND id = ? ''', (user_id, token_id)) conn.commit() if cursor.rowcount > 0: logger.info(f"Wix token {token_id} revoked for user {user_id}") return True return False except Exception as e: logger.error(f"Error revoking Wix token: {e}") return False