"""Tests for scan scheduling, diff engine, and scan endpoints — CMP-24. Covers: - Scanner schemas (new additions) - Scan service (job lifecycle, diff engine, cookie sync) - Scanner router (trigger, list, detail, diff endpoints) - Integration tests against live database """ import uuid from datetime import UTC, datetime from unittest.mock import AsyncMock, MagicMock, patch import pytest from httpx import ASGITransport, AsyncClient from src.schemas.scanner import ( CookieDiffItem, DiffStatus, ScanDiffResponse, ScanJobDetailResponse, ScanResultResponse, TriggerScanRequest, ) # ── Schema tests ───────────────────────────────────────────────────── class TestSchemas: """Validate scanner schema additions.""" def test_scan_result_response(self): r = ScanResultResponse( id=uuid.uuid4(), scan_job_id=uuid.uuid4(), page_url="https://example.com", cookie_name="_ga", cookie_domain=".example.com", storage_type="cookie", found_at=datetime.now(UTC), created_at=datetime.now(UTC), ) assert r.cookie_name == "_ga" def test_scan_job_detail_response(self): r = ScanJobDetailResponse( id=uuid.uuid4(), site_id=uuid.uuid4(), status="completed", trigger="manual", pages_scanned=5, pages_total=10, cookies_found=3, error_message=None, started_at=datetime.now(UTC), completed_at=datetime.now(UTC), created_at=datetime.now(UTC), updated_at=datetime.now(UTC), results=[], ) assert r.status == "completed" assert r.results == [] def test_trigger_scan_request(self): req = TriggerScanRequest(site_id=uuid.uuid4(), max_pages=100) assert req.max_pages == 100 def test_trigger_scan_request_defaults(self): req = TriggerScanRequest(site_id=uuid.uuid4()) assert req.max_pages == 50 def test_trigger_scan_max_pages_validation(self): with pytest.raises(ValueError): TriggerScanRequest(site_id=uuid.uuid4(), max_pages=0) with pytest.raises(ValueError): TriggerScanRequest(site_id=uuid.uuid4(), max_pages=501) def test_diff_status_values(self): assert DiffStatus.NEW == "new" assert DiffStatus.REMOVED == "removed" assert DiffStatus.CHANGED == "changed" def test_cookie_diff_item(self): item = CookieDiffItem( name="_ga", domain=".example.com", storage_type="cookie", diff_status=DiffStatus.NEW, details="First scan", ) assert item.diff_status == "new" def test_scan_diff_response(self): resp = ScanDiffResponse( current_scan_id=uuid.uuid4(), previous_scan_id=uuid.uuid4(), new_cookies=[ CookieDiffItem( name="_ga", domain=".example.com", storage_type="cookie", diff_status=DiffStatus.NEW, ), ], total_new=1, ) assert resp.total_new == 1 assert len(resp.new_cookies) == 1 def test_scan_diff_response_no_previous(self): resp = ScanDiffResponse( current_scan_id=uuid.uuid4(), previous_scan_id=None, ) assert resp.previous_scan_id is None assert resp.total_new == 0 # ── Diff engine unit tests ─────────────────────────────────────────── class TestDiffEngine: """Test the scan diff engine with mocked data.""" def _make_scan_result( self, name: str = "_ga", domain: str = ".example.com", storage_type: str = "cookie", script_source: str | None = None, auto_category: str | None = None, attributes: dict | None = None, ): """Create a mock ScanResult.""" mock = MagicMock() mock.cookie_name = name mock.cookie_domain = domain mock.storage_type = storage_type mock.script_source = script_source mock.auto_category = auto_category mock.attributes = attributes return mock def test_result_key(self): from src.services.scanner import _result_key mock = self._make_scan_result("_ga", ".example.com", "cookie") assert _result_key(mock) == ("_ga", ".example.com", "cookie") def test_result_key_different_storage(self): from src.services.scanner import _result_key mock = self._make_scan_result("key", "example.com", "local_storage") assert _result_key(mock) == ("key", "example.com", "local_storage") # ── Scan service unit tests ────────────────────────────────────────── class TestScanService: """Test scan service functions with mocked DB.""" @pytest.mark.asyncio async def test_create_scan_job(self): from src.services.scanner import create_scan_job db = AsyncMock() db.add = MagicMock() db.flush = AsyncMock() site_id = uuid.uuid4() job = await create_scan_job(db, site_id=site_id, trigger="manual", max_pages=10) assert job.site_id == site_id assert job.status == "pending" assert job.trigger == "manual" assert job.pages_total == 10 db.add.assert_called_once() @pytest.mark.asyncio async def test_start_scan_job(self): from src.services.scanner import start_scan_job db = AsyncMock() db.flush = AsyncMock() job = MagicMock() job.status = "pending" job.started_at = None result = await start_scan_job(db, job) assert result.status == "running" assert result.started_at is not None @pytest.mark.asyncio async def test_complete_scan_job_success(self): from src.services.scanner import complete_scan_job db = AsyncMock() db.flush = AsyncMock() job = MagicMock() result = await complete_scan_job(db, job, pages_scanned=5, cookies_found=10) assert result.status == "completed" assert result.pages_scanned == 5 assert result.cookies_found == 10 assert result.completed_at is not None @pytest.mark.asyncio async def test_complete_scan_job_failure(self): from src.services.scanner import complete_scan_job db = AsyncMock() db.flush = AsyncMock() job = MagicMock() result = await complete_scan_job(db, job, error_message="Connection failed") assert result.status == "failed" assert result.error_message == "Connection failed" @pytest.mark.asyncio async def test_add_scan_result(self): from src.services.scanner import add_scan_result db = AsyncMock() db.add = MagicMock() db.flush = AsyncMock() scan_job_id = uuid.uuid4() result = await add_scan_result( db, scan_job_id=scan_job_id, page_url="https://example.com", cookie_name="_ga", cookie_domain=".example.com", storage_type="cookie", auto_category="analytics", ) assert result.scan_job_id == scan_job_id assert result.cookie_name == "_ga" assert result.auto_category == "analytics" db.add.assert_called_once() # ── Router unit tests (mocked DB) ─────────────────────────────────── def _mock_auth_user(): """Create a mock authenticated user.""" from src.schemas.auth import CurrentUser return CurrentUser( id=uuid.uuid4(), organisation_id=uuid.uuid4(), email="test@example.com", role="owner", ) async def _authed_client(app, db, user=None): """Create an authenticated test client with mocked DB.""" from src.db import get_db from src.services.dependencies import get_current_user if user is None: user = _mock_auth_user() async def _override_get_db(): yield db app.dependency_overrides[get_db] = _override_get_db app.dependency_overrides[get_current_user] = lambda: user transport = ASGITransport(app=app) return AsyncClient(transport=transport, base_url="http://test") class TestTriggerScan: """Test POST /scanner/scans.""" @pytest.mark.asyncio async def test_trigger_scan_success(self, app): user = _mock_auth_user() db = AsyncMock() # Site exists and belongs to user's org site_mock = MagicMock() site_mock.organisation_id = user.organisation_id site_id = uuid.uuid4() job_id = uuid.uuid4() now = datetime.now(UTC) # Mock scan job returned by create_scan_job mock_job = MagicMock() mock_job.id = job_id mock_job.site_id = site_id mock_job.status = "pending" mock_job.trigger = "manual" mock_job.pages_scanned = 0 mock_job.pages_total = 25 mock_job.cookies_found = 0 mock_job.error_message = None mock_job.started_at = None mock_job.completed_at = None mock_job.created_at = now mock_job.updated_at = now # First call: site lookup. Second call: running scan count. call_count = 0 async def mock_execute(stmt): nonlocal call_count call_count += 1 result = MagicMock() if call_count == 1: # Site lookup result.scalar_one_or_none.return_value = site_mock elif call_count == 2: # Active scan jobs query — none running result.scalars.return_value.all.return_value = [] return result db.execute = mock_execute db.add = MagicMock() db.flush = AsyncMock() with ( patch( "src.routers.scanner.create_scan_job", new=AsyncMock(return_value=mock_job), ), patch("src.tasks.scanner.run_scan", create=True), ): async with await _authed_client(app, db, user) as client: resp = await client.post( "/api/v1/scanner/scans", json={ "site_id": str(site_id), "max_pages": 25, }, ) assert resp.status_code == 201 @pytest.mark.asyncio async def test_trigger_scan_site_not_found(self, app): db = AsyncMock() result = MagicMock() result.scalar_one_or_none.return_value = None db.execute = AsyncMock(return_value=result) async with await _authed_client(app, db) as client: resp = await client.post( "/api/v1/scanner/scans", json={ "site_id": str(uuid.uuid4()), "max_pages": 50, }, ) assert resp.status_code == 404 @pytest.mark.asyncio async def test_trigger_scan_conflict(self, app): user = _mock_auth_user() db = AsyncMock() # Build a non-stale active job so the router raises 409 active_job = MagicMock() active_job.status = "running" active_job.created_at = datetime.now(UTC) active_job.started_at = datetime.now(UTC) call_count = 0 async def mock_execute(stmt): nonlocal call_count call_count += 1 result = MagicMock() if call_count == 1: # Site lookup site_mock = MagicMock() site_mock.organisation_id = user.organisation_id result.scalar_one_or_none.return_value = site_mock elif call_count == 2: # Active scan jobs query — return a non-stale job result.scalars.return_value.all.return_value = [active_job] return result db.execute = mock_execute async with await _authed_client(app, db, user) as client: resp = await client.post( "/api/v1/scanner/scans", json={"site_id": str(uuid.uuid4())}, ) assert resp.status_code == 409 class TestListScans: """Test GET /scanner/scans/site/{site_id}.""" @pytest.mark.asyncio async def test_list_scans_success(self, app): user = _mock_auth_user() db = AsyncMock() call_count = 0 async def mock_execute(stmt): nonlocal call_count call_count += 1 result = MagicMock() if call_count == 1: # Site access check site_mock = MagicMock() site_mock.organisation_id = user.organisation_id result.scalar_one_or_none.return_value = site_mock else: # Scan list result.scalars.return_value.all.return_value = [] return result db.execute = mock_execute async with await _authed_client(app, db, user) as client: resp = await client.get(f"/api/v1/scanner/scans/site/{uuid.uuid4()}") assert resp.status_code == 200 assert resp.json() == [] class TestGetScan: """Test GET /scanner/scans/{scan_id}.""" @pytest.mark.asyncio async def test_get_scan_not_found(self, app): db = AsyncMock() result = MagicMock() result.scalar_one_or_none.return_value = None db.execute = AsyncMock(return_value=result) async with await _authed_client(app, db) as client: resp = await client.get(f"/api/v1/scanner/scans/{uuid.uuid4()}") assert resp.status_code == 404 class TestGetScanDiff: """Test GET /scanner/scans/{scan_id}/diff.""" @pytest.mark.asyncio async def test_diff_scan_not_found(self, app): db = AsyncMock() result = MagicMock() result.scalar_one_or_none.return_value = None db.execute = AsyncMock(return_value=result) async with await _authed_client(app, db) as client: resp = await client.get(f"/api/v1/scanner/scans/{uuid.uuid4()}/diff") assert resp.status_code == 404 # ── Integration tests ──────────────────────────────────────────────── try: from tests.conftest import create_test_site, requires_db except ImportError: from conftest import create_test_site, requires_db @requires_db class TestScanIntegration: """Integration tests against a live database.""" async def test_trigger_scan(self, db_client, auth_headers): site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-trigger") resp = await db_client.post( "/api/v1/scanner/scans", json={"site_id": site_id, "max_pages": 10}, headers=auth_headers, ) assert resp.status_code == 201 data = resp.json() assert data["status"] == "pending" assert data["trigger"] == "manual" assert data["pages_total"] == 10 async def test_trigger_scan_conflict(self, db_client, auth_headers): site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-conflict") # First scan resp1 = await db_client.post( "/api/v1/scanner/scans", json={"site_id": site_id}, headers=auth_headers, ) assert resp1.status_code == 201 # Second scan — should conflict resp2 = await db_client.post( "/api/v1/scanner/scans", json={"site_id": site_id}, headers=auth_headers, ) assert resp2.status_code == 409 async def test_list_scans(self, db_client, auth_headers): site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-list") # Trigger a scan await db_client.post( "/api/v1/scanner/scans", json={"site_id": site_id}, headers=auth_headers, ) resp = await db_client.get( f"/api/v1/scanner/scans/site/{site_id}", headers=auth_headers, ) assert resp.status_code == 200 scans = resp.json() assert len(scans) >= 1 assert scans[0]["site_id"] == site_id async def test_get_scan_detail(self, db_client, auth_headers): site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-detail") create_resp = await db_client.post( "/api/v1/scanner/scans", json={"site_id": site_id}, headers=auth_headers, ) scan_id = create_resp.json()["id"] resp = await db_client.get( f"/api/v1/scanner/scans/{scan_id}", headers=auth_headers, ) assert resp.status_code == 200 data = resp.json() assert data["id"] == scan_id assert "results" in data async def test_get_scan_diff(self, db_client, auth_headers): site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-diff") create_resp = await db_client.post( "/api/v1/scanner/scans", json={"site_id": site_id}, headers=auth_headers, ) scan_id = create_resp.json()["id"] resp = await db_client.get( f"/api/v1/scanner/scans/{scan_id}/diff", headers=auth_headers, ) assert resp.status_code == 200 data = resp.json() assert data["current_scan_id"] == scan_id # No previous scan, so previous_scan_id should be null assert data["previous_scan_id"] is None async def test_scan_not_found(self, db_client, auth_headers): resp = await db_client.get( f"/api/v1/scanner/scans/{uuid.uuid4()}", headers=auth_headers, ) assert resp.status_code == 404 async def test_list_scans_pagination(self, db_client, auth_headers): site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-page") resp = await db_client.get( f"/api/v1/scanner/scans/site/{site_id}?limit=5&offset=0", headers=auth_headers, ) assert resp.status_code == 200 async def test_trigger_scan_requires_auth(self, db_client): resp = await db_client.post( "/api/v1/scanner/scans", json={"site_id": str(uuid.uuid4())}, ) assert resp.status_code in (401, 403) async def test_list_scans_requires_auth(self, db_client): resp = await db_client.get(f"/api/v1/scanner/scans/site/{uuid.uuid4()}") assert resp.status_code in (401, 403)