feat: initial public release

ConsentOS — a privacy-first cookie consent management platform.

Self-hosted, source-available alternative to OneTrust, Cookiebot, and
CookieYes. Full standards coverage (IAB TCF v2.2, GPP v1, Google
Consent Mode v2, GPC, Shopify Customer Privacy API), multi-tenant
architecture with role-based access, configuration cascade
(system → org → group → site → region), dark-pattern detection in
the scanner, and a tamper-evident consent record audit trail.

This is the initial public release. Prior development history is
retained internally.

See README.md for the feature list, architecture overview, and
quick-start instructions. Licensed under the Elastic Licence 2.0 —
self-host freely; do not resell as a managed service.
This commit is contained in:
James Cottrill
2026-04-13 14:20:15 +00:00
commit fbf26453f2
341 changed files with 62807 additions and 0 deletions

View File

241
apps/api/tests/conftest.py Normal file
View File

@@ -0,0 +1,241 @@
"""Shared test fixtures for the CMP API test suite.
Provides two modes:
- Unit tests: use `app` and `client` fixtures (no database required)
- Integration tests: use `db_client` fixture (requires PostgreSQL)
Integration tests are automatically skipped when no database is available.
"""
import os
# Disable rate limiting for the test suite. Many tests make dozens of
# requests from the same loopback address in rapid succession and the
# middleware would legitimately reject them as a DoS; the middleware
# has its own dedicated test module.
os.environ.setdefault("RATE_LIMIT_ENABLED", "false")
os.environ.setdefault("ENVIRONMENT", "test")
import uuid
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.main import create_app
from src.models.base import Base
# ── Detect whether a test database is available ──────────────────────
_TEST_DB_URL = os.environ.get(
"TEST_DATABASE_URL",
os.environ.get("DATABASE_URL", ""),
)
_HAS_DB = bool(_TEST_DB_URL) and "localhost" in _TEST_DB_URL
def _requires_db(fn):
"""Mark a test as requiring a live database.
Also pins the event loop to session scope so that fixtures sharing the
session-scoped engine don't get 'Future attached to a different loop'.
"""
fn = pytest.mark.asyncio(loop_scope="session")(fn)
fn = pytest.mark.skipif(not _HAS_DB, reason="No test database available")(fn)
return fn
requires_db = _requires_db
# ── Unit test fixtures (no database) ─────────────────────────────────
@pytest.fixture
def app():
"""Create a fresh FastAPI application instance."""
return create_app()
@pytest.fixture
async def client(app):
"""Async HTTP client for unit tests (no database)."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
# ── Integration test fixtures (with database) ────────────────────────
@pytest.fixture(scope="session")
def _test_engine():
"""Create a test database engine (session-scoped)."""
if not _HAS_DB:
pytest.skip("No test database available")
return create_async_engine(_TEST_DB_URL, echo=False)
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def _setup_db(_test_engine):
"""Create all tables once per test session, then seed fixture data.
Tests that depend on the cookie-category seed (normally applied by
the ``0001_initial_schema`` alembic migration) get the same rows
here so they can run without invoking alembic.
"""
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await _seed_cookie_categories(conn)
yield
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
async def _seed_cookie_categories(conn) -> None:
"""Insert the default cookie categories. Mirrors migration 0001."""
import uuid as _uuid
from sqlalchemy import text
rows = [
("10000000-0000-0000-0000-000000000001", "Necessary", "necessary", True, 0),
("10000000-0000-0000-0000-000000000002", "Functional", "functional", False, 1),
("10000000-0000-0000-0000-000000000003", "Analytics", "analytics", False, 2),
("10000000-0000-0000-0000-000000000004", "Marketing", "marketing", False, 3),
("10000000-0000-0000-0000-000000000005", "Personalisation", "personalisation", False, 4),
]
stmt = text(
"""
INSERT INTO cookie_categories
(id, name, slug, description, is_essential, display_order)
VALUES (:id, :name, :slug, :description, :is_essential, :display_order)
ON CONFLICT (slug) DO NOTHING
""",
)
for row_id, name, slug, is_essential, order in rows:
await conn.execute(
stmt,
{
"id": _uuid.UUID(row_id),
"name": name,
"slug": slug,
"description": f"{name} cookies",
"is_essential": is_essential,
"display_order": order,
},
)
@pytest_asyncio.fixture(loop_scope="session")
async def db_client(_test_engine, _setup_db):
"""Async HTTP client where each route handler gets its own DB session.
Each request gets an independent session/connection so there are no
'another operation is in progress' errors from asyncpg.
"""
from src.db import get_db
app = create_app()
async def _override_get_db():
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
# ── Auth helper fixtures ─────────────────────────────────────────────
@pytest_asyncio.fixture(loop_scope="session")
async def test_org(_test_engine, _setup_db):
"""Create a test organisation in the database."""
from src.models.organisation import Organisation
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
org = Organisation(
id=uuid.uuid4(),
name="Test Organisation",
slug=f"test-org-{uuid.uuid4().hex[:8]}",
)
session.add(org)
await session.commit()
return org
@pytest_asyncio.fixture(loop_scope="session")
async def test_user(_test_engine, _setup_db, test_org):
"""Create a test user (owner role) with a known password."""
from src.models.user import User
from src.services.auth import hash_password
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
user = User(
id=uuid.uuid4(),
email=f"admin-{uuid.uuid4().hex[:8]}@test.com",
password_hash=hash_password("TestPassword123"),
full_name="Test Admin",
role="owner",
organisation_id=test_org.id,
)
session.add(user)
await session.commit()
return user
@pytest_asyncio.fixture(loop_scope="session")
async def auth_token(test_user):
"""Generate a valid JWT token for the test user."""
from src.services.auth import create_access_token
return create_access_token(
user_id=str(test_user.id),
organisation_id=str(test_user.organisation_id),
role=test_user.role,
email=test_user.email,
)
@pytest_asyncio.fixture(loop_scope="session")
async def auth_headers(auth_token):
"""HTTP headers with a valid Bearer token."""
return {"Authorization": f"Bearer {auth_token}"}
# ── Shared helper for creating sites in integration tests ────────────
async def create_test_site(
client: AsyncClient,
headers: dict,
*,
domain_prefix: str = "test",
display_name: str = "Test Site",
) -> str:
"""Create a site via the API and return its ID.
This is a helper function (not a fixture) so it can be called
inline within each test, avoiding async fixture event-loop issues.
"""
resp = await client.post(
"/api/v1/sites/",
json={
"domain": f"{domain_prefix}-{uuid.uuid4().hex[:8]}.com",
"display_name": display_name,
},
headers=headers,
)
assert resp.status_code == 201, f"Failed to create test site: {resp.text}"
return resp.json()["id"]

179
apps/api/tests/test_auth.py Normal file
View File

@@ -0,0 +1,179 @@
"""Tests for JWT authentication service and dependencies."""
import uuid
from datetime import UTC, datetime, timedelta
import pytest
from jose import JWTError, jwt
from src.config.settings import get_settings
from src.schemas.auth import CurrentUser
from src.services.auth import (
create_access_token,
create_refresh_token,
decode_token,
hash_password,
verify_password,
)
class TestPasswordHashing:
def test_hash_and_verify(self):
password = "s3cureP@ss!"
hashed = hash_password(password)
assert hashed != password
assert verify_password(password, hashed)
def test_wrong_password_fails(self):
hashed = hash_password("correct")
assert not verify_password("wrong", hashed)
def test_different_hashes_for_same_password(self):
h1 = hash_password("same")
h2 = hash_password("same")
assert h1 != h2 # bcrypt salts differ
class TestJWTTokens:
@pytest.fixture
def user_data(self):
return {
"user_id": uuid.uuid4(),
"organisation_id": uuid.uuid4(),
"role": "admin",
"email": "test@example.com",
}
def test_create_access_token_decodable(self, user_data):
token = create_access_token(**user_data)
payload = decode_token(token)
assert payload["sub"] == str(user_data["user_id"])
assert payload["org_id"] == str(user_data["organisation_id"])
assert payload["role"] == "admin"
assert payload["email"] == "test@example.com"
assert payload["type"] == "access"
def test_create_refresh_token_decodable(self, user_data):
token = create_refresh_token(
user_id=user_data["user_id"],
organisation_id=user_data["organisation_id"],
)
payload = decode_token(token)
assert payload["sub"] == str(user_data["user_id"])
assert payload["type"] == "refresh"
def test_access_token_expiry(self, user_data):
token = create_access_token(**user_data)
payload = decode_token(token)
settings = get_settings()
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
delta = exp - iat
assert abs(delta.total_seconds() - settings.jwt_access_token_expire_minutes * 60) < 5
def test_refresh_token_expiry(self, user_data):
token = create_refresh_token(
user_id=user_data["user_id"],
organisation_id=user_data["organisation_id"],
)
payload = decode_token(token)
settings = get_settings()
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
delta = exp - iat
expected = settings.jwt_refresh_token_expire_days * 86400
assert abs(delta.total_seconds() - expected) < 5
def test_expired_token_raises(self):
settings = get_settings()
payload = {
"sub": str(uuid.uuid4()),
"org_id": str(uuid.uuid4()),
"role": "viewer",
"exp": datetime.now(UTC) - timedelta(hours=1),
"iat": datetime.now(UTC) - timedelta(hours=2),
"type": "access",
}
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
with pytest.raises(JWTError):
decode_token(token)
def test_tampered_token_raises(self, user_data):
token = create_access_token(**user_data)
# Tamper with the token
tampered = token[:-5] + "XXXXX"
with pytest.raises(JWTError):
decode_token(tampered)
class TestCurrentUser:
def test_has_role(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="admin@example.com",
role="admin",
)
assert user.has_role("admin", "owner")
assert not user.has_role("editor", "viewer")
def test_is_admin(self):
admin = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="a@b.com",
role="admin",
)
viewer = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="v@b.com",
role="viewer",
)
assert admin.is_admin
assert not viewer.is_admin
@pytest.mark.asyncio
class TestAuthEndpoints:
async def test_me_without_token_returns_401(self, client):
response = await client.get("/api/v1/auth/me")
assert response.status_code == 401
async def test_me_with_valid_token(self, client):
user_id = uuid.uuid4()
org_id = uuid.uuid4()
token = create_access_token(
user_id=user_id,
organisation_id=org_id,
role="editor",
email="user@example.com",
)
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == str(user_id)
assert data["organisation_id"] == str(org_id)
assert data["role"] == "editor"
assert data["email"] == "user@example.com"
async def test_me_with_refresh_token_rejected(self, client):
token = create_refresh_token(
user_id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
)
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 401
async def test_me_with_invalid_token(self, client):
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": "Bearer invalid.token.here"},
)
assert response.status_code == 401

View File

@@ -0,0 +1,124 @@
"""Tests for the initial admin bootstrap service."""
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import Settings
from src.models.organisation import Organisation
from src.models.user import User
from src.services.auth import verify_password
from src.services.bootstrap import bootstrap_initial_admin
from tests.conftest import requires_db
def _settings(**overrides) -> Settings:
base: dict = dict(
environment="test",
initial_admin_email=None,
initial_admin_password=None,
initial_admin_full_name="Administrator",
initial_org_name="Default Organisation",
initial_org_slug="default",
)
base.update(overrides)
return Settings(**base)
class TestBootstrapNoOp:
"""Pure unit tests — bootstrap must short-circuit before touching the DB."""
async def test_noop_when_email_unset(self):
settings = _settings(initial_admin_password="pw")
with patch("src.services.bootstrap.async_session_factory") as factory:
await bootstrap_initial_admin(settings)
factory.assert_not_called()
async def test_noop_when_password_unset(self):
settings = _settings(initial_admin_email="admin@example.com")
with patch("src.services.bootstrap.async_session_factory") as factory:
await bootstrap_initial_admin(settings)
factory.assert_not_called()
@requires_db
class TestBootstrapWithDatabase:
"""Integration tests — exercise the real SQL path."""
@pytest_asyncio.fixture(loop_scope="session")
async def clean_db(self, _test_engine, _setup_db):
"""Strip users and orgs so bootstrap sees an empty table."""
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
await session.execute(User.__table__.delete())
await session.execute(Organisation.__table__.delete())
await session.commit()
yield
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
await session.execute(User.__table__.delete())
await session.execute(Organisation.__table__.delete())
await session.commit()
async def test_creates_org_and_owner_when_empty(self, _test_engine, clean_db):
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
settings = _settings(
initial_admin_email=email,
initial_admin_password="SuperSecret123",
initial_org_slug=slug,
initial_org_name="Bootstrapped Org",
)
def _factory():
return AsyncSession(_test_engine, expire_on_commit=False)
with patch("src.services.bootstrap.async_session_factory", _factory):
await bootstrap_initial_admin(settings)
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
user = (await session.execute(select(User).where(User.email == email))).scalar_one()
org = (
await session.execute(select(Organisation).where(Organisation.slug == slug))
).scalar_one()
assert user.role == "owner"
assert user.organisation_id == org.id
assert user.full_name == "Administrator"
assert verify_password("SuperSecret123", user.password_hash)
assert org.name == "Bootstrapped Org"
assert org.contact_email == email
async def test_idempotent_when_user_exists(self, _test_engine, clean_db):
"""A second invocation must not create a second user."""
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
settings = _settings(
initial_admin_email=email,
initial_admin_password="SuperSecret123",
initial_org_slug=slug,
)
def _factory():
return AsyncSession(_test_engine, expire_on_commit=False)
with patch("src.services.bootstrap.async_session_factory", _factory):
await bootstrap_initial_admin(settings)
await bootstrap_initial_admin(
_settings(
initial_admin_email="someone-else@example.com",
initial_admin_password="Different123",
initial_org_slug=slug,
)
)
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
users = (await session.execute(select(User))).scalars().all()
assert len(users) == 1
assert users[0].email == email
pytestmark = pytest.mark.asyncio(loop_scope="session")

View File

@@ -0,0 +1,869 @@
"""Tests for known cookies database and auto-categorisation engine — CMP-22.
Covers:
- Classification service logic (unit tests — pure functions)
- Pattern matching (exact, wildcard, regex)
- Priority ordering (allow-list → exact → regex → unmatched)
- Known cookie CRUD endpoints (unit tests with mocked DB)
- Classification endpoints (unit tests with mocked DB)
- Schema validation
- Integration tests against live database
"""
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from src.schemas.cookie import (
ClassificationResultResponse,
ClassifySingleRequest,
ClassifySiteResponse,
KnownCookieCreate,
KnownCookieResponse,
KnownCookieUpdate,
)
from src.services.classification import (
ClassificationResult,
MatchSource,
_match_pattern,
_match_regex,
classify_cookie,
)
# ── Schema tests ─────────────────────────────────────────────────────
class TestSchemas:
"""Validate known cookie and classification schemas."""
def test_known_cookie_create(self):
kc = KnownCookieCreate(
name_pattern="_ga",
domain_pattern="*",
category_id=uuid.uuid4(),
vendor="Google",
description="GA cookie",
)
assert kc.is_regex is False
def test_known_cookie_create_regex(self):
kc = KnownCookieCreate(
name_pattern="_hj.*",
domain_pattern=".*",
category_id=uuid.uuid4(),
is_regex=True,
)
assert kc.is_regex is True
def test_known_cookie_update_partial(self):
ku = KnownCookieUpdate(vendor="Updated Vendor")
dumped = ku.model_dump(exclude_unset=True)
assert "vendor" in dumped
assert "category_id" not in dumped
def test_known_cookie_response(self):
resp = KnownCookieResponse(
id=uuid.uuid4(),
name_pattern="_ga",
domain_pattern="*",
category_id=uuid.uuid4(),
is_regex=False,
created_at=datetime.now(),
updated_at=datetime.now(),
)
assert resp.vendor is None
def test_classification_result_response(self):
crr = ClassificationResultResponse(
cookie_name="_ga",
cookie_domain=".example.com",
match_source="known_exact",
matched=True,
)
assert crr.matched is True
def test_classify_single_request(self):
req = ClassifySingleRequest(cookie_name="_ga", cookie_domain=".example.com")
assert req.cookie_name == "_ga"
def test_classify_single_request_validation(self):
with pytest.raises(ValueError):
ClassifySingleRequest(cookie_name="", cookie_domain=".example.com")
def test_classify_site_response(self):
resp = ClassifySiteResponse(
site_id="abc",
total=10,
matched=7,
unmatched=3,
results=[],
)
assert resp.matched == 7
def test_match_source_enum(self):
assert MatchSource.ALLOW_LIST == "allow_list"
assert MatchSource.KNOWN_EXACT == "known_exact"
assert MatchSource.KNOWN_REGEX == "known_regex"
assert MatchSource.UNMATCHED == "unmatched"
# ── Pattern matching unit tests ──────────────────────────────────────
class TestPatternMatching:
"""Test the _match_pattern and _match_regex helpers."""
def test_exact_match(self):
assert _match_pattern("_ga", "_ga") is True
def test_exact_match_case_insensitive(self):
assert _match_pattern("_GA", "_ga") is True
assert _match_pattern("_ga", "_GA") is True
def test_exact_no_match(self):
assert _match_pattern("_ga", "_gid") is False
def test_wildcard_star(self):
assert _match_pattern("*", "_ga") is True
assert _match_pattern("*", "anything") is True
def test_wildcard_prefix(self):
assert _match_pattern("_ga_*", "_ga_ABC123") is True
assert _match_pattern("_ga_*", "_ga_") is True
assert _match_pattern("_ga_*", "_gid") is False
def test_wildcard_suffix(self):
assert _match_pattern("*.google.com", ".google.com") is True
assert _match_pattern("*.google.com", "www.google.com") is True
assert _match_pattern("*.google.com", ".facebook.com") is False
def test_wildcard_middle(self):
assert _match_pattern("_ga*id", "_ga_gid") is True # * matches _g
assert _match_pattern("_ga*id", "_gaid") is True
assert _match_pattern("_ga*id", "_ga") is False # must end in id
def test_empty_values(self):
assert _match_pattern("", "_ga") is False
assert _match_pattern("_ga", "") is False
assert _match_pattern("", "") is False
def test_regex_match(self):
assert _match_regex(r"_hj.*", "_hjSession_12345") is True
assert _match_regex(r"_hj.*", "_ga") is False
def test_regex_case_insensitive(self):
assert _match_regex(r"_hj.*", "_HJSession") is True
def test_regex_anchored(self):
# re.match anchors at start by default
assert _match_regex(r"_pk_id.*", "_pk_id.abc.123") is True
assert _match_regex(r"_pk_id.*", "x_pk_id") is False
def test_regex_invalid_pattern(self):
assert _match_regex(r"[invalid", "test") is False
def test_regex_full_domain_match(self):
assert _match_regex(r".*", ".example.com") is True
def test_wildcard_dynamic_id_suffix(self):
"""Cookies with dynamic IDs should match wildcard prefix patterns."""
assert _match_pattern("_hjSessionUser_*", "_hjSessionUser_1150536") is True
assert _match_pattern("_hjSession_*", "_hjSession_9876543") is True
assert _match_pattern("ri--*", "ri--zC77O2yRxuIvW5fjRAq0RdzNYaF-x") is True
assert _match_pattern("intercom-id-*", "intercom-id-abc123def") is True
assert _match_pattern("amp_*", "amp_ff29a3") is True
assert _match_pattern("mp_*", "mp_abc123_mixpanel") is True
def test_wildcard_does_not_overmatch(self):
"""Wildcard patterns should not match unrelated cookies."""
assert _match_pattern("_hjSessionUser_*", "_hjSession_123") is False
assert _match_pattern("ri--*", "ri-single-dash") is False
assert _match_pattern("intercom-id-*", "intercom-session-xyz") is False
# ── Classification engine unit tests ─────────────────────────────────
def _make_category(slug: str, cat_id: uuid.UUID | None = None):
"""Create a mock CookieCategory."""
cat = MagicMock()
cat.id = cat_id or uuid.uuid4()
cat.slug = slug
return cat
def _make_known(
name_pattern: str,
domain_pattern: str,
category_id: uuid.UUID,
vendor: str | None = None,
description: str | None = None,
is_regex: bool = False,
):
"""Create a mock KnownCookie."""
known = MagicMock()
known.name_pattern = name_pattern
known.domain_pattern = domain_pattern
known.category_id = category_id
known.vendor = vendor
known.description = description
known.is_regex = is_regex
return known
def _make_allow_entry(
name_pattern: str,
domain_pattern: str,
category_id: uuid.UUID,
description: str | None = None,
):
"""Create a mock CookieAllowListEntry."""
entry = MagicMock()
entry.name_pattern = name_pattern
entry.domain_pattern = domain_pattern
entry.category_id = category_id
entry.description = description
return entry
class TestClassifyCookie:
"""Test the classify_cookie pure function."""
def setup_method(self):
self.analytics_cat = _make_category("analytics")
self.marketing_cat = _make_category("marketing")
self.necessary_cat = _make_category("necessary")
self.category_map = {
self.analytics_cat.id: self.analytics_cat,
self.marketing_cat.id: self.marketing_cat,
self.necessary_cat.id: self.necessary_cat,
}
def test_exact_known_match(self):
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
result = classify_cookie("_ga", ".example.com", [], [known], [], self.category_map)
assert result.matched is True
assert result.match_source == MatchSource.KNOWN_EXACT
assert result.category_slug == "analytics"
assert result.vendor == "Google"
def test_regex_known_match(self):
known = _make_known(
r"_hj.*",
r".*",
self.analytics_cat.id,
vendor="Hotjar",
is_regex=True,
)
result = classify_cookie(
"_hjSession_123",
".example.com",
[],
[],
[known],
self.category_map,
)
assert result.matched is True
assert result.match_source == MatchSource.KNOWN_REGEX
assert result.vendor == "Hotjar"
def test_allow_list_match(self):
entry = _make_allow_entry(
"_custom_cookie",
"*",
self.necessary_cat.id,
description="Site-specific override",
)
result = classify_cookie(
"_custom_cookie",
".example.com",
[entry],
[],
[],
self.category_map,
)
assert result.matched is True
assert result.match_source == MatchSource.ALLOW_LIST
assert result.category_slug == "necessary"
def test_allow_list_takes_priority_over_known(self):
"""Allow-list should override known cookies database."""
allow_entry = _make_allow_entry(
"_ga",
"*",
self.necessary_cat.id,
description="Overridden to necessary",
)
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
result = classify_cookie(
"_ga",
".example.com",
[allow_entry],
[known],
[],
self.category_map,
)
assert result.match_source == MatchSource.ALLOW_LIST
assert result.category_slug == "necessary"
def test_exact_takes_priority_over_regex(self):
"""Exact match should be preferred over regex match."""
exact = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
regex = _make_known(
r"_g.*",
r".*",
self.marketing_cat.id,
vendor="Other",
is_regex=True,
)
result = classify_cookie(
"_ga",
".example.com",
[],
[exact],
[regex],
self.category_map,
)
assert result.match_source == MatchSource.KNOWN_EXACT
assert result.category_slug == "analytics"
def test_unmatched(self):
result = classify_cookie(
"obscure_cookie",
".unknown.com",
[],
[],
[],
self.category_map,
)
assert result.matched is False
assert result.match_source == MatchSource.UNMATCHED
assert result.category_id is None
def test_domain_must_match(self):
"""Cookie should not match if domain pattern doesn't match."""
known = _make_known("_ga", "*.google.com", self.analytics_cat.id)
result = classify_cookie(
"_ga",
".example.com",
[],
[known],
[],
self.category_map,
)
assert result.matched is False
def test_name_must_match(self):
"""Cookie should not match if name pattern doesn't match."""
known = _make_known("_gid", "*", self.analytics_cat.id)
result = classify_cookie(
"_ga",
".example.com",
[],
[known],
[],
self.category_map,
)
assert result.matched is False
def test_wildcard_domain_match(self):
known = _make_known(
"fr",
"*.facebook.com",
self.marketing_cat.id,
vendor="Meta",
)
result = classify_cookie(
"fr",
".facebook.com",
[],
[known],
[],
self.category_map,
)
assert result.matched is True
assert result.vendor == "Meta"
def test_classification_result_fields(self):
result = ClassificationResult(
cookie_name="_ga",
cookie_domain=".example.com",
)
assert result.category_id is None
assert result.match_source == MatchSource.UNMATCHED
assert result.matched is False
# ── Router unit tests (mocked service) ──────────────────────────────
def _mock_db():
"""Create a mock async DB session."""
db = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = MagicMock()
mock_result.scalars.return_value.all.return_value = []
db.execute.return_value = mock_result
return db
async def _client(app, db):
"""Create an async test client with mocked DB and auth."""
from src.db import get_db
from src.services.dependencies import get_current_user, require_role
user = MagicMock()
user.organisation_id = uuid.uuid4()
user.role = "owner"
async def _override_get_db():
yield db
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_current_user] = lambda: user
def _override_require_role(*_roles):
return lambda: user
app.dependency_overrides[require_role] = _override_require_role
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestKnownCookieRoutes:
"""Test known cookie CRUD endpoints."""
@pytest.mark.asyncio
async def test_list_known_cookies(self, app):
db = _mock_db()
async with await _client(app, db) as client:
resp = await client.get("/api/v1/cookies/known")
assert resp.status_code == 200
assert isinstance(resp.json(), list)
@pytest.mark.asyncio
async def test_create_known_cookie(self, app):
db = _mock_db()
# Mock category validation
cat_result = MagicMock()
cat_result.scalar_one_or_none.return_value = MagicMock()
# Mock the created known cookie
known_mock = MagicMock()
known_mock.id = uuid.uuid4()
known_mock.name_pattern = "_ga"
known_mock.domain_pattern = "*"
known_mock.category_id = uuid.uuid4()
known_mock.vendor = "Google"
known_mock.description = "GA cookie"
known_mock.is_regex = False
known_mock.created_at = datetime.now()
known_mock.updated_at = datetime.now()
call_count = 0
async def mock_execute(stmt):
nonlocal call_count
call_count += 1
if call_count == 1:
# Category validation
return cat_result
return MagicMock()
db.execute = mock_execute
db.flush = AsyncMock()
db.refresh = AsyncMock(side_effect=lambda obj: None)
db.add = MagicMock()
with patch(
"src.routers.cookies.KnownCookie",
return_value=known_mock,
):
async with await _client(app, db) as client:
resp = await client.post(
"/api/v1/cookies/known",
json={
"name_pattern": "_ga",
"domain_pattern": "*",
"category_id": str(uuid.uuid4()),
"vendor": "Google",
},
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_get_known_cookie_not_found(self, app):
db = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
db.execute.return_value = mock_result
async with await _client(app, db) as client:
resp = await client.get(f"/api/v1/cookies/known/{uuid.uuid4()}")
assert resp.status_code == 404
class TestClassificationRoutes:
"""Test classification endpoint responses."""
@pytest.mark.asyncio
async def test_classify_preview(self, app):
db = _mock_db()
mock_result = ClassificationResult(
cookie_name="_ga",
cookie_domain=".example.com",
category_id=uuid.uuid4(),
category_slug="analytics",
vendor="Google",
match_source=MatchSource.KNOWN_EXACT,
matched=True,
)
with patch(
"src.routers.cookies.classify_single_cookie",
return_value=mock_result,
):
async with await _client(app, db) as client:
resp = await client.post(
f"/api/v1/cookies/sites/{uuid.uuid4()}/classify/preview",
json={
"cookie_name": "_ga",
"cookie_domain": ".example.com",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["matched"] is True
assert data["match_source"] == "known_exact"
# ── Integration tests ────────────────────────────────────────────────
try:
from tests.conftest import create_test_site, requires_db
except ImportError:
from conftest import create_test_site, requires_db
@requires_db
class TestClassificationIntegration:
"""Integration tests against a live database."""
async def _get_category_id(self, client: AsyncClient, headers: dict, slug: str) -> str:
"""Get a category ID by slug."""
resp = await client.get("/api/v1/cookies/categories", headers=headers)
assert resp.status_code == 200
for cat in resp.json():
if cat["slug"] == slug:
return cat["id"]
pytest.fail(f"Category '{slug}' not found")
async def _create_known_cookie(
self,
client: AsyncClient,
headers: dict,
name_pattern: str,
domain_pattern: str,
category_slug: str,
*,
vendor: str | None = None,
is_regex: bool = False,
) -> str:
"""Create a known cookie and return its ID."""
cat_id = await self._get_category_id(client, headers, category_slug)
resp = await client.post(
"/api/v1/cookies/known",
headers=headers,
json={
"name_pattern": name_pattern,
"domain_pattern": domain_pattern,
"category_id": cat_id,
"vendor": vendor,
"is_regex": is_regex,
},
)
assert resp.status_code == 201, resp.text
return resp.json()["id"]
async def _create_cookie(
self,
client: AsyncClient,
headers: dict,
site_id: str,
name: str,
domain: str,
) -> str:
"""Create a pending cookie on a site and return its ID."""
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}",
headers=headers,
json={"name": name, "domain": domain},
)
assert resp.status_code == 201, resp.text
return resp.json()["id"]
async def test_known_cookies_crud(self, db_client, auth_headers):
"""Test full CRUD lifecycle for known cookies."""
cat_id = await self._get_category_id(db_client, auth_headers, "analytics")
# Create
resp = await db_client.post(
"/api/v1/cookies/known",
headers=auth_headers,
json={
"name_pattern": f"_test_{uuid.uuid4().hex[:6]}",
"domain_pattern": "*",
"category_id": cat_id,
"vendor": "TestVendor",
"description": "Test cookie",
},
)
assert resp.status_code == 201
known_id = resp.json()["id"]
# Read
resp = await db_client.get(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["vendor"] == "TestVendor"
# Update
resp = await db_client.patch(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
json={"vendor": "UpdatedVendor"},
)
assert resp.status_code == 200
assert resp.json()["vendor"] == "UpdatedVendor"
# List (with search)
resp = await db_client.get(
"/api/v1/cookies/known",
headers=auth_headers,
params={"vendor": "UpdatedVendor"},
)
assert resp.status_code == 200
assert any(k["id"] == known_id for k in resp.json())
# Delete
resp = await db_client.delete(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
)
assert resp.status_code == 204
# Verify deleted
resp = await db_client.get(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_classify_exact_match(self, db_client, auth_headers):
"""Test classification with exact known cookie match."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-exact")
# Create a known cookie pattern
pattern_name = f"_test_exact_{uuid.uuid4().hex[:6]}"
await self._create_known_cookie(
db_client,
auth_headers,
pattern_name,
"*",
"analytics",
vendor="TestVendor",
)
# Create a pending cookie on the site
await self._create_cookie(
db_client,
auth_headers,
site_id,
pattern_name,
".example.com",
)
# Classify
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] >= 1
assert data["matched"] >= 1
matched = [r for r in data["results"] if r["matched"]]
assert any(r["cookie_name"] == pattern_name for r in matched)
async def test_classify_regex_match(self, db_client, auth_headers):
"""Test classification with regex known cookie match."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-regex")
prefix = f"_rx_{uuid.uuid4().hex[:4]}"
# Create regex pattern
await self._create_known_cookie(
db_client,
auth_headers,
f"{prefix}.*",
".*",
"analytics",
vendor="RegexVendor",
is_regex=True,
)
# Create a cookie that should match the regex
await self._create_cookie(
db_client,
auth_headers,
site_id,
f"{prefix}_session_123",
".example.com",
)
# Classify
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["matched"] >= 1
matched = [r for r in data["results"] if r["matched"]]
assert any(r["match_source"] == "known_regex" for r in matched)
async def test_classify_unmatched(self, db_client, auth_headers):
"""Cookies without known patterns should remain unmatched."""
site_id = await create_test_site(
db_client, auth_headers, domain_prefix="classify-unmatched"
)
unique_name = f"_unknown_{uuid.uuid4().hex[:8]}"
await self._create_cookie(
db_client,
auth_headers,
site_id,
unique_name,
".obscure-domain.com",
)
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["unmatched"] >= 1
async def test_classify_preview(self, db_client, auth_headers):
"""Test preview classification without saving."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-preview")
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify/preview",
headers=auth_headers,
json={
"cookie_name": "_unknown_cookie",
"cookie_domain": ".test.com",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["matched"] is False
assert data["match_source"] == "unmatched"
async def test_classify_allow_list_priority(self, db_client, auth_headers):
"""Allow-list entries should take priority over known cookies."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-allow")
cookie_name = f"_priority_{uuid.uuid4().hex[:6]}"
# Add to known cookies as marketing
await self._create_known_cookie(
db_client,
auth_headers,
cookie_name,
"*",
"marketing",
)
# Add to allow-list as necessary (should take priority)
necessary_id = await self._get_category_id(db_client, auth_headers, "necessary")
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/allow-list",
headers=auth_headers,
json={
"name_pattern": cookie_name,
"domain_pattern": "*",
"category_id": necessary_id,
},
)
assert resp.status_code == 201
# Create cookie and classify
await self._create_cookie(
db_client,
auth_headers,
site_id,
cookie_name,
".example.com",
)
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
matched = [r for r in data["results"] if r["cookie_name"] == cookie_name]
assert len(matched) == 1
assert matched[0]["match_source"] == "allow_list"
assert matched[0]["category_id"] == necessary_id
async def test_known_cookies_not_found(self, db_client, auth_headers):
resp = await db_client.get(
f"/api/v1/cookies/known/{uuid.uuid4()}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_known_cookies_invalid_category(self, db_client, auth_headers):
resp = await db_client.post(
"/api/v1/cookies/known",
headers=auth_headers,
json={
"name_pattern": "_test",
"domain_pattern": "*",
"category_id": str(uuid.uuid4()),
},
)
assert resp.status_code == 400
async def test_known_cookies_auth_required(self, db_client):
"""Known cookie endpoints require authentication."""
resp = await db_client.get("/api/v1/cookies/known")
assert resp.status_code == 401
async def test_classify_empty_site(self, db_client, auth_headers):
"""Classifying a site with no cookies should return empty results."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-empty")
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
assert data["matched"] == 0
async def test_list_known_cookies_search(self, db_client, auth_headers):
"""Test searching known cookies by name pattern."""
unique = uuid.uuid4().hex[:6]
await self._create_known_cookie(
db_client,
auth_headers,
f"_search_{unique}",
"*",
"analytics",
)
resp = await db_client.get(
"/api/v1/cookies/known",
headers=auth_headers,
params={"search": f"_search_{unique}"},
)
assert resp.status_code == 200
results = resp.json()
assert len(results) >= 1
assert all(f"_search_{unique}" in r["name_pattern"] for r in results)

View File

@@ -0,0 +1,597 @@
"""Tests for the compliance rule engine and router."""
import uuid
import pytest
from httpx import ASGITransport, AsyncClient
from src.schemas.compliance import (
ComplianceCheckResponse,
ComplianceIssue,
Framework,
FrameworkResult,
Severity,
)
from src.services.compliance import (
CCPA_RULES,
CNIL_RULES,
EPRIVACY_RULES,
FRAMEWORK_RULES,
GDPR_RULES,
LGPD_RULES,
SiteContext,
calculate_overall_score,
run_compliance_check,
run_framework_check,
)
# ── SiteContext defaults ──────────────────────────────────────────────
class TestSiteContext:
def test_default_values(self):
ctx = SiteContext()
assert ctx.blocking_mode == "opt_in"
assert ctx.tcf_enabled is False
assert ctx.gcm_enabled is True
assert ctx.consent_expiry_days == 365
assert ctx.has_reject_button is True
assert ctx.has_granular_choices is True
assert ctx.has_cookie_wall is False
assert ctx.pre_ticked_boxes is False
def test_custom_values(self):
ctx = SiteContext(
blocking_mode="opt_out",
consent_expiry_days=180,
privacy_policy_url="https://example.com/privacy",
)
assert ctx.blocking_mode == "opt_out"
assert ctx.consent_expiry_days == 180
assert ctx.privacy_policy_url == "https://example.com/privacy"
# ── GDPR rules ────────────────────────────────────────────────────────
class TestGDPRRules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_in",
has_reject_button=True,
has_granular_choices=True,
has_cookie_wall=False,
pre_ticked_boxes=False,
privacy_policy_url="https://example.com/privacy",
uncategorised_cookies=0,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.score == 100
assert result.status == "compliant"
assert len(result.issues) == 0
assert result.rules_passed == result.rules_checked
def test_opt_out_mode_fails(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
assert result.status == "non_compliant"
def test_informational_mode_fails(self):
ctx = SiteContext(
blocking_mode="informational",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
def test_no_reject_button_fails(self):
ctx = SiteContext(
has_reject_button=False,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_reject_button" for i in result.issues)
def test_no_granular_consent_fails(self):
ctx = SiteContext(
has_granular_choices=False,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_granular" for i in result.issues)
def test_cookie_wall_fails(self):
ctx = SiteContext(
has_cookie_wall=True,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_cookie_wall" for i in result.issues)
def test_pre_ticked_fails(self):
ctx = SiteContext(
pre_ticked_boxes=True,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_pre_ticked" for i in result.issues)
def test_no_privacy_policy_warns(self):
ctx = SiteContext(privacy_policy_url=None)
result = run_framework_check(Framework.GDPR, ctx)
policy_issues = [i for i in result.issues if i.rule_id == "gdpr_privacy_policy"]
assert len(policy_issues) == 1
assert policy_issues[0].severity == Severity.WARNING
def test_uncategorised_cookies_warns(self):
ctx = SiteContext(
uncategorised_cookies=5,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
uncat_issues = [i for i in result.issues if i.rule_id == "gdpr_uncategorised"]
assert len(uncat_issues) == 1
assert "5" in uncat_issues[0].message
def test_multiple_failures_accumulate(self):
ctx = SiteContext(
blocking_mode="opt_out",
has_reject_button=False,
has_granular_choices=False,
has_cookie_wall=True,
pre_ticked_boxes=True,
privacy_policy_url=None,
uncategorised_cookies=3,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.score == 0 # Capped at 0
assert result.status == "non_compliant"
assert len(result.issues) >= 5
# ── CNIL rules ────────────────────────────────────────────────────────
class TestCNILRules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_in",
has_reject_button=True,
has_granular_choices=True,
privacy_policy_url="https://example.com/privacy",
consent_expiry_days=180,
)
result = run_framework_check(Framework.CNIL, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_consent_expiry_too_long(self):
ctx = SiteContext(
consent_expiry_days=365,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert any(i.rule_id == "cnil_reconsent" for i in result.issues)
def test_consent_expiry_at_limit(self):
ctx = SiteContext(
consent_expiry_days=182,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert not any(i.rule_id == "cnil_reconsent" for i in result.issues)
def test_cookie_lifetime_too_long(self):
ctx = SiteContext(
consent_expiry_days=400,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
def test_cookie_lifetime_at_limit(self):
ctx = SiteContext(
consent_expiry_days=395,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert not any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
def test_inherits_gdpr_rules(self):
"""CNIL should check all GDPR rules plus CNIL-specific ones."""
assert len(CNIL_RULES) > len(GDPR_RULES)
def test_reject_first_layer(self):
ctx = SiteContext(
has_reject_button=False,
consent_expiry_days=180,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert any(i.rule_id == "cnil_reject_first_layer" for i in result.issues)
# ── CCPA rules ────────────────────────────────────────────────────────
class TestCCPARules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_opt_in_also_acceptable(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
assert not any(i.rule_id == "ccpa_opt_out" for i in result.issues)
def test_informational_mode_passes_ccpa(self):
"""CCPA opt-out check passes for informational (it's not 'informational')."""
ctx = SiteContext(
blocking_mode="informational",
privacy_policy_url="https://example.com/privacy",
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
# informational is not in ("opt_out", "opt_in"), so it fails
assert any(i.rule_id == "ccpa_opt_out" for i in result.issues)
def test_no_do_not_sell_link(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
banner_config={},
)
result = run_framework_check(Framework.CCPA, ctx)
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
def test_no_banner_config_fails_dns(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
banner_config=None,
)
result = run_framework_check(Framework.CCPA, ctx)
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
def test_no_privacy_policy_warns(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url=None,
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
assert any(i.rule_id == "ccpa_privacy_policy" for i in result.issues)
# ── ePrivacy rules ────────────────────────────────────────────────────
class TestEPrivacyRules:
def test_compliant_site(self):
ctx = SiteContext(blocking_mode="opt_in")
result = run_framework_check(Framework.EPRIVACY, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_opt_out_passes(self):
ctx = SiteContext(blocking_mode="opt_out")
result = run_framework_check(Framework.EPRIVACY, ctx)
assert not any(i.rule_id == "eprivacy_consent" for i in result.issues)
def test_informational_fails(self):
ctx = SiteContext(blocking_mode="informational")
result = run_framework_check(Framework.EPRIVACY, ctx)
assert any(i.rule_id == "eprivacy_consent" for i in result.issues)
# ── LGPD rules ────────────────────────────────────────────────────────
class TestLGPDRules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
has_granular_choices=True,
)
result = run_framework_check(Framework.LGPD, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_informational_fails(self):
ctx = SiteContext(
blocking_mode="informational",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.LGPD, ctx)
assert any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
def test_no_privacy_policy_warns(self):
ctx = SiteContext(privacy_policy_url=None)
result = run_framework_check(Framework.LGPD, ctx)
assert any(i.rule_id == "lgpd_data_controller" for i in result.issues)
def test_no_granular_warns(self):
ctx = SiteContext(
has_granular_choices=False,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.LGPD, ctx)
assert any(i.rule_id == "lgpd_granular" for i in result.issues)
def test_opt_out_passes(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.LGPD, ctx)
assert not any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
# ── Engine orchestration ──────────────────────────────────────────────
class TestComplianceEngine:
def test_run_all_frameworks(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
has_reject_button=True,
has_granular_choices=True,
consent_expiry_days=180,
)
results = run_compliance_check(ctx)
assert len(results) == 5
frameworks = {r.framework for r in results}
assert frameworks == {
Framework.GDPR,
Framework.CNIL,
Framework.CCPA,
Framework.EPRIVACY,
Framework.LGPD,
}
def test_run_specific_frameworks(self):
ctx = SiteContext()
results = run_compliance_check(ctx, [Framework.GDPR, Framework.CCPA])
assert len(results) == 2
assert results[0].framework == Framework.GDPR
assert results[1].framework == Framework.CCPA
def test_run_single_framework(self):
ctx = SiteContext()
results = run_compliance_check(ctx, [Framework.EPRIVACY])
assert len(results) == 1
assert results[0].framework == Framework.EPRIVACY
def test_empty_frameworks_list_runs_all(self):
ctx = SiteContext(
privacy_policy_url="https://example.com/privacy",
consent_expiry_days=180,
)
results = run_compliance_check(ctx, None)
assert len(results) == 5
class TestScoring:
def test_perfect_score(self):
result = FrameworkResult(
framework=Framework.GDPR,
score=100,
status="compliant",
rules_checked=7,
rules_passed=7,
)
assert calculate_overall_score([result]) == 100
def test_zero_score(self):
result = FrameworkResult(
framework=Framework.GDPR,
score=0,
status="non_compliant",
rules_checked=7,
rules_passed=0,
)
assert calculate_overall_score([result]) == 0
def test_average_across_frameworks(self):
results = [
FrameworkResult(
framework=Framework.GDPR,
score=100,
status="compliant",
rules_checked=7,
rules_passed=7,
),
FrameworkResult(
framework=Framework.CCPA,
score=50,
status="partial",
rules_checked=3,
rules_passed=1,
),
]
assert calculate_overall_score(results) == 75
def test_empty_results(self):
assert calculate_overall_score([]) == 100
def test_critical_issues_deduct_20(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
# opt_out causes one critical issue (gdpr_opt_in) → -20 points
assert result.score == 80
def test_warning_issues_deduct_5(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url=None,
uncategorised_cookies=0,
)
result = run_framework_check(Framework.GDPR, ctx)
# Missing privacy policy is a warning → -5 points
assert result.score == 95
def test_score_floors_at_zero(self):
ctx = SiteContext(
blocking_mode="opt_out",
has_reject_button=False,
has_granular_choices=False,
has_cookie_wall=True,
pre_ticked_boxes=True,
privacy_policy_url=None,
uncategorised_cookies=10,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.score == 0
def test_status_non_compliant_with_critical(self):
ctx = SiteContext(blocking_mode="opt_out")
result = run_framework_check(Framework.GDPR, ctx)
assert result.status == "non_compliant"
def test_status_partial_with_warnings_only(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url=None,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.status == "partial"
def test_status_compliant_with_no_issues(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.status == "compliant"
# ── Framework registry ────────────────────────────────────────────────
class TestFrameworkRegistry:
def test_all_frameworks_registered(self):
assert Framework.GDPR in FRAMEWORK_RULES
assert Framework.CNIL in FRAMEWORK_RULES
assert Framework.CCPA in FRAMEWORK_RULES
assert Framework.EPRIVACY in FRAMEWORK_RULES
assert Framework.LGPD in FRAMEWORK_RULES
def test_each_framework_has_rules(self):
for fw, rules in FRAMEWORK_RULES.items():
assert len(rules) > 0, f"{fw} has no rules"
def test_rule_ids_are_unique_per_framework(self):
for fw, rules in FRAMEWORK_RULES.items():
ids = [r.rule_id for r in rules]
assert len(ids) == len(set(ids)), f"Duplicate rule IDs in {fw}"
def test_gdpr_rule_count(self):
assert len(GDPR_RULES) == 7
def test_cnil_includes_gdpr_rules(self):
gdpr_ids = {r.rule_id for r in GDPR_RULES}
cnil_ids = {r.rule_id for r in CNIL_RULES}
assert gdpr_ids.issubset(cnil_ids)
def test_ccpa_rule_count(self):
assert len(CCPA_RULES) == 3
def test_eprivacy_rule_count(self):
assert len(EPRIVACY_RULES) == 2
def test_lgpd_rule_count(self):
assert len(LGPD_RULES) == 3
# ── Router tests ──────────────────────────────────────────────────────
class TestComplianceRouter:
@pytest.fixture
def app(self):
from src.main import create_app
return create_app()
@pytest.fixture
async def client(self, app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
async def test_list_frameworks(self, client):
resp = await client.get("/api/v1/compliance/frameworks")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 5
ids = {fw["id"] for fw in data}
assert ids == {"gdpr", "cnil", "ccpa", "eprivacy", "lgpd"}
async def test_check_requires_auth(self, client):
resp = await client.post(f"/api/v1/compliance/check/{uuid.uuid4()}")
assert resp.status_code == 401
# ── Schema tests ──────────────────────────────────────────────────────
class TestSchemas:
def test_compliance_issue_schema(self):
issue = ComplianceIssue(
rule_id="test_rule",
severity=Severity.CRITICAL,
message="Test message",
recommendation="Test recommendation",
)
assert issue.rule_id == "test_rule"
assert issue.severity == Severity.CRITICAL
def test_framework_result_schema(self):
result = FrameworkResult(
framework=Framework.GDPR,
score=85,
status="partial",
rules_checked=7,
rules_passed=5,
)
assert result.framework == Framework.GDPR
assert result.score == 85
def test_compliance_check_response_schema(self):
response = ComplianceCheckResponse(
site_id="test-id",
results=[],
overall_score=100,
)
assert response.overall_score == 100
def test_severity_values(self):
assert Severity.CRITICAL == "critical"
assert Severity.WARNING == "warning"
assert Severity.INFO == "info"
def test_framework_values(self):
assert Framework.GDPR == "gdpr"
assert Framework.CNIL == "cnil"
assert Framework.CCPA == "ccpa"
assert Framework.EPRIVACY == "eprivacy"
assert Framework.LGPD == "lgpd"

View File

@@ -0,0 +1,258 @@
"""Tests for configuration hierarchy resolver and publisher."""
import uuid
import pytest
from src.services.config_resolver import (
SYSTEM_DEFAULTS,
build_public_config,
resolve_config,
)
class TestResolveConfig:
def test_returns_system_defaults_for_empty_config(self):
result = resolve_config({})
assert result["blocking_mode"] == "opt_in"
assert result["consent_expiry_days"] == 365
assert result["gcm_enabled"] is True
assert result["tcf_enabled"] is False
assert result["gpp_enabled"] is True
assert result["gpp_supported_apis"] == ["usnat"]
assert result["gpc_enabled"] is True
assert result["gpc_jurisdictions"] == [
"US-CA",
"US-CO",
"US-CT",
"US-TX",
"US-MT",
]
assert result["gpc_global_honour"] is False
def test_site_config_overrides_defaults(self):
site_config = {
"blocking_mode": "opt_out",
"consent_expiry_days": 180,
"tcf_enabled": True,
}
result = resolve_config(site_config)
assert result["blocking_mode"] == "opt_out"
assert result["consent_expiry_days"] == 180
assert result["tcf_enabled"] is True
# Non-overridden values stay as defaults
assert result["gcm_enabled"] is True
def test_org_defaults_override_system_defaults(self):
org_defaults = {"consent_expiry_days": 90}
result = resolve_config({}, org_defaults=org_defaults)
assert result["consent_expiry_days"] == 90
def test_site_config_overrides_org_defaults(self):
org_defaults = {"consent_expiry_days": 90}
site_config = {"consent_expiry_days": 30}
result = resolve_config(site_config, org_defaults=org_defaults)
assert result["consent_expiry_days"] == 30
def test_none_values_in_site_config_do_not_override(self):
site_config = {"blocking_mode": None, "consent_expiry_days": 180}
result = resolve_config(site_config)
# None should not override the default
assert result["blocking_mode"] == "opt_in"
assert result["consent_expiry_days"] == 180
def test_regional_override_applied(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"US-CA": "opt_out", "EU": "opt_in"},
}
result = resolve_config(site_config, region="US-CA")
assert result["blocking_mode"] == "opt_out"
def test_regional_override_falls_back_to_default(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"EU": "opt_in", "DEFAULT": "opt_out"},
}
result = resolve_config(site_config, region="BR")
assert result["blocking_mode"] == "opt_out"
def test_regional_override_no_match_keeps_site_config(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"EU": "opt_out"},
}
result = resolve_config(site_config, region="JP")
assert result["blocking_mode"] == "opt_in"
def test_no_region_ignores_regional_modes(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"US-CA": "opt_out"},
}
result = resolve_config(site_config)
assert result["blocking_mode"] == "opt_in"
def test_gpp_site_config_overrides_defaults(self):
site_config = {
"gpp_enabled": False,
"gpp_supported_apis": ["usca", "usva"],
}
result = resolve_config(site_config)
assert result["gpp_enabled"] is False
assert result["gpp_supported_apis"] == ["usca", "usva"]
def test_gpc_site_config_overrides_defaults(self):
site_config = {
"gpc_enabled": False,
"gpc_global_honour": True,
"gpc_jurisdictions": ["US-CA"],
}
result = resolve_config(site_config)
assert result["gpc_enabled"] is False
assert result["gpc_global_honour"] is True
assert result["gpc_jurisdictions"] == ["US-CA"]
def test_gpp_gpc_org_defaults_override_system(self):
org_defaults = {
"gpp_enabled": False,
"gpc_global_honour": True,
}
result = resolve_config({}, org_defaults=org_defaults)
assert result["gpp_enabled"] is False
assert result["gpc_global_honour"] is True
# Non-overridden GPP/GPC fields stay as system defaults
assert result["gpc_enabled"] is True
def test_gpp_gpc_site_overrides_org(self):
org_defaults = {"gpp_supported_apis": ["usca"]}
site_config = {"gpp_supported_apis": ["usnat", "usco"]}
result = resolve_config(site_config, org_defaults=org_defaults)
assert result["gpp_supported_apis"] == ["usnat", "usco"]
def test_group_defaults_override_org_defaults(self):
org_defaults = {"consent_expiry_days": 90, "tcf_enabled": True}
group_defaults = {"consent_expiry_days": 60}
result = resolve_config(
{},
org_defaults=org_defaults,
group_defaults=group_defaults,
)
assert result["consent_expiry_days"] == 60 # Group overrides org
assert result["tcf_enabled"] is True # Still from org
def test_site_config_overrides_group_defaults(self):
group_defaults = {"consent_expiry_days": 60}
site_config = {"consent_expiry_days": 30}
result = resolve_config(site_config, group_defaults=group_defaults)
assert result["consent_expiry_days"] == 30 # Site overrides group
def test_none_in_group_defaults_does_not_override(self):
org_defaults = {"consent_expiry_days": 90}
group_defaults = {"consent_expiry_days": None}
result = resolve_config(
{},
org_defaults=org_defaults,
group_defaults=group_defaults,
)
assert result["consent_expiry_days"] == 90 # Org value preserved
def test_full_hierarchy(self):
org_defaults = {
"consent_expiry_days": 90,
"tcf_enabled": True,
}
site_config = {
"consent_expiry_days": 60,
"banner_config": {"primaryColour": "#ff0000"},
"regional_modes": {"EU": "opt_in", "US": "opt_out"},
}
result = resolve_config(site_config, org_defaults=org_defaults, region="US")
assert result["consent_expiry_days"] == 60 # Site overrides org
assert result["tcf_enabled"] is True # From org defaults
assert result["blocking_mode"] == "opt_out" # Regional override
assert result["banner_config"] == {"primaryColour": "#ff0000"}
def test_full_hierarchy_with_group(self):
org_defaults = {
"consent_expiry_days": 90,
"tcf_enabled": True,
"blocking_mode": "opt_in",
}
group_defaults = {
"consent_expiry_days": 60,
"privacy_policy_url": "https://group.example.com/privacy",
}
site_config = {
"banner_config": {"primaryColour": "#ff0000"},
"regional_modes": {"US": "opt_out"},
}
result = resolve_config(
site_config,
org_defaults=org_defaults,
group_defaults=group_defaults,
region="US",
)
assert result["consent_expiry_days"] == 60 # From group
assert result["tcf_enabled"] is True # From org
assert result["blocking_mode"] == "opt_out" # Regional override
assert result["privacy_policy_url"] == "https://group.example.com/privacy" # From group
assert result["banner_config"] == {"primaryColour": "#ff0000"} # From site
class TestBuildPublicConfig:
def test_includes_required_fields(self):
site_id = str(uuid.uuid4())
resolved = {**SYSTEM_DEFAULTS, "id": "config-123"}
result = build_public_config(site_id, resolved)
assert result["site_id"] == site_id
assert result["id"] == "config-123"
assert result["blocking_mode"] == "opt_in"
assert result["consent_expiry_days"] == 365
assert "gcm_enabled" in result
assert "tcf_enabled" in result
assert "banner_config" in result
assert result["gpp_enabled"] is True
assert result["gpp_supported_apis"] == ["usnat"]
assert result["gpc_enabled"] is True
assert result["gpc_jurisdictions"] == [
"US-CA",
"US-CO",
"US-CT",
"US-TX",
"US-MT",
]
assert result["gpc_global_honour"] is False
def test_strips_unknown_internal_fields(self):
site_id = str(uuid.uuid4())
resolved = {
**SYSTEM_DEFAULTS,
"id": "",
"internal_field": "should_not_appear",
"scan_enabled": True,
}
result = build_public_config(site_id, resolved)
assert "internal_field" not in result
assert "scan_enabled" not in result
class TestConfigRoutes:
def test_resolved_config_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/config/sites/{site_id}/resolved" in routes
def test_publish_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/config/sites/{site_id}/publish" in routes
def test_inheritance_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/config/sites/{site_id}/inheritance" in routes
@pytest.mark.asyncio
async def test_publish_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.post(f"/api/v1/config/sites/{site_id}/publish")
assert resp.status_code == 401

View File

@@ -0,0 +1,130 @@
"""Tests for consent recording API schemas and routes."""
import uuid
import pytest
from pydantic import ValidationError
from src.schemas.consent import (
ConsentAction,
ConsentRecordCreate,
ConsentRecordResponse,
ConsentVerifyResponse,
)
class TestConsentSchemas:
def test_create_accept_all(self):
record = ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="visitor-abc-123",
action=ConsentAction.ACCEPT_ALL,
categories_accepted=["necessary", "analytics", "marketing"],
)
assert record.action == "accept_all"
assert len(record.categories_accepted) == 3
assert record.categories_rejected is None
def test_create_custom(self):
record = ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="visitor-xyz",
action=ConsentAction.CUSTOM,
categories_accepted=["necessary", "functional"],
categories_rejected=["analytics", "marketing"],
tc_string="COwQHgAAAAA",
gcm_state={"analytics_storage": "denied", "ad_storage": "denied"},
page_url="https://example.com/page",
country_code="GB",
region_code="GB-ENG",
)
assert record.action == "custom"
assert record.tc_string == "COwQHgAAAAA"
assert record.gcm_state["analytics_storage"] == "denied"
def test_create_reject_all(self):
record = ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="v-1",
action=ConsentAction.REJECT_ALL,
categories_accepted=["necessary"],
categories_rejected=["analytics", "marketing", "functional"],
)
assert record.action == "reject_all"
def test_empty_visitor_id_rejected(self):
with pytest.raises(ValidationError):
ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="",
action=ConsentAction.ACCEPT_ALL,
categories_accepted=["necessary"],
)
def test_invalid_action_rejected(self):
with pytest.raises(ValidationError):
ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="v-1",
action="invalid_action",
categories_accepted=[],
)
def test_response_from_attributes(self):
resp = ConsentRecordResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
visitor_id="v-1",
action="accept_all",
categories_accepted=["necessary"],
categories_rejected=None,
tc_string=None,
gcm_state=None,
page_url=None,
country_code=None,
region_code=None,
consented_at="2026-01-01T00:00:00Z",
)
assert resp.action == "accept_all"
def test_verify_response(self):
resp = ConsentVerifyResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
visitor_id="v-1",
action="accept_all",
categories_accepted=["necessary"],
consented_at="2026-01-01T00:00:00Z",
)
assert resp.valid is True
class TestConsentActions:
def test_action_values(self):
assert ConsentAction.ACCEPT_ALL == "accept_all"
assert ConsentAction.REJECT_ALL == "reject_all"
assert ConsentAction.CUSTOM == "custom"
assert ConsentAction.WITHDRAW == "withdraw"
@pytest.mark.asyncio
class TestConsentRoutesRegistered:
async def test_consent_routes_exist(self, client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/api/v1/consent/" in paths
assert "/api/v1/consent/{consent_id}" in paths
assert "/api/v1/consent/verify/{consent_id}" in paths
async def test_consent_post_validates_body(self, client):
"""POST /consent rejects invalid payloads."""
response = await client.post(
"/api/v1/consent/",
json={"invalid": "body"},
)
assert response.status_code == 422
async def test_config_public_endpoint_exists(self, client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/api/v1/config/sites/{site_id}" in paths

View File

@@ -0,0 +1,218 @@
"""Tests for cookie category, cookie, and allow-list schemas and routes."""
import uuid
import pytest
from pydantic import ValidationError
from src.schemas.cookie import (
AllowListEntryCreate,
AllowListEntryUpdate,
CookieCategoryResponse,
CookieCreate,
CookieResponse,
CookieUpdate,
ReviewStatus,
StorageType,
)
# ─── Schema tests ────────────────────────────────────────────────────
class TestStorageType:
def test_values(self):
assert StorageType.cookie == "cookie"
assert StorageType.local_storage == "local_storage"
assert StorageType.session_storage == "session_storage"
assert StorageType.indexed_db == "indexed_db"
class TestReviewStatus:
def test_values(self):
assert ReviewStatus.pending == "pending"
assert ReviewStatus.approved == "approved"
assert ReviewStatus.rejected == "rejected"
class TestCookieCreate:
def test_valid_minimal(self):
schema = CookieCreate(name="_ga", domain=".example.com")
assert schema.name == "_ga"
assert schema.domain == ".example.com"
assert schema.storage_type == StorageType.cookie
assert schema.category_id is None
def test_valid_full(self):
cat_id = uuid.uuid4()
schema = CookieCreate(
name="_ga",
domain=".google.com",
storage_type=StorageType.cookie,
category_id=cat_id,
description="Google Analytics cookie",
vendor="Google",
path="/",
max_age_seconds=63072000,
is_http_only=False,
is_secure=True,
same_site="Lax",
)
assert schema.category_id == cat_id
assert schema.max_age_seconds == 63072000
def test_rejects_empty_name(self):
with pytest.raises(ValidationError):
CookieCreate(name="", domain=".example.com")
def test_rejects_empty_domain(self):
with pytest.raises(ValidationError):
CookieCreate(name="_ga", domain="")
class TestCookieUpdate:
def test_partial_update(self):
schema = CookieUpdate(review_status=ReviewStatus.approved)
dump = schema.model_dump(exclude_unset=True)
assert dump == {"review_status": ReviewStatus.approved}
def test_update_category(self):
cat_id = uuid.uuid4()
schema = CookieUpdate(category_id=cat_id)
assert schema.category_id == cat_id
class TestAllowListEntryCreate:
def test_valid(self):
cat_id = uuid.uuid4()
schema = AllowListEntryCreate(
name_pattern="_ga*",
domain_pattern=".google.com",
category_id=cat_id,
description="Google Analytics cookies",
)
assert schema.name_pattern == "_ga*"
assert schema.category_id == cat_id
def test_rejects_empty_name_pattern(self):
with pytest.raises(ValidationError):
AllowListEntryCreate(
name_pattern="",
domain_pattern=".example.com",
category_id=uuid.uuid4(),
)
class TestAllowListEntryUpdate:
def test_partial_update(self):
schema = AllowListEntryUpdate(description="Updated description")
dump = schema.model_dump(exclude_unset=True)
assert dump == {"description": "Updated description"}
class TestCookieCategoryResponse:
def test_from_dict(self):
now = "2024-01-01T00:00:00"
resp = CookieCategoryResponse(
id=uuid.uuid4(),
name="Analytics",
slug="analytics",
description="Analytics cookies",
is_essential=False,
display_order=2,
tcf_purpose_ids=[1, 3],
gcm_consent_types=["analytics_storage"],
created_at=now,
updated_at=now,
)
assert resp.slug == "analytics"
assert resp.is_essential is False
class TestCookieResponse:
def test_from_dict(self):
now = "2024-01-01T00:00:00"
resp = CookieResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
name="_ga",
domain=".google.com",
storage_type="cookie",
review_status="pending",
created_at=now,
updated_at=now,
)
assert resp.name == "_ga"
assert resp.review_status == "pending"
# ─── Route tests ─────────────────────────────────────────────────────
class TestCookieCategoryRoutes:
def test_categories_route_registered(self, app):
"""Verify the categories routes are registered in the app."""
routes = [r.path for r in app.routes]
assert "/api/v1/cookies/categories" in routes
assert "/api/v1/cookies/categories/{category_id}" in routes
class TestCookieRoutes:
@pytest.mark.asyncio
async def test_list_cookies_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.get(f"/api/v1/cookies/sites/{site_id}")
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_create_cookie_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}",
json={"name": "_ga", "domain": ".google.com"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_create_cookie_rejects_invalid_body(self, client):
site_id = uuid.uuid4()
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}",
json={"name": "", "domain": ""},
headers={"Authorization": "Bearer fake-token"},
)
# Should return 401 (bad token) or 422 (validation)
assert resp.status_code in (401, 422)
@pytest.mark.asyncio
async def test_summary_route_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/summary")
assert resp.status_code == 401
class TestAllowListRoutes:
@pytest.mark.asyncio
async def test_list_allow_list_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/allow-list")
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_create_allow_list_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}/allow-list",
json={
"name_pattern": "_ga*",
"domain_pattern": ".google.com",
"category_id": str(uuid.uuid4()),
},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_delete_allow_list_requires_auth(self, client):
site_id = uuid.uuid4()
entry_id = uuid.uuid4()
resp = await client.delete(f"/api/v1/cookies/sites/{site_id}/allow-list/{entry_id}")
assert resp.status_code == 401

