feat: initial public release
ConsentOS — a privacy-first cookie consent management platform. Self-hosted, source-available alternative to OneTrust, Cookiebot, and CookieYes. Full standards coverage (IAB TCF v2.2, GPP v1, Google Consent Mode v2, GPC, Shopify Customer Privacy API), multi-tenant architecture with role-based access, configuration cascade (system → org → group → site → region), dark-pattern detection in the scanner, and a tamper-evident consent record audit trail. This is the initial public release. Prior development history is retained internally. See README.md for the feature list, architecture overview, and quick-start instructions. Licensed under the Elastic Licence 2.0 — self-host freely; do not resell as a managed service.
This commit is contained in:
0
apps/api/tests/__init__.py
Normal file
0
apps/api/tests/__init__.py
Normal file
241
apps/api/tests/conftest.py
Normal file
241
apps/api/tests/conftest.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Shared test fixtures for the CMP API test suite.
|
||||
|
||||
Provides two modes:
|
||||
- Unit tests: use `app` and `client` fixtures (no database required)
|
||||
- Integration tests: use `db_client` fixture (requires PostgreSQL)
|
||||
|
||||
Integration tests are automatically skipped when no database is available.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Disable rate limiting for the test suite. Many tests make dozens of
|
||||
# requests from the same loopback address in rapid succession and the
|
||||
# middleware would legitimately reject them as a DoS; the middleware
|
||||
# has its own dedicated test module.
|
||||
os.environ.setdefault("RATE_LIMIT_ENABLED", "false")
|
||||
os.environ.setdefault("ENVIRONMENT", "test")
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.main import create_app
|
||||
from src.models.base import Base
|
||||
|
||||
# ── Detect whether a test database is available ──────────────────────
|
||||
|
||||
_TEST_DB_URL = os.environ.get(
|
||||
"TEST_DATABASE_URL",
|
||||
os.environ.get("DATABASE_URL", ""),
|
||||
)
|
||||
|
||||
_HAS_DB = bool(_TEST_DB_URL) and "localhost" in _TEST_DB_URL
|
||||
|
||||
|
||||
def _requires_db(fn):
|
||||
"""Mark a test as requiring a live database.
|
||||
|
||||
Also pins the event loop to session scope so that fixtures sharing the
|
||||
session-scoped engine don't get 'Future attached to a different loop'.
|
||||
"""
|
||||
fn = pytest.mark.asyncio(loop_scope="session")(fn)
|
||||
fn = pytest.mark.skipif(not _HAS_DB, reason="No test database available")(fn)
|
||||
return fn
|
||||
|
||||
|
||||
requires_db = _requires_db
|
||||
|
||||
|
||||
# ── Unit test fixtures (no database) ─────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a fresh FastAPI application instance."""
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app):
|
||||
"""Async HTTP client for unit tests (no database)."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
# ── Integration test fixtures (with database) ────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _test_engine():
|
||||
"""Create a test database engine (session-scoped)."""
|
||||
if not _HAS_DB:
|
||||
pytest.skip("No test database available")
|
||||
return create_async_engine(_TEST_DB_URL, echo=False)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def _setup_db(_test_engine):
|
||||
"""Create all tables once per test session, then seed fixture data.
|
||||
|
||||
Tests that depend on the cookie-category seed (normally applied by
|
||||
the ``0001_initial_schema`` alembic migration) get the same rows
|
||||
here so they can run without invoking alembic.
|
||||
"""
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await _seed_cookie_categories(conn)
|
||||
yield
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
async def _seed_cookie_categories(conn) -> None:
|
||||
"""Insert the default cookie categories. Mirrors migration 0001."""
|
||||
import uuid as _uuid
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
rows = [
|
||||
("10000000-0000-0000-0000-000000000001", "Necessary", "necessary", True, 0),
|
||||
("10000000-0000-0000-0000-000000000002", "Functional", "functional", False, 1),
|
||||
("10000000-0000-0000-0000-000000000003", "Analytics", "analytics", False, 2),
|
||||
("10000000-0000-0000-0000-000000000004", "Marketing", "marketing", False, 3),
|
||||
("10000000-0000-0000-0000-000000000005", "Personalisation", "personalisation", False, 4),
|
||||
]
|
||||
stmt = text(
|
||||
"""
|
||||
INSERT INTO cookie_categories
|
||||
(id, name, slug, description, is_essential, display_order)
|
||||
VALUES (:id, :name, :slug, :description, :is_essential, :display_order)
|
||||
ON CONFLICT (slug) DO NOTHING
|
||||
""",
|
||||
)
|
||||
for row_id, name, slug, is_essential, order in rows:
|
||||
await conn.execute(
|
||||
stmt,
|
||||
{
|
||||
"id": _uuid.UUID(row_id),
|
||||
"name": name,
|
||||
"slug": slug,
|
||||
"description": f"{name} cookies",
|
||||
"is_essential": is_essential,
|
||||
"display_order": order,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def db_client(_test_engine, _setup_db):
|
||||
"""Async HTTP client where each route handler gets its own DB session.
|
||||
|
||||
Each request gets an independent session/connection so there are no
|
||||
'another operation is in progress' errors from asyncpg.
|
||||
"""
|
||||
from src.db import get_db
|
||||
|
||||
app = create_app()
|
||||
|
||||
async def _override_get_db():
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
# ── Auth helper fixtures ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def test_org(_test_engine, _setup_db):
|
||||
"""Create a test organisation in the database."""
|
||||
from src.models.organisation import Organisation
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
org = Organisation(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Organisation",
|
||||
slug=f"test-org-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
return org
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def test_user(_test_engine, _setup_db, test_org):
|
||||
"""Create a test user (owner role) with a known password."""
|
||||
from src.models.user import User
|
||||
from src.services.auth import hash_password
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"admin-{uuid.uuid4().hex[:8]}@test.com",
|
||||
password_hash=hash_password("TestPassword123"),
|
||||
full_name="Test Admin",
|
||||
role="owner",
|
||||
organisation_id=test_org.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def auth_token(test_user):
|
||||
"""Generate a valid JWT token for the test user."""
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
return create_access_token(
|
||||
user_id=str(test_user.id),
|
||||
organisation_id=str(test_user.organisation_id),
|
||||
role=test_user.role,
|
||||
email=test_user.email,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def auth_headers(auth_token):
|
||||
"""HTTP headers with a valid Bearer token."""
|
||||
return {"Authorization": f"Bearer {auth_token}"}
|
||||
|
||||
|
||||
# ── Shared helper for creating sites in integration tests ────────────
|
||||
|
||||
|
||||
async def create_test_site(
|
||||
client: AsyncClient,
|
||||
headers: dict,
|
||||
*,
|
||||
domain_prefix: str = "test",
|
||||
display_name: str = "Test Site",
|
||||
) -> str:
|
||||
"""Create a site via the API and return its ID.
|
||||
|
||||
This is a helper function (not a fixture) so it can be called
|
||||
inline within each test, avoiding async fixture event-loop issues.
|
||||
"""
|
||||
resp = await client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": f"{domain_prefix}-{uuid.uuid4().hex[:8]}.com",
|
||||
"display_name": display_name,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 201, f"Failed to create test site: {resp.text}"
|
||||
return resp.json()["id"]
|
||||
179
apps/api/tests/test_auth.py
Normal file
179
apps/api/tests/test_auth.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Tests for JWT authentication service and dependencies."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.services.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
def test_hash_and_verify(self):
|
||||
password = "s3cureP@ss!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed)
|
||||
|
||||
def test_wrong_password_fails(self):
|
||||
hashed = hash_password("correct")
|
||||
assert not verify_password("wrong", hashed)
|
||||
|
||||
def test_different_hashes_for_same_password(self):
|
||||
h1 = hash_password("same")
|
||||
h2 = hash_password("same")
|
||||
assert h1 != h2 # bcrypt salts differ
|
||||
|
||||
|
||||
class TestJWTTokens:
|
||||
@pytest.fixture
|
||||
def user_data(self):
|
||||
return {
|
||||
"user_id": uuid.uuid4(),
|
||||
"organisation_id": uuid.uuid4(),
|
||||
"role": "admin",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
|
||||
def test_create_access_token_decodable(self, user_data):
|
||||
token = create_access_token(**user_data)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_data["user_id"])
|
||||
assert payload["org_id"] == str(user_data["organisation_id"])
|
||||
assert payload["role"] == "admin"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_create_refresh_token_decodable(self, user_data):
|
||||
token = create_refresh_token(
|
||||
user_id=user_data["user_id"],
|
||||
organisation_id=user_data["organisation_id"],
|
||||
)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_data["user_id"])
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_access_token_expiry(self, user_data):
|
||||
token = create_access_token(**user_data)
|
||||
payload = decode_token(token)
|
||||
settings = get_settings()
|
||||
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
|
||||
delta = exp - iat
|
||||
assert abs(delta.total_seconds() - settings.jwt_access_token_expire_minutes * 60) < 5
|
||||
|
||||
def test_refresh_token_expiry(self, user_data):
|
||||
token = create_refresh_token(
|
||||
user_id=user_data["user_id"],
|
||||
organisation_id=user_data["organisation_id"],
|
||||
)
|
||||
payload = decode_token(token)
|
||||
settings = get_settings()
|
||||
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
|
||||
delta = exp - iat
|
||||
expected = settings.jwt_refresh_token_expire_days * 86400
|
||||
assert abs(delta.total_seconds() - expected) < 5
|
||||
|
||||
def test_expired_token_raises(self):
|
||||
settings = get_settings()
|
||||
payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"org_id": str(uuid.uuid4()),
|
||||
"role": "viewer",
|
||||
"exp": datetime.now(UTC) - timedelta(hours=1),
|
||||
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
with pytest.raises(JWTError):
|
||||
decode_token(token)
|
||||
|
||||
def test_tampered_token_raises(self, user_data):
|
||||
token = create_access_token(**user_data)
|
||||
# Tamper with the token
|
||||
tampered = token[:-5] + "XXXXX"
|
||||
with pytest.raises(JWTError):
|
||||
decode_token(tampered)
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
def test_has_role(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
role="admin",
|
||||
)
|
||||
assert user.has_role("admin", "owner")
|
||||
assert not user.has_role("editor", "viewer")
|
||||
|
||||
def test_is_admin(self):
|
||||
admin = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="a@b.com",
|
||||
role="admin",
|
||||
)
|
||||
viewer = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="v@b.com",
|
||||
role="viewer",
|
||||
)
|
||||
assert admin.is_admin
|
||||
assert not viewer.is_admin
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthEndpoints:
|
||||
async def test_me_without_token_returns_401(self, client):
|
||||
response = await client.get("/api/v1/auth/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_me_with_valid_token(self, client):
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
role="editor",
|
||||
email="user@example.com",
|
||||
)
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(user_id)
|
||||
assert data["organisation_id"] == str(org_id)
|
||||
assert data["role"] == "editor"
|
||||
assert data["email"] == "user@example.com"
|
||||
|
||||
async def test_me_with_refresh_token_rejected(self, client):
|
||||
token = create_refresh_token(
|
||||
user_id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
)
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_me_with_invalid_token(self, client):
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": "Bearer invalid.token.here"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
124
apps/api/tests/test_bootstrap.py
Normal file
124
apps/api/tests/test_bootstrap.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Tests for the initial admin bootstrap service."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.models.organisation import Organisation
|
||||
from src.models.user import User
|
||||
from src.services.auth import verify_password
|
||||
from src.services.bootstrap import bootstrap_initial_admin
|
||||
from tests.conftest import requires_db
|
||||
|
||||
|
||||
def _settings(**overrides) -> Settings:
|
||||
base: dict = dict(
|
||||
environment="test",
|
||||
initial_admin_email=None,
|
||||
initial_admin_password=None,
|
||||
initial_admin_full_name="Administrator",
|
||||
initial_org_name="Default Organisation",
|
||||
initial_org_slug="default",
|
||||
)
|
||||
base.update(overrides)
|
||||
return Settings(**base)
|
||||
|
||||
|
||||
class TestBootstrapNoOp:
|
||||
"""Pure unit tests — bootstrap must short-circuit before touching the DB."""
|
||||
|
||||
async def test_noop_when_email_unset(self):
|
||||
settings = _settings(initial_admin_password="pw")
|
||||
with patch("src.services.bootstrap.async_session_factory") as factory:
|
||||
await bootstrap_initial_admin(settings)
|
||||
factory.assert_not_called()
|
||||
|
||||
async def test_noop_when_password_unset(self):
|
||||
settings = _settings(initial_admin_email="admin@example.com")
|
||||
with patch("src.services.bootstrap.async_session_factory") as factory:
|
||||
await bootstrap_initial_admin(settings)
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestBootstrapWithDatabase:
|
||||
"""Integration tests — exercise the real SQL path."""
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def clean_db(self, _test_engine, _setup_db):
|
||||
"""Strip users and orgs so bootstrap sees an empty table."""
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
await session.execute(User.__table__.delete())
|
||||
await session.execute(Organisation.__table__.delete())
|
||||
await session.commit()
|
||||
yield
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
await session.execute(User.__table__.delete())
|
||||
await session.execute(Organisation.__table__.delete())
|
||||
await session.commit()
|
||||
|
||||
async def test_creates_org_and_owner_when_empty(self, _test_engine, clean_db):
|
||||
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
|
||||
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
|
||||
settings = _settings(
|
||||
initial_admin_email=email,
|
||||
initial_admin_password="SuperSecret123",
|
||||
initial_org_slug=slug,
|
||||
initial_org_name="Bootstrapped Org",
|
||||
)
|
||||
|
||||
def _factory():
|
||||
return AsyncSession(_test_engine, expire_on_commit=False)
|
||||
|
||||
with patch("src.services.bootstrap.async_session_factory", _factory):
|
||||
await bootstrap_initial_admin(settings)
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
user = (await session.execute(select(User).where(User.email == email))).scalar_one()
|
||||
org = (
|
||||
await session.execute(select(Organisation).where(Organisation.slug == slug))
|
||||
).scalar_one()
|
||||
|
||||
assert user.role == "owner"
|
||||
assert user.organisation_id == org.id
|
||||
assert user.full_name == "Administrator"
|
||||
assert verify_password("SuperSecret123", user.password_hash)
|
||||
assert org.name == "Bootstrapped Org"
|
||||
assert org.contact_email == email
|
||||
|
||||
async def test_idempotent_when_user_exists(self, _test_engine, clean_db):
|
||||
"""A second invocation must not create a second user."""
|
||||
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
|
||||
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
|
||||
settings = _settings(
|
||||
initial_admin_email=email,
|
||||
initial_admin_password="SuperSecret123",
|
||||
initial_org_slug=slug,
|
||||
)
|
||||
|
||||
def _factory():
|
||||
return AsyncSession(_test_engine, expire_on_commit=False)
|
||||
|
||||
with patch("src.services.bootstrap.async_session_factory", _factory):
|
||||
await bootstrap_initial_admin(settings)
|
||||
await bootstrap_initial_admin(
|
||||
_settings(
|
||||
initial_admin_email="someone-else@example.com",
|
||||
initial_admin_password="Different123",
|
||||
initial_org_slug=slug,
|
||||
)
|
||||
)
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
users = (await session.execute(select(User))).scalars().all()
|
||||
|
||||
assert len(users) == 1
|
||||
assert users[0].email == email
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="session")
|
||||
869
apps/api/tests/test_classification.py
Normal file
869
apps/api/tests/test_classification.py
Normal file
@@ -0,0 +1,869 @@
|
||||
"""Tests for known cookies database and auto-categorisation engine — CMP-22.
|
||||
|
||||
Covers:
|
||||
- Classification service logic (unit tests — pure functions)
|
||||
- Pattern matching (exact, wildcard, regex)
|
||||
- Priority ordering (allow-list → exact → regex → unmatched)
|
||||
- Known cookie CRUD endpoints (unit tests with mocked DB)
|
||||
- Classification endpoints (unit tests with mocked DB)
|
||||
- Schema validation
|
||||
- Integration tests against live database
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.schemas.cookie import (
|
||||
ClassificationResultResponse,
|
||||
ClassifySingleRequest,
|
||||
ClassifySiteResponse,
|
||||
KnownCookieCreate,
|
||||
KnownCookieResponse,
|
||||
KnownCookieUpdate,
|
||||
)
|
||||
from src.services.classification import (
|
||||
ClassificationResult,
|
||||
MatchSource,
|
||||
_match_pattern,
|
||||
_match_regex,
|
||||
classify_cookie,
|
||||
)
|
||||
|
||||
# ── Schema tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
"""Validate known cookie and classification schemas."""
|
||||
|
||||
def test_known_cookie_create(self):
|
||||
kc = KnownCookieCreate(
|
||||
name_pattern="_ga",
|
||||
domain_pattern="*",
|
||||
category_id=uuid.uuid4(),
|
||||
vendor="Google",
|
||||
description="GA cookie",
|
||||
)
|
||||
assert kc.is_regex is False
|
||||
|
||||
def test_known_cookie_create_regex(self):
|
||||
kc = KnownCookieCreate(
|
||||
name_pattern="_hj.*",
|
||||
domain_pattern=".*",
|
||||
category_id=uuid.uuid4(),
|
||||
is_regex=True,
|
||||
)
|
||||
assert kc.is_regex is True
|
||||
|
||||
def test_known_cookie_update_partial(self):
|
||||
ku = KnownCookieUpdate(vendor="Updated Vendor")
|
||||
dumped = ku.model_dump(exclude_unset=True)
|
||||
assert "vendor" in dumped
|
||||
assert "category_id" not in dumped
|
||||
|
||||
def test_known_cookie_response(self):
|
||||
resp = KnownCookieResponse(
|
||||
id=uuid.uuid4(),
|
||||
name_pattern="_ga",
|
||||
domain_pattern="*",
|
||||
category_id=uuid.uuid4(),
|
||||
is_regex=False,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
assert resp.vendor is None
|
||||
|
||||
def test_classification_result_response(self):
|
||||
crr = ClassificationResultResponse(
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
match_source="known_exact",
|
||||
matched=True,
|
||||
)
|
||||
assert crr.matched is True
|
||||
|
||||
def test_classify_single_request(self):
|
||||
req = ClassifySingleRequest(cookie_name="_ga", cookie_domain=".example.com")
|
||||
assert req.cookie_name == "_ga"
|
||||
|
||||
def test_classify_single_request_validation(self):
|
||||
with pytest.raises(ValueError):
|
||||
ClassifySingleRequest(cookie_name="", cookie_domain=".example.com")
|
||||
|
||||
def test_classify_site_response(self):
|
||||
resp = ClassifySiteResponse(
|
||||
site_id="abc",
|
||||
total=10,
|
||||
matched=7,
|
||||
unmatched=3,
|
||||
results=[],
|
||||
)
|
||||
assert resp.matched == 7
|
||||
|
||||
def test_match_source_enum(self):
|
||||
assert MatchSource.ALLOW_LIST == "allow_list"
|
||||
assert MatchSource.KNOWN_EXACT == "known_exact"
|
||||
assert MatchSource.KNOWN_REGEX == "known_regex"
|
||||
assert MatchSource.UNMATCHED == "unmatched"
|
||||
|
||||
|
||||
# ── Pattern matching unit tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestPatternMatching:
|
||||
"""Test the _match_pattern and _match_regex helpers."""
|
||||
|
||||
def test_exact_match(self):
|
||||
assert _match_pattern("_ga", "_ga") is True
|
||||
|
||||
def test_exact_match_case_insensitive(self):
|
||||
assert _match_pattern("_GA", "_ga") is True
|
||||
assert _match_pattern("_ga", "_GA") is True
|
||||
|
||||
def test_exact_no_match(self):
|
||||
assert _match_pattern("_ga", "_gid") is False
|
||||
|
||||
def test_wildcard_star(self):
|
||||
assert _match_pattern("*", "_ga") is True
|
||||
assert _match_pattern("*", "anything") is True
|
||||
|
||||
def test_wildcard_prefix(self):
|
||||
assert _match_pattern("_ga_*", "_ga_ABC123") is True
|
||||
assert _match_pattern("_ga_*", "_ga_") is True
|
||||
assert _match_pattern("_ga_*", "_gid") is False
|
||||
|
||||
def test_wildcard_suffix(self):
|
||||
assert _match_pattern("*.google.com", ".google.com") is True
|
||||
assert _match_pattern("*.google.com", "www.google.com") is True
|
||||
assert _match_pattern("*.google.com", ".facebook.com") is False
|
||||
|
||||
def test_wildcard_middle(self):
|
||||
assert _match_pattern("_ga*id", "_ga_gid") is True # * matches _g
|
||||
assert _match_pattern("_ga*id", "_gaid") is True
|
||||
assert _match_pattern("_ga*id", "_ga") is False # must end in id
|
||||
|
||||
def test_empty_values(self):
|
||||
assert _match_pattern("", "_ga") is False
|
||||
assert _match_pattern("_ga", "") is False
|
||||
assert _match_pattern("", "") is False
|
||||
|
||||
def test_regex_match(self):
|
||||
assert _match_regex(r"_hj.*", "_hjSession_12345") is True
|
||||
assert _match_regex(r"_hj.*", "_ga") is False
|
||||
|
||||
def test_regex_case_insensitive(self):
|
||||
assert _match_regex(r"_hj.*", "_HJSession") is True
|
||||
|
||||
def test_regex_anchored(self):
|
||||
# re.match anchors at start by default
|
||||
assert _match_regex(r"_pk_id.*", "_pk_id.abc.123") is True
|
||||
assert _match_regex(r"_pk_id.*", "x_pk_id") is False
|
||||
|
||||
def test_regex_invalid_pattern(self):
|
||||
assert _match_regex(r"[invalid", "test") is False
|
||||
|
||||
def test_regex_full_domain_match(self):
|
||||
assert _match_regex(r".*", ".example.com") is True
|
||||
|
||||
def test_wildcard_dynamic_id_suffix(self):
|
||||
"""Cookies with dynamic IDs should match wildcard prefix patterns."""
|
||||
assert _match_pattern("_hjSessionUser_*", "_hjSessionUser_1150536") is True
|
||||
assert _match_pattern("_hjSession_*", "_hjSession_9876543") is True
|
||||
assert _match_pattern("ri--*", "ri--zC77O2yRxuIvW5fjRAq0RdzNYaF-x") is True
|
||||
assert _match_pattern("intercom-id-*", "intercom-id-abc123def") is True
|
||||
assert _match_pattern("amp_*", "amp_ff29a3") is True
|
||||
assert _match_pattern("mp_*", "mp_abc123_mixpanel") is True
|
||||
|
||||
def test_wildcard_does_not_overmatch(self):
|
||||
"""Wildcard patterns should not match unrelated cookies."""
|
||||
assert _match_pattern("_hjSessionUser_*", "_hjSession_123") is False
|
||||
assert _match_pattern("ri--*", "ri-single-dash") is False
|
||||
assert _match_pattern("intercom-id-*", "intercom-session-xyz") is False
|
||||
|
||||
|
||||
# ── Classification engine unit tests ─────────────────────────────────
|
||||
|
||||
|
||||
def _make_category(slug: str, cat_id: uuid.UUID | None = None):
|
||||
"""Create a mock CookieCategory."""
|
||||
cat = MagicMock()
|
||||
cat.id = cat_id or uuid.uuid4()
|
||||
cat.slug = slug
|
||||
return cat
|
||||
|
||||
|
||||
def _make_known(
|
||||
name_pattern: str,
|
||||
domain_pattern: str,
|
||||
category_id: uuid.UUID,
|
||||
vendor: str | None = None,
|
||||
description: str | None = None,
|
||||
is_regex: bool = False,
|
||||
):
|
||||
"""Create a mock KnownCookie."""
|
||||
known = MagicMock()
|
||||
known.name_pattern = name_pattern
|
||||
known.domain_pattern = domain_pattern
|
||||
known.category_id = category_id
|
||||
known.vendor = vendor
|
||||
known.description = description
|
||||
known.is_regex = is_regex
|
||||
return known
|
||||
|
||||
|
||||
def _make_allow_entry(
|
||||
name_pattern: str,
|
||||
domain_pattern: str,
|
||||
category_id: uuid.UUID,
|
||||
description: str | None = None,
|
||||
):
|
||||
"""Create a mock CookieAllowListEntry."""
|
||||
entry = MagicMock()
|
||||
entry.name_pattern = name_pattern
|
||||
entry.domain_pattern = domain_pattern
|
||||
entry.category_id = category_id
|
||||
entry.description = description
|
||||
return entry
|
||||
|
||||
|
||||
class TestClassifyCookie:
|
||||
"""Test the classify_cookie pure function."""
|
||||
|
||||
def setup_method(self):
|
||||
self.analytics_cat = _make_category("analytics")
|
||||
self.marketing_cat = _make_category("marketing")
|
||||
self.necessary_cat = _make_category("necessary")
|
||||
self.category_map = {
|
||||
self.analytics_cat.id: self.analytics_cat,
|
||||
self.marketing_cat.id: self.marketing_cat,
|
||||
self.necessary_cat.id: self.necessary_cat,
|
||||
}
|
||||
|
||||
def test_exact_known_match(self):
|
||||
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
|
||||
result = classify_cookie("_ga", ".example.com", [], [known], [], self.category_map)
|
||||
assert result.matched is True
|
||||
assert result.match_source == MatchSource.KNOWN_EXACT
|
||||
assert result.category_slug == "analytics"
|
||||
assert result.vendor == "Google"
|
||||
|
||||
def test_regex_known_match(self):
|
||||
known = _make_known(
|
||||
r"_hj.*",
|
||||
r".*",
|
||||
self.analytics_cat.id,
|
||||
vendor="Hotjar",
|
||||
is_regex=True,
|
||||
)
|
||||
result = classify_cookie(
|
||||
"_hjSession_123",
|
||||
".example.com",
|
||||
[],
|
||||
[],
|
||||
[known],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.match_source == MatchSource.KNOWN_REGEX
|
||||
assert result.vendor == "Hotjar"
|
||||
|
||||
def test_allow_list_match(self):
|
||||
entry = _make_allow_entry(
|
||||
"_custom_cookie",
|
||||
"*",
|
||||
self.necessary_cat.id,
|
||||
description="Site-specific override",
|
||||
)
|
||||
result = classify_cookie(
|
||||
"_custom_cookie",
|
||||
".example.com",
|
||||
[entry],
|
||||
[],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.match_source == MatchSource.ALLOW_LIST
|
||||
assert result.category_slug == "necessary"
|
||||
|
||||
def test_allow_list_takes_priority_over_known(self):
|
||||
"""Allow-list should override known cookies database."""
|
||||
allow_entry = _make_allow_entry(
|
||||
"_ga",
|
||||
"*",
|
||||
self.necessary_cat.id,
|
||||
description="Overridden to necessary",
|
||||
)
|
||||
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[allow_entry],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.match_source == MatchSource.ALLOW_LIST
|
||||
assert result.category_slug == "necessary"
|
||||
|
||||
def test_exact_takes_priority_over_regex(self):
|
||||
"""Exact match should be preferred over regex match."""
|
||||
exact = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
|
||||
regex = _make_known(
|
||||
r"_g.*",
|
||||
r".*",
|
||||
self.marketing_cat.id,
|
||||
vendor="Other",
|
||||
is_regex=True,
|
||||
)
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[],
|
||||
[exact],
|
||||
[regex],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.match_source == MatchSource.KNOWN_EXACT
|
||||
assert result.category_slug == "analytics"
|
||||
|
||||
def test_unmatched(self):
|
||||
result = classify_cookie(
|
||||
"obscure_cookie",
|
||||
".unknown.com",
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is False
|
||||
assert result.match_source == MatchSource.UNMATCHED
|
||||
assert result.category_id is None
|
||||
|
||||
def test_domain_must_match(self):
|
||||
"""Cookie should not match if domain pattern doesn't match."""
|
||||
known = _make_known("_ga", "*.google.com", self.analytics_cat.id)
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is False
|
||||
|
||||
def test_name_must_match(self):
|
||||
"""Cookie should not match if name pattern doesn't match."""
|
||||
known = _make_known("_gid", "*", self.analytics_cat.id)
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is False
|
||||
|
||||
def test_wildcard_domain_match(self):
|
||||
known = _make_known(
|
||||
"fr",
|
||||
"*.facebook.com",
|
||||
self.marketing_cat.id,
|
||||
vendor="Meta",
|
||||
)
|
||||
result = classify_cookie(
|
||||
"fr",
|
||||
".facebook.com",
|
||||
[],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.vendor == "Meta"
|
||||
|
||||
def test_classification_result_fields(self):
|
||||
result = ClassificationResult(
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
)
|
||||
assert result.category_id is None
|
||||
assert result.match_source == MatchSource.UNMATCHED
|
||||
assert result.matched is False
|
||||
|
||||
|
||||
# ── Router unit tests (mocked service) ──────────────────────────────
|
||||
|
||||
|
||||
def _mock_db():
|
||||
"""Create a mock async DB session."""
|
||||
db = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
db.execute.return_value = mock_result
|
||||
return db
|
||||
|
||||
|
||||
async def _client(app, db):
|
||||
"""Create an async test client with mocked DB and auth."""
|
||||
from src.db import get_db
|
||||
from src.services.dependencies import get_current_user, require_role
|
||||
|
||||
user = MagicMock()
|
||||
user.organisation_id = uuid.uuid4()
|
||||
user.role = "owner"
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_current_user] = lambda: user
|
||||
|
||||
def _override_require_role(*_roles):
|
||||
return lambda: user
|
||||
|
||||
app.dependency_overrides[require_role] = _override_require_role
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestKnownCookieRoutes:
|
||||
"""Test known cookie CRUD endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_known_cookies(self, app):
|
||||
db = _mock_db()
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.get("/api/v1/cookies/known")
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_known_cookie(self, app):
|
||||
db = _mock_db()
|
||||
# Mock category validation
|
||||
cat_result = MagicMock()
|
||||
cat_result.scalar_one_or_none.return_value = MagicMock()
|
||||
# Mock the created known cookie
|
||||
known_mock = MagicMock()
|
||||
known_mock.id = uuid.uuid4()
|
||||
known_mock.name_pattern = "_ga"
|
||||
known_mock.domain_pattern = "*"
|
||||
known_mock.category_id = uuid.uuid4()
|
||||
known_mock.vendor = "Google"
|
||||
known_mock.description = "GA cookie"
|
||||
known_mock.is_regex = False
|
||||
known_mock.created_at = datetime.now()
|
||||
known_mock.updated_at = datetime.now()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(stmt):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# Category validation
|
||||
return cat_result
|
||||
return MagicMock()
|
||||
|
||||
db.execute = mock_execute
|
||||
db.flush = AsyncMock()
|
||||
db.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||
db.add = MagicMock()
|
||||
|
||||
with patch(
|
||||
"src.routers.cookies.KnownCookie",
|
||||
return_value=known_mock,
|
||||
):
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/cookies/known",
|
||||
json={
|
||||
"name_pattern": "_ga",
|
||||
"domain_pattern": "*",
|
||||
"category_id": str(uuid.uuid4()),
|
||||
"vendor": "Google",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_known_cookie_not_found(self, app):
|
||||
db = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
db.execute.return_value = mock_result
|
||||
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.get(f"/api/v1/cookies/known/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestClassificationRoutes:
|
||||
"""Test classification endpoint responses."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_preview(self, app):
|
||||
db = _mock_db()
|
||||
mock_result = ClassificationResult(
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
category_id=uuid.uuid4(),
|
||||
category_slug="analytics",
|
||||
vendor="Google",
|
||||
match_source=MatchSource.KNOWN_EXACT,
|
||||
matched=True,
|
||||
)
|
||||
with patch(
|
||||
"src.routers.cookies.classify_single_cookie",
|
||||
return_value=mock_result,
|
||||
):
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{uuid.uuid4()}/classify/preview",
|
||||
json={
|
||||
"cookie_name": "_ga",
|
||||
"cookie_domain": ".example.com",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["matched"] is True
|
||||
assert data["match_source"] == "known_exact"
|
||||
|
||||
|
||||
# ── Integration tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
try:
|
||||
from tests.conftest import create_test_site, requires_db
|
||||
except ImportError:
|
||||
from conftest import create_test_site, requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestClassificationIntegration:
|
||||
"""Integration tests against a live database."""
|
||||
|
||||
async def _get_category_id(self, client: AsyncClient, headers: dict, slug: str) -> str:
|
||||
"""Get a category ID by slug."""
|
||||
resp = await client.get("/api/v1/cookies/categories", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
for cat in resp.json():
|
||||
if cat["slug"] == slug:
|
||||
return cat["id"]
|
||||
pytest.fail(f"Category '{slug}' not found")
|
||||
|
||||
async def _create_known_cookie(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
headers: dict,
|
||||
name_pattern: str,
|
||||
domain_pattern: str,
|
||||
category_slug: str,
|
||||
*,
|
||||
vendor: str | None = None,
|
||||
is_regex: bool = False,
|
||||
) -> str:
|
||||
"""Create a known cookie and return its ID."""
|
||||
cat_id = await self._get_category_id(client, headers, category_slug)
|
||||
resp = await client.post(
|
||||
"/api/v1/cookies/known",
|
||||
headers=headers,
|
||||
json={
|
||||
"name_pattern": name_pattern,
|
||||
"domain_pattern": domain_pattern,
|
||||
"category_id": cat_id,
|
||||
"vendor": vendor,
|
||||
"is_regex": is_regex,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()["id"]
|
||||
|
||||
async def _create_cookie(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
headers: dict,
|
||||
site_id: str,
|
||||
name: str,
|
||||
domain: str,
|
||||
) -> str:
|
||||
"""Create a pending cookie on a site and return its ID."""
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
headers=headers,
|
||||
json={"name": name, "domain": domain},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()["id"]
|
||||
|
||||
async def test_known_cookies_crud(self, db_client, auth_headers):
|
||||
"""Test full CRUD lifecycle for known cookies."""
|
||||
cat_id = await self._get_category_id(db_client, auth_headers, "analytics")
|
||||
# Create
|
||||
resp = await db_client.post(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name_pattern": f"_test_{uuid.uuid4().hex[:6]}",
|
||||
"domain_pattern": "*",
|
||||
"category_id": cat_id,
|
||||
"vendor": "TestVendor",
|
||||
"description": "Test cookie",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
known_id = resp.json()["id"]
|
||||
|
||||
# Read
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["vendor"] == "TestVendor"
|
||||
|
||||
# Update
|
||||
resp = await db_client.patch(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
json={"vendor": "UpdatedVendor"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["vendor"] == "UpdatedVendor"
|
||||
|
||||
# List (with search)
|
||||
resp = await db_client.get(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
params={"vendor": "UpdatedVendor"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert any(k["id"] == known_id for k in resp.json())
|
||||
|
||||
# Delete
|
||||
resp = await db_client.delete(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify deleted
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_classify_exact_match(self, db_client, auth_headers):
|
||||
"""Test classification with exact known cookie match."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-exact")
|
||||
# Create a known cookie pattern
|
||||
pattern_name = f"_test_exact_{uuid.uuid4().hex[:6]}"
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
pattern_name,
|
||||
"*",
|
||||
"analytics",
|
||||
vendor="TestVendor",
|
||||
)
|
||||
# Create a pending cookie on the site
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
pattern_name,
|
||||
".example.com",
|
||||
)
|
||||
# Classify
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
assert data["matched"] >= 1
|
||||
matched = [r for r in data["results"] if r["matched"]]
|
||||
assert any(r["cookie_name"] == pattern_name for r in matched)
|
||||
|
||||
async def test_classify_regex_match(self, db_client, auth_headers):
|
||||
"""Test classification with regex known cookie match."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-regex")
|
||||
prefix = f"_rx_{uuid.uuid4().hex[:4]}"
|
||||
# Create regex pattern
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
f"{prefix}.*",
|
||||
".*",
|
||||
"analytics",
|
||||
vendor="RegexVendor",
|
||||
is_regex=True,
|
||||
)
|
||||
# Create a cookie that should match the regex
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
f"{prefix}_session_123",
|
||||
".example.com",
|
||||
)
|
||||
# Classify
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["matched"] >= 1
|
||||
matched = [r for r in data["results"] if r["matched"]]
|
||||
assert any(r["match_source"] == "known_regex" for r in matched)
|
||||
|
||||
async def test_classify_unmatched(self, db_client, auth_headers):
|
||||
"""Cookies without known patterns should remain unmatched."""
|
||||
site_id = await create_test_site(
|
||||
db_client, auth_headers, domain_prefix="classify-unmatched"
|
||||
)
|
||||
unique_name = f"_unknown_{uuid.uuid4().hex[:8]}"
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
unique_name,
|
||||
".obscure-domain.com",
|
||||
)
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["unmatched"] >= 1
|
||||
|
||||
async def test_classify_preview(self, db_client, auth_headers):
|
||||
"""Test preview classification without saving."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-preview")
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify/preview",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"cookie_name": "_unknown_cookie",
|
||||
"cookie_domain": ".test.com",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["matched"] is False
|
||||
assert data["match_source"] == "unmatched"
|
||||
|
||||
async def test_classify_allow_list_priority(self, db_client, auth_headers):
|
||||
"""Allow-list entries should take priority over known cookies."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-allow")
|
||||
cookie_name = f"_priority_{uuid.uuid4().hex[:6]}"
|
||||
|
||||
# Add to known cookies as marketing
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
cookie_name,
|
||||
"*",
|
||||
"marketing",
|
||||
)
|
||||
|
||||
# Add to allow-list as necessary (should take priority)
|
||||
necessary_id = await self._get_category_id(db_client, auth_headers, "necessary")
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name_pattern": cookie_name,
|
||||
"domain_pattern": "*",
|
||||
"category_id": necessary_id,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
# Create cookie and classify
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
cookie_name,
|
||||
".example.com",
|
||||
)
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
matched = [r for r in data["results"] if r["cookie_name"] == cookie_name]
|
||||
assert len(matched) == 1
|
||||
assert matched[0]["match_source"] == "allow_list"
|
||||
assert matched[0]["category_id"] == necessary_id
|
||||
|
||||
async def test_known_cookies_not_found(self, db_client, auth_headers):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/known/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_known_cookies_invalid_category(self, db_client, auth_headers):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name_pattern": "_test",
|
||||
"domain_pattern": "*",
|
||||
"category_id": str(uuid.uuid4()),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_known_cookies_auth_required(self, db_client):
|
||||
"""Known cookie endpoints require authentication."""
|
||||
resp = await db_client.get("/api/v1/cookies/known")
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_classify_empty_site(self, db_client, auth_headers):
|
||||
"""Classifying a site with no cookies should return empty results."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-empty")
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 0
|
||||
assert data["matched"] == 0
|
||||
|
||||
async def test_list_known_cookies_search(self, db_client, auth_headers):
|
||||
"""Test searching known cookies by name pattern."""
|
||||
unique = uuid.uuid4().hex[:6]
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
f"_search_{unique}",
|
||||
"*",
|
||||
"analytics",
|
||||
)
|
||||
resp = await db_client.get(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
params={"search": f"_search_{unique}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert len(results) >= 1
|
||||
assert all(f"_search_{unique}" in r["name_pattern"] for r in results)
|
||||
597
apps/api/tests/test_compliance.py
Normal file
597
apps/api/tests/test_compliance.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""Tests for the compliance rule engine and router."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.schemas.compliance import (
|
||||
ComplianceCheckResponse,
|
||||
ComplianceIssue,
|
||||
Framework,
|
||||
FrameworkResult,
|
||||
Severity,
|
||||
)
|
||||
from src.services.compliance import (
|
||||
CCPA_RULES,
|
||||
CNIL_RULES,
|
||||
EPRIVACY_RULES,
|
||||
FRAMEWORK_RULES,
|
||||
GDPR_RULES,
|
||||
LGPD_RULES,
|
||||
SiteContext,
|
||||
calculate_overall_score,
|
||||
run_compliance_check,
|
||||
run_framework_check,
|
||||
)
|
||||
|
||||
# ── SiteContext defaults ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSiteContext:
|
||||
def test_default_values(self):
|
||||
ctx = SiteContext()
|
||||
assert ctx.blocking_mode == "opt_in"
|
||||
assert ctx.tcf_enabled is False
|
||||
assert ctx.gcm_enabled is True
|
||||
assert ctx.consent_expiry_days == 365
|
||||
assert ctx.has_reject_button is True
|
||||
assert ctx.has_granular_choices is True
|
||||
assert ctx.has_cookie_wall is False
|
||||
assert ctx.pre_ticked_boxes is False
|
||||
|
||||
def test_custom_values(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
consent_expiry_days=180,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
assert ctx.blocking_mode == "opt_out"
|
||||
assert ctx.consent_expiry_days == 180
|
||||
assert ctx.privacy_policy_url == "https://example.com/privacy"
|
||||
|
||||
|
||||
# ── GDPR rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGDPRRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
has_reject_button=True,
|
||||
has_granular_choices=True,
|
||||
has_cookie_wall=False,
|
||||
pre_ticked_boxes=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
uncategorised_cookies=0,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
assert len(result.issues) == 0
|
||||
assert result.rules_passed == result.rules_checked
|
||||
|
||||
def test_opt_out_mode_fails(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
|
||||
assert result.status == "non_compliant"
|
||||
|
||||
def test_informational_mode_fails(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="informational",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
|
||||
|
||||
def test_no_reject_button_fails(self):
|
||||
ctx = SiteContext(
|
||||
has_reject_button=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_reject_button" for i in result.issues)
|
||||
|
||||
def test_no_granular_consent_fails(self):
|
||||
ctx = SiteContext(
|
||||
has_granular_choices=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_granular" for i in result.issues)
|
||||
|
||||
def test_cookie_wall_fails(self):
|
||||
ctx = SiteContext(
|
||||
has_cookie_wall=True,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_cookie_wall" for i in result.issues)
|
||||
|
||||
def test_pre_ticked_fails(self):
|
||||
ctx = SiteContext(
|
||||
pre_ticked_boxes=True,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_pre_ticked" for i in result.issues)
|
||||
|
||||
def test_no_privacy_policy_warns(self):
|
||||
ctx = SiteContext(privacy_policy_url=None)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
policy_issues = [i for i in result.issues if i.rule_id == "gdpr_privacy_policy"]
|
||||
assert len(policy_issues) == 1
|
||||
assert policy_issues[0].severity == Severity.WARNING
|
||||
|
||||
def test_uncategorised_cookies_warns(self):
|
||||
ctx = SiteContext(
|
||||
uncategorised_cookies=5,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
uncat_issues = [i for i in result.issues if i.rule_id == "gdpr_uncategorised"]
|
||||
assert len(uncat_issues) == 1
|
||||
assert "5" in uncat_issues[0].message
|
||||
|
||||
def test_multiple_failures_accumulate(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
has_reject_button=False,
|
||||
has_granular_choices=False,
|
||||
has_cookie_wall=True,
|
||||
pre_ticked_boxes=True,
|
||||
privacy_policy_url=None,
|
||||
uncategorised_cookies=3,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.score == 0 # Capped at 0
|
||||
assert result.status == "non_compliant"
|
||||
assert len(result.issues) >= 5
|
||||
|
||||
|
||||
# ── CNIL rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCNILRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
has_reject_button=True,
|
||||
has_granular_choices=True,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
consent_expiry_days=180,
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_consent_expiry_too_long(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=365,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert any(i.rule_id == "cnil_reconsent" for i in result.issues)
|
||||
|
||||
def test_consent_expiry_at_limit(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=182,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert not any(i.rule_id == "cnil_reconsent" for i in result.issues)
|
||||
|
||||
def test_cookie_lifetime_too_long(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=400,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
|
||||
|
||||
def test_cookie_lifetime_at_limit(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=395,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert not any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
|
||||
|
||||
def test_inherits_gdpr_rules(self):
|
||||
"""CNIL should check all GDPR rules plus CNIL-specific ones."""
|
||||
assert len(CNIL_RULES) > len(GDPR_RULES)
|
||||
|
||||
def test_reject_first_layer(self):
|
||||
ctx = SiteContext(
|
||||
has_reject_button=False,
|
||||
consent_expiry_days=180,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert any(i.rule_id == "cnil_reject_first_layer" for i in result.issues)
|
||||
|
||||
|
||||
# ── CCPA rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCCPARules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_opt_in_also_acceptable(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert not any(i.rule_id == "ccpa_opt_out" for i in result.issues)
|
||||
|
||||
def test_informational_mode_passes_ccpa(self):
|
||||
"""CCPA opt-out check passes for informational (it's not 'informational')."""
|
||||
ctx = SiteContext(
|
||||
blocking_mode="informational",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
# informational is not in ("opt_out", "opt_in"), so it fails
|
||||
assert any(i.rule_id == "ccpa_opt_out" for i in result.issues)
|
||||
|
||||
def test_no_do_not_sell_link(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
|
||||
|
||||
def test_no_banner_config_fails_dns(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config=None,
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
|
||||
|
||||
def test_no_privacy_policy_warns(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url=None,
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert any(i.rule_id == "ccpa_privacy_policy" for i in result.issues)
|
||||
|
||||
|
||||
# ── ePrivacy rules ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEPrivacyRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(blocking_mode="opt_in")
|
||||
result = run_framework_check(Framework.EPRIVACY, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_opt_out_passes(self):
|
||||
ctx = SiteContext(blocking_mode="opt_out")
|
||||
result = run_framework_check(Framework.EPRIVACY, ctx)
|
||||
assert not any(i.rule_id == "eprivacy_consent" for i in result.issues)
|
||||
|
||||
def test_informational_fails(self):
|
||||
ctx = SiteContext(blocking_mode="informational")
|
||||
result = run_framework_check(Framework.EPRIVACY, ctx)
|
||||
assert any(i.rule_id == "eprivacy_consent" for i in result.issues)
|
||||
|
||||
|
||||
# ── LGPD rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLGPDRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
has_granular_choices=True,
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_informational_fails(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="informational",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
|
||||
|
||||
def test_no_privacy_policy_warns(self):
|
||||
ctx = SiteContext(privacy_policy_url=None)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert any(i.rule_id == "lgpd_data_controller" for i in result.issues)
|
||||
|
||||
def test_no_granular_warns(self):
|
||||
ctx = SiteContext(
|
||||
has_granular_choices=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert any(i.rule_id == "lgpd_granular" for i in result.issues)
|
||||
|
||||
def test_opt_out_passes(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert not any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
|
||||
|
||||
|
||||
# ── Engine orchestration ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestComplianceEngine:
|
||||
def test_run_all_frameworks(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
has_reject_button=True,
|
||||
has_granular_choices=True,
|
||||
consent_expiry_days=180,
|
||||
)
|
||||
results = run_compliance_check(ctx)
|
||||
assert len(results) == 5
|
||||
frameworks = {r.framework for r in results}
|
||||
assert frameworks == {
|
||||
Framework.GDPR,
|
||||
Framework.CNIL,
|
||||
Framework.CCPA,
|
||||
Framework.EPRIVACY,
|
||||
Framework.LGPD,
|
||||
}
|
||||
|
||||
def test_run_specific_frameworks(self):
|
||||
ctx = SiteContext()
|
||||
results = run_compliance_check(ctx, [Framework.GDPR, Framework.CCPA])
|
||||
assert len(results) == 2
|
||||
assert results[0].framework == Framework.GDPR
|
||||
assert results[1].framework == Framework.CCPA
|
||||
|
||||
def test_run_single_framework(self):
|
||||
ctx = SiteContext()
|
||||
results = run_compliance_check(ctx, [Framework.EPRIVACY])
|
||||
assert len(results) == 1
|
||||
assert results[0].framework == Framework.EPRIVACY
|
||||
|
||||
def test_empty_frameworks_list_runs_all(self):
|
||||
ctx = SiteContext(
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
consent_expiry_days=180,
|
||||
)
|
||||
results = run_compliance_check(ctx, None)
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
class TestScoring:
|
||||
def test_perfect_score(self):
|
||||
result = FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=100,
|
||||
status="compliant",
|
||||
rules_checked=7,
|
||||
rules_passed=7,
|
||||
)
|
||||
assert calculate_overall_score([result]) == 100
|
||||
|
||||
def test_zero_score(self):
|
||||
result = FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=0,
|
||||
status="non_compliant",
|
||||
rules_checked=7,
|
||||
rules_passed=0,
|
||||
)
|
||||
assert calculate_overall_score([result]) == 0
|
||||
|
||||
def test_average_across_frameworks(self):
|
||||
results = [
|
||||
FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=100,
|
||||
status="compliant",
|
||||
rules_checked=7,
|
||||
rules_passed=7,
|
||||
),
|
||||
FrameworkResult(
|
||||
framework=Framework.CCPA,
|
||||
score=50,
|
||||
status="partial",
|
||||
rules_checked=3,
|
||||
rules_passed=1,
|
||||
),
|
||||
]
|
||||
assert calculate_overall_score(results) == 75
|
||||
|
||||
def test_empty_results(self):
|
||||
assert calculate_overall_score([]) == 100
|
||||
|
||||
def test_critical_issues_deduct_20(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
# opt_out causes one critical issue (gdpr_opt_in) → -20 points
|
||||
assert result.score == 80
|
||||
|
||||
def test_warning_issues_deduct_5(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url=None,
|
||||
uncategorised_cookies=0,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
# Missing privacy policy is a warning → -5 points
|
||||
assert result.score == 95
|
||||
|
||||
def test_score_floors_at_zero(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
has_reject_button=False,
|
||||
has_granular_choices=False,
|
||||
has_cookie_wall=True,
|
||||
pre_ticked_boxes=True,
|
||||
privacy_policy_url=None,
|
||||
uncategorised_cookies=10,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.score == 0
|
||||
|
||||
def test_status_non_compliant_with_critical(self):
|
||||
ctx = SiteContext(blocking_mode="opt_out")
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.status == "non_compliant"
|
||||
|
||||
def test_status_partial_with_warnings_only(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url=None,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.status == "partial"
|
||||
|
||||
def test_status_compliant_with_no_issues(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.status == "compliant"
|
||||
|
||||
|
||||
# ── Framework registry ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFrameworkRegistry:
|
||||
def test_all_frameworks_registered(self):
|
||||
assert Framework.GDPR in FRAMEWORK_RULES
|
||||
assert Framework.CNIL in FRAMEWORK_RULES
|
||||
assert Framework.CCPA in FRAMEWORK_RULES
|
||||
assert Framework.EPRIVACY in FRAMEWORK_RULES
|
||||
assert Framework.LGPD in FRAMEWORK_RULES
|
||||
|
||||
def test_each_framework_has_rules(self):
|
||||
for fw, rules in FRAMEWORK_RULES.items():
|
||||
assert len(rules) > 0, f"{fw} has no rules"
|
||||
|
||||
def test_rule_ids_are_unique_per_framework(self):
|
||||
for fw, rules in FRAMEWORK_RULES.items():
|
||||
ids = [r.rule_id for r in rules]
|
||||
assert len(ids) == len(set(ids)), f"Duplicate rule IDs in {fw}"
|
||||
|
||||
def test_gdpr_rule_count(self):
|
||||
assert len(GDPR_RULES) == 7
|
||||
|
||||
def test_cnil_includes_gdpr_rules(self):
|
||||
gdpr_ids = {r.rule_id for r in GDPR_RULES}
|
||||
cnil_ids = {r.rule_id for r in CNIL_RULES}
|
||||
assert gdpr_ids.issubset(cnil_ids)
|
||||
|
||||
def test_ccpa_rule_count(self):
|
||||
assert len(CCPA_RULES) == 3
|
||||
|
||||
def test_eprivacy_rule_count(self):
|
||||
assert len(EPRIVACY_RULES) == 2
|
||||
|
||||
def test_lgpd_rule_count(self):
|
||||
assert len(LGPD_RULES) == 3
|
||||
|
||||
|
||||
# ── Router tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestComplianceRouter:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
from src.main import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
async def test_list_frameworks(self, client):
|
||||
resp = await client.get("/api/v1/compliance/frameworks")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 5
|
||||
ids = {fw["id"] for fw in data}
|
||||
assert ids == {"gdpr", "cnil", "ccpa", "eprivacy", "lgpd"}
|
||||
|
||||
async def test_check_requires_auth(self, client):
|
||||
resp = await client.post(f"/api/v1/compliance/check/{uuid.uuid4()}")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ── Schema tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
def test_compliance_issue_schema(self):
|
||||
issue = ComplianceIssue(
|
||||
rule_id="test_rule",
|
||||
severity=Severity.CRITICAL,
|
||||
message="Test message",
|
||||
recommendation="Test recommendation",
|
||||
)
|
||||
assert issue.rule_id == "test_rule"
|
||||
assert issue.severity == Severity.CRITICAL
|
||||
|
||||
def test_framework_result_schema(self):
|
||||
result = FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=85,
|
||||
status="partial",
|
||||
rules_checked=7,
|
||||
rules_passed=5,
|
||||
)
|
||||
assert result.framework == Framework.GDPR
|
||||
assert result.score == 85
|
||||
|
||||
def test_compliance_check_response_schema(self):
|
||||
response = ComplianceCheckResponse(
|
||||
site_id="test-id",
|
||||
results=[],
|
||||
overall_score=100,
|
||||
)
|
||||
assert response.overall_score == 100
|
||||
|
||||
def test_severity_values(self):
|
||||
assert Severity.CRITICAL == "critical"
|
||||
assert Severity.WARNING == "warning"
|
||||
assert Severity.INFO == "info"
|
||||
|
||||
def test_framework_values(self):
|
||||
assert Framework.GDPR == "gdpr"
|
||||
assert Framework.CNIL == "cnil"
|
||||
assert Framework.CCPA == "ccpa"
|
||||
assert Framework.EPRIVACY == "eprivacy"
|
||||
assert Framework.LGPD == "lgpd"
|
||||
258
apps/api/tests/test_config_resolver.py
Normal file
258
apps/api/tests/test_config_resolver.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Tests for configuration hierarchy resolver and publisher."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services.config_resolver import (
|
||||
SYSTEM_DEFAULTS,
|
||||
build_public_config,
|
||||
resolve_config,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveConfig:
|
||||
def test_returns_system_defaults_for_empty_config(self):
|
||||
result = resolve_config({})
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
assert result["consent_expiry_days"] == 365
|
||||
assert result["gcm_enabled"] is True
|
||||
assert result["tcf_enabled"] is False
|
||||
assert result["gpp_enabled"] is True
|
||||
assert result["gpp_supported_apis"] == ["usnat"]
|
||||
assert result["gpc_enabled"] is True
|
||||
assert result["gpc_jurisdictions"] == [
|
||||
"US-CA",
|
||||
"US-CO",
|
||||
"US-CT",
|
||||
"US-TX",
|
||||
"US-MT",
|
||||
]
|
||||
assert result["gpc_global_honour"] is False
|
||||
|
||||
def test_site_config_overrides_defaults(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_out",
|
||||
"consent_expiry_days": 180,
|
||||
"tcf_enabled": True,
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["blocking_mode"] == "opt_out"
|
||||
assert result["consent_expiry_days"] == 180
|
||||
assert result["tcf_enabled"] is True
|
||||
# Non-overridden values stay as defaults
|
||||
assert result["gcm_enabled"] is True
|
||||
|
||||
def test_org_defaults_override_system_defaults(self):
|
||||
org_defaults = {"consent_expiry_days": 90}
|
||||
result = resolve_config({}, org_defaults=org_defaults)
|
||||
assert result["consent_expiry_days"] == 90
|
||||
|
||||
def test_site_config_overrides_org_defaults(self):
|
||||
org_defaults = {"consent_expiry_days": 90}
|
||||
site_config = {"consent_expiry_days": 30}
|
||||
result = resolve_config(site_config, org_defaults=org_defaults)
|
||||
assert result["consent_expiry_days"] == 30
|
||||
|
||||
def test_none_values_in_site_config_do_not_override(self):
|
||||
site_config = {"blocking_mode": None, "consent_expiry_days": 180}
|
||||
result = resolve_config(site_config)
|
||||
# None should not override the default
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
assert result["consent_expiry_days"] == 180
|
||||
|
||||
def test_regional_override_applied(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"US-CA": "opt_out", "EU": "opt_in"},
|
||||
}
|
||||
result = resolve_config(site_config, region="US-CA")
|
||||
assert result["blocking_mode"] == "opt_out"
|
||||
|
||||
def test_regional_override_falls_back_to_default(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"EU": "opt_in", "DEFAULT": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config, region="BR")
|
||||
assert result["blocking_mode"] == "opt_out"
|
||||
|
||||
def test_regional_override_no_match_keeps_site_config(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"EU": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config, region="JP")
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
|
||||
def test_no_region_ignores_regional_modes(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"US-CA": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
|
||||
def test_gpp_site_config_overrides_defaults(self):
|
||||
site_config = {
|
||||
"gpp_enabled": False,
|
||||
"gpp_supported_apis": ["usca", "usva"],
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["gpp_enabled"] is False
|
||||
assert result["gpp_supported_apis"] == ["usca", "usva"]
|
||||
|
||||
def test_gpc_site_config_overrides_defaults(self):
|
||||
site_config = {
|
||||
"gpc_enabled": False,
|
||||
"gpc_global_honour": True,
|
||||
"gpc_jurisdictions": ["US-CA"],
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["gpc_enabled"] is False
|
||||
assert result["gpc_global_honour"] is True
|
||||
assert result["gpc_jurisdictions"] == ["US-CA"]
|
||||
|
||||
def test_gpp_gpc_org_defaults_override_system(self):
|
||||
org_defaults = {
|
||||
"gpp_enabled": False,
|
||||
"gpc_global_honour": True,
|
||||
}
|
||||
result = resolve_config({}, org_defaults=org_defaults)
|
||||
assert result["gpp_enabled"] is False
|
||||
assert result["gpc_global_honour"] is True
|
||||
# Non-overridden GPP/GPC fields stay as system defaults
|
||||
assert result["gpc_enabled"] is True
|
||||
|
||||
def test_gpp_gpc_site_overrides_org(self):
|
||||
org_defaults = {"gpp_supported_apis": ["usca"]}
|
||||
site_config = {"gpp_supported_apis": ["usnat", "usco"]}
|
||||
result = resolve_config(site_config, org_defaults=org_defaults)
|
||||
assert result["gpp_supported_apis"] == ["usnat", "usco"]
|
||||
|
||||
def test_group_defaults_override_org_defaults(self):
|
||||
org_defaults = {"consent_expiry_days": 90, "tcf_enabled": True}
|
||||
group_defaults = {"consent_expiry_days": 60}
|
||||
result = resolve_config(
|
||||
{},
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
assert result["consent_expiry_days"] == 60 # Group overrides org
|
||||
assert result["tcf_enabled"] is True # Still from org
|
||||
|
||||
def test_site_config_overrides_group_defaults(self):
|
||||
group_defaults = {"consent_expiry_days": 60}
|
||||
site_config = {"consent_expiry_days": 30}
|
||||
result = resolve_config(site_config, group_defaults=group_defaults)
|
||||
assert result["consent_expiry_days"] == 30 # Site overrides group
|
||||
|
||||
def test_none_in_group_defaults_does_not_override(self):
|
||||
org_defaults = {"consent_expiry_days": 90}
|
||||
group_defaults = {"consent_expiry_days": None}
|
||||
result = resolve_config(
|
||||
{},
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
assert result["consent_expiry_days"] == 90 # Org value preserved
|
||||
|
||||
def test_full_hierarchy(self):
|
||||
org_defaults = {
|
||||
"consent_expiry_days": 90,
|
||||
"tcf_enabled": True,
|
||||
}
|
||||
site_config = {
|
||||
"consent_expiry_days": 60,
|
||||
"banner_config": {"primaryColour": "#ff0000"},
|
||||
"regional_modes": {"EU": "opt_in", "US": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config, org_defaults=org_defaults, region="US")
|
||||
assert result["consent_expiry_days"] == 60 # Site overrides org
|
||||
assert result["tcf_enabled"] is True # From org defaults
|
||||
assert result["blocking_mode"] == "opt_out" # Regional override
|
||||
assert result["banner_config"] == {"primaryColour": "#ff0000"}
|
||||
|
||||
def test_full_hierarchy_with_group(self):
|
||||
org_defaults = {
|
||||
"consent_expiry_days": 90,
|
||||
"tcf_enabled": True,
|
||||
"blocking_mode": "opt_in",
|
||||
}
|
||||
group_defaults = {
|
||||
"consent_expiry_days": 60,
|
||||
"privacy_policy_url": "https://group.example.com/privacy",
|
||||
}
|
||||
site_config = {
|
||||
"banner_config": {"primaryColour": "#ff0000"},
|
||||
"regional_modes": {"US": "opt_out"},
|
||||
}
|
||||
result = resolve_config(
|
||||
site_config,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
region="US",
|
||||
)
|
||||
assert result["consent_expiry_days"] == 60 # From group
|
||||
assert result["tcf_enabled"] is True # From org
|
||||
assert result["blocking_mode"] == "opt_out" # Regional override
|
||||
assert result["privacy_policy_url"] == "https://group.example.com/privacy" # From group
|
||||
assert result["banner_config"] == {"primaryColour": "#ff0000"} # From site
|
||||
|
||||
|
||||
class TestBuildPublicConfig:
|
||||
def test_includes_required_fields(self):
|
||||
site_id = str(uuid.uuid4())
|
||||
resolved = {**SYSTEM_DEFAULTS, "id": "config-123"}
|
||||
result = build_public_config(site_id, resolved)
|
||||
|
||||
assert result["site_id"] == site_id
|
||||
assert result["id"] == "config-123"
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
assert result["consent_expiry_days"] == 365
|
||||
assert "gcm_enabled" in result
|
||||
assert "tcf_enabled" in result
|
||||
assert "banner_config" in result
|
||||
assert result["gpp_enabled"] is True
|
||||
assert result["gpp_supported_apis"] == ["usnat"]
|
||||
assert result["gpc_enabled"] is True
|
||||
assert result["gpc_jurisdictions"] == [
|
||||
"US-CA",
|
||||
"US-CO",
|
||||
"US-CT",
|
||||
"US-TX",
|
||||
"US-MT",
|
||||
]
|
||||
assert result["gpc_global_honour"] is False
|
||||
|
||||
def test_strips_unknown_internal_fields(self):
|
||||
site_id = str(uuid.uuid4())
|
||||
resolved = {
|
||||
**SYSTEM_DEFAULTS,
|
||||
"id": "",
|
||||
"internal_field": "should_not_appear",
|
||||
"scan_enabled": True,
|
||||
}
|
||||
result = build_public_config(site_id, resolved)
|
||||
assert "internal_field" not in result
|
||||
assert "scan_enabled" not in result
|
||||
|
||||
|
||||
class TestConfigRoutes:
|
||||
def test_resolved_config_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/config/sites/{site_id}/resolved" in routes
|
||||
|
||||
def test_publish_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/config/sites/{site_id}/publish" in routes
|
||||
|
||||
def test_inheritance_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/config/sites/{site_id}/inheritance" in routes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(f"/api/v1/config/sites/{site_id}/publish")
|
||||
assert resp.status_code == 401
|
||||
130
apps/api/tests/test_consent.py
Normal file
130
apps/api/tests/test_consent.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Tests for consent recording API schemas and routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.schemas.consent import (
|
||||
ConsentAction,
|
||||
ConsentRecordCreate,
|
||||
ConsentRecordResponse,
|
||||
ConsentVerifyResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestConsentSchemas:
|
||||
def test_create_accept_all(self):
|
||||
record = ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="visitor-abc-123",
|
||||
action=ConsentAction.ACCEPT_ALL,
|
||||
categories_accepted=["necessary", "analytics", "marketing"],
|
||||
)
|
||||
assert record.action == "accept_all"
|
||||
assert len(record.categories_accepted) == 3
|
||||
assert record.categories_rejected is None
|
||||
|
||||
def test_create_custom(self):
|
||||
record = ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="visitor-xyz",
|
||||
action=ConsentAction.CUSTOM,
|
||||
categories_accepted=["necessary", "functional"],
|
||||
categories_rejected=["analytics", "marketing"],
|
||||
tc_string="COwQHgAAAAA",
|
||||
gcm_state={"analytics_storage": "denied", "ad_storage": "denied"},
|
||||
page_url="https://example.com/page",
|
||||
country_code="GB",
|
||||
region_code="GB-ENG",
|
||||
)
|
||||
assert record.action == "custom"
|
||||
assert record.tc_string == "COwQHgAAAAA"
|
||||
assert record.gcm_state["analytics_storage"] == "denied"
|
||||
|
||||
def test_create_reject_all(self):
|
||||
record = ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action=ConsentAction.REJECT_ALL,
|
||||
categories_accepted=["necessary"],
|
||||
categories_rejected=["analytics", "marketing", "functional"],
|
||||
)
|
||||
assert record.action == "reject_all"
|
||||
|
||||
def test_empty_visitor_id_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="",
|
||||
action=ConsentAction.ACCEPT_ALL,
|
||||
categories_accepted=["necessary"],
|
||||
)
|
||||
|
||||
def test_invalid_action_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action="invalid_action",
|
||||
categories_accepted=[],
|
||||
)
|
||||
|
||||
def test_response_from_attributes(self):
|
||||
resp = ConsentRecordResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action="accept_all",
|
||||
categories_accepted=["necessary"],
|
||||
categories_rejected=None,
|
||||
tc_string=None,
|
||||
gcm_state=None,
|
||||
page_url=None,
|
||||
country_code=None,
|
||||
region_code=None,
|
||||
consented_at="2026-01-01T00:00:00Z",
|
||||
)
|
||||
assert resp.action == "accept_all"
|
||||
|
||||
def test_verify_response(self):
|
||||
resp = ConsentVerifyResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action="accept_all",
|
||||
categories_accepted=["necessary"],
|
||||
consented_at="2026-01-01T00:00:00Z",
|
||||
)
|
||||
assert resp.valid is True
|
||||
|
||||
|
||||
class TestConsentActions:
|
||||
def test_action_values(self):
|
||||
assert ConsentAction.ACCEPT_ALL == "accept_all"
|
||||
assert ConsentAction.REJECT_ALL == "reject_all"
|
||||
assert ConsentAction.CUSTOM == "custom"
|
||||
assert ConsentAction.WITHDRAW == "withdraw"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestConsentRoutesRegistered:
|
||||
async def test_consent_routes_exist(self, client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/api/v1/consent/" in paths
|
||||
assert "/api/v1/consent/{consent_id}" in paths
|
||||
assert "/api/v1/consent/verify/{consent_id}" in paths
|
||||
|
||||
async def test_consent_post_validates_body(self, client):
|
||||
"""POST /consent rejects invalid payloads."""
|
||||
response = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={"invalid": "body"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_config_public_endpoint_exists(self, client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/api/v1/config/sites/{site_id}" in paths
|
||||
218
apps/api/tests/test_cookies.py
Normal file
218
apps/api/tests/test_cookies.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tests for cookie category, cookie, and allow-list schemas and routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.schemas.cookie import (
|
||||
AllowListEntryCreate,
|
||||
AllowListEntryUpdate,
|
||||
CookieCategoryResponse,
|
||||
CookieCreate,
|
||||
CookieResponse,
|
||||
CookieUpdate,
|
||||
ReviewStatus,
|
||||
StorageType,
|
||||
)
|
||||
|
||||
# ─── Schema tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStorageType:
|
||||
def test_values(self):
|
||||
assert StorageType.cookie == "cookie"
|
||||
assert StorageType.local_storage == "local_storage"
|
||||
assert StorageType.session_storage == "session_storage"
|
||||
assert StorageType.indexed_db == "indexed_db"
|
||||
|
||||
|
||||
class TestReviewStatus:
|
||||
def test_values(self):
|
||||
assert ReviewStatus.pending == "pending"
|
||||
assert ReviewStatus.approved == "approved"
|
||||
assert ReviewStatus.rejected == "rejected"
|
||||
|
||||
|
||||
class TestCookieCreate:
|
||||
def test_valid_minimal(self):
|
||||
schema = CookieCreate(name="_ga", domain=".example.com")
|
||||
assert schema.name == "_ga"
|
||||
assert schema.domain == ".example.com"
|
||||
assert schema.storage_type == StorageType.cookie
|
||||
assert schema.category_id is None
|
||||
|
||||
def test_valid_full(self):
|
||||
cat_id = uuid.uuid4()
|
||||
schema = CookieCreate(
|
||||
name="_ga",
|
||||
domain=".google.com",
|
||||
storage_type=StorageType.cookie,
|
||||
category_id=cat_id,
|
||||
description="Google Analytics cookie",
|
||||
vendor="Google",
|
||||
path="/",
|
||||
max_age_seconds=63072000,
|
||||
is_http_only=False,
|
||||
is_secure=True,
|
||||
same_site="Lax",
|
||||
)
|
||||
assert schema.category_id == cat_id
|
||||
assert schema.max_age_seconds == 63072000
|
||||
|
||||
def test_rejects_empty_name(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CookieCreate(name="", domain=".example.com")
|
||||
|
||||
def test_rejects_empty_domain(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CookieCreate(name="_ga", domain="")
|
||||
|
||||
|
||||
class TestCookieUpdate:
|
||||
def test_partial_update(self):
|
||||
schema = CookieUpdate(review_status=ReviewStatus.approved)
|
||||
dump = schema.model_dump(exclude_unset=True)
|
||||
assert dump == {"review_status": ReviewStatus.approved}
|
||||
|
||||
def test_update_category(self):
|
||||
cat_id = uuid.uuid4()
|
||||
schema = CookieUpdate(category_id=cat_id)
|
||||
assert schema.category_id == cat_id
|
||||
|
||||
|
||||
class TestAllowListEntryCreate:
|
||||
def test_valid(self):
|
||||
cat_id = uuid.uuid4()
|
||||
schema = AllowListEntryCreate(
|
||||
name_pattern="_ga*",
|
||||
domain_pattern=".google.com",
|
||||
category_id=cat_id,
|
||||
description="Google Analytics cookies",
|
||||
)
|
||||
assert schema.name_pattern == "_ga*"
|
||||
assert schema.category_id == cat_id
|
||||
|
||||
def test_rejects_empty_name_pattern(self):
|
||||
with pytest.raises(ValidationError):
|
||||
AllowListEntryCreate(
|
||||
name_pattern="",
|
||||
domain_pattern=".example.com",
|
||||
category_id=uuid.uuid4(),
|
||||
)
|
||||
|
||||
|
||||
class TestAllowListEntryUpdate:
|
||||
def test_partial_update(self):
|
||||
schema = AllowListEntryUpdate(description="Updated description")
|
||||
dump = schema.model_dump(exclude_unset=True)
|
||||
assert dump == {"description": "Updated description"}
|
||||
|
||||
|
||||
class TestCookieCategoryResponse:
|
||||
def test_from_dict(self):
|
||||
now = "2024-01-01T00:00:00"
|
||||
resp = CookieCategoryResponse(
|
||||
id=uuid.uuid4(),
|
||||
name="Analytics",
|
||||
slug="analytics",
|
||||
description="Analytics cookies",
|
||||
is_essential=False,
|
||||
display_order=2,
|
||||
tcf_purpose_ids=[1, 3],
|
||||
gcm_consent_types=["analytics_storage"],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.slug == "analytics"
|
||||
assert resp.is_essential is False
|
||||
|
||||
|
||||
class TestCookieResponse:
|
||||
def test_from_dict(self):
|
||||
now = "2024-01-01T00:00:00"
|
||||
resp = CookieResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
name="_ga",
|
||||
domain=".google.com",
|
||||
storage_type="cookie",
|
||||
review_status="pending",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.name == "_ga"
|
||||
assert resp.review_status == "pending"
|
||||
|
||||
|
||||
# ─── Route tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCookieCategoryRoutes:
|
||||
def test_categories_route_registered(self, app):
|
||||
"""Verify the categories routes are registered in the app."""
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/cookies/categories" in routes
|
||||
assert "/api/v1/cookies/categories/{category_id}" in routes
|
||||
|
||||
|
||||
class TestCookieRoutes:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_cookies_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site_id}")
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cookie_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
json={"name": "_ga", "domain": ".google.com"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cookie_rejects_invalid_body(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
json={"name": "", "domain": ""},
|
||||
headers={"Authorization": "Bearer fake-token"},
|
||||
)
|
||||
# Should return 401 (bad token) or 422 (validation)
|
||||
assert resp.status_code in (401, 422)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_route_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/summary")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestAllowListRoutes:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_allow_list_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/allow-list")
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_allow_list_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list",
|
||||
json={
|
||||
"name_pattern": "_ga*",
|
||||
"domain_pattern": ".google.com",
|
||||
"category_id": str(uuid.uuid4()),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_allow_list_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
entry_id = uuid.uuid4()
|
||||
resp = await client.delete(f"/api/v1/cookies/sites/{site_id}/allow-list/{entry_id}")
|
||||
assert resp.status_code == 401
|
||||
222
apps/api/tests/test_cors.py
Normal file
222
apps/api/tests/test_cors.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Tests for the dynamic CORS origin validation service."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services.cors import extract_domain_from_origin, get_allowed_domains, is_origin_allowed
|
||||
|
||||
|
||||
class TestExtractDomainFromOrigin:
|
||||
def test_https_origin(self):
|
||||
assert extract_domain_from_origin("https://example.com") == "example.com"
|
||||
|
||||
def test_http_origin(self):
|
||||
assert extract_domain_from_origin("http://example.com") == "example.com"
|
||||
|
||||
def test_origin_with_port(self):
|
||||
assert extract_domain_from_origin("https://example.com:443") == "example.com"
|
||||
|
||||
def test_origin_with_subdomain(self):
|
||||
assert extract_domain_from_origin("https://www.example.com") == "www.example.com"
|
||||
|
||||
def test_localhost(self):
|
||||
assert extract_domain_from_origin("http://localhost:5173") == "localhost"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert extract_domain_from_origin("") is None
|
||||
|
||||
def test_invalid_url(self):
|
||||
# urlparse is lenient, but hostname may be None for really bad input
|
||||
result = extract_domain_from_origin("not-a-url")
|
||||
# urlparse("not-a-url") sets hostname to None
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsOriginAllowed:
|
||||
def test_static_origin_exact_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"http://localhost:5173",
|
||||
["http://localhost:5173"],
|
||||
set(),
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_static_origin_no_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://evil.com",
|
||||
["http://localhost:5173"],
|
||||
set(),
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_wildcard_allows_everything(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://anything.com",
|
||||
["*"],
|
||||
set(),
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_registered_domain_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://example.com",
|
||||
[],
|
||||
{"example.com", "other.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_registered_domain_case_insensitive(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://Example.COM",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_registered_domain_no_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://evil.com",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_static_takes_priority(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"http://localhost:5173",
|
||||
["http://localhost:5173"],
|
||||
{"example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_origin_with_port_matches_domain(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://example.com:8443",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_subdomain_matches_if_registered(self):
|
||||
# www.example.com only matches if explicitly registered
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://www.example.com",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_subdomain_matches_when_registered(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://www.example.com",
|
||||
[],
|
||||
{"www.example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_empty_origin(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_empty_lists(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://example.com",
|
||||
[],
|
||||
set(),
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
class TestGetAllowedDomains:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_primary_domains(self):
|
||||
row1 = MagicMock()
|
||||
row1.domain = "example.com"
|
||||
row1.additional_domains = None
|
||||
|
||||
row2 = MagicMock()
|
||||
row2.domain = "other.com"
|
||||
row2.additional_domains = None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = [row1, row2]
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert "example.com" in domains
|
||||
assert "other.com" in domains
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_includes_additional_domains(self):
|
||||
row = MagicMock()
|
||||
row.domain = "example.com"
|
||||
row.additional_domains = ["www.example.com", "app.example.com"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = [row]
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert "example.com" in domains
|
||||
assert "www.example.com" in domains
|
||||
assert "app.example.com" in domains
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lowercases_domains(self):
|
||||
row = MagicMock()
|
||||
row.domain = "Example.COM"
|
||||
row.additional_domains = ["WWW.Example.COM"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = [row]
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert "example.com" in domains
|
||||
assert "www.example.com" in domains
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_result(self):
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = []
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert domains == set()
|
||||
88
apps/api/tests/test_dependencies.py
Normal file
88
apps/api/tests/test_dependencies.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Unit tests for auth dependencies."""
|
||||
|
||||
import uuid
|
||||
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.services.auth import create_access_token, create_refresh_token, decode_token
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
def test_has_role_matching(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="admin",
|
||||
)
|
||||
assert user.has_role("admin", "owner") is True
|
||||
|
||||
def test_has_role_not_matching(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="viewer",
|
||||
)
|
||||
assert user.has_role("admin", "owner") is False
|
||||
|
||||
def test_is_admin_property(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="admin",
|
||||
)
|
||||
assert user.is_admin is True
|
||||
|
||||
def test_is_admin_owner(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="owner",
|
||||
)
|
||||
assert user.is_admin is True
|
||||
|
||||
def test_is_admin_viewer(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="viewer",
|
||||
)
|
||||
assert user.is_admin is False
|
||||
|
||||
|
||||
class TestTokenCreation:
|
||||
def test_access_token_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
role="editor",
|
||||
email="test@test.com",
|
||||
)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["org_id"] == str(org_id)
|
||||
assert payload["role"] == "editor"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_refresh_token_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
token = create_refresh_token(user_id=user_id, organisation_id=org_id)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_access_token_is_not_refresh(self):
|
||||
token = create_access_token(
|
||||
user_id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
role="viewer",
|
||||
email="test@test.com",
|
||||
)
|
||||
payload = decode_token(token)
|
||||
assert payload["type"] != "refresh"
|
||||
141
apps/api/tests/test_extensions.py
Normal file
141
apps/api/tests/test_extensions.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Tests for the extension registry and edition detection."""
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
from src.config.edition import edition_name, is_ee
|
||||
from src.extensions.registry import (
|
||||
ExtensionRegistry,
|
||||
OpenAPITag,
|
||||
discover_extensions,
|
||||
get_registry,
|
||||
)
|
||||
|
||||
# -- Edition detection -------------------------------------------------------
|
||||
|
||||
|
||||
class TestEditionDetection:
|
||||
"""The ``is_ee()`` / ``edition_name()`` helpers should return a
|
||||
consistent pair regardless of which edition is installed. Core tests
|
||||
don't assume a specific edition — that's checked in each repo's
|
||||
own integration tests."""
|
||||
|
||||
def test_edition_name_matches_is_ee(self):
|
||||
assert edition_name() == ("ee" if is_ee() else "ce")
|
||||
|
||||
def test_edition_name_is_valid(self):
|
||||
assert edition_name() in ("ce", "ee")
|
||||
|
||||
|
||||
# -- Extension registry (unit) ----------------------------------------------
|
||||
|
||||
|
||||
class TestExtensionRegistry:
|
||||
def _make_registry(self) -> ExtensionRegistry:
|
||||
return ExtensionRegistry()
|
||||
|
||||
def test_empty_registry(self):
|
||||
reg = self._make_registry()
|
||||
assert reg.routers == []
|
||||
assert reg.model_modules == []
|
||||
assert reg.startup_hooks == []
|
||||
|
||||
def test_add_router(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
reg.add_router(router, prefix="/api/v1")
|
||||
assert len(reg.routers) == 1
|
||||
assert reg.routers[0].router is router
|
||||
assert reg.routers[0].prefix == "/api/v1"
|
||||
|
||||
def test_add_router_with_tags(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
tag = OpenAPITag(name="billing", description="Billing endpoints")
|
||||
reg.add_router(router, tags=[tag])
|
||||
assert reg.routers[0].tags == [tag]
|
||||
|
||||
def test_add_model_module(self):
|
||||
reg = self._make_registry()
|
||||
reg.add_model_module("ee.api.src.models.billing")
|
||||
assert reg.model_modules == ["ee.api.src.models.billing"]
|
||||
|
||||
def test_add_startup_hook(self):
|
||||
reg = self._make_registry()
|
||||
|
||||
async def hook(app: FastAPI) -> None:
|
||||
pass
|
||||
|
||||
reg.add_startup_hook(hook)
|
||||
assert len(reg.startup_hooks) == 1
|
||||
|
||||
def test_apply_mounts_routers(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/test")
|
||||
async def _test() -> dict[str, str]:
|
||||
return {"ok": True}
|
||||
|
||||
reg.add_router(router, prefix="/ext")
|
||||
|
||||
app = FastAPI()
|
||||
reg.apply(app)
|
||||
|
||||
# The router should be included in the app routes
|
||||
paths = [r.path for r in app.routes]
|
||||
assert "/ext/test" in paths
|
||||
|
||||
def test_apply_adds_openapi_tags(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
tag = OpenAPITag(name="billing", description="Billing endpoints")
|
||||
reg.add_router(router, tags=[tag])
|
||||
|
||||
app = FastAPI()
|
||||
app.openapi_tags = []
|
||||
reg.apply(app)
|
||||
|
||||
assert any(t["name"] == "billing" for t in app.openapi_tags)
|
||||
|
||||
def test_apply_skips_duplicate_tags(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
tag = OpenAPITag(name="billing", description="Billing endpoints")
|
||||
reg.add_router(router, tags=[tag])
|
||||
|
||||
app = FastAPI()
|
||||
app.openapi_tags = [{"name": "billing", "description": "Existing"}]
|
||||
reg.apply(app)
|
||||
|
||||
billing_tags = [t for t in app.openapi_tags if t["name"] == "billing"]
|
||||
assert len(billing_tags) == 1
|
||||
assert billing_tags[0]["description"] == "Existing"
|
||||
|
||||
|
||||
# -- discover_extensions -----------------------------------------------------
|
||||
|
||||
|
||||
class TestDiscoverExtensions:
|
||||
def test_discover_extensions_does_not_raise(self):
|
||||
"""discover_extensions should not raise regardless of edition."""
|
||||
discover_extensions()
|
||||
|
||||
|
||||
# -- Global registry ---------------------------------------------------------
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
def test_get_registry_returns_singleton(self):
|
||||
assert get_registry() is get_registry()
|
||||
|
||||
|
||||
# -- Health endpoint with edition field --------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_reports_edition(client):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["edition"] in ("ce", "ee")
|
||||
573
apps/api/tests/test_geoip.py
Normal file
573
apps/api/tests/test_geoip.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""Tests for the GeoIP service.
|
||||
|
||||
Covers header-based detection, IP lookup, country-to-region mapping,
|
||||
client IP extraction, and the combined detect_region flow.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
import src.services.geoip as geoip_module
|
||||
from src.services.geoip import (
|
||||
GeoResult,
|
||||
_is_private_ip,
|
||||
country_to_region,
|
||||
detect_region,
|
||||
detect_region_from_headers,
|
||||
get_client_ip,
|
||||
lookup_ip_maxmind,
|
||||
lookup_ip_region,
|
||||
)
|
||||
|
||||
# ── country_to_region ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountryToRegion:
|
||||
def test_eu_country_returns_eu(self):
|
||||
assert country_to_region("DE") == "EU"
|
||||
assert country_to_region("FR") == "EU"
|
||||
assert country_to_region("IT") == "EU"
|
||||
assert country_to_region("ES") == "EU"
|
||||
|
||||
def test_eu_country_case_insensitive(self):
|
||||
assert country_to_region("de") == "EU"
|
||||
assert country_to_region("fr") == "EU"
|
||||
|
||||
def test_gb_returns_gb(self):
|
||||
assert country_to_region("GB") == "GB"
|
||||
|
||||
def test_br_returns_br(self):
|
||||
assert country_to_region("BR") == "BR"
|
||||
|
||||
def test_us_without_state(self):
|
||||
assert country_to_region("US") == "US"
|
||||
|
||||
def test_us_with_state(self):
|
||||
assert country_to_region("US", "CA") == "US-CA"
|
||||
assert country_to_region("US", "ny") == "US-NY"
|
||||
|
||||
def test_non_eu_country_returned_as_is(self):
|
||||
assert country_to_region("JP") == "JP"
|
||||
assert country_to_region("AU") == "AU"
|
||||
assert country_to_region("CA") == "CA"
|
||||
|
||||
|
||||
# ── detect_region_from_headers ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestDetectRegionFromHeaders:
|
||||
def _make_request(self, headers: dict[str, str]) -> MagicMock:
|
||||
request = MagicMock()
|
||||
request.headers = headers
|
||||
return request
|
||||
|
||||
def test_cloudflare_header(self):
|
||||
request = self._make_request({"cf-ipcountry": "DE"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "DE"
|
||||
assert result.region == "EU"
|
||||
assert result.is_resolved is True
|
||||
|
||||
def test_vercel_header(self):
|
||||
request = self._make_request({"x-vercel-ip-country": "GB"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "GB"
|
||||
assert result.region == "GB"
|
||||
|
||||
def test_appengine_header(self):
|
||||
request = self._make_request({"x-appengine-country": "BR"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "BR"
|
||||
assert result.region == "BR"
|
||||
|
||||
def test_custom_header(self):
|
||||
request = self._make_request({"x-country-code": "JP"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "JP"
|
||||
assert result.region == "JP"
|
||||
|
||||
def test_no_geo_headers(self):
|
||||
request = self._make_request({})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code is None
|
||||
assert result.region is None
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_ignores_xx_value(self):
|
||||
request = self._make_request({"cf-ipcountry": "XX"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_header_priority_cloudflare_first(self):
|
||||
request = self._make_request(
|
||||
{
|
||||
"cf-ipcountry": "FR",
|
||||
"x-vercel-ip-country": "DE",
|
||||
}
|
||||
)
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "FR"
|
||||
|
||||
def test_case_normalisation(self):
|
||||
request = self._make_request({"cf-ipcountry": "gb"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "GB"
|
||||
assert result.region == "GB"
|
||||
|
||||
def test_configured_custom_header(self):
|
||||
"""An operator-configured header is honoured."""
|
||||
request = self._make_request({"x-gclb-country": "JP"})
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "JP"
|
||||
assert result.region == "JP"
|
||||
|
||||
def test_configured_custom_header_takes_priority(self):
|
||||
"""When both a custom and a built-in header are present, the
|
||||
custom one wins — that's the operator's explicit choice."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"cf-ipcountry": "FR",
|
||||
"x-gclb-country": "JP",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "JP"
|
||||
|
||||
def test_configured_header_falls_through_to_builtin(self):
|
||||
"""If the custom header isn't present, the built-in list still
|
||||
applies."""
|
||||
request = self._make_request({"cf-ipcountry": "FR"})
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = None
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "FR"
|
||||
assert result.region == "EU"
|
||||
|
||||
def test_configured_region_header_pairs_with_country(self):
|
||||
"""A configured region header is paired with the custom country."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-gclb-country": "US",
|
||||
"x-gclb-region": "CA",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
|
||||
def test_configured_region_header_strips_country_prefix(self):
|
||||
"""ISO 3166-2 subdivisions may arrive prefixed (``US-CA``)."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-gclb-country": "US",
|
||||
"x-gclb-region": "US-NY",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.region == "US-NY"
|
||||
|
||||
def test_configured_region_header_missing_is_country_only(self):
|
||||
"""Only country hits region-aware path if the region header is absent."""
|
||||
request = self._make_request({"x-gclb-country": "US"})
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US"
|
||||
|
||||
def test_configured_region_header_xx_ignored(self):
|
||||
"""Region value of ``XX`` is treated as unknown."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-gclb-country": "US",
|
||||
"x-gclb-region": "XX",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.region == "US"
|
||||
|
||||
|
||||
# ── get_client_ip ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
def _make_request(
|
||||
self,
|
||||
headers: dict[str, str] | None = None,
|
||||
client_host: str | None = None,
|
||||
) -> MagicMock:
|
||||
request = MagicMock()
|
||||
request.headers = headers or {}
|
||||
if client_host:
|
||||
request.client = MagicMock()
|
||||
request.client.host = client_host
|
||||
else:
|
||||
request.client = None
|
||||
return request
|
||||
|
||||
def test_x_forwarded_for_single(self):
|
||||
request = self._make_request({"x-forwarded-for": "1.2.3.4"})
|
||||
assert get_client_ip(request) == "1.2.3.4"
|
||||
|
||||
def test_x_forwarded_for_multiple(self):
|
||||
request = self._make_request({"x-forwarded-for": "1.2.3.4, 5.6.7.8, 9.10.11.12"})
|
||||
assert get_client_ip(request) == "1.2.3.4"
|
||||
|
||||
def test_x_real_ip(self):
|
||||
request = self._make_request({"x-real-ip": "1.2.3.4"})
|
||||
assert get_client_ip(request) == "1.2.3.4"
|
||||
|
||||
def test_forwarded_for_takes_priority_over_real_ip(self):
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-forwarded-for": "1.1.1.1",
|
||||
"x-real-ip": "2.2.2.2",
|
||||
}
|
||||
)
|
||||
assert get_client_ip(request) == "1.1.1.1"
|
||||
|
||||
def test_falls_back_to_client_host(self):
|
||||
request = self._make_request(client_host="10.0.0.1")
|
||||
assert get_client_ip(request) == "10.0.0.1"
|
||||
|
||||
def test_returns_none_when_no_ip(self):
|
||||
request = self._make_request()
|
||||
assert get_client_ip(request) is None
|
||||
|
||||
|
||||
# ── _is_private_ip ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsPrivateIp:
|
||||
def test_loopback(self):
|
||||
assert _is_private_ip("127.0.0.1") is True
|
||||
assert _is_private_ip("127.0.0.2") is True
|
||||
|
||||
def test_private_ranges(self):
|
||||
assert _is_private_ip("10.0.0.1") is True
|
||||
assert _is_private_ip("192.168.1.1") is True
|
||||
assert _is_private_ip("172.16.0.1") is True
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
assert _is_private_ip("::1") is True
|
||||
|
||||
def test_localhost_string(self):
|
||||
assert _is_private_ip("localhost") is True
|
||||
|
||||
def test_public_ip(self):
|
||||
assert _is_private_ip("8.8.8.8") is False
|
||||
assert _is_private_ip("1.1.1.1") is False
|
||||
|
||||
|
||||
# ── lookup_ip_region ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLookupIpRegion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_ip_returns_unresolved(self):
|
||||
result = await lookup_ip_region("127.0.0.1")
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_ip_10_range(self):
|
||||
result = await lookup_ip_region("10.0.0.1")
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_lookup(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "success",
|
||||
"countryCode": "DE",
|
||||
"region": "BY",
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.country_code == "DE"
|
||||
assert result.region == "EU"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_status(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "fail", "message": "invalid query"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_exception(self):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_us_with_state_lookup(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "success",
|
||||
"countryCode": "US",
|
||||
"region": "CA",
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_country_code(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
|
||||
# ── detect_region (combined) ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDetectRegion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_headers_when_available(self):
|
||||
request = MagicMock()
|
||||
request.headers = {"cf-ipcountry": "FR"}
|
||||
request.client = None
|
||||
|
||||
result = await detect_region(request)
|
||||
assert result.country_code == "FR"
|
||||
assert result.region == "EU"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_ip_lookup(self):
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.client = MagicMock()
|
||||
request.client.host = "8.8.8.8"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "success",
|
||||
"countryCode": "US",
|
||||
"region": "CA",
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await detect_region(request)
|
||||
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_unresolved_when_no_ip(self):
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.client = None
|
||||
|
||||
result = await detect_region(request)
|
||||
assert result.is_resolved is False
|
||||
|
||||
|
||||
# ── GeoResult ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGeoResult:
|
||||
def test_is_resolved_true(self):
|
||||
result = GeoResult(country_code="GB", region="GB")
|
||||
assert result.is_resolved is True
|
||||
|
||||
def test_is_resolved_false(self):
|
||||
result = GeoResult(country_code=None, region=None)
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_frozen_dataclass(self):
|
||||
result = GeoResult(country_code="GB", region="GB")
|
||||
with pytest.raises(AttributeError):
|
||||
result.country_code = "US" # type: ignore[misc]
|
||||
|
||||
|
||||
# ── MaxMind database lookup ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLookupIpMaxmind:
|
||||
def setup_method(self):
|
||||
# Reset the module-level cache so each test starts clean.
|
||||
geoip_module._maxmind_reader = None
|
||||
geoip_module._maxmind_initialised = False
|
||||
|
||||
def _mock_reader(self, country_iso: str | None, subdivision_iso: str | None):
|
||||
reader = MagicMock()
|
||||
response = MagicMock()
|
||||
response.country.iso_code = country_iso
|
||||
if subdivision_iso is None:
|
||||
response.subdivisions = None
|
||||
else:
|
||||
response.subdivisions.most_specific.iso_code = subdivision_iso
|
||||
reader.city.return_value = response
|
||||
return reader
|
||||
|
||||
def test_private_ip_returns_unresolved(self):
|
||||
result = lookup_ip_maxmind("10.0.0.1")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_no_db_configured_returns_unresolved(self):
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_maxmind_db_path = None
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_successful_lookup_with_subdivision(self):
|
||||
reader = self._mock_reader("US", "CA")
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
reader.city.assert_called_once_with("8.8.8.8")
|
||||
|
||||
def test_successful_lookup_without_subdivision(self):
|
||||
reader = self._mock_reader("DE", None)
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.country_code == "DE"
|
||||
assert result.region == "EU"
|
||||
|
||||
def test_reader_raises_returns_unresolved(self):
|
||||
reader = MagicMock()
|
||||
reader.city.side_effect = RuntimeError("corrupt db")
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_reader_missing_country_returns_unresolved(self):
|
||||
reader = self._mock_reader(None, None)
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_bad_db_path_is_cached_as_failure(self):
|
||||
"""A missing ``.mmdb`` file should not reopen on every request."""
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_maxmind_db_path = "/nonexistent/geo.mmdb"
|
||||
r1 = lookup_ip_maxmind("8.8.8.8")
|
||||
r2 = lookup_ip_maxmind("1.1.1.1")
|
||||
assert r1.is_resolved is False
|
||||
assert r2.is_resolved is False
|
||||
assert geoip_module._maxmind_initialised is True
|
||||
assert geoip_module._maxmind_reader is None
|
||||
|
||||
|
||||
class TestDetectRegionMaxmind:
|
||||
def setup_method(self):
|
||||
geoip_module._maxmind_reader = None
|
||||
geoip_module._maxmind_initialised = False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_maxmind_before_external_api(self):
|
||||
"""With MaxMind configured, ip-api.com must not be called."""
|
||||
reader = MagicMock()
|
||||
response = MagicMock()
|
||||
response.country.iso_code = "GB"
|
||||
response.subdivisions.most_specific.iso_code = "SCT"
|
||||
reader.city.return_value = response
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"x-forwarded-for": "8.8.8.8"}
|
||||
request.client = None
|
||||
|
||||
with (
|
||||
patch("src.services.geoip.get_settings") as mock_settings,
|
||||
patch("src.services.geoip.httpx.AsyncClient") as mock_httpx,
|
||||
):
|
||||
mock_settings.return_value.geoip_country_header = None
|
||||
mock_settings.return_value.geoip_region_header = None
|
||||
mock_settings.return_value.geoip_maxmind_db_path = "/data/GeoLite2-City.mmdb"
|
||||
|
||||
result = await detect_region(request)
|
||||
|
||||
assert result.country_code == "GB"
|
||||
assert result.region == "GB-SCT"
|
||||
mock_httpx.assert_not_called()
|
||||
31
apps/api/tests/test_health.py
Normal file
31
apps/api/tests/test_health.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(client):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["edition"] in ("ce", "ee")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_schema(client):
|
||||
response = await client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
schema = response.json()
|
||||
assert schema["info"]["title"] == "ConsentOS API"
|
||||
assert schema["info"]["version"] == "0.1.0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_routes_registered(client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/health" in paths
|
||||
assert "/api/v1/auth/login" in paths
|
||||
assert "/api/v1/config/sites/{site_id}" in paths
|
||||
assert "/api/v1/consent/" in paths
|
||||
assert "/api/v1/scanner/scans" in paths
|
||||
assert "/api/v1/compliance/check/{site_id}" in paths
|
||||
89
apps/api/tests/test_integration_auth.py
Normal file
89
apps/api/tests/test_integration_auth.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Integration tests for authentication endpoints (requires database)."""
|
||||
|
||||
from tests.conftest import requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestAuthLogin:
|
||||
async def test_login_success(self, db_client, test_user):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": test_user.email,
|
||||
"password": "TestPassword123",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
async def test_login_wrong_password(self, db_client, test_user):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": test_user.email,
|
||||
"password": "wrong",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_login_nonexistent_user(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nobody@test.com",
|
||||
"password": "anything",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_login_invalid_email(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "not-an-email", "password": "anything"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestAuthMe:
|
||||
async def test_me_returns_user(self, db_client, auth_headers, test_user):
|
||||
resp = await db_client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["email"] == test_user.email
|
||||
assert data["role"] == "owner"
|
||||
|
||||
async def test_me_without_token(self, db_client):
|
||||
resp = await db_client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestAuthRefresh:
|
||||
async def test_refresh_returns_new_tokens(self, db_client, test_user):
|
||||
# First login to get a refresh token
|
||||
login_resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": test_user.email,
|
||||
"password": "TestPassword123",
|
||||
},
|
||||
)
|
||||
refresh_token = login_resp.json()["refresh_token"]
|
||||
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "access_token" in resp.json()
|
||||
|
||||
async def test_refresh_with_invalid_token(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
152
apps/api/tests/test_integration_consent.py
Normal file
152
apps/api/tests/test_integration_consent.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""Integration tests for consent recording endpoints (requires database)."""
|
||||
|
||||
import uuid
|
||||
|
||||
from tests.conftest import create_test_site, requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestConsentEndpoints:
|
||||
async def test_record_consent(self, db_client, auth_headers):
|
||||
"""POST /consent/ is public (no auth) — used by the banner."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent")
|
||||
resp = await db_client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"visitor_id": str(uuid.uuid4()),
|
||||
"action": "accept_all",
|
||||
"categories_accepted": [
|
||||
"necessary",
|
||||
"functional",
|
||||
"analytics",
|
||||
"marketing",
|
||||
"personalisation",
|
||||
],
|
||||
"categories_rejected": [],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["action"] == "accept_all"
|
||||
assert "id" in data
|
||||
|
||||
async def test_record_consent_reject_all(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-rej")
|
||||
resp = await db_client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"visitor_id": str(uuid.uuid4()),
|
||||
"action": "reject_all",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": [
|
||||
"functional",
|
||||
"analytics",
|
||||
"marketing",
|
||||
"personalisation",
|
||||
],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["action"] == "reject_all"
|
||||
|
||||
async def test_record_consent_custom(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-cust")
|
||||
resp = await db_client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"visitor_id": str(uuid.uuid4()),
|
||||
"action": "custom",
|
||||
"categories_accepted": [
|
||||
"necessary",
|
||||
"analytics",
|
||||
],
|
||||
"categories_rejected": [
|
||||
"functional",
|
||||
"marketing",
|
||||
"personalisation",
|
||||
],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
assert resp.json()["action"] == "custom"
|
||||
|
||||
async def test_get_consent_record(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-get")
|
||||
# Create a consent record
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"visitor_id": str(uuid.uuid4()),
|
||||
"action": "accept_all",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": [],
|
||||
},
|
||||
)
|
||||
consent_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/consent/{consent_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["id"] == consent_id
|
||||
|
||||
async def test_get_consent_requires_auth(self, db_client):
|
||||
"""Reading a consent record without auth must be rejected."""
|
||||
resp = await db_client.get(f"/api/v1/consent/{uuid.uuid4()}")
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_verify_consent(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-ver")
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"visitor_id": str(uuid.uuid4()),
|
||||
"action": "accept_all",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": [],
|
||||
},
|
||||
)
|
||||
consent_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/consent/verify/{consent_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert str(data["id"]) == consent_id
|
||||
|
||||
async def test_get_consent_not_found(self, db_client, auth_headers):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/consent/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_verify_consent_not_found(self, db_client, auth_headers):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/consent/verify/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_record_consent_invalid_action(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-inv")
|
||||
resp = await db_client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"visitor_id": str(uuid.uuid4()),
|
||||
"action": "invalid_action",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": [],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
169
apps/api/tests/test_integration_cookies.py
Normal file
169
apps/api/tests/test_integration_cookies.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Integration tests for cookie and allow-list endpoints (requires database)."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import create_test_site, requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestCookieCategoriesIntegration:
|
||||
async def test_list_categories_with_db(self, db_client):
|
||||
"""Categories are seeded by migration; verify the endpoint."""
|
||||
resp = await db_client.get("/api/v1/cookies/categories")
|
||||
assert resp.status_code == 200
|
||||
categories = resp.json()
|
||||
assert isinstance(categories, list)
|
||||
# Should have at least the 5 seeded categories
|
||||
slugs = {c["slug"] for c in categories}
|
||||
assert "necessary" in slugs
|
||||
assert "analytics" in slugs
|
||||
|
||||
async def test_get_category_by_id(self, db_client):
|
||||
cats_resp = await db_client.get("/api/v1/cookies/categories")
|
||||
if cats_resp.status_code == 200 and cats_resp.json():
|
||||
cat_id = cats_resp.json()[0]["id"]
|
||||
resp = await db_client.get(f"/api/v1/cookies/categories/{cat_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_get_category_not_found(self, db_client):
|
||||
resp = await db_client.get(f"/api/v1/cookies/categories/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestCookieCRUDIntegration:
|
||||
async def test_list_cookies_empty(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-empty")
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
async def test_create_and_list_cookie(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-create")
|
||||
create_resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
json={"name": "_ga", "domain": ".google.com"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert create_resp.status_code == 201
|
||||
cookie = create_resp.json()
|
||||
assert cookie["name"] == "_ga"
|
||||
assert cookie["review_status"] == "pending"
|
||||
|
||||
# Should now appear in list
|
||||
list_resp = await db_client.get(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert len(list_resp.json()) >= 1
|
||||
|
||||
async def test_update_cookie_review_status(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-upd")
|
||||
create_resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
json={"name": "_fbp", "domain": ".facebook.com"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
cookie_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.patch(
|
||||
f"/api/v1/cookies/sites/{site_id}/{cookie_id}",
|
||||
json={"review_status": "approved"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["review_status"] == "approved"
|
||||
|
||||
async def test_delete_cookie(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-del")
|
||||
create_resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
json={"name": "_del_test", "domain": ".test.com"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
cookie_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.delete(
|
||||
f"/api/v1/cookies/sites/{site_id}/{cookie_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_cookie_summary(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-sum")
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/sites/{site_id}/summary",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "total" in data
|
||||
assert "by_status" in data
|
||||
assert "uncategorised" in data
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestAllowListIntegration:
|
||||
async def _get_category_id(self, db_client):
|
||||
"""Fetch the first available cookie category ID."""
|
||||
resp = await db_client.get("/api/v1/cookies/categories")
|
||||
categories = resp.json()
|
||||
if categories:
|
||||
return categories[0]["id"]
|
||||
return None
|
||||
|
||||
async def test_create_allow_list_entry(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="allow-create")
|
||||
category_id = await self._get_category_id(db_client)
|
||||
if not category_id:
|
||||
pytest.skip("No categories seeded")
|
||||
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list",
|
||||
json={
|
||||
"name_pattern": "_ga*",
|
||||
"domain_pattern": ".google.com",
|
||||
"category_id": category_id,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name_pattern"] == "_ga*"
|
||||
|
||||
async def test_list_allow_list(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="allow-list")
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
async def test_delete_allow_list_entry(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="allow-del")
|
||||
category_id = await self._get_category_id(db_client)
|
||||
if not category_id:
|
||||
pytest.skip("No categories seeded")
|
||||
|
||||
create_resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list",
|
||||
json={
|
||||
"name_pattern": "_del_test*",
|
||||
"domain_pattern": ".test.com",
|
||||
"category_id": category_id,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
entry_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.delete(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list/{entry_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
173
apps/api/tests/test_integration_orgs_users.py
Normal file
173
apps/api/tests/test_integration_orgs_users.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Integration tests for organisation and user endpoints (requires database)."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import requires_db
|
||||
|
||||
_BOOTSTRAP_TOKEN = "test-bootstrap-token-xyz"
|
||||
_BOOTSTRAP_HEADERS = {"X-Admin-Bootstrap-Token": _BOOTSTRAP_TOKEN}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _bootstrap_enabled(monkeypatch):
|
||||
"""Configure ``admin_bootstrap_token`` on the cached settings object."""
|
||||
from src.config.settings import get_settings
|
||||
|
||||
monkeypatch.setattr(get_settings(), "admin_bootstrap_token", _BOOTSTRAP_TOKEN)
|
||||
yield
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestOrganisationEndpoints:
|
||||
async def test_create_org(self, db_client, _bootstrap_enabled):
|
||||
slug = f"new-org-{uuid.uuid4().hex[:8]}"
|
||||
resp = await db_client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "New Org", "slug": slug},
|
||||
headers=_BOOTSTRAP_HEADERS,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "New Org"
|
||||
assert data["slug"] == slug
|
||||
assert "id" in data
|
||||
|
||||
async def test_create_org_duplicate_slug(self, db_client, _bootstrap_enabled):
|
||||
slug = f"dup-org-{uuid.uuid4().hex[:8]}"
|
||||
await db_client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "Dup Org", "slug": slug},
|
||||
headers=_BOOTSTRAP_HEADERS,
|
||||
)
|
||||
resp = await db_client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "Dup Org 2", "slug": slug},
|
||||
headers=_BOOTSTRAP_HEADERS,
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_get_my_org(self, db_client, auth_headers, test_org):
|
||||
resp = await db_client.get("/api/v1/organisations/me", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["slug"] == test_org.slug
|
||||
|
||||
async def test_update_my_org(self, db_client, auth_headers):
|
||||
resp = await db_client.patch(
|
||||
"/api/v1/organisations/me",
|
||||
json={"name": "Updated Org Name"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["name"] == "Updated Org Name"
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestUserEndpoints:
|
||||
async def test_list_users(self, db_client, auth_headers):
|
||||
resp = await db_client.get("/api/v1/users/", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
assert len(resp.json()) >= 1 # At least the test user
|
||||
|
||||
async def test_create_user(self, db_client, auth_headers):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": f"new-{uuid.uuid4().hex[:8]}@test.com",
|
||||
"password": "SecurePass123",
|
||||
"full_name": "New User",
|
||||
"role": "editor",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["role"] == "editor"
|
||||
|
||||
async def test_create_user_duplicate_email(self, db_client, auth_headers):
|
||||
email = f"dup-{uuid.uuid4().hex[:8]}@test.com"
|
||||
await db_client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "SecurePass123",
|
||||
"full_name": "Dup User",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
resp = await db_client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": email,
|
||||
"password": "SecurePass123",
|
||||
"full_name": "Dup User",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_get_user(self, db_client, auth_headers, test_user):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/users/{test_user.id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["email"] == test_user.email
|
||||
|
||||
async def test_get_user_not_found(self, db_client, auth_headers):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/users/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_update_user(self, db_client, auth_headers):
|
||||
# Create a user to update
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": f"upd-{uuid.uuid4().hex[:8]}@test.com",
|
||||
"password": "SecurePass123",
|
||||
"full_name": "Update User",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
user_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.patch(
|
||||
f"/api/v1/users/{user_id}",
|
||||
json={
|
||||
"full_name": "Updated Name",
|
||||
"role": "editor",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["full_name"] == "Updated Name"
|
||||
assert resp.json()["role"] == "editor"
|
||||
|
||||
async def test_delete_user(self, db_client, auth_headers):
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": f"del-{uuid.uuid4().hex[:8]}@test.com",
|
||||
"password": "SecurePass123",
|
||||
"full_name": "Delete User",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
user_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.delete(f"/api/v1/users/{user_id}", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_users_require_auth(self, db_client):
|
||||
resp = await db_client.get("/api/v1/users/")
|
||||
assert resp.status_code == 401
|
||||
192
apps/api/tests/test_integration_sites.py
Normal file
192
apps/api/tests/test_integration_sites.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""Integration tests for site and site config endpoints (requires database)."""
|
||||
|
||||
import uuid
|
||||
|
||||
from tests.conftest import requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestSiteCRUD:
|
||||
async def test_create_site(self, db_client, auth_headers):
|
||||
domain = f"example-{uuid.uuid4().hex[:8]}.com"
|
||||
resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Example Site",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["domain"] == domain
|
||||
assert data["display_name"] == "Example Site"
|
||||
assert data["is_active"] is True
|
||||
assert "id" in data
|
||||
|
||||
async def test_create_site_duplicate_domain(self, db_client, auth_headers):
|
||||
domain = f"dup-{uuid.uuid4().hex[:8]}.com"
|
||||
# Create first
|
||||
await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Dup Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
# Duplicate should fail
|
||||
resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Dup Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
async def test_list_sites(self, db_client, auth_headers):
|
||||
resp = await db_client.get("/api/v1/sites/", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
async def test_get_site(self, db_client, auth_headers):
|
||||
# Create a site first
|
||||
domain = f"get-{uuid.uuid4().hex[:8]}.com"
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Get Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
site_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.get(f"/api/v1/sites/{site_id}", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["domain"] == domain
|
||||
|
||||
async def test_get_site_not_found(self, db_client, auth_headers):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/sites/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_update_site(self, db_client, auth_headers):
|
||||
domain = f"update-{uuid.uuid4().hex[:8]}.com"
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Update Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
site_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.patch(
|
||||
f"/api/v1/sites/{site_id}",
|
||||
json={"display_name": "Updated Name"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["display_name"] == "Updated Name"
|
||||
|
||||
async def test_delete_site_soft_deletes(self, db_client, auth_headers):
|
||||
domain = f"delete-{uuid.uuid4().hex[:8]}.com"
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Delete Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
site_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.delete(f"/api/v1/sites/{site_id}", headers=auth_headers)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Should no longer be findable
|
||||
get_resp = await db_client.get(f"/api/v1/sites/{site_id}", headers=auth_headers)
|
||||
assert get_resp.status_code == 404
|
||||
|
||||
async def test_create_site_requires_auth(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": "noauth.com",
|
||||
"display_name": "No Auth",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestSiteConfig:
|
||||
async def test_get_config_creates_default(self, db_client, auth_headers):
|
||||
domain = f"config-{uuid.uuid4().hex[:8]}.com"
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Config Test",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
site_id = create_resp.json()["id"]
|
||||
|
||||
# PUT config to create it
|
||||
put_resp = await db_client.put(
|
||||
f"/api/v1/sites/{site_id}/config",
|
||||
json={"blocking_mode": "opt_in"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert put_resp.status_code in (200, 201)
|
||||
|
||||
# GET config
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/sites/{site_id}/config",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["blocking_mode"] == "opt_in"
|
||||
|
||||
async def test_update_config(self, db_client, auth_headers):
|
||||
domain = f"config-upd-{uuid.uuid4().hex[:8]}.com"
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": domain,
|
||||
"display_name": "Config Update",
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
site_id = create_resp.json()["id"]
|
||||
|
||||
# Create config
|
||||
await db_client.put(
|
||||
f"/api/v1/sites/{site_id}/config",
|
||||
json={"blocking_mode": "opt_in"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
# Patch config
|
||||
resp = await db_client.patch(
|
||||
f"/api/v1/sites/{site_id}/config",
|
||||
json={
|
||||
"blocking_mode": "opt_out",
|
||||
"gcm_enabled": False,
|
||||
"consent_expiry_days": 180,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["blocking_mode"] == "opt_out"
|
||||
assert data["gcm_enabled"] is False
|
||||
assert data["consent_expiry_days"] == 180
|
||||
137
apps/api/tests/test_middleware_rate_limit.py
Normal file
137
apps/api/tests/test_middleware_rate_limit.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Tests for the rate limiting middleware."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
|
||||
|
||||
class TestRateLimitMiddleware:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_ip_from_forwarded_for(self):
|
||||
from starlette.applications import Starlette
|
||||
|
||||
app = Starlette()
|
||||
middleware = RateLimitMiddleware(app)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"x-forwarded-for": "1.2.3.4, 5.6.7.8"}
|
||||
assert middleware._get_client_ip(request) == "1.2.3.4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_ip_from_real_ip(self):
|
||||
from starlette.applications import Starlette
|
||||
|
||||
app = Starlette()
|
||||
middleware = RateLimitMiddleware(app)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"x-real-ip": "9.8.7.6"}
|
||||
assert middleware._get_client_ip(request) == "9.8.7.6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_ip_from_client(self):
|
||||
from starlette.applications import Starlette
|
||||
|
||||
app = Starlette()
|
||||
middleware = RateLimitMiddleware(app)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.client = MagicMock()
|
||||
request.client.host = "10.0.0.1"
|
||||
assert middleware._get_client_ip(request) == "10.0.0.1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_client_ip_no_client(self):
|
||||
from starlette.applications import Starlette
|
||||
|
||||
app = Starlette()
|
||||
middleware = RateLimitMiddleware(app)
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.client = None
|
||||
assert middleware._get_client_ip(request) == "unknown"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_bypasses_rate_limit(self):
|
||||
"""Health checks should never be rate limited."""
|
||||
from src.main import create_app
|
||||
|
||||
app = create_app()
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_through_when_redis_unavailable(self):
|
||||
"""When Redis is down, requests should still be served."""
|
||||
from src.main import create_app
|
||||
|
||||
# Rate limiting disabled by default in test settings
|
||||
app = create_app()
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_headers_present(self):
|
||||
"""Rate limit headers should be added when middleware is active."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
async def test_endpoint():
|
||||
return {"ok": True}
|
||||
|
||||
# Mock Redis
|
||||
mock_redis = AsyncMock()
|
||||
mock_redis.incr = AsyncMock(return_value=1)
|
||||
mock_redis.expire = AsyncMock()
|
||||
|
||||
middleware = RateLimitMiddleware(app, requests_per_minute=100)
|
||||
middleware._redis = mock_redis
|
||||
|
||||
# Since we can't easily inject the mock Redis into the ASGI middleware,
|
||||
# test the logic unit separately
|
||||
assert middleware.requests_per_minute == 100
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_middleware_creation(self):
|
||||
"""Middleware should initialise with provided parameters."""
|
||||
from starlette.applications import Starlette
|
||||
|
||||
app = Starlette()
|
||||
middleware = RateLimitMiddleware(app, redis_url="redis://fake:6379", requests_per_minute=30)
|
||||
assert middleware.requests_per_minute == 30
|
||||
assert middleware.redis_url == "redis://fake:6379"
|
||||
assert middleware._redis is None # Lazy initialisation
|
||||
|
||||
|
||||
class TestRateLimitConfiguration:
|
||||
def test_default_settings_enabled(self, monkeypatch):
|
||||
"""Rate limiting is on by default — public endpoints must not be DoS-able.
|
||||
|
||||
Note: the suite-wide conftest sets ``RATE_LIMIT_ENABLED=false``
|
||||
so other tests aren't rate-limited by Redis; we unset it here
|
||||
to verify the baked-in default.
|
||||
"""
|
||||
monkeypatch.delenv("RATE_LIMIT_ENABLED", raising=False)
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
settings = Settings()
|
||||
assert settings.rate_limit_enabled is True
|
||||
|
||||
def test_configurable_limit(self):
|
||||
"""Rate limit per minute should be configurable."""
|
||||
from src.config.settings import Settings
|
||||
|
||||
settings = Settings(rate_limit_per_minute=120)
|
||||
assert settings.rate_limit_per_minute == 120
|
||||
65
apps/api/tests/test_middleware_security.py
Normal file
65
apps/api/tests/test_middleware_security.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Tests for the security headers middleware."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
@pytest.mark.asyncio
|
||||
async def test_x_content_type_options(self, client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.headers["x-content-type-options"] == "nosniff"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_x_frame_options(self, client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.headers["x-frame-options"] == "DENY"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_x_xss_protection(self, client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.headers["x-xss-protection"] == "0"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_referrer_policy(self, client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.headers["referrer-policy"] == "strict-origin-when-cross-origin"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_content_security_policy(self, client):
|
||||
resp = await client.get("/health")
|
||||
assert resp.headers["content-security-policy"] == "default-src 'none'"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_hsts_on_http(self, client):
|
||||
resp = await client.get("/health")
|
||||
assert "strict-transport-security" not in resp.headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hsts_on_https(self, app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="https://test") as client:
|
||||
resp = await client.get("/health")
|
||||
assert "strict-transport-security" in resp.headers
|
||||
assert "max-age=63072000" in resp.headers["strict-transport-security"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_headers_present_on_non_existent_route(self, client):
|
||||
# Even 404s on unknown routes should have security headers
|
||||
resp = await client.get("/this-does-not-exist")
|
||||
assert resp.headers["x-content-type-options"] == "nosniff"
|
||||
assert resp.headers["x-frame-options"] == "DENY"
|
||||
166
apps/api/tests/test_models.py
Normal file
166
apps/api/tests/test_models.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Tests for SQLAlchemy model definitions.
|
||||
|
||||
These tests verify model structure without needing a database connection.
|
||||
"""
|
||||
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from src.models import (
|
||||
Base,
|
||||
ConsentRecord,
|
||||
Cookie,
|
||||
CookieAllowListEntry,
|
||||
CookieCategory,
|
||||
KnownCookie,
|
||||
Organisation,
|
||||
ScanJob,
|
||||
ScanResult,
|
||||
Site,
|
||||
SiteConfig,
|
||||
Translation,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
def test_all_models_registered_in_metadata():
|
||||
"""All expected tables should be present in Base.metadata."""
|
||||
table_names = set(Base.metadata.tables.keys())
|
||||
expected = {
|
||||
"organisations",
|
||||
"users",
|
||||
"sites",
|
||||
"site_configs",
|
||||
"cookie_categories",
|
||||
"cookies",
|
||||
"cookie_allow_list",
|
||||
"known_cookies",
|
||||
"consent_records",
|
||||
"scan_jobs",
|
||||
"scan_results",
|
||||
"translations",
|
||||
}
|
||||
assert expected.issubset(table_names), f"Missing tables: {expected - table_names}"
|
||||
|
||||
|
||||
def test_organisation_columns():
|
||||
mapper = inspect(Organisation)
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
assert "id" in column_names
|
||||
assert "name" in column_names
|
||||
assert "slug" in column_names
|
||||
assert "contact_email" in column_names
|
||||
assert "billing_plan" in column_names
|
||||
assert "created_at" in column_names
|
||||
assert "updated_at" in column_names
|
||||
assert "deleted_at" in column_names
|
||||
|
||||
|
||||
def test_user_columns_and_fk():
|
||||
mapper = inspect(User)
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
assert "organisation_id" in column_names
|
||||
assert "email" in column_names
|
||||
assert "password_hash" in column_names
|
||||
assert "role" in column_names
|
||||
|
||||
|
||||
def test_site_unique_constraint():
|
||||
table = Site.__table__
|
||||
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
|
||||
assert "uq_sites_org_domain" in constraint_names
|
||||
|
||||
|
||||
def test_site_config_jsonb_fields():
|
||||
mapper = inspect(SiteConfig)
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
for field in ["regional_modes", "gcm_default", "banner_config"]:
|
||||
assert field in column_names, f"Missing JSONB field: {field}"
|
||||
|
||||
|
||||
def test_cookie_category_columns():
|
||||
mapper = inspect(CookieCategory)
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
assert "tcf_purpose_ids" in column_names
|
||||
assert "gcm_consent_types" in column_names
|
||||
assert "is_essential" in column_names
|
||||
|
||||
|
||||
def test_cookie_unique_constraint():
|
||||
table = Cookie.__table__
|
||||
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
|
||||
assert "uq_cookies_site_name_domain_type" in constraint_names
|
||||
|
||||
|
||||
def test_cookie_allow_list_unique_constraint():
|
||||
table = CookieAllowListEntry.__table__
|
||||
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
|
||||
assert "uq_allow_list_site_name_domain" in constraint_names
|
||||
|
||||
|
||||
def test_known_cookie_unique_constraint():
|
||||
table = KnownCookie.__table__
|
||||
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
|
||||
assert "uq_known_cookies_name_domain" in constraint_names
|
||||
|
||||
|
||||
def test_consent_record_columns():
|
||||
mapper = inspect(ConsentRecord)
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
for field in [
|
||||
"visitor_id",
|
||||
"action",
|
||||
"categories_accepted",
|
||||
"tc_string",
|
||||
"gcm_state",
|
||||
"country_code",
|
||||
"consented_at",
|
||||
]:
|
||||
assert field in column_names, f"Missing field: {field}"
|
||||
|
||||
|
||||
def test_scan_job_columns():
|
||||
mapper = inspect(ScanJob)
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
assert "status" in column_names
|
||||
assert "pages_scanned" in column_names
|
||||
assert "cookies_found" in column_names
|
||||
|
||||
|
||||
def test_scan_result_columns():
|
||||
mapper = inspect(ScanResult)
|
||||
column_names = {c.key for c in mapper.columns}
|
||||
assert "page_url" in column_names
|
||||
assert "cookie_name" in column_names
|
||||
assert "script_source" in column_names
|
||||
assert "auto_category" in column_names
|
||||
|
||||
|
||||
def test_translation_unique_constraint():
|
||||
table = Translation.__table__
|
||||
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
|
||||
assert "uq_translations_site_locale" in constraint_names
|
||||
|
||||
|
||||
def test_uuid_primary_keys():
|
||||
"""All models should use UUID primary keys."""
|
||||
models = [
|
||||
Organisation,
|
||||
User,
|
||||
Site,
|
||||
SiteConfig,
|
||||
CookieCategory,
|
||||
Cookie,
|
||||
CookieAllowListEntry,
|
||||
KnownCookie,
|
||||
ConsentRecord,
|
||||
ScanJob,
|
||||
ScanResult,
|
||||
Translation,
|
||||
]
|
||||
for model in models:
|
||||
mapper = inspect(model)
|
||||
pk_cols = mapper.primary_key
|
||||
assert len(pk_cols) == 1, f"{model.__name__} should have exactly one PK column"
|
||||
assert str(pk_cols[0].type) == "UUID", (
|
||||
f"{model.__name__} PK should be UUID, got {pk_cols[0].type}"
|
||||
)
|
||||
96
apps/api/tests/test_openapi.py
Normal file
96
apps/api/tests/test_openapi.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""Tests for OpenAPI schema generation and documentation."""
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
class TestOpenAPISchema:
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_endpoint_accessible(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_has_info(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
schema = resp.json()
|
||||
assert schema["info"]["title"] == "ConsentOS API"
|
||||
assert "version" in schema["info"]
|
||||
assert "description" in schema["info"]
|
||||
assert "consent" in schema["info"]["description"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_schema_has_tags(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
schema = resp.json()
|
||||
tag_names = {t["name"] for t in schema.get("tags", [])}
|
||||
expected_tags = {
|
||||
"auth",
|
||||
"config",
|
||||
"consent",
|
||||
"sites",
|
||||
"cookies",
|
||||
"scanner",
|
||||
"compliance",
|
||||
"organisations",
|
||||
"users",
|
||||
}
|
||||
assert expected_tags.issubset(tag_names)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_tags_have_descriptions(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
schema = resp.json()
|
||||
for tag in schema.get("tags", []):
|
||||
assert "description" in tag, f"Tag '{tag['name']}' missing description"
|
||||
assert len(tag["description"]) > 10, f"Tag '{tag['name']}' has weak description"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint_in_schema(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
schema = resp.json()
|
||||
assert "/health" in schema["paths"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_endpoints_present(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
paths = resp.json()["paths"]
|
||||
assert "/api/v1/auth/login" in paths
|
||||
assert "/api/v1/consent/" in paths
|
||||
assert "/api/v1/sites/" in paths
|
||||
assert "/api/v1/config/geo" in paths
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_docs_endpoint_accessible(self, client):
|
||||
resp = await client.get("/docs")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
class TestOpenAPIEndpoints:
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_endpoints_documented(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
paths = resp.json()["paths"]
|
||||
config_paths = [p for p in paths if "/config/" in p]
|
||||
assert len(config_paths) >= 4 # public, resolved, geo-resolved, publish, geo
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consent_endpoints_documented(self, client):
|
||||
resp = await client.get("/openapi.json")
|
||||
paths = resp.json()["paths"]
|
||||
consent_paths = [p for p in paths if "/consent" in p]
|
||||
assert len(consent_paths) >= 1
|
||||
146
apps/api/tests/test_org_user_crud.py
Normal file
146
apps/api/tests/test_org_user_crud.py
Normal file
@@ -0,0 +1,146 @@
|
||||
"""Tests for organisation and user CRUD endpoints and schemas."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.schemas.organisation import OrganisationCreate, OrganisationResponse, OrganisationUpdate
|
||||
from src.schemas.user import UserCreate, UserResponse, UserRole, UserUpdate
|
||||
|
||||
|
||||
class TestOrganisationSchemas:
|
||||
def test_create_valid(self):
|
||||
org = OrganisationCreate(name="Acme Corp", slug="acme-corp")
|
||||
assert org.name == "Acme Corp"
|
||||
assert org.slug == "acme-corp"
|
||||
assert org.billing_plan == "free"
|
||||
|
||||
def test_create_invalid_slug(self):
|
||||
with pytest.raises(ValidationError):
|
||||
OrganisationCreate(name="Acme", slug="INVALID SLUG!")
|
||||
|
||||
def test_create_empty_name_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
OrganisationCreate(name="", slug="valid-slug")
|
||||
|
||||
def test_update_partial(self):
|
||||
update = OrganisationUpdate(name="New Name")
|
||||
data = update.model_dump(exclude_unset=True)
|
||||
assert data == {"name": "New Name"}
|
||||
assert "contact_email" not in data
|
||||
|
||||
def test_response_from_attributes(self):
|
||||
now = "2026-01-01T00:00:00Z"
|
||||
resp = OrganisationResponse(
|
||||
id=uuid.uuid4(),
|
||||
name="Test",
|
||||
slug="test",
|
||||
contact_email=None,
|
||||
billing_plan="free",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.name == "Test"
|
||||
|
||||
|
||||
class TestUserSchemas:
|
||||
def test_create_valid(self):
|
||||
user = UserCreate(
|
||||
email="test@example.com",
|
||||
password="securepass123",
|
||||
full_name="Test User",
|
||||
)
|
||||
assert user.role == UserRole.VIEWER
|
||||
|
||||
def test_create_short_password_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UserCreate(email="a@b.com", password="short", full_name="Test")
|
||||
|
||||
def test_create_invalid_email_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UserCreate(email="not-an-email", password="securepass123", full_name="Test")
|
||||
|
||||
def test_create_with_role(self):
|
||||
user = UserCreate(
|
||||
email="admin@example.com",
|
||||
password="securepass123",
|
||||
full_name="Admin",
|
||||
role=UserRole.ADMIN,
|
||||
)
|
||||
assert user.role == UserRole.ADMIN
|
||||
|
||||
def test_update_partial(self):
|
||||
update = UserUpdate(role=UserRole.EDITOR)
|
||||
data = update.model_dump(exclude_unset=True)
|
||||
assert data == {"role": "editor"}
|
||||
|
||||
def test_response_from_attributes(self):
|
||||
now = "2026-01-01T00:00:00Z"
|
||||
resp = UserResponse(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="a@b.com",
|
||||
full_name="Test",
|
||||
role="viewer",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.role == "viewer"
|
||||
|
||||
|
||||
class TestUserRole:
|
||||
def test_role_values(self):
|
||||
assert UserRole.OWNER == "owner"
|
||||
assert UserRole.ADMIN == "admin"
|
||||
assert UserRole.EDITOR == "editor"
|
||||
assert UserRole.VIEWER == "viewer"
|
||||
|
||||
def test_invalid_role_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
UserCreate(
|
||||
email="a@b.com",
|
||||
password="securepass123",
|
||||
full_name="Test",
|
||||
role="superadmin",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRoutesRegistered:
|
||||
async def test_org_routes(self, client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/api/v1/organisations/" in paths
|
||||
assert "/api/v1/organisations/me" in paths
|
||||
|
||||
async def test_user_routes(self, client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/api/v1/users/" in paths
|
||||
assert "/api/v1/users/{user_id}" in paths
|
||||
|
||||
async def test_org_endpoints_require_auth(self, client):
|
||||
response = await client.get("/api/v1/organisations/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_user_endpoints_require_auth(self, client):
|
||||
response = await client.get("/api/v1/users/")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_create_org_rejects_invalid_body(self, client, monkeypatch):
|
||||
"""Create org endpoint validates the request body schema.
|
||||
|
||||
We need to enable the bootstrap token first so the request
|
||||
reaches the body-validation stage (the token guard otherwise
|
||||
fires before Pydantic validation and we'd see 403 instead).
|
||||
"""
|
||||
from src.config.settings import get_settings
|
||||
|
||||
monkeypatch.setattr(get_settings(), "admin_bootstrap_token", "test-token")
|
||||
response = await client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "", "slug": "INVALID SLUG!"},
|
||||
headers={"X-Admin-Bootstrap-Token": "test-token"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
88
apps/api/tests/test_publisher.py
Normal file
88
apps/api/tests/test_publisher.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Unit tests for the CDN publisher service."""
|
||||
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services.publisher import PublishResult, _publish_local, publish_site_config
|
||||
|
||||
|
||||
class TestPublishResult:
|
||||
def test_success_result(self):
|
||||
result = PublishResult(success=True, path="/cdn/config.json")
|
||||
assert result.success is True
|
||||
assert result.path == "/cdn/config.json"
|
||||
assert result.published_at is not None
|
||||
assert result.error is None
|
||||
|
||||
def test_failure_result(self):
|
||||
result = PublishResult(success=False, path="", error="Something went wrong")
|
||||
assert result.success is False
|
||||
assert result.published_at is None
|
||||
assert result.error == "Something went wrong"
|
||||
|
||||
|
||||
class TestPublishLocal:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_creates_files(self):
|
||||
config = {"site_id": "abc-123", "blocking_mode": "opt_in"}
|
||||
path = await _publish_local("abc-123", config, "https://cdn.example.com")
|
||||
assert os.path.exists(path)
|
||||
with open(path) as f:
|
||||
data = json.load(f)
|
||||
assert data["site_id"] == "abc-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_creates_versioned_copy(self):
|
||||
config = {"site_id": "def-456", "blocking_mode": "opt_out"}
|
||||
path = await _publish_local("def-456", config, "https://cdn.example.com")
|
||||
publish_dir = Path(path).parent
|
||||
versioned = list(publish_dir.glob("site-config-def-456-*.json"))
|
||||
assert len(versioned) >= 1
|
||||
|
||||
|
||||
class TestPublishSiteConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_success(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"tcf_enabled": False,
|
||||
"gcm_enabled": True,
|
||||
"consent_expiry_days": 365,
|
||||
}
|
||||
result = await publish_site_config("site-123", site_config)
|
||||
assert result.success is True
|
||||
assert result.path != ""
|
||||
assert result.published_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_with_org_defaults(self):
|
||||
site_config = {"blocking_mode": "opt_in"}
|
||||
org_defaults = {"consent_expiry_days": 180}
|
||||
result = await publish_site_config("site-456", site_config, org_defaults)
|
||||
assert result.success is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_failure_returns_error(self):
|
||||
with patch(
|
||||
"src.services.publisher._publish_local",
|
||||
side_effect=OSError("Permission denied"),
|
||||
):
|
||||
result = await publish_site_config("site-789", {"blocking_mode": "opt_in"})
|
||||
assert result.success is False
|
||||
assert "Permission denied" in result.error
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_cdn():
|
||||
yield
|
||||
cdn_dir = Path("cdn-publish")
|
||||
if cdn_dir.exists():
|
||||
for f in cdn_dir.glob("site-config-*.json"):
|
||||
f.unlink(missing_ok=True)
|
||||
with contextlib.suppress(OSError):
|
||||
cdn_dir.rmdir()
|
||||
203
apps/api/tests/test_routers_auth.py
Normal file
203
apps/api/tests/test_routers_auth.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""Unit tests for auth router — mocked database."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token, create_refresh_token, hash_password
|
||||
|
||||
|
||||
def _make_user(org_id: uuid.UUID | None = None, **overrides):
|
||||
"""Build a mock User ORM object."""
|
||||
_org_id = org_id or uuid.uuid4()
|
||||
_id = overrides.pop("id", uuid.uuid4())
|
||||
user = MagicMock()
|
||||
user.id = _id
|
||||
user.organisation_id = _org_id
|
||||
user.email = overrides.get("email", "admin@test.com")
|
||||
user.password_hash = overrides.get("password_hash", hash_password("TestPassword123"))
|
||||
user.full_name = overrides.get("full_name", "Test Admin")
|
||||
user.role = overrides.get("role", "owner")
|
||||
user.deleted_at = None
|
||||
user.is_active = True
|
||||
return user
|
||||
|
||||
|
||||
def _mock_db(scalars=None, scalar_one_or_none=None):
|
||||
"""Create a mock AsyncSession.
|
||||
|
||||
When a query is executed:
|
||||
- result.scalar_one_or_none() returns `scalar_one_or_none`
|
||||
- result.scalars().all() returns `scalars or []`
|
||||
"""
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = scalar_one_or_none
|
||||
scalars_obj = MagicMock()
|
||||
scalars_obj.all.return_value = scalars or []
|
||||
result.scalars.return_value = scalars_obj
|
||||
session.execute.return_value = result
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
"""Build a test client with the given mock session."""
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestLoginEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_success(self, mock_app):
|
||||
org_id = uuid.uuid4()
|
||||
user = _make_user(org_id=org_id)
|
||||
db = _mock_db(scalar_one_or_none=user)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "admin@test.com", "password": "TestPassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_wrong_password(self, mock_app):
|
||||
user = _make_user()
|
||||
db = _mock_db(scalar_one_or_none=user)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "admin@test.com", "password": "WrongPassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_not_found(self, mock_app):
|
||||
db = _mock_db(scalar_one_or_none=None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "nobody@test.com", "password": "whatever"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_invalid_body(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post("/api/v1/auth/login", json={"email": "not-an-email"})
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
class TestMeEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_returns_user(self, mock_app):
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
role="owner",
|
||||
email="admin@test.com",
|
||||
)
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["email"] == "admin@test.com"
|
||||
assert data["role"] == "owner"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_me_without_token(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/auth/me")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
class TestRefreshEndpoint:
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_success(self, mock_app):
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
user = _make_user(org_id=org_id, id=user_id)
|
||||
refresh_token = create_refresh_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
)
|
||||
db = _mock_db(scalar_one_or_none=user)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_with_access_token_rejected(self, mock_app):
|
||||
"""An access token should not be usable as a refresh token."""
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
access_token = create_access_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
role="owner",
|
||||
email="admin@test.com",
|
||||
)
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": access_token},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_invalid_token(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid.token.here"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_refresh_user_deleted(self, mock_app):
|
||||
org_id = uuid.uuid4()
|
||||
user_id = uuid.uuid4()
|
||||
refresh_token = create_refresh_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
)
|
||||
db = _mock_db(scalar_one_or_none=None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
assert "no longer exists" in resp.json()["detail"]
|
||||
296
apps/api/tests/test_routers_config.py
Normal file
296
apps/api/tests/test_routers_config.py
Normal file
@@ -0,0 +1,296 @@
|
||||
"""Unit tests for config router — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
ORG_ID = uuid.uuid4()
|
||||
USER_ID = uuid.uuid4()
|
||||
|
||||
|
||||
def _auth_headers(role="owner"):
|
||||
token = create_access_token(
|
||||
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _mock_config(**overrides):
|
||||
config = MagicMock(spec=[])
|
||||
config.id = overrides.get("id", uuid.uuid4())
|
||||
config.site_id = overrides.get("site_id", uuid.uuid4())
|
||||
config.blocking_mode = overrides.get("blocking_mode", "opt_in")
|
||||
config.tcf_enabled = overrides.get("tcf_enabled", False)
|
||||
config.tcf_publisher_cc = overrides.get("tcf_publisher_cc")
|
||||
config.gpp_enabled = overrides.get("gpp_enabled", True)
|
||||
config.gpp_supported_apis = overrides.get("gpp_supported_apis", ["usnat"])
|
||||
config.gpc_enabled = overrides.get("gpc_enabled", True)
|
||||
default_jurisdictions = ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"]
|
||||
config.gpc_jurisdictions = overrides.get("gpc_jurisdictions", default_jurisdictions)
|
||||
config.gpc_global_honour = overrides.get("gpc_global_honour", False)
|
||||
config.gcm_enabled = overrides.get("gcm_enabled", True)
|
||||
config.gcm_default = overrides.get("gcm_default")
|
||||
config.banner_config = overrides.get("banner_config", {})
|
||||
config.regional_modes = overrides.get("regional_modes")
|
||||
config.privacy_policy_url = overrides.get("privacy_policy_url")
|
||||
config.scan_schedule_cron = overrides.get("scan_schedule_cron")
|
||||
config.scan_max_pages = overrides.get("scan_max_pages", 50)
|
||||
config.consent_expiry_days = overrides.get("consent_expiry_days", 365)
|
||||
config.created_at = datetime.now(UTC)
|
||||
config.updated_at = datetime.now(UTC)
|
||||
return config
|
||||
|
||||
|
||||
def _mock_db_sequence(*results):
|
||||
session = AsyncMock()
|
||||
mock_results = []
|
||||
for r in results:
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = r
|
||||
mock_results.append(result)
|
||||
session.execute = AsyncMock(side_effect=mock_results)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestPublicSiteConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_public_config(self, mock_app):
|
||||
config = _mock_config()
|
||||
db = _mock_db_sequence(config)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/config/sites/{config.site_id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_public_config_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/config/sites/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestResolvedConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resolved_config(self, mock_app):
|
||||
config = _mock_config()
|
||||
# Resolved endpoint does 4 queries: config, site org_id, org_config, site group_id
|
||||
db = _mock_db_sequence(config, ORG_ID, None, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/config/sites/{config.site_id}/resolved")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "site_id" in data
|
||||
assert "blocking_mode" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resolved_config_with_region(self, mock_app):
|
||||
config = _mock_config(regional_modes={"EU": "opt_in", "US": "opt_out"})
|
||||
db = _mock_db_sequence(config, ORG_ID, None, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/config/sites/{config.site_id}/resolved?region=EU")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resolved_config_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/config/sites/{uuid.uuid4()}/resolved")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestPublishConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_config_success(self, mock_app):
|
||||
config = _mock_config()
|
||||
# Publish does: config query, org_config query, group_id, active A/B test query
|
||||
db = _mock_db_sequence(config, None, None, None)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = True
|
||||
mock_result.path = "/cdn/site-config.json"
|
||||
mock_result.published_at = datetime.now(UTC).isoformat()
|
||||
|
||||
with patch(
|
||||
"src.routers.config.publish_site_config",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/config/sites/{config.site_id}/publish",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["published"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_config_failure(self, mock_app):
|
||||
config = _mock_config()
|
||||
db = _mock_db_sequence(config, None, None, None)
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.success = False
|
||||
mock_result.error = "Disk full"
|
||||
|
||||
with patch(
|
||||
"src.routers.config.publish_site_config",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_result,
|
||||
):
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/config/sites/{config.site_id}/publish",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 500
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_config_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/config/sites/{uuid.uuid4()}/publish",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_requires_admin(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/config/sites/{uuid.uuid4()}/publish",
|
||||
headers=_auth_headers(role="viewer"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
class TestGeoResolvedConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_geo_resolved_config_with_header(self, mock_app):
|
||||
config = _mock_config(
|
||||
regional_modes={"EU": "opt_in", "US": "opt_out", "DEFAULT": "informational"},
|
||||
)
|
||||
# Geo-resolved does: config, site org_id, org_config, site group_id
|
||||
db = _mock_db_sequence(config, ORG_ID, None, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/config/sites/{config.site_id}/geo-resolved",
|
||||
headers={"cf-ipcountry": "DE"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["blocking_mode"] == "opt_in"
|
||||
assert data["detected_country"] == "DE"
|
||||
assert data["detected_region"] == "EU"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_geo_resolved_config_us(self, mock_app):
|
||||
config = _mock_config(
|
||||
regional_modes={"EU": "opt_in", "US-CA": "opt_out", "DEFAULT": "informational"},
|
||||
)
|
||||
db = _mock_db_sequence(config, ORG_ID, None, None)
|
||||
|
||||
with patch(
|
||||
"src.routers.config.detect_region",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
from src.services.geoip import GeoResult
|
||||
|
||||
mock_detect.return_value = GeoResult(country_code="US", region="US-CA")
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/config/sites/{config.site_id}/geo-resolved",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["blocking_mode"] == "opt_out"
|
||||
assert data["detected_region"] == "US-CA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_geo_resolved_config_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/config/sites/{uuid.uuid4()}/geo-resolved",
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_geo_resolved_config_no_region_detected(self, mock_app):
|
||||
config = _mock_config(
|
||||
regional_modes={"EU": "opt_in", "DEFAULT": "informational"},
|
||||
)
|
||||
db = _mock_db_sequence(config, ORG_ID, None, None)
|
||||
|
||||
with patch(
|
||||
"src.routers.config.detect_region",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
from src.services.geoip import GeoResult
|
||||
|
||||
mock_detect.return_value = GeoResult(country_code=None, region=None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/config/sites/{config.site_id}/geo-resolved",
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["detected_country"] is None
|
||||
assert data["detected_region"] is None
|
||||
|
||||
|
||||
class TestVisitorGeo:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_visitor_geo_with_header(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
"/api/v1/config/geo",
|
||||
headers={"cf-ipcountry": "GB"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["country_code"] == "GB"
|
||||
assert data["region"] == "GB"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_visitor_geo_no_headers(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
|
||||
with patch(
|
||||
"src.routers.config.detect_region",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_detect:
|
||||
from src.services.geoip import GeoResult
|
||||
|
||||
mock_detect.return_value = GeoResult(country_code=None, region=None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/config/geo")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["country_code"] is None
|
||||
assert data["region"] is None
|
||||
230
apps/api/tests/test_routers_consent.py
Normal file
230
apps/api/tests/test_routers_consent.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""Unit tests for consent router — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
|
||||
|
||||
def _mock_consent_record(**overrides):
|
||||
"""Build a mock ConsentRecord ORM object."""
|
||||
record = MagicMock()
|
||||
record.id = overrides.get("id", uuid.uuid4())
|
||||
record.site_id = overrides.get("site_id", uuid.uuid4())
|
||||
record.visitor_id = overrides.get("visitor_id", "visitor-123")
|
||||
record.ip_hash = "abc123"
|
||||
record.user_agent_hash = "def456"
|
||||
record.action = overrides.get("action", "accept_all")
|
||||
record.categories_accepted = overrides.get("categories_accepted", ["necessary"])
|
||||
record.categories_rejected = overrides.get("categories_rejected", [])
|
||||
record.tc_string = overrides.get("tc_string")
|
||||
record.gcm_state = overrides.get("gcm_state")
|
||||
record.gpp_string = overrides.get("gpp_string")
|
||||
record.gpc_detected = overrides.get("gpc_detected")
|
||||
record.gpc_honoured = overrides.get("gpc_honoured")
|
||||
record.page_url = overrides.get("page_url")
|
||||
record.country_code = overrides.get("country_code")
|
||||
record.region_code = overrides.get("region_code")
|
||||
record.consented_at = overrides.get("consented_at", datetime.now(UTC))
|
||||
return record
|
||||
|
||||
|
||||
def _mock_db(scalar_one_or_none=None):
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = scalar_one_or_none
|
||||
session.execute.return_value = result
|
||||
|
||||
_added_objects = []
|
||||
|
||||
def _fake_add(obj):
|
||||
_added_objects.append(obj)
|
||||
|
||||
session.add = MagicMock(side_effect=_fake_add)
|
||||
|
||||
async def _fake_flush():
|
||||
"""Simulate DB flush — populate server-side defaults."""
|
||||
for obj in _added_objects:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = uuid.uuid4()
|
||||
if hasattr(obj, "consented_at") and getattr(obj, "consented_at", None) is None:
|
||||
obj.consented_at = datetime.now(UTC)
|
||||
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(UTC)
|
||||
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(UTC)
|
||||
|
||||
session.flush = AsyncMock(side_effect=_fake_flush)
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
from src.services.dependencies import get_current_user, require_role
|
||||
|
||||
user = MagicMock()
|
||||
user.organisation_id = uuid.uuid4()
|
||||
user.role = "owner"
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
app.dependency_overrides[get_current_user] = lambda: user
|
||||
|
||||
def _override_require_role(*_roles):
|
||||
return lambda: user
|
||||
|
||||
app.dependency_overrides[require_role] = _override_require_role
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestRecordConsent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_consent_success(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"visitor_id": "visitor-123",
|
||||
"action": "accept_all",
|
||||
"categories_accepted": ["necessary", "analytics"],
|
||||
"categories_rejected": [],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_consent_reject_all(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"visitor_id": "visitor-456",
|
||||
"action": "reject_all",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": ["analytics", "marketing"],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_consent_custom(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"visitor_id": "visitor-789",
|
||||
"action": "custom",
|
||||
"categories_accepted": ["necessary", "analytics"],
|
||||
"categories_rejected": ["marketing"],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_consent_invalid_action(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"visitor_id": "visitor-000",
|
||||
"action": "invalid_action",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": [],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_consent_empty_visitor_id(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"visitor_id": "",
|
||||
"action": "accept_all",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": [],
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_consent_with_optional_fields(self, mock_app):
|
||||
db = _mock_db()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"visitor_id": "visitor-opt",
|
||||
"action": "accept_all",
|
||||
"categories_accepted": ["necessary"],
|
||||
"categories_rejected": [],
|
||||
"tc_string": "CPXxRAAAA",
|
||||
"gcm_state": {"analytics_storage": "granted"},
|
||||
"page_url": "https://example.com",
|
||||
"country_code": "GB",
|
||||
"region_code": "ENG",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
|
||||
class TestGetConsent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_consent_found(self, mock_app):
|
||||
record = _mock_consent_record()
|
||||
db = _mock_db(scalar_one_or_none=record)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/consent/{record.id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_consent_not_found(self, mock_app):
|
||||
db = _mock_db(scalar_one_or_none=None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/consent/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestVerifyConsent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_consent_valid(self, mock_app):
|
||||
record = _mock_consent_record()
|
||||
db = _mock_db(scalar_one_or_none=record)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/consent/verify/{record.id}")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_verify_consent_not_found(self, mock_app):
|
||||
db = _mock_db(scalar_one_or_none=None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/consent/verify/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
437
apps/api/tests/test_routers_cookies.py
Normal file
437
apps/api/tests/test_routers_cookies.py
Normal file
@@ -0,0 +1,437 @@
|
||||
"""Unit tests for cookie, category, and allow-list routers — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
ORG_ID = uuid.uuid4()
|
||||
USER_ID = uuid.uuid4()
|
||||
|
||||
|
||||
def _auth_headers(role="owner"):
|
||||
token = create_access_token(
|
||||
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _mock_category(**overrides):
|
||||
cat = MagicMock(spec=[])
|
||||
cat.id = overrides.get("id", uuid.uuid4())
|
||||
cat.name = overrides.get("name", "Analytics")
|
||||
cat.slug = overrides.get("slug", "analytics")
|
||||
cat.description = overrides.get("description", "Analytics cookies")
|
||||
cat.is_essential = overrides.get("is_essential", False)
|
||||
cat.display_order = overrides.get("display_order", 3)
|
||||
cat.tcf_purpose_ids = overrides.get("tcf_purpose_ids", [])
|
||||
cat.gcm_consent_types = overrides.get("gcm_consent_types", ["analytics_storage"])
|
||||
cat.created_at = datetime.now(UTC)
|
||||
cat.updated_at = datetime.now(UTC)
|
||||
return cat
|
||||
|
||||
|
||||
def _mock_cookie(**overrides):
|
||||
cookie = MagicMock(spec=[])
|
||||
cookie.id = overrides.get("id", uuid.uuid4())
|
||||
cookie.site_id = overrides.get("site_id", uuid.uuid4())
|
||||
cookie.name = overrides.get("name", "_ga")
|
||||
cookie.domain = overrides.get("domain", ".google.com")
|
||||
cookie.path = overrides.get("path", "/")
|
||||
cookie.category_id = overrides.get("category_id")
|
||||
cookie.storage_type = overrides.get("storage_type", "cookie")
|
||||
cookie.review_status = overrides.get("review_status", "pending")
|
||||
cookie.description = overrides.get("description")
|
||||
cookie.vendor = overrides.get("vendor")
|
||||
cookie.max_age_seconds = overrides.get("max_age_seconds")
|
||||
cookie.is_http_only = overrides.get("is_http_only")
|
||||
cookie.is_secure = overrides.get("is_secure")
|
||||
cookie.same_site = overrides.get("same_site")
|
||||
cookie.first_seen_at = overrides.get("first_seen_at", datetime.now(UTC).isoformat())
|
||||
cookie.last_seen_at = overrides.get("last_seen_at", datetime.now(UTC).isoformat())
|
||||
cookie.created_at = datetime.now(UTC)
|
||||
cookie.updated_at = datetime.now(UTC)
|
||||
return cookie
|
||||
|
||||
|
||||
def _mock_site():
|
||||
site = MagicMock(spec=[])
|
||||
site.id = uuid.uuid4()
|
||||
site.organisation_id = ORG_ID
|
||||
site.domain = "test.com"
|
||||
site.deleted_at = None
|
||||
return site
|
||||
|
||||
|
||||
def _mock_allow_list_entry(**overrides):
|
||||
entry = MagicMock(spec=[])
|
||||
entry.id = overrides.get("id", uuid.uuid4())
|
||||
entry.site_id = overrides.get("site_id", uuid.uuid4())
|
||||
entry.name_pattern = overrides.get("name_pattern", "_ga*")
|
||||
entry.domain_pattern = overrides.get("domain_pattern", ".google.com")
|
||||
entry.category_id = overrides.get("category_id", uuid.uuid4())
|
||||
entry.description = overrides.get("description")
|
||||
entry.created_at = datetime.now(UTC)
|
||||
entry.updated_at = datetime.now(UTC)
|
||||
return entry
|
||||
|
||||
|
||||
def _mock_db_sequence(*results):
|
||||
session = AsyncMock()
|
||||
mock_results = []
|
||||
for r in results:
|
||||
result = MagicMock()
|
||||
if isinstance(r, list):
|
||||
result.scalar_one_or_none.return_value = r[0] if r else None
|
||||
scalars_obj = MagicMock()
|
||||
scalars_obj.all.return_value = r
|
||||
result.scalars.return_value = scalars_obj
|
||||
elif isinstance(r, dict) and "scalar" in r:
|
||||
result.scalar.return_value = r["scalar"]
|
||||
elif isinstance(r, dict) and "all" in r:
|
||||
result.all.return_value = r["all"]
|
||||
else:
|
||||
result.scalar_one_or_none.return_value = r
|
||||
mock_results.append(result)
|
||||
session.execute = AsyncMock(side_effect=mock_results)
|
||||
|
||||
_added = []
|
||||
|
||||
def _fake_add(obj):
|
||||
_added.append(obj)
|
||||
|
||||
session.add = MagicMock(side_effect=_fake_add)
|
||||
|
||||
async def _fake_flush():
|
||||
for obj in _added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = uuid.uuid4()
|
||||
if hasattr(obj, "review_status") and getattr(obj, "review_status", None) is None:
|
||||
obj.review_status = "pending"
|
||||
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(UTC)
|
||||
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(UTC)
|
||||
|
||||
session.flush = AsyncMock(side_effect=_fake_flush)
|
||||
session.refresh = AsyncMock()
|
||||
session.delete = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestCookieCategories:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_categories(self, mock_app):
|
||||
cats = [_mock_category(slug="necessary"), _mock_category(slug="analytics")]
|
||||
db = _mock_db_sequence(cats)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/cookies/categories")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_category(self, mock_app):
|
||||
cat = _mock_category()
|
||||
db = _mock_db_sequence(cat)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/cookies/categories/{cat.id}")
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_category_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/cookies/categories/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestCookieCRUD:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_cookies(self, mock_app):
|
||||
site = _mock_site()
|
||||
cookies = [_mock_cookie(site_id=site.id)]
|
||||
db = _mock_db_sequence(site, cookies)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site.id}", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_cookies_empty(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, [])
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site.id}", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cookie(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site) # site found
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site.id}",
|
||||
json={"name": "_ga", "domain": ".google.com"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cookie_with_invalid_category(self, mock_app):
|
||||
site = _mock_site()
|
||||
cat_id = uuid.uuid4()
|
||||
db = _mock_db_sequence(site, None) # site found, category not found
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site.id}",
|
||||
json={"name": "_ga", "domain": ".google.com", "category_id": str(cat_id)},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cookie(self, mock_app):
|
||||
site = _mock_site()
|
||||
cookie = _mock_cookie(site_id=site.id)
|
||||
db = _mock_db_sequence(site, cookie)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_cookie_not_found(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/cookies/sites/{site.id}/{uuid.uuid4()}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_cookie(self, mock_app):
|
||||
site = _mock_site()
|
||||
cookie = _mock_cookie(site_id=site.id)
|
||||
db = _mock_db_sequence(site, cookie)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
|
||||
json={"review_status": "approved"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_cookie_not_found(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/cookies/sites/{site.id}/{uuid.uuid4()}",
|
||||
json={"review_status": "approved"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_cookie_invalid_category(self, mock_app):
|
||||
site = _mock_site()
|
||||
cookie = _mock_cookie(site_id=site.id)
|
||||
cat_id = uuid.uuid4()
|
||||
# site found, cookie found, category validation fails
|
||||
db = _mock_db_sequence(site, cookie, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
|
||||
json={"category_id": str(cat_id)},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cookie(self, mock_app):
|
||||
site = _mock_site()
|
||||
cookie = _mock_cookie(site_id=site.id)
|
||||
db = _mock_db_sequence(site, cookie)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(
|
||||
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_cookie_not_found(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(
|
||||
f"/api/v1/cookies/sites/{site.id}/{uuid.uuid4()}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestCookieSummary:
|
||||
@pytest.mark.asyncio
|
||||
async def test_cookie_summary(self, mock_app):
|
||||
site = _mock_site()
|
||||
# summary makes 4 queries: _get_org_site, status count, category count, uncategorised
|
||||
db = _mock_db_sequence(
|
||||
site,
|
||||
{"all": [("pending", 5), ("approved", 3)]},
|
||||
{"all": [("analytics", 4), ("marketing", 2)]},
|
||||
{"scalar": 2},
|
||||
)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/cookies/sites/{site.id}/summary",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "total" in data
|
||||
assert "by_status" in data
|
||||
assert "uncategorised" in data
|
||||
|
||||
|
||||
class TestAllowList:
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_allow_list(self, mock_app):
|
||||
site = _mock_site()
|
||||
entries = [_mock_allow_list_entry(site_id=site.id)]
|
||||
db = _mock_db_sequence(site, entries)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_allow_list_entry(self, mock_app):
|
||||
site = _mock_site()
|
||||
cat = _mock_category()
|
||||
db = _mock_db_sequence(site, cat) # site found, category valid
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list",
|
||||
json={
|
||||
"name_pattern": "_ga*",
|
||||
"domain_pattern": ".google.com",
|
||||
"category_id": str(cat.id),
|
||||
},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_allow_list_invalid_category(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None) # site found, category not found
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list",
|
||||
json={
|
||||
"name_pattern": "_ga*",
|
||||
"domain_pattern": ".google.com",
|
||||
"category_id": str(uuid.uuid4()),
|
||||
},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_allow_list_entry(self, mock_app):
|
||||
site = _mock_site()
|
||||
entry = _mock_allow_list_entry(site_id=site.id)
|
||||
db = _mock_db_sequence(site, entry)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list/{entry.id}",
|
||||
json={"name_pattern": "_fbp*"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_allow_list_not_found(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list/{uuid.uuid4()}",
|
||||
json={"name_pattern": "_fbp*"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_allow_list_invalid_category(self, mock_app):
|
||||
site = _mock_site()
|
||||
entry = _mock_allow_list_entry(site_id=site.id)
|
||||
db = _mock_db_sequence(site, entry, None) # site, entry found, category invalid
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list/{entry.id}",
|
||||
json={"category_id": str(uuid.uuid4())},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_allow_list_entry(self, mock_app):
|
||||
site = _mock_site()
|
||||
entry = _mock_allow_list_entry(site_id=site.id)
|
||||
db = _mock_db_sequence(site, entry)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list/{entry.id}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_allow_list_not_found(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(
|
||||
f"/api/v1/cookies/sites/{site.id}/allow-list/{uuid.uuid4()}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_site_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None) # site not found
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/cookies/sites/{uuid.uuid4()}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
184
apps/api/tests/test_routers_org_config.py
Normal file
184
apps/api/tests/test_routers_org_config.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Unit tests for org-config router — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
ORG_ID = uuid.uuid4()
|
||||
USER_ID = uuid.uuid4()
|
||||
|
||||
|
||||
def _auth_headers(role="owner"):
|
||||
token = create_access_token(
|
||||
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _mock_org_config(**overrides):
|
||||
config = MagicMock()
|
||||
config.id = overrides.get("id", uuid.uuid4())
|
||||
config.organisation_id = overrides.get("organisation_id", ORG_ID)
|
||||
config.blocking_mode = overrides.get("blocking_mode")
|
||||
config.regional_modes = overrides.get("regional_modes")
|
||||
config.tcf_enabled = overrides.get("tcf_enabled")
|
||||
config.tcf_publisher_cc = overrides.get("tcf_publisher_cc")
|
||||
config.gcm_enabled = overrides.get("gcm_enabled")
|
||||
config.gcm_default = overrides.get("gcm_default")
|
||||
config.banner_config = overrides.get("banner_config")
|
||||
config.gpp_enabled = overrides.get("gpp_enabled")
|
||||
config.gpp_supported_apis = overrides.get("gpp_supported_apis")
|
||||
config.gpc_enabled = overrides.get("gpc_enabled")
|
||||
config.gpc_jurisdictions = overrides.get("gpc_jurisdictions")
|
||||
config.gpc_global_honour = overrides.get("gpc_global_honour")
|
||||
config.shopify_privacy_enabled = overrides.get("shopify_privacy_enabled")
|
||||
config.privacy_policy_url = overrides.get("privacy_policy_url")
|
||||
config.terms_url = overrides.get("terms_url")
|
||||
config.scan_schedule_cron = overrides.get("scan_schedule_cron")
|
||||
config.scan_max_pages = overrides.get("scan_max_pages")
|
||||
config.consent_expiry_days = overrides.get("consent_expiry_days")
|
||||
config.consent_retention_days = overrides.get("consent_retention_days")
|
||||
config.created_at = datetime.now(UTC)
|
||||
config.updated_at = datetime.now(UTC)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
def _mock_db_sequence(*results):
|
||||
"""Create a mock session returning different results on successive execute() calls."""
|
||||
session = AsyncMock()
|
||||
mock_results = []
|
||||
for r in results:
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = r
|
||||
mock_results.append(result)
|
||||
session.execute = AsyncMock(side_effect=mock_results)
|
||||
|
||||
_added = []
|
||||
|
||||
def _fake_add(obj):
|
||||
_added.append(obj)
|
||||
|
||||
session.add = MagicMock(side_effect=_fake_add)
|
||||
|
||||
async def _fake_flush():
|
||||
for obj in _added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = uuid.uuid4()
|
||||
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(UTC)
|
||||
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(UTC)
|
||||
|
||||
session.flush = AsyncMock(side_effect=_fake_flush)
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
class TestGetOrgConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_existing_config(self, mock_app):
|
||||
"""GET /org-config/ returns existing config."""
|
||||
config = _mock_org_config(blocking_mode="opt_out", consent_expiry_days=180)
|
||||
db = _mock_db_sequence(config)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/org-config/", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["blocking_mode"] == "opt_out"
|
||||
assert data["consent_expiry_days"] == 180
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_auto_creates_when_missing(self, mock_app):
|
||||
"""GET /org-config/ auto-creates a blank config if none exists."""
|
||||
db = _mock_db_sequence(None) # no existing config
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/org-config/", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
# All optional fields should be None
|
||||
assert data["blocking_mode"] is None
|
||||
assert data["tcf_enabled"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_requires_auth(self, mock_app):
|
||||
"""GET /org-config/ returns 401 without token."""
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/org-config/")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestUpdateOrgConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_existing_config(self, mock_app):
|
||||
"""PUT /org-config/ updates existing config."""
|
||||
config = _mock_org_config()
|
||||
db = _mock_db_sequence(config)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
"/api/v1/org-config/",
|
||||
json={"blocking_mode": "opt_out", "consent_expiry_days": 90},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
# Verify setattr was called on the mock
|
||||
assert config.blocking_mode == "opt_out"
|
||||
assert config.consent_expiry_days == 90
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_creates_when_missing(self, mock_app):
|
||||
"""PUT /org-config/ creates config if none exists."""
|
||||
db = _mock_db_sequence(None) # no existing config
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
"/api/v1/org-config/",
|
||||
json={"tcf_enabled": True},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_requires_admin(self, mock_app):
|
||||
"""PUT /org-config/ returns 403 for viewers."""
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
"/api/v1/org-config/",
|
||||
json={"blocking_mode": "opt_in"},
|
||||
headers=_auth_headers(role="viewer"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_allows_editor_role_fails(self, mock_app):
|
||||
"""PUT /org-config/ returns 403 for editors (only owner/admin can update)."""
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
"/api/v1/org-config/",
|
||||
json={"blocking_mode": "opt_in"},
|
||||
headers=_auth_headers(role="editor"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
344
apps/api/tests/test_routers_orgs_users.py
Normal file
344
apps/api/tests/test_routers_orgs_users.py
Normal file
@@ -0,0 +1,344 @@
|
||||
"""Unit tests for organisation and user routers — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token, hash_password
|
||||
|
||||
ORG_ID = uuid.uuid4()
|
||||
USER_ID = uuid.uuid4()
|
||||
|
||||
|
||||
def _auth_headers(role="owner"):
|
||||
token = create_access_token(
|
||||
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _mock_org(**overrides):
|
||||
org = MagicMock(spec=[])
|
||||
org.id = overrides.get("id", ORG_ID)
|
||||
org.name = overrides.get("name", "Test Org")
|
||||
org.slug = overrides.get("slug", "test-org")
|
||||
org.contact_email = overrides.get("contact_email")
|
||||
org.billing_plan = overrides.get("billing_plan", "free")
|
||||
org.deleted_at = None
|
||||
org.created_at = datetime.now(UTC)
|
||||
org.updated_at = datetime.now(UTC)
|
||||
return org
|
||||
|
||||
|
||||
def _mock_user(**overrides):
|
||||
user = MagicMock(spec=[])
|
||||
user.id = overrides.get("id", uuid.uuid4())
|
||||
user.organisation_id = overrides.get("organisation_id", ORG_ID)
|
||||
user.email = overrides.get("email", "user@test.com")
|
||||
user.password_hash = overrides.get("password_hash", hash_password("Pass123"))
|
||||
user.full_name = overrides.get("full_name", "Test User")
|
||||
user.role = overrides.get("role", "editor")
|
||||
user.is_active = True
|
||||
user.deleted_at = None
|
||||
user.created_at = datetime.now(UTC)
|
||||
user.updated_at = datetime.now(UTC)
|
||||
return user
|
||||
|
||||
|
||||
def _mock_db_sequence(*results):
|
||||
session = AsyncMock()
|
||||
mock_results = []
|
||||
for r in results:
|
||||
result = MagicMock()
|
||||
if isinstance(r, list):
|
||||
result.scalar_one_or_none.return_value = r[0] if r else None
|
||||
scalars_obj = MagicMock()
|
||||
scalars_obj.all.return_value = r
|
||||
result.scalars.return_value = scalars_obj
|
||||
else:
|
||||
result.scalar_one_or_none.return_value = r
|
||||
mock_results.append(result)
|
||||
session.execute = AsyncMock(side_effect=mock_results)
|
||||
|
||||
_added = []
|
||||
|
||||
def _fake_add(obj):
|
||||
_added.append(obj)
|
||||
|
||||
session.add = MagicMock(side_effect=_fake_add)
|
||||
|
||||
async def _fake_flush():
|
||||
for obj in _added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = uuid.uuid4()
|
||||
if hasattr(obj, "is_active") and getattr(obj, "is_active", None) is None:
|
||||
obj.is_active = True
|
||||
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(UTC)
|
||||
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(UTC)
|
||||
|
||||
session.flush = AsyncMock(side_effect=_fake_flush)
|
||||
session.refresh = AsyncMock()
|
||||
session.delete = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
_BOOTSTRAP_TOKEN = "test-bootstrap-token-xyz"
|
||||
_BOOTSTRAP_HEADERS = {"X-Admin-Bootstrap-Token": _BOOTSTRAP_TOKEN}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _bootstrap_enabled(monkeypatch):
|
||||
"""Configure the bootstrap token so org creation is permitted."""
|
||||
from src.config import settings as settings_mod
|
||||
|
||||
monkeypatch.setattr(
|
||||
settings_mod.get_settings(),
|
||||
"admin_bootstrap_token",
|
||||
_BOOTSTRAP_TOKEN,
|
||||
)
|
||||
yield
|
||||
|
||||
|
||||
class TestOrganisationRouter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org(self, mock_app, _bootstrap_enabled):
|
||||
db = _mock_db_sequence(None) # no duplicate slug
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "New Org", "slug": "new-org"},
|
||||
headers=_BOOTSTRAP_HEADERS,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_duplicate_slug(self, mock_app, _bootstrap_enabled):
|
||||
existing = _mock_org(slug="dup-slug")
|
||||
db = _mock_db_sequence(existing)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "Another", "slug": "dup-slug"},
|
||||
headers=_BOOTSTRAP_HEADERS,
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_disabled_without_token(self, mock_app):
|
||||
"""With no ``admin_bootstrap_token`` configured, creation is forbidden."""
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "X", "slug": "x"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_org_wrong_token(self, mock_app, _bootstrap_enabled):
|
||||
"""With an incorrect token, creation is unauthorised."""
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/organisations/",
|
||||
json={"name": "X", "slug": "x"},
|
||||
headers={"X-Admin-Bootstrap-Token": "wrong"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_org(self, mock_app):
|
||||
org = _mock_org()
|
||||
db = _mock_db_sequence(org)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/organisations/me", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_my_org_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/organisations/me", headers=_auth_headers())
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_my_org(self, mock_app):
|
||||
org = _mock_org()
|
||||
db = _mock_db_sequence(org)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
"/api/v1/organisations/me",
|
||||
json={"name": "Updated Name"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_my_org_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
"/api/v1/organisations/me",
|
||||
json={"name": "Updated"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_my_org_requires_admin(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
"/api/v1/organisations/me",
|
||||
json={"name": "Updated"},
|
||||
headers=_auth_headers(role="viewer"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
class TestUserRouter:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user(self, mock_app):
|
||||
db = _mock_db_sequence(None) # no duplicate email
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": "new@test.com",
|
||||
"password": "SecurePass123",
|
||||
"full_name": "New User",
|
||||
"role": "editor",
|
||||
},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_duplicate_email(self, mock_app):
|
||||
existing = _mock_user(email="dup@test.com")
|
||||
db = _mock_db_sequence(existing)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": "dup@test.com",
|
||||
"password": "SecurePass123",
|
||||
"full_name": "Dup User",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(self, mock_app):
|
||||
users = [_mock_user(), _mock_user(email="two@test.com")]
|
||||
db = _mock_db_sequence(users)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/users/", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user(self, mock_app):
|
||||
user = _mock_user()
|
||||
db = _mock_db_sequence(user)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/users/{user.id}", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/users/{uuid.uuid4()}", headers=_auth_headers())
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user(self, mock_app):
|
||||
user = _mock_user()
|
||||
db = _mock_db_sequence(user)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/users/{user.id}",
|
||||
json={"full_name": "Updated Name", "role": "admin"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/users/{uuid.uuid4()}",
|
||||
json={"full_name": "Nope"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user(self, mock_app):
|
||||
user = _mock_user(id=uuid.uuid4())
|
||||
db = _mock_db_sequence(user)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(f"/api/v1/users/{user.id}", headers=_auth_headers())
|
||||
assert resp.status_code == 204
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_self_rejected(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(f"/api/v1/users/{USER_ID}", headers=_auth_headers())
|
||||
assert resp.status_code == 400
|
||||
assert "yourself" in resp.json()["detail"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(f"/api/v1/users/{uuid.uuid4()}", headers=_auth_headers())
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_users_require_auth(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/users/")
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_user_viewer_forbidden(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/users/",
|
||||
json={
|
||||
"email": "new@test.com",
|
||||
"password": "SecurePass123",
|
||||
"role": "viewer",
|
||||
},
|
||||
headers=_auth_headers(role="viewer"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
189
apps/api/tests/test_routers_site_groups.py
Normal file
189
apps/api/tests/test_routers_site_groups.py
Normal file
@@ -0,0 +1,189 @@
|
||||
"""Unit tests for site-groups router — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
ORG_ID = uuid.uuid4()
|
||||
USER_ID = uuid.uuid4()
|
||||
|
||||
|
||||
def _auth_headers():
|
||||
token = create_access_token(
|
||||
user_id=USER_ID, organisation_id=ORG_ID, role="owner", email="admin@test.com"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _mock_group(**overrides):
|
||||
group = MagicMock()
|
||||
group.id = overrides.get("id", uuid.uuid4())
|
||||
group.organisation_id = overrides.get("organisation_id", ORG_ID)
|
||||
group.name = overrides.get("name", "Steve Madden")
|
||||
group.description = overrides.get("description")
|
||||
group.deleted_at = None
|
||||
group.created_at = datetime.now(UTC)
|
||||
group.updated_at = datetime.now(UTC)
|
||||
return group
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
def _mock_db_sequence(*results):
|
||||
"""Create a mock session returning different results on successive execute() calls."""
|
||||
session = AsyncMock()
|
||||
mock_results = []
|
||||
for r in results:
|
||||
result = MagicMock()
|
||||
if isinstance(r, list):
|
||||
result.scalar_one_or_none.return_value = r[0] if r else None
|
||||
result.scalar_one.return_value = len(r) if isinstance(r[0], int) else 0
|
||||
scalars_obj = MagicMock()
|
||||
scalars_obj.all.return_value = r
|
||||
result.scalars.return_value = scalars_obj
|
||||
result.all.return_value = r
|
||||
elif isinstance(r, int):
|
||||
result.scalar_one.return_value = r
|
||||
result.scalar_one_or_none.return_value = r
|
||||
elif isinstance(r, tuple):
|
||||
# (group, site_count) rows for list endpoint
|
||||
result.all.return_value = r
|
||||
else:
|
||||
result.scalar_one_or_none.return_value = r
|
||||
result.scalar_one.return_value = r if r is not None else 0
|
||||
mock_results.append(result)
|
||||
session.execute = AsyncMock(side_effect=mock_results)
|
||||
|
||||
_added = []
|
||||
|
||||
def _fake_add(obj):
|
||||
_added.append(obj)
|
||||
|
||||
session.add = MagicMock(side_effect=_fake_add)
|
||||
|
||||
async def _fake_flush():
|
||||
for obj in _added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = uuid.uuid4()
|
||||
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(UTC)
|
||||
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(UTC)
|
||||
|
||||
session.flush = AsyncMock(side_effect=_fake_flush)
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
class TestSiteGroupCRUD:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_site_group(self, mock_app):
|
||||
"""POST /site-groups/ creates a new group."""
|
||||
db = _mock_db_sequence(None) # no duplicate
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/site-groups/",
|
||||
json={"name": "Steve Madden"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["name"] == "Steve Madden"
|
||||
assert data["site_count"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_site_group_conflict(self, mock_app):
|
||||
"""POST /site-groups/ returns 409 when name exists."""
|
||||
existing = _mock_group(name="Steve Madden")
|
||||
db = _mock_db_sequence(existing)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/site-groups/",
|
||||
json={"name": "Steve Madden"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_site_groups(self, mock_app):
|
||||
"""GET /site-groups/ returns groups with site counts."""
|
||||
group = _mock_group(name="Steve Madden")
|
||||
row = MagicMock()
|
||||
row.SiteGroup = group
|
||||
row.site_count = 3
|
||||
rows = (row,)
|
||||
db = _mock_db_sequence(rows)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
"/api/v1/site-groups/",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["name"] == "Steve Madden"
|
||||
assert data[0]["site_count"] == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_site_group(self, mock_app):
|
||||
"""GET /site-groups/{id} returns a single group."""
|
||||
group_id = uuid.uuid4()
|
||||
group = _mock_group(id=group_id, description="SM brand")
|
||||
db = _mock_db_sequence(group, 2) # group lookup, site count
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/site-groups/{group_id}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["description"] == "SM brand"
|
||||
assert data["site_count"] == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_site_group_not_found(self, mock_app):
|
||||
"""GET /site-groups/{id} returns 404 for unknown ID."""
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/site-groups/{uuid.uuid4()}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_site_group(self, mock_app):
|
||||
"""DELETE /site-groups/{id} soft-deletes and ungroups sites."""
|
||||
group_id = uuid.uuid4()
|
||||
group = _mock_group(id=group_id)
|
||||
site = MagicMock()
|
||||
site.site_group_id = group_id
|
||||
db = _mock_db_sequence(group, [site]) # group lookup, sites in group
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(
|
||||
f"/api/v1/site-groups/{group_id}",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
assert site.site_group_id is None
|
||||
assert group.deleted_at is not None
|
||||
266
apps/api/tests/test_routers_sites.py
Normal file
266
apps/api/tests/test_routers_sites.py
Normal file
@@ -0,0 +1,266 @@
|
||||
"""Unit tests for sites router — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
ORG_ID = uuid.uuid4()
|
||||
USER_ID = uuid.uuid4()
|
||||
|
||||
|
||||
def _auth_headers():
|
||||
token = create_access_token(
|
||||
user_id=USER_ID, organisation_id=ORG_ID, role="owner", email="admin@test.com"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _mock_site(**overrides):
|
||||
site = MagicMock()
|
||||
site.id = overrides.get("id", uuid.uuid4())
|
||||
site.organisation_id = overrides.get("organisation_id", ORG_ID)
|
||||
site.domain = overrides.get("domain", "example.com")
|
||||
site.display_name = overrides.get("display_name", "Example Site")
|
||||
site.is_active = overrides.get("is_active", True)
|
||||
site.additional_domains = overrides.get("additional_domains")
|
||||
site.site_group_id = overrides.get("site_group_id")
|
||||
site.deleted_at = None
|
||||
site.created_at = datetime.now(UTC)
|
||||
site.updated_at = datetime.now(UTC)
|
||||
# Alias for SiteResponse.name field
|
||||
site.name = site.display_name
|
||||
return site
|
||||
|
||||
|
||||
def _mock_config(**overrides):
|
||||
config = MagicMock(spec=[]) # spec=[] prevents auto-attr generation
|
||||
config.id = overrides.get("id", uuid.uuid4())
|
||||
config.site_id = overrides.get("site_id", uuid.uuid4())
|
||||
config.blocking_mode = overrides.get("blocking_mode", "opt_in")
|
||||
config.tcf_enabled = overrides.get("tcf_enabled", False)
|
||||
config.tcf_publisher_cc = overrides.get("tcf_publisher_cc")
|
||||
config.gpp_enabled = overrides.get("gpp_enabled", True)
|
||||
config.gpp_supported_apis = overrides.get("gpp_supported_apis", ["usnat"])
|
||||
config.gpc_enabled = overrides.get("gpc_enabled", True)
|
||||
default_jurisdictions = ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"]
|
||||
config.gpc_jurisdictions = overrides.get("gpc_jurisdictions", default_jurisdictions)
|
||||
config.gpc_global_honour = overrides.get("gpc_global_honour", False)
|
||||
config.gcm_enabled = overrides.get("gcm_enabled", True)
|
||||
config.gcm_default = overrides.get("gcm_default")
|
||||
config.banner_config = overrides.get("banner_config", {})
|
||||
config.regional_modes = overrides.get("regional_modes")
|
||||
config.privacy_policy_url = overrides.get("privacy_policy_url")
|
||||
config.scan_schedule_cron = overrides.get("scan_schedule_cron")
|
||||
config.scan_max_pages = overrides.get("scan_max_pages", 50)
|
||||
config.consent_expiry_days = overrides.get("consent_expiry_days", 365)
|
||||
config.created_at = datetime.now(UTC)
|
||||
config.updated_at = datetime.now(UTC)
|
||||
return config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
def _mock_db_sequence(*results):
|
||||
"""Create a mock session that returns different results on successive execute() calls."""
|
||||
session = AsyncMock()
|
||||
mock_results = []
|
||||
for r in results:
|
||||
result = MagicMock()
|
||||
if isinstance(r, list):
|
||||
result.scalar_one_or_none.return_value = r[0] if r else None
|
||||
scalars_obj = MagicMock()
|
||||
scalars_obj.all.return_value = r
|
||||
result.scalars.return_value = scalars_obj
|
||||
else:
|
||||
result.scalar_one_or_none.return_value = r
|
||||
mock_results.append(result)
|
||||
session.execute = AsyncMock(side_effect=mock_results)
|
||||
|
||||
_added = []
|
||||
|
||||
def _fake_add(obj):
|
||||
_added.append(obj)
|
||||
|
||||
session.add = MagicMock(side_effect=_fake_add)
|
||||
|
||||
async def _fake_flush():
|
||||
for obj in _added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = uuid.uuid4()
|
||||
if hasattr(obj, "is_active") and getattr(obj, "is_active", None) is None:
|
||||
obj.is_active = True
|
||||
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(UTC)
|
||||
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(UTC)
|
||||
|
||||
session.flush = AsyncMock(side_effect=_fake_flush)
|
||||
session.refresh = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
class TestSiteCRUD:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_site_success(self, mock_app):
|
||||
# First execute: check existing (None), second: after flush
|
||||
db = _mock_db_sequence(None) # no duplicate
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/sites/",
|
||||
json={"domain": "new-site.com", "display_name": "New Site"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_site_duplicate(self, mock_app):
|
||||
existing_site = _mock_site(domain="dup.com")
|
||||
db = _mock_db_sequence(existing_site)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/sites/",
|
||||
json={"domain": "dup.com", "display_name": "Dup Site"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_sites(self, mock_app):
|
||||
sites = [_mock_site(), _mock_site(domain="two.com")]
|
||||
db = _mock_db_sequence(sites)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get("/api/v1/sites/", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_site_success(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/sites/{site.id}", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_site_not_found(self, mock_app):
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/sites/{uuid.uuid4()}", headers=_auth_headers())
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_site(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/sites/{site.id}",
|
||||
json={"display_name": "Updated"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_site(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(f"/api/v1/sites/{site.id}", headers=_auth_headers())
|
||||
assert resp.status_code == 204
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_site_requires_auth(self, mock_app):
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/sites/", json={"domain": "noauth.com", "display_name": "No Auth"}
|
||||
)
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
|
||||
class TestSiteConfig:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_config_success(self, mock_app):
|
||||
site = _mock_site()
|
||||
config = _mock_config(site_id=site.id)
|
||||
db = _mock_db_sequence(site, config)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/sites/{site.id}/config", headers=_auth_headers())
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_config_not_found(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/sites/{site.id}/config", headers=_auth_headers())
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_config_create(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None) # site found, no existing config
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
f"/api/v1/sites/{site.id}/config",
|
||||
json={"blocking_mode": "opt_in"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_config_replace(self, mock_app):
|
||||
site = _mock_site()
|
||||
config = _mock_config(site_id=site.id)
|
||||
db = _mock_db_sequence(site, config)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
f"/api/v1/sites/{site.id}/config",
|
||||
json={"blocking_mode": "opt_out"},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_config_success(self, mock_app):
|
||||
site = _mock_site()
|
||||
config = _mock_config(site_id=site.id)
|
||||
db = _mock_db_sequence(site, config)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/sites/{site.id}/config",
|
||||
json={"gcm_enabled": False, "consent_expiry_days": 180},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_patch_config_not_found(self, mock_app):
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.patch(
|
||||
f"/api/v1/sites/{site.id}/config",
|
||||
json={"gcm_enabled": False},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
277
apps/api/tests/test_routers_translations.py
Normal file
277
apps/api/tests/test_routers_translations.py
Normal file
@@ -0,0 +1,277 @@
|
||||
"""Unit tests for translations router — mocked database."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.main import create_app
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
ORG_ID = uuid.uuid4()
|
||||
USER_ID = uuid.uuid4()
|
||||
SITE_ID = uuid.uuid4()
|
||||
|
||||
|
||||
def _auth_headers(role="owner"):
|
||||
token = create_access_token(
|
||||
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
|
||||
)
|
||||
return {"Authorization": f"Bearer {token}"}
|
||||
|
||||
|
||||
def _mock_site(**overrides):
|
||||
site = MagicMock()
|
||||
site.id = overrides.get("id", SITE_ID)
|
||||
site.organisation_id = overrides.get("organisation_id", ORG_ID)
|
||||
site.domain = "example.com"
|
||||
site.display_name = "Example"
|
||||
site.is_active = True
|
||||
site.deleted_at = None
|
||||
site.additional_domains = None
|
||||
site.site_group_id = None
|
||||
site.created_at = datetime.now(UTC)
|
||||
site.updated_at = datetime.now(UTC)
|
||||
return site
|
||||
|
||||
|
||||
def _mock_translation(**overrides):
|
||||
t = MagicMock()
|
||||
t.id = overrides.get("id", uuid.uuid4())
|
||||
t.site_id = overrides.get("site_id", SITE_ID)
|
||||
t.locale = overrides.get("locale", "fr")
|
||||
t.strings = overrides.get(
|
||||
"strings",
|
||||
{"title": "Nous utilisons des cookies", "acceptAll": "Tout accepter"},
|
||||
)
|
||||
t.created_at = datetime.now(UTC)
|
||||
t.updated_at = datetime.now(UTC)
|
||||
return t
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_app():
|
||||
return create_app()
|
||||
|
||||
|
||||
async def _client(app, mock_session):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override():
|
||||
yield mock_session
|
||||
|
||||
app.dependency_overrides[get_db] = _override
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
def _mock_db_sequence(*results):
|
||||
"""Create a mock session returning different results on successive execute() calls."""
|
||||
session = AsyncMock()
|
||||
mock_results = []
|
||||
for r in results:
|
||||
result = MagicMock()
|
||||
if isinstance(r, list):
|
||||
scalars_obj = MagicMock()
|
||||
scalars_obj.all.return_value = r
|
||||
result.scalars.return_value = scalars_obj
|
||||
result.scalar_one_or_none.return_value = r[0] if r else None
|
||||
else:
|
||||
result.scalar_one_or_none.return_value = r
|
||||
mock_results.append(result)
|
||||
session.execute = AsyncMock(side_effect=mock_results)
|
||||
|
||||
_added = []
|
||||
|
||||
def _fake_add(obj):
|
||||
_added.append(obj)
|
||||
|
||||
session.add = MagicMock(side_effect=_fake_add)
|
||||
|
||||
async def _fake_flush():
|
||||
for obj in _added:
|
||||
if getattr(obj, "id", None) is None:
|
||||
obj.id = uuid.uuid4()
|
||||
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
|
||||
obj.created_at = datetime.now(UTC)
|
||||
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
|
||||
obj.updated_at = datetime.now(UTC)
|
||||
|
||||
session.flush = AsyncMock(side_effect=_fake_flush)
|
||||
session.refresh = AsyncMock()
|
||||
session.delete = AsyncMock()
|
||||
return session
|
||||
|
||||
|
||||
class TestListTranslations:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_translations(self, mock_app):
|
||||
"""GET /sites/{id}/translations/ returns all translations."""
|
||||
site = _mock_site()
|
||||
fr = _mock_translation(locale="fr")
|
||||
de = _mock_translation(locale="de")
|
||||
db = _mock_db_sequence(site, [fr, de])
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_translations_empty(self, mock_app):
|
||||
"""GET /sites/{id}/translations/ returns empty list when no translations."""
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, [])
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
class TestGetTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_translation(self, mock_app):
|
||||
"""GET /sites/{id}/translations/fr returns the French translation."""
|
||||
site = _mock_site()
|
||||
fr = _mock_translation(locale="fr")
|
||||
db = _mock_db_sequence(site, fr)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/fr",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["locale"] == "fr"
|
||||
assert "title" in data["strings"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_translation_not_found(self, mock_app):
|
||||
"""GET /sites/{id}/translations/xx returns 404."""
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/xx",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestCreateTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_translation(self, mock_app):
|
||||
"""POST /sites/{id}/translations/ creates a new translation."""
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None) # site lookup, duplicate check
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/",
|
||||
json={
|
||||
"locale": "de",
|
||||
"strings": {"title": "Wir verwenden Cookies"},
|
||||
},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["locale"] == "de"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_translation_conflict(self, mock_app):
|
||||
"""POST returns 409 when locale already exists."""
|
||||
site = _mock_site()
|
||||
existing = _mock_translation(locale="fr")
|
||||
db = _mock_db_sequence(site, existing)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/",
|
||||
json={"locale": "fr", "strings": {"title": "test"}},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
class TestUpdateTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_translation(self, mock_app):
|
||||
"""PUT /sites/{id}/translations/fr updates the strings."""
|
||||
site = _mock_site()
|
||||
fr = _mock_translation(locale="fr")
|
||||
db = _mock_db_sequence(site, fr)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/fr",
|
||||
json={"strings": {"title": "Updated title"}},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert fr.strings == {"title": "Updated title"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_translation_not_found(self, mock_app):
|
||||
"""PUT returns 404 when locale does not exist."""
|
||||
site = _mock_site()
|
||||
db = _mock_db_sequence(site, None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.put(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/xx",
|
||||
json={"strings": {"title": "test"}},
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestDeleteTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_translation(self, mock_app):
|
||||
"""DELETE /sites/{id}/translations/fr removes the translation."""
|
||||
site = _mock_site()
|
||||
fr = _mock_translation(locale="fr")
|
||||
db = _mock_db_sequence(site, fr)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/fr",
|
||||
headers=_auth_headers(),
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_requires_admin(self, mock_app):
|
||||
"""DELETE returns 403 for editors."""
|
||||
db = _mock_db_sequence()
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.delete(
|
||||
f"/api/v1/sites/{SITE_ID}/translations/fr",
|
||||
headers=_auth_headers(role="editor"),
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
class TestPublicTranslation:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_public_translation(self, mock_app):
|
||||
"""GET /translations/{site_id}/fr returns raw strings (no auth)."""
|
||||
fr = _mock_translation(locale="fr")
|
||||
db = _mock_db_sequence(fr)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/translations/{SITE_ID}/fr")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "title" in data
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_public_translation_not_found(self, mock_app):
|
||||
"""GET /translations/{site_id}/xx returns 404."""
|
||||
db = _mock_db_sequence(None)
|
||||
async with await _client(mock_app, db) as client:
|
||||
resp = await client.get(f"/api/v1/translations/{SITE_ID}/xx")
|
||||
assert resp.status_code == 404
|
||||
593
apps/api/tests/test_scan_scheduling.py
Normal file
593
apps/api/tests/test_scan_scheduling.py
Normal file
@@ -0,0 +1,593 @@
|
||||
"""Tests for scan scheduling, diff engine, and scan endpoints — CMP-24.
|
||||
|
||||
Covers:
|
||||
- Scanner schemas (new additions)
|
||||
- Scan service (job lifecycle, diff engine, cookie sync)
|
||||
- Scanner router (trigger, list, detail, diff endpoints)
|
||||
- Integration tests against live database
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.schemas.scanner import (
|
||||
CookieDiffItem,
|
||||
DiffStatus,
|
||||
ScanDiffResponse,
|
||||
ScanJobDetailResponse,
|
||||
ScanResultResponse,
|
||||
TriggerScanRequest,
|
||||
)
|
||||
|
||||
# ── Schema tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
"""Validate scanner schema additions."""
|
||||
|
||||
def test_scan_result_response(self):
|
||||
r = ScanResultResponse(
|
||||
id=uuid.uuid4(),
|
||||
scan_job_id=uuid.uuid4(),
|
||||
page_url="https://example.com",
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
storage_type="cookie",
|
||||
found_at=datetime.now(UTC),
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
assert r.cookie_name == "_ga"
|
||||
|
||||
def test_scan_job_detail_response(self):
|
||||
r = ScanJobDetailResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
status="completed",
|
||||
trigger="manual",
|
||||
pages_scanned=5,
|
||||
pages_total=10,
|
||||
cookies_found=3,
|
||||
error_message=None,
|
||||
started_at=datetime.now(UTC),
|
||||
completed_at=datetime.now(UTC),
|
||||
created_at=datetime.now(UTC),
|
||||
updated_at=datetime.now(UTC),
|
||||
results=[],
|
||||
)
|
||||
assert r.status == "completed"
|
||||
assert r.results == []
|
||||
|
||||
def test_trigger_scan_request(self):
|
||||
req = TriggerScanRequest(site_id=uuid.uuid4(), max_pages=100)
|
||||
assert req.max_pages == 100
|
||||
|
||||
def test_trigger_scan_request_defaults(self):
|
||||
req = TriggerScanRequest(site_id=uuid.uuid4())
|
||||
assert req.max_pages == 50
|
||||
|
||||
def test_trigger_scan_max_pages_validation(self):
|
||||
with pytest.raises(ValueError):
|
||||
TriggerScanRequest(site_id=uuid.uuid4(), max_pages=0)
|
||||
with pytest.raises(ValueError):
|
||||
TriggerScanRequest(site_id=uuid.uuid4(), max_pages=501)
|
||||
|
||||
def test_diff_status_values(self):
|
||||
assert DiffStatus.NEW == "new"
|
||||
assert DiffStatus.REMOVED == "removed"
|
||||
assert DiffStatus.CHANGED == "changed"
|
||||
|
||||
def test_cookie_diff_item(self):
|
||||
item = CookieDiffItem(
|
||||
name="_ga",
|
||||
domain=".example.com",
|
||||
storage_type="cookie",
|
||||
diff_status=DiffStatus.NEW,
|
||||
details="First scan",
|
||||
)
|
||||
assert item.diff_status == "new"
|
||||
|
||||
def test_scan_diff_response(self):
|
||||
resp = ScanDiffResponse(
|
||||
current_scan_id=uuid.uuid4(),
|
||||
previous_scan_id=uuid.uuid4(),
|
||||
new_cookies=[
|
||||
CookieDiffItem(
|
||||
name="_ga",
|
||||
domain=".example.com",
|
||||
storage_type="cookie",
|
||||
diff_status=DiffStatus.NEW,
|
||||
),
|
||||
],
|
||||
total_new=1,
|
||||
)
|
||||
assert resp.total_new == 1
|
||||
assert len(resp.new_cookies) == 1
|
||||
|
||||
def test_scan_diff_response_no_previous(self):
|
||||
resp = ScanDiffResponse(
|
||||
current_scan_id=uuid.uuid4(),
|
||||
previous_scan_id=None,
|
||||
)
|
||||
assert resp.previous_scan_id is None
|
||||
assert resp.total_new == 0
|
||||
|
||||
|
||||
# ── Diff engine unit tests ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDiffEngine:
|
||||
"""Test the scan diff engine with mocked data."""
|
||||
|
||||
def _make_scan_result(
|
||||
self,
|
||||
name: str = "_ga",
|
||||
domain: str = ".example.com",
|
||||
storage_type: str = "cookie",
|
||||
script_source: str | None = None,
|
||||
auto_category: str | None = None,
|
||||
attributes: dict | None = None,
|
||||
):
|
||||
"""Create a mock ScanResult."""
|
||||
mock = MagicMock()
|
||||
mock.cookie_name = name
|
||||
mock.cookie_domain = domain
|
||||
mock.storage_type = storage_type
|
||||
mock.script_source = script_source
|
||||
mock.auto_category = auto_category
|
||||
mock.attributes = attributes
|
||||
return mock
|
||||
|
||||
def test_result_key(self):
|
||||
from src.services.scanner import _result_key
|
||||
|
||||
mock = self._make_scan_result("_ga", ".example.com", "cookie")
|
||||
assert _result_key(mock) == ("_ga", ".example.com", "cookie")
|
||||
|
||||
def test_result_key_different_storage(self):
|
||||
from src.services.scanner import _result_key
|
||||
|
||||
mock = self._make_scan_result("key", "example.com", "local_storage")
|
||||
assert _result_key(mock) == ("key", "example.com", "local_storage")
|
||||
|
||||
|
||||
# ── Scan service unit tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestScanService:
|
||||
"""Test scan service functions with mocked DB."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_scan_job(self):
|
||||
from src.services.scanner import create_scan_job
|
||||
|
||||
db = AsyncMock()
|
||||
db.add = MagicMock()
|
||||
db.flush = AsyncMock()
|
||||
|
||||
site_id = uuid.uuid4()
|
||||
job = await create_scan_job(db, site_id=site_id, trigger="manual", max_pages=10)
|
||||
|
||||
assert job.site_id == site_id
|
||||
assert job.status == "pending"
|
||||
assert job.trigger == "manual"
|
||||
assert job.pages_total == 10
|
||||
db.add.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_scan_job(self):
|
||||
from src.services.scanner import start_scan_job
|
||||
|
||||
db = AsyncMock()
|
||||
db.flush = AsyncMock()
|
||||
|
||||
job = MagicMock()
|
||||
job.status = "pending"
|
||||
job.started_at = None
|
||||
|
||||
result = await start_scan_job(db, job)
|
||||
|
||||
assert result.status == "running"
|
||||
assert result.started_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_scan_job_success(self):
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
db = AsyncMock()
|
||||
db.flush = AsyncMock()
|
||||
|
||||
job = MagicMock()
|
||||
result = await complete_scan_job(db, job, pages_scanned=5, cookies_found=10)
|
||||
|
||||
assert result.status == "completed"
|
||||
assert result.pages_scanned == 5
|
||||
assert result.cookies_found == 10
|
||||
assert result.completed_at is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_scan_job_failure(self):
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
db = AsyncMock()
|
||||
db.flush = AsyncMock()
|
||||
|
||||
job = MagicMock()
|
||||
result = await complete_scan_job(db, job, error_message="Connection failed")
|
||||
|
||||
assert result.status == "failed"
|
||||
assert result.error_message == "Connection failed"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_scan_result(self):
|
||||
from src.services.scanner import add_scan_result
|
||||
|
||||
db = AsyncMock()
|
||||
db.add = MagicMock()
|
||||
db.flush = AsyncMock()
|
||||
|
||||
scan_job_id = uuid.uuid4()
|
||||
result = await add_scan_result(
|
||||
db,
|
||||
scan_job_id=scan_job_id,
|
||||
page_url="https://example.com",
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
storage_type="cookie",
|
||||
auto_category="analytics",
|
||||
)
|
||||
|
||||
assert result.scan_job_id == scan_job_id
|
||||
assert result.cookie_name == "_ga"
|
||||
assert result.auto_category == "analytics"
|
||||
db.add.assert_called_once()
|
||||
|
||||
|
||||
# ── Router unit tests (mocked DB) ───────────────────────────────────
|
||||
|
||||
|
||||
def _mock_auth_user():
|
||||
"""Create a mock authenticated user."""
|
||||
from src.schemas.auth import CurrentUser
|
||||
|
||||
return CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@example.com",
|
||||
role="owner",
|
||||
)
|
||||
|
||||
|
||||
async def _authed_client(app, db, user=None):
|
||||
"""Create an authenticated test client with mocked DB."""
|
||||
from src.db import get_db
|
||||
from src.services.dependencies import get_current_user
|
||||
|
||||
if user is None:
|
||||
user = _mock_auth_user()
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_current_user] = lambda: user
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestTriggerScan:
|
||||
"""Test POST /scanner/scans."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_scan_success(self, app):
|
||||
user = _mock_auth_user()
|
||||
db = AsyncMock()
|
||||
|
||||
# Site exists and belongs to user's org
|
||||
site_mock = MagicMock()
|
||||
site_mock.organisation_id = user.organisation_id
|
||||
|
||||
site_id = uuid.uuid4()
|
||||
job_id = uuid.uuid4()
|
||||
now = datetime.now(UTC)
|
||||
|
||||
# Mock scan job returned by create_scan_job
|
||||
mock_job = MagicMock()
|
||||
mock_job.id = job_id
|
||||
mock_job.site_id = site_id
|
||||
mock_job.status = "pending"
|
||||
mock_job.trigger = "manual"
|
||||
mock_job.pages_scanned = 0
|
||||
mock_job.pages_total = 25
|
||||
mock_job.cookies_found = 0
|
||||
mock_job.error_message = None
|
||||
mock_job.started_at = None
|
||||
mock_job.completed_at = None
|
||||
mock_job.created_at = now
|
||||
mock_job.updated_at = now
|
||||
|
||||
# First call: site lookup. Second call: running scan count.
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(stmt):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
result = MagicMock()
|
||||
if call_count == 1:
|
||||
# Site lookup
|
||||
result.scalar_one_or_none.return_value = site_mock
|
||||
elif call_count == 2:
|
||||
# Active scan jobs query — none running
|
||||
result.scalars.return_value.all.return_value = []
|
||||
return result
|
||||
|
||||
db.execute = mock_execute
|
||||
db.add = MagicMock()
|
||||
db.flush = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"src.routers.scanner.create_scan_job",
|
||||
new=AsyncMock(return_value=mock_job),
|
||||
),
|
||||
patch("src.tasks.scanner.run_scan", create=True),
|
||||
):
|
||||
async with await _authed_client(app, db, user) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={
|
||||
"site_id": str(site_id),
|
||||
"max_pages": 25,
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_scan_site_not_found(self, app):
|
||||
db = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = None
|
||||
db.execute = AsyncMock(return_value=result)
|
||||
|
||||
async with await _authed_client(app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"max_pages": 50,
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_scan_conflict(self, app):
|
||||
user = _mock_auth_user()
|
||||
db = AsyncMock()
|
||||
|
||||
# Build a non-stale active job so the router raises 409
|
||||
active_job = MagicMock()
|
||||
active_job.status = "running"
|
||||
active_job.created_at = datetime.now(UTC)
|
||||
active_job.started_at = datetime.now(UTC)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(stmt):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
result = MagicMock()
|
||||
if call_count == 1:
|
||||
# Site lookup
|
||||
site_mock = MagicMock()
|
||||
site_mock.organisation_id = user.organisation_id
|
||||
result.scalar_one_or_none.return_value = site_mock
|
||||
elif call_count == 2:
|
||||
# Active scan jobs query — return a non-stale job
|
||||
result.scalars.return_value.all.return_value = [active_job]
|
||||
return result
|
||||
|
||||
db.execute = mock_execute
|
||||
|
||||
async with await _authed_client(app, db, user) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": str(uuid.uuid4())},
|
||||
)
|
||||
|
||||
assert resp.status_code == 409
|
||||
|
||||
|
||||
class TestListScans:
|
||||
"""Test GET /scanner/scans/site/{site_id}."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_scans_success(self, app):
|
||||
user = _mock_auth_user()
|
||||
db = AsyncMock()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(stmt):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
result = MagicMock()
|
||||
if call_count == 1:
|
||||
# Site access check
|
||||
site_mock = MagicMock()
|
||||
site_mock.organisation_id = user.organisation_id
|
||||
result.scalar_one_or_none.return_value = site_mock
|
||||
else:
|
||||
# Scan list
|
||||
result.scalars.return_value.all.return_value = []
|
||||
return result
|
||||
|
||||
db.execute = mock_execute
|
||||
|
||||
async with await _authed_client(app, db, user) as client:
|
||||
resp = await client.get(f"/api/v1/scanner/scans/site/{uuid.uuid4()}")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == []
|
||||
|
||||
|
||||
class TestGetScan:
|
||||
"""Test GET /scanner/scans/{scan_id}."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_scan_not_found(self, app):
|
||||
db = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = None
|
||||
db.execute = AsyncMock(return_value=result)
|
||||
|
||||
async with await _authed_client(app, db) as client:
|
||||
resp = await client.get(f"/api/v1/scanner/scans/{uuid.uuid4()}")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestGetScanDiff:
|
||||
"""Test GET /scanner/scans/{scan_id}/diff."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_diff_scan_not_found(self, app):
|
||||
db = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = None
|
||||
db.execute = AsyncMock(return_value=result)
|
||||
|
||||
async with await _authed_client(app, db) as client:
|
||||
resp = await client.get(f"/api/v1/scanner/scans/{uuid.uuid4()}/diff")
|
||||
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ── Integration tests ────────────────────────────────────────────────
|
||||
|
||||
try:
|
||||
from tests.conftest import create_test_site, requires_db
|
||||
except ImportError:
|
||||
from conftest import create_test_site, requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestScanIntegration:
|
||||
"""Integration tests against a live database."""
|
||||
|
||||
async def test_trigger_scan(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-trigger")
|
||||
resp = await db_client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": site_id, "max_pages": 10},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
data = resp.json()
|
||||
assert data["status"] == "pending"
|
||||
assert data["trigger"] == "manual"
|
||||
assert data["pages_total"] == 10
|
||||
|
||||
async def test_trigger_scan_conflict(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-conflict")
|
||||
# First scan
|
||||
resp1 = await db_client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": site_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp1.status_code == 201
|
||||
|
||||
# Second scan — should conflict
|
||||
resp2 = await db_client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": site_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp2.status_code == 409
|
||||
|
||||
async def test_list_scans(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-list")
|
||||
# Trigger a scan
|
||||
await db_client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": site_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/scanner/scans/site/{site_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
scans = resp.json()
|
||||
assert len(scans) >= 1
|
||||
assert scans[0]["site_id"] == site_id
|
||||
|
||||
async def test_get_scan_detail(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-detail")
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": site_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
scan_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/scanner/scans/{scan_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == scan_id
|
||||
assert "results" in data
|
||||
|
||||
async def test_get_scan_diff(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-diff")
|
||||
create_resp = await db_client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": site_id},
|
||||
headers=auth_headers,
|
||||
)
|
||||
scan_id = create_resp.json()["id"]
|
||||
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/scanner/scans/{scan_id}/diff",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["current_scan_id"] == scan_id
|
||||
# No previous scan, so previous_scan_id should be null
|
||||
assert data["previous_scan_id"] is None
|
||||
|
||||
async def test_scan_not_found(self, db_client, auth_headers):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/scanner/scans/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_list_scans_pagination(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-page")
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/scanner/scans/site/{site_id}?limit=5&offset=0",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_trigger_scan_requires_auth(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/scanner/scans",
|
||||
json={"site_id": str(uuid.uuid4())},
|
||||
)
|
||||
assert resp.status_code in (401, 403)
|
||||
|
||||
async def test_list_scans_requires_auth(self, db_client):
|
||||
resp = await db_client.get(f"/api/v1/scanner/scans/site/{uuid.uuid4()}")
|
||||
assert resp.status_code in (401, 403)
|
||||
346
apps/api/tests/test_scanner_report.py
Normal file
346
apps/api/tests/test_scanner_report.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Tests for scanner cookie report endpoint — CMP-23 (API side).
|
||||
|
||||
Covers:
|
||||
- Schema validation
|
||||
- Report endpoint (unit tests with mocked DB)
|
||||
- Integration tests against live database
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.schemas.scanner import (
|
||||
CookieReportRequest,
|
||||
CookieReportResponse,
|
||||
ReportedCookie,
|
||||
ScanStatus,
|
||||
ScanTrigger,
|
||||
)
|
||||
|
||||
# ── Schema tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
"""Validate scanner schemas."""
|
||||
|
||||
def test_scan_status_values(self):
|
||||
assert ScanStatus.PENDING == "pending"
|
||||
assert ScanStatus.COMPLETED == "completed"
|
||||
|
||||
def test_scan_trigger_values(self):
|
||||
assert ScanTrigger.CLIENT_REPORT == "client_report"
|
||||
|
||||
def test_reported_cookie(self):
|
||||
rc = ReportedCookie(
|
||||
name="_ga",
|
||||
domain=".example.com",
|
||||
storage_type="cookie",
|
||||
value_length=30,
|
||||
)
|
||||
assert rc.name == "_ga"
|
||||
|
||||
def test_reported_cookie_validation(self):
|
||||
with pytest.raises(ValueError):
|
||||
ReportedCookie(name="", domain=".example.com")
|
||||
|
||||
def test_cookie_report_request(self):
|
||||
req = CookieReportRequest(
|
||||
site_id=uuid.uuid4(),
|
||||
page_url="https://example.com/page",
|
||||
cookies=[
|
||||
ReportedCookie(name="_ga", domain=".example.com"),
|
||||
],
|
||||
collected_at=datetime.now(),
|
||||
)
|
||||
assert len(req.cookies) == 1
|
||||
|
||||
def test_cookie_report_response(self):
|
||||
resp = CookieReportResponse(
|
||||
accepted=True,
|
||||
cookies_received=5,
|
||||
new_cookies=2,
|
||||
)
|
||||
assert resp.new_cookies == 2
|
||||
|
||||
|
||||
# ── Router unit tests (mocked DB) ───────────────────────────────────
|
||||
|
||||
|
||||
def _mock_db_with_site():
|
||||
"""Create a mock DB that returns a site for validation."""
|
||||
db = AsyncMock()
|
||||
site_mock = MagicMock()
|
||||
cookie_result = MagicMock()
|
||||
cookie_result.scalar_one_or_none.return_value = None # no existing cookie
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(stmt):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# Site validation
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = site_mock
|
||||
return result
|
||||
# Cookie existence checks
|
||||
return cookie_result
|
||||
|
||||
db.execute = mock_execute
|
||||
db.add = MagicMock()
|
||||
db.flush = AsyncMock()
|
||||
return db
|
||||
|
||||
|
||||
async def _client(app, db):
|
||||
from src.db import get_db
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestReportEndpoint:
|
||||
"""Test POST /scanner/report."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_success(self, app):
|
||||
db = _mock_db_with_site()
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"page_url": "https://example.com",
|
||||
"cookies": [
|
||||
{
|
||||
"name": "_ga",
|
||||
"domain": ".example.com",
|
||||
"storage_type": "cookie",
|
||||
"value_length": 30,
|
||||
},
|
||||
],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert data["accepted"] is True
|
||||
assert data["cookies_received"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_site_not_found(self, app):
|
||||
db = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = None
|
||||
db.execute.return_value = result
|
||||
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"page_url": "https://example.com",
|
||||
"cookies": [],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_empty_cookies(self, app):
|
||||
db = AsyncMock()
|
||||
site_result = MagicMock()
|
||||
site_result.scalar_one_or_none.return_value = MagicMock()
|
||||
db.execute.return_value = site_result
|
||||
db.flush = AsyncMock()
|
||||
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"page_url": "https://example.com",
|
||||
"cookies": [],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
assert resp.json()["cookies_received"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_report_multiple_storage_types(self, app):
|
||||
db = _mock_db_with_site()
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"page_url": "https://example.com",
|
||||
"cookies": [
|
||||
{
|
||||
"name": "_ga",
|
||||
"domain": ".example.com",
|
||||
"storage_type": "cookie",
|
||||
"value_length": 30,
|
||||
},
|
||||
{
|
||||
"name": "analytics_id",
|
||||
"domain": "example.com",
|
||||
"storage_type": "local_storage",
|
||||
"value_length": 10,
|
||||
},
|
||||
{
|
||||
"name": "session_key",
|
||||
"domain": "example.com",
|
||||
"storage_type": "session_storage",
|
||||
"value_length": 20,
|
||||
},
|
||||
],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
assert resp.json()["cookies_received"] == 3
|
||||
|
||||
|
||||
# ── Integration tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
try:
|
||||
from tests.conftest import create_test_site, requires_db
|
||||
except ImportError:
|
||||
from conftest import create_test_site, requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestScannerReportIntegration:
|
||||
"""Integration tests against a live database."""
|
||||
|
||||
async def test_report_creates_new_cookies(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-new")
|
||||
resp = await db_client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"page_url": "https://report-new.com/page",
|
||||
"cookies": [
|
||||
{
|
||||
"name": "_ga",
|
||||
"domain": ".report-new.com",
|
||||
"storage_type": "cookie",
|
||||
"value_length": 30,
|
||||
},
|
||||
{
|
||||
"name": "analytics_id",
|
||||
"domain": "report-new.com",
|
||||
"storage_type": "local_storage",
|
||||
"value_length": 10,
|
||||
},
|
||||
],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert data["cookies_received"] == 2
|
||||
assert data["new_cookies"] == 2
|
||||
|
||||
# Verify cookies were created
|
||||
cookies_resp = await db_client.get(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert cookies_resp.status_code == 200
|
||||
cookies = cookies_resp.json()
|
||||
names = [c["name"] for c in cookies]
|
||||
assert "_ga" in names
|
||||
assert "analytics_id" in names
|
||||
|
||||
async def test_report_deduplicates_existing_cookies(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-dedup")
|
||||
report_payload = {
|
||||
"site_id": site_id,
|
||||
"page_url": "https://report-dedup.com",
|
||||
"cookies": [
|
||||
{
|
||||
"name": "_dedup_cookie",
|
||||
"domain": ".report-dedup.com",
|
||||
"storage_type": "cookie",
|
||||
"value_length": 10,
|
||||
},
|
||||
],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
}
|
||||
|
||||
# First report — should create
|
||||
resp1 = await db_client.post("/api/v1/scanner/report", json=report_payload)
|
||||
assert resp1.status_code == 202
|
||||
assert resp1.json()["new_cookies"] == 1
|
||||
|
||||
# Second report — should not create duplicate
|
||||
resp2 = await db_client.post("/api/v1/scanner/report", json=report_payload)
|
||||
assert resp2.status_code == 202
|
||||
assert resp2.json()["new_cookies"] == 0
|
||||
|
||||
async def test_report_sets_review_status_pending(self, db_client, auth_headers):
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-status")
|
||||
await db_client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"page_url": "https://report-status.com",
|
||||
"cookies": [
|
||||
{
|
||||
"name": "_status_cookie",
|
||||
"domain": ".report-status.com",
|
||||
"storage_type": "cookie",
|
||||
"value_length": 5,
|
||||
},
|
||||
],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
# Check the created cookie's review status
|
||||
cookies_resp = await db_client.get(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
cookies = cookies_resp.json()
|
||||
status_cookie = next((c for c in cookies if c["name"] == "_status_cookie"), None)
|
||||
assert status_cookie is not None
|
||||
assert status_cookie["review_status"] == "pending"
|
||||
|
||||
async def test_report_no_auth_required(self, db_client, auth_headers):
|
||||
"""Report endpoint should work without authentication."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-noauth")
|
||||
# POST without auth headers
|
||||
resp = await db_client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": site_id,
|
||||
"page_url": "https://report-noauth.com",
|
||||
"cookies": [],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 202
|
||||
|
||||
async def test_report_invalid_site(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/scanner/report",
|
||||
json={
|
||||
"site_id": str(uuid.uuid4()),
|
||||
"page_url": "https://unknown.com",
|
||||
"cookies": [],
|
||||
"collected_at": datetime.now().isoformat(),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
37
apps/api/tests/test_settings.py
Normal file
37
apps/api/tests/test_settings.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Tests for application settings parsing."""
|
||||
|
||||
from src.config.settings import Settings
|
||||
|
||||
|
||||
class TestAllowedOrigins:
|
||||
"""Tests for the allowed_origins_list property."""
|
||||
|
||||
def test_comma_separated_string(self) -> None:
|
||||
"""Comma-separated string is parsed into a list."""
|
||||
settings = Settings(allowed_origins="https://a.com,https://b.com")
|
||||
assert settings.allowed_origins_list == ["https://a.com", "https://b.com"]
|
||||
|
||||
def test_comma_separated_with_spaces(self) -> None:
|
||||
"""Whitespace around commas is stripped."""
|
||||
settings = Settings(allowed_origins="https://a.com , https://b.com")
|
||||
assert settings.allowed_origins_list == ["https://a.com", "https://b.com"]
|
||||
|
||||
def test_single_origin_string(self) -> None:
|
||||
"""A single origin string (no comma) is a single-element list."""
|
||||
settings = Settings(allowed_origins="https://a.com")
|
||||
assert settings.allowed_origins_list == ["https://a.com"]
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
"""An empty string results in an empty list."""
|
||||
settings = Settings(allowed_origins="")
|
||||
assert settings.allowed_origins_list == []
|
||||
|
||||
def test_trailing_comma_ignored(self) -> None:
|
||||
"""Trailing commas don't produce empty entries."""
|
||||
settings = Settings(allowed_origins="https://a.com,")
|
||||
assert settings.allowed_origins_list == ["https://a.com"]
|
||||
|
||||
def test_default_value(self) -> None:
|
||||
"""Default value is localhost:5173."""
|
||||
settings = Settings()
|
||||
assert settings.allowed_origins_list == ["http://localhost:5173"]
|
||||
132
apps/api/tests/test_site_crud.py
Normal file
132
apps/api/tests/test_site_crud.py
Normal file
@@ -0,0 +1,132 @@
|
||||
"""Tests for site and site config CRUD endpoints and schemas."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.schemas.site import (
|
||||
BlockingMode,
|
||||
SiteConfigCreate,
|
||||
SiteConfigResponse,
|
||||
SiteConfigUpdate,
|
||||
SiteCreate,
|
||||
SiteResponse,
|
||||
SiteUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestSiteSchemas:
|
||||
def test_create_valid(self):
|
||||
site = SiteCreate(domain="example.com", display_name="Example Site")
|
||||
assert site.domain == "example.com"
|
||||
|
||||
def test_create_empty_domain_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SiteCreate(domain="", display_name="Test")
|
||||
|
||||
def test_update_partial(self):
|
||||
update = SiteUpdate(display_name="New Name")
|
||||
data = update.model_dump(exclude_unset=True)
|
||||
assert data == {"display_name": "New Name"}
|
||||
|
||||
def test_response_from_attributes(self):
|
||||
now = "2026-01-01T00:00:00Z"
|
||||
resp = SiteResponse(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
domain="example.com",
|
||||
display_name="Example",
|
||||
is_active=True,
|
||||
additional_domains=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.is_active
|
||||
|
||||
|
||||
class TestSiteConfigSchemas:
|
||||
def test_create_defaults(self):
|
||||
config = SiteConfigCreate()
|
||||
assert config.blocking_mode == BlockingMode.OPT_IN
|
||||
assert config.gcm_enabled is True
|
||||
assert config.tcf_enabled is False
|
||||
assert config.scan_max_pages == 50
|
||||
assert config.consent_expiry_days == 365
|
||||
|
||||
def test_create_with_regional_modes(self):
|
||||
config = SiteConfigCreate(
|
||||
regional_modes={"EU": "opt_in", "US-CA": "opt_out", "DEFAULT": "opt_in"}
|
||||
)
|
||||
assert config.regional_modes["EU"] == "opt_in"
|
||||
|
||||
def test_scan_max_pages_bounds(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SiteConfigCreate(scan_max_pages=0)
|
||||
with pytest.raises(ValidationError):
|
||||
SiteConfigCreate(scan_max_pages=1001)
|
||||
|
||||
def test_consent_expiry_bounds(self):
|
||||
with pytest.raises(ValidationError):
|
||||
SiteConfigCreate(consent_expiry_days=0)
|
||||
with pytest.raises(ValidationError):
|
||||
SiteConfigCreate(consent_expiry_days=731)
|
||||
|
||||
def test_update_partial(self):
|
||||
update = SiteConfigUpdate(blocking_mode=BlockingMode.OPT_OUT)
|
||||
data = update.model_dump(exclude_unset=True)
|
||||
assert data == {"blocking_mode": "opt_out"}
|
||||
|
||||
def test_response_from_attributes(self):
|
||||
now = "2026-01-01T00:00:00Z"
|
||||
resp = SiteConfigResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
blocking_mode="opt_in",
|
||||
regional_modes=None,
|
||||
tcf_enabled=False,
|
||||
tcf_publisher_cc=None,
|
||||
gcm_enabled=True,
|
||||
gcm_default=None,
|
||||
banner_config=None,
|
||||
privacy_policy_url=None,
|
||||
scan_schedule_cron=None,
|
||||
scan_max_pages=50,
|
||||
consent_expiry_days=365,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.blocking_mode == "opt_in"
|
||||
|
||||
def test_display_mode_in_banner_config(self):
|
||||
"""Display mode is stored inside banner_config, not as a top-level field."""
|
||||
config = SiteConfigCreate(
|
||||
banner_config={"displayMode": "overlay"},
|
||||
)
|
||||
assert config.banner_config["displayMode"] == "overlay"
|
||||
|
||||
|
||||
class TestEnums:
|
||||
def test_blocking_modes(self):
|
||||
assert BlockingMode.OPT_IN == "opt_in"
|
||||
assert BlockingMode.OPT_OUT == "opt_out"
|
||||
assert BlockingMode.INFORMATIONAL == "informational"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestSiteRoutesRegistered:
|
||||
async def test_site_routes_exist(self, client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/api/v1/sites/" in paths
|
||||
assert "/api/v1/sites/{site_id}" in paths
|
||||
assert "/api/v1/sites/{site_id}/config" in paths
|
||||
|
||||
async def test_site_endpoints_require_auth(self, client):
|
||||
response = await client.get("/api/v1/sites/")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_site_config_endpoints_require_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
response = await client.get(f"/api/v1/sites/{site_id}/config")
|
||||
assert response.status_code == 401
|
||||
104
apps/api/tests/test_site_group_config.py
Normal file
104
apps/api/tests/test_site_group_config.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Tests for site group config endpoints."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.conftest import requires_db
|
||||
|
||||
|
||||
class TestSiteGroupConfigRoutes:
|
||||
"""Unit tests — no database required."""
|
||||
|
||||
def test_group_config_get_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/site-groups/{group_id}/config" in routes
|
||||
|
||||
def test_group_config_put_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/site-groups/{group_id}/config" in routes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_group_config_requires_auth(self, client):
|
||||
group_id = uuid.uuid4()
|
||||
resp = await client.get(f"/api/v1/site-groups/{group_id}/config")
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_put_group_config_requires_auth(self, client):
|
||||
group_id = uuid.uuid4()
|
||||
resp = await client.put(
|
||||
f"/api/v1/site-groups/{group_id}/config",
|
||||
json={"blocking_mode": "opt_in"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestSiteGroupConfigIntegration:
|
||||
"""Integration tests — require a running PostgreSQL database."""
|
||||
|
||||
@requires_db
|
||||
async def test_create_group_and_get_config(self, db_client, auth_headers):
|
||||
# Create a group
|
||||
resp = await db_client.post(
|
||||
"/api/v1/site-groups/",
|
||||
json={"name": f"test-group-{uuid.uuid4().hex[:8]}"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
group_id = resp.json()["id"]
|
||||
|
||||
# GET config (auto-creates empty row)
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/site-groups/{group_id}/config",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["site_group_id"] == group_id
|
||||
assert data["blocking_mode"] is None
|
||||
assert data["consent_expiry_days"] is None
|
||||
|
||||
@requires_db
|
||||
async def test_update_group_config(self, db_client, auth_headers):
|
||||
# Create a group
|
||||
resp = await db_client.post(
|
||||
"/api/v1/site-groups/",
|
||||
json={"name": f"cfg-group-{uuid.uuid4().hex[:8]}"},
|
||||
headers=auth_headers,
|
||||
)
|
||||
group_id = resp.json()["id"]
|
||||
|
||||
# PUT config
|
||||
resp = await db_client.put(
|
||||
f"/api/v1/site-groups/{group_id}/config",
|
||||
json={
|
||||
"blocking_mode": "opt_out",
|
||||
"consent_expiry_days": 90,
|
||||
"tcf_enabled": True,
|
||||
},
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["blocking_mode"] == "opt_out"
|
||||
assert data["consent_expiry_days"] == 90
|
||||
assert data["tcf_enabled"] is True
|
||||
|
||||
# GET confirms persistence
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/site-groups/{group_id}/config",
|
||||
headers=auth_headers,
|
||||
)
|
||||
data = resp.json()
|
||||
assert data["blocking_mode"] == "opt_out"
|
||||
assert data["consent_expiry_days"] == 90
|
||||
|
||||
@requires_db
|
||||
async def test_group_config_not_found_for_other_org(self, db_client, auth_headers):
|
||||
fake_group_id = str(uuid.uuid4())
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/site-groups/{fake_group_id}/config",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
Reference in New Issue
Block a user