222
apps/api/tests/test_cors.py Normal file
View File

@@ -0,0 +1,222 @@
"""Tests for the dynamic CORS origin validation service."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from src.services.cors import extract_domain_from_origin, get_allowed_domains, is_origin_allowed
class TestExtractDomainFromOrigin:
def test_https_origin(self):
assert extract_domain_from_origin("https://example.com") == "example.com"
def test_http_origin(self):
assert extract_domain_from_origin("http://example.com") == "example.com"
def test_origin_with_port(self):
assert extract_domain_from_origin("https://example.com:443") == "example.com"
def test_origin_with_subdomain(self):
assert extract_domain_from_origin("https://www.example.com") == "www.example.com"
def test_localhost(self):
assert extract_domain_from_origin("http://localhost:5173") == "localhost"
def test_empty_string(self):
assert extract_domain_from_origin("") is None
def test_invalid_url(self):
# urlparse is lenient, but hostname may be None for really bad input
result = extract_domain_from_origin("not-a-url")
# urlparse("not-a-url") sets hostname to None
assert result is None
class TestIsOriginAllowed:
def test_static_origin_exact_match(self):
assert (
is_origin_allowed(
"http://localhost:5173",
["http://localhost:5173"],
set(),
)
is True
)
def test_static_origin_no_match(self):
assert (
is_origin_allowed(
"https://evil.com",
["http://localhost:5173"],
set(),
)
is False
)
def test_wildcard_allows_everything(self):
assert (
is_origin_allowed(
"https://anything.com",
["*"],
set(),
)
is True
)
def test_registered_domain_match(self):
assert (
is_origin_allowed(
"https://example.com",
[],
{"example.com", "other.com"},
)
is True
)
def test_registered_domain_case_insensitive(self):
assert (
is_origin_allowed(
"https://Example.COM",
[],
{"example.com"},
)
is True
)
def test_registered_domain_no_match(self):
assert (
is_origin_allowed(
"https://evil.com",
[],
{"example.com"},
)
is False
)
def test_static_takes_priority(self):
assert (
is_origin_allowed(
"http://localhost:5173",
["http://localhost:5173"],
{"example.com"},
)
is True
)
def test_origin_with_port_matches_domain(self):
assert (
is_origin_allowed(
"https://example.com:8443",
[],
{"example.com"},
)
is True
)
def test_subdomain_matches_if_registered(self):
# www.example.com only matches if explicitly registered
assert (
is_origin_allowed(
"https://www.example.com",
[],
{"example.com"},
)
is False
)
def test_subdomain_matches_when_registered(self):
assert (
is_origin_allowed(
"https://www.example.com",
[],
{"www.example.com"},
)
is True
)
def test_empty_origin(self):
assert (
is_origin_allowed(
"",
[],
{"example.com"},
)
is False
)
def test_empty_lists(self):
assert (
is_origin_allowed(
"https://example.com",
[],
set(),
)
is False
)
class TestGetAllowedDomains:
@pytest.mark.asyncio
async def test_returns_primary_domains(self):
row1 = MagicMock()
row1.domain = "example.com"
row1.additional_domains = None
row2 = MagicMock()
row2.domain = "other.com"
row2.additional_domains = None
mock_result = MagicMock()
mock_result.all.return_value = [row1, row2]
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert "example.com" in domains
assert "other.com" in domains
@pytest.mark.asyncio
async def test_includes_additional_domains(self):
row = MagicMock()
row.domain = "example.com"
row.additional_domains = ["www.example.com", "app.example.com"]
mock_result = MagicMock()
mock_result.all.return_value = [row]
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert "example.com" in domains
assert "www.example.com" in domains
assert "app.example.com" in domains
@pytest.mark.asyncio
async def test_lowercases_domains(self):
row = MagicMock()
row.domain = "Example.COM"
row.additional_domains = ["WWW.Example.COM"]
mock_result = MagicMock()
mock_result.all.return_value = [row]
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert "example.com" in domains
assert "www.example.com" in domains
@pytest.mark.asyncio
async def test_empty_result(self):
mock_result = MagicMock()
mock_result.all.return_value = []
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert domains == set()

View File

@@ -0,0 +1,88 @@
"""Unit tests for auth dependencies."""
import uuid
from src.schemas.auth import CurrentUser
from src.services.auth import create_access_token, create_refresh_token, decode_token
class TestCurrentUser:
def test_has_role_matching(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="admin",
)
assert user.has_role("admin", "owner") is True
def test_has_role_not_matching(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="viewer",
)
assert user.has_role("admin", "owner") is False
def test_is_admin_property(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="admin",
)
assert user.is_admin is True
def test_is_admin_owner(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="owner",
)
assert user.is_admin is True
def test_is_admin_viewer(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="viewer",
)
assert user.is_admin is False
class TestTokenCreation:
def test_access_token_roundtrip(self):
user_id = uuid.uuid4()
org_id = uuid.uuid4()
token = create_access_token(
user_id=user_id,
organisation_id=org_id,
role="editor",
email="test@test.com",
)
payload = decode_token(token)
assert payload["sub"] == str(user_id)
assert payload["org_id"] == str(org_id)
assert payload["role"] == "editor"
assert payload["type"] == "access"
def test_refresh_token_roundtrip(self):
user_id = uuid.uuid4()
org_id = uuid.uuid4()
token = create_refresh_token(user_id=user_id, organisation_id=org_id)
payload = decode_token(token)
assert payload["sub"] == str(user_id)
assert payload["type"] == "refresh"
def test_access_token_is_not_refresh(self):
token = create_access_token(
user_id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
role="viewer",
email="test@test.com",
)
payload = decode_token(token)
assert payload["type"] != "refresh"

View File

@@ -0,0 +1,141 @@
"""Tests for the extension registry and edition detection."""
import pytest
from fastapi import APIRouter, FastAPI
from src.config.edition import edition_name, is_ee
from src.extensions.registry import (
ExtensionRegistry,
OpenAPITag,
discover_extensions,
get_registry,
)
# -- Edition detection -------------------------------------------------------
class TestEditionDetection:
"""The ``is_ee()`` / ``edition_name()`` helpers should return a
consistent pair regardless of which edition is installed. Core tests
don't assume a specific edition — that's checked in each repo's
own integration tests."""
def test_edition_name_matches_is_ee(self):
assert edition_name() == ("ee" if is_ee() else "ce")
def test_edition_name_is_valid(self):
assert edition_name() in ("ce", "ee")
# -- Extension registry (unit) ----------------------------------------------
class TestExtensionRegistry:
def _make_registry(self) -> ExtensionRegistry:
return ExtensionRegistry()
def test_empty_registry(self):
reg = self._make_registry()
assert reg.routers == []
assert reg.model_modules == []
assert reg.startup_hooks == []
def test_add_router(self):
reg = self._make_registry()
router = APIRouter()
reg.add_router(router, prefix="/api/v1")
assert len(reg.routers) == 1
assert reg.routers[0].router is router
assert reg.routers[0].prefix == "/api/v1"
def test_add_router_with_tags(self):
reg = self._make_registry()
router = APIRouter()
tag = OpenAPITag(name="billing", description="Billing endpoints")
reg.add_router(router, tags=[tag])
assert reg.routers[0].tags == [tag]
def test_add_model_module(self):
reg = self._make_registry()
reg.add_model_module("ee.api.src.models.billing")
assert reg.model_modules == ["ee.api.src.models.billing"]
def test_add_startup_hook(self):
reg = self._make_registry()
async def hook(app: FastAPI) -> None:
pass
reg.add_startup_hook(hook)
assert len(reg.startup_hooks) == 1
def test_apply_mounts_routers(self):
reg = self._make_registry()
router = APIRouter()
@router.get("/test")
async def _test() -> dict[str, str]:
return {"ok": True}
reg.add_router(router, prefix="/ext")
app = FastAPI()
reg.apply(app)
# The router should be included in the app routes
paths = [r.path for r in app.routes]
assert "/ext/test" in paths
def test_apply_adds_openapi_tags(self):
reg = self._make_registry()
router = APIRouter()
tag = OpenAPITag(name="billing", description="Billing endpoints")
reg.add_router(router, tags=[tag])
app = FastAPI()
app.openapi_tags = []
reg.apply(app)
assert any(t["name"] == "billing" for t in app.openapi_tags)
def test_apply_skips_duplicate_tags(self):
reg = self._make_registry()
router = APIRouter()
tag = OpenAPITag(name="billing", description="Billing endpoints")
reg.add_router(router, tags=[tag])
app = FastAPI()
app.openapi_tags = [{"name": "billing", "description": "Existing"}]
reg.apply(app)
billing_tags = [t for t in app.openapi_tags if t["name"] == "billing"]
assert len(billing_tags) == 1
assert billing_tags[0]["description"] == "Existing"
# -- discover_extensions -----------------------------------------------------
class TestDiscoverExtensions:
def test_discover_extensions_does_not_raise(self):
"""discover_extensions should not raise regardless of edition."""
discover_extensions()
# -- Global registry ---------------------------------------------------------
class TestGlobalRegistry:
def test_get_registry_returns_singleton(self):
assert get_registry() is get_registry()
# -- Health endpoint with edition field --------------------------------------
@pytest.mark.asyncio
async def test_health_reports_edition(client):
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["edition"] in ("ce", "ee")

View File

@@ -0,0 +1,573 @@
"""Tests for the GeoIP service.
Covers header-based detection, IP lookup, country-to-region mapping,
client IP extraction, and the combined detect_region flow.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
import src.services.geoip as geoip_module
from src.services.geoip import (
GeoResult,
_is_private_ip,
country_to_region,
detect_region,
detect_region_from_headers,
get_client_ip,
lookup_ip_maxmind,
lookup_ip_region,
)
# ── country_to_region ────────────────────────────────────────────────
class TestCountryToRegion:
def test_eu_country_returns_eu(self):
assert country_to_region("DE") == "EU"
assert country_to_region("FR") == "EU"
assert country_to_region("IT") == "EU"
assert country_to_region("ES") == "EU"
def test_eu_country_case_insensitive(self):
assert country_to_region("de") == "EU"
assert country_to_region("fr") == "EU"
def test_gb_returns_gb(self):
assert country_to_region("GB") == "GB"
def test_br_returns_br(self):
assert country_to_region("BR") == "BR"
def test_us_without_state(self):
assert country_to_region("US") == "US"
def test_us_with_state(self):
assert country_to_region("US", "CA") == "US-CA"
assert country_to_region("US", "ny") == "US-NY"
def test_non_eu_country_returned_as_is(self):
assert country_to_region("JP") == "JP"
assert country_to_region("AU") == "AU"
assert country_to_region("CA") == "CA"
# ── detect_region_from_headers ───────────────────────────────────────
class TestDetectRegionFromHeaders:
def _make_request(self, headers: dict[str, str]) -> MagicMock:
request = MagicMock()
request.headers = headers
return request
def test_cloudflare_header(self):
request = self._make_request({"cf-ipcountry": "DE"})
result = detect_region_from_headers(request)
assert result.country_code == "DE"
assert result.region == "EU"
assert result.is_resolved is True
def test_vercel_header(self):
request = self._make_request({"x-vercel-ip-country": "GB"})
result = detect_region_from_headers(request)
assert result.country_code == "GB"
assert result.region == "GB"
def test_appengine_header(self):
request = self._make_request({"x-appengine-country": "BR"})
result = detect_region_from_headers(request)
assert result.country_code == "BR"
assert result.region == "BR"
def test_custom_header(self):
request = self._make_request({"x-country-code": "JP"})
result = detect_region_from_headers(request)
assert result.country_code == "JP"
assert result.region == "JP"
def test_no_geo_headers(self):
request = self._make_request({})
result = detect_region_from_headers(request)
assert result.country_code is None
assert result.region is None
assert result.is_resolved is False
def test_ignores_xx_value(self):
request = self._make_request({"cf-ipcountry": "XX"})
result = detect_region_from_headers(request)
assert result.is_resolved is False
def test_header_priority_cloudflare_first(self):
request = self._make_request(
{
"cf-ipcountry": "FR",
"x-vercel-ip-country": "DE",
}
)
result = detect_region_from_headers(request)
assert result.country_code == "FR"
def test_case_normalisation(self):
request = self._make_request({"cf-ipcountry": "gb"})
result = detect_region_from_headers(request)
assert result.country_code == "GB"
assert result.region == "GB"
def test_configured_custom_header(self):
"""An operator-configured header is honoured."""
request = self._make_request({"x-gclb-country": "JP"})
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
result = detect_region_from_headers(request)
assert result.country_code == "JP"
assert result.region == "JP"
def test_configured_custom_header_takes_priority(self):
"""When both a custom and a built-in header are present, the
custom one wins — that's the operator's explicit choice."""
request = self._make_request(
{
"cf-ipcountry": "FR",
"x-gclb-country": "JP",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
result = detect_region_from_headers(request)
assert result.country_code == "JP"
def test_configured_header_falls_through_to_builtin(self):
"""If the custom header isn't present, the built-in list still
applies."""
request = self._make_request({"cf-ipcountry": "FR"})
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = None
result = detect_region_from_headers(request)
assert result.country_code == "FR"
assert result.region == "EU"
def test_configured_region_header_pairs_with_country(self):
"""A configured region header is paired with the custom country."""
request = self._make_request(
{
"x-gclb-country": "US",
"x-gclb-region": "CA",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.country_code == "US"
assert result.region == "US-CA"
def test_configured_region_header_strips_country_prefix(self):
"""ISO 3166-2 subdivisions may arrive prefixed (``US-CA``)."""
request = self._make_request(
{
"x-gclb-country": "US",
"x-gclb-region": "US-NY",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.region == "US-NY"
def test_configured_region_header_missing_is_country_only(self):
"""Only country hits region-aware path if the region header is absent."""
request = self._make_request({"x-gclb-country": "US"})
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.country_code == "US"
assert result.region == "US"
def test_configured_region_header_xx_ignored(self):
"""Region value of ``XX`` is treated as unknown."""
request = self._make_request(
{
"x-gclb-country": "US",
"x-gclb-region": "XX",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.region == "US"
# ── get_client_ip ────────────────────────────────────────────────────
class TestGetClientIp:
def _make_request(
self,
headers: dict[str, str] | None = None,
client_host: str | None = None,
) -> MagicMock:
request = MagicMock()
request.headers = headers or {}
if client_host:
request.client = MagicMock()
request.client.host = client_host
else:
request.client = None
return request
def test_x_forwarded_for_single(self):
request = self._make_request({"x-forwarded-for": "1.2.3.4"})
assert get_client_ip(request) == "1.2.3.4"
def test_x_forwarded_for_multiple(self):
request = self._make_request({"x-forwarded-for": "1.2.3.4, 5.6.7.8, 9.10.11.12"})
assert get_client_ip(request) == "1.2.3.4"
def test_x_real_ip(self):
request = self._make_request({"x-real-ip": "1.2.3.4"})
assert get_client_ip(request) == "1.2.3.4"
def test_forwarded_for_takes_priority_over_real_ip(self):
request = self._make_request(
{
"x-forwarded-for": "1.1.1.1",
"x-real-ip": "2.2.2.2",
}
)
assert get_client_ip(request) == "1.1.1.1"
def test_falls_back_to_client_host(self):
request = self._make_request(client_host="10.0.0.1")
assert get_client_ip(request) == "10.0.0.1"
def test_returns_none_when_no_ip(self):
request = self._make_request()
assert get_client_ip(request) is None
# ── _is_private_ip ───────────────────────────────────────────────────
class TestIsPrivateIp:
def test_loopback(self):
assert _is_private_ip("127.0.0.1") is True
assert _is_private_ip("127.0.0.2") is True
def test_private_ranges(self):
assert _is_private_ip("10.0.0.1") is True
assert _is_private_ip("192.168.1.1") is True
assert _is_private_ip("172.16.0.1") is True
def test_ipv6_loopback(self):
assert _is_private_ip("::1") is True
def test_localhost_string(self):
assert _is_private_ip("localhost") is True
def test_public_ip(self):
assert _is_private_ip("8.8.8.8") is False
assert _is_private_ip("1.1.1.1") is False
# ── lookup_ip_region ─────────────────────────────────────────────────
class TestLookupIpRegion:
@pytest.mark.asyncio
async def test_private_ip_returns_unresolved(self):
result = await lookup_ip_region("127.0.0.1")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_private_ip_10_range(self):
result = await lookup_ip_region("10.0.0.1")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_successful_lookup(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"countryCode": "DE",
"region": "BY",
}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.country_code == "DE"
assert result.region == "EU"
@pytest.mark.asyncio
async def test_failed_status(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"status": "fail", "message": "invalid query"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_http_error(self):
mock_response = MagicMock()
mock_response.status_code = 500
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_network_exception(self):
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_us_with_state_lookup(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"countryCode": "US",
"region": "CA",
}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.country_code == "US"
assert result.region == "US-CA"
@pytest.mark.asyncio
async def test_missing_country_code(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"status": "success"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
# ── detect_region (combined) ─────────────────────────────────────────
class TestDetectRegion:
@pytest.mark.asyncio
async def test_uses_headers_when_available(self):
request = MagicMock()
request.headers = {"cf-ipcountry": "FR"}
request.client = None
result = await detect_region(request)
assert result.country_code == "FR"
assert result.region == "EU"
@pytest.mark.asyncio
async def test_falls_back_to_ip_lookup(self):
request = MagicMock()
request.headers = {}
request.client = MagicMock()
request.client.host = "8.8.8.8"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"countryCode": "US",
"region": "CA",
}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await detect_region(request)
assert result.country_code == "US"
assert result.region == "US-CA"
@pytest.mark.asyncio
async def test_returns_unresolved_when_no_ip(self):
request = MagicMock()
request.headers = {}
request.client = None
result = await detect_region(request)
assert result.is_resolved is False
# ── GeoResult ────────────────────────────────────────────────────────
class TestGeoResult:
def test_is_resolved_true(self):
result = GeoResult(country_code="GB", region="GB")
assert result.is_resolved is True
def test_is_resolved_false(self):
result = GeoResult(country_code=None, region=None)
assert result.is_resolved is False
def test_frozen_dataclass(self):
result = GeoResult(country_code="GB", region="GB")
with pytest.raises(AttributeError):
result.country_code = "US" # type: ignore[misc]
# ── MaxMind database lookup ──────────────────────────────────────────
class TestLookupIpMaxmind:
def setup_method(self):
# Reset the module-level cache so each test starts clean.
geoip_module._maxmind_reader = None
geoip_module._maxmind_initialised = False
def _mock_reader(self, country_iso: str | None, subdivision_iso: str | None):
reader = MagicMock()
response = MagicMock()
response.country.iso_code = country_iso
if subdivision_iso is None:
response.subdivisions = None
else:
response.subdivisions.most_specific.iso_code = subdivision_iso
reader.city.return_value = response
return reader
def test_private_ip_returns_unresolved(self):
result = lookup_ip_maxmind("10.0.0.1")
assert result.is_resolved is False
def test_no_db_configured_returns_unresolved(self):
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_maxmind_db_path = None
result = lookup_ip_maxmind("8.8.8.8")
assert result.is_resolved is False
def test_successful_lookup_with_subdivision(self):
reader = self._mock_reader("US", "CA")
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.country_code == "US"
assert result.region == "US-CA"
reader.city.assert_called_once_with("8.8.8.8")
def test_successful_lookup_without_subdivision(self):
reader = self._mock_reader("DE", None)
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.country_code == "DE"
assert result.region == "EU"
def test_reader_raises_returns_unresolved(self):
reader = MagicMock()
reader.city.side_effect = RuntimeError("corrupt db")
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.is_resolved is False
def test_reader_missing_country_returns_unresolved(self):
reader = self._mock_reader(None, None)
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.is_resolved is False
def test_bad_db_path_is_cached_as_failure(self):
"""A missing ``.mmdb`` file should not reopen on every request."""
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_maxmind_db_path = "/nonexistent/geo.mmdb"
r1 = lookup_ip_maxmind("8.8.8.8")
r2 = lookup_ip_maxmind("1.1.1.1")
assert r1.is_resolved is False
assert r2.is_resolved is False
assert geoip_module._maxmind_initialised is True
assert geoip_module._maxmind_reader is None
class TestDetectRegionMaxmind:
def setup_method(self):
geoip_module._maxmind_reader = None
geoip_module._maxmind_initialised = False
@pytest.mark.asyncio
async def test_uses_maxmind_before_external_api(self):
"""With MaxMind configured, ip-api.com must not be called."""
reader = MagicMock()
response = MagicMock()
response.country.iso_code = "GB"
response.subdivisions.most_specific.iso_code = "SCT"
reader.city.return_value = response
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
request = MagicMock()
request.headers = {"x-forwarded-for": "8.8.8.8"}
request.client = None
with (
patch("src.services.geoip.get_settings") as mock_settings,
patch("src.services.geoip.httpx.AsyncClient") as mock_httpx,
):
mock_settings.return_value.geoip_country_header = None
mock_settings.return_value.geoip_region_header = None
mock_settings.return_value.geoip_maxmind_db_path = "/data/GeoLite2-City.mmdb"
result = await detect_region(request)
assert result.country_code == "GB"
assert result.region == "GB-SCT"
mock_httpx.assert_not_called()

View File

@@ -0,0 +1,31 @@
import pytest
@pytest.mark.asyncio
async def test_health_endpoint(client):
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["edition"] in ("ce", "ee")
@pytest.mark.asyncio
async def test_openapi_schema(client):
response = await client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
assert schema["info"]["title"] == "ConsentOS API"
assert schema["info"]["version"] == "0.1.0"
@pytest.mark.asyncio
async def test_api_routes_registered(client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/health" in paths
assert "/api/v1/auth/login" in paths
assert "/api/v1/config/sites/{site_id}" in paths
assert "/api/v1/consent/" in paths
assert "/api/v1/scanner/scans" in paths
assert "/api/v1/compliance/check/{site_id}" in paths

View File

@@ -0,0 +1,89 @@
"""Integration tests for authentication endpoints (requires database)."""
from tests.conftest import requires_db
@requires_db
class TestAuthLogin:
async def test_login_success(self, db_client, test_user):
resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123",
},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
async def test_login_wrong_password(self, db_client, test_user):
resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "wrong",
},
)
assert resp.status_code == 401
async def test_login_nonexistent_user(self, db_client):
resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": "nobody@test.com",
"password": "anything",
},
)
assert resp.status_code == 401
async def test_login_invalid_email(self, db_client):
resp = await db_client.post(
"/api/v1/auth/login",
json={"email": "not-an-email", "password": "anything"},
)
assert resp.status_code == 422
@requires_db
class TestAuthMe:
async def test_me_returns_user(self, db_client, auth_headers, test_user):
resp = await db_client.get("/api/v1/auth/me", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["email"] == test_user.email
assert data["role"] == "owner"
async def test_me_without_token(self, db_client):
resp = await db_client.get("/api/v1/auth/me")
assert resp.status_code == 401
@requires_db
class TestAuthRefresh:
async def test_refresh_returns_new_tokens(self, db_client, test_user):
# First login to get a refresh token
login_resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123",
},
)
refresh_token = login_resp.json()["refresh_token"]
resp = await db_client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token},
)
assert resp.status_code == 200
assert "access_token" in resp.json()
async def test_refresh_with_invalid_token(self, db_client):
resp = await db_client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid-token"},
)
assert resp.status_code == 401

View File

@@ -0,0 +1,152 @@
"""Integration tests for consent recording endpoints (requires database)."""
import uuid
from tests.conftest import create_test_site, requires_db
@requires_db
class TestConsentEndpoints:
async def test_record_consent(self, db_client, auth_headers):
"""POST /consent/ is public (no auth) — used by the banner."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent")
resp = await db_client.post(
"/api/v1/consent/",
json={
"site_id": site_id,
"visitor_id": str(uuid.uuid4()),
"action": "accept_all",
"categories_accepted": [
"necessary",
"functional",
"analytics",
"marketing",
"personalisation",
],
"categories_rejected": [],
},
)
assert resp.status_code == 201
data = resp.json()
assert data["action"] == "accept_all"
assert "id" in data
async def test_record_consent_reject_all(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-rej")
resp = await db_client.post(
"/api/v1/consent/",
json={
"site_id": site_id,
"visitor_id": str(uuid.uuid4()),
"action": "reject_all",
"categories_accepted": ["necessary"],
"categories_rejected": [
"functional",
"analytics",
"marketing",
"personalisation",
],
},
)
assert resp.status_code == 201
assert resp.json()["action"] == "reject_all"
async def test_record_consent_custom(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-cust")
resp = await db_client.post(
"/api/v1/consent/",
json={
"site_id": site_id,
"visitor_id": str(uuid.uuid4()),
"action": "custom",
"categories_accepted": [
"necessary",
"analytics",
],
"categories_rejected": [
"functional",
"marketing",
"personalisation",
],
},
)
assert resp.status_code == 201
assert resp.json()["action"] == "custom"
async def test_get_consent_record(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-get")
# Create a consent record
create_resp = await db_client.post(
"/api/v1/consent/",
json={
"site_id": site_id,
"visitor_id": str(uuid.uuid4()),
"action": "accept_all",
"categories_accepted": ["necessary"],
"categories_rejected": [],
},
)
consent_id = create_resp.json()["id"]
resp = await db_client.get(
f"/api/v1/consent/{consent_id}",
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["id"] == consent_id
async def test_get_consent_requires_auth(self, db_client):
"""Reading a consent record without auth must be rejected."""
resp = await db_client.get(f"/api/v1/consent/{uuid.uuid4()}")
assert resp.status_code == 401
async def test_verify_consent(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-ver")
create_resp = await db_client.post(
"/api/v1/consent/",
json={
"site_id": site_id,
"visitor_id": str(uuid.uuid4()),
"action": "accept_all",
"categories_accepted": ["necessary"],
"categories_rejected": [],
},
)
consent_id = create_resp.json()["id"]
resp = await db_client.get(
f"/api/v1/consent/verify/{consent_id}",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert str(data["id"]) == consent_id
async def test_get_consent_not_found(self, db_client, auth_headers):
resp = await db_client.get(
f"/api/v1/consent/{uuid.uuid4()}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_verify_consent_not_found(self, db_client, auth_headers):
resp = await db_client.get(
f"/api/v1/consent/verify/{uuid.uuid4()}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_record_consent_invalid_action(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="consent-inv")
resp = await db_client.post(
"/api/v1/consent/",
json={
"site_id": site_id,
"visitor_id": str(uuid.uuid4()),
"action": "invalid_action",
"categories_accepted": ["necessary"],
"categories_rejected": [],
},
)
assert resp.status_code == 422

View File

@@ -0,0 +1,169 @@
"""Integration tests for cookie and allow-list endpoints (requires database)."""
import uuid
import pytest
from tests.conftest import create_test_site, requires_db
@requires_db
class TestCookieCategoriesIntegration:
async def test_list_categories_with_db(self, db_client):
"""Categories are seeded by migration; verify the endpoint."""
resp = await db_client.get("/api/v1/cookies/categories")
assert resp.status_code == 200
categories = resp.json()
assert isinstance(categories, list)
# Should have at least the 5 seeded categories
slugs = {c["slug"] for c in categories}
assert "necessary" in slugs
assert "analytics" in slugs
async def test_get_category_by_id(self, db_client):
cats_resp = await db_client.get("/api/v1/cookies/categories")
if cats_resp.status_code == 200 and cats_resp.json():
cat_id = cats_resp.json()[0]["id"]
resp = await db_client.get(f"/api/v1/cookies/categories/{cat_id}")
assert resp.status_code == 200
async def test_get_category_not_found(self, db_client):
resp = await db_client.get(f"/api/v1/cookies/categories/{uuid.uuid4()}")
assert resp.status_code == 404
@requires_db
class TestCookieCRUDIntegration:
async def test_list_cookies_empty(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-empty")
resp = await db_client.get(
f"/api/v1/cookies/sites/{site_id}",
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json() == []
async def test_create_and_list_cookie(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-create")
create_resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}",
json={"name": "_ga", "domain": ".google.com"},
headers=auth_headers,
)
assert create_resp.status_code == 201
cookie = create_resp.json()
assert cookie["name"] == "_ga"
assert cookie["review_status"] == "pending"
# Should now appear in list
list_resp = await db_client.get(
f"/api/v1/cookies/sites/{site_id}",
headers=auth_headers,
)
assert len(list_resp.json()) >= 1
async def test_update_cookie_review_status(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-upd")
create_resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}",
json={"name": "_fbp", "domain": ".facebook.com"},
headers=auth_headers,
)
cookie_id = create_resp.json()["id"]
resp = await db_client.patch(
f"/api/v1/cookies/sites/{site_id}/{cookie_id}",
json={"review_status": "approved"},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["review_status"] == "approved"
async def test_delete_cookie(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-del")
create_resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}",
json={"name": "_del_test", "domain": ".test.com"},
headers=auth_headers,
)
cookie_id = create_resp.json()["id"]
resp = await db_client.delete(
f"/api/v1/cookies/sites/{site_id}/{cookie_id}",
headers=auth_headers,
)
assert resp.status_code == 204
async def test_cookie_summary(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="cookie-sum")
resp = await db_client.get(
f"/api/v1/cookies/sites/{site_id}/summary",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert "total" in data
assert "by_status" in data
assert "uncategorised" in data
@requires_db
class TestAllowListIntegration:
async def _get_category_id(self, db_client):
"""Fetch the first available cookie category ID."""
resp = await db_client.get("/api/v1/cookies/categories")
categories = resp.json()
if categories:
return categories[0]["id"]
return None
async def test_create_allow_list_entry(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="allow-create")
category_id = await self._get_category_id(db_client)
if not category_id:
pytest.skip("No categories seeded")
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/allow-list",
json={
"name_pattern": "_ga*",
"domain_pattern": ".google.com",
"category_id": category_id,
},
headers=auth_headers,
)
assert resp.status_code == 201
data = resp.json()
assert data["name_pattern"] == "_ga*"
async def test_list_allow_list(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="allow-list")
resp = await db_client.get(
f"/api/v1/cookies/sites/{site_id}/allow-list",
headers=auth_headers,
)
assert resp.status_code == 200
assert isinstance(resp.json(), list)
async def test_delete_allow_list_entry(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="allow-del")
category_id = await self._get_category_id(db_client)
if not category_id:
pytest.skip("No categories seeded")
create_resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/allow-list",
json={
"name_pattern": "_del_test*",
"domain_pattern": ".test.com",
"category_id": category_id,
},
headers=auth_headers,
)
entry_id = create_resp.json()["id"]
resp = await db_client.delete(
f"/api/v1/cookies/sites/{site_id}/allow-list/{entry_id}",
headers=auth_headers,
)
assert resp.status_code == 204

View File

@@ -0,0 +1,173 @@
"""Integration tests for organisation and user endpoints (requires database)."""
import uuid
import pytest
from tests.conftest import requires_db
_BOOTSTRAP_TOKEN = "test-bootstrap-token-xyz"
_BOOTSTRAP_HEADERS = {"X-Admin-Bootstrap-Token": _BOOTSTRAP_TOKEN}
@pytest.fixture
def _bootstrap_enabled(monkeypatch):
"""Configure ``admin_bootstrap_token`` on the cached settings object."""
from src.config.settings import get_settings
monkeypatch.setattr(get_settings(), "admin_bootstrap_token", _BOOTSTRAP_TOKEN)
yield
@requires_db
class TestOrganisationEndpoints:
async def test_create_org(self, db_client, _bootstrap_enabled):
slug = f"new-org-{uuid.uuid4().hex[:8]}"
resp = await db_client.post(
"/api/v1/organisations/",
json={"name": "New Org", "slug": slug},
headers=_BOOTSTRAP_HEADERS,
)
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "New Org"
assert data["slug"] == slug
assert "id" in data
async def test_create_org_duplicate_slug(self, db_client, _bootstrap_enabled):
slug = f"dup-org-{uuid.uuid4().hex[:8]}"
await db_client.post(
"/api/v1/organisations/",
json={"name": "Dup Org", "slug": slug},
headers=_BOOTSTRAP_HEADERS,
)
resp = await db_client.post(
"/api/v1/organisations/",
json={"name": "Dup Org 2", "slug": slug},
headers=_BOOTSTRAP_HEADERS,
)
assert resp.status_code == 409
async def test_get_my_org(self, db_client, auth_headers, test_org):
resp = await db_client.get("/api/v1/organisations/me", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["slug"] == test_org.slug
async def test_update_my_org(self, db_client, auth_headers):
resp = await db_client.patch(
"/api/v1/organisations/me",
json={"name": "Updated Org Name"},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["name"] == "Updated Org Name"
@requires_db
class TestUserEndpoints:
async def test_list_users(self, db_client, auth_headers):
resp = await db_client.get("/api/v1/users/", headers=auth_headers)
assert resp.status_code == 200
assert isinstance(resp.json(), list)
assert len(resp.json()) >= 1 # At least the test user
async def test_create_user(self, db_client, auth_headers):
resp = await db_client.post(
"/api/v1/users/",
json={
"email": f"new-{uuid.uuid4().hex[:8]}@test.com",
"password": "SecurePass123",
"full_name": "New User",
"role": "editor",
},
headers=auth_headers,
)
assert resp.status_code == 201
data = resp.json()
assert data["role"] == "editor"
async def test_create_user_duplicate_email(self, db_client, auth_headers):
email = f"dup-{uuid.uuid4().hex[:8]}@test.com"
await db_client.post(
"/api/v1/users/",
json={
"email": email,
"password": "SecurePass123",
"full_name": "Dup User",
"role": "viewer",
},
headers=auth_headers,
)
resp = await db_client.post(
"/api/v1/users/",
json={
"email": email,
"password": "SecurePass123",
"full_name": "Dup User",
"role": "viewer",
},
headers=auth_headers,
)
assert resp.status_code == 409
async def test_get_user(self, db_client, auth_headers, test_user):
resp = await db_client.get(
f"/api/v1/users/{test_user.id}",
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["email"] == test_user.email
async def test_get_user_not_found(self, db_client, auth_headers):
resp = await db_client.get(
f"/api/v1/users/{uuid.uuid4()}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_update_user(self, db_client, auth_headers):
# Create a user to update
create_resp = await db_client.post(
"/api/v1/users/",
json={
"email": f"upd-{uuid.uuid4().hex[:8]}@test.com",
"password": "SecurePass123",
"full_name": "Update User",
"role": "viewer",
},
headers=auth_headers,
)
user_id = create_resp.json()["id"]
resp = await db_client.patch(
f"/api/v1/users/{user_id}",
json={
"full_name": "Updated Name",
"role": "editor",
},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["full_name"] == "Updated Name"
assert resp.json()["role"] == "editor"
async def test_delete_user(self, db_client, auth_headers):
create_resp = await db_client.post(
"/api/v1/users/",
json={
"email": f"del-{uuid.uuid4().hex[:8]}@test.com",
"password": "SecurePass123",
"full_name": "Delete User",
"role": "viewer",
},
headers=auth_headers,
)
user_id = create_resp.json()["id"]
resp = await db_client.delete(f"/api/v1/users/{user_id}", headers=auth_headers)
assert resp.status_code == 204
async def test_users_require_auth(self, db_client):
resp = await db_client.get("/api/v1/users/")
assert resp.status_code == 401

View File

@@ -0,0 +1,192 @@
"""Integration tests for site and site config endpoints (requires database)."""
import uuid
from tests.conftest import requires_db
@requires_db
class TestSiteCRUD:
async def test_create_site(self, db_client, auth_headers):
domain = f"example-{uuid.uuid4().hex[:8]}.com"
resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Example Site",
},
headers=auth_headers,
)
assert resp.status_code == 201
data = resp.json()
assert data["domain"] == domain
assert data["display_name"] == "Example Site"
assert data["is_active"] is True
assert "id" in data
async def test_create_site_duplicate_domain(self, db_client, auth_headers):
domain = f"dup-{uuid.uuid4().hex[:8]}.com"
# Create first
await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Dup Test",
},
headers=auth_headers,
)
# Duplicate should fail
resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Dup Test",
},
headers=auth_headers,
)
assert resp.status_code == 409
async def test_list_sites(self, db_client, auth_headers):
resp = await db_client.get("/api/v1/sites/", headers=auth_headers)
assert resp.status_code == 200
assert isinstance(resp.json(), list)
async def test_get_site(self, db_client, auth_headers):
# Create a site first
domain = f"get-{uuid.uuid4().hex[:8]}.com"
create_resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Get Test",
},
headers=auth_headers,
)
site_id = create_resp.json()["id"]
resp = await db_client.get(f"/api/v1/sites/{site_id}", headers=auth_headers)
assert resp.status_code == 200
assert resp.json()["domain"] == domain
async def test_get_site_not_found(self, db_client, auth_headers):
resp = await db_client.get(
f"/api/v1/sites/{uuid.uuid4()}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_update_site(self, db_client, auth_headers):
domain = f"update-{uuid.uuid4().hex[:8]}.com"
create_resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Update Test",
},
headers=auth_headers,
)
site_id = create_resp.json()["id"]
resp = await db_client.patch(
f"/api/v1/sites/{site_id}",
json={"display_name": "Updated Name"},
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["display_name"] == "Updated Name"
async def test_delete_site_soft_deletes(self, db_client, auth_headers):
domain = f"delete-{uuid.uuid4().hex[:8]}.com"
create_resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Delete Test",
},
headers=auth_headers,
)
site_id = create_resp.json()["id"]
resp = await db_client.delete(f"/api/v1/sites/{site_id}", headers=auth_headers)
assert resp.status_code == 204
# Should no longer be findable
get_resp = await db_client.get(f"/api/v1/sites/{site_id}", headers=auth_headers)
assert get_resp.status_code == 404
async def test_create_site_requires_auth(self, db_client):
resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": "noauth.com",
"display_name": "No Auth",
},
)
assert resp.status_code == 401
@requires_db
class TestSiteConfig:
async def test_get_config_creates_default(self, db_client, auth_headers):
domain = f"config-{uuid.uuid4().hex[:8]}.com"
create_resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Config Test",
},
headers=auth_headers,
)
site_id = create_resp.json()["id"]
# PUT config to create it
put_resp = await db_client.put(
f"/api/v1/sites/{site_id}/config",
json={"blocking_mode": "opt_in"},
headers=auth_headers,
)
assert put_resp.status_code in (200, 201)
# GET config
resp = await db_client.get(
f"/api/v1/sites/{site_id}/config",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["blocking_mode"] == "opt_in"
async def test_update_config(self, db_client, auth_headers):
domain = f"config-upd-{uuid.uuid4().hex[:8]}.com"
create_resp = await db_client.post(
"/api/v1/sites/",
json={
"domain": domain,
"display_name": "Config Update",
},
headers=auth_headers,
)
site_id = create_resp.json()["id"]
# Create config
await db_client.put(
f"/api/v1/sites/{site_id}/config",
json={"blocking_mode": "opt_in"},
headers=auth_headers,
)
# Patch config
resp = await db_client.patch(
f"/api/v1/sites/{site_id}/config",
json={
"blocking_mode": "opt_out",
"gcm_enabled": False,
"consent_expiry_days": 180,
},
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["blocking_mode"] == "opt_out"
assert data["gcm_enabled"] is False
assert data["consent_expiry_days"] == 180

View File

@@ -0,0 +1,137 @@
"""Tests for the rate limiting middleware."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.middleware.rate_limit import RateLimitMiddleware
class TestRateLimitMiddleware:
@pytest.mark.asyncio
async def test_get_client_ip_from_forwarded_for(self):
from starlette.applications import Starlette
app = Starlette()
middleware = RateLimitMiddleware(app)
request = MagicMock()
request.headers = {"x-forwarded-for": "1.2.3.4, 5.6.7.8"}
assert middleware._get_client_ip(request) == "1.2.3.4"
@pytest.mark.asyncio
async def test_get_client_ip_from_real_ip(self):
from starlette.applications import Starlette
app = Starlette()
middleware = RateLimitMiddleware(app)
request = MagicMock()
request.headers = {"x-real-ip": "9.8.7.6"}
assert middleware._get_client_ip(request) == "9.8.7.6"
@pytest.mark.asyncio
async def test_get_client_ip_from_client(self):
from starlette.applications import Starlette
app = Starlette()
middleware = RateLimitMiddleware(app)
request = MagicMock()
request.headers = {}
request.client = MagicMock()
request.client.host = "10.0.0.1"
assert middleware._get_client_ip(request) == "10.0.0.1"
@pytest.mark.asyncio
async def test_get_client_ip_no_client(self):
from starlette.applications import Starlette
app = Starlette()
middleware = RateLimitMiddleware(app)
request = MagicMock()
request.headers = {}
request.client = None
assert middleware._get_client_ip(request) == "unknown"
@pytest.mark.asyncio
async def test_health_bypasses_rate_limit(self):
"""Health checks should never be rate limited."""
from src.main import create_app
app = create_app()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/health")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_passes_through_when_redis_unavailable(self):
"""When Redis is down, requests should still be served."""
from src.main import create_app
# Rate limiting disabled by default in test settings
app = create_app()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get("/health")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_rate_limit_headers_present(self):
"""Rate limit headers should be added when middleware is active."""
from fastapi import FastAPI
app = FastAPI()
@app.get("/test")
async def test_endpoint():
return {"ok": True}
# Mock Redis
mock_redis = AsyncMock()
mock_redis.incr = AsyncMock(return_value=1)
mock_redis.expire = AsyncMock()
middleware = RateLimitMiddleware(app, requests_per_minute=100)
middleware._redis = mock_redis
# Since we can't easily inject the mock Redis into the ASGI middleware,
# test the logic unit separately
assert middleware.requests_per_minute == 100
@pytest.mark.asyncio
async def test_middleware_creation(self):
"""Middleware should initialise with provided parameters."""
from starlette.applications import Starlette
app = Starlette()
middleware = RateLimitMiddleware(app, redis_url="redis://fake:6379", requests_per_minute=30)
assert middleware.requests_per_minute == 30
assert middleware.redis_url == "redis://fake:6379"
assert middleware._redis is None # Lazy initialisation
class TestRateLimitConfiguration:
def test_default_settings_enabled(self, monkeypatch):
"""Rate limiting is on by default — public endpoints must not be DoS-able.
Note: the suite-wide conftest sets ``RATE_LIMIT_ENABLED=false``
so other tests aren't rate-limited by Redis; we unset it here
to verify the baked-in default.
"""
monkeypatch.delenv("RATE_LIMIT_ENABLED", raising=False)
from src.config.settings import Settings
settings = Settings()
assert settings.rate_limit_enabled is True
def test_configurable_limit(self):
"""Rate limit per minute should be configurable."""
from src.config.settings import Settings
settings = Settings(rate_limit_per_minute=120)
assert settings.rate_limit_per_minute == 120

View File

@@ -0,0 +1,65 @@
"""Tests for the security headers middleware."""
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
@pytest.fixture
def app():
return create_app()
@pytest.fixture
async def client(app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
class TestSecurityHeaders:
@pytest.mark.asyncio
async def test_x_content_type_options(self, client):
resp = await client.get("/health")
assert resp.headers["x-content-type-options"] == "nosniff"
@pytest.mark.asyncio
async def test_x_frame_options(self, client):
resp = await client.get("/health")
assert resp.headers["x-frame-options"] == "DENY"
@pytest.mark.asyncio
async def test_x_xss_protection(self, client):
resp = await client.get("/health")
assert resp.headers["x-xss-protection"] == "0"
@pytest.mark.asyncio
async def test_referrer_policy(self, client):
resp = await client.get("/health")
assert resp.headers["referrer-policy"] == "strict-origin-when-cross-origin"
@pytest.mark.asyncio
async def test_content_security_policy(self, client):
resp = await client.get("/health")
assert resp.headers["content-security-policy"] == "default-src 'none'"
@pytest.mark.asyncio
async def test_no_hsts_on_http(self, client):
resp = await client.get("/health")
assert "strict-transport-security" not in resp.headers
@pytest.mark.asyncio
async def test_hsts_on_https(self, app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="https://test") as client:
resp = await client.get("/health")
assert "strict-transport-security" in resp.headers
assert "max-age=63072000" in resp.headers["strict-transport-security"]
@pytest.mark.asyncio
async def test_headers_present_on_non_existent_route(self, client):
# Even 404s on unknown routes should have security headers
resp = await client.get("/this-does-not-exist")
assert resp.headers["x-content-type-options"] == "nosniff"
assert resp.headers["x-frame-options"] == "DENY"

View File

@@ -0,0 +1,166 @@
"""Tests for SQLAlchemy model definitions.
These tests verify model structure without needing a database connection.
"""
from sqlalchemy import inspect
from src.models import (
Base,
ConsentRecord,
Cookie,
CookieAllowListEntry,
CookieCategory,
KnownCookie,
Organisation,
ScanJob,
ScanResult,
Site,
SiteConfig,
Translation,
User,
)
def test_all_models_registered_in_metadata():
"""All expected tables should be present in Base.metadata."""
table_names = set(Base.metadata.tables.keys())
expected = {
"organisations",
"users",
"sites",
"site_configs",
"cookie_categories",
"cookies",
"cookie_allow_list",
"known_cookies",
"consent_records",
"scan_jobs",
"scan_results",
"translations",
}
assert expected.issubset(table_names), f"Missing tables: {expected - table_names}"
def test_organisation_columns():
mapper = inspect(Organisation)
column_names = {c.key for c in mapper.columns}
assert "id" in column_names
assert "name" in column_names
assert "slug" in column_names
assert "contact_email" in column_names
assert "billing_plan" in column_names
assert "created_at" in column_names
assert "updated_at" in column_names
assert "deleted_at" in column_names
def test_user_columns_and_fk():
mapper = inspect(User)
column_names = {c.key for c in mapper.columns}
assert "organisation_id" in column_names
assert "email" in column_names
assert "password_hash" in column_names
assert "role" in column_names
def test_site_unique_constraint():
table = Site.__table__
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
assert "uq_sites_org_domain" in constraint_names
def test_site_config_jsonb_fields():
mapper = inspect(SiteConfig)
column_names = {c.key for c in mapper.columns}
for field in ["regional_modes", "gcm_default", "banner_config"]:
assert field in column_names, f"Missing JSONB field: {field}"
def test_cookie_category_columns():
mapper = inspect(CookieCategory)
column_names = {c.key for c in mapper.columns}
assert "tcf_purpose_ids" in column_names
assert "gcm_consent_types" in column_names
assert "is_essential" in column_names
def test_cookie_unique_constraint():
table = Cookie.__table__
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
assert "uq_cookies_site_name_domain_type" in constraint_names
def test_cookie_allow_list_unique_constraint():
table = CookieAllowListEntry.__table__
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
assert "uq_allow_list_site_name_domain" in constraint_names
def test_known_cookie_unique_constraint():
table = KnownCookie.__table__
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
assert "uq_known_cookies_name_domain" in constraint_names
def test_consent_record_columns():
mapper = inspect(ConsentRecord)
column_names = {c.key for c in mapper.columns}
for field in [
"visitor_id",
"action",
"categories_accepted",
"tc_string",
"gcm_state",
"country_code",
"consented_at",
]:
assert field in column_names, f"Missing field: {field}"
def test_scan_job_columns():
mapper = inspect(ScanJob)
column_names = {c.key for c in mapper.columns}
assert "status" in column_names
assert "pages_scanned" in column_names
assert "cookies_found" in column_names
def test_scan_result_columns():
mapper = inspect(ScanResult)
column_names = {c.key for c in mapper.columns}
assert "page_url" in column_names
assert "cookie_name" in column_names
assert "script_source" in column_names
assert "auto_category" in column_names
def test_translation_unique_constraint():
table = Translation.__table__
constraint_names = {c.name for c in table.constraints if hasattr(c, "name") and c.name}
assert "uq_translations_site_locale" in constraint_names
def test_uuid_primary_keys():
"""All models should use UUID primary keys."""
models = [
Organisation,
User,
Site,
SiteConfig,
CookieCategory,
Cookie,
CookieAllowListEntry,
KnownCookie,
ConsentRecord,
ScanJob,
ScanResult,
Translation,
]
for model in models:
mapper = inspect(model)
pk_cols = mapper.primary_key
assert len(pk_cols) == 1, f"{model.__name__} should have exactly one PK column"
assert str(pk_cols[0].type) == "UUID", (
f"{model.__name__} PK should be UUID, got {pk_cols[0].type}"
)

View File

@@ -0,0 +1,96 @@
"""Tests for OpenAPI schema generation and documentation."""
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
@pytest.fixture
def app():
return create_app()
@pytest.fixture
async def client(app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
class TestOpenAPISchema:
@pytest.mark.asyncio
async def test_openapi_endpoint_accessible(self, client):
resp = await client.get("/openapi.json")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_schema_has_info(self, client):
resp = await client.get("/openapi.json")
schema = resp.json()
assert schema["info"]["title"] == "ConsentOS API"
assert "version" in schema["info"]
assert "description" in schema["info"]
assert "consent" in schema["info"]["description"].lower()
@pytest.mark.asyncio
async def test_schema_has_tags(self, client):
resp = await client.get("/openapi.json")
schema = resp.json()
tag_names = {t["name"] for t in schema.get("tags", [])}
expected_tags = {
"auth",
"config",
"consent",
"sites",
"cookies",
"scanner",
"compliance",
"organisations",
"users",
}
assert expected_tags.issubset(tag_names)
@pytest.mark.asyncio
async def test_all_tags_have_descriptions(self, client):
resp = await client.get("/openapi.json")
schema = resp.json()
for tag in schema.get("tags", []):
assert "description" in tag, f"Tag '{tag['name']}' missing description"
assert len(tag["description"]) > 10, f"Tag '{tag['name']}' has weak description"
@pytest.mark.asyncio
async def test_health_endpoint_in_schema(self, client):
resp = await client.get("/openapi.json")
schema = resp.json()
assert "/health" in schema["paths"]
@pytest.mark.asyncio
async def test_key_endpoints_present(self, client):
resp = await client.get("/openapi.json")
paths = resp.json()["paths"]
assert "/api/v1/auth/login" in paths
assert "/api/v1/consent/" in paths
assert "/api/v1/sites/" in paths
assert "/api/v1/config/geo" in paths
@pytest.mark.asyncio
async def test_docs_endpoint_accessible(self, client):
resp = await client.get("/docs")
assert resp.status_code == 200
class TestOpenAPIEndpoints:
@pytest.mark.asyncio
async def test_config_endpoints_documented(self, client):
resp = await client.get("/openapi.json")
paths = resp.json()["paths"]
config_paths = [p for p in paths if "/config/" in p]
assert len(config_paths) >= 4 # public, resolved, geo-resolved, publish, geo
@pytest.mark.asyncio
async def test_consent_endpoints_documented(self, client):
resp = await client.get("/openapi.json")
paths = resp.json()["paths"]
consent_paths = [p for p in paths if "/consent" in p]
assert len(consent_paths) >= 1

View File

@@ -0,0 +1,146 @@
"""Tests for organisation and user CRUD endpoints and schemas."""
import uuid
import pytest
from pydantic import ValidationError
from src.schemas.organisation import OrganisationCreate, OrganisationResponse, OrganisationUpdate
from src.schemas.user import UserCreate, UserResponse, UserRole, UserUpdate
class TestOrganisationSchemas:
def test_create_valid(self):
org = OrganisationCreate(name="Acme Corp", slug="acme-corp")
assert org.name == "Acme Corp"
assert org.slug == "acme-corp"
assert org.billing_plan == "free"
def test_create_invalid_slug(self):
with pytest.raises(ValidationError):
OrganisationCreate(name="Acme", slug="INVALID SLUG!")
def test_create_empty_name_rejected(self):
with pytest.raises(ValidationError):
OrganisationCreate(name="", slug="valid-slug")
def test_update_partial(self):
update = OrganisationUpdate(name="New Name")
data = update.model_dump(exclude_unset=True)
assert data == {"name": "New Name"}
assert "contact_email" not in data
def test_response_from_attributes(self):
now = "2026-01-01T00:00:00Z"
resp = OrganisationResponse(
id=uuid.uuid4(),
name="Test",
slug="test",
contact_email=None,
billing_plan="free",
created_at=now,
updated_at=now,
)
assert resp.name == "Test"
class TestUserSchemas:
def test_create_valid(self):
user = UserCreate(
email="test@example.com",
password="securepass123",
full_name="Test User",
)
assert user.role == UserRole.VIEWER
def test_create_short_password_rejected(self):
with pytest.raises(ValidationError):
UserCreate(email="a@b.com", password="short", full_name="Test")
def test_create_invalid_email_rejected(self):
with pytest.raises(ValidationError):
UserCreate(email="not-an-email", password="securepass123", full_name="Test")
def test_create_with_role(self):
user = UserCreate(
email="admin@example.com",
password="securepass123",
full_name="Admin",
role=UserRole.ADMIN,
)
assert user.role == UserRole.ADMIN
def test_update_partial(self):
update = UserUpdate(role=UserRole.EDITOR)
data = update.model_dump(exclude_unset=True)
assert data == {"role": "editor"}
def test_response_from_attributes(self):
now = "2026-01-01T00:00:00Z"
resp = UserResponse(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="a@b.com",
full_name="Test",
role="viewer",
created_at=now,
updated_at=now,
)
assert resp.role == "viewer"
class TestUserRole:
def test_role_values(self):
assert UserRole.OWNER == "owner"
assert UserRole.ADMIN == "admin"
assert UserRole.EDITOR == "editor"
assert UserRole.VIEWER == "viewer"
def test_invalid_role_rejected(self):
with pytest.raises(ValidationError):
UserCreate(
email="a@b.com",
password="securepass123",
full_name="Test",
role="superadmin",
)
@pytest.mark.asyncio
class TestRoutesRegistered:
async def test_org_routes(self, client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/api/v1/organisations/" in paths
assert "/api/v1/organisations/me" in paths
async def test_user_routes(self, client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/api/v1/users/" in paths
assert "/api/v1/users/{user_id}" in paths
async def test_org_endpoints_require_auth(self, client):
response = await client.get("/api/v1/organisations/me")
assert response.status_code == 401
async def test_user_endpoints_require_auth(self, client):
response = await client.get("/api/v1/users/")
assert response.status_code == 401
async def test_create_org_rejects_invalid_body(self, client, monkeypatch):
"""Create org endpoint validates the request body schema.
We need to enable the bootstrap token first so the request
reaches the body-validation stage (the token guard otherwise
fires before Pydantic validation and we'd see 403 instead).
"""
from src.config.settings import get_settings
monkeypatch.setattr(get_settings(), "admin_bootstrap_token", "test-token")
response = await client.post(
"/api/v1/organisations/",
json={"name": "", "slug": "INVALID SLUG!"},
headers={"X-Admin-Bootstrap-Token": "test-token"},
)
assert response.status_code == 422

View File

@@ -0,0 +1,88 @@
"""Unit tests for the CDN publisher service."""
import contextlib
import json
import os
from pathlib import Path
from unittest.mock import patch
import pytest
from src.services.publisher import PublishResult, _publish_local, publish_site_config
class TestPublishResult:
def test_success_result(self):
result = PublishResult(success=True, path="/cdn/config.json")
assert result.success is True
assert result.path == "/cdn/config.json"
assert result.published_at is not None
assert result.error is None
def test_failure_result(self):
result = PublishResult(success=False, path="", error="Something went wrong")
assert result.success is False
assert result.published_at is None
assert result.error == "Something went wrong"
class TestPublishLocal:
@pytest.mark.asyncio
async def test_publish_creates_files(self):
config = {"site_id": "abc-123", "blocking_mode": "opt_in"}
path = await _publish_local("abc-123", config, "https://cdn.example.com")
assert os.path.exists(path)
with open(path) as f:
data = json.load(f)
assert data["site_id"] == "abc-123"
@pytest.mark.asyncio
async def test_publish_creates_versioned_copy(self):
config = {"site_id": "def-456", "blocking_mode": "opt_out"}
path = await _publish_local("def-456", config, "https://cdn.example.com")
publish_dir = Path(path).parent
versioned = list(publish_dir.glob("site-config-def-456-*.json"))
assert len(versioned) >= 1
class TestPublishSiteConfig:
@pytest.mark.asyncio
async def test_publish_success(self):
site_config = {
"blocking_mode": "opt_in",
"tcf_enabled": False,
"gcm_enabled": True,
"consent_expiry_days": 365,
}
result = await publish_site_config("site-123", site_config)
assert result.success is True
assert result.path != ""
assert result.published_at is not None
@pytest.mark.asyncio
async def test_publish_with_org_defaults(self):
site_config = {"blocking_mode": "opt_in"}
org_defaults = {"consent_expiry_days": 180}
result = await publish_site_config("site-456", site_config, org_defaults)
assert result.success is True
@pytest.mark.asyncio
async def test_publish_failure_returns_error(self):
with patch(
"src.services.publisher._publish_local",
side_effect=OSError("Permission denied"),
):
result = await publish_site_config("site-789", {"blocking_mode": "opt_in"})
assert result.success is False
assert "Permission denied" in result.error
@pytest.fixture(autouse=True)
def _cleanup_cdn():
yield
cdn_dir = Path("cdn-publish")
if cdn_dir.exists():
for f in cdn_dir.glob("site-config-*.json"):
f.unlink(missing_ok=True)
with contextlib.suppress(OSError):
cdn_dir.rmdir()

View File

@@ -0,0 +1,203 @@
"""Unit tests for auth router — mocked database."""
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token, create_refresh_token, hash_password
def _make_user(org_id: uuid.UUID | None = None, **overrides):
"""Build a mock User ORM object."""
_org_id = org_id or uuid.uuid4()
_id = overrides.pop("id", uuid.uuid4())
user = MagicMock()
user.id = _id
user.organisation_id = _org_id
user.email = overrides.get("email", "admin@test.com")
user.password_hash = overrides.get("password_hash", hash_password("TestPassword123"))
user.full_name = overrides.get("full_name", "Test Admin")
user.role = overrides.get("role", "owner")
user.deleted_at = None
user.is_active = True
return user
def _mock_db(scalars=None, scalar_one_or_none=None):
"""Create a mock AsyncSession.
When a query is executed:
- result.scalar_one_or_none() returns `scalar_one_or_none`
- result.scalars().all() returns `scalars or []`
"""
session = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = scalar_one_or_none
scalars_obj = MagicMock()
scalars_obj.all.return_value = scalars or []
result.scalars.return_value = scalars_obj
session.execute.return_value = result
return session
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
"""Build a test client with the given mock session."""
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestLoginEndpoint:
@pytest.mark.asyncio
async def test_login_success(self, mock_app):
org_id = uuid.uuid4()
user = _make_user(org_id=org_id)
db = _mock_db(scalar_one_or_none=user)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/auth/login",
json={"email": "admin@test.com", "password": "TestPassword123"},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
@pytest.mark.asyncio
async def test_login_wrong_password(self, mock_app):
user = _make_user()
db = _mock_db(scalar_one_or_none=user)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/auth/login",
json={"email": "admin@test.com", "password": "WrongPassword"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_login_user_not_found(self, mock_app):
db = _mock_db(scalar_one_or_none=None)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/auth/login",
json={"email": "nobody@test.com", "password": "whatever"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_login_invalid_body(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post("/api/v1/auth/login", json={"email": "not-an-email"})
assert resp.status_code == 422
class TestMeEndpoint:
@pytest.mark.asyncio
async def test_me_returns_user(self, mock_app):
org_id = uuid.uuid4()
user_id = uuid.uuid4()
token = create_access_token(
user_id=user_id,
organisation_id=org_id,
role="owner",
email="admin@test.com",
)
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert resp.status_code == 200
data = resp.json()
assert data["email"] == "admin@test.com"
assert data["role"] == "owner"
@pytest.mark.asyncio
async def test_me_without_token(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/auth/me")
assert resp.status_code in (401, 403)
class TestRefreshEndpoint:
@pytest.mark.asyncio
async def test_refresh_success(self, mock_app):
org_id = uuid.uuid4()
user_id = uuid.uuid4()
user = _make_user(org_id=org_id, id=user_id)
refresh_token = create_refresh_token(
user_id=user_id,
organisation_id=org_id,
)
db = _mock_db(scalar_one_or_none=user)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
@pytest.mark.asyncio
async def test_refresh_with_access_token_rejected(self, mock_app):
"""An access token should not be usable as a refresh token."""
org_id = uuid.uuid4()
user_id = uuid.uuid4()
access_token = create_access_token(
user_id=user_id,
organisation_id=org_id,
role="owner",
email="admin@test.com",
)
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": access_token},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_refresh_invalid_token(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid.token.here"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_refresh_user_deleted(self, mock_app):
org_id = uuid.uuid4()
user_id = uuid.uuid4()
refresh_token = create_refresh_token(
user_id=user_id,
organisation_id=org_id,
)
db = _mock_db(scalar_one_or_none=None)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token},
)
assert resp.status_code == 401
assert "no longer exists" in resp.json()["detail"]

View File

@@ -0,0 +1,296 @@
"""Unit tests for config router — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token
ORG_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
def _auth_headers(role="owner"):
token = create_access_token(
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
)
return {"Authorization": f"Bearer {token}"}
def _mock_config(**overrides):
config = MagicMock(spec=[])
config.id = overrides.get("id", uuid.uuid4())
config.site_id = overrides.get("site_id", uuid.uuid4())
config.blocking_mode = overrides.get("blocking_mode", "opt_in")
config.tcf_enabled = overrides.get("tcf_enabled", False)
config.tcf_publisher_cc = overrides.get("tcf_publisher_cc")
config.gpp_enabled = overrides.get("gpp_enabled", True)
config.gpp_supported_apis = overrides.get("gpp_supported_apis", ["usnat"])
config.gpc_enabled = overrides.get("gpc_enabled", True)
default_jurisdictions = ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"]
config.gpc_jurisdictions = overrides.get("gpc_jurisdictions", default_jurisdictions)
config.gpc_global_honour = overrides.get("gpc_global_honour", False)
config.gcm_enabled = overrides.get("gcm_enabled", True)
config.gcm_default = overrides.get("gcm_default")
config.banner_config = overrides.get("banner_config", {})
config.regional_modes = overrides.get("regional_modes")
config.privacy_policy_url = overrides.get("privacy_policy_url")
config.scan_schedule_cron = overrides.get("scan_schedule_cron")
config.scan_max_pages = overrides.get("scan_max_pages", 50)
config.consent_expiry_days = overrides.get("consent_expiry_days", 365)
config.created_at = datetime.now(UTC)
config.updated_at = datetime.now(UTC)
return config
def _mock_db_sequence(*results):
session = AsyncMock()
mock_results = []
for r in results:
result = MagicMock()
result.scalar_one_or_none.return_value = r
mock_results.append(result)
session.execute = AsyncMock(side_effect=mock_results)
return session
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestPublicSiteConfig:
@pytest.mark.asyncio
async def test_get_public_config(self, mock_app):
config = _mock_config()
db = _mock_db_sequence(config)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/config/sites/{config.site_id}")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_public_config_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/config/sites/{uuid.uuid4()}")
assert resp.status_code == 404
class TestResolvedConfig:
@pytest.mark.asyncio
async def test_get_resolved_config(self, mock_app):
config = _mock_config()
# Resolved endpoint does 4 queries: config, site org_id, org_config, site group_id
db = _mock_db_sequence(config, ORG_ID, None, None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/config/sites/{config.site_id}/resolved")
assert resp.status_code == 200
data = resp.json()
assert "site_id" in data
assert "blocking_mode" in data
@pytest.mark.asyncio
async def test_get_resolved_config_with_region(self, mock_app):
config = _mock_config(regional_modes={"EU": "opt_in", "US": "opt_out"})
db = _mock_db_sequence(config, ORG_ID, None, None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/config/sites/{config.site_id}/resolved?region=EU")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_resolved_config_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/config/sites/{uuid.uuid4()}/resolved")
assert resp.status_code == 404
class TestPublishConfig:
@pytest.mark.asyncio
async def test_publish_config_success(self, mock_app):
config = _mock_config()
# Publish does: config query, org_config query, group_id, active A/B test query
db = _mock_db_sequence(config, None, None, None)
mock_result = MagicMock()
mock_result.success = True
mock_result.path = "/cdn/site-config.json"
mock_result.published_at = datetime.now(UTC).isoformat()
with patch(
"src.routers.config.publish_site_config",
new_callable=AsyncMock,
return_value=mock_result,
):
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/config/sites/{config.site_id}/publish",
headers=_auth_headers(),
)
assert resp.status_code == 200
data = resp.json()
assert data["published"] is True
@pytest.mark.asyncio
async def test_publish_config_failure(self, mock_app):
config = _mock_config()
db = _mock_db_sequence(config, None, None, None)
mock_result = MagicMock()
mock_result.success = False
mock_result.error = "Disk full"
with patch(
"src.routers.config.publish_site_config",
new_callable=AsyncMock,
return_value=mock_result,
):
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/config/sites/{config.site_id}/publish",
headers=_auth_headers(),
)
assert resp.status_code == 500
@pytest.mark.asyncio
async def test_publish_config_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/config/sites/{uuid.uuid4()}/publish",
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_publish_requires_admin(self, mock_app):
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/config/sites/{uuid.uuid4()}/publish",
headers=_auth_headers(role="viewer"),
)
assert resp.status_code == 403
class TestGeoResolvedConfig:
@pytest.mark.asyncio
async def test_get_geo_resolved_config_with_header(self, mock_app):
config = _mock_config(
regional_modes={"EU": "opt_in", "US": "opt_out", "DEFAULT": "informational"},
)
# Geo-resolved does: config, site org_id, org_config, site group_id
db = _mock_db_sequence(config, ORG_ID, None, None)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/config/sites/{config.site_id}/geo-resolved",
headers={"cf-ipcountry": "DE"},
)
assert resp.status_code == 200
data = resp.json()
assert data["blocking_mode"] == "opt_in"
assert data["detected_country"] == "DE"
assert data["detected_region"] == "EU"
@pytest.mark.asyncio
async def test_get_geo_resolved_config_us(self, mock_app):
config = _mock_config(
regional_modes={"EU": "opt_in", "US-CA": "opt_out", "DEFAULT": "informational"},
)
db = _mock_db_sequence(config, ORG_ID, None, None)
with patch(
"src.routers.config.detect_region",
new_callable=AsyncMock,
) as mock_detect:
from src.services.geoip import GeoResult
mock_detect.return_value = GeoResult(country_code="US", region="US-CA")
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/config/sites/{config.site_id}/geo-resolved",
)
assert resp.status_code == 200
data = resp.json()
assert data["blocking_mode"] == "opt_out"
assert data["detected_region"] == "US-CA"
@pytest.mark.asyncio
async def test_get_geo_resolved_config_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/config/sites/{uuid.uuid4()}/geo-resolved",
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_get_geo_resolved_config_no_region_detected(self, mock_app):
config = _mock_config(
regional_modes={"EU": "opt_in", "DEFAULT": "informational"},
)
db = _mock_db_sequence(config, ORG_ID, None, None)
with patch(
"src.routers.config.detect_region",
new_callable=AsyncMock,
) as mock_detect:
from src.services.geoip import GeoResult
mock_detect.return_value = GeoResult(country_code=None, region=None)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/config/sites/{config.site_id}/geo-resolved",
)
assert resp.status_code == 200
data = resp.json()
assert data["detected_country"] is None
assert data["detected_region"] is None
class TestVisitorGeo:
@pytest.mark.asyncio
async def test_get_visitor_geo_with_header(self, mock_app):
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.get(
"/api/v1/config/geo",
headers={"cf-ipcountry": "GB"},
)
assert resp.status_code == 200
data = resp.json()
assert data["country_code"] == "GB"
assert data["region"] == "GB"
@pytest.mark.asyncio
async def test_get_visitor_geo_no_headers(self, mock_app):
db = _mock_db_sequence()
with patch(
"src.routers.config.detect_region",
new_callable=AsyncMock,
) as mock_detect:
from src.services.geoip import GeoResult
mock_detect.return_value = GeoResult(country_code=None, region=None)
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/config/geo")
assert resp.status_code == 200
data = resp.json()
assert data["country_code"] is None
assert data["region"] is None

View File

@@ -0,0 +1,230 @@
"""Unit tests for consent router — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
def _mock_consent_record(**overrides):
"""Build a mock ConsentRecord ORM object."""
record = MagicMock()
record.id = overrides.get("id", uuid.uuid4())
record.site_id = overrides.get("site_id", uuid.uuid4())
record.visitor_id = overrides.get("visitor_id", "visitor-123")
record.ip_hash = "abc123"
record.user_agent_hash = "def456"
record.action = overrides.get("action", "accept_all")
record.categories_accepted = overrides.get("categories_accepted", ["necessary"])
record.categories_rejected = overrides.get("categories_rejected", [])
record.tc_string = overrides.get("tc_string")
record.gcm_state = overrides.get("gcm_state")
record.gpp_string = overrides.get("gpp_string")
record.gpc_detected = overrides.get("gpc_detected")
record.gpc_honoured = overrides.get("gpc_honoured")
record.page_url = overrides.get("page_url")
record.country_code = overrides.get("country_code")
record.region_code = overrides.get("region_code")
record.consented_at = overrides.get("consented_at", datetime.now(UTC))
return record
def _mock_db(scalar_one_or_none=None):
session = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = scalar_one_or_none
session.execute.return_value = result
_added_objects = []
def _fake_add(obj):
_added_objects.append(obj)
session.add = MagicMock(side_effect=_fake_add)
async def _fake_flush():
"""Simulate DB flush — populate server-side defaults."""
for obj in _added_objects:
if getattr(obj, "id", None) is None:
obj.id = uuid.uuid4()
if hasattr(obj, "consented_at") and getattr(obj, "consented_at", None) is None:
obj.consented_at = datetime.now(UTC)
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
obj.created_at = datetime.now(UTC)
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
obj.updated_at = datetime.now(UTC)
session.flush = AsyncMock(side_effect=_fake_flush)
session.refresh = AsyncMock()
return session
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
from src.services.dependencies import get_current_user, require_role
user = MagicMock()
user.organisation_id = uuid.uuid4()
user.role = "owner"
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
app.dependency_overrides[get_current_user] = lambda: user
def _override_require_role(*_roles):
return lambda: user
app.dependency_overrides[require_role] = _override_require_role
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestRecordConsent:
@pytest.mark.asyncio
async def test_record_consent_success(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/consent/",
json={
"site_id": str(uuid.uuid4()),
"visitor_id": "visitor-123",
"action": "accept_all",
"categories_accepted": ["necessary", "analytics"],
"categories_rejected": [],
},
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_record_consent_reject_all(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/consent/",
json={
"site_id": str(uuid.uuid4()),
"visitor_id": "visitor-456",
"action": "reject_all",
"categories_accepted": ["necessary"],
"categories_rejected": ["analytics", "marketing"],
},
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_record_consent_custom(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/consent/",
json={
"site_id": str(uuid.uuid4()),
"visitor_id": "visitor-789",
"action": "custom",
"categories_accepted": ["necessary", "analytics"],
"categories_rejected": ["marketing"],
},
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_record_consent_invalid_action(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/consent/",
json={
"site_id": str(uuid.uuid4()),
"visitor_id": "visitor-000",
"action": "invalid_action",
"categories_accepted": ["necessary"],
"categories_rejected": [],
},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_record_consent_empty_visitor_id(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/consent/",
json={
"site_id": str(uuid.uuid4()),
"visitor_id": "",
"action": "accept_all",
"categories_accepted": ["necessary"],
"categories_rejected": [],
},
)
assert resp.status_code == 422
@pytest.mark.asyncio
async def test_record_consent_with_optional_fields(self, mock_app):
db = _mock_db()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/consent/",
json={
"site_id": str(uuid.uuid4()),
"visitor_id": "visitor-opt",
"action": "accept_all",
"categories_accepted": ["necessary"],
"categories_rejected": [],
"tc_string": "CPXxRAAAA",
"gcm_state": {"analytics_storage": "granted"},
"page_url": "https://example.com",
"country_code": "GB",
"region_code": "ENG",
},
)
assert resp.status_code == 201
class TestGetConsent:
@pytest.mark.asyncio
async def test_get_consent_found(self, mock_app):
record = _mock_consent_record()
db = _mock_db(scalar_one_or_none=record)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/consent/{record.id}")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_consent_not_found(self, mock_app):
db = _mock_db(scalar_one_or_none=None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/consent/{uuid.uuid4()}")
assert resp.status_code == 404
class TestVerifyConsent:
@pytest.mark.asyncio
async def test_verify_consent_valid(self, mock_app):
record = _mock_consent_record()
db = _mock_db(scalar_one_or_none=record)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/consent/verify/{record.id}")
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
@pytest.mark.asyncio
async def test_verify_consent_not_found(self, mock_app):
db = _mock_db(scalar_one_or_none=None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/consent/verify/{uuid.uuid4()}")
assert resp.status_code == 404

View File

@@ -0,0 +1,437 @@
"""Unit tests for cookie, category, and allow-list routers — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token
ORG_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
def _auth_headers(role="owner"):
token = create_access_token(
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
)
return {"Authorization": f"Bearer {token}"}
def _mock_category(**overrides):
cat = MagicMock(spec=[])
cat.id = overrides.get("id", uuid.uuid4())
cat.name = overrides.get("name", "Analytics")
cat.slug = overrides.get("slug", "analytics")
cat.description = overrides.get("description", "Analytics cookies")
cat.is_essential = overrides.get("is_essential", False)
cat.display_order = overrides.get("display_order", 3)
cat.tcf_purpose_ids = overrides.get("tcf_purpose_ids", [])
cat.gcm_consent_types = overrides.get("gcm_consent_types", ["analytics_storage"])
cat.created_at = datetime.now(UTC)
cat.updated_at = datetime.now(UTC)
return cat
def _mock_cookie(**overrides):
cookie = MagicMock(spec=[])
cookie.id = overrides.get("id", uuid.uuid4())
cookie.site_id = overrides.get("site_id", uuid.uuid4())
cookie.name = overrides.get("name", "_ga")
cookie.domain = overrides.get("domain", ".google.com")
cookie.path = overrides.get("path", "/")
cookie.category_id = overrides.get("category_id")
cookie.storage_type = overrides.get("storage_type", "cookie")
cookie.review_status = overrides.get("review_status", "pending")
cookie.description = overrides.get("description")
cookie.vendor = overrides.get("vendor")
cookie.max_age_seconds = overrides.get("max_age_seconds")
cookie.is_http_only = overrides.get("is_http_only")
cookie.is_secure = overrides.get("is_secure")
cookie.same_site = overrides.get("same_site")
cookie.first_seen_at = overrides.get("first_seen_at", datetime.now(UTC).isoformat())
cookie.last_seen_at = overrides.get("last_seen_at", datetime.now(UTC).isoformat())
cookie.created_at = datetime.now(UTC)
cookie.updated_at = datetime.now(UTC)
return cookie
def _mock_site():
site = MagicMock(spec=[])
site.id = uuid.uuid4()
site.organisation_id = ORG_ID
site.domain = "test.com"
site.deleted_at = None
return site
def _mock_allow_list_entry(**overrides):
entry = MagicMock(spec=[])
entry.id = overrides.get("id", uuid.uuid4())
entry.site_id = overrides.get("site_id", uuid.uuid4())
entry.name_pattern = overrides.get("name_pattern", "_ga*")
entry.domain_pattern = overrides.get("domain_pattern", ".google.com")
entry.category_id = overrides.get("category_id", uuid.uuid4())
entry.description = overrides.get("description")
entry.created_at = datetime.now(UTC)
entry.updated_at = datetime.now(UTC)
return entry
def _mock_db_sequence(*results):
session = AsyncMock()
mock_results = []
for r in results:
result = MagicMock()
if isinstance(r, list):
result.scalar_one_or_none.return_value = r[0] if r else None
scalars_obj = MagicMock()
scalars_obj.all.return_value = r
result.scalars.return_value = scalars_obj
elif isinstance(r, dict) and "scalar" in r:
result.scalar.return_value = r["scalar"]
elif isinstance(r, dict) and "all" in r:
result.all.return_value = r["all"]
else:
result.scalar_one_or_none.return_value = r
mock_results.append(result)
session.execute = AsyncMock(side_effect=mock_results)
_added = []
def _fake_add(obj):
_added.append(obj)
session.add = MagicMock(side_effect=_fake_add)
async def _fake_flush():
for obj in _added:
if getattr(obj, "id", None) is None:
obj.id = uuid.uuid4()
if hasattr(obj, "review_status") and getattr(obj, "review_status", None) is None:
obj.review_status = "pending"
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
obj.created_at = datetime.now(UTC)
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
obj.updated_at = datetime.now(UTC)
session.flush = AsyncMock(side_effect=_fake_flush)
session.refresh = AsyncMock()
session.delete = AsyncMock()
return session
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestCookieCategories:
@pytest.mark.asyncio
async def test_list_categories(self, mock_app):
cats = [_mock_category(slug="necessary"), _mock_category(slug="analytics")]
db = _mock_db_sequence(cats)
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/cookies/categories")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_category(self, mock_app):
cat = _mock_category()
db = _mock_db_sequence(cat)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/cookies/categories/{cat.id}")
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_category_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/cookies/categories/{uuid.uuid4()}")
assert resp.status_code == 404
class TestCookieCRUD:
@pytest.mark.asyncio
async def test_list_cookies(self, mock_app):
site = _mock_site()
cookies = [_mock_cookie(site_id=site.id)]
db = _mock_db_sequence(site, cookies)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/cookies/sites/{site.id}", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_list_cookies_empty(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, [])
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/cookies/sites/{site.id}", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_create_cookie(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site) # site found
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/cookies/sites/{site.id}",
json={"name": "_ga", "domain": ".google.com"},
headers=_auth_headers(),
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_create_cookie_with_invalid_category(self, mock_app):
site = _mock_site()
cat_id = uuid.uuid4()
db = _mock_db_sequence(site, None) # site found, category not found
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/cookies/sites/{site.id}",
json={"name": "_ga", "domain": ".google.com", "category_id": str(cat_id)},
headers=_auth_headers(),
)
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_get_cookie(self, mock_app):
site = _mock_site()
cookie = _mock_cookie(site_id=site.id)
db = _mock_db_sequence(site, cookie)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_cookie_not_found(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/cookies/sites/{site.id}/{uuid.uuid4()}",
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_update_cookie(self, mock_app):
site = _mock_site()
cookie = _mock_cookie(site_id=site.id)
db = _mock_db_sequence(site, cookie)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
json={"review_status": "approved"},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_update_cookie_not_found(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/cookies/sites/{site.id}/{uuid.uuid4()}",
json={"review_status": "approved"},
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_update_cookie_invalid_category(self, mock_app):
site = _mock_site()
cookie = _mock_cookie(site_id=site.id)
cat_id = uuid.uuid4()
# site found, cookie found, category validation fails
db = _mock_db_sequence(site, cookie, None)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
json={"category_id": str(cat_id)},
headers=_auth_headers(),
)
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_delete_cookie(self, mock_app):
site = _mock_site()
cookie = _mock_cookie(site_id=site.id)
db = _mock_db_sequence(site, cookie)
async with await _client(mock_app, db) as client:
resp = await client.delete(
f"/api/v1/cookies/sites/{site.id}/{cookie.id}",
headers=_auth_headers(),
)
assert resp.status_code == 204
@pytest.mark.asyncio
async def test_delete_cookie_not_found(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.delete(
f"/api/v1/cookies/sites/{site.id}/{uuid.uuid4()}",
headers=_auth_headers(),
)
assert resp.status_code == 404
class TestCookieSummary:
@pytest.mark.asyncio
async def test_cookie_summary(self, mock_app):
site = _mock_site()
# summary makes 4 queries: _get_org_site, status count, category count, uncategorised
db = _mock_db_sequence(
site,
{"all": [("pending", 5), ("approved", 3)]},
{"all": [("analytics", 4), ("marketing", 2)]},
{"scalar": 2},
)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/cookies/sites/{site.id}/summary",
headers=_auth_headers(),
)
assert resp.status_code == 200
data = resp.json()
assert "total" in data
assert "by_status" in data
assert "uncategorised" in data
class TestAllowList:
@pytest.mark.asyncio
@pytest.mark.asyncio
async def test_list_allow_list(self, mock_app):
site = _mock_site()
entries = [_mock_allow_list_entry(site_id=site.id)]
db = _mock_db_sequence(site, entries)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/cookies/sites/{site.id}/allow-list",
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_create_allow_list_entry(self, mock_app):
site = _mock_site()
cat = _mock_category()
db = _mock_db_sequence(site, cat) # site found, category valid
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/cookies/sites/{site.id}/allow-list",
json={
"name_pattern": "_ga*",
"domain_pattern": ".google.com",
"category_id": str(cat.id),
},
headers=_auth_headers(),
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_create_allow_list_invalid_category(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None) # site found, category not found
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/cookies/sites/{site.id}/allow-list",
json={
"name_pattern": "_ga*",
"domain_pattern": ".google.com",
"category_id": str(uuid.uuid4()),
},
headers=_auth_headers(),
)
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_update_allow_list_entry(self, mock_app):
site = _mock_site()
entry = _mock_allow_list_entry(site_id=site.id)
db = _mock_db_sequence(site, entry)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/cookies/sites/{site.id}/allow-list/{entry.id}",
json={"name_pattern": "_fbp*"},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_update_allow_list_not_found(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/cookies/sites/{site.id}/allow-list/{uuid.uuid4()}",
json={"name_pattern": "_fbp*"},
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_update_allow_list_invalid_category(self, mock_app):
site = _mock_site()
entry = _mock_allow_list_entry(site_id=site.id)
db = _mock_db_sequence(site, entry, None) # site, entry found, category invalid
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/cookies/sites/{site.id}/allow-list/{entry.id}",
json={"category_id": str(uuid.uuid4())},
headers=_auth_headers(),
)
assert resp.status_code == 400
@pytest.mark.asyncio
async def test_delete_allow_list_entry(self, mock_app):
site = _mock_site()
entry = _mock_allow_list_entry(site_id=site.id)
db = _mock_db_sequence(site, entry)
async with await _client(mock_app, db) as client:
resp = await client.delete(
f"/api/v1/cookies/sites/{site.id}/allow-list/{entry.id}",
headers=_auth_headers(),
)
assert resp.status_code == 204
@pytest.mark.asyncio
async def test_delete_allow_list_not_found(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.delete(
f"/api/v1/cookies/sites/{site.id}/allow-list/{uuid.uuid4()}",
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_site_not_found(self, mock_app):
db = _mock_db_sequence(None) # site not found
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/cookies/sites/{uuid.uuid4()}",
headers=_auth_headers(),
)
assert resp.status_code == 404

View File

@@ -0,0 +1,184 @@
"""Unit tests for org-config router — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token
ORG_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
def _auth_headers(role="owner"):
token = create_access_token(
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
)
return {"Authorization": f"Bearer {token}"}
def _mock_org_config(**overrides):
config = MagicMock()
config.id = overrides.get("id", uuid.uuid4())
config.organisation_id = overrides.get("organisation_id", ORG_ID)
config.blocking_mode = overrides.get("blocking_mode")
config.regional_modes = overrides.get("regional_modes")
config.tcf_enabled = overrides.get("tcf_enabled")
config.tcf_publisher_cc = overrides.get("tcf_publisher_cc")
config.gcm_enabled = overrides.get("gcm_enabled")
config.gcm_default = overrides.get("gcm_default")
config.banner_config = overrides.get("banner_config")
config.gpp_enabled = overrides.get("gpp_enabled")
config.gpp_supported_apis = overrides.get("gpp_supported_apis")
config.gpc_enabled = overrides.get("gpc_enabled")
config.gpc_jurisdictions = overrides.get("gpc_jurisdictions")
config.gpc_global_honour = overrides.get("gpc_global_honour")
config.shopify_privacy_enabled = overrides.get("shopify_privacy_enabled")
config.privacy_policy_url = overrides.get("privacy_policy_url")
config.terms_url = overrides.get("terms_url")
config.scan_schedule_cron = overrides.get("scan_schedule_cron")
config.scan_max_pages = overrides.get("scan_max_pages")
config.consent_expiry_days = overrides.get("consent_expiry_days")
config.consent_retention_days = overrides.get("consent_retention_days")
config.created_at = datetime.now(UTC)
config.updated_at = datetime.now(UTC)
return config
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
def _mock_db_sequence(*results):
"""Create a mock session returning different results on successive execute() calls."""
session = AsyncMock()
mock_results = []
for r in results:
result = MagicMock()
result.scalar_one_or_none.return_value = r
mock_results.append(result)
session.execute = AsyncMock(side_effect=mock_results)
_added = []
def _fake_add(obj):
_added.append(obj)
session.add = MagicMock(side_effect=_fake_add)
async def _fake_flush():
for obj in _added:
if getattr(obj, "id", None) is None:
obj.id = uuid.uuid4()
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
obj.created_at = datetime.now(UTC)
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
obj.updated_at = datetime.now(UTC)
session.flush = AsyncMock(side_effect=_fake_flush)
session.refresh = AsyncMock()
return session
class TestGetOrgConfig:
@pytest.mark.asyncio
async def test_get_existing_config(self, mock_app):
"""GET /org-config/ returns existing config."""
config = _mock_org_config(blocking_mode="opt_out", consent_expiry_days=180)
db = _mock_db_sequence(config)
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/org-config/", headers=_auth_headers())
assert resp.status_code == 200
data = resp.json()
assert data["blocking_mode"] == "opt_out"
assert data["consent_expiry_days"] == 180
@pytest.mark.asyncio
async def test_get_auto_creates_when_missing(self, mock_app):
"""GET /org-config/ auto-creates a blank config if none exists."""
db = _mock_db_sequence(None) # no existing config
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/org-config/", headers=_auth_headers())
assert resp.status_code == 200
data = resp.json()
# All optional fields should be None
assert data["blocking_mode"] is None
assert data["tcf_enabled"] is None
@pytest.mark.asyncio
async def test_get_requires_auth(self, mock_app):
"""GET /org-config/ returns 401 without token."""
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/org-config/")
assert resp.status_code == 401
class TestUpdateOrgConfig:
@pytest.mark.asyncio
async def test_update_existing_config(self, mock_app):
"""PUT /org-config/ updates existing config."""
config = _mock_org_config()
db = _mock_db_sequence(config)
async with await _client(mock_app, db) as client:
resp = await client.put(
"/api/v1/org-config/",
json={"blocking_mode": "opt_out", "consent_expiry_days": 90},
headers=_auth_headers(),
)
assert resp.status_code == 200
# Verify setattr was called on the mock
assert config.blocking_mode == "opt_out"
assert config.consent_expiry_days == 90
@pytest.mark.asyncio
async def test_update_creates_when_missing(self, mock_app):
"""PUT /org-config/ creates config if none exists."""
db = _mock_db_sequence(None) # no existing config
async with await _client(mock_app, db) as client:
resp = await client.put(
"/api/v1/org-config/",
json={"tcf_enabled": True},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_update_requires_admin(self, mock_app):
"""PUT /org-config/ returns 403 for viewers."""
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.put(
"/api/v1/org-config/",
json={"blocking_mode": "opt_in"},
headers=_auth_headers(role="viewer"),
)
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_update_allows_editor_role_fails(self, mock_app):
"""PUT /org-config/ returns 403 for editors (only owner/admin can update)."""
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.put(
"/api/v1/org-config/",
json={"blocking_mode": "opt_in"},
headers=_auth_headers(role="editor"),
)
assert resp.status_code == 403

View File

@@ -0,0 +1,344 @@
"""Unit tests for organisation and user routers — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token, hash_password
ORG_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
def _auth_headers(role="owner"):
token = create_access_token(
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
)
return {"Authorization": f"Bearer {token}"}
def _mock_org(**overrides):
org = MagicMock(spec=[])
org.id = overrides.get("id", ORG_ID)
org.name = overrides.get("name", "Test Org")
org.slug = overrides.get("slug", "test-org")
org.contact_email = overrides.get("contact_email")
org.billing_plan = overrides.get("billing_plan", "free")
org.deleted_at = None
org.created_at = datetime.now(UTC)
org.updated_at = datetime.now(UTC)
return org
def _mock_user(**overrides):
user = MagicMock(spec=[])
user.id = overrides.get("id", uuid.uuid4())
user.organisation_id = overrides.get("organisation_id", ORG_ID)
user.email = overrides.get("email", "user@test.com")
user.password_hash = overrides.get("password_hash", hash_password("Pass123"))
user.full_name = overrides.get("full_name", "Test User")
user.role = overrides.get("role", "editor")
user.is_active = True
user.deleted_at = None
user.created_at = datetime.now(UTC)
user.updated_at = datetime.now(UTC)
return user
def _mock_db_sequence(*results):
session = AsyncMock()
mock_results = []
for r in results:
result = MagicMock()
if isinstance(r, list):
result.scalar_one_or_none.return_value = r[0] if r else None
scalars_obj = MagicMock()
scalars_obj.all.return_value = r
result.scalars.return_value = scalars_obj
else:
result.scalar_one_or_none.return_value = r
mock_results.append(result)
session.execute = AsyncMock(side_effect=mock_results)
_added = []
def _fake_add(obj):
_added.append(obj)
session.add = MagicMock(side_effect=_fake_add)
async def _fake_flush():
for obj in _added:
if getattr(obj, "id", None) is None:
obj.id = uuid.uuid4()
if hasattr(obj, "is_active") and getattr(obj, "is_active", None) is None:
obj.is_active = True
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
obj.created_at = datetime.now(UTC)
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
obj.updated_at = datetime.now(UTC)
session.flush = AsyncMock(side_effect=_fake_flush)
session.refresh = AsyncMock()
session.delete = AsyncMock()
return session
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
_BOOTSTRAP_TOKEN = "test-bootstrap-token-xyz"
_BOOTSTRAP_HEADERS = {"X-Admin-Bootstrap-Token": _BOOTSTRAP_TOKEN}
@pytest.fixture
def _bootstrap_enabled(monkeypatch):
"""Configure the bootstrap token so org creation is permitted."""
from src.config import settings as settings_mod
monkeypatch.setattr(
settings_mod.get_settings(),
"admin_bootstrap_token",
_BOOTSTRAP_TOKEN,
)
yield
class TestOrganisationRouter:
@pytest.mark.asyncio
async def test_create_org(self, mock_app, _bootstrap_enabled):
db = _mock_db_sequence(None) # no duplicate slug
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/organisations/",
json={"name": "New Org", "slug": "new-org"},
headers=_BOOTSTRAP_HEADERS,
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_create_org_duplicate_slug(self, mock_app, _bootstrap_enabled):
existing = _mock_org(slug="dup-slug")
db = _mock_db_sequence(existing)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/organisations/",
json={"name": "Another", "slug": "dup-slug"},
headers=_BOOTSTRAP_HEADERS,
)
assert resp.status_code == 409
@pytest.mark.asyncio
async def test_create_org_disabled_without_token(self, mock_app):
"""With no ``admin_bootstrap_token`` configured, creation is forbidden."""
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/organisations/",
json={"name": "X", "slug": "x"},
)
assert resp.status_code == 403
@pytest.mark.asyncio
async def test_create_org_wrong_token(self, mock_app, _bootstrap_enabled):
"""With an incorrect token, creation is unauthorised."""
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/organisations/",
json={"name": "X", "slug": "x"},
headers={"X-Admin-Bootstrap-Token": "wrong"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_get_my_org(self, mock_app):
org = _mock_org()
db = _mock_db_sequence(org)
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/organisations/me", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_my_org_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/organisations/me", headers=_auth_headers())
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_update_my_org(self, mock_app):
org = _mock_org()
db = _mock_db_sequence(org)
async with await _client(mock_app, db) as client:
resp = await client.patch(
"/api/v1/organisations/me",
json={"name": "Updated Name"},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_update_my_org_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.patch(
"/api/v1/organisations/me",
json={"name": "Updated"},
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_update_my_org_requires_admin(self, mock_app):
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.patch(
"/api/v1/organisations/me",
json={"name": "Updated"},
headers=_auth_headers(role="viewer"),
)
assert resp.status_code == 403
class TestUserRouter:
@pytest.mark.asyncio
async def test_create_user(self, mock_app):
db = _mock_db_sequence(None) # no duplicate email
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/users/",
json={
"email": "new@test.com",
"password": "SecurePass123",
"full_name": "New User",
"role": "editor",
},
headers=_auth_headers(),
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_create_user_duplicate_email(self, mock_app):
existing = _mock_user(email="dup@test.com")
db = _mock_db_sequence(existing)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/users/",
json={
"email": "dup@test.com",
"password": "SecurePass123",
"full_name": "Dup User",
"role": "viewer",
},
headers=_auth_headers(),
)
assert resp.status_code == 409
@pytest.mark.asyncio
async def test_list_users(self, mock_app):
users = [_mock_user(), _mock_user(email="two@test.com")]
db = _mock_db_sequence(users)
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/users/", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_user(self, mock_app):
user = _mock_user()
db = _mock_db_sequence(user)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/users/{user.id}", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_user_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/users/{uuid.uuid4()}", headers=_auth_headers())
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_update_user(self, mock_app):
user = _mock_user()
db = _mock_db_sequence(user)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/users/{user.id}",
json={"full_name": "Updated Name", "role": "admin"},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_update_user_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/users/{uuid.uuid4()}",
json={"full_name": "Nope"},
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_delete_user(self, mock_app):
user = _mock_user(id=uuid.uuid4())
db = _mock_db_sequence(user)
async with await _client(mock_app, db) as client:
resp = await client.delete(f"/api/v1/users/{user.id}", headers=_auth_headers())
assert resp.status_code == 204
@pytest.mark.asyncio
async def test_delete_self_rejected(self, mock_app):
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.delete(f"/api/v1/users/{USER_ID}", headers=_auth_headers())
assert resp.status_code == 400
assert "yourself" in resp.json()["detail"]
@pytest.mark.asyncio
async def test_delete_user_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.delete(f"/api/v1/users/{uuid.uuid4()}", headers=_auth_headers())
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_users_require_auth(self, mock_app):
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/users/")
assert resp.status_code in (401, 403)
@pytest.mark.asyncio
async def test_create_user_viewer_forbidden(self, mock_app):
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/users/",
json={
"email": "new@test.com",
"password": "SecurePass123",
"role": "viewer",
},
headers=_auth_headers(role="viewer"),
)
assert resp.status_code == 403

View File

@@ -0,0 +1,189 @@
"""Unit tests for site-groups router — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token
ORG_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
def _auth_headers():
token = create_access_token(
user_id=USER_ID, organisation_id=ORG_ID, role="owner", email="admin@test.com"
)
return {"Authorization": f"Bearer {token}"}
def _mock_group(**overrides):
group = MagicMock()
group.id = overrides.get("id", uuid.uuid4())
group.organisation_id = overrides.get("organisation_id", ORG_ID)
group.name = overrides.get("name", "Steve Madden")
group.description = overrides.get("description")
group.deleted_at = None
group.created_at = datetime.now(UTC)
group.updated_at = datetime.now(UTC)
return group
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
def _mock_db_sequence(*results):
"""Create a mock session returning different results on successive execute() calls."""
session = AsyncMock()
mock_results = []
for r in results:
result = MagicMock()
if isinstance(r, list):
result.scalar_one_or_none.return_value = r[0] if r else None
result.scalar_one.return_value = len(r) if isinstance(r[0], int) else 0
scalars_obj = MagicMock()
scalars_obj.all.return_value = r
result.scalars.return_value = scalars_obj
result.all.return_value = r
elif isinstance(r, int):
result.scalar_one.return_value = r
result.scalar_one_or_none.return_value = r
elif isinstance(r, tuple):
# (group, site_count) rows for list endpoint
result.all.return_value = r
else:
result.scalar_one_or_none.return_value = r
result.scalar_one.return_value = r if r is not None else 0
mock_results.append(result)
session.execute = AsyncMock(side_effect=mock_results)
_added = []
def _fake_add(obj):
_added.append(obj)
session.add = MagicMock(side_effect=_fake_add)
async def _fake_flush():
for obj in _added:
if getattr(obj, "id", None) is None:
obj.id = uuid.uuid4()
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
obj.created_at = datetime.now(UTC)
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
obj.updated_at = datetime.now(UTC)
session.flush = AsyncMock(side_effect=_fake_flush)
session.refresh = AsyncMock()
return session
class TestSiteGroupCRUD:
@pytest.mark.asyncio
async def test_create_site_group(self, mock_app):
"""POST /site-groups/ creates a new group."""
db = _mock_db_sequence(None) # no duplicate
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/site-groups/",
json={"name": "Steve Madden"},
headers=_auth_headers(),
)
assert resp.status_code == 201
data = resp.json()
assert data["name"] == "Steve Madden"
assert data["site_count"] == 0
@pytest.mark.asyncio
async def test_create_site_group_conflict(self, mock_app):
"""POST /site-groups/ returns 409 when name exists."""
existing = _mock_group(name="Steve Madden")
db = _mock_db_sequence(existing)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/site-groups/",
json={"name": "Steve Madden"},
headers=_auth_headers(),
)
assert resp.status_code == 409
@pytest.mark.asyncio
async def test_list_site_groups(self, mock_app):
"""GET /site-groups/ returns groups with site counts."""
group = _mock_group(name="Steve Madden")
row = MagicMock()
row.SiteGroup = group
row.site_count = 3
rows = (row,)
db = _mock_db_sequence(rows)
async with await _client(mock_app, db) as client:
resp = await client.get(
"/api/v1/site-groups/",
headers=_auth_headers(),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["name"] == "Steve Madden"
assert data[0]["site_count"] == 3
@pytest.mark.asyncio
async def test_get_site_group(self, mock_app):
"""GET /site-groups/{id} returns a single group."""
group_id = uuid.uuid4()
group = _mock_group(id=group_id, description="SM brand")
db = _mock_db_sequence(group, 2) # group lookup, site count
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/site-groups/{group_id}",
headers=_auth_headers(),
)
assert resp.status_code == 200
data = resp.json()
assert data["description"] == "SM brand"
assert data["site_count"] == 2
@pytest.mark.asyncio
async def test_get_site_group_not_found(self, mock_app):
"""GET /site-groups/{id} returns 404 for unknown ID."""
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/site-groups/{uuid.uuid4()}",
headers=_auth_headers(),
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_delete_site_group(self, mock_app):
"""DELETE /site-groups/{id} soft-deletes and ungroups sites."""
group_id = uuid.uuid4()
group = _mock_group(id=group_id)
site = MagicMock()
site.site_group_id = group_id
db = _mock_db_sequence(group, [site]) # group lookup, sites in group
async with await _client(mock_app, db) as client:
resp = await client.delete(
f"/api/v1/site-groups/{group_id}",
headers=_auth_headers(),
)
assert resp.status_code == 204
assert site.site_group_id is None
assert group.deleted_at is not None

View File

@@ -0,0 +1,266 @@
"""Unit tests for sites router — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token
ORG_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
def _auth_headers():
token = create_access_token(
user_id=USER_ID, organisation_id=ORG_ID, role="owner", email="admin@test.com"
)
return {"Authorization": f"Bearer {token}"}
def _mock_site(**overrides):
site = MagicMock()
site.id = overrides.get("id", uuid.uuid4())
site.organisation_id = overrides.get("organisation_id", ORG_ID)
site.domain = overrides.get("domain", "example.com")
site.display_name = overrides.get("display_name", "Example Site")
site.is_active = overrides.get("is_active", True)
site.additional_domains = overrides.get("additional_domains")
site.site_group_id = overrides.get("site_group_id")
site.deleted_at = None
site.created_at = datetime.now(UTC)
site.updated_at = datetime.now(UTC)
# Alias for SiteResponse.name field
site.name = site.display_name
return site
def _mock_config(**overrides):
config = MagicMock(spec=[]) # spec=[] prevents auto-attr generation
config.id = overrides.get("id", uuid.uuid4())
config.site_id = overrides.get("site_id", uuid.uuid4())
config.blocking_mode = overrides.get("blocking_mode", "opt_in")
config.tcf_enabled = overrides.get("tcf_enabled", False)
config.tcf_publisher_cc = overrides.get("tcf_publisher_cc")
config.gpp_enabled = overrides.get("gpp_enabled", True)
config.gpp_supported_apis = overrides.get("gpp_supported_apis", ["usnat"])
config.gpc_enabled = overrides.get("gpc_enabled", True)
default_jurisdictions = ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"]
config.gpc_jurisdictions = overrides.get("gpc_jurisdictions", default_jurisdictions)
config.gpc_global_honour = overrides.get("gpc_global_honour", False)
config.gcm_enabled = overrides.get("gcm_enabled", True)
config.gcm_default = overrides.get("gcm_default")
config.banner_config = overrides.get("banner_config", {})
config.regional_modes = overrides.get("regional_modes")
config.privacy_policy_url = overrides.get("privacy_policy_url")
config.scan_schedule_cron = overrides.get("scan_schedule_cron")
config.scan_max_pages = overrides.get("scan_max_pages", 50)
config.consent_expiry_days = overrides.get("consent_expiry_days", 365)
config.created_at = datetime.now(UTC)
config.updated_at = datetime.now(UTC)
return config
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
def _mock_db_sequence(*results):
"""Create a mock session that returns different results on successive execute() calls."""
session = AsyncMock()
mock_results = []
for r in results:
result = MagicMock()
if isinstance(r, list):
result.scalar_one_or_none.return_value = r[0] if r else None
scalars_obj = MagicMock()
scalars_obj.all.return_value = r
result.scalars.return_value = scalars_obj
else:
result.scalar_one_or_none.return_value = r
mock_results.append(result)
session.execute = AsyncMock(side_effect=mock_results)
_added = []
def _fake_add(obj):
_added.append(obj)
session.add = MagicMock(side_effect=_fake_add)
async def _fake_flush():
for obj in _added:
if getattr(obj, "id", None) is None:
obj.id = uuid.uuid4()
if hasattr(obj, "is_active") and getattr(obj, "is_active", None) is None:
obj.is_active = True
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
obj.created_at = datetime.now(UTC)
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
obj.updated_at = datetime.now(UTC)
session.flush = AsyncMock(side_effect=_fake_flush)
session.refresh = AsyncMock()
return session
class TestSiteCRUD:
@pytest.mark.asyncio
async def test_create_site_success(self, mock_app):
# First execute: check existing (None), second: after flush
db = _mock_db_sequence(None) # no duplicate
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/sites/",
json={"domain": "new-site.com", "display_name": "New Site"},
headers=_auth_headers(),
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_create_site_duplicate(self, mock_app):
existing_site = _mock_site(domain="dup.com")
db = _mock_db_sequence(existing_site)
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/sites/",
json={"domain": "dup.com", "display_name": "Dup Site"},
headers=_auth_headers(),
)
assert resp.status_code == 409
@pytest.mark.asyncio
async def test_list_sites(self, mock_app):
sites = [_mock_site(), _mock_site(domain="two.com")]
db = _mock_db_sequence(sites)
async with await _client(mock_app, db) as client:
resp = await client.get("/api/v1/sites/", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_site_success(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/sites/{site.id}", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_site_not_found(self, mock_app):
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/sites/{uuid.uuid4()}", headers=_auth_headers())
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_update_site(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/sites/{site.id}",
json={"display_name": "Updated"},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_delete_site(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site)
async with await _client(mock_app, db) as client:
resp = await client.delete(f"/api/v1/sites/{site.id}", headers=_auth_headers())
assert resp.status_code == 204
@pytest.mark.asyncio
async def test_create_site_requires_auth(self, mock_app):
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.post(
"/api/v1/sites/", json={"domain": "noauth.com", "display_name": "No Auth"}
)
assert resp.status_code in (401, 403)
class TestSiteConfig:
@pytest.mark.asyncio
async def test_get_config_success(self, mock_app):
site = _mock_site()
config = _mock_config(site_id=site.id)
db = _mock_db_sequence(site, config)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/sites/{site.id}/config", headers=_auth_headers())
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_get_config_not_found(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/sites/{site.id}/config", headers=_auth_headers())
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_put_config_create(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None) # site found, no existing config
async with await _client(mock_app, db) as client:
resp = await client.put(
f"/api/v1/sites/{site.id}/config",
json={"blocking_mode": "opt_in"},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_put_config_replace(self, mock_app):
site = _mock_site()
config = _mock_config(site_id=site.id)
db = _mock_db_sequence(site, config)
async with await _client(mock_app, db) as client:
resp = await client.put(
f"/api/v1/sites/{site.id}/config",
json={"blocking_mode": "opt_out"},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_patch_config_success(self, mock_app):
site = _mock_site()
config = _mock_config(site_id=site.id)
db = _mock_db_sequence(site, config)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/sites/{site.id}/config",
json={"gcm_enabled": False, "consent_expiry_days": 180},
headers=_auth_headers(),
)
assert resp.status_code == 200
@pytest.mark.asyncio
async def test_patch_config_not_found(self, mock_app):
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.patch(
f"/api/v1/sites/{site.id}/config",
json={"gcm_enabled": False},
headers=_auth_headers(),
)
assert resp.status_code == 404

View File

@@ -0,0 +1,277 @@
"""Unit tests for translations router — mocked database."""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.main import create_app
from src.services.auth import create_access_token
ORG_ID = uuid.uuid4()
USER_ID = uuid.uuid4()
SITE_ID = uuid.uuid4()
def _auth_headers(role="owner"):
token = create_access_token(
user_id=USER_ID, organisation_id=ORG_ID, role=role, email="admin@test.com"
)
return {"Authorization": f"Bearer {token}"}
def _mock_site(**overrides):
site = MagicMock()
site.id = overrides.get("id", SITE_ID)
site.organisation_id = overrides.get("organisation_id", ORG_ID)
site.domain = "example.com"
site.display_name = "Example"
site.is_active = True
site.deleted_at = None
site.additional_domains = None
site.site_group_id = None
site.created_at = datetime.now(UTC)
site.updated_at = datetime.now(UTC)
return site
def _mock_translation(**overrides):
t = MagicMock()
t.id = overrides.get("id", uuid.uuid4())
t.site_id = overrides.get("site_id", SITE_ID)
t.locale = overrides.get("locale", "fr")
t.strings = overrides.get(
"strings",
{"title": "Nous utilisons des cookies", "acceptAll": "Tout accepter"},
)
t.created_at = datetime.now(UTC)
t.updated_at = datetime.now(UTC)
return t
@pytest.fixture
def mock_app():
return create_app()
async def _client(app, mock_session):
from src.db import get_db
async def _override():
yield mock_session
app.dependency_overrides[get_db] = _override
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
def _mock_db_sequence(*results):
"""Create a mock session returning different results on successive execute() calls."""
session = AsyncMock()
mock_results = []
for r in results:
result = MagicMock()
if isinstance(r, list):
scalars_obj = MagicMock()
scalars_obj.all.return_value = r
result.scalars.return_value = scalars_obj
result.scalar_one_or_none.return_value = r[0] if r else None
else:
result.scalar_one_or_none.return_value = r
mock_results.append(result)
session.execute = AsyncMock(side_effect=mock_results)
_added = []
def _fake_add(obj):
_added.append(obj)
session.add = MagicMock(side_effect=_fake_add)
async def _fake_flush():
for obj in _added:
if getattr(obj, "id", None) is None:
obj.id = uuid.uuid4()
if hasattr(obj, "created_at") and getattr(obj, "created_at", None) is None:
obj.created_at = datetime.now(UTC)
if hasattr(obj, "updated_at") and getattr(obj, "updated_at", None) is None:
obj.updated_at = datetime.now(UTC)
session.flush = AsyncMock(side_effect=_fake_flush)
session.refresh = AsyncMock()
session.delete = AsyncMock()
return session
class TestListTranslations:
@pytest.mark.asyncio
async def test_list_translations(self, mock_app):
"""GET /sites/{id}/translations/ returns all translations."""
site = _mock_site()
fr = _mock_translation(locale="fr")
de = _mock_translation(locale="de")
db = _mock_db_sequence(site, [fr, de])
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/sites/{SITE_ID}/translations/",
headers=_auth_headers(),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
@pytest.mark.asyncio
async def test_list_translations_empty(self, mock_app):
"""GET /sites/{id}/translations/ returns empty list when no translations."""
site = _mock_site()
db = _mock_db_sequence(site, [])
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/sites/{SITE_ID}/translations/",
headers=_auth_headers(),
)
assert resp.status_code == 200
assert resp.json() == []
class TestGetTranslation:
@pytest.mark.asyncio
async def test_get_translation(self, mock_app):
"""GET /sites/{id}/translations/fr returns the French translation."""
site = _mock_site()
fr = _mock_translation(locale="fr")
db = _mock_db_sequence(site, fr)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/sites/{SITE_ID}/translations/fr",
headers=_auth_headers(),
)
assert resp.status_code == 200
data = resp.json()
assert data["locale"] == "fr"
assert "title" in data["strings"]
@pytest.mark.asyncio
async def test_get_translation_not_found(self, mock_app):
"""GET /sites/{id}/translations/xx returns 404."""
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.get(
f"/api/v1/sites/{SITE_ID}/translations/xx",
headers=_auth_headers(),
)
assert resp.status_code == 404
class TestCreateTranslation:
@pytest.mark.asyncio
async def test_create_translation(self, mock_app):
"""POST /sites/{id}/translations/ creates a new translation."""
site = _mock_site()
db = _mock_db_sequence(site, None) # site lookup, duplicate check
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/sites/{SITE_ID}/translations/",
json={
"locale": "de",
"strings": {"title": "Wir verwenden Cookies"},
},
headers=_auth_headers(),
)
assert resp.status_code == 201
data = resp.json()
assert data["locale"] == "de"
@pytest.mark.asyncio
async def test_create_translation_conflict(self, mock_app):
"""POST returns 409 when locale already exists."""
site = _mock_site()
existing = _mock_translation(locale="fr")
db = _mock_db_sequence(site, existing)
async with await _client(mock_app, db) as client:
resp = await client.post(
f"/api/v1/sites/{SITE_ID}/translations/",
json={"locale": "fr", "strings": {"title": "test"}},
headers=_auth_headers(),
)
assert resp.status_code == 409
class TestUpdateTranslation:
@pytest.mark.asyncio
async def test_update_translation(self, mock_app):
"""PUT /sites/{id}/translations/fr updates the strings."""
site = _mock_site()
fr = _mock_translation(locale="fr")
db = _mock_db_sequence(site, fr)
async with await _client(mock_app, db) as client:
resp = await client.put(
f"/api/v1/sites/{SITE_ID}/translations/fr",
json={"strings": {"title": "Updated title"}},
headers=_auth_headers(),
)
assert resp.status_code == 200
assert fr.strings == {"title": "Updated title"}
@pytest.mark.asyncio
async def test_update_translation_not_found(self, mock_app):
"""PUT returns 404 when locale does not exist."""
site = _mock_site()
db = _mock_db_sequence(site, None)
async with await _client(mock_app, db) as client:
resp = await client.put(
f"/api/v1/sites/{SITE_ID}/translations/xx",
json={"strings": {"title": "test"}},
headers=_auth_headers(),
)
assert resp.status_code == 404
class TestDeleteTranslation:
@pytest.mark.asyncio
async def test_delete_translation(self, mock_app):
"""DELETE /sites/{id}/translations/fr removes the translation."""
site = _mock_site()
fr = _mock_translation(locale="fr")
db = _mock_db_sequence(site, fr)
async with await _client(mock_app, db) as client:
resp = await client.delete(
f"/api/v1/sites/{SITE_ID}/translations/fr",
headers=_auth_headers(),
)
assert resp.status_code == 204
@pytest.mark.asyncio
async def test_delete_requires_admin(self, mock_app):
"""DELETE returns 403 for editors."""
db = _mock_db_sequence()
async with await _client(mock_app, db) as client:
resp = await client.delete(
f"/api/v1/sites/{SITE_ID}/translations/fr",
headers=_auth_headers(role="editor"),
)
assert resp.status_code == 403
class TestPublicTranslation:
@pytest.mark.asyncio
async def test_get_public_translation(self, mock_app):
"""GET /translations/{site_id}/fr returns raw strings (no auth)."""
fr = _mock_translation(locale="fr")
db = _mock_db_sequence(fr)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/translations/{SITE_ID}/fr")
assert resp.status_code == 200
data = resp.json()
assert "title" in data
@pytest.mark.asyncio
async def test_get_public_translation_not_found(self, mock_app):
"""GET /translations/{site_id}/xx returns 404."""
db = _mock_db_sequence(None)
async with await _client(mock_app, db) as client:
resp = await client.get(f"/api/v1/translations/{SITE_ID}/xx")
assert resp.status_code == 404

View File

@@ -0,0 +1,593 @@
"""Tests for scan scheduling, diff engine, and scan endpoints — CMP-24.
Covers:
- Scanner schemas (new additions)
- Scan service (job lifecycle, diff engine, cookie sync)
- Scanner router (trigger, list, detail, diff endpoints)
- Integration tests against live database
"""
import uuid
from datetime import UTC, datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from src.schemas.scanner import (
CookieDiffItem,
DiffStatus,
ScanDiffResponse,
ScanJobDetailResponse,
ScanResultResponse,
TriggerScanRequest,
)
# ── Schema tests ─────────────────────────────────────────────────────
class TestSchemas:
"""Validate scanner schema additions."""
def test_scan_result_response(self):
r = ScanResultResponse(
id=uuid.uuid4(),
scan_job_id=uuid.uuid4(),
page_url="https://example.com",
cookie_name="_ga",
cookie_domain=".example.com",
storage_type="cookie",
found_at=datetime.now(UTC),
created_at=datetime.now(UTC),
)
assert r.cookie_name == "_ga"
def test_scan_job_detail_response(self):
r = ScanJobDetailResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
status="completed",
trigger="manual",
pages_scanned=5,
pages_total=10,
cookies_found=3,
error_message=None,
started_at=datetime.now(UTC),
completed_at=datetime.now(UTC),
created_at=datetime.now(UTC),
updated_at=datetime.now(UTC),
results=[],
)
assert r.status == "completed"
assert r.results == []
def test_trigger_scan_request(self):
req = TriggerScanRequest(site_id=uuid.uuid4(), max_pages=100)
assert req.max_pages == 100
def test_trigger_scan_request_defaults(self):
req = TriggerScanRequest(site_id=uuid.uuid4())
assert req.max_pages == 50
def test_trigger_scan_max_pages_validation(self):
with pytest.raises(ValueError):
TriggerScanRequest(site_id=uuid.uuid4(), max_pages=0)
with pytest.raises(ValueError):
TriggerScanRequest(site_id=uuid.uuid4(), max_pages=501)
def test_diff_status_values(self):
assert DiffStatus.NEW == "new"
assert DiffStatus.REMOVED == "removed"
assert DiffStatus.CHANGED == "changed"
def test_cookie_diff_item(self):
item = CookieDiffItem(
name="_ga",
domain=".example.com",
storage_type="cookie",
diff_status=DiffStatus.NEW,
details="First scan",
)
assert item.diff_status == "new"
def test_scan_diff_response(self):
resp = ScanDiffResponse(
current_scan_id=uuid.uuid4(),
previous_scan_id=uuid.uuid4(),
new_cookies=[
CookieDiffItem(
name="_ga",
domain=".example.com",
storage_type="cookie",
diff_status=DiffStatus.NEW,
),
],
total_new=1,
)
assert resp.total_new == 1
assert len(resp.new_cookies) == 1
def test_scan_diff_response_no_previous(self):
resp = ScanDiffResponse(
current_scan_id=uuid.uuid4(),
previous_scan_id=None,
)
assert resp.previous_scan_id is None
assert resp.total_new == 0
# ── Diff engine unit tests ───────────────────────────────────────────
class TestDiffEngine:
"""Test the scan diff engine with mocked data."""
def _make_scan_result(
self,
name: str = "_ga",
domain: str = ".example.com",
storage_type: str = "cookie",
script_source: str | None = None,
auto_category: str | None = None,
attributes: dict | None = None,
):
"""Create a mock ScanResult."""
mock = MagicMock()
mock.cookie_name = name
mock.cookie_domain = domain
mock.storage_type = storage_type
mock.script_source = script_source
mock.auto_category = auto_category
mock.attributes = attributes
return mock
def test_result_key(self):
from src.services.scanner import _result_key
mock = self._make_scan_result("_ga", ".example.com", "cookie")
assert _result_key(mock) == ("_ga", ".example.com", "cookie")
def test_result_key_different_storage(self):
from src.services.scanner import _result_key
mock = self._make_scan_result("key", "example.com", "local_storage")
assert _result_key(mock) == ("key", "example.com", "local_storage")
# ── Scan service unit tests ──────────────────────────────────────────
class TestScanService:
"""Test scan service functions with mocked DB."""
@pytest.mark.asyncio
async def test_create_scan_job(self):
from src.services.scanner import create_scan_job
db = AsyncMock()
db.add = MagicMock()
db.flush = AsyncMock()
site_id = uuid.uuid4()
job = await create_scan_job(db, site_id=site_id, trigger="manual", max_pages=10)
assert job.site_id == site_id
assert job.status == "pending"
assert job.trigger == "manual"
assert job.pages_total == 10
db.add.assert_called_once()
@pytest.mark.asyncio
async def test_start_scan_job(self):
from src.services.scanner import start_scan_job
db = AsyncMock()
db.flush = AsyncMock()
job = MagicMock()
job.status = "pending"
job.started_at = None
result = await start_scan_job(db, job)
assert result.status == "running"
assert result.started_at is not None
@pytest.mark.asyncio
async def test_complete_scan_job_success(self):
from src.services.scanner import complete_scan_job
db = AsyncMock()
db.flush = AsyncMock()
job = MagicMock()
result = await complete_scan_job(db, job, pages_scanned=5, cookies_found=10)
assert result.status == "completed"
assert result.pages_scanned == 5
assert result.cookies_found == 10
assert result.completed_at is not None
@pytest.mark.asyncio
async def test_complete_scan_job_failure(self):
from src.services.scanner import complete_scan_job
db = AsyncMock()
db.flush = AsyncMock()
job = MagicMock()
result = await complete_scan_job(db, job, error_message="Connection failed")
assert result.status == "failed"
assert result.error_message == "Connection failed"
@pytest.mark.asyncio
async def test_add_scan_result(self):
from src.services.scanner import add_scan_result
db = AsyncMock()
db.add = MagicMock()
db.flush = AsyncMock()
scan_job_id = uuid.uuid4()
result = await add_scan_result(
db,
scan_job_id=scan_job_id,
page_url="https://example.com",
cookie_name="_ga",
cookie_domain=".example.com",
storage_type="cookie",
auto_category="analytics",
)
assert result.scan_job_id == scan_job_id
assert result.cookie_name == "_ga"
assert result.auto_category == "analytics"
db.add.assert_called_once()
# ── Router unit tests (mocked DB) ───────────────────────────────────
def _mock_auth_user():
"""Create a mock authenticated user."""
from src.schemas.auth import CurrentUser
return CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@example.com",
role="owner",
)
async def _authed_client(app, db, user=None):
"""Create an authenticated test client with mocked DB."""
from src.db import get_db
from src.services.dependencies import get_current_user
if user is None:
user = _mock_auth_user()
async def _override_get_db():
yield db
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_current_user] = lambda: user
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestTriggerScan:
"""Test POST /scanner/scans."""
@pytest.mark.asyncio
async def test_trigger_scan_success(self, app):
user = _mock_auth_user()
db = AsyncMock()
# Site exists and belongs to user's org
site_mock = MagicMock()
site_mock.organisation_id = user.organisation_id
site_id = uuid.uuid4()
job_id = uuid.uuid4()
now = datetime.now(UTC)
# Mock scan job returned by create_scan_job
mock_job = MagicMock()
mock_job.id = job_id
mock_job.site_id = site_id
mock_job.status = "pending"
mock_job.trigger = "manual"
mock_job.pages_scanned = 0
mock_job.pages_total = 25
mock_job.cookies_found = 0
mock_job.error_message = None
mock_job.started_at = None
mock_job.completed_at = None
mock_job.created_at = now
mock_job.updated_at = now
# First call: site lookup. Second call: running scan count.
call_count = 0
async def mock_execute(stmt):
nonlocal call_count
call_count += 1
result = MagicMock()
if call_count == 1:
# Site lookup
result.scalar_one_or_none.return_value = site_mock
elif call_count == 2:
# Active scan jobs query — none running
result.scalars.return_value.all.return_value = []
return result
db.execute = mock_execute
db.add = MagicMock()
db.flush = AsyncMock()
with (
patch(
"src.routers.scanner.create_scan_job",
new=AsyncMock(return_value=mock_job),
),
patch("src.tasks.scanner.run_scan", create=True),
):
async with await _authed_client(app, db, user) as client:
resp = await client.post(
"/api/v1/scanner/scans",
json={
"site_id": str(site_id),
"max_pages": 25,
},
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_trigger_scan_site_not_found(self, app):
db = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = None
db.execute = AsyncMock(return_value=result)
async with await _authed_client(app, db) as client:
resp = await client.post(
"/api/v1/scanner/scans",
json={
"site_id": str(uuid.uuid4()),
"max_pages": 50,
},
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_trigger_scan_conflict(self, app):
user = _mock_auth_user()
db = AsyncMock()
# Build a non-stale active job so the router raises 409
active_job = MagicMock()
active_job.status = "running"
active_job.created_at = datetime.now(UTC)
active_job.started_at = datetime.now(UTC)
call_count = 0
async def mock_execute(stmt):
nonlocal call_count
call_count += 1
result = MagicMock()
if call_count == 1:
# Site lookup
site_mock = MagicMock()
site_mock.organisation_id = user.organisation_id
result.scalar_one_or_none.return_value = site_mock
elif call_count == 2:
# Active scan jobs query — return a non-stale job
result.scalars.return_value.all.return_value = [active_job]
return result
db.execute = mock_execute
async with await _authed_client(app, db, user) as client:
resp = await client.post(
"/api/v1/scanner/scans",
json={"site_id": str(uuid.uuid4())},
)
assert resp.status_code == 409
class TestListScans:
"""Test GET /scanner/scans/site/{site_id}."""
@pytest.mark.asyncio
async def test_list_scans_success(self, app):
user = _mock_auth_user()
db = AsyncMock()
call_count = 0
async def mock_execute(stmt):
nonlocal call_count
call_count += 1
result = MagicMock()
if call_count == 1:
# Site access check
site_mock = MagicMock()
site_mock.organisation_id = user.organisation_id
result.scalar_one_or_none.return_value = site_mock
else:
# Scan list
result.scalars.return_value.all.return_value = []
return result
db.execute = mock_execute
async with await _authed_client(app, db, user) as client:
resp = await client.get(f"/api/v1/scanner/scans/site/{uuid.uuid4()}")
assert resp.status_code == 200
assert resp.json() == []
class TestGetScan:
"""Test GET /scanner/scans/{scan_id}."""
@pytest.mark.asyncio
async def test_get_scan_not_found(self, app):
db = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = None
db.execute = AsyncMock(return_value=result)
async with await _authed_client(app, db) as client:
resp = await client.get(f"/api/v1/scanner/scans/{uuid.uuid4()}")
assert resp.status_code == 404
class TestGetScanDiff:
"""Test GET /scanner/scans/{scan_id}/diff."""
@pytest.mark.asyncio
async def test_diff_scan_not_found(self, app):
db = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = None
db.execute = AsyncMock(return_value=result)
async with await _authed_client(app, db) as client:
resp = await client.get(f"/api/v1/scanner/scans/{uuid.uuid4()}/diff")
assert resp.status_code == 404
# ── Integration tests ────────────────────────────────────────────────
try:
from tests.conftest import create_test_site, requires_db
except ImportError:
from conftest import create_test_site, requires_db
@requires_db
class TestScanIntegration:
"""Integration tests against a live database."""
async def test_trigger_scan(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-trigger")
resp = await db_client.post(
"/api/v1/scanner/scans",
json={"site_id": site_id, "max_pages": 10},
headers=auth_headers,
)
assert resp.status_code == 201
data = resp.json()
assert data["status"] == "pending"
assert data["trigger"] == "manual"
assert data["pages_total"] == 10
async def test_trigger_scan_conflict(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-conflict")
# First scan
resp1 = await db_client.post(
"/api/v1/scanner/scans",
json={"site_id": site_id},
headers=auth_headers,
)
assert resp1.status_code == 201
# Second scan — should conflict
resp2 = await db_client.post(
"/api/v1/scanner/scans",
json={"site_id": site_id},
headers=auth_headers,
)
assert resp2.status_code == 409
async def test_list_scans(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-list")
# Trigger a scan
await db_client.post(
"/api/v1/scanner/scans",
json={"site_id": site_id},
headers=auth_headers,
)
resp = await db_client.get(
f"/api/v1/scanner/scans/site/{site_id}",
headers=auth_headers,
)
assert resp.status_code == 200
scans = resp.json()
assert len(scans) >= 1
assert scans[0]["site_id"] == site_id
async def test_get_scan_detail(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-detail")
create_resp = await db_client.post(
"/api/v1/scanner/scans",
json={"site_id": site_id},
headers=auth_headers,
)
scan_id = create_resp.json()["id"]
resp = await db_client.get(
f"/api/v1/scanner/scans/{scan_id}",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["id"] == scan_id
assert "results" in data
async def test_get_scan_diff(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-diff")
create_resp = await db_client.post(
"/api/v1/scanner/scans",
json={"site_id": site_id},
headers=auth_headers,
)
scan_id = create_resp.json()["id"]
resp = await db_client.get(
f"/api/v1/scanner/scans/{scan_id}/diff",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["current_scan_id"] == scan_id
# No previous scan, so previous_scan_id should be null
assert data["previous_scan_id"] is None
async def test_scan_not_found(self, db_client, auth_headers):
resp = await db_client.get(
f"/api/v1/scanner/scans/{uuid.uuid4()}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_list_scans_pagination(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="scan-page")
resp = await db_client.get(
f"/api/v1/scanner/scans/site/{site_id}?limit=5&offset=0",
headers=auth_headers,
)
assert resp.status_code == 200
async def test_trigger_scan_requires_auth(self, db_client):
resp = await db_client.post(
"/api/v1/scanner/scans",
json={"site_id": str(uuid.uuid4())},
)
assert resp.status_code in (401, 403)
async def test_list_scans_requires_auth(self, db_client):
resp = await db_client.get(f"/api/v1/scanner/scans/site/{uuid.uuid4()}")
assert resp.status_code in (401, 403)

View File

@@ -0,0 +1,346 @@
"""Tests for scanner cookie report endpoint — CMP-23 (API side).
Covers:
- Schema validation
- Report endpoint (unit tests with mocked DB)
- Integration tests against live database
"""
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.schemas.scanner import (
CookieReportRequest,
CookieReportResponse,
ReportedCookie,
ScanStatus,
ScanTrigger,
)
# ── Schema tests ─────────────────────────────────────────────────────
class TestSchemas:
"""Validate scanner schemas."""
def test_scan_status_values(self):
assert ScanStatus.PENDING == "pending"
assert ScanStatus.COMPLETED == "completed"
def test_scan_trigger_values(self):
assert ScanTrigger.CLIENT_REPORT == "client_report"
def test_reported_cookie(self):
rc = ReportedCookie(
name="_ga",
domain=".example.com",
storage_type="cookie",
value_length=30,
)
assert rc.name == "_ga"
def test_reported_cookie_validation(self):
with pytest.raises(ValueError):
ReportedCookie(name="", domain=".example.com")
def test_cookie_report_request(self):
req = CookieReportRequest(
site_id=uuid.uuid4(),
page_url="https://example.com/page",
cookies=[
ReportedCookie(name="_ga", domain=".example.com"),
],
collected_at=datetime.now(),
)
assert len(req.cookies) == 1
def test_cookie_report_response(self):
resp = CookieReportResponse(
accepted=True,
cookies_received=5,
new_cookies=2,
)
assert resp.new_cookies == 2
# ── Router unit tests (mocked DB) ───────────────────────────────────
def _mock_db_with_site():
"""Create a mock DB that returns a site for validation."""
db = AsyncMock()
site_mock = MagicMock()
cookie_result = MagicMock()
cookie_result.scalar_one_or_none.return_value = None # no existing cookie
call_count = 0
async def mock_execute(stmt):
nonlocal call_count
call_count += 1
if call_count == 1:
# Site validation
result = MagicMock()
result.scalar_one_or_none.return_value = site_mock
return result
# Cookie existence checks
return cookie_result
db.execute = mock_execute
db.add = MagicMock()
db.flush = AsyncMock()
return db
async def _client(app, db):
from src.db import get_db
async def _override_get_db():
yield db
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestReportEndpoint:
"""Test POST /scanner/report."""
@pytest.mark.asyncio
async def test_report_success(self, app):
db = _mock_db_with_site()
async with await _client(app, db) as client:
resp = await client.post(
"/api/v1/scanner/report",
json={
"site_id": str(uuid.uuid4()),
"page_url": "https://example.com",
"cookies": [
{
"name": "_ga",
"domain": ".example.com",
"storage_type": "cookie",
"value_length": 30,
},
],
"collected_at": datetime.now().isoformat(),
},
)
assert resp.status_code == 202
data = resp.json()
assert data["accepted"] is True
assert data["cookies_received"] == 1
@pytest.mark.asyncio
async def test_report_site_not_found(self, app):
db = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = None
db.execute.return_value = result
async with await _client(app, db) as client:
resp = await client.post(
"/api/v1/scanner/report",
json={
"site_id": str(uuid.uuid4()),
"page_url": "https://example.com",
"cookies": [],
"collected_at": datetime.now().isoformat(),
},
)
assert resp.status_code == 404
@pytest.mark.asyncio
async def test_report_empty_cookies(self, app):
db = AsyncMock()
site_result = MagicMock()
site_result.scalar_one_or_none.return_value = MagicMock()
db.execute.return_value = site_result
db.flush = AsyncMock()
async with await _client(app, db) as client:
resp = await client.post(
"/api/v1/scanner/report",
json={
"site_id": str(uuid.uuid4()),
"page_url": "https://example.com",
"cookies": [],
"collected_at": datetime.now().isoformat(),
},
)
assert resp.status_code == 202
assert resp.json()["cookies_received"] == 0
@pytest.mark.asyncio
async def test_report_multiple_storage_types(self, app):
db = _mock_db_with_site()
async with await _client(app, db) as client:
resp = await client.post(
"/api/v1/scanner/report",
json={
"site_id": str(uuid.uuid4()),
"page_url": "https://example.com",
"cookies": [
{
"name": "_ga",
"domain": ".example.com",
"storage_type": "cookie",
"value_length": 30,
},
{
"name": "analytics_id",
"domain": "example.com",
"storage_type": "local_storage",
"value_length": 10,
},
{
"name": "session_key",
"domain": "example.com",
"storage_type": "session_storage",
"value_length": 20,
},
],
"collected_at": datetime.now().isoformat(),
},
)
assert resp.status_code == 202
assert resp.json()["cookies_received"] == 3
# ── Integration tests ────────────────────────────────────────────────
try:
from tests.conftest import create_test_site, requires_db
except ImportError:
from conftest import create_test_site, requires_db
@requires_db
class TestScannerReportIntegration:
"""Integration tests against a live database."""
async def test_report_creates_new_cookies(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-new")
resp = await db_client.post(
"/api/v1/scanner/report",
json={
"site_id": site_id,
"page_url": "https://report-new.com/page",
"cookies": [
{
"name": "_ga",
"domain": ".report-new.com",
"storage_type": "cookie",
"value_length": 30,
},
{
"name": "analytics_id",
"domain": "report-new.com",
"storage_type": "local_storage",
"value_length": 10,
},
],
"collected_at": datetime.now().isoformat(),
},
)
assert resp.status_code == 202
data = resp.json()
assert data["cookies_received"] == 2
assert data["new_cookies"] == 2
# Verify cookies were created
cookies_resp = await db_client.get(
f"/api/v1/cookies/sites/{site_id}",
headers=auth_headers,
)
assert cookies_resp.status_code == 200
cookies = cookies_resp.json()
names = [c["name"] for c in cookies]
assert "_ga" in names
assert "analytics_id" in names
async def test_report_deduplicates_existing_cookies(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-dedup")
report_payload = {
"site_id": site_id,
"page_url": "https://report-dedup.com",
"cookies": [
{
"name": "_dedup_cookie",
"domain": ".report-dedup.com",
"storage_type": "cookie",
"value_length": 10,
},
],
"collected_at": datetime.now().isoformat(),
}
# First report — should create
resp1 = await db_client.post("/api/v1/scanner/report", json=report_payload)
assert resp1.status_code == 202
assert resp1.json()["new_cookies"] == 1
# Second report — should not create duplicate
resp2 = await db_client.post("/api/v1/scanner/report", json=report_payload)
assert resp2.status_code == 202
assert resp2.json()["new_cookies"] == 0
async def test_report_sets_review_status_pending(self, db_client, auth_headers):
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-status")
await db_client.post(
"/api/v1/scanner/report",
json={
"site_id": site_id,
"page_url": "https://report-status.com",
"cookies": [
{
"name": "_status_cookie",
"domain": ".report-status.com",
"storage_type": "cookie",
"value_length": 5,
},
],
"collected_at": datetime.now().isoformat(),
},
)
# Check the created cookie's review status
cookies_resp = await db_client.get(
f"/api/v1/cookies/sites/{site_id}",
headers=auth_headers,
)
cookies = cookies_resp.json()
status_cookie = next((c for c in cookies if c["name"] == "_status_cookie"), None)
assert status_cookie is not None
assert status_cookie["review_status"] == "pending"
async def test_report_no_auth_required(self, db_client, auth_headers):
"""Report endpoint should work without authentication."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="report-noauth")
# POST without auth headers
resp = await db_client.post(
"/api/v1/scanner/report",
json={
"site_id": site_id,
"page_url": "https://report-noauth.com",
"cookies": [],
"collected_at": datetime.now().isoformat(),
},
)
assert resp.status_code == 202
async def test_report_invalid_site(self, db_client):
resp = await db_client.post(
"/api/v1/scanner/report",
json={
"site_id": str(uuid.uuid4()),
"page_url": "https://unknown.com",
"cookies": [],
"collected_at": datetime.now().isoformat(),
},
)
assert resp.status_code == 404

View File

@@ -0,0 +1,37 @@
"""Tests for application settings parsing."""
from src.config.settings import Settings
class TestAllowedOrigins:
"""Tests for the allowed_origins_list property."""
def test_comma_separated_string(self) -> None:
"""Comma-separated string is parsed into a list."""
settings = Settings(allowed_origins="https://a.com,https://b.com")
assert settings.allowed_origins_list == ["https://a.com", "https://b.com"]
def test_comma_separated_with_spaces(self) -> None:
"""Whitespace around commas is stripped."""
settings = Settings(allowed_origins="https://a.com , https://b.com")
assert settings.allowed_origins_list == ["https://a.com", "https://b.com"]
def test_single_origin_string(self) -> None:
"""A single origin string (no comma) is a single-element list."""
settings = Settings(allowed_origins="https://a.com")
assert settings.allowed_origins_list == ["https://a.com"]
def test_empty_string(self) -> None:
"""An empty string results in an empty list."""
settings = Settings(allowed_origins="")
assert settings.allowed_origins_list == []
def test_trailing_comma_ignored(self) -> None:
"""Trailing commas don't produce empty entries."""
settings = Settings(allowed_origins="https://a.com,")
assert settings.allowed_origins_list == ["https://a.com"]
def test_default_value(self) -> None:
"""Default value is localhost:5173."""
settings = Settings()
assert settings.allowed_origins_list == ["http://localhost:5173"]

View File

@@ -0,0 +1,132 @@
"""Tests for site and site config CRUD endpoints and schemas."""
import uuid
import pytest
from pydantic import ValidationError
from src.schemas.site import (
BlockingMode,
SiteConfigCreate,
SiteConfigResponse,
SiteConfigUpdate,
SiteCreate,
SiteResponse,
SiteUpdate,
)
class TestSiteSchemas:
def test_create_valid(self):
site = SiteCreate(domain="example.com", display_name="Example Site")
assert site.domain == "example.com"
def test_create_empty_domain_rejected(self):
with pytest.raises(ValidationError):
SiteCreate(domain="", display_name="Test")
def test_update_partial(self):
update = SiteUpdate(display_name="New Name")
data = update.model_dump(exclude_unset=True)
assert data == {"display_name": "New Name"}
def test_response_from_attributes(self):
now = "2026-01-01T00:00:00Z"
resp = SiteResponse(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
domain="example.com",
display_name="Example",
is_active=True,
additional_domains=None,
created_at=now,
updated_at=now,
)
assert resp.is_active
class TestSiteConfigSchemas:
def test_create_defaults(self):
config = SiteConfigCreate()
assert config.blocking_mode == BlockingMode.OPT_IN
assert config.gcm_enabled is True
assert config.tcf_enabled is False
assert config.scan_max_pages == 50
assert config.consent_expiry_days == 365
def test_create_with_regional_modes(self):
config = SiteConfigCreate(
regional_modes={"EU": "opt_in", "US-CA": "opt_out", "DEFAULT": "opt_in"}
)
assert config.regional_modes["EU"] == "opt_in"
def test_scan_max_pages_bounds(self):
with pytest.raises(ValidationError):
SiteConfigCreate(scan_max_pages=0)
with pytest.raises(ValidationError):
SiteConfigCreate(scan_max_pages=1001)
def test_consent_expiry_bounds(self):
with pytest.raises(ValidationError):
SiteConfigCreate(consent_expiry_days=0)
with pytest.raises(ValidationError):
SiteConfigCreate(consent_expiry_days=731)
def test_update_partial(self):
update = SiteConfigUpdate(blocking_mode=BlockingMode.OPT_OUT)
data = update.model_dump(exclude_unset=True)
assert data == {"blocking_mode": "opt_out"}
def test_response_from_attributes(self):
now = "2026-01-01T00:00:00Z"
resp = SiteConfigResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
blocking_mode="opt_in",
regional_modes=None,
tcf_enabled=False,
tcf_publisher_cc=None,
gcm_enabled=True,
gcm_default=None,
banner_config=None,
privacy_policy_url=None,
scan_schedule_cron=None,
scan_max_pages=50,
consent_expiry_days=365,
created_at=now,
updated_at=now,
)
assert resp.blocking_mode == "opt_in"
def test_display_mode_in_banner_config(self):
"""Display mode is stored inside banner_config, not as a top-level field."""
config = SiteConfigCreate(
banner_config={"displayMode": "overlay"},
)
assert config.banner_config["displayMode"] == "overlay"
class TestEnums:
def test_blocking_modes(self):
assert BlockingMode.OPT_IN == "opt_in"
assert BlockingMode.OPT_OUT == "opt_out"
assert BlockingMode.INFORMATIONAL == "informational"
@pytest.mark.asyncio
class TestSiteRoutesRegistered:
async def test_site_routes_exist(self, client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/api/v1/sites/" in paths
assert "/api/v1/sites/{site_id}" in paths
assert "/api/v1/sites/{site_id}/config" in paths
async def test_site_endpoints_require_auth(self, client):
response = await client.get("/api/v1/sites/")
assert response.status_code == 401
async def test_site_config_endpoints_require_auth(self, client):
site_id = uuid.uuid4()
response = await client.get(f"/api/v1/sites/{site_id}/config")
assert response.status_code == 401

View File

@@ -0,0 +1,104 @@
"""Tests for site group config endpoints."""
import uuid
import pytest
from tests.conftest import requires_db
class TestSiteGroupConfigRoutes:
"""Unit tests — no database required."""
def test_group_config_get_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/site-groups/{group_id}/config" in routes
def test_group_config_put_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/site-groups/{group_id}/config" in routes
@pytest.mark.asyncio
async def test_get_group_config_requires_auth(self, client):
group_id = uuid.uuid4()
resp = await client.get(f"/api/v1/site-groups/{group_id}/config")
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_put_group_config_requires_auth(self, client):
group_id = uuid.uuid4()
resp = await client.put(
f"/api/v1/site-groups/{group_id}/config",
json={"blocking_mode": "opt_in"},
)
assert resp.status_code == 401
class TestSiteGroupConfigIntegration:
"""Integration tests — require a running PostgreSQL database."""
@requires_db
async def test_create_group_and_get_config(self, db_client, auth_headers):
# Create a group
resp = await db_client.post(
"/api/v1/site-groups/",
json={"name": f"test-group-{uuid.uuid4().hex[:8]}"},
headers=auth_headers,
)
assert resp.status_code == 201
group_id = resp.json()["id"]
# GET config (auto-creates empty row)
resp = await db_client.get(
f"/api/v1/site-groups/{group_id}/config",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["site_group_id"] == group_id
assert data["blocking_mode"] is None
assert data["consent_expiry_days"] is None
@requires_db
async def test_update_group_config(self, db_client, auth_headers):
# Create a group
resp = await db_client.post(
"/api/v1/site-groups/",
json={"name": f"cfg-group-{uuid.uuid4().hex[:8]}"},
headers=auth_headers,
)
group_id = resp.json()["id"]
# PUT config
resp = await db_client.put(
f"/api/v1/site-groups/{group_id}/config",
json={
"blocking_mode": "opt_out",
"consent_expiry_days": 90,
"tcf_enabled": True,
},
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["blocking_mode"] == "opt_out"
assert data["consent_expiry_days"] == 90
assert data["tcf_enabled"] is True
# GET confirms persistence
resp = await db_client.get(
f"/api/v1/site-groups/{group_id}/config",
headers=auth_headers,
)
data = resp.json()
assert data["blocking_mode"] == "opt_out"
assert data["consent_expiry_days"] == 90
@requires_db
async def test_group_config_not_found_for_other_org(self, db_client, auth_headers):
fake_group_id = str(uuid.uuid4())
resp = await db_client.get(
f"/api/v1/site-groups/{fake_group_id}/config",
headers=auth_headers,
)
assert resp.status_code == 404