feat: initial public release

ConsentOS — a privacy-first cookie consent management platform.

Self-hosted, source-available alternative to OneTrust, Cookiebot, and
CookieYes. Full standards coverage (IAB TCF v2.2, GPP v1, Google
Consent Mode v2, GPC, Shopify Customer Privacy API), multi-tenant
architecture with role-based access, configuration cascade
(system → org → group → site → region), dark-pattern detection in
the scanner, and a tamper-evident consent record audit trail.

This is the initial public release. Prior development history is
retained internally.

See README.md for the feature list, architecture overview, and
quick-start instructions. Licensed under the Elastic Licence 2.0 —
self-host freely; do not resell as a managed service.
This commit is contained in:
James Cottrill
2026-04-13 14:20:15 +00:00
commit fbf26453f2
341 changed files with 62807 additions and 0 deletions

12
apps/api/.dockerignore Normal file
View File

@@ -0,0 +1,12 @@
.venv/
__pycache__/
*.py[cod]
.pytest_cache/
.mypy_cache/
.coverage
htmlcov/
*.egg-info/
tests/
fly.toml
.env
.env.*

51
apps/api/Dockerfile Normal file
View File

@@ -0,0 +1,51 @@
# ── Build stage ──────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc libpq-dev \
&& rm -rf /var/lib/apt/lists/*
COPY pyproject.toml .
RUN pip install --no-cache-dir --prefix=/install .
# ── Runtime stage ────────────────────────────────────────────────────
FROM python:3.12-slim
RUN apt-get update && apt-get install -y --no-install-recommends \
libpq5 curl \
&& rm -rf /var/lib/apt/lists/*
# Non-root user for security
RUN groupadd -r cmp && useradd -r -g cmp -d /app -s /sbin/nologin cmp
WORKDIR /app
# Copy installed dependencies from builder
COPY --from=builder /install /usr/local
# Copy application code
COPY . .
RUN chown -R cmp:cmp /app
USER cmp
EXPOSE 8000
HEALTHCHECK --interval=30s --timeout=5s --start-period=30s --retries=3 \
CMD curl -f http://localhost:8000/health || exit 1
# Start the server. Database migrations and the initial-admin
# bootstrap are owned by a separate init container (see the OSS
# docker-compose in consentos-deployment) — the API assumes the
# schema is ready by the time it starts.
# Workers configurable via WEB_CONCURRENCY (default 4, use 1 for 256MB RAM)
CMD ["sh", "-c", "uvicorn src.main:app \
--host 0.0.0.0 \
--port ${PORT:-8000} \
--workers ${WEB_CONCURRENCY:-4} \
--access-log \
--proxy-headers \
--forwarded-allow-ips '*'"]

149
apps/api/alembic.ini Normal file
View File

@@ -0,0 +1,149 @@
# A generic, single database configuration.
[alembic]
# path to migration scripts.
# this is typically a path given in POSIX (e.g. forward slashes)
# format, relative to the token %(here)s which refers to the location of this
# ini file
script_location = %(here)s/alembic
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
# Uncomment the line below if you want the files to be prepended with date and time
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
# for all available tokens
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
# sys.path path, will be prepended to sys.path if present.
# defaults to the current working directory. for multiple paths, the path separator
# is defined by "path_separator" below.
prepend_sys_path = .
# timezone to use when rendering the date within the migration file
# as well as the filename.
# If specified, requires the tzdata library which can be installed by adding
# `alembic[tz]` to the pip requirements.
# string value is passed to ZoneInfo()
# leave blank for localtime
# timezone =
# max length of characters to apply to the "slug" field
# truncate_slug_length = 40
# set to 'true' to run the environment during
# the 'revision' command, regardless of autogenerate
# revision_environment = false
# set to 'true' to allow .pyc and .pyo files without
# a source .py file to be detected as revisions in the
# versions/ directory
# sourceless = false
# version location specification; This defaults
# to <script_location>/versions. When using multiple version
# directories, initial revisions must be specified with --version-path.
# The path separator used here should be the separator specified by "path_separator"
# below.
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
# path_separator; This indicates what character is used to split lists of file
# paths, including version_locations and prepend_sys_path within configparser
# files such as alembic.ini.
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
# to provide os-dependent path splitting.
#
# Note that in order to support legacy alembic.ini files, this default does NOT
# take place if path_separator is not present in alembic.ini. If this
# option is omitted entirely, fallback logic is as follows:
#
# 1. Parsing of the version_locations option falls back to using the legacy
# "version_path_separator" key, which if absent then falls back to the legacy
# behavior of splitting on spaces and/or commas.
# 2. Parsing of the prepend_sys_path option falls back to the legacy
# behavior of splitting on spaces, commas, or colons.
#
# Valid values for path_separator are:
#
# path_separator = :
# path_separator = ;
# path_separator = space
# path_separator = newline
#
# Use os.pathsep. Default configuration used for new projects.
path_separator = os
# set to 'true' to search source files recursively
# in each "version_locations" directory
# new in Alembic version 1.10
# recursive_version_locations = false
# the output encoding used when revision files
# are written from script.py.mako
# output_encoding = utf-8
# database URL. This is consumed by the user-maintained env.py script only.
# other means of configuring database URLs may be customized within the env.py
# file.
sqlalchemy.url = postgresql://consentos:consentos@localhost:5432/consentos
[post_write_hooks]
# post_write_hooks defines scripts or Python functions that are run
# on newly generated revision scripts. See the documentation for further
# detail and examples
# format using "black" - use the console_scripts runner, against the "black" entrypoint
# hooks = black
# black.type = console_scripts
# black.entrypoint = black
# black.options = -l 79 REVISION_SCRIPT_FILENAME
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
# hooks = ruff
# ruff.type = module
# ruff.module = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Alternatively, use the exec runner to execute a binary found on your PATH
# hooks = ruff
# ruff.type = exec
# ruff.executable = ruff
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
# Logging configuration. This is also consumed by the user-maintained
# env.py script only.
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARNING
handlers = console
qualname =
[logger_sqlalchemy]
level = WARNING
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

1
apps/api/alembic/README Normal file
View File

@@ -0,0 +1 @@
Generic single-database configuration.

61
apps/api/alembic/env.py Normal file
View File

@@ -0,0 +1,61 @@
import os
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
from src.models import Base
# Alembic Config object
config = context.config
# Override sqlalchemy.url from environment if set
database_url = os.environ.get("DATABASE_URL")
if database_url:
# Alembic needs the synchronous driver
database_url = database_url.replace("postgresql+asyncpg://", "postgresql://")
config.set_main_option("sqlalchemy.url", database_url)
# Set up Python logging from the config file
if config.config_file_name is not None:
fileConfig(config.config_file_name)
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
connectable = engine_from_config(
config.get_section(config.config_ini_section, {}),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(
connection=connection,
target_metadata=target_metadata,
)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
"""Upgrade schema."""
${upgrades if upgrades else "pass"}
def downgrade() -> None:
"""Downgrade schema."""
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,442 @@
"""initial schema
Revision ID: 0001
Revises:
Create Date: 2026-04-13
Creates the full core schema plus seeds the default cookie categories.
"""
import uuid
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = '0001'
down_revision: Union[str, Sequence[str], None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
"""Upgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('cookie_categories',
sa.Column('name', sa.String(length=50), nullable=False),
sa.Column('slug', sa.String(length=50), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_essential', sa.Boolean(), nullable=False),
sa.Column('display_order', sa.Integer(), server_default='0', nullable=False),
sa.Column('tcf_purpose_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gcm_consent_types', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name'),
sa.UniqueConstraint('slug')
)
op.create_table('organisations',
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('slug', sa.String(length=100), nullable=False),
sa.Column('contact_email', sa.String(length=255), nullable=True),
sa.Column('billing_plan', sa.String(length=50), server_default='free', nullable=False),
sa.Column('notes', sa.Text(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_organisations_slug'), 'organisations', ['slug'], unique=True)
op.create_table('known_cookies',
sa.Column('name_pattern', sa.String(length=255), nullable=False),
sa.Column('domain_pattern', sa.String(length=255), nullable=False),
sa.Column('category_id', sa.UUID(), nullable=False),
sa.Column('vendor', sa.String(length=255), nullable=True),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('is_regex', sa.Boolean(), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['category_id'], ['cookie_categories.id'], ondelete='RESTRICT'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('name_pattern', 'domain_pattern', name='uq_known_cookies_name_domain')
)
op.create_index(op.f('ix_known_cookies_name_pattern'), 'known_cookies', ['name_pattern'], unique=False)
op.create_table('org_configs',
sa.Column('organisation_id', sa.UUID(), nullable=False),
sa.Column('blocking_mode', sa.String(length=20), nullable=True),
sa.Column('regional_modes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('tcf_enabled', sa.Boolean(), nullable=True),
sa.Column('tcf_publisher_cc', sa.String(length=2), nullable=True),
sa.Column('gpp_enabled', sa.Boolean(), nullable=True),
sa.Column('gpp_supported_apis', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gpc_enabled', sa.Boolean(), nullable=True),
sa.Column('gpc_jurisdictions', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gpc_global_honour', sa.Boolean(), nullable=True),
sa.Column('gcm_enabled', sa.Boolean(), nullable=True),
sa.Column('gcm_default', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('shopify_privacy_enabled', sa.Boolean(), nullable=True),
sa.Column('banner_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('privacy_policy_url', sa.Text(), nullable=True),
sa.Column('terms_url', sa.Text(), nullable=True),
sa.Column('scan_schedule_cron', sa.String(length=100), nullable=True),
sa.Column('scan_max_pages', sa.Integer(), nullable=True),
sa.Column('consent_expiry_days', sa.Integer(), nullable=True),
sa.Column('consent_retention_days', sa.Integer(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('organisation_id')
)
op.create_table('site_groups',
sa.Column('organisation_id', sa.UUID(), nullable=False),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('organisation_id', 'name', name='uq_site_groups_org_name')
)
op.create_index(op.f('ix_site_groups_organisation_id'), 'site_groups', ['organisation_id'], unique=False)
op.create_table('users',
sa.Column('organisation_id', sa.UUID(), nullable=False),
sa.Column('email', sa.String(length=255), nullable=False),
sa.Column('password_hash', sa.String(length=255), nullable=False),
sa.Column('full_name', sa.String(length=255), nullable=False),
sa.Column('role', sa.String(length=20), server_default='viewer', nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
op.create_index(op.f('ix_users_organisation_id'), 'users', ['organisation_id'], unique=False)
op.create_table('site_group_configs',
sa.Column('site_group_id', sa.UUID(), nullable=False),
sa.Column('blocking_mode', sa.String(length=20), nullable=True),
sa.Column('regional_modes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('tcf_enabled', sa.Boolean(), nullable=True),
sa.Column('tcf_publisher_cc', sa.String(length=2), nullable=True),
sa.Column('gpp_enabled', sa.Boolean(), nullable=True),
sa.Column('gpp_supported_apis', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gpc_enabled', sa.Boolean(), nullable=True),
sa.Column('gpc_jurisdictions', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gpc_global_honour', sa.Boolean(), nullable=True),
sa.Column('gcm_enabled', sa.Boolean(), nullable=True),
sa.Column('gcm_default', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('shopify_privacy_enabled', sa.Boolean(), nullable=True),
sa.Column('banner_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('privacy_policy_url', sa.Text(), nullable=True),
sa.Column('terms_url', sa.Text(), nullable=True),
sa.Column('scan_schedule_cron', sa.String(length=100), nullable=True),
sa.Column('scan_max_pages', sa.Integer(), nullable=True),
sa.Column('consent_expiry_days', sa.Integer(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['site_group_id'], ['site_groups.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('site_group_id')
)
op.create_table('sites',
sa.Column('organisation_id', sa.UUID(), nullable=False),
sa.Column('domain', sa.String(length=255), nullable=False),
sa.Column('display_name', sa.String(length=255), nullable=False),
sa.Column('is_active', sa.Boolean(), nullable=False),
sa.Column('additional_domains', postgresql.ARRAY(sa.String(length=255)), nullable=True),
sa.Column('site_group_id', sa.UUID(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['site_group_id'], ['site_groups.id'], ondelete='SET NULL'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('organisation_id', 'domain', name='uq_sites_org_domain')
)
op.create_index(op.f('ix_sites_domain'), 'sites', ['domain'], unique=False)
op.create_index(op.f('ix_sites_organisation_id'), 'sites', ['organisation_id'], unique=False)
op.create_index(op.f('ix_sites_site_group_id'), 'sites', ['site_group_id'], unique=False)
op.create_table('consent_records',
sa.Column('site_id', sa.UUID(), nullable=False),
sa.Column('visitor_id', sa.String(length=255), nullable=False),
sa.Column('ip_hash', sa.String(length=64), nullable=True),
sa.Column('user_agent_hash', sa.String(length=64), nullable=True),
sa.Column('action', sa.String(length=30), nullable=False),
sa.Column('categories_accepted', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('categories_rejected', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('tc_string', sa.Text(), nullable=True),
sa.Column('gcm_state', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gpp_string', sa.Text(), nullable=True),
sa.Column('gpc_detected', sa.Boolean(), nullable=True),
sa.Column('gpc_honoured', sa.Boolean(), nullable=True),
sa.Column('ab_test_id', sa.UUID(), nullable=True),
sa.Column('ab_variant_id', sa.UUID(), nullable=True),
sa.Column('page_url', sa.Text(), nullable=True),
sa.Column('country_code', sa.String(length=5), nullable=True),
sa.Column('region_code', sa.String(length=10), nullable=True),
sa.Column('consented_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_consent_records_ab_test_id'), 'consent_records', ['ab_test_id'], unique=False)
op.create_index(op.f('ix_consent_records_consented_at'), 'consent_records', ['consented_at'], unique=False)
op.create_index(op.f('ix_consent_records_site_id'), 'consent_records', ['site_id'], unique=False)
op.create_index(op.f('ix_consent_records_visitor_id'), 'consent_records', ['visitor_id'], unique=False)
op.create_table('cookie_allow_list',
sa.Column('site_id', sa.UUID(), nullable=False),
sa.Column('category_id', sa.UUID(), nullable=False),
sa.Column('name_pattern', sa.String(length=255), nullable=False),
sa.Column('domain_pattern', sa.String(length=255), nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['category_id'], ['cookie_categories.id'], ondelete='RESTRICT'),
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('site_id', 'name_pattern', 'domain_pattern', name='uq_allow_list_site_name_domain')
)
op.create_index(op.f('ix_cookie_allow_list_site_id'), 'cookie_allow_list', ['site_id'], unique=False)
op.create_table('cookies',
sa.Column('site_id', sa.UUID(), nullable=False),
sa.Column('category_id', sa.UUID(), nullable=True),
sa.Column('name', sa.String(length=255), nullable=False),
sa.Column('domain', sa.String(length=255), nullable=False),
sa.Column('storage_type', sa.String(length=30), server_default='cookie', nullable=False),
sa.Column('description', sa.Text(), nullable=True),
sa.Column('vendor', sa.String(length=255), nullable=True),
sa.Column('path', sa.String(length=500), nullable=True),
sa.Column('max_age_seconds', sa.Integer(), nullable=True),
sa.Column('is_http_only', sa.Boolean(), nullable=True),
sa.Column('is_secure', sa.Boolean(), nullable=True),
sa.Column('same_site', sa.String(length=10), nullable=True),
sa.Column('review_status', sa.String(length=20), server_default='pending', nullable=False),
sa.Column('first_seen_at', sa.String(length=50), nullable=True),
sa.Column('last_seen_at', sa.String(length=50), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['category_id'], ['cookie_categories.id'], ondelete='SET NULL'),
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('site_id', 'name', 'domain', 'storage_type', name='uq_cookies_site_name_domain_type')
)
op.create_index(op.f('ix_cookies_category_id'), 'cookies', ['category_id'], unique=False)
op.create_index(op.f('ix_cookies_name'), 'cookies', ['name'], unique=False)
op.create_index(op.f('ix_cookies_site_id'), 'cookies', ['site_id'], unique=False)
op.create_table('scan_jobs',
sa.Column('site_id', sa.UUID(), nullable=False),
sa.Column('status', sa.String(length=20), server_default='pending', nullable=False),
sa.Column('trigger', sa.String(length=20), server_default='manual', nullable=False),
sa.Column('pages_scanned', sa.Integer(), server_default='0', nullable=False),
sa.Column('pages_total', sa.Integer(), nullable=True),
sa.Column('cookies_found', sa.Integer(), server_default='0', nullable=False),
sa.Column('error_message', sa.Text(), nullable=True),
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_scan_jobs_site_id'), 'scan_jobs', ['site_id'], unique=False)
op.create_index(op.f('ix_scan_jobs_status'), 'scan_jobs', ['status'], unique=False)
op.create_table('site_configs',
sa.Column('site_id', sa.UUID(), nullable=False),
sa.Column('blocking_mode', sa.String(length=20), server_default='opt_in', nullable=False),
sa.Column('regional_modes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('tcf_enabled', sa.Boolean(), nullable=False),
sa.Column('tcf_publisher_cc', sa.String(length=2), nullable=True),
sa.Column('gpp_enabled', sa.Boolean(), nullable=False),
sa.Column('gpp_supported_apis', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gpc_enabled', sa.Boolean(), nullable=False),
sa.Column('gpc_jurisdictions', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('gpc_global_honour', sa.Boolean(), nullable=False),
sa.Column('gcm_enabled', sa.Boolean(), nullable=False),
sa.Column('gcm_default', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('shopify_privacy_enabled', sa.Boolean(), nullable=False),
sa.Column('banner_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('display_mode', sa.String(length=30), server_default='bottom_banner', nullable=False),
sa.Column('privacy_policy_url', sa.Text(), nullable=True),
sa.Column('terms_url', sa.Text(), nullable=True),
sa.Column('scan_schedule_cron', sa.String(length=100), nullable=True),
sa.Column('scan_max_pages', sa.Integer(), server_default='50', nullable=False),
sa.Column('consent_expiry_days', sa.Integer(), server_default='365', nullable=False),
sa.Column('consent_retention_days', sa.Integer(), nullable=True),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('site_id')
)
op.create_table('translations',
sa.Column('site_id', sa.UUID(), nullable=False),
sa.Column('locale', sa.String(length=10), nullable=False),
sa.Column('strings', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('site_id', 'locale', name='uq_translations_site_locale')
)
op.create_index(op.f('ix_translations_site_id'), 'translations', ['site_id'], unique=False)
op.create_table('scan_results',
sa.Column('scan_job_id', sa.UUID(), nullable=False),
sa.Column('page_url', sa.Text(), nullable=False),
sa.Column('cookie_name', sa.String(length=255), nullable=False),
sa.Column('cookie_domain', sa.String(length=255), nullable=False),
sa.Column('storage_type', sa.String(length=30), server_default='cookie', nullable=False),
sa.Column('attributes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column('script_source', sa.Text(), nullable=True),
sa.Column('auto_category', sa.String(length=50), nullable=True),
sa.Column('initiator_chain', postgresql.ARRAY(sa.Text()), nullable=True, comment='Ordered script URLs from root initiator to leaf'),
sa.Column('found_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('id', sa.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['scan_job_id'], ['scan_jobs.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_scan_results_scan_job_id'), 'scan_results', ['scan_job_id'], unique=False)
# ### end Alembic commands ###
# ── Seed default cookie categories ───────────────────────────────
cookie_categories_table = sa.table(
"cookie_categories",
sa.column("id", sa.UUID()),
sa.column("name", sa.String),
sa.column("slug", sa.String),
sa.column("description", sa.Text),
sa.column("is_essential", sa.Boolean),
sa.column("display_order", sa.Integer),
sa.column("tcf_purpose_ids", postgresql.JSONB),
sa.column("gcm_consent_types", postgresql.JSONB),
)
op.bulk_insert(
cookie_categories_table,
[
{
"id": uuid.UUID("10000000-0000-0000-0000-000000000001"),
"name": "Necessary",
"slug": "necessary",
"description": (
"Essential cookies required for the website to function. "
"These cannot be disabled."
),
"is_essential": True,
"display_order": 0,
"tcf_purpose_ids": None,
"gcm_consent_types": ["functionality_storage", "security_storage"],
},
{
"id": uuid.UUID("10000000-0000-0000-0000-000000000002"),
"name": "Functional",
"slug": "functional",
"description": (
"Cookies that enable enhanced functionality and personalisation, "
"such as remembering preferences."
),
"is_essential": False,
"display_order": 1,
"tcf_purpose_ids": [1],
"gcm_consent_types": ["functionality_storage", "personalization_storage"],
},
{
"id": uuid.UUID("10000000-0000-0000-0000-000000000003"),
"name": "Analytics",
"slug": "analytics",
"description": (
"Cookies used to collect information about how visitors use the website, "
"helping to improve the site."
),
"is_essential": False,
"display_order": 2,
"tcf_purpose_ids": [7, 8, 9],
"gcm_consent_types": ["analytics_storage"],
},
{
"id": uuid.UUID("10000000-0000-0000-0000-000000000004"),
"name": "Marketing",
"slug": "marketing",
"description": (
"Cookies used to deliver personalised advertisements and "
"track advertising campaign performance."
),
"is_essential": False,
"display_order": 3,
"tcf_purpose_ids": [2, 3, 4, 5, 6, 10, 11],
"gcm_consent_types": ["ad_storage", "ad_user_data", "ad_personalization"],
},
{
"id": uuid.UUID("10000000-0000-0000-0000-000000000005"),
"name": "Personalisation",
"slug": "personalisation",
"description": (
"Cookies that enable content personalisation based on "
"user profiles and browsing behaviour."
),
"is_essential": False,
"display_order": 4,
"tcf_purpose_ids": [3, 4, 6],
"gcm_consent_types": ["personalization_storage"],
},
],
)
def downgrade() -> None:
"""Downgrade schema."""
# ### commands auto generated by Alembic - please adjust! ###
op.drop_index(op.f('ix_scan_results_scan_job_id'), table_name='scan_results')
op.drop_table('scan_results')
op.drop_index(op.f('ix_translations_site_id'), table_name='translations')
op.drop_table('translations')
op.drop_table('site_configs')
op.drop_index(op.f('ix_scan_jobs_status'), table_name='scan_jobs')
op.drop_index(op.f('ix_scan_jobs_site_id'), table_name='scan_jobs')
op.drop_table('scan_jobs')
op.drop_index(op.f('ix_cookies_site_id'), table_name='cookies')
op.drop_index(op.f('ix_cookies_name'), table_name='cookies')
op.drop_index(op.f('ix_cookies_category_id'), table_name='cookies')
op.drop_table('cookies')
op.drop_index(op.f('ix_cookie_allow_list_site_id'), table_name='cookie_allow_list')
op.drop_table('cookie_allow_list')
op.drop_index(op.f('ix_consent_records_visitor_id'), table_name='consent_records')
op.drop_index(op.f('ix_consent_records_site_id'), table_name='consent_records')
op.drop_index(op.f('ix_consent_records_consented_at'), table_name='consent_records')
op.drop_index(op.f('ix_consent_records_ab_test_id'), table_name='consent_records')
op.drop_table('consent_records')
op.drop_index(op.f('ix_sites_site_group_id'), table_name='sites')
op.drop_index(op.f('ix_sites_organisation_id'), table_name='sites')
op.drop_index(op.f('ix_sites_domain'), table_name='sites')
op.drop_table('sites')
op.drop_table('site_group_configs')
op.drop_index(op.f('ix_users_organisation_id'), table_name='users')
op.drop_index(op.f('ix_users_email'), table_name='users')
op.drop_table('users')
op.drop_index(op.f('ix_site_groups_organisation_id'), table_name='site_groups')
op.drop_table('site_groups')
op.drop_table('org_configs')
op.drop_index(op.f('ix_known_cookies_name_pattern'), table_name='known_cookies')
op.drop_table('known_cookies')
op.drop_index(op.f('ix_organisations_slug'), table_name='organisations')
op.drop_table('organisations')
op.drop_table('cookie_categories')
# ### end Alembic commands ###

View File

@@ -0,0 +1,36 @@
"""composite index on consent_records(site_id, consented_at)
Revision ID: 0002
Revises: 0001
Create Date: 2026-04-13
The most common analytic query pattern is "consents for site X in date
range" (consent rates, trends, regional breakdowns). The single-column
indexes on ``site_id`` and ``consented_at`` each help a little, but a
composite index is materially faster for the combined filter.
"""
from typing import Sequence, Union
from alembic import op
revision: str = "0002"
down_revision: Union[str, Sequence[str], None] = "0001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_index(
"ix_consent_records_site_consented_at",
"consent_records",
["site_id", "consented_at"],
unique=False,
)
def downgrade() -> None:
op.drop_index(
"ix_consent_records_site_consented_at",
table_name="consent_records",
)

View File

@@ -0,0 +1,9 @@
# Open Cookie Database
The file `open-cookie-database.csv` is sourced from the
[Open Cookie Database](https://github.com/jkwakman/Open-Cookie-Database)
by jkwakman, licensed under the Creative Commons Attribution 4.0 International
(CC BY 4.0) licence.
To update the database, download the latest CSV from the repository above and
replace this file, then run `make seed` to reload the data.

File diff suppressed because it is too large Load Diff

65
apps/api/fly.toml Normal file
View File

@@ -0,0 +1,65 @@
# Fly.io configuration for the ConsentOS API
# See https://fly.io/docs/reference/configuration/ for reference.
#
# This app runs three process groups from the same Docker image:
# - app: FastAPI web server (handles HTTP traffic)
# - worker: Celery worker (processes scan jobs and background tasks)
# - beat: Celery beat scheduler (triggers periodic tasks)
app = "consentos-api"
primary_region = "lhr" # London
[build]
dockerfile = "Dockerfile"
[env]
ENVIRONMENT = "production"
LOG_LEVEL = "INFO"
PORT = "8000"
RATE_LIMIT_ENABLED = "true"
RATE_LIMIT_PER_MINUTE = "120"
# ── Migrations run once per deployment, before processes start ──────
[deploy]
release_command = "python -m alembic upgrade head"
# ── Process groups ──────────────────────────────────────────────────
[processes]
app = "sh start.sh"
worker = "celery -A src.celery_app worker --loglevel=info --concurrency=2"
beat = "celery -A src.celery_app beat --loglevel=info"
# ── HTTP service (only the 'app' process serves HTTP) ───────────────
[http_service]
internal_port = 8000
force_https = true
auto_stop_machines = "stop"
auto_start_machines = true
min_machines_running = 0
processes = ["app"]
[http_service.concurrency]
type = "requests"
hard_limit = 250
soft_limit = 200
# ── VM sizing per process ───────────────────────────────────────────
# The app and beat processes are lightweight; the worker needs more
# memory for processing scan results.
[[vm]]
memory = "256mb"
cpu_kind = "shared"
cpus = 1
processes = ["app"]
[[vm]]
memory = "256mb"
cpu_kind = "shared"
cpus = 1
processes = ["worker"]
[[vm]]
memory = "256mb"
cpu_kind = "shared"
cpus = 1
processes = ["beat"]

64
apps/api/pyproject.toml Normal file
View File

@@ -0,0 +1,64 @@
[project]
name = "consentos-api"
version = "0.1.0"
description = "ConsentOS — API service"
license = "Elastic-2.0"
requires-python = ">=3.12"
dependencies = [
"fastapi>=0.115,<1",
"uvicorn[standard]>=0.34,<1",
"sqlalchemy[asyncio]>=2.0,<3",
"asyncpg>=0.30,<1",
"alembic>=1.14,<2",
"pydantic>=2.0,<3",
"pydantic-settings>=2.0,<3",
"python-jose[cryptography]>=3.3,<4",
"bcrypt>=4.0,<5",
"redis>=5.0,<6",
"celery>=5.4,<6",
"httpx>=0.28,<1",
"structlog>=24.0,<25",
"psycopg2-binary>=2.9,<3",
"email-validator>=2.0,<3",
"jinja2>=3.1,<4",
"markupsafe>=2.1,<3",
"reportlab>=4.0,<5",
"geoip2>=4.8,<5",
]
[project.optional-dependencies]
dev = [
"pytest>=8.0,<9",
"pytest-asyncio>=0.24,<1",
"pytest-cov>=6.0,<7",
"httpx>=0.28,<1",
"ruff>=0.9,<1",
"mypy>=1.13,<2",
]
[build-system]
requires = ["setuptools>=68"]
build-backend = "setuptools.build_meta"
[tool.setuptools.packages.find]
where = ["."]
include = ["src*"]
[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"
asyncio_default_fixture_loop_scope = "session"
filterwarnings = ["ignore::DeprecationWarning"]
[tool.ruff]
target-version = "py312"
line-length = 100
[tool.ruff.lint]
select = ["E", "F", "I", "N", "W", "UP", "B", "SIM", "RUF"]
ignore = ["B008"] # Depends() in FastAPI defaults is idiomatic
[tool.mypy]
python_version = "3.12"
strict = true
plugins = ["pydantic.mypy"]

0
apps/api/src/__init__.py Normal file
View File

View File

@@ -0,0 +1,89 @@
"""Celery application and task definitions for the CMP API.
Provides async-compatible scan scheduling via Celery with Redis as the
broker and result backend.
"""
import ssl
from celery import Celery
from celery.schedules import crontab
from src.config.settings import get_settings
settings = get_settings()
# Named `app` by Celery convention — the CLI finds it via -A src.celery_app
app = Celery(
"cmp",
broker=settings.redis_url,
backend=settings.redis_url,
)
# When using rediss:// (TLS) — e.g. Upstash — Celery requires explicit
# SSL certificate verification settings for both broker and backend.
_conf: dict = {
"task_serializer": "json",
"accept_content": ["json"],
"result_serializer": "json",
"timezone": "UTC",
"enable_utc": True,
"task_track_started": True,
"task_acks_late": True,
"worker_prefetch_multiplier": 1,
}
if settings.redis_url.startswith("rediss://"):
_conf["broker_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
_conf["redis_backend_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
app.conf.update(**_conf)
# ── Beat schedule (periodic tasks) ──────────────────────────────────
app.conf.beat_schedule = {
"check-scheduled-scans": {
"task": "src.tasks.scanner.check_scheduled_scans",
"schedule": crontab(minute="*/15"), # Every 15 minutes
},
"recover-stale-scans": {
"task": "src.tasks.scanner.recover_stale_scans",
"schedule": crontab(minute="*/5"), # Every 5 minutes
},
"purge-expired-consent-records": {
"task": "src.tasks.retention.purge_expired_consent_records",
"schedule": crontab(hour="1", minute="0"), # Daily at 01:00 UTC
},
}
# ── Explicit task imports ───────────────────────────────────────────
# Must be at the bottom to avoid circular imports. These ensure the
# worker process registers all @app.task definitions on startup.
import src.tasks.retention # noqa: E402
import src.tasks.scanner # noqa: E402, F401
# EE tasks are registered conditionally — they only exist in EE mode.
try:
import ee.api.src.tasks.compliance_scanner
import ee.api.src.tasks.compliance_scoring
import ee.api.src.tasks.retention # noqa: F401
app.conf.beat_schedule.update(
{
"check-scheduled-compliance-scans": {
"task": "src.tasks.compliance_scanner.check_scheduled_compliance_scans",
"schedule": crontab(hour="3", minute="0"),
},
"compute-daily-compliance-scores": {
"task": "src.tasks.compliance_scoring.compute_daily_scores",
"schedule": crontab(hour="4", minute="0"),
},
"run-retention-purge": {
"task": "src.tasks.retention.run_retention_purge",
"schedule": crontab(hour="2", minute="0"),
},
}
)
except ImportError:
pass

View File

View File

@@ -0,0 +1,40 @@
"""One-shot bootstrap of an initial organisation and owner user.
Usage:
python -m src.cli.bootstrap_admin
Reads ``INITIAL_ADMIN_EMAIL`` and ``INITIAL_ADMIN_PASSWORD`` (plus the
optional ``INITIAL_ADMIN_FULL_NAME``, ``INITIAL_ORG_NAME``, and
``INITIAL_ORG_SLUG``) from the environment. If the ``users`` table is
empty and both credentials are set, creates the org and owner user so
the operator can log in to the admin UI. Idempotent: if any user
already exists, exits 0 without touching the database.
Intended to be run as a one-shot init container *after* the database
migrations have been applied — typically via ``depends_on`` with
``service_healthy`` on the API container.
"""
from __future__ import annotations
import asyncio
import sys
from src.config.logging import setup_logging
from src.config.settings import get_settings
from src.services.bootstrap import bootstrap_initial_admin
async def _main() -> int:
settings = get_settings()
setup_logging(settings.log_level)
await bootstrap_initial_admin(settings)
return 0
def main() -> None:
sys.exit(asyncio.run(_main()))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,137 @@
"""Seed the known_cookies table from the Open Cookie Database CSV.
Usage:
python -m src.cli.seed_known_cookies [--csv PATH] [--clear]
The Open Cookie Database is a community-maintained catalogue of ~2,200+
cookie patterns. See https://github.com/jkwakman/Open-Cookie-Database
"""
from __future__ import annotations
import argparse
import csv
import sys
import uuid
from pathlib import Path
import sqlalchemy as sa
# ---------------------------------------------------------------------------
# Category mapping: Open Cookie Database category → CMP slug
# ---------------------------------------------------------------------------
_CATEGORY_MAP: dict[str, str] = {
"Functional": "functional",
"Analytics": "analytics",
"Marketing": "marketing",
"Personalization": "personalisation",
"Security": "necessary",
}
_DEFAULT_CSV = Path(__file__).resolve().parent.parent.parent / "data" / "open-cookie-database.csv"
def _build_sync_url(async_url: str) -> str:
"""Convert an asyncpg DSN to a psycopg2 DSN for one-off scripts."""
return async_url.replace("postgresql+asyncpg://", "postgresql://")
def seed(csv_path: Path, *, clear: bool = False) -> int:
"""Read the CSV and upsert rows into known_cookies.
Returns the number of rows inserted.
"""
from src.config.settings import get_settings
settings = get_settings()
engine = sa.create_engine(_build_sync_url(settings.database_url))
with engine.begin() as conn:
# Build slug → category_id lookup
rows = conn.execute(sa.text("SELECT id, slug FROM cookie_categories"))
slug_to_id: dict[str, str] = {r[1]: str(r[0]) for r in rows}
if clear:
conn.execute(sa.text("DELETE FROM known_cookies"))
inserted = 0
with csv_path.open(newline="", encoding="utf-8") as fh:
reader = csv.DictReader(fh)
for row in reader:
category = row.get("Category", "").strip()
slug = _CATEGORY_MAP.get(category)
if not slug or slug not in slug_to_id:
continue
name = row.get("Cookie / Data Key name", "").strip()
if not name:
continue
domain_raw = row.get("Domain", "").strip()
domain = domain_raw if domain_raw else "*"
wildcard = row.get("Wildcard match", "0").strip() == "1"
description = row.get("Description", "").strip() or None
vendor = row.get("Platform", "").strip() or None
# Build pattern: if wildcard, append * to name for glob matching
name_pattern = f"{name}*" if wildcard else name
is_regex = False
conn.execute(
sa.text(
"""
INSERT INTO known_cookies
(id, name_pattern, domain_pattern, category_id,
vendor, description, is_regex, created_at, updated_at)
VALUES
(:id, :name_pattern, :domain_pattern, :category_id,
:vendor, :description, :is_regex, NOW(), NOW())
ON CONFLICT (name_pattern, domain_pattern) DO UPDATE SET
category_id = EXCLUDED.category_id,
vendor = EXCLUDED.vendor,
description = EXCLUDED.description,
is_regex = EXCLUDED.is_regex,
updated_at = NOW()
"""
),
{
"id": str(uuid.uuid4()),
"name_pattern": name_pattern,
"domain_pattern": domain,
"category_id": slug_to_id[slug],
"vendor": vendor,
"description": description,
"is_regex": is_regex,
},
)
inserted += 1
return inserted
def main() -> None:
parser = argparse.ArgumentParser(description="Seed known cookies from Open Cookie Database")
parser.add_argument(
"--csv",
type=Path,
default=_DEFAULT_CSV,
help="Path to the Open Cookie Database CSV (default: bundled copy)",
)
parser.add_argument(
"--clear",
action="store_true",
help="Delete all existing known_cookies before importing",
)
args = parser.parse_args()
if not args.csv.exists():
print(f"Error: CSV not found at {args.csv}", file=sys.stderr)
sys.exit(1)
count = seed(args.csv, clear=args.clear)
print(f"Seeded {count} known cookie patterns from {args.csv.name}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
from src.config.settings import Settings, get_settings
__all__ = ["Settings", "get_settings"]

View File

@@ -0,0 +1,26 @@
"""Edition detection for the open-core architecture.
Determines whether the application is running in community edition (CE)
or enterprise edition (EE) based on the availability of the ``ee``
package.
"""
from __future__ import annotations
from functools import lru_cache
@lru_cache(maxsize=1)
def is_ee() -> bool:
"""Return ``True`` if enterprise extensions are available."""
try:
import ee # noqa: F401
return True
except ImportError:
return False
def edition_name() -> str:
"""Return a human-readable edition label (``"ee"`` or ``"ce"``)."""
return "ee" if is_ee() else "ce"

View File

@@ -0,0 +1,26 @@
import logging
import sys
import structlog
def setup_logging(log_level: str = "INFO") -> None:
"""Configure structured logging with structlog."""
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info,
structlog.processors.TimeStamper(fmt="iso"),
structlog.dev.ConsoleRenderer()
if sys.stderr.isatty()
else structlog.processors.JSONRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, log_level.upper(), logging.INFO)
),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(),
cache_logger_on_first_use=True,
)

View File

@@ -0,0 +1,166 @@
from functools import lru_cache
from pydantic import model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
# Placeholder value — the application refuses to start in non-dev
# environments if ``jwt_secret_key`` is left at this literal.
_JWT_PLACEHOLDER = "CHANGE-ME-in-production"
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
)
# Application
app_name: str = "ConsentOS API"
app_version: str = "0.1.0"
debug: bool = False
environment: str = "development"
log_level: str = "INFO"
# Server
host: str = "0.0.0.0"
port: int = 8000
allowed_origins: str = "http://localhost:5173"
@property
def allowed_origins_list(self) -> list[str]:
"""Parse allowed_origins as a comma-separated string."""
return [o.strip() for o in self.allowed_origins.split(",") if o.strip()]
# Database
database_url: str = "postgresql+asyncpg://consentos:consentos@localhost:5432/consentos"
database_echo: bool = False
database_pool_size: int = 20
database_max_overflow: int = 10
# Redis
redis_url: str = "redis://localhost:6379/0"
# JWT
jwt_secret_key: str = _JWT_PLACEHOLDER
jwt_algorithm: str = "HS256"
jwt_access_token_expire_minutes: int = 30
jwt_refresh_token_expire_days: int = 7
# Pseudonymisation — HMAC key for IP / UA hashing on consent records.
# Defaults to deriving from the JWT secret if not explicitly set.
pseudonymisation_secret: str | None = None
# Bootstrap token — required as ``X-Admin-Bootstrap-Token`` on
# ``POST /api/v1/organisations/``. When unset (the default), the
# endpoint is disabled. Rotate or unset after your first org is
# provisioned to prevent further tenant creation.
admin_bootstrap_token: str | None = None
# Initial admin bootstrap — on first startup, if the ``users`` table
# is empty and both credentials below are set, the API creates an
# organisation and an owner user so the operator can log in to the
# admin UI for the first time. Idempotent: once any user exists this
# is a no-op, so the variables can safely remain set across restarts.
# Rotate the password via the admin UI after first login.
initial_admin_email: str | None = None
initial_admin_password: str | None = None
initial_admin_full_name: str = "Administrator"
initial_org_name: str = "Default Organisation"
initial_org_slug: str = "default"
# CDN — public URL where banner scripts (consent-loader.js,
# consent-bundle.js) are hosted. In dev the admin UI dog-foods
# the banner so localhost:5173 works for testing; in production
# this should be a real CDN URL (CloudFlare Pages, S3+CloudFront,
# Cloud CDN, etc.) — see docs for setup.
cdn_base_url: str = "http://localhost:5173"
# Scanner service
scanner_service_url: str = "http://localhost:8001"
scanner_timeout_seconds: int = 300
# Extra GeoIP country header — checked *before* the built-in list
# (``cf-ipcountry``, ``x-vercel-ip-country``, ``x-appengine-country``,
# ``x-country-code``). Set this when running behind a CDN/load
# balancer that uses a non-standard header, e.g. Google Cloud
# Load Balancer's ``x-gclb-country`` or an internal edge proxy.
# Header names are case-insensitive. Leave unset if one of the
# built-in headers is fine.
geoip_country_header: str | None = None
# Subdivision/state code header — optional companion to
# ``GEOIP_COUNTRY_HEADER``. When both are set the API pairs them to
# produce region keys like ``US-CA`` or ``GB-SCT`` (ISO 3166-2
# subdivision without the country prefix). Different CDNs expose
# this under different names: Cloudflare Enterprise uses
# ``cf-region-code``, Vercel uses ``x-vercel-ip-country-region``,
# GCP Load Balancer uses ``x-gclb-region``, CloudFront functions
# use ``cloudfront-viewer-country-region``. Leave unset if you
# only need country-level granularity.
geoip_region_header: str | None = None
# Local MaxMind GeoLite2/GeoIP2 City database — used as a fallback
# when no CDN header is present. Download GeoLite2-City.mmdb from
# https://dev.maxmind.com/geoip/geolite2-free-geolocation-data and
# mount it into the container (e.g. ``/data/GeoLite2-City.mmdb``).
# When unset, lookups fall back to the free external ip-api.com
# service, which is rate-limited and should not be relied on in
# production.
geoip_maxmind_db_path: str | None = None
# Rate limiting — on by default. Public endpoints (banner config +
# consent submission) are internet-exposed and must not be DoS-able.
# Auth endpoints get a stricter bucket via ``RateLimitMiddleware``.
rate_limit_enabled: bool = True
rate_limit_per_minute: int = 120
@model_validator(mode="after")
def _check_production_safety(self) -> "Settings":
"""Refuse to start with unsafe defaults in non-dev environments."""
if self.environment.lower() in ("development", "dev", "local", "test"):
return self
errors: list[str] = []
if self.jwt_secret_key == _JWT_PLACEHOLDER:
errors.append(
"JWT_SECRET_KEY is set to the placeholder value "
f"{_JWT_PLACEHOLDER!r}. Generate a strong random value "
"(e.g. `openssl rand -base64 48`) and set it in the "
"environment before starting the API."
)
if "*" in self.allowed_origins_list:
errors.append(
"ALLOWED_ORIGINS contains '*'. Wildcard CORS combined with "
"allow_credentials=True is a credential-theft vector. "
"Set ALLOWED_ORIGINS to an explicit list of trusted origins."
)
if errors:
msg = "Refusing to start with unsafe configuration:\n - " + "\n - ".join(
errors,
)
raise ValueError(msg)
return self
@property
def pseudonymisation_key(self) -> bytes:
"""Return the HMAC key used for pseudonymising IP/UA values.
If ``pseudonymisation_secret`` is not set, derives a per-instance
key from the JWT secret so operators don't have to configure two
secrets. Using JWT_SECRET directly is acceptable because the
HMAC is one-way and the resulting hashes are not reversible.
"""
source = self.pseudonymisation_secret or self.jwt_secret_key
return source.encode("utf-8")
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()

View File

@@ -0,0 +1,3 @@
from src.db.session import get_db
__all__ = ["get_db"]

View File

@@ -0,0 +1,31 @@
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from src.config.settings import get_settings
settings = get_settings()
engine = create_async_engine(
settings.database_url,
echo=settings.database_echo,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
)
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Dependency that yields an async database session."""
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise

View File

@@ -0,0 +1,3 @@
from src.extensions.registry import discover_extensions, get_registry
__all__ = ["discover_extensions", "get_registry"]

View File

@@ -0,0 +1,197 @@
"""Extension registry for the open-core architecture.
Provides registration hooks that allow enterprise/commercial code to inject
routers, model modules, startup tasks, and OpenAPI tags into the core
application — without the core needing any direct knowledge of the
extensions.
In community edition (CE) mode, ``discover_extensions()`` is a no-op
because the ``ee`` package is not present.
"""
from __future__ import annotations
import importlib
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable, Coroutine
from typing import Any
from fastapi import APIRouter, FastAPI
logger = logging.getLogger(__name__)
@dataclass
class OpenAPITag:
"""Metadata for a FastAPI OpenAPI tag."""
name: str
description: str
@dataclass
class RouterEntry:
"""A router registered by an extension."""
router: APIRouter
prefix: str = "/api/v1"
tags: list[OpenAPITag] = field(default_factory=list)
@dataclass
class ExtensionRegistry:
"""Central registry for extension-contributed components.
Extensions call the module-level helper functions (``register_router``,
``register_model_module``, etc.) which delegate to the singleton
instance stored in ``_registry``.
"""
routers: list[RouterEntry] = field(default_factory=list)
model_modules: list[str] = field(default_factory=list)
startup_hooks: list[Callable[[FastAPI], Coroutine[Any, Any, None]]] = field(
default_factory=list,
)
config_enrichers: list[Callable] = field(default_factory=list)
consent_record_hooks: list[Callable] = field(default_factory=list)
# ------------------------------------------------------------------
# Registration helpers
# ------------------------------------------------------------------
def add_router(
self,
router: APIRouter,
*,
prefix: str = "/api/v1",
tags: list[OpenAPITag] | None = None,
) -> None:
self.routers.append(RouterEntry(router=router, prefix=prefix, tags=tags or []))
def add_model_module(self, module_path: str) -> None:
self.model_modules.append(module_path)
def add_startup_hook(
self,
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
) -> None:
self.startup_hooks.append(hook)
def add_config_enricher(self, enricher: Callable) -> None:
self.config_enrichers.append(enricher)
def add_consent_record_hook(self, hook: Callable) -> None:
self.consent_record_hooks.append(hook)
# ------------------------------------------------------------------
# Application wiring
# ------------------------------------------------------------------
def apply(self, app: FastAPI) -> None:
"""Mount all registered routers and tags onto *app*."""
for entry in self.routers:
# Inject OpenAPI tags
for tag in entry.tags:
existing = app.openapi_tags or []
if not any(t["name"] == tag.name for t in existing):
existing.append(
{"name": tag.name, "description": tag.description},
)
app.openapi_tags = existing
app.include_router(entry.router, prefix=entry.prefix)
if self.routers:
logger.info(
"Registered %d extension router(s)",
len(self.routers),
)
# Import model modules so SQLAlchemy picks them up
for mod in self.model_modules:
importlib.import_module(mod)
if self.model_modules:
logger.info(
"Registered %d extension model module(s)",
len(self.model_modules),
)
# Singleton ------------------------------------------------------------------
_registry = ExtensionRegistry()
def get_registry() -> ExtensionRegistry:
"""Return the global extension registry."""
return _registry
# Convenience module-level API -----------------------------------------------
def register_router(
router: APIRouter,
*,
prefix: str = "/api/v1",
tags: list[OpenAPITag] | None = None,
) -> None:
"""Register an API router to be mounted at startup."""
_registry.add_router(router, prefix=prefix, tags=tags)
def register_model_module(module_path: str) -> None:
"""Register a dotted module path whose SQLAlchemy models should be imported."""
_registry.add_model_module(module_path)
def register_startup_hook(
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
) -> None:
"""Register an async callable to run during application startup."""
_registry.add_startup_hook(hook)
def register_config_enricher(enricher: Callable) -> None:
"""Register a callable that enriches published config.
The callable signature is ``async (site_id: UUID, db: AsyncSession, config: dict) -> None``.
It should mutate *config* in-place to add extension-specific data
(e.g. A/B test variants).
"""
_registry.add_config_enricher(enricher)
def register_consent_record_hook(hook: Callable) -> None:
"""Register a callable invoked after a consent record is persisted.
The callable signature is ``async (db: AsyncSession, consent_record) -> None``.
It is called from ``POST /api/v1/consent`` after the record has been
flushed to the database. Typical use: generating a consent receipt
(EE), writing audit logs, firing webhooks.
"""
_registry.add_consent_record_hook(hook)
# Discovery ------------------------------------------------------------------
def discover_extensions() -> None:
"""Import the EE registration module if installed.
Enterprise edition is distributed as a separate ``consent-enterprise``
package. When installed in the same environment, importing
``ee.api.src.register`` triggers its side-effect registrations. In
community edition the import simply fails and we carry on.
"""
try:
import ee.api.src.register # noqa: F401
logger.info("Enterprise extensions loaded")
except ImportError:
logger.debug("No enterprise extensions found (CE mode)")

210
apps/api/src/main.py Normal file
View File

@@ -0,0 +1,210 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.config.edition import edition_name
from src.config.logging import setup_logging
from src.config.settings import get_settings
from src.extensions.registry import discover_extensions, get_registry
from src.middleware.rate_limit import RateLimitMiddleware
from src.middleware.security_headers import SecurityHeadersMiddleware
from src.routers import (
auth,
compliance,
config,
consent,
cookies,
org_config,
organisations,
scanner,
site_group_config,
site_groups,
sites,
translations,
users,
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application startup and shutdown lifecycle."""
settings = get_settings()
setup_logging(settings.log_level)
yield
def create_app() -> FastAPI:
"""Application factory."""
settings = get_settings()
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description=(
"Multi-tenant cookie consent management platform API. "
"Provides consent collection, cookie scanning, auto-blocking, "
"compliance checking, and analytics across multiple sites."
),
debug=settings.debug,
lifespan=lifespan,
openapi_tags=[
{
"name": "auth",
"description": "Authentication — login, token refresh, and current user.",
},
{
"name": "config",
"description": (
"Site configuration — public endpoints for the banner script "
"to fetch config, GeoIP-resolved config, and CDN publishing."
),
},
{
"name": "consent",
"description": (
"Consent recording and retrieval — public endpoints called "
"by the banner script to record visitor consent decisions."
),
},
{
"name": "sites",
"description": "Site and site config CRUD — manage domains and settings.",
},
{
"name": "cookies",
"description": (
"Cookie management — categories, discovered cookies, allow-list, "
"known cookies database, and auto-classification."
),
},
{
"name": "scanner",
"description": (
"Cookie scanner — trigger scans, view results, and receive "
"client-side cookie reports from the banner script."
),
},
{
"name": "compliance",
"description": (
"Compliance checking — run checks against GDPR, CNIL, CCPA, "
"ePrivacy, and LGPD frameworks."
),
},
{
"name": "organisations",
"description": "Organisation management — multi-tenant root entities.",
},
{
"name": "users",
"description": "User management — org-scoped users with role-based access.",
},
],
)
# Security headers
app.add_middleware(SecurityHeadersMiddleware)
# Rate limiting (must be added before CORS to count requests correctly)
if settings.rate_limit_enabled:
app.add_middleware(
RateLimitMiddleware,
redis_url=settings.redis_url,
requests_per_minute=settings.rate_limit_per_minute,
auth_requests_per_minute=10,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Core routers
api_prefix = "/api/v1"
app.include_router(auth.router, prefix=api_prefix)
app.include_router(config.router, prefix=api_prefix)
app.include_router(consent.router, prefix=api_prefix)
app.include_router(scanner.router, prefix=api_prefix)
app.include_router(compliance.router, prefix=api_prefix)
app.include_router(organisations.router, prefix=api_prefix)
app.include_router(org_config.router, prefix=api_prefix)
app.include_router(users.router, prefix=api_prefix)
app.include_router(site_groups.router, prefix=api_prefix)
app.include_router(site_group_config.router, prefix=api_prefix)
app.include_router(sites.router, prefix=api_prefix)
app.include_router(cookies.router, prefix=api_prefix)
app.include_router(translations.router, prefix=api_prefix)
app.include_router(translations.public_router, prefix=api_prefix)
# Discover and mount enterprise extensions (no-op in CE mode)
discover_extensions()
registry = get_registry()
registry.apply(app)
@app.get("/health", tags=["health"])
async def health() -> dict[str, str]:
"""Shallow liveness check.
Answers "is the process running?". Suitable for orchestrator
liveness probes. For deployment readiness, use
``/health/ready`` which verifies downstream dependencies.
"""
return {"status": "ok", "edition": edition_name()}
@app.get("/health/ready", tags=["health"])
async def health_ready() -> dict[str, object]:
"""Deep readiness check — verifies database and Redis.
Returns HTTP 503 if either dependency is unreachable so load
balancers route traffic away from broken instances.
"""
from fastapi import HTTPException
from sqlalchemy import text
from src.db.session import engine as db_engine
checks: dict[str, str] = {}
overall_ok = True
# Database
try:
async with db_engine.connect() as conn:
await conn.execute(text("SELECT 1"))
checks["database"] = "ok"
except Exception as exc:
checks["database"] = f"error: {type(exc).__name__}"
overall_ok = False
# Redis
try:
import redis.asyncio as aioredis
r = aioredis.from_url(settings.redis_url, decode_responses=True)
pong = await r.ping()
checks["redis"] = "ok" if pong else "error: ping failed"
if not pong:
overall_ok = False
await r.aclose()
except Exception as exc:
checks["redis"] = f"error: {type(exc).__name__}"
overall_ok = False
payload = {
"status": "ok" if overall_ok else "degraded",
"edition": edition_name(),
"checks": checks,
}
if not overall_ok:
raise HTTPException(status_code=503, detail=payload)
return payload
return app
app = create_app()

View File

View File

@@ -0,0 +1,111 @@
"""Redis-backed rate limiting middleware.
Applies per-IP rate limits to all incoming requests. Public endpoints
(consent recording, config fetching) are the primary protection target.
Uses a sliding window counter stored in Redis with automatic expiry.
"""
from __future__ import annotations
import logging
import time
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple per-IP rate limiter backed by Redis."""
def __init__(
self,
app: object,
redis_url: str = "redis://localhost:6379/0",
requests_per_minute: int = 120,
auth_requests_per_minute: int = 10,
) -> None:
super().__init__(app) # type: ignore[arg-type]
self.redis_url = redis_url
self.requests_per_minute = requests_per_minute
self.auth_requests_per_minute = auth_requests_per_minute
self._redis: object | None = None
async def _get_redis(self) -> object | None:
"""Lazy-initialise Redis connection."""
if self._redis is not None:
return self._redis
try:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self.redis_url, decode_responses=True)
return self._redis
except Exception:
logger.warning("Rate limiting disabled: Redis unavailable")
return None
def _get_client_ip(self, request: Request) -> str:
"""Extract the real client IP."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
if request.client:
return request.client.host
return "unknown"
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
# Skip rate limiting for health checks
if request.url.path in ("/health", "/health/ready", "/health/live"):
return await call_next(request)
r = await self._get_redis()
if r is None:
# Redis unavailable — allow request through
return await call_next(request)
# Auth endpoints get a stricter bucket to slow down credential
# stuffing — login, password reset, token refresh.
path = request.url.path
is_auth = path.startswith("/api/v1/auth/") and path not in ("/api/v1/auth/me",)
limit = self.auth_requests_per_minute if is_auth else self.requests_per_minute
bucket = "auth" if is_auth else "req"
client_ip = self._get_client_ip(request)
window = int(time.time() // 60)
key = f"cmp:rate:{bucket}:{client_ip}:{window}"
try:
current = await r.incr(key) # type: ignore[union-attr]
if current == 1:
await r.expire(key, 120) # type: ignore[union-attr]
if current > limit:
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please try again later."},
headers={
"Retry-After": "60",
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
remaining = max(0, limit - current)
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except Exception:
logger.debug("Rate limit check failed", exc_info=True)
return await call_next(request)

View File

@@ -0,0 +1,41 @@
"""Security headers middleware.
Adds standard security headers to all API responses:
- X-Content-Type-Options: nosniff
- X-Frame-Options: DENY
- X-XSS-Protection: 0 (disabled in favour of CSP)
- Referrer-Policy: strict-origin-when-cross-origin
- Content-Security-Policy: default-src 'none'
- Strict-Transport-Security (HSTS) in production
"""
from __future__ import annotations
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses."""
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "0"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = "default-src 'none'"
# HSTS — only on HTTPS requests (reverse proxy may terminate TLS)
if request.url.scheme == "https":
response.headers["Strict-Transport-Security"] = (
"max-age=63072000; includeSubDomains; preload"
)
return response

View File

@@ -0,0 +1,31 @@
from src.models.base import Base
from src.models.consent import ConsentRecord
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
from src.models.org_config import OrgConfig
from src.models.organisation import Organisation
from src.models.scan import ScanJob, ScanResult
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.models.site_group import SiteGroup
from src.models.site_group_config import SiteGroupConfig
from src.models.translation import Translation
from src.models.user import User
__all__ = [
"Base",
"ConsentRecord",
"Cookie",
"CookieAllowListEntry",
"CookieCategory",
"KnownCookie",
"OrgConfig",
"Organisation",
"ScanJob",
"ScanResult",
"Site",
"SiteConfig",
"SiteGroup",
"SiteGroupConfig",
"Translation",
"User",
]

View File

@@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Base class for all SQLAlchemy models."""
pass
class TimestampMixin:
"""Mixin that adds created_at and updated_at columns."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
class UUIDPrimaryKeyMixin:
"""Mixin that adds a UUID primary key."""
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
class SoftDeleteMixin:
"""Mixin that adds soft delete support."""
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
default=None,
)

View File

@@ -0,0 +1,81 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from src.models.base import Base, UUIDPrimaryKeyMixin
class ConsentRecord(UUIDPrimaryKeyMixin, Base):
"""Audit trail of every consent event. Partitioned by month for performance."""
__tablename__ = "consent_records"
__table_args__ = (
# Composite index for the most common analytics query pattern:
# "records for site X between dates A and B". The (site_id,
# consented_at DESC) ordering also supports "latest consents
# for site X" without an extra sort.
Index(
"ix_consent_records_site_consented_at",
"site_id",
"consented_at",
),
)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Visitor identification (anonymous)
visitor_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
ip_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
user_agent_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
# Consent details
action: Mapped[str] = mapped_column(String(30), nullable=False)
categories_accepted: Mapped[list] = mapped_column(JSONB, nullable=False)
categories_rejected: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# TCF
tc_string: Mapped[str | None] = mapped_column(Text, nullable=True)
# GCM state at time of consent
gcm_state: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# GPP
gpp_string: Mapped[str | None] = mapped_column(Text, nullable=True)
# GPC
gpc_detected: Mapped[bool | None] = mapped_column(nullable=True)
gpc_honoured: Mapped[bool | None] = mapped_column(nullable=True)
# A/B testing — soft references to EE `ab_tests` / `ab_test_variants`
# tables. Intentionally *no* FK constraint so the core schema works
# without the EE extension installed.
ab_test_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
nullable=True,
index=True,
)
ab_variant_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
nullable=True,
)
# Context
page_url: Mapped[str | None] = mapped_column(Text, nullable=True)
country_code: Mapped[str | None] = mapped_column(String(5), nullable=True)
region_code: Mapped[str | None] = mapped_column(String(10), nullable=True)
# Timestamp
consented_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
index=True,
)

View File

@@ -0,0 +1,130 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class CookieCategory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Cookie category taxonomy (necessary, functional, analytics, marketing, personalisation)."""
__tablename__ = "cookie_categories"
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
is_essential: Mapped[bool] = mapped_column(default=False, nullable=False)
display_order: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
# TCF purpose mapping
tcf_purpose_ids: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# Google Consent Mode consent type mapping
gcm_consent_types: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# Relationships
cookies: Mapped[list["Cookie"]] = relationship(back_populates="category")
allow_list_entries: Mapped[list["CookieAllowListEntry"]] = relationship(
back_populates="category"
)
class Cookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A cookie discovered on a site via scanning or client-side reporting."""
__tablename__ = "cookies"
__table_args__ = (
UniqueConstraint(
"site_id",
"name",
"domain",
"storage_type",
name="uq_cookies_site_name_domain_type",
),
)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
category_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("cookie_categories.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
domain: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
path: Mapped[str | None] = mapped_column(String(500), nullable=True)
max_age_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
is_http_only: Mapped[bool | None] = mapped_column(nullable=True)
is_secure: Mapped[bool | None] = mapped_column(nullable=True)
same_site: Mapped[str | None] = mapped_column(String(10), nullable=True)
review_status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
first_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
last_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
# Relationships
site: Mapped["Site"] = relationship(back_populates="cookies") # noqa: F821
category: Mapped["CookieCategory | None"] = relationship(back_populates="cookies")
class CookieAllowListEntry(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Approved cookies per site with category assignment."""
__tablename__ = "cookie_allow_list"
__table_args__ = (
UniqueConstraint(
"site_id",
"name_pattern",
"domain_pattern",
name="uq_allow_list_site_name_domain",
),
)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
category_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
nullable=False,
)
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships
site: Mapped["Site"] = relationship(back_populates="cookie_allow_list") # noqa: F821
category: Mapped["CookieCategory"] = relationship(back_populates="allow_list_entries")
class KnownCookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Shared knowledge base of known cookie patterns for auto-categorisation."""
__tablename__ = "known_cookies"
__table_args__ = (
UniqueConstraint("name_pattern", "domain_pattern", name="uq_known_cookies_name_domain"),
)
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
category_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
nullable=False,
)
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
is_regex: Mapped[bool] = mapped_column(default=False, nullable=False)

View File

@@ -0,0 +1,64 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class OrgConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Organisation-level default configuration.
These defaults sit between system defaults and site config in the cascade:
System Defaults → Org Config → Site Group Config → Site Config → Regional Overrides
"""
__tablename__ = "org_configs"
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
unique=True,
nullable=False,
)
# Blocking mode
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# TCF
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
# GPP (Global Privacy Platform)
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# GPC (Global Privacy Control)
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
# Google Consent Mode
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# Shopify Customer Privacy API
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
# Banner
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Scanning
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Consent
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Relationship
organisation: Mapped["Organisation"] = relationship(back_populates="org_config") # noqa: F821

View File

@@ -0,0 +1,26 @@
from sqlalchemy import String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class Organisation(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""Multi-tenant root entity. Each organisation has multiple sites and users."""
__tablename__ = "organisations"
name: Mapped[str] = mapped_column(String(255), nullable=False)
slug: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
billing_plan: Mapped[str] = mapped_column(String(50), server_default="free", nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships
users: Mapped[list["User"]] = relationship(back_populates="organisation") # noqa: F821
sites: Mapped[list["Site"]] = relationship(back_populates="organisation") # noqa: F821
site_groups: Mapped[list["SiteGroup"]] = relationship( # noqa: F821
back_populates="organisation"
)
org_config: Mapped["OrgConfig | None"] = relationship( # noqa: F821
back_populates="organisation", uselist=False
)

View File

@@ -0,0 +1,68 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class ScanJob(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A cookie scanning job for a site."""
__tablename__ = "scan_jobs"
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
status: Mapped[str] = mapped_column(
String(20), server_default="pending", nullable=False, index=True
)
trigger: Mapped[str] = mapped_column(String(20), server_default="manual", nullable=False)
pages_scanned: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
pages_total: Mapped[int | None] = mapped_column(Integer, nullable=True)
cookies_found: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
# Relationships
site: Mapped["Site"] = relationship(back_populates="scan_jobs") # noqa: F821
results: Mapped[list["ScanResult"]] = relationship(back_populates="scan_job")
class ScanResult(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Individual result from a scan — a cookie found on a specific page."""
__tablename__ = "scan_results"
scan_job_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("scan_jobs.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
page_url: Mapped[str] = mapped_column(Text, nullable=False)
cookie_name: Mapped[str] = mapped_column(String(255), nullable=False)
cookie_domain: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
script_source: Mapped[str | None] = mapped_column(Text, nullable=True)
auto_category: Mapped[str | None] = mapped_column(String(50), nullable=True)
initiator_chain: Mapped[list[str] | None] = mapped_column(
ARRAY(Text), nullable=True, comment="Ordered script URLs from root initiator to leaf"
)
found_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
# Relationships
scan_job: Mapped["ScanJob"] = relationship(back_populates="results")

View File

@@ -0,0 +1,48 @@
import uuid
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import ARRAY, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class Site(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""A domain being managed for cookie consent, belongs to an organisation."""
__tablename__ = "sites"
__table_args__ = (UniqueConstraint("organisation_id", "domain", name="uq_sites_org_domain"),)
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
domain: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
display_name: Mapped[str] = mapped_column(String(255), nullable=False)
is_active: Mapped[bool] = mapped_column(default=True, nullable=False)
additional_domains: Mapped[list[str] | None] = mapped_column(
ARRAY(String(255)), nullable=True, server_default=None
)
site_group_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("site_groups.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Relationships
organisation: Mapped["Organisation"] = relationship(back_populates="sites") # noqa: F821
site_group: Mapped["SiteGroup | None"] = relationship(back_populates="sites") # noqa: F821
config: Mapped["SiteConfig | None"] = relationship( # noqa: F821
back_populates="site", uselist=False
)
cookies: Mapped[list["Cookie"]] = relationship(back_populates="site") # noqa: F821
cookie_allow_list: Mapped[list["CookieAllowListEntry"]] = relationship( # noqa: F821
back_populates="site"
)
scan_jobs: Mapped[list["ScanJob"]] = relationship(back_populates="site") # noqa: F821
translations: Mapped[list["Translation"]] = relationship( # noqa: F821
back_populates="site"
)

View File

@@ -0,0 +1,63 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class SiteConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Full configuration for a site: blocking mode, TCF, GCM, banner, scanning, consent."""
__tablename__ = "site_configs"
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
unique=True,
nullable=False,
)
# Blocking mode
blocking_mode: Mapped[str] = mapped_column(String(20), server_default="opt_in", nullable=False)
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# TCF
tcf_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
# GPP (Global Privacy Platform)
gpp_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# GPC (Global Privacy Control)
gpc_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
gpc_global_honour: Mapped[bool] = mapped_column(default=False, nullable=False)
# Google Consent Mode
gcm_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# Shopify Customer Privacy API
shopify_privacy_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
# Banner
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
display_mode: Mapped[str] = mapped_column(
String(30), server_default="bottom_banner", nullable=False
)
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Scanning
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
scan_max_pages: Mapped[int] = mapped_column(Integer, server_default="50", nullable=False)
# Consent
consent_expiry_days: Mapped[int] = mapped_column(Integer, server_default="365", nullable=False)
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Relationship
site: Mapped["Site"] = relationship(back_populates="config") # noqa: F821

View File

@@ -0,0 +1,32 @@
import uuid
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class SiteGroup(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""A logical grouping of sites within an organisation (e.g. a brand)."""
__tablename__ = "site_groups"
__table_args__ = (UniqueConstraint("organisation_id", "name", name="uq_site_groups_org_name"),)
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships
organisation: Mapped["Organisation"] = relationship( # noqa: F821
back_populates="site_groups"
)
sites: Mapped[list["Site"]] = relationship(back_populates="site_group") # noqa: F821
group_config: Mapped["SiteGroupConfig | None"] = relationship( # noqa: F821
back_populates="site_group", uselist=False
)

View File

@@ -0,0 +1,63 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class SiteGroupConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Site-group-level default configuration.
These defaults sit between org defaults and site config in the cascade:
System Defaults -> Org Config -> Site Group Config -> Site Config -> Regional Overrides
"""
__tablename__ = "site_group_configs"
site_group_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("site_groups.id", ondelete="CASCADE"),
unique=True,
nullable=False,
)
# Blocking mode
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# TCF
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
# GPP (Global Privacy Platform)
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# GPC (Global Privacy Control)
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
# Google Consent Mode
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# Shopify Customer Privacy API
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
# Banner
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Scanning
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Consent
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Relationship
site_group: Mapped["SiteGroup"] = relationship(back_populates="group_config") # noqa: F821

View File

@@ -0,0 +1,26 @@
import uuid
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class Translation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Internationalisation strings per site per locale."""
__tablename__ = "translations"
__table_args__ = (UniqueConstraint("site_id", "locale", name="uq_translations_site_locale"),)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
locale: Mapped[str] = mapped_column(String(10), nullable=False)
strings: Mapped[dict] = mapped_column(JSONB, nullable=False)
# Relationships
site: Mapped["Site"] = relationship(back_populates="translations") # noqa: F821

View File

@@ -0,0 +1,31 @@
import uuid
from sqlalchemy import ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class User(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""User account, scoped to an organisation with a role."""
__tablename__ = "users"
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
full_name: Mapped[str] = mapped_column(String(255), nullable=False)
role: Mapped[str] = mapped_column(
String(20),
nullable=False,
server_default="viewer",
)
# Relationships
organisation: Mapped["Organisation"] = relationship(back_populates="users") # noqa: F821

View File

View File

@@ -0,0 +1,108 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from jose import JWTError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import get_settings
from src.db import get_db
from src.models.user import User
from src.schemas.auth import CurrentUser, LoginRequest, RefreshRequest, TokenResponse
from src.services.auth import (
create_access_token,
create_refresh_token,
decode_token,
verify_password,
)
from src.services.dependencies import get_current_user
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/login", response_model=TokenResponse)
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
"""Authenticate a user with email and password, return JWT tokens."""
result = await db.execute(
select(User).where(User.email == body.email, User.deleted_at.is_(None))
)
user = result.scalar_one_or_none()
if user is None or not verify_password(body.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
settings = get_settings()
access_token = create_access_token(
user_id=user.id,
organisation_id=user.organisation_id,
role=user.role,
email=user.email,
)
refresh_token = create_refresh_token(
user_id=user.id,
organisation_id=user.organisation_id,
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.jwt_access_token_expire_minutes * 60,
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(
body: RefreshRequest,
db: AsyncSession = Depends(get_db),
) -> TokenResponse:
"""Exchange a valid refresh token for a new access/refresh token pair."""
try:
payload = decode_token(body.refresh_token)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
) from exc
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token is not a refresh token",
)
user_id = uuid.UUID(payload["sub"])
result = await db.execute(select(User).where(User.id == user_id, User.deleted_at.is_(None)))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User no longer exists",
)
settings = get_settings()
access_token = create_access_token(
user_id=user.id,
organisation_id=user.organisation_id,
role=user.role,
email=user.email,
)
new_refresh_token = create_refresh_token(
user_id=user.id,
organisation_id=user.organisation_id,
)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=settings.jwt_access_token_expire_minutes * 60,
)
@router.get("/me", response_model=CurrentUser)
async def get_me(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
"""Return the currently authenticated user's profile from the JWT."""
return current_user

View File

@@ -0,0 +1,135 @@
"""Compliance checking endpoints.
Evaluates a site's configuration against regulatory frameworks (GDPR, CNIL,
CCPA, ePrivacy, LGPD) and returns per-framework compliance reports with scores,
issues, and recommendations.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.cookie import Cookie
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.schemas.compliance import (
ComplianceCheckRequest,
ComplianceCheckResponse,
Framework,
)
from src.services.compliance import (
SiteContext,
calculate_overall_score,
run_compliance_check,
)
from src.services.dependencies import get_current_user
router = APIRouter(prefix="/compliance", tags=["compliance"])
async def _build_site_context(
site_id: uuid.UUID,
db: AsyncSession,
) -> SiteContext:
"""Load site config and cookie stats to build a SiteContext."""
# Fetch site config
result = await db.execute(
select(SiteConfig).where(
SiteConfig.site_id == site_id,
SiteConfig.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
# Fetch cookie statistics
total_q = await db.execute(
select(func.count()).select_from(Cookie).where(Cookie.site_id == site_id)
)
total_cookies = total_q.scalar() or 0
uncat_q = await db.execute(
select(func.count())
.select_from(Cookie)
.where(
Cookie.site_id == site_id,
Cookie.category_id.is_(None),
)
)
uncategorised_cookies = uncat_q.scalar() or 0
if config is None:
return SiteContext(
total_cookies=total_cookies,
uncategorised_cookies=uncategorised_cookies,
)
banner_config = config.banner_config or {}
return SiteContext(
blocking_mode=config.blocking_mode,
regional_modes=config.regional_modes,
tcf_enabled=config.tcf_enabled,
gcm_enabled=config.gcm_enabled,
consent_expiry_days=config.consent_expiry_days,
privacy_policy_url=config.privacy_policy_url,
display_mode=config.display_mode,
banner_config=config.banner_config,
total_cookies=total_cookies,
uncategorised_cookies=uncategorised_cookies,
has_reject_button=banner_config.get("show_reject_all", True),
has_granular_choices=banner_config.get("show_category_toggles", True),
has_cookie_wall=banner_config.get("cookie_wall", False),
pre_ticked_boxes=banner_config.get("pre_ticked", False),
)
@router.post(
"/check/{site_id}",
response_model=ComplianceCheckResponse,
status_code=status.HTTP_200_OK,
)
async def check_compliance(
site_id: uuid.UUID,
body: ComplianceCheckRequest | None = None,
db: AsyncSession = Depends(get_db),
_user=Depends(get_current_user),
) -> ComplianceCheckResponse:
"""Run compliance checks against a site's configuration."""
# Verify site exists
site_result = await db.execute(
select(Site).where(Site.id == site_id, Site.deleted_at.is_(None))
)
site = site_result.scalar_one_or_none()
if site is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site not found",
)
ctx = await _build_site_context(site_id, db)
frameworks = body.frameworks if body else None
results = run_compliance_check(ctx, frameworks)
overall_score = calculate_overall_score(results)
return ComplianceCheckResponse(
site_id=str(site_id),
results=results,
overall_score=overall_score,
)
@router.get("/frameworks", response_model=list[dict])
async def list_frameworks() -> list[dict]:
"""List all available compliance frameworks."""
return [
{"id": fw.value, "name": fw.value.upper(), "description": desc}
for fw, desc in [
(Framework.GDPR, "EU General Data Protection Regulation"),
(Framework.CNIL, "French Data Protection Authority (stricter GDPR)"),
(Framework.CCPA, "California Consumer Privacy Act / CPRA"),
(Framework.EPRIVACY, "EU ePrivacy Directive"),
(Framework.LGPD, "Brazilian General Data Protection Law"),
]
]

View File

@@ -0,0 +1,324 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.extensions.registry import get_registry
from src.models.org_config import OrgConfig
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.models.site_group_config import SiteGroupConfig
from src.schemas.auth import CurrentUser
from src.schemas.site import SiteConfigResponse
from src.services.config_resolver import (
CONFIG_FIELDS,
build_public_config,
orm_to_config_dict,
resolve_config,
)
from src.services.dependencies import require_role
from src.services.geoip import detect_region
from src.services.publisher import publish_site_config
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/sites/{site_id}", response_model=SiteConfigResponse)
async def get_public_site_config(
site_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Public endpoint: retrieve site config for the banner script. No auth required."""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
return config
@router.get("/sites/{site_id}/resolved")
async def get_resolved_config(
site_id: uuid.UUID,
region: str | None = None,
db: AsyncSession = Depends(get_db),
) -> dict:
"""Public endpoint: retrieve fully resolved config with regional overrides applied.
Applies the full cascade: System → Org → Group → Site → Regional.
"""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
config_dict = orm_to_config_dict(config, include_id=True)
# Load org defaults via the site
org_id = await _get_site_org_id(site_id, db)
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
# Load site group defaults
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
config_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
region=region,
)
return build_public_config(str(site_id), resolved)
@router.get("/sites/{site_id}/geo-resolved")
async def get_geo_resolved_config(
site_id: uuid.UUID,
request: Request,
db: AsyncSession = Depends(get_db),
) -> dict:
"""Public endpoint: resolve config using the visitor's detected region.
Detects the visitor's region from CDN headers or IP geolocation,
then applies regional blocking mode overrides automatically.
Uses the full cascade: System → Org → Group → Site → Regional.
"""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
# Detect region from request
geo = await detect_region(request)
config_dict = orm_to_config_dict(config, include_id=True)
org_id = await _get_site_org_id(site_id, db)
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
config_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
region=geo.region,
)
public = build_public_config(str(site_id), resolved)
# Include detected geo info so the banner can use it
public["detected_country"] = geo.country_code
public["detected_region"] = geo.region
return public
@router.get("/geo")
async def get_visitor_geo(request: Request) -> dict:
"""Public endpoint: return the detected region for the current visitor.
Useful for banner scripts that need to know the region before
fetching the full config.
"""
geo = await detect_region(request)
return {
"country_code": geo.country_code,
"region": geo.region,
}
@router.get("/sites/{site_id}/inheritance")
async def get_config_inheritance(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Return the full config inheritance chain for a site.
Shows the value at each level so the UI can display where each setting
comes from: system, org, group, or site.
"""
from src.services.config_resolver import SYSTEM_DEFAULTS
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
site_dict = orm_to_config_dict(config)
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
site_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
)
# For each config field, determine the source
sources: dict[str, dict] = {}
for field in CONFIG_FIELDS:
site_val = site_dict.get(field)
group_val = group_defaults.get(field) if group_defaults else None
org_val = org_defaults.get(field) if org_defaults else None
system_val = SYSTEM_DEFAULTS.get(field)
# Determine effective source (highest priority non-None wins)
if site_val is not None:
source = "site"
elif group_val is not None:
source = "group"
elif org_val is not None:
source = "org"
elif system_val is not None:
source = "system"
else:
source = "system"
sources[field] = {
"resolved_value": resolved.get(field),
"source": source,
"site_value": site_val,
"group_value": group_val,
"org_value": org_val,
"system_value": system_val,
}
return {
"site_id": str(site_id),
"site_group_id": str(group_id) if group_id else None,
"fields": sources,
}
@router.post("/sites/{site_id}/publish")
async def publish_config(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Publish fully-resolved site config to CDN. Requires admin role."""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
config_dict = orm_to_config_dict(config, include_id=True)
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
config_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
)
# Allow extensions to enrich the published config (e.g. A/B test data)
registry = get_registry()
for enricher in registry.config_enrichers:
await enricher(site_id, db, resolved)
publish_result = await publish_site_config(str(site_id), resolved)
if not publish_result.success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Publish failed: {publish_result.error}",
)
return {
"published": True,
"path": publish_result.path,
"published_at": publish_result.published_at,
}
# ── Helpers ──────────────────────────────────────────────────────────
async def _get_site_org_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
"""Look up the organisation_id for a site."""
result = await db.execute(select(Site.organisation_id).where(Site.id == site_id))
return result.scalar_one_or_none()
async def _get_site_group_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
"""Look up the site_group_id for a site."""
result = await db.execute(select(Site.site_group_id).where(Site.id == site_id))
return result.scalar_one_or_none()
async def _load_org_defaults(organisation_id: uuid.UUID, db: AsyncSession) -> dict | None:
"""Load the org-level config defaults, or None if not set."""
result = await db.execute(select(OrgConfig).where(OrgConfig.organisation_id == organisation_id))
org_config = result.scalar_one_or_none()
if org_config is None:
return None
return orm_to_config_dict(org_config)
async def _load_group_defaults(group_id: uuid.UUID, db: AsyncSession) -> dict | None:
"""Load the site-group-level config defaults, or None if not set."""
result = await db.execute(
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
)
group_config = result.scalar_one_or_none()
if group_config is None:
return None
return orm_to_config_dict(group_config)

View File

@@ -0,0 +1,125 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.extensions.registry import get_registry
from src.models.consent import ConsentRecord
from src.models.site import Site
from src.schemas.auth import CurrentUser
from src.schemas.consent import (
ConsentRecordCreate,
ConsentRecordResponse,
ConsentVerifyResponse,
)
from src.services.dependencies import require_role
from src.services.pseudonymisation import pseudonymise
router = APIRouter(prefix="/consent", tags=["consent"])
@router.post("/", response_model=ConsentRecordResponse, status_code=status.HTTP_201_CREATED)
async def record_consent(
body: ConsentRecordCreate,
request: Request,
db: AsyncSession = Depends(get_db),
) -> ConsentRecord:
"""Record a consent event from the banner. Public endpoint (no auth required)."""
# Pseudonymise IP and user agent with HMAC so the resulting values
# cannot be reversed without the server-side secret.
client_ip = request.client.host if request.client else ""
user_agent = request.headers.get("user-agent", "")
record = ConsentRecord(
site_id=body.site_id,
visitor_id=body.visitor_id,
ip_hash=pseudonymise(client_ip),
user_agent_hash=pseudonymise(user_agent),
action=body.action,
categories_accepted=body.categories_accepted,
categories_rejected=body.categories_rejected,
tc_string=body.tc_string,
gcm_state=body.gcm_state,
page_url=body.page_url,
country_code=body.country_code,
region_code=body.region_code,
)
db.add(record)
await db.flush()
await db.refresh(record)
# Invoke any registered post-record hooks (EE consent receipts, etc.)
for hook in get_registry().consent_record_hooks:
await hook(db, record)
return record
async def _load_record_for_org(
consent_id: uuid.UUID,
current_user: CurrentUser,
db: AsyncSession,
) -> ConsentRecord:
"""Load a consent record and enforce tenant isolation.
The record's site must belong to the caller's organisation. A record
from another tenant returns 404 rather than 403 so we don't leak
existence across tenants.
"""
stmt = (
select(ConsentRecord)
.join(Site, Site.id == ConsentRecord.site_id)
.where(
ConsentRecord.id == consent_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
record = (await db.execute(stmt)).scalar_one_or_none()
if record is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Consent record not found",
)
return record
@router.get("/{consent_id}", response_model=ConsentRecordResponse)
async def get_consent(
consent_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> ConsentRecord:
"""Retrieve a consent record by ID.
Requires authentication and tenant membership. Consent records
contain PII-adjacent data (hashed IP, page URL, category decisions)
and must not be readable by anyone holding a record UUID.
"""
return await _load_record_for_org(consent_id, current_user, db)
@router.get("/verify/{consent_id}", response_model=ConsentVerifyResponse)
async def verify_consent(
consent_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Verify that a consent record exists (audit proof).
Same tenant-scoped auth as :func:`get_consent` — proof of consent
is only meaningful to the organisation that owns the site, and
leaking existence to arbitrary callers enables enumeration.
"""
record = await _load_record_for_org(consent_id, current_user, db)
return {
"id": record.id,
"site_id": record.site_id,
"visitor_id": record.visitor_id,
"action": record.action,
"categories_accepted": record.categories_accepted,
"consented_at": record.consented_at,
"valid": True,
}

View File

@@ -0,0 +1,582 @@
"""Cookie category, cookie, and allow-list management endpoints."""
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
from src.models.site import Site
from src.schemas.auth import CurrentUser
from src.schemas.cookie import (
AllowListEntryCreate,
AllowListEntryResponse,
AllowListEntryUpdate,
ClassificationResultResponse,
ClassifySingleRequest,
ClassifySiteResponse,
CookieCategoryResponse,
CookieCreate,
CookieResponse,
CookieUpdate,
KnownCookieCreate,
KnownCookieResponse,
KnownCookieUpdate,
ReviewStatus,
)
from src.services.classification import classify_single_cookie, classify_site_cookies
from src.services.dependencies import get_current_user, require_role
router = APIRouter(prefix="/cookies", tags=["cookies"])
# ── Cookie categories (read-only, seeded by migration) ──────────────
@router.get("/categories", response_model=list[CookieCategoryResponse])
async def list_categories(
db: AsyncSession = Depends(get_db),
) -> list[CookieCategory]:
"""List all cookie categories. Public endpoint used by banner and admin."""
result = await db.execute(select(CookieCategory).order_by(CookieCategory.display_order))
return list(result.scalars().all())
@router.get("/categories/{category_id}", response_model=CookieCategoryResponse)
async def get_category(
category_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
) -> CookieCategory:
"""Get a single cookie category by ID."""
result = await db.execute(select(CookieCategory).where(CookieCategory.id == category_id))
category = result.scalar_one_or_none()
if not category:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category not found")
return category
# ── Cookies per site ─────────────────────────────────────────────────
async def _get_org_site(
site_id: uuid.UUID,
current_user: CurrentUser,
db: AsyncSession,
) -> Site:
"""Fetch a site ensuring it belongs to the user's organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if not site:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
return site
@router.get(
"/sites/{site_id}",
response_model=list[CookieResponse],
)
async def list_cookies(
site_id: uuid.UUID,
review_status: ReviewStatus | None = Query(None),
category_id: uuid.UUID | None = Query(None),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[Cookie]:
"""List cookies discovered on a site, with optional filters."""
await _get_org_site(site_id, current_user, db)
query = select(Cookie).where(Cookie.site_id == site_id)
if review_status:
query = query.where(Cookie.review_status == review_status.value)
if category_id:
query = query.where(Cookie.category_id == category_id)
query = query.order_by(Cookie.name)
result = await db.execute(query)
return list(result.scalars().all())
@router.post(
"/sites/{site_id}",
response_model=CookieResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_cookie(
site_id: uuid.UUID,
body: CookieCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Cookie:
"""Create a cookie record for a site (manual entry or from scanner)."""
await _get_org_site(site_id, current_user, db)
# Validate category if provided
if body.category_id:
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
cookie = Cookie(
site_id=site_id,
**body.model_dump(),
first_seen_at=datetime.now(UTC).isoformat(),
last_seen_at=datetime.now(UTC).isoformat(),
)
db.add(cookie)
await db.flush()
await db.refresh(cookie)
return cookie
@router.get("/sites/{site_id}/summary")
async def cookie_summary(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Get a summary of cookies for a site (counts by status and category)."""
await _get_org_site(site_id, current_user, db)
# Count by review status
status_result = await db.execute(
select(Cookie.review_status, func.count(Cookie.id))
.where(Cookie.site_id == site_id)
.group_by(Cookie.review_status)
)
by_status = {row[0]: row[1] for row in status_result.all()}
# Count by category
cat_result = await db.execute(
select(CookieCategory.slug, func.count(Cookie.id))
.outerjoin(Cookie, Cookie.category_id == CookieCategory.id)
.where(Cookie.site_id == site_id)
.group_by(CookieCategory.slug)
)
by_category = {row[0]: row[1] for row in cat_result.all()}
# Uncategorised count
uncat_result = await db.execute(
select(func.count(Cookie.id)).where(Cookie.site_id == site_id, Cookie.category_id.is_(None))
)
uncategorised = uncat_result.scalar() or 0
return {
"total": sum(by_status.values()),
"by_status": by_status,
"by_category": by_category,
"uncategorised": uncategorised,
}
# ── Allow-list per site ──────────────────────────────────────────────
# (Must be defined before {cookie_id} routes to avoid path conflicts)
@router.get(
"/sites/{site_id}/allow-list",
response_model=list[AllowListEntryResponse],
)
async def list_allow_list(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[CookieAllowListEntry]:
"""List all allow-list entries for a site."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(CookieAllowListEntry)
.where(CookieAllowListEntry.site_id == site_id)
.order_by(CookieAllowListEntry.name_pattern)
)
return list(result.scalars().all())
@router.post(
"/sites/{site_id}/allow-list",
response_model=AllowListEntryResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_allow_list_entry(
site_id: uuid.UUID,
body: AllowListEntryCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> CookieAllowListEntry:
"""Add a cookie pattern to the allow-list for a site."""
await _get_org_site(site_id, current_user, db)
# Validate category
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
entry = CookieAllowListEntry(
site_id=site_id,
**body.model_dump(),
)
db.add(entry)
await db.flush()
await db.refresh(entry)
return entry
@router.patch(
"/sites/{site_id}/allow-list/{entry_id}",
response_model=AllowListEntryResponse,
)
async def update_allow_list_entry(
site_id: uuid.UUID,
entry_id: uuid.UUID,
body: AllowListEntryUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> CookieAllowListEntry:
"""Update an allow-list entry."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(CookieAllowListEntry).where(
CookieAllowListEntry.id == entry_id,
CookieAllowListEntry.site_id == site_id,
)
)
entry = result.scalar_one_or_none()
if not entry:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Allow-list entry not found",
)
updates = body.model_dump(exclude_unset=True)
if "category_id" in updates and updates["category_id"] is not None:
cat = await db.execute(
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
)
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
for field, value in updates.items():
setattr(entry, field, value)
entry.updated_at = datetime.now(UTC)
await db.flush()
await db.refresh(entry)
return entry
@router.delete(
"/sites/{site_id}/allow-list/{entry_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_allow_list_entry(
site_id: uuid.UUID,
entry_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Remove an entry from the allow-list."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(CookieAllowListEntry).where(
CookieAllowListEntry.id == entry_id,
CookieAllowListEntry.site_id == site_id,
)
)
entry = result.scalar_one_or_none()
if not entry:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Allow-list entry not found",
)
await db.delete(entry)
# ── Individual cookie by ID (must come after /summary and /allow-list) ──
@router.get("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
async def get_cookie(
site_id: uuid.UUID,
cookie_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Cookie:
"""Get a single cookie by ID."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
)
cookie = result.scalar_one_or_none()
if not cookie:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
return cookie
@router.patch("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
async def update_cookie(
site_id: uuid.UUID,
cookie_id: uuid.UUID,
body: CookieUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Cookie:
"""Update a cookie record (e.g. assign category, change review status)."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
)
cookie = result.scalar_one_or_none()
if not cookie:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
updates = body.model_dump(exclude_unset=True)
# Validate category if being changed
if "category_id" in updates and updates["category_id"] is not None:
cat = await db.execute(
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
)
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
for field, value in updates.items():
setattr(cookie, field, value)
cookie.updated_at = datetime.now(UTC)
await db.flush()
await db.refresh(cookie)
return cookie
@router.delete(
"/sites/{site_id}/{cookie_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_cookie(
site_id: uuid.UUID,
cookie_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a cookie record."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
)
cookie = result.scalar_one_or_none()
if not cookie:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
await db.delete(cookie)
# ── Known cookies database ──────────────────────────────────────────
@router.get("/known", response_model=list[KnownCookieResponse])
async def list_known_cookies(
vendor: str | None = Query(None, description="Filter by vendor name"),
search: str | None = Query(None, description="Search by name pattern"),
db: AsyncSession = Depends(get_db),
_user: CurrentUser = Depends(get_current_user),
) -> list[KnownCookie]:
"""List known cookie patterns from the shared database."""
query = select(KnownCookie).order_by(KnownCookie.name_pattern)
if vendor:
query = query.where(KnownCookie.vendor == vendor)
if search:
query = query.where(KnownCookie.name_pattern.ilike(f"%{search}%"))
result = await db.execute(query)
return list(result.scalars().all())
@router.post(
"/known",
response_model=KnownCookieResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_known_cookie(
body: KnownCookieCreate,
_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> KnownCookie:
"""Add a new pattern to the known cookies database."""
# Validate category
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
known = KnownCookie(**body.model_dump())
db.add(known)
await db.flush()
await db.refresh(known)
return known
@router.get("/known/{known_id}", response_model=KnownCookieResponse)
async def get_known_cookie(
known_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
_user: CurrentUser = Depends(get_current_user),
) -> KnownCookie:
"""Get a single known cookie pattern by ID."""
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
known = result.scalar_one_or_none()
if not known:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Known cookie not found",
)
return known
@router.patch("/known/{known_id}", response_model=KnownCookieResponse)
async def update_known_cookie(
known_id: uuid.UUID,
body: KnownCookieUpdate,
_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> KnownCookie:
"""Update a known cookie pattern."""
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
known = result.scalar_one_or_none()
if not known:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Known cookie not found",
)
updates = body.model_dump(exclude_unset=True)
if "category_id" in updates and updates["category_id"] is not None:
cat = await db.execute(
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
)
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
for field, value in updates.items():
setattr(known, field, value)
known.updated_at = datetime.now(UTC)
await db.flush()
await db.refresh(known)
return known
@router.delete(
"/known/{known_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_known_cookie(
known_id: uuid.UUID,
_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a known cookie pattern."""
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
known = result.scalar_one_or_none()
if not known:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Known cookie not found",
)
await db.delete(known)
# ── Classification endpoints ────────────────────────────────────────
@router.post(
"/sites/{site_id}/classify",
response_model=ClassifySiteResponse,
)
async def classify_cookies(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> ClassifySiteResponse:
"""Auto-classify pending cookies for a site against known patterns."""
await _get_org_site(site_id, current_user, db)
results = await classify_site_cookies(db, site_id, only_pending=True)
matched_count = sum(1 for r in results if r.matched)
return ClassifySiteResponse(
site_id=str(site_id),
total=len(results),
matched=matched_count,
unmatched=len(results) - matched_count,
results=[
ClassificationResultResponse(
cookie_name=r.cookie_name,
cookie_domain=r.cookie_domain,
category_id=r.category_id,
category_slug=r.category_slug,
vendor=r.vendor,
description=r.description,
match_source=r.match_source,
matched=r.matched,
)
for r in results
],
)
@router.post(
"/sites/{site_id}/classify/preview",
response_model=ClassificationResultResponse,
)
async def classify_preview(
site_id: uuid.UUID,
body: ClassifySingleRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ClassificationResultResponse:
"""Preview classification for a single cookie without saving."""
await _get_org_site(site_id, current_user, db)
result = await classify_single_cookie(db, site_id, body.cookie_name, body.cookie_domain)
return ClassificationResultResponse(
cookie_name=result.cookie_name,
cookie_domain=result.cookie_domain,
category_id=result.category_id,
category_slug=result.category_slug,
vendor=result.vendor,
description=result.description,
match_source=result.match_source,
matched=result.matched,
)

View File

@@ -0,0 +1,69 @@
"""Organisation-level default configuration endpoints.
Provides GET and PUT for the organisation's global config defaults.
These defaults sit between system defaults and site config in the cascade.
"""
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.org_config import OrgConfig
from src.schemas.auth import CurrentUser
from src.schemas.org_config import OrgConfigResponse, OrgConfigUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/org-config", tags=["organisations"])
@router.get("/", response_model=OrgConfigResponse)
async def get_org_config(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> OrgConfig:
"""Retrieve the organisation's global configuration defaults."""
result = await db.execute(
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
)
config = result.scalar_one_or_none()
if config is None:
# Auto-create an empty config row so the response is always valid
config = OrgConfig(organisation_id=current_user.organisation_id)
db.add(config)
await db.flush()
await db.refresh(config)
return config
@router.put("/", response_model=OrgConfigResponse)
async def update_org_config(
body: OrgConfigUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> OrgConfig:
"""Create or update the organisation's global configuration defaults.
Only non-None fields will override system defaults when resolving site config.
"""
result = await db.execute(
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
)
config = result.scalar_one_or_none()
if config is None:
config = OrgConfig(
organisation_id=current_user.organisation_id,
**body.model_dump(exclude_unset=True),
)
db.add(config)
else:
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
await db.flush()
await db.refresh(config)
return config

View File

@@ -0,0 +1,118 @@
import hmac
from fastapi import APIRouter, Depends, Header, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import get_settings
from src.db import get_db
from src.models.organisation import Organisation
from src.schemas.auth import CurrentUser
from src.schemas.organisation import (
OrganisationCreate,
OrganisationResponse,
OrganisationUpdate,
)
from src.services.dependencies import require_role
router = APIRouter(prefix="/organisations", tags=["organisations"])
def _require_bootstrap_token(
x_admin_bootstrap_token: str | None = Header(default=None),
) -> None:
"""Gate organisation creation behind a static bootstrap token.
The token is configured via ``ADMIN_BOOTSTRAP_TOKEN``. When unset
(the default), the endpoint is disabled entirely — operators must
explicitly opt in and should rotate or unset the value after their
initial org is provisioned.
"""
expected = get_settings().admin_bootstrap_token
if not expected:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"Organisation creation is disabled. Set ADMIN_BOOTSTRAP_TOKEN "
"in the environment to enable it."
),
)
if not x_admin_bootstrap_token or not hmac.compare_digest(
x_admin_bootstrap_token,
expected,
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing admin bootstrap token",
)
@router.post("/", response_model=OrganisationResponse, status_code=status.HTTP_201_CREATED)
async def create_organisation(
body: OrganisationCreate,
db: AsyncSession = Depends(get_db),
_: None = Depends(_require_bootstrap_token),
) -> Organisation:
"""Create a new organisation. Gated by ``X-Admin-Bootstrap-Token``.
See :func:`_require_bootstrap_token` for the gating semantics. Once
your initial organisation exists, rotate or unset
``ADMIN_BOOTSTRAP_TOKEN`` to disable further tenant creation.
"""
# Check slug uniqueness
existing = await db.execute(select(Organisation).where(Organisation.slug == body.slug))
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Organisation with slug '{body.slug}' already exists",
)
org = Organisation(**body.model_dump())
db.add(org)
await db.flush()
await db.refresh(org)
return org
@router.get("/me", response_model=OrganisationResponse)
async def get_my_organisation(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> Organisation:
"""Get the current user's organisation."""
result = await db.execute(
select(Organisation).where(
Organisation.id == current_user.organisation_id,
Organisation.deleted_at.is_(None),
)
)
org = result.scalar_one_or_none()
if org is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
return org
@router.patch("/me", response_model=OrganisationResponse)
async def update_my_organisation(
body: OrganisationUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> Organisation:
"""Update the current user's organisation. Requires owner or admin role."""
result = await db.execute(
select(Organisation).where(
Organisation.id == current_user.organisation_id,
Organisation.deleted_at.is_(None),
)
)
org = result.scalar_one_or_none()
if org is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(org, field, value)
await db.flush()
await db.refresh(org)
return org

View File

@@ -0,0 +1,310 @@
"""Scanner and client-side cookie report endpoints.
Accepts cookie reports from the client-side reporter embedded in the banner
bundle, upserts discovered cookies into the site's cookie inventory, and
provides scan job management (trigger, list, detail, diff).
"""
import logging
import uuid
from datetime import UTC, datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.cookie import Cookie
from src.models.scan import ScanJob, ScanResult
from src.models.site import Site
from src.schemas.auth import CurrentUser
from src.schemas.scanner import (
CookieReportRequest,
CookieReportResponse,
ScanDiffResponse,
ScanJobDetailResponse,
ScanJobResponse,
TriggerScanRequest,
)
from src.services.dependencies import get_current_user
from src.services.scanner import (
compute_scan_diff,
create_scan_job,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/scanner", tags=["scanner"])
# ── Client-side cookie report (public, no auth) ─────────────────────
@router.post(
"/report",
response_model=CookieReportResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def receive_cookie_report(
body: CookieReportRequest,
db: AsyncSession = Depends(get_db),
) -> CookieReportResponse:
"""Receive a cookie report from the client-side reporter.
This is a public endpoint (no auth) since it's called from the banner
script running on end-user browsers. The site_id acts as implicit auth.
"""
# Verify site exists
site_result = await db.execute(
select(Site).where(
Site.id == body.site_id,
Site.deleted_at.is_(None),
)
)
if site_result.scalar_one_or_none() is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site not found",
)
new_cookies = 0
now_iso = datetime.now(UTC).isoformat()
for reported in body.cookies:
# Check if this cookie already exists for the site
existing = await db.execute(
select(Cookie).where(
Cookie.site_id == body.site_id,
Cookie.name == reported.name,
Cookie.domain == reported.domain,
Cookie.storage_type == reported.storage_type,
)
)
cookie = existing.scalar_one_or_none()
if cookie:
# Update last_seen_at timestamp
cookie.last_seen_at = now_iso
else:
# Create new cookie record
cookie = Cookie(
site_id=body.site_id,
name=reported.name,
domain=reported.domain,
storage_type=reported.storage_type,
path=reported.path,
is_secure=reported.is_secure,
same_site=reported.same_site,
review_status="pending",
first_seen_at=now_iso,
last_seen_at=now_iso,
)
db.add(cookie)
new_cookies += 1
await db.flush()
return CookieReportResponse(
accepted=True,
cookies_received=len(body.cookies),
new_cookies=new_cookies,
)
# ── Scan job management (authenticated) ─────────────────────────────
async def _verify_site_access(
site_id: uuid.UUID,
user: CurrentUser,
db: AsyncSession,
) -> Site:
"""Verify site exists and belongs to the user's organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == user.organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if site is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site not found",
)
return site
@router.post(
"/scans",
response_model=ScanJobResponse,
status_code=status.HTTP_201_CREATED,
)
async def trigger_scan(
body: TriggerScanRequest,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
) -> ScanJob:
"""Trigger a new cookie scan for a site.
Creates a scan job in 'pending' state and dispatches it to the
Celery worker queue for execution.
"""
from src.services.scanner import complete_scan_job
await _verify_site_access(body.site_id, user, db)
# Check for an already-running scan
active_result = await db.execute(
select(ScanJob).where(
ScanJob.site_id == body.site_id,
ScanJob.status.in_(["pending", "running"]),
)
)
active_jobs = list(active_result.scalars().all())
now = datetime.now(UTC)
stale_pending_cutoff = now - timedelta(minutes=5)
stale_running_cutoff = now - timedelta(minutes=10)
for active_job in active_jobs:
is_stale_pending = (
active_job.status == "pending"
and active_job.created_at.replace(tzinfo=UTC) < stale_pending_cutoff
)
is_stale_running = (
active_job.status == "running"
and active_job.started_at
and active_job.started_at.replace(tzinfo=UTC) < stale_running_cutoff
)
if is_stale_pending or is_stale_running:
logger.warning(
"Failing stale %s scan job %s for site %s",
active_job.status,
active_job.id,
body.site_id,
)
await complete_scan_job(
db,
active_job,
error_message=(
f"Job was stale ({active_job.status} too long), superseded by new scan"
),
)
else:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="A scan is already in progress for this site",
)
job = await create_scan_job(
db,
site_id=body.site_id,
trigger="manual",
max_pages=body.max_pages,
)
# Commit before dispatching to Celery so the worker can find the
# job in the database immediately (avoids race condition).
await db.commit()
# Dispatch to Celery (import here to avoid import at module level
# when Celery broker is unavailable during testing)
try:
from src.tasks.scanner import run_scan
run_scan.delay(str(job.id), str(body.site_id))
except Exception:
logger.exception("Failed to dispatch scan job %s to Celery", job.id)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=(
"Background task queue is unavailable — scan job"
" created but cannot be processed. Please try again later."
),
) from None
return job
@router.get("/scans/site/{site_id}", response_model=list[ScanJobResponse])
async def list_scans(
site_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
limit: int = Query(default=20, ge=1, le=100),
offset: int = Query(default=0, ge=0),
) -> list[ScanJob]:
"""List scan jobs for a site, most recent first."""
await _verify_site_access(site_id, user, db)
result = await db.execute(
select(ScanJob)
.where(ScanJob.site_id == site_id)
.order_by(ScanJob.created_at.desc())
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
@router.get("/scans/{scan_id}", response_model=ScanJobDetailResponse)
async def get_scan(
scan_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""Retrieve a scan job with its results."""
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scan job not found",
)
# Verify org access
await _verify_site_access(job.site_id, user, db)
# Load results
results = await db.execute(
select(ScanResult).where(ScanResult.scan_job_id == scan_id).order_by(ScanResult.cookie_name)
)
scan_results = list(results.scalars().all())
return {
"id": job.id,
"site_id": job.site_id,
"status": job.status,
"trigger": job.trigger,
"pages_scanned": job.pages_scanned,
"pages_total": job.pages_total,
"cookies_found": job.cookies_found,
"error_message": job.error_message,
"started_at": job.started_at,
"completed_at": job.completed_at,
"created_at": job.created_at,
"updated_at": job.updated_at,
"results": scan_results,
}
@router.get("/scans/{scan_id}/diff", response_model=ScanDiffResponse)
async def get_scan_diff(
scan_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
) -> ScanDiffResponse:
"""Get the diff between a scan and its predecessor."""
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scan job not found",
)
await _verify_site_access(job.site_id, user, db)
return await compute_scan_diff(db, current_scan_id=scan_id, site_id=job.site_id)

View File

@@ -0,0 +1,101 @@
"""Site-group-level default configuration endpoints.
Provides GET and PUT for a site group's config defaults.
These defaults sit between org defaults and site config in the cascade.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site_group import SiteGroup
from src.models.site_group_config import SiteGroupConfig
from src.schemas.auth import CurrentUser
from src.schemas.site_group_config import SiteGroupConfigResponse, SiteGroupConfigUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
@router.get("/{group_id}/config", response_model=SiteGroupConfigResponse)
async def get_site_group_config(
group_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> SiteGroupConfig:
"""Retrieve configuration defaults for a site group."""
await _verify_group_ownership(group_id, current_user.organisation_id, db)
result = await db.execute(
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
)
config = result.scalar_one_or_none()
if config is None:
# Auto-create an empty config row so the response is always valid
config = SiteGroupConfig(site_group_id=group_id)
db.add(config)
await db.flush()
await db.refresh(config)
return config
@router.put("/{group_id}/config", response_model=SiteGroupConfigResponse)
async def update_site_group_config(
group_id: uuid.UUID,
body: SiteGroupConfigUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> SiteGroupConfig:
"""Create or update configuration defaults for a site group.
Only non-None fields will override org/system defaults when resolving site config.
"""
await _verify_group_ownership(group_id, current_user.organisation_id, db)
result = await db.execute(
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
)
config = result.scalar_one_or_none()
if config is None:
config = SiteGroupConfig(
site_group_id=group_id,
**body.model_dump(exclude_unset=True),
)
db.add(config)
else:
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
await db.flush()
await db.refresh(config)
return config
# -- Helpers ------------------------------------------------------------------
async def _verify_group_ownership(
group_id: uuid.UUID,
organisation_id: uuid.UUID,
db: AsyncSession,
) -> None:
"""Ensure the site group belongs to the user's organisation."""
result = await db.execute(
select(SiteGroup).where(
SiteGroup.id == group_id,
SiteGroup.organisation_id == organisation_id,
SiteGroup.deleted_at.is_(None),
)
)
if result.scalar_one_or_none() is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site group not found",
)

View File

@@ -0,0 +1,198 @@
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site import Site
from src.models.site_group import SiteGroup
from src.schemas.auth import CurrentUser
from src.schemas.site_group import SiteGroupCreate, SiteGroupResponse, SiteGroupUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
@router.post("/", response_model=SiteGroupResponse, status_code=status.HTTP_201_CREATED)
async def create_site_group(
body: SiteGroupCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Create a new site group within the current organisation."""
# Check name uniqueness within the org
existing = await db.execute(
select(SiteGroup).where(
SiteGroup.organisation_id == current_user.organisation_id,
SiteGroup.name == body.name,
SiteGroup.deleted_at.is_(None),
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Site group '{body.name}' already exists in this organisation",
)
group = SiteGroup(
organisation_id=current_user.organisation_id,
name=body.name,
description=body.description,
)
db.add(group)
await db.flush()
await db.refresh(group)
return _to_response(group, site_count=0)
@router.get("/", response_model=list[SiteGroupResponse])
async def list_site_groups(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[dict]:
"""List all site groups in the current organisation with site counts."""
# Subquery for site counts
site_count_sq = (
select(
Site.site_group_id,
func.count(Site.id).label("cnt"),
)
.where(Site.deleted_at.is_(None))
.group_by(Site.site_group_id)
.subquery()
)
result = await db.execute(
select(SiteGroup, func.coalesce(site_count_sq.c.cnt, 0).label("site_count"))
.outerjoin(site_count_sq, SiteGroup.id == site_count_sq.c.site_group_id)
.where(
SiteGroup.organisation_id == current_user.organisation_id,
SiteGroup.deleted_at.is_(None),
)
.order_by(SiteGroup.name)
)
return [_to_response(row.SiteGroup, site_count=row.site_count) for row in result.all()]
@router.get("/{group_id}", response_model=SiteGroupResponse)
async def get_site_group(
group_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Get a specific site group by ID."""
group = await _get_org_group(group_id, current_user.organisation_id, db)
site_count = await _count_sites(group_id, db)
return _to_response(group, site_count=site_count)
@router.patch("/{group_id}", response_model=SiteGroupResponse)
async def update_site_group(
group_id: uuid.UUID,
body: SiteGroupUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Update a site group's name or description."""
group = await _get_org_group(group_id, current_user.organisation_id, db)
update_data = body.model_dump(exclude_unset=True)
# Check name uniqueness if name is being changed
if "name" in update_data and update_data["name"] != group.name:
existing = await db.execute(
select(SiteGroup).where(
SiteGroup.organisation_id == current_user.organisation_id,
SiteGroup.name == update_data["name"],
SiteGroup.deleted_at.is_(None),
SiteGroup.id != group_id,
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Site group '{update_data['name']}' already exists",
)
for field, value in update_data.items():
setattr(group, field, value)
await db.flush()
await db.refresh(group)
site_count = await _count_sites(group_id, db)
return _to_response(group, site_count=site_count)
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_site_group(
group_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Soft-delete a site group. Sites in this group become ungrouped."""
group = await _get_org_group(group_id, current_user.organisation_id, db)
# Ungroup all sites in this group
result = await db.execute(
select(Site).where(
Site.site_group_id == group_id,
Site.deleted_at.is_(None),
)
)
for site in result.scalars().all():
site.site_group_id = None
group.deleted_at = datetime.now(UTC)
await db.flush()
# ── Helpers ──────────────────────────────────────────────────────────
async def _get_org_group(
group_id: uuid.UUID,
organisation_id: uuid.UUID,
db: AsyncSession,
) -> SiteGroup:
"""Fetch a site group ensuring it belongs to the given organisation."""
result = await db.execute(
select(SiteGroup).where(
SiteGroup.id == group_id,
SiteGroup.organisation_id == organisation_id,
SiteGroup.deleted_at.is_(None),
)
)
group = result.scalar_one_or_none()
if group is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site group not found",
)
return group
async def _count_sites(group_id: uuid.UUID, db: AsyncSession) -> int:
"""Count active sites in a group."""
result = await db.execute(
select(func.count(Site.id)).where(
Site.site_group_id == group_id,
Site.deleted_at.is_(None),
)
)
return result.scalar_one()
def _to_response(group: SiteGroup, *, site_count: int) -> dict:
"""Convert a SiteGroup model to a response dict with site_count."""
return {
"id": group.id,
"organisation_id": group.organisation_id,
"name": group.name,
"description": group.description,
"created_at": group.created_at,
"updated_at": group.updated_at,
"site_count": site_count,
}

View File

@@ -0,0 +1,220 @@
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.schemas.auth import CurrentUser
from src.schemas.site import (
SiteConfigCreate,
SiteConfigResponse,
SiteConfigUpdate,
SiteCreate,
SiteResponse,
SiteUpdate,
)
from src.services.dependencies import require_role
router = APIRouter(prefix="/sites", tags=["sites"])
# ── Site CRUD ────────────────────────────────────────────────────────
@router.post("/", response_model=SiteResponse, status_code=status.HTTP_201_CREATED)
async def create_site(
body: SiteCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> Site:
"""Create a new site within the current organisation."""
# Check domain uniqueness within the org
existing = await db.execute(
select(Site).where(
Site.organisation_id == current_user.organisation_id,
Site.domain == body.domain,
Site.deleted_at.is_(None),
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Site with domain '{body.domain}' already exists in this organisation",
)
site = Site(
organisation_id=current_user.organisation_id,
domain=body.domain,
display_name=body.display_name,
site_group_id=body.site_group_id,
)
db.add(site)
await db.flush()
# Auto-create a default site configuration
default_config = SiteConfig(site_id=site.id)
db.add(default_config)
await db.flush()
await db.refresh(site)
return site
@router.get("/", response_model=list[SiteResponse])
async def list_sites(
site_group_id: uuid.UUID | None = Query(default=None),
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[Site]:
"""List all active sites in the current organisation, optionally filtered by group."""
query = select(Site).where(
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
if site_group_id is not None:
query = query.where(Site.site_group_id == site_group_id)
result = await db.execute(query.order_by(Site.domain))
return list(result.scalars().all())
@router.get("/{site_id}", response_model=SiteResponse)
async def get_site(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> Site:
"""Get a specific site by ID."""
site = await _get_org_site(site_id, current_user.organisation_id, db)
return site
@router.patch("/{site_id}", response_model=SiteResponse)
async def update_site(
site_id: uuid.UUID,
body: SiteUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Site:
"""Update a site's display name or active status."""
site = await _get_org_site(site_id, current_user.organisation_id, db)
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(site, field, value)
await db.flush()
await db.refresh(site)
return site
@router.delete("/{site_id}", status_code=status.HTTP_204_NO_CONTENT)
async def deactivate_site(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Soft-delete a site."""
site = await _get_org_site(site_id, current_user.organisation_id, db)
site.deleted_at = datetime.now(UTC)
await db.flush()
# ── Site config CRUD ─────────────────────────────────────────────────
@router.get("/{site_id}/config", response_model=SiteConfigResponse)
async def get_site_config(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Get the configuration for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found. Create one first.",
)
return config
@router.put("/{site_id}/config", response_model=SiteConfigResponse)
async def create_or_replace_site_config(
site_id: uuid.UUID,
body: SiteConfigCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Create or replace the full configuration for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
existing = result.scalar_one_or_none()
if existing is not None:
for field, value in body.model_dump().items():
setattr(existing, field, value)
await db.flush()
await db.refresh(existing)
return existing
config = SiteConfig(site_id=site_id, **body.model_dump())
db.add(config)
await db.flush()
await db.refresh(config)
return config
@router.patch("/{site_id}/config", response_model=SiteConfigResponse)
async def update_site_config(
site_id: uuid.UUID,
body: SiteConfigUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Partially update the configuration for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found. Create one first.",
)
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
await db.flush()
await db.refresh(config)
return config
# ── Helpers ──────────────────────────────────────────────────────────
async def _get_org_site(
site_id: uuid.UUID,
organisation_id: uuid.UUID,
db: AsyncSession,
) -> Site:
"""Fetch a site ensuring it belongs to the given organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if site is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
return site

View File

@@ -0,0 +1,195 @@
"""Translation management endpoints.
CRUD for per-site, per-locale translation strings used by the banner script.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site import Site
from src.models.translation import Translation
from src.schemas.auth import CurrentUser
from src.schemas.translation import TranslationCreate, TranslationResponse, TranslationUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/sites/{site_id}/translations", tags=["translations"])
async def _get_org_site(site_id: uuid.UUID, organisation_id: uuid.UUID, db: AsyncSession) -> Site:
"""Ensure site belongs to the current organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if site is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
return site
@router.get("/", response_model=list[TranslationResponse])
async def list_translations(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[Translation]:
"""List all translations for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(Translation.site_id == site_id).order_by(Translation.locale)
)
return list(result.scalars().all())
@router.get("/{locale}", response_model=TranslationResponse)
async def get_translation(
site_id: uuid.UUID,
locale: str,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> Translation:
"""Get translation strings for a specific locale."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == locale,
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No translation found for locale '{locale}'",
)
return translation
@router.post("/", response_model=TranslationResponse, status_code=status.HTTP_201_CREATED)
async def create_translation(
site_id: uuid.UUID,
body: TranslationCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Translation:
"""Create a translation for a new locale."""
await _get_org_site(site_id, current_user.organisation_id, db)
# Check for duplicate locale
existing = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == body.locale,
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Translation for locale '{body.locale}' already exists",
)
translation = Translation(
site_id=site_id,
locale=body.locale,
strings=body.strings,
)
db.add(translation)
await db.flush()
await db.refresh(translation)
return translation
@router.put("/{locale}", response_model=TranslationResponse)
async def update_translation(
site_id: uuid.UUID,
locale: str,
body: TranslationUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Translation:
"""Replace the strings for an existing locale translation."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == locale,
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No translation found for locale '{locale}'",
)
translation.strings = body.strings
await db.flush()
await db.refresh(translation)
return translation
@router.delete("/{locale}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_translation(
site_id: uuid.UUID,
locale: str,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a translation for a specific locale."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == locale,
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No translation found for locale '{locale}'",
)
await db.delete(translation)
await db.flush()
# ── Public endpoint for the banner script ────────────────────────────
public_router = APIRouter(prefix="/translations", tags=["translations"])
@public_router.get("/{site_id}/{locale}")
async def get_public_translation(
site_id: uuid.UUID,
locale: str,
db: AsyncSession = Depends(get_db),
) -> dict[str, str]:
"""Public endpoint: return translation strings for the banner script.
No auth required. Returns the raw strings dict for a given site and locale.
Returns 404 if no translation exists (banner falls back to English defaults).
"""
result = await db.execute(
select(Translation)
.join(Site)
.where(
Translation.site_id == site_id,
Translation.locale == locale,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Translation not found",
)
return translation.strings

View File

@@ -0,0 +1,136 @@
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.user import User
from src.schemas.auth import CurrentUser
from src.schemas.user import UserCreate, UserResponse, UserUpdate
from src.services.auth import hash_password
from src.services.dependencies import require_role
router = APIRouter(prefix="/users", tags=["users"])
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
body: UserCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> User:
"""Invite/create a new user within the current organisation."""
# Check email uniqueness
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"User with email '{body.email}' already exists",
)
user = User(
organisation_id=current_user.organisation_id,
email=body.email,
password_hash=hash_password(body.password),
full_name=body.full_name,
role=body.role,
)
db.add(user)
await db.flush()
await db.refresh(user)
return user
@router.get("/", response_model=list[UserResponse])
async def list_users(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[User]:
"""List all active users in the current organisation."""
result = await db.execute(
select(User)
.where(
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
.order_by(User.created_at)
)
return list(result.scalars().all())
@router.get("/{user_id}", response_model=UserResponse)
async def get_user(
user_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> User:
"""Get a specific user by ID within the current organisation."""
result = await db.execute(
select(User).where(
User.id == user_id,
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user
@router.patch("/{user_id}", response_model=UserResponse)
async def update_user(
user_id: uuid.UUID,
body: UserUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> User:
"""Update a user's name or role. Requires owner or admin."""
result = await db.execute(
select(User).where(
User.id == user_id,
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(user, field, value)
await db.flush()
await db.refresh(user)
return user
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def deactivate_user(
user_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Soft-delete (deactivate) a user. Requires owner or admin."""
if user_id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot deactivate yourself",
)
result = await db.execute(
select(User).where(
User.id == user_id,
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.deleted_at = datetime.now(UTC)
await db.flush()

View File

View File

@@ -0,0 +1,45 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, EmailStr
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshRequest(BaseModel):
refresh_token: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class TokenPayload(BaseModel):
sub: str # user ID
org_id: str # organisation ID
role: str # user role
exp: datetime
iat: datetime
type: str = "access" # "access" or "refresh"
class CurrentUser(BaseModel):
"""Represents the authenticated user extracted from a JWT."""
id: uuid.UUID
organisation_id: uuid.UUID
email: str
role: str
def has_role(self, *roles: str) -> bool:
return self.role in roles
@property
def is_admin(self) -> bool:
return self.role in ("owner", "admin")

View File

@@ -0,0 +1,56 @@
"""Pydantic schemas for compliance check results."""
from enum import StrEnum
from pydantic import BaseModel, Field
class Severity(StrEnum):
CRITICAL = "critical"
WARNING = "warning"
INFO = "info"
class Framework(StrEnum):
GDPR = "gdpr"
CNIL = "cnil"
CCPA = "ccpa"
EPRIVACY = "eprivacy"
LGPD = "lgpd"
class ComplianceIssue(BaseModel):
"""A single compliance issue found during a check."""
rule_id: str
severity: Severity
message: str
recommendation: str
class FrameworkResult(BaseModel):
"""Compliance result for a single regulatory framework."""
framework: Framework
score: int = Field(ge=0, le=100, description="Compliance score (0-100)")
status: str = Field(description="Overall status: compliant, partial, non_compliant")
issues: list[ComplianceIssue] = Field(default_factory=list)
rules_checked: int = 0
rules_passed: int = 0
class ComplianceCheckRequest(BaseModel):
"""Request body for compliance checks."""
frameworks: list[Framework] | None = Field(
default=None,
description="Frameworks to check. If null, all frameworks are checked.",
)
class ComplianceCheckResponse(BaseModel):
"""Full compliance check response for a site."""
site_id: str
results: list[FrameworkResult]
overall_score: int = Field(ge=0, le=100, description="Weighted average across all frameworks")

View File

@@ -0,0 +1,62 @@
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class ConsentAction(StrEnum):
ACCEPT_ALL = "accept_all"
REJECT_ALL = "reject_all"
CUSTOM = "custom"
WITHDRAW = "withdraw"
class ConsentRecordCreate(BaseModel):
"""Payload sent by the banner when a consent event occurs."""
site_id: uuid.UUID
visitor_id: str = Field(min_length=1, max_length=255)
action: ConsentAction
categories_accepted: list[str]
categories_rejected: list[str] | None = None
tc_string: str | None = None
gcm_state: dict | None = None
gpp_string: str | None = None
gpc_detected: bool | None = None
gpc_honoured: bool | None = None
page_url: str | None = None
country_code: str | None = Field(default=None, max_length=5)
region_code: str | None = Field(default=None, max_length=10)
class ConsentRecordResponse(BaseModel):
id: uuid.UUID
site_id: uuid.UUID
visitor_id: str
action: str
categories_accepted: list
categories_rejected: list | None = None
tc_string: str | None = None
gcm_state: dict | None = None
gpp_string: str | None = None
gpc_detected: bool | None = None
gpc_honoured: bool | None = None
page_url: str | None = None
country_code: str | None = None
region_code: str | None = None
consented_at: datetime
model_config = {"from_attributes": True}
class ConsentVerifyResponse(BaseModel):
"""Audit proof that a consent record exists."""
id: uuid.UUID
site_id: uuid.UUID
visitor_id: str
action: str
categories_accepted: list
consented_at: datetime
valid: bool = True

View File

@@ -0,0 +1,210 @@
"""Pydantic schemas for cookie categories, cookies, and allow-list entries."""
from __future__ import annotations
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
# ─── Cookie category schemas ───
class CookieCategoryResponse(BaseModel):
"""Response schema for a cookie category."""
id: uuid.UUID
name: str
slug: str
description: str | None = None
is_essential: bool
display_order: int
tcf_purpose_ids: list[int] | None = None
gcm_consent_types: list[str] | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Storage type enum ───
class StorageType(StrEnum):
"""Type of browser storage used by the cookie/tracker."""
cookie = "cookie"
local_storage = "local_storage"
session_storage = "session_storage"
indexed_db = "indexed_db"
# ─── Review status enum ───
class ReviewStatus(StrEnum):
"""Review status for a discovered cookie."""
pending = "pending"
approved = "approved"
rejected = "rejected"
# ─── Cookie schemas ───
class CookieCreate(BaseModel):
"""Schema for creating a cookie record (typically from scanner/reporter)."""
name: str = Field(..., min_length=1, max_length=255)
domain: str = Field(..., min_length=1, max_length=255)
storage_type: StorageType = StorageType.cookie
category_id: uuid.UUID | None = None
description: str | None = None
vendor: str | None = Field(None, max_length=255)
path: str | None = Field(None, max_length=500)
max_age_seconds: int | None = None
is_http_only: bool | None = None
is_secure: bool | None = None
same_site: str | None = Field(None, max_length=10)
class CookieUpdate(BaseModel):
"""Schema for updating a cookie record."""
category_id: uuid.UUID | None = None
description: str | None = None
vendor: str | None = Field(None, max_length=255)
review_status: ReviewStatus | None = None
class CookieResponse(BaseModel):
"""Response schema for a cookie."""
id: uuid.UUID
site_id: uuid.UUID
category_id: uuid.UUID | None = None
name: str
domain: str
storage_type: str
description: str | None = None
vendor: str | None = None
path: str | None = None
max_age_seconds: int | None = None
is_http_only: bool | None = None
is_secure: bool | None = None
same_site: str | None = None
review_status: str
first_seen_at: str | None = None
last_seen_at: str | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Allow-list schemas ───
class AllowListEntryCreate(BaseModel):
"""Schema for adding a cookie to the allow-list."""
name_pattern: str = Field(..., min_length=1, max_length=255)
domain_pattern: str = Field(..., min_length=1, max_length=255)
category_id: uuid.UUID
description: str | None = None
class AllowListEntryUpdate(BaseModel):
"""Schema for updating an allow-list entry."""
category_id: uuid.UUID | None = None
description: str | None = None
class AllowListEntryResponse(BaseModel):
"""Response schema for an allow-list entry."""
id: uuid.UUID
site_id: uuid.UUID
category_id: uuid.UUID
name_pattern: str
domain_pattern: str
description: str | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Known cookie schemas ───
class KnownCookieCreate(BaseModel):
"""Schema for creating a known cookie pattern."""
name_pattern: str = Field(..., min_length=1, max_length=255)
domain_pattern: str = Field(..., min_length=1, max_length=255)
category_id: uuid.UUID
vendor: str | None = Field(None, max_length=255)
description: str | None = None
is_regex: bool = False
class KnownCookieUpdate(BaseModel):
"""Schema for updating a known cookie pattern."""
category_id: uuid.UUID | None = None
vendor: str | None = Field(None, max_length=255)
description: str | None = None
is_regex: bool | None = None
class KnownCookieResponse(BaseModel):
"""Response schema for a known cookie pattern."""
id: uuid.UUID
name_pattern: str
domain_pattern: str
category_id: uuid.UUID
vendor: str | None = None
description: str | None = None
is_regex: bool
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Classification schemas ───
class ClassificationResultResponse(BaseModel):
"""Response for a single cookie classification result."""
cookie_name: str
cookie_domain: str
category_id: uuid.UUID | None = None
category_slug: str | None = None
vendor: str | None = None
description: str | None = None
match_source: str
matched: bool
class ClassifySiteResponse(BaseModel):
"""Response for classifying all cookies on a site."""
site_id: str
total: int
matched: int
unmatched: int
results: list[ClassificationResultResponse]
class ClassifySingleRequest(BaseModel):
"""Request to classify a single cookie (preview/test)."""
cookie_name: str = Field(..., min_length=1, max_length=255)
cookie_domain: str = Field(..., min_length=1, max_length=255)

View File

@@ -0,0 +1,61 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from src.schemas.site import BlockingMode
class OrgConfigUpdate(BaseModel):
"""Update (or create) organisation-level default configuration.
All fields are optional — only non-None values override the system defaults.
"""
blocking_mode: BlockingMode | None = None
regional_modes: dict | None = None
tcf_enabled: bool | None = None
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool | None = None
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool | None = None
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool | None = None
gcm_enabled: bool | None = None
gcm_default: dict | None = None
shopify_privacy_enabled: bool | None = None
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
class OrgConfigResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
blocking_mode: str | None
regional_modes: dict | None
tcf_enabled: bool | None
tcf_publisher_cc: str | None
gpp_enabled: bool | None
gpp_supported_apis: list[str] | None
gpc_enabled: bool | None
gpc_jurisdictions: list[str] | None
gpc_global_honour: bool | None
gcm_enabled: bool | None
gcm_default: dict | None
shopify_privacy_enabled: bool | None
banner_config: dict | None
privacy_policy_url: str | None
terms_url: str | None
scan_schedule_cron: str | None
scan_max_pages: int | None
consent_expiry_days: int | None
consent_retention_days: int | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,29 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class OrganisationCreate(BaseModel):
name: str = Field(min_length=1, max_length=255)
slug: str = Field(min_length=1, max_length=100, pattern=r"^[a-z0-9-]+$")
contact_email: str | None = None
billing_plan: str = "free"
class OrganisationUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
contact_email: str | None = None
billing_plan: str | None = None
class OrganisationResponse(BaseModel):
id: uuid.UUID
name: str
slug: str
contact_email: str | None
billing_plan: str
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,142 @@
"""Pydantic schemas for scanner and client-side cookie reports."""
from __future__ import annotations
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class ScanStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class ScanTrigger(StrEnum):
MANUAL = "manual"
SCHEDULED = "scheduled"
CLIENT_REPORT = "client_report"
# ── Client-side cookie report ────────────────────────────────────────
class ReportedCookie(BaseModel):
"""A single cookie/storage item reported by the client-side reporter."""
name: str = Field(..., min_length=1, max_length=255)
domain: str = Field(..., min_length=1, max_length=255)
storage_type: str = Field(default="cookie", max_length=30)
value_length: int = Field(default=0, ge=0)
path: str | None = None
is_secure: bool | None = None
same_site: str | None = None
script_source: str | None = None
class CookieReportRequest(BaseModel):
"""Payload from the client-side cookie reporter."""
site_id: uuid.UUID
page_url: str = Field(..., max_length=2000)
cookies: list[ReportedCookie] = Field(..., max_length=500)
collected_at: datetime
user_agent: str = Field(default="", max_length=500)
class CookieReportResponse(BaseModel):
"""Acknowledgement response for a cookie report."""
accepted: bool = True
cookies_received: int
new_cookies: int = 0
# ── Scan job schemas ─────────────────────────────────────────────────
class ScanResultResponse(BaseModel):
"""A single scan result — a cookie found on a specific page."""
id: uuid.UUID
scan_job_id: uuid.UUID
page_url: str
cookie_name: str
cookie_domain: str
storage_type: str
attributes: dict | None = None
script_source: str | None = None
auto_category: str | None = None
initiator_chain: list[str] | None = None
found_at: datetime
created_at: datetime
model_config = {"from_attributes": True}
class ScanJobResponse(BaseModel):
"""Response schema for a scan job."""
id: uuid.UUID
site_id: uuid.UUID
status: str
trigger: str
pages_scanned: int
pages_total: int | None
cookies_found: int
error_message: str | None
started_at: datetime | None
completed_at: datetime | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class ScanJobDetailResponse(ScanJobResponse):
"""Scan job with results included."""
results: list[ScanResultResponse] = []
class TriggerScanRequest(BaseModel):
"""Request to trigger a new scan."""
site_id: uuid.UUID
max_pages: int = Field(default=50, ge=1, le=500)
# ── Diff engine schemas ──────────────────────────────────────────────
class DiffStatus(StrEnum):
NEW = "new"
REMOVED = "removed"
CHANGED = "changed"
class CookieDiffItem(BaseModel):
"""A single cookie difference between two scans."""
name: str
domain: str
storage_type: str
diff_status: DiffStatus
details: str | None = None
class ScanDiffResponse(BaseModel):
"""Diff between two scans."""
current_scan_id: uuid.UUID
previous_scan_id: uuid.UUID | None
new_cookies: list[CookieDiffItem] = []
removed_cookies: list[CookieDiffItem] = []
changed_cookies: list[CookieDiffItem] = []
total_new: int = 0
total_removed: int = 0
total_changed: int = 0

View File

@@ -0,0 +1,117 @@
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class BlockingMode(StrEnum):
OPT_IN = "opt_in"
OPT_OUT = "opt_out"
INFORMATIONAL = "informational"
# ── Site schemas ─────────────────────────────────────────────────────
class SiteCreate(BaseModel):
domain: str = Field(min_length=1, max_length=255)
display_name: str = Field(min_length=1, max_length=255)
additional_domains: list[str] | None = None
site_group_id: uuid.UUID | None = None
class SiteUpdate(BaseModel):
display_name: str | None = Field(default=None, min_length=1, max_length=255)
is_active: bool | None = None
additional_domains: list[str] | None = None
site_group_id: uuid.UUID | None = None
class SiteResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
domain: str
display_name: str
is_active: bool
additional_domains: list[str] | None = None
site_group_id: uuid.UUID | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ── Site config schemas ──────────────────────────────────────────────
class SiteConfigCreate(BaseModel):
blocking_mode: BlockingMode = BlockingMode.OPT_IN
regional_modes: dict | None = None
tcf_enabled: bool = False
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool = True
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool = True
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool = False
gcm_enabled: bool = True
gcm_default: dict | None = None
shopify_privacy_enabled: bool = False
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int = Field(default=50, ge=1, le=1000)
consent_expiry_days: int = Field(default=365, ge=1, le=730)
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
class SiteConfigUpdate(BaseModel):
blocking_mode: BlockingMode | None = None
regional_modes: dict | None = None
tcf_enabled: bool | None = None
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool | None = None
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool | None = None
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool | None = None
gcm_enabled: bool | None = None
gcm_default: dict | None = None
shopify_privacy_enabled: bool | None = None
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
class SiteConfigResponse(BaseModel):
id: uuid.UUID
site_id: uuid.UUID
blocking_mode: str
regional_modes: dict | None
tcf_enabled: bool
tcf_publisher_cc: str | None = None
gpp_enabled: bool = True
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool = True
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool = False
gcm_enabled: bool
gcm_default: dict | None = None
shopify_privacy_enabled: bool = False
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int = 50
consent_expiry_days: int = 365
consent_retention_days: int | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,26 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class SiteGroupCreate(BaseModel):
name: str = Field(min_length=1, max_length=255)
description: str | None = None
class SiteGroupUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = None
class SiteGroupResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
name: str
description: str | None
created_at: datetime
updated_at: datetime
site_count: int = 0
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,59 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from src.schemas.site import BlockingMode
class SiteGroupConfigUpdate(BaseModel):
"""Update (or create) site-group-level default configuration.
All fields are optional — only non-None values override the org/system defaults.
"""
blocking_mode: BlockingMode | None = None
regional_modes: dict | None = None
tcf_enabled: bool | None = None
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool | None = None
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool | None = None
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool | None = None
gcm_enabled: bool | None = None
gcm_default: dict | None = None
shopify_privacy_enabled: bool | None = None
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
class SiteGroupConfigResponse(BaseModel):
id: uuid.UUID
site_group_id: uuid.UUID
blocking_mode: str | None
regional_modes: dict | None
tcf_enabled: bool | None
tcf_publisher_cc: str | None
gpp_enabled: bool | None
gpp_supported_apis: list[str] | None
gpc_enabled: bool | None
gpc_jurisdictions: list[str] | None
gpc_global_honour: bool | None
gcm_enabled: bool | None
gcm_default: dict | None
shopify_privacy_enabled: bool | None
banner_config: dict | None
privacy_policy_url: str | None
terms_url: str | None
scan_schedule_cron: str | None
scan_max_pages: int | None
consent_expiry_days: int | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,24 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class TranslationCreate(BaseModel):
locale: str = Field(min_length=2, max_length=10)
strings: dict[str, str]
class TranslationUpdate(BaseModel):
strings: dict[str, str]
class TranslationResponse(BaseModel):
id: uuid.UUID
site_id: uuid.UUID
locale: str
strings: dict[str, str]
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,36 @@
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, EmailStr, Field
class UserRole(StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
VIEWER = "viewer"
class UserCreate(BaseModel):
email: EmailStr
password: str = Field(min_length=8, max_length=72)
full_name: str = Field(min_length=1, max_length=255)
role: UserRole = UserRole.VIEWER
class UserUpdate(BaseModel):
full_name: str | None = Field(default=None, min_length=1, max_length=255)
role: UserRole | None = None
class UserResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
email: str
full_name: str
role: str
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

View File

@@ -0,0 +1,59 @@
import uuid
from datetime import UTC, datetime, timedelta
import bcrypt
from jose import jwt
from src.config.settings import get_settings
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def create_access_token(
user_id: uuid.UUID,
organisation_id: uuid.UUID,
role: str,
email: str,
) -> str:
settings = get_settings()
now = datetime.now(UTC)
expire = now + timedelta(minutes=settings.jwt_access_token_expire_minutes)
payload = {
"sub": str(user_id),
"org_id": str(organisation_id),
"role": role,
"email": email,
"exp": expire,
"iat": now,
"type": "access",
}
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
def create_refresh_token(
user_id: uuid.UUID,
organisation_id: uuid.UUID,
) -> str:
settings = get_settings()
now = datetime.now(UTC)
expire = now + timedelta(days=settings.jwt_refresh_token_expire_days)
payload = {
"sub": str(user_id),
"org_id": str(organisation_id),
"exp": expire,
"iat": now,
"type": "refresh",
}
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
def decode_token(token: str) -> dict:
"""Decode and validate a JWT token. Raises JWTError on failure."""
settings = get_settings()
return jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])

View File

@@ -0,0 +1,79 @@
"""First-run bootstrap of an organisation and owner user.
Runs once on API startup. If ``INITIAL_ADMIN_EMAIL`` and
``INITIAL_ADMIN_PASSWORD`` are set and the ``users`` table is empty,
creates an organisation and a single owner user so the operator can log
in to the admin UI for the first time. Idempotent: once any user
exists, this is a no-op, so the environment variables can safely remain
set across restarts. Complements ``ADMIN_BOOTSTRAP_TOKEN`` — that gates
runtime org creation; this creates the *initial* org + owner without
requiring a second round-trip.
"""
import logging
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import Settings
from src.db.session import async_session_factory
from src.models.organisation import Organisation
from src.models.user import User
from src.services.auth import hash_password
logger = logging.getLogger(__name__)
async def bootstrap_initial_admin(settings: Settings) -> None:
"""Create the first organisation and owner user if none exist.
No-op when either credential env var is unset or when the database
already contains at least one user. Unexpected errors are logged
and swallowed — a failed bootstrap must not prevent the API from
starting, since operators can always fall back to manual provisioning.
"""
if not settings.initial_admin_email or not settings.initial_admin_password:
logger.debug("Initial admin bootstrap skipped: credentials not configured")
return
try:
async with async_session_factory() as session:
await _bootstrap(session, settings)
except Exception: # pragma: no cover — defensive, logged
logger.exception("Initial admin bootstrap failed")
async def _bootstrap(session: AsyncSession, settings: Settings) -> None:
existing_users = await session.scalar(select(func.count()).select_from(User))
if existing_users:
logger.debug("Initial admin bootstrap skipped: %d user(s) already exist", existing_users)
return
org = await session.scalar(
select(Organisation).where(Organisation.slug == settings.initial_org_slug)
)
if org is None:
org = Organisation(
name=settings.initial_org_name,
slug=settings.initial_org_slug,
contact_email=settings.initial_admin_email,
)
session.add(org)
await session.flush()
user = User(
organisation_id=org.id,
email=settings.initial_admin_email,
password_hash=hash_password(settings.initial_admin_password),
full_name=settings.initial_admin_full_name,
role="owner",
)
session.add(user)
await session.commit()
logger.warning(
"Initial admin bootstrap created owner %s in organisation '%s'. "
"Rotate the password via the admin UI as soon as possible.",
settings.initial_admin_email,
org.slug,
)

View File

@@ -0,0 +1,298 @@
"""Cookie auto-categorisation engine.
Matches discovered cookies against the known_cookies database using exact name
matching, domain matching, and regex patterns. Also checks site-specific
allow-list entries. Returns a classification result with category, vendor, and
confidence level.
Matching priority (highest first):
1. Site-specific allow-list (exact or pattern match)
2. Known cookies — exact name + domain match
3. Known cookies — regex pattern match on name + domain
4. Unmatched → remains as 'pending'
"""
from __future__ import annotations
import re
import uuid
from dataclasses import dataclass
from enum import StrEnum
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.models.cookie import (
Cookie,
CookieAllowListEntry,
CookieCategory,
KnownCookie,
)
class MatchSource(StrEnum):
"""Where the classification match came from."""
ALLOW_LIST = "allow_list"
KNOWN_EXACT = "known_exact"
KNOWN_REGEX = "known_regex"
UNMATCHED = "unmatched"
@dataclass
class ClassificationResult:
"""Result of classifying a single cookie."""
cookie_name: str
cookie_domain: str
category_id: uuid.UUID | None = None
category_slug: str | None = None
vendor: str | None = None
description: str | None = None
match_source: MatchSource = MatchSource.UNMATCHED
matched: bool = False
async def _load_allow_list(
db: AsyncSession,
site_id: uuid.UUID,
) -> list[CookieAllowListEntry]:
"""Load the allow-list entries for a site."""
result = await db.execute(
select(CookieAllowListEntry).where(
CookieAllowListEntry.site_id == site_id,
)
)
return list(result.scalars().all())
async def _load_known_cookies(
db: AsyncSession,
) -> tuple[list[KnownCookie], list[KnownCookie]]:
"""Load known cookies, split into exact and regex lists."""
result = await db.execute(select(KnownCookie))
all_known = list(result.scalars().all())
exact = [k for k in all_known if not k.is_regex]
regex = [k for k in all_known if k.is_regex]
return exact, regex
async def _load_category_map(
db: AsyncSession,
) -> dict[uuid.UUID, CookieCategory]:
"""Load a mapping of category ID to CookieCategory."""
result = await db.execute(select(CookieCategory))
return {cat.id: cat for cat in result.scalars().all()}
def _match_pattern(pattern: str, value: str) -> bool:
"""Check if a value matches a pattern (case-insensitive).
Patterns support:
- Exact match (e.g. "_ga")
- Wildcard with * (e.g. "_ga*", "*.google.com")
- Regex if it contains regex-specific characters
"""
if not pattern or not value:
return False
pattern_lower = pattern.lower()
value_lower = value.lower()
# Simple exact match
if pattern_lower == value_lower:
return True
# Wildcard: convert * to regex .*
if "*" in pattern_lower:
regex_pattern = "^" + re.escape(pattern_lower).replace(r"\*", ".*") + "$"
return bool(re.match(regex_pattern, value_lower))
return False
def _match_regex(pattern: str, value: str) -> bool:
"""Match a value against a regex pattern (case-insensitive)."""
try:
return bool(re.match(pattern, value, re.IGNORECASE))
except re.error:
return False
def _match_allow_list(
cookie_name: str,
cookie_domain: str,
allow_list: list[CookieAllowListEntry],
) -> CookieAllowListEntry | None:
"""Check if a cookie matches any allow-list entry."""
for entry in allow_list:
name_match = _match_pattern(entry.name_pattern, cookie_name)
domain_match = _match_pattern(entry.domain_pattern, cookie_domain)
if name_match and domain_match:
return entry
return None
def _match_exact_known(
cookie_name: str,
cookie_domain: str,
exact_known: list[KnownCookie],
) -> KnownCookie | None:
"""Find an exact match in the known cookies database."""
for known in exact_known:
name_match = _match_pattern(known.name_pattern, cookie_name)
domain_match = _match_pattern(known.domain_pattern, cookie_domain)
if name_match and domain_match:
return known
return None
def _match_regex_known(
cookie_name: str,
cookie_domain: str,
regex_known: list[KnownCookie],
) -> KnownCookie | None:
"""Find a regex match in the known cookies database."""
for known in regex_known:
name_match = _match_regex(known.name_pattern, cookie_name)
domain_match = _match_regex(known.domain_pattern, cookie_domain)
if name_match and domain_match:
return known
return None
def classify_cookie(
cookie_name: str,
cookie_domain: str,
allow_list: list[CookieAllowListEntry],
exact_known: list[KnownCookie],
regex_known: list[KnownCookie],
category_map: dict[uuid.UUID, CookieCategory],
) -> ClassificationResult:
"""Classify a single cookie against allow-list and known cookies DB.
This is a pure function — all data is passed in, no DB calls.
"""
# 1. Check allow-list first (site-specific overrides)
allow_match = _match_allow_list(cookie_name, cookie_domain, allow_list)
if allow_match:
cat = category_map.get(allow_match.category_id)
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
category_id=allow_match.category_id,
category_slug=cat.slug if cat else None,
description=allow_match.description,
match_source=MatchSource.ALLOW_LIST,
matched=True,
)
# 2. Check exact known cookies
exact_match = _match_exact_known(cookie_name, cookie_domain, exact_known)
if exact_match:
cat = category_map.get(exact_match.category_id)
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
category_id=exact_match.category_id,
category_slug=cat.slug if cat else None,
vendor=exact_match.vendor,
description=exact_match.description,
match_source=MatchSource.KNOWN_EXACT,
matched=True,
)
# 3. Check regex known cookies
regex_match = _match_regex_known(cookie_name, cookie_domain, regex_known)
if regex_match:
cat = category_map.get(regex_match.category_id)
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
category_id=regex_match.category_id,
category_slug=cat.slug if cat else None,
vendor=regex_match.vendor,
description=regex_match.description,
match_source=MatchSource.KNOWN_REGEX,
matched=True,
)
# 4. Unmatched
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
)
async def classify_site_cookies(
db: AsyncSession,
site_id: uuid.UUID,
*,
only_pending: bool = True,
) -> list[ClassificationResult]:
"""Classify all cookies for a site against known patterns.
If only_pending is True, only cookies with review_status='pending'
and no category are classified.
Returns a list of results. Also updates matching cookies in the DB.
"""
# Load lookup data
allow_list = await _load_allow_list(db, site_id)
exact_known, regex_known = await _load_known_cookies(db)
category_map = await _load_category_map(db)
# Load cookies to classify
query = select(Cookie).where(Cookie.site_id == site_id)
if only_pending:
query = query.where(
Cookie.review_status == "pending",
Cookie.category_id.is_(None),
)
result = await db.execute(query)
cookies = list(result.scalars().all())
results: list[ClassificationResult] = []
for cookie in cookies:
cr = classify_cookie(
cookie.name,
cookie.domain,
allow_list,
exact_known,
regex_known,
category_map,
)
results.append(cr)
# Update the cookie if matched
if cr.matched and cr.category_id:
cookie.category_id = cr.category_id
if cr.vendor and not cookie.vendor:
cookie.vendor = cr.vendor
if cr.description and not cookie.description:
cookie.description = cr.description
await db.flush()
return results
async def classify_single_cookie(
db: AsyncSession,
site_id: uuid.UUID,
cookie_name: str,
cookie_domain: str,
) -> ClassificationResult:
"""Classify a single cookie (e.g. for preview/testing)."""
allow_list = await _load_allow_list(db, site_id)
exact_known, regex_known = await _load_known_cookies(db)
category_map = await _load_category_map(db)
return classify_cookie(
cookie_name,
cookie_domain,
allow_list,
exact_known,
regex_known,
category_map,
)

View File

@@ -0,0 +1,482 @@
"""Pluggable compliance rule engine.
Each regulatory framework (GDPR, CNIL, CCPA, ePrivacy, LGPD) is defined as a
list of ComplianceRule objects. Rules evaluate site configuration, banner
settings, cookie data, and consent parameters to produce issues with severity,
message, and recommendation.
The engine aggregates individual rule results into per-framework reports with
a compliance score, status, and actionable issues list.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from src.schemas.compliance import (
ComplianceIssue,
Framework,
FrameworkResult,
Severity,
)
# ── Rule context ──────────────────────────────────────────────────────
@dataclass
class SiteContext:
"""All data needed to evaluate compliance rules for a site."""
# Site config fields
blocking_mode: str = "opt_in"
regional_modes: dict[str, str] | None = None
tcf_enabled: bool = False
gcm_enabled: bool = True
consent_expiry_days: int = 365
privacy_policy_url: str | None = None
# Banner config (JSONB — may have any keys)
banner_config: dict[str, Any] | None = None
# Cookie statistics
total_cookies: int = 0
uncategorised_cookies: int = 0
cookies_without_expiry: int = 0
# Consent settings
has_reject_button: bool = True
has_granular_choices: bool = True
has_cookie_wall: bool = False
pre_ticked_boxes: bool = False
# ── Rule definition ───────────────────────────────────────────────────
# A check function receives a SiteContext and returns a list of issues.
CheckFn = Callable[[SiteContext], list[ComplianceIssue]]
@dataclass
class ComplianceRule:
"""A single compliance rule with an ID, description, and check function."""
rule_id: str
description: str
check: CheckFn
# ── Helper factories ──────────────────────────────────────────────────
def _issue(
rule_id: str,
severity: Severity,
message: str,
recommendation: str,
) -> ComplianceIssue:
return ComplianceIssue(
rule_id=rule_id,
severity=severity,
message=message,
recommendation=recommendation,
)
# ── GDPR rules ────────────────────────────────────────────────────────
def _gdpr_opt_in(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.blocking_mode != "opt_in":
return [
_issue(
"gdpr_opt_in",
Severity.CRITICAL,
"GDPR requires opt-in consent before setting non-essential cookies.",
"Set blocking mode to 'opt_in'.",
)
]
return []
def _gdpr_reject_button(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.has_reject_button:
return [
_issue(
"gdpr_reject_button",
Severity.CRITICAL,
"The reject option must be as prominent as the accept option.",
"Add a clearly visible 'Reject all' button to the first layer.",
)
]
return []
def _gdpr_granular_consent(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.has_granular_choices:
return [
_issue(
"gdpr_granular",
Severity.CRITICAL,
"Users must be able to consent to individual cookie categories.",
"Provide granular category toggles in the consent banner.",
)
]
return []
def _gdpr_no_cookie_wall(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.has_cookie_wall:
return [
_issue(
"gdpr_cookie_wall",
Severity.CRITICAL,
"Cookie walls (blocking access unless consent is given) are not permitted.",
"Remove the cookie wall and allow access without consent.",
)
]
return []
def _gdpr_no_pre_ticked(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.pre_ticked_boxes:
return [
_issue(
"gdpr_pre_ticked",
Severity.CRITICAL,
"Pre-ticked consent boxes do not constitute valid consent.",
"Ensure all non-essential category checkboxes default to unchecked.",
)
]
return []
def _gdpr_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.privacy_policy_url:
return [
_issue(
"gdpr_privacy_policy",
Severity.WARNING,
"A link to the privacy policy should be accessible from the banner.",
"Configure a privacy policy URL in the site settings.",
)
]
return []
def _gdpr_uncategorised_cookies(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.uncategorised_cookies > 0:
return [
_issue(
"gdpr_uncategorised",
Severity.WARNING,
f"{ctx.uncategorised_cookies} cookie(s) have not been categorised.",
"Review and assign a category to all discovered cookies.",
)
]
return []
GDPR_RULES: list[ComplianceRule] = [
ComplianceRule("gdpr_opt_in", "Opt-in consent required", _gdpr_opt_in),
ComplianceRule("gdpr_reject_button", "Reject as prominent as accept", _gdpr_reject_button),
ComplianceRule("gdpr_granular", "Granular category consent", _gdpr_granular_consent),
ComplianceRule("gdpr_cookie_wall", "No cookie walls", _gdpr_no_cookie_wall),
ComplianceRule("gdpr_pre_ticked", "No pre-ticked boxes", _gdpr_no_pre_ticked),
ComplianceRule("gdpr_privacy_policy", "Privacy policy link", _gdpr_privacy_policy),
ComplianceRule("gdpr_uncategorised", "All cookies categorised", _gdpr_uncategorised_cookies),
]
# ── CNIL rules (French — stricter GDPR) ──────────────────────────────
def _cnil_consent_expiry(ctx: SiteContext) -> list[ComplianceIssue]:
"""CNIL mandates re-consent every 6 months (≈ 182 days)."""
if ctx.consent_expiry_days > 182:
return [
_issue(
"cnil_reconsent",
Severity.CRITICAL,
"CNIL requires re-consent at least every 6 months.",
"Set consent_expiry_days to 182 or fewer.",
)
]
return []
def _cnil_cookie_lifetime(ctx: SiteContext) -> list[ComplianceIssue]:
"""CNIL limits cookie lifetime to 13 months (≈ 395 days)."""
if ctx.consent_expiry_days > 395:
return [
_issue(
"cnil_cookie_lifetime",
Severity.CRITICAL,
"CNIL limits consent cookie lifetime to 13 months.",
"Set consent_expiry_days to 395 or fewer.",
)
]
return []
def _cnil_reject_first_layer(ctx: SiteContext) -> list[ComplianceIssue]:
"""CNIL requires 'Tout refuser' on the first layer of the banner."""
if not ctx.has_reject_button:
return [
_issue(
"cnil_reject_first_layer",
Severity.CRITICAL,
"CNIL requires a 'Reject all' button on the first layer of the banner.",
"Ensure the 'Reject all' button is visible on the first banner view.",
)
]
return []
# CNIL rules include all GDPR rules plus CNIL-specific ones
CNIL_RULES: list[ComplianceRule] = [
*GDPR_RULES,
ComplianceRule("cnil_reconsent", "Re-consent every 6 months", _cnil_consent_expiry),
ComplianceRule("cnil_cookie_lifetime", "13-month cookie lifetime", _cnil_cookie_lifetime),
ComplianceRule(
"cnil_reject_first_layer",
"Reject on first layer",
_cnil_reject_first_layer,
),
]
# ── CCPA / CPRA rules ────────────────────────────────────────────────
def _ccpa_opt_out(ctx: SiteContext) -> list[ComplianceIssue]:
"""CCPA uses an opt-out model — blocking mode should be opt_out."""
if ctx.blocking_mode not in ("opt_out", "opt_in"):
return [
_issue(
"ccpa_opt_out",
Severity.CRITICAL,
"CCPA requires at minimum an opt-out mechanism for data sale.",
"Set blocking mode to 'opt_out' or 'opt_in'.",
)
]
return []
def _ccpa_do_not_sell(ctx: SiteContext) -> list[ComplianceIssue]:
"""CCPA requires a 'Do Not Sell My Personal Information' link."""
bc = ctx.banner_config or {}
has_dns = bc.get("show_do_not_sell_link", False)
if not has_dns:
return [
_issue(
"ccpa_do_not_sell",
Severity.CRITICAL,
"CCPA requires a 'Do Not Sell My Personal Information' link.",
"Enable 'show_do_not_sell_link' in the banner configuration.",
)
]
return []
def _ccpa_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.privacy_policy_url:
return [
_issue(
"ccpa_privacy_policy",
Severity.WARNING,
"A privacy policy is required under CCPA.",
"Configure a privacy policy URL in the site settings.",
)
]
return []
CCPA_RULES: list[ComplianceRule] = [
ComplianceRule("ccpa_opt_out", "Opt-out mechanism", _ccpa_opt_out),
ComplianceRule("ccpa_do_not_sell", "Do Not Sell link", _ccpa_do_not_sell),
ComplianceRule("ccpa_privacy_policy", "Privacy policy required", _ccpa_privacy_policy),
]
# ── ePrivacy rules ───────────────────────────────────────────────────
def _eprivacy_consent(ctx: SiteContext) -> list[ComplianceIssue]:
"""ePrivacy requires consent for non-essential cookies."""
if ctx.blocking_mode == "informational":
return [
_issue(
"eprivacy_consent",
Severity.CRITICAL,
"ePrivacy Directive requires consent for non-essential cookies.",
"Set blocking mode to 'opt_in' or 'opt_out'.",
)
]
return []
def _eprivacy_necessary_exempt(ctx: SiteContext) -> list[ComplianceIssue]:
"""Strictly necessary cookies must be exempt from consent."""
# This is a configuration guidance check — ensure opt-in mode
# doesn't block necessary cookies (which the blocker handles by default).
# We report an info if everything looks good.
return []
EPRIVACY_RULES: list[ComplianceRule] = [
ComplianceRule("eprivacy_consent", "Consent for non-essential", _eprivacy_consent),
ComplianceRule(
"eprivacy_necessary_exempt",
"Necessary cookies exempt",
_eprivacy_necessary_exempt,
),
]
# ── LGPD rules (Brazil) ──────────────────────────────────────────────
def _lgpd_consent_basis(ctx: SiteContext) -> list[ComplianceIssue]:
"""LGPD requires consent or legitimate interest as legal basis."""
if ctx.blocking_mode == "informational":
return [
_issue(
"lgpd_consent_basis",
Severity.CRITICAL,
"LGPD requires a legal basis (consent or legitimate interest) for data processing.",
"Set blocking mode to 'opt_in' or 'opt_out'.",
)
]
return []
def _lgpd_data_controller(ctx: SiteContext) -> list[ComplianceIssue]:
"""LGPD requires identifying the data controller."""
if not ctx.privacy_policy_url:
return [
_issue(
"lgpd_data_controller",
Severity.WARNING,
"LGPD requires identification of the data controller.",
"Link to a privacy policy that identifies the data controller.",
)
]
return []
def _lgpd_granular(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.has_granular_choices:
return [
_issue(
"lgpd_granular",
Severity.WARNING,
"LGPD recommends granular consent choices.",
"Provide individual category toggles in the consent banner.",
)
]
return []
LGPD_RULES: list[ComplianceRule] = [
ComplianceRule("lgpd_consent_basis", "Legal basis for processing", _lgpd_consent_basis),
ComplianceRule("lgpd_data_controller", "Identify data controller", _lgpd_data_controller),
ComplianceRule("lgpd_granular", "Granular consent choices", _lgpd_granular),
]
# ── Framework registry ────────────────────────────────────────────────
FRAMEWORK_RULES: dict[Framework, list[ComplianceRule]] = {
Framework.GDPR: GDPR_RULES,
Framework.CNIL: CNIL_RULES,
Framework.CCPA: CCPA_RULES,
Framework.EPRIVACY: EPRIVACY_RULES,
Framework.LGPD: LGPD_RULES,
}
# ── Engine ────────────────────────────────────────────────────────────
def run_framework_check(
framework: Framework,
ctx: SiteContext,
) -> FrameworkResult:
"""Run all rules for a single framework and produce a result."""
rules = FRAMEWORK_RULES.get(framework, [])
all_issues: list[ComplianceIssue] = []
rules_passed = 0
for rule in rules:
issues = rule.check(ctx)
if issues:
all_issues.extend(issues)
else:
rules_passed += 1
rules_checked = len(rules)
score = _calculate_score(all_issues, rules_checked)
status = _determine_status(score, all_issues)
return FrameworkResult(
framework=framework,
score=score,
status=status,
issues=all_issues,
rules_checked=rules_checked,
rules_passed=rules_passed,
)
def run_compliance_check(
ctx: SiteContext,
frameworks: list[Framework] | None = None,
) -> list[FrameworkResult]:
"""Run compliance checks for the specified (or all) frameworks."""
targets = frameworks if frameworks else list(FRAMEWORK_RULES.keys())
return [run_framework_check(fw, ctx) for fw in targets]
def calculate_overall_score(results: list[FrameworkResult]) -> int:
"""Calculate a weighted average score across framework results."""
if not results:
return 100
total = sum(r.score for r in results)
return round(total / len(results))
# ── Scoring helpers ───────────────────────────────────────────────────
def _calculate_score(
issues: list[ComplianceIssue],
rules_checked: int,
) -> int:
"""Score from 0-100. Critical issues deduct 20 pts, warnings 5 pts."""
if rules_checked == 0:
return 100
deductions = 0
for issue in issues:
if issue.severity == Severity.CRITICAL:
deductions += 20
elif issue.severity == Severity.WARNING:
deductions += 5
# INFO issues don't affect the score
return max(0, 100 - deductions)
def _determine_status(
score: int,
issues: list[ComplianceIssue],
) -> str:
"""Derive overall status string from score and issues."""
has_critical = any(i.severity == Severity.CRITICAL for i in issues)
if has_critical:
return "non_compliant"
if score >= 100:
return "compliant"
return "partial"

View File

@@ -0,0 +1,156 @@
"""Configuration hierarchy resolver.
Resolves site configuration by merging:
System Defaults → Org Defaults → Site Group Defaults → Site Config → Regional Overrides
Produces a fully resolved public config suitable for the banner script.
"""
from __future__ import annotations
from typing import Any
# System-level defaults (hard-coded, lowest priority)
SYSTEM_DEFAULTS: dict[str, Any] = {
"blocking_mode": "opt_in",
"tcf_enabled": False,
"gpp_enabled": True,
"gpp_supported_apis": ["usnat"],
"gpc_enabled": True,
"gpc_jurisdictions": ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"],
"gpc_global_honour": False,
"gcm_enabled": True,
"shopify_privacy_enabled": False,
"gcm_default": {
"ad_storage": "denied",
"ad_user_data": "denied",
"ad_personalization": "denied",
"analytics_storage": "denied",
"functionality_storage": "denied",
"personalization_storage": "denied",
"security_storage": "granted",
},
"banner_config": None,
"privacy_policy_url": None,
"terms_url": None,
"consent_expiry_days": 365,
}
def resolve_config(
site_config: dict[str, Any],
org_defaults: dict[str, Any] | None = None,
group_defaults: dict[str, Any] | None = None,
region: str | None = None,
) -> dict[str, Any]:
"""Resolve the full configuration by merging layers.
Args:
site_config: Site-specific configuration from the database.
org_defaults: Organisation-level default overrides (optional).
group_defaults: Site-group-level default overrides (optional).
region: ISO region code for regional mode override (optional).
Returns:
Fully resolved configuration dictionary.
"""
# Start with system defaults
resolved = {**SYSTEM_DEFAULTS}
# Apply organisation defaults (if any)
if org_defaults:
_merge_non_none(resolved, org_defaults)
# Apply site group defaults (if any)
if group_defaults:
_merge_non_none(resolved, group_defaults)
# Apply site-specific config
_merge_non_none(resolved, site_config)
# Apply regional blocking mode override
if region and site_config.get("regional_modes"):
regional_modes = site_config["regional_modes"]
if isinstance(regional_modes, dict):
# Try exact match first, then fall back to DEFAULT
regional_mode = regional_modes.get(region) or regional_modes.get("DEFAULT")
if regional_mode:
resolved["blocking_mode"] = regional_mode
return resolved
def build_public_config(
site_id: str,
resolved: dict[str, Any],
) -> dict[str, Any]:
"""Build a public configuration JSON for the banner script.
Strips internal fields and adds the site_id for identification.
"""
return {
"id": resolved.get("id", ""),
"site_id": site_id,
"blocking_mode": resolved["blocking_mode"],
"regional_modes": resolved.get("regional_modes"),
"tcf_enabled": resolved["tcf_enabled"],
"gpp_enabled": resolved["gpp_enabled"],
"gpp_supported_apis": resolved.get("gpp_supported_apis"),
"gpc_enabled": resolved["gpc_enabled"],
"gpc_jurisdictions": resolved.get("gpc_jurisdictions"),
"gpc_global_honour": resolved["gpc_global_honour"],
"gcm_enabled": resolved["gcm_enabled"],
"gcm_default": resolved.get("gcm_default"),
"shopify_privacy_enabled": resolved["shopify_privacy_enabled"],
"banner_config": resolved.get("banner_config"),
"privacy_policy_url": resolved.get("privacy_policy_url"),
"terms_url": resolved.get("terms_url"),
"consent_expiry_days": resolved["consent_expiry_days"],
"consent_group_id": resolved.get("consent_group_id"),
"ab_test": resolved.get("ab_test"),
}
CONFIG_FIELDS = (
"blocking_mode",
"regional_modes",
"tcf_enabled",
"tcf_publisher_cc",
"gpp_enabled",
"gpp_supported_apis",
"gpc_enabled",
"gpc_jurisdictions",
"gpc_global_honour",
"gcm_enabled",
"gcm_default",
"shopify_privacy_enabled",
"banner_config",
"privacy_policy_url",
"terms_url",
"consent_expiry_days",
)
def orm_to_config_dict(obj: Any, *, include_id: bool = False) -> dict[str, Any]:
"""Convert a SiteConfig or OrgConfig ORM object to a dict of config fields.
Only includes fields that are explicitly set (not NULL). This allows the
hierarchy to work correctly: unset fields at higher-priority layers don't
block inheritance from lower-priority layers.
"""
d: dict[str, Any] = {}
if include_id and hasattr(obj, "id"):
d["id"] = str(obj.id)
for field in CONFIG_FIELDS:
if hasattr(obj, field):
value = getattr(obj, field)
if value is not None:
d[field] = value
return d
def _merge_non_none(target: dict[str, Any], source: dict[str, Any]) -> None:
"""Merge source into target, skipping None values in source."""
for key, value in source.items():
if value is not None:
target[key] = value

View File

@@ -0,0 +1,77 @@
"""Dynamic CORS origin validation.
Provides an origin validator that checks incoming origins against
registered site domains (primary + additional) in addition to the
statically configured allowed_origins list.
"""
from __future__ import annotations
import logging
from urllib.parse import urlparse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.models.site import Site
logger = logging.getLogger(__name__)
def extract_domain_from_origin(origin: str) -> str | None:
"""Extract the hostname from an origin URL.
e.g. 'https://example.com:443''example.com'
"""
try:
parsed = urlparse(origin)
return parsed.hostname
except Exception:
return None
async def get_allowed_domains(db: AsyncSession) -> set[str]:
"""Fetch all registered domains (primary + additional) from active sites."""
result = await db.execute(
select(Site.domain, Site.additional_domains).where(
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
domains: set[str] = set()
for row in result.all():
domains.add(row.domain.lower())
if row.additional_domains:
for d in row.additional_domains:
domains.add(d.lower())
return domains
def is_origin_allowed(
origin: str,
static_origins: list[str],
registered_domains: set[str],
) -> bool:
"""Check if an origin is allowed by either the static list or registered domains.
Args:
origin: The Origin header value (e.g. 'https://example.com').
static_origins: Statically configured allowed origins from settings.
registered_domains: Set of registered site domains from the database.
Returns:
True if the origin is allowed.
"""
# Check static origins first (exact match)
if origin in static_origins:
return True
# Wildcard — allow everything
if "*" in static_origins:
return True
# Extract domain from origin and check against registered domains
domain = extract_domain_from_origin(origin)
return bool(domain and domain.lower() in registered_domains)

View File

@@ -0,0 +1,54 @@
import uuid
from collections.abc import Callable
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError
from src.schemas.auth import CurrentUser
from src.services.auth import decode_token
bearer_scheme = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
) -> CurrentUser:
"""Extract and validate the current user from the JWT bearer token."""
try:
payload = decode_token(credentials.credentials)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
if payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
return CurrentUser(
id=uuid.UUID(payload["sub"]),
organisation_id=uuid.UUID(payload["org_id"]),
email=payload.get("email", ""),
role=payload.get("role", "viewer"),
)
def require_role(*allowed_roles: str) -> Callable:
"""Dependency factory that restricts access to users with specific roles."""
async def _check_role(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
if not current_user.has_role(*allowed_roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role '{current_user.role}' is not permitted for this action",
)
return current_user
return _check_role

View File

@@ -0,0 +1,339 @@
"""GeoIP service — resolve an IP address to a country/region code.
Resolution order (see :func:`detect_region`):
1. **CDN / proxy headers.** Operators configure ``GEOIP_COUNTRY_HEADER``
(and optionally ``GEOIP_REGION_HEADER``) to match whatever their edge
uses — e.g. ``cf-ipcountry`` + ``cf-region-code`` on Cloudflare
Enterprise, or ``x-gclb-country`` + ``x-gclb-region`` on GCP. A short
built-in country list (``cf-ipcountry``, ``x-vercel-ip-country``,
``x-appengine-country``, ``x-country-code``) covers the common case
where only country-level granularity is needed.
2. **Local MaxMind GeoLite2-City database.** Set
``GEOIP_MAXMIND_DB_PATH`` to a mounted ``.mmdb`` file. Gives both
country and ISO 3166-2 subdivision without any external calls.
3. **External ip-api.com lookup** (rate-limited, no API key). Last-ditch
fallback; fine for development, not recommended for production.
4. Unresolved — the caller should fall back to the default region.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
import geoip2.database
import httpx
from fastapi import Request
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
# Lazily-initialised MaxMind reader. ``geoip2.database.Reader`` opens
# the file once and then every lookup is a memory-mapped read, so we
# cache it for the lifetime of the process. ``None`` means either no
# path is configured, initialisation failed, or we haven't tried yet.
_maxmind_reader: geoip2.database.Reader | None = None
_maxmind_initialised = False
# Standard headers set by CDN / reverse proxy providers. Operators
# running behind a CDN that uses a non-standard header (e.g. Google
# Cloud Load Balancer's ``x-gclb-country``) can add one more via the
# ``GEOIP_COUNTRY_HEADER`` env var — see ``detect_region_from_headers``.
_GEO_HEADERS = [
"cf-ipcountry", # Cloudflare
"x-vercel-ip-country", # Vercel
"x-appengine-country", # Google App Engine
"x-country-code", # Generic / custom
]
# Mapping from two-letter country code to region codes used in regional_modes
# EU member states → "EU", US states handled separately, etc.
_EU_COUNTRIES = frozenset(
{
"AT",
"BE",
"BG",
"HR",
"CY",
"CZ",
"DK",
"EE",
"FI",
"FR",
"DE",
"GR",
"HU",
"IE",
"IT",
"LV",
"LT",
"LU",
"MT",
"NL",
"PL",
"PT",
"RO",
"SK",
"SI",
"ES",
"SE",
}
)
@dataclass(frozen=True)
class GeoResult:
"""Result of a GeoIP lookup."""
country_code: str | None
region: str | None
@property
def is_resolved(self) -> bool:
return self.country_code is not None
def country_to_region(country_code: str, state_code: str | None = None) -> str:
"""Map a country code (+ optional subdivision) to a regional_modes key.
Resolution order:
- EU member states collapse to ``"EU"`` regardless of subdivision;
regional_modes treats the bloc as a single unit.
- Any other country with a subdivision produces ``"{CC}-{SUB}"``
(e.g. ``"US-CA"``, ``"GB-SCT"``, ``"BR-SP"``). The operator
opts in to subdivision-level resolution by configuring a key
of that form in ``regional_modes``; if they don't, the
fallback resolver still matches on the plain country code.
- Country with no subdivision is returned as-is (``"GB"``,
``"BR"``, …).
"""
upper = country_code.upper()
if upper in _EU_COUNTRIES:
return "EU"
if state_code:
return f"{upper}-{state_code.upper()}"
return upper
def detect_region_from_headers(request: Request) -> GeoResult:
"""Attempt to detect the visitor's region from proxy/CDN headers.
This is the fastest path — no external calls needed. A custom
country header configured via ``GEOIP_COUNTRY_HEADER`` takes
priority over the built-in list so operators can plumb in
non-standard CDN/load-balancer headers (e.g. ``x-gclb-country``)
without code changes.
When ``GEOIP_REGION_HEADER`` is also set and the custom country
header resolved, the subdivision code from that header is paired
with the country to build region keys like ``US-CA``. The built-in
country list is country-only — operators who need subdivision
granularity must configure the explicit pair.
Header lookups are case-insensitive.
"""
settings = get_settings()
custom_country = settings.geoip_country_header
custom_region = settings.geoip_region_header
if custom_country:
value = request.headers.get(custom_country)
if value and value.upper() != "XX":
country = value.upper().strip()
state: str | None = None
if custom_region:
raw_state = request.headers.get(custom_region)
if raw_state and raw_state.upper() != "XX":
# ISO 3166-2 subdivision codes may be prefixed
# with the country (e.g. ``US-CA``) or bare (e.g.
# ``CA``). Strip the prefix so ``country_to_region``
# sees just the subdivision.
stripped = raw_state.strip().upper()
state = stripped.split("-", 1)[-1] if "-" in stripped else stripped
return GeoResult(
country_code=country,
region=country_to_region(country, state),
)
for header in _GEO_HEADERS:
value = request.headers.get(header)
if value and value.upper() != "XX":
country = value.upper().strip()
return GeoResult(
country_code=country,
region=country_to_region(country),
)
return GeoResult(country_code=None, region=None)
def get_client_ip(request: Request) -> str | None:
"""Extract the real client IP from the request.
Checks X-Forwarded-For and X-Real-IP before falling back to the
direct connection address.
"""
# X-Forwarded-For: client, proxy1, proxy2
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
if request.client:
return request.client.host
return None
async def lookup_ip_region(ip: str) -> GeoResult:
"""Look up the region for an IP address via an external API.
Uses ip-api.com (free tier, no key required, 45 req/min).
In production this should be replaced with a local MaxMind database.
"""
if _is_private_ip(ip):
return GeoResult(country_code=None, region=None)
try:
async with httpx.AsyncClient(timeout=3.0) as client:
resp = await client.get(
f"http://ip-api.com/json/{ip}",
params={"fields": "status,countryCode,region"},
)
if resp.status_code != 200:
return GeoResult(country_code=None, region=None)
data = resp.json()
if data.get("status") != "success":
return GeoResult(country_code=None, region=None)
country = data.get("countryCode")
state = data.get("region") # State/province code
if not country:
return GeoResult(country_code=None, region=None)
region = country_to_region(country, state)
return GeoResult(country_code=country, region=region)
except Exception:
logger.debug("GeoIP lookup failed for %s", ip, exc_info=True)
return GeoResult(country_code=None, region=None)
def _get_maxmind_reader() -> geoip2.database.Reader | None:
"""Return the cached MaxMind reader, opening the DB on first use.
Caches both successful opens and failures (via
``_maxmind_initialised``) so we don't retry a bad path on every
request. Returns ``None`` if no path is configured or the DB
couldn't be opened.
"""
global _maxmind_reader, _maxmind_initialised
if _maxmind_initialised:
return _maxmind_reader
_maxmind_initialised = True
db_path = get_settings().geoip_maxmind_db_path
if not db_path:
return None
try:
_maxmind_reader = geoip2.database.Reader(db_path)
logger.info("GeoIP: opened MaxMind database at %s", db_path)
except Exception:
logger.warning(
"GeoIP: failed to open MaxMind database at %s — falling back to "
"external lookups. Check GEOIP_MAXMIND_DB_PATH and that the file "
"is readable inside the container.",
db_path,
exc_info=True,
)
_maxmind_reader = None
return _maxmind_reader
def lookup_ip_maxmind(ip: str) -> GeoResult:
"""Resolve an IP via the local MaxMind database.
Memory-mapped read, no network I/O — cheap enough to call
synchronously from the async path. Returns an unresolved
``GeoResult`` when the DB isn't configured, the IP is private, or
the record can't be found.
"""
if _is_private_ip(ip):
return GeoResult(country_code=None, region=None)
reader = _get_maxmind_reader()
if reader is None:
return GeoResult(country_code=None, region=None)
try:
response = reader.city(ip)
except Exception:
logger.debug("MaxMind lookup failed for %s", ip, exc_info=True)
return GeoResult(country_code=None, region=None)
country = response.country.iso_code
if not country:
return GeoResult(country_code=None, region=None)
# ``subdivisions`` is ordered most-specific first; the first entry
# is the ISO 3166-2 code (without the country prefix).
state = response.subdivisions.most_specific.iso_code if response.subdivisions else None
return GeoResult(
country_code=country.upper(),
region=country_to_region(country, state),
)
async def detect_region(request: Request) -> GeoResult:
"""Detect the visitor's region.
Resolution order:
1. CDN/proxy headers (see :func:`detect_region_from_headers`).
2. Local MaxMind database, if ``GEOIP_MAXMIND_DB_PATH`` is set.
3. External ``ip-api.com`` lookup — last-ditch fallback.
Returns an unresolved :class:`GeoResult` if every tier fails.
"""
result = detect_region_from_headers(request)
if result.is_resolved:
return result
ip = get_client_ip(request)
if not ip:
return GeoResult(country_code=None, region=None)
if get_settings().geoip_maxmind_db_path:
result = lookup_ip_maxmind(ip)
if result.is_resolved:
return result
return await lookup_ip_region(ip)
def _is_private_ip(ip: str) -> bool:
"""Check if an IP address is a private/loopback address."""
return (
ip.startswith("127.")
or ip.startswith("10.")
or ip.startswith("192.168.")
or ip.startswith("172.16.")
or ip.startswith("172.17.")
or ip.startswith("172.18.")
or ip.startswith("172.19.")
or ip.startswith("172.2")
or ip.startswith("172.3")
or ip == "::1"
or ip == "localhost"
)

View File

@@ -0,0 +1,41 @@
"""Pseudonymisation helpers for consent records.
Consent records capture a hash of the visitor's IP address and
user-agent string for abuse protection and audit trail purposes.
Previously this used an unsalted truncated SHA-256, which is trivially
reversible for IPv4 addresses (only ~4 billion inputs). We now use
HMAC-SHA256 keyed with a server-side secret so the hash cannot be
recovered without access to the secret.
Public API: :func:`pseudonymise`.
"""
from __future__ import annotations
import hmac
from hashlib import sha256
from src.config.settings import get_settings
# Length of the hex-encoded digest stored in the database. 32 hex chars
# = 128 bits, which is more than enough entropy while keeping the
# column compact. (Previous code used 16 hex chars = 64 bits.)
_DIGEST_HEX_LEN = 32
def pseudonymise(value: str) -> str:
"""Return a keyed hash of *value* safe to store in an audit record.
Uses HMAC-SHA256 with the configured ``pseudonymisation_secret``
(falling back to ``jwt_secret_key`` if not explicitly set). The
resulting hex digest is truncated to 32 characters (128 bits).
An empty input always returns an empty string so callers don't
have to branch on missing data.
"""
if not value:
return ""
key = get_settings().pseudonymisation_key
digest = hmac.new(key, value.encode("utf-8"), sha256).hexdigest()
return digest[:_DIGEST_HEX_LEN]

View File

@@ -0,0 +1,89 @@
"""CDN publishing pipeline.
Publishes resolved site configurations as static JSON files for the
banner script to fetch. Supports local filesystem (development) and
can be extended for S3/GCS/CloudFront.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from src.config.settings import get_settings
from .config_resolver import build_public_config, resolve_config
logger = logging.getLogger(__name__)
class PublishResult:
"""Result of a publish operation."""
def __init__(self, success: bool, path: str, error: str | None = None) -> None:
self.success = success
self.path = path
self.error = error
self.published_at = datetime.now(UTC).isoformat() if success else None
async def publish_site_config(
site_id: str,
site_config: dict[str, Any],
org_defaults: dict[str, Any] | None = None,
) -> PublishResult:
"""Resolve and publish a site configuration to CDN.
Args:
site_id: The site UUID as a string.
site_config: Raw site configuration from the database.
org_defaults: Organisation-level defaults (optional).
Returns:
PublishResult with success status and path.
"""
try:
# Resolve the full config hierarchy
resolved = resolve_config(site_config, org_defaults)
# Build the public-facing config
public_config = build_public_config(site_id, resolved)
# Publish to the configured backend
settings = get_settings()
path = await _publish_local(site_id, public_config, settings.cdn_base_url)
logger.info("Published config for site %s to %s", site_id, path)
return PublishResult(success=True, path=path)
except Exception as exc:
logger.exception("Failed to publish config for site %s", site_id)
return PublishResult(success=False, path="", error=str(exc))
async def _publish_local(
site_id: str,
config: dict[str, Any],
cdn_base: str,
) -> str:
"""Publish config to local filesystem (for development/Docker Compose).
Writes to the CDN proxy's HTML directory so nginx can serve it.
"""
# Default local publish directory
publish_dir = Path("/app/cdn-publish") if Path("/app").exists() else Path("cdn-publish")
publish_dir.mkdir(parents=True, exist_ok=True)
# Write the config JSON
config_path = publish_dir / f"site-config-{site_id}.json"
config_path.write_text(json.dumps(config, indent=2, default=str))
# Also write a versioned copy for cache-busting
version = datetime.now(UTC).strftime("%Y%m%d%H%M%S")
versioned_path = publish_dir / f"site-config-{site_id}-{version}.json"
versioned_path.write_text(json.dumps(config, indent=2, default=str))
return str(config_path)

View File

@@ -0,0 +1,322 @@
"""Scan orchestration and diff engine.
Provides scan job lifecycle management, result diffing between scans,
and cookie inventory synchronisation from scan results.
"""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.models.cookie import Cookie
from src.models.scan import ScanJob, ScanResult
from src.models.site import Site
from src.schemas.scanner import (
CookieDiffItem,
DiffStatus,
ScanDiffResponse,
)
async def create_scan_job(
db: AsyncSession,
*,
site_id: uuid.UUID,
trigger: str = "manual",
max_pages: int = 50,
) -> ScanJob:
"""Create a new scan job in 'pending' state."""
job = ScanJob(
site_id=site_id,
status="pending",
trigger=trigger,
pages_total=max_pages,
)
db.add(job)
await db.flush()
return job
async def start_scan_job(db: AsyncSession, job: ScanJob) -> ScanJob:
"""Transition a scan job to 'running'.
Idempotent: if the job is already running (e.g. Celery re-delivered the
task after a worker crash), this is a no-op. Also handles re-delivery
after a transient failure that left the job in 'failed' state mid-retry.
"""
if job.status == "running":
return job
job.status = "running"
job.started_at = datetime.now(UTC)
# Reset any previous error so the retry starts clean
job.error_message = None
await db.flush()
return job
async def complete_scan_job(
db: AsyncSession,
job: ScanJob,
*,
pages_scanned: int = 0,
cookies_found: int = 0,
error_message: str | None = None,
) -> ScanJob:
"""Mark a scan job as completed or failed."""
job.status = "failed" if error_message else "completed"
job.completed_at = datetime.now(UTC)
job.pages_scanned = pages_scanned
job.cookies_found = cookies_found
job.error_message = error_message
await db.flush()
return job
async def add_scan_result(
db: AsyncSession,
*,
scan_job_id: uuid.UUID,
page_url: str,
cookie_name: str,
cookie_domain: str,
storage_type: str = "cookie",
attributes: dict | None = None,
script_source: str | None = None,
auto_category: str | None = None,
initiator_chain: list[str] | None = None,
) -> ScanResult:
"""Record a single cookie discovery from a scan."""
result = ScanResult(
scan_job_id=scan_job_id,
page_url=page_url,
cookie_name=cookie_name,
cookie_domain=cookie_domain,
storage_type=storage_type,
attributes=attributes,
script_source=script_source,
auto_category=auto_category,
initiator_chain=initiator_chain,
)
db.add(result)
await db.flush()
return result
async def get_previous_completed_scan(
db: AsyncSession,
*,
site_id: uuid.UUID,
before_scan_id: uuid.UUID,
) -> ScanJob | None:
"""Find the most recent completed scan before the given one."""
# First get the creation time of the reference scan
ref_result = await db.execute(select(ScanJob.created_at).where(ScanJob.id == before_scan_id))
ref_time = ref_result.scalar_one_or_none()
if ref_time is None:
return None
result = await db.execute(
select(ScanJob)
.where(
ScanJob.site_id == site_id,
ScanJob.status == "completed",
ScanJob.id != before_scan_id,
ScanJob.created_at < ref_time,
)
.order_by(ScanJob.created_at.desc())
.limit(1)
)
return result.scalar_one_or_none()
def _result_key(r: ScanResult) -> tuple[str, str, str]:
"""Unique key for a scan result (cookie identity)."""
return (r.cookie_name, r.cookie_domain, r.storage_type)
async def compute_scan_diff(
db: AsyncSession,
*,
current_scan_id: uuid.UUID,
site_id: uuid.UUID,
) -> ScanDiffResponse:
"""Compute the diff between the current scan and the previous one.
Returns new, removed, and changed cookies. If no previous scan exists,
all cookies in the current scan are marked as 'new'.
"""
previous_scan = await get_previous_completed_scan(
db, site_id=site_id, before_scan_id=current_scan_id
)
# Load current scan results
current_results = await db.execute(
select(ScanResult).where(ScanResult.scan_job_id == current_scan_id)
)
current_items = list(current_results.scalars().all())
current_keys = {_result_key(r): r for r in current_items}
if previous_scan is None:
# No previous scan — everything is new
new_cookies = [
CookieDiffItem(
name=r.cookie_name,
domain=r.cookie_domain,
storage_type=r.storage_type,
diff_status=DiffStatus.NEW,
details="First scan — no previous data",
)
for r in current_items
]
return ScanDiffResponse(
current_scan_id=current_scan_id,
previous_scan_id=None,
new_cookies=new_cookies,
total_new=len(new_cookies),
)
# Load previous scan results
prev_results = await db.execute(
select(ScanResult).where(ScanResult.scan_job_id == previous_scan.id)
)
prev_items = list(prev_results.scalars().all())
prev_keys = {_result_key(r): r for r in prev_items}
new_cookies: list[CookieDiffItem] = []
removed_cookies: list[CookieDiffItem] = []
changed_cookies: list[CookieDiffItem] = []
# New cookies: in current but not in previous
for key, r in current_keys.items():
if key not in prev_keys:
new_cookies.append(
CookieDiffItem(
name=r.cookie_name,
domain=r.cookie_domain,
storage_type=r.storage_type,
diff_status=DiffStatus.NEW,
)
)
# Removed cookies: in previous but not in current
for key, r in prev_keys.items():
if key not in current_keys:
removed_cookies.append(
CookieDiffItem(
name=r.cookie_name,
domain=r.cookie_domain,
storage_type=r.storage_type,
diff_status=DiffStatus.REMOVED,
)
)
# Changed cookies: in both but with different attributes
for key in current_keys:
if key in prev_keys:
curr = current_keys[key]
prev = prev_keys[key]
changes: list[str] = []
if curr.script_source != prev.script_source:
changes.append("script_source changed")
if curr.auto_category != prev.auto_category:
changes.append("auto_category changed")
# Compare cookie attributes (e.g. secure, httpOnly)
if (curr.attributes or {}) != (prev.attributes or {}):
changes.append("attributes changed")
if changes:
changed_cookies.append(
CookieDiffItem(
name=curr.cookie_name,
domain=curr.cookie_domain,
storage_type=curr.storage_type,
diff_status=DiffStatus.CHANGED,
details="; ".join(changes),
)
)
return ScanDiffResponse(
current_scan_id=current_scan_id,
previous_scan_id=previous_scan.id,
new_cookies=new_cookies,
removed_cookies=removed_cookies,
changed_cookies=changed_cookies,
total_new=len(new_cookies),
total_removed=len(removed_cookies),
total_changed=len(changed_cookies),
)
async def sync_scan_results_to_cookies(
db: AsyncSession,
*,
scan_job_id: uuid.UUID,
site_id: uuid.UUID,
) -> int:
"""Upsert scan results into the site's cookie inventory.
Creates new Cookie records for newly discovered items or updates
last_seen_at for existing ones. Returns the number of new cookies.
"""
results = await db.execute(select(ScanResult).where(ScanResult.scan_job_id == scan_job_id))
items = list(results.scalars().all())
now_iso = datetime.now(UTC).isoformat()
new_count = 0
for item in items:
existing = await db.execute(
select(Cookie).where(
Cookie.site_id == site_id,
Cookie.name == item.cookie_name,
Cookie.domain == item.cookie_domain,
Cookie.storage_type == item.storage_type,
)
)
cookie = existing.scalar_one_or_none()
if cookie:
cookie.last_seen_at = now_iso
else:
cookie = Cookie(
site_id=site_id,
name=item.cookie_name,
domain=item.cookie_domain,
storage_type=item.storage_type,
review_status="pending",
first_seen_at=now_iso,
last_seen_at=now_iso,
)
db.add(cookie)
new_count += 1
await db.flush()
return new_count
async def get_sites_due_for_scan(db: AsyncSession) -> list[Site]:
"""Find sites with a scan schedule that are due for scanning.
A site is due when it has a scan_schedule_cron set and either has
never been scanned or the last scan completed before the schedule
interval. For simplicity, this checks the most recent scan's
completed_at against the current time minus a derived interval.
"""
from src.models.site_config import SiteConfig
# Find sites with a cron schedule
result = await db.execute(
select(Site)
.join(SiteConfig, SiteConfig.site_id == Site.id)
.where(
Site.deleted_at.is_(None),
Site.is_active.is_(True),
SiteConfig.scan_schedule_cron.isnot(None),
)
)
return list(result.scalars().all())

View File

View File

@@ -0,0 +1,87 @@
"""Consent record retention purge.
Deletes consent records older than each site's configured
``consent_retention_days``. Sites with no retention configured are
skipped — operators must explicitly opt in per site (or set it at the
org/system level and let the cascade resolve it).
Scheduled by ``celery beat`` daily at 01:00 UTC via the entry in
``src.celery_app.beat_schedule``.
"""
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from src.celery_app import app
logger = logging.getLogger(__name__)
async def _purge() -> dict[str, int]:
"""Delete expired consent records across all sites with retention set.
Returns a summary ``{"sites_processed": N, "records_deleted": M}``.
"""
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.models.consent import ConsentRecord
from src.models.site_config import SiteConfig
settings = get_settings()
engine = create_async_engine(settings.database_url, echo=False)
sites_processed = 0
records_deleted = 0
async with AsyncSession(engine, expire_on_commit=False) as session:
configs = (
(
await session.execute(
select(SiteConfig).where(SiteConfig.consent_retention_days.isnot(None)),
)
)
.scalars()
.all()
)
now = datetime.now(UTC)
for cfg in configs:
retention_days = cfg.consent_retention_days
if not retention_days or retention_days <= 0:
continue
cutoff = now - timedelta(days=retention_days)
result = await session.execute(
delete(ConsentRecord).where(
ConsentRecord.site_id == cfg.site_id,
ConsentRecord.consented_at < cutoff,
),
)
deleted = result.rowcount or 0
records_deleted += deleted
sites_processed += 1
if deleted:
logger.info(
"retention.purged",
extra={
"site_id": str(cfg.site_id),
"retention_days": retention_days,
"deleted": deleted,
"cutoff": cutoff.isoformat(),
},
)
await session.commit()
await engine.dispose()
return {"sites_processed": sites_processed, "records_deleted": records_deleted}
@app.task(name="src.tasks.retention.purge_expired_consent_records")
def purge_expired_consent_records() -> dict[str, int]:
"""Celery entrypoint for the retention purge."""
return asyncio.run(_purge())

View File

@@ -0,0 +1,308 @@
"""Celery tasks for scan job execution and scheduling.
The run_scan task calls the scanner HTTP service to execute a Playwright
crawl, then processes the results: stores scan results, runs auto-
classification, syncs discovered cookies to the site inventory, and
computes diffs against the previous scan.
"""
from __future__ import annotations
import logging
import uuid
import httpx
from src.celery_app import app
logger = logging.getLogger(__name__)
@app.task(name="src.tasks.scanner.run_scan", bind=True, max_retries=2)
def run_scan(self, scan_job_id: str, site_id: str) -> dict:
"""Execute a scan job by calling the scanner service.
1. Transition job to 'running'
2. Look up site domain
3. Call scanner HTTP service with the domain
4. Store scan results and run auto-classification
5. Sync discovered cookies to the site inventory
6. Mark job as completed
"""
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.models.scan import ScanJob
from src.models.site import Site
from src.services.classification import classify_single_cookie
from src.services.scanner import (
add_scan_result,
complete_scan_job,
start_scan_job,
sync_scan_results_to_cookies,
)
settings = get_settings()
job_uuid = uuid.UUID(scan_job_id)
site_uuid = uuid.UUID(site_id)
async def _execute() -> dict:
engine = create_async_engine(settings.database_url, echo=False)
async with AsyncSession(engine, expire_on_commit=False) as db:
try:
# Load the job
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
job = result.scalar_one_or_none()
if job is None:
return {"error": "Scan job not found"}
# Load the site to get the domain
site_result = await db.execute(select(Site).where(Site.id == site_uuid))
site = site_result.scalar_one_or_none()
if site is None:
return {"error": "Site not found"}
# Transition to running
await start_scan_job(db, job)
await db.commit()
# Call the scanner service
scanner_url = f"{settings.scanner_service_url}/scan"
max_pages = job.pages_total or 50
async with httpx.AsyncClient(
timeout=httpx.Timeout(settings.scanner_timeout_seconds)
) as client:
resp = await client.post(
scanner_url,
json={
"domain": site.domain,
"max_pages": max_pages,
},
)
resp.raise_for_status()
scan_data = resp.json()
# Store scan results
cookies = scan_data.get("cookies", [])
pages_crawled = scan_data.get("pages_crawled", 0)
for cookie in cookies:
# Auto-classify the cookie
category = await classify_single_cookie(
db,
site_id=site_uuid,
cookie_name=cookie["name"],
cookie_domain=cookie["domain"],
)
await add_scan_result(
db,
scan_job_id=job_uuid,
page_url=cookie.get("page_url", ""),
cookie_name=cookie["name"],
cookie_domain=cookie["domain"],
storage_type=cookie.get("storage_type", "cookie"),
attributes={
"path": cookie.get("path"),
"http_only": cookie.get("http_only"),
"secure": cookie.get("secure"),
"same_site": cookie.get("same_site"),
"value_length": cookie.get("value_length", 0),
},
script_source=cookie.get("script_source"),
auto_category=category.category_slug if category else None,
initiator_chain=cookie.get("initiator_chain") or None,
)
await db.commit()
# Mark job as completed
await complete_scan_job(
db,
job,
pages_scanned=pages_crawled,
cookies_found=len(cookies),
)
await db.commit()
# Sync results to cookie inventory
new_cookies = await sync_scan_results_to_cookies(
db,
scan_job_id=job_uuid,
site_id=site_uuid,
)
await db.commit()
logger.info(
"Scan %s completed: %d pages, %d cookies, %d new",
scan_job_id,
pages_crawled,
len(cookies),
new_cookies,
)
return {
"scan_job_id": scan_job_id,
"status": "completed",
"pages_scanned": pages_crawled,
"cookies_found": len(cookies),
"new_cookies_synced": new_cookies,
}
except httpx.HTTPError as exc:
logger.error("Scanner service error for job %s: %s", scan_job_id, exc)
await db.rollback()
# Only mark failed on the final retry; otherwise let the
# retry set status back to "running" cleanly.
if self.request.retries >= self.max_retries:
await _mark_failed(db, job_uuid, f"Scanner service error: {exc}")
raise self.retry(exc=exc, countdown=30) from exc
except Exception as exc:
logger.exception("Scan task failed for job %s", scan_job_id)
await db.rollback()
await _mark_failed(db, job_uuid, str(exc))
return {"error": str(exc)}
finally:
await engine.dispose()
return asyncio.run(_execute())
async def _mark_failed(db, job_uuid: uuid.UUID, message: str) -> None:
"""Mark a scan job as failed."""
from sqlalchemy import select
from src.models.scan import ScanJob
from src.services.scanner import complete_scan_job
try:
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
job = result.scalar_one_or_none()
if job:
await complete_scan_job(db, job, error_message=message)
await db.commit()
except Exception:
logger.exception("Failed to mark scan job %s as failed", job_uuid)
@app.task(name="src.tasks.scanner.check_scheduled_scans")
def check_scheduled_scans() -> dict:
"""Periodic task: check which sites are due for a scheduled scan.
Runs every 15 minutes via Celery Beat. For each site with a
scan_schedule_cron, checks if a scan is overdue and triggers one.
"""
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.services.scanner import create_scan_job, get_sites_due_for_scan
settings = get_settings()
async def _check() -> dict:
engine = create_async_engine(settings.database_url, echo=False)
async with AsyncSession(engine, expire_on_commit=False) as db:
try:
sites = await get_sites_due_for_scan(db)
triggered = 0
for site in sites:
job = await create_scan_job(db, site_id=site.id, trigger="scheduled")
await db.commit()
# Dispatch the scan task
run_scan.delay(str(job.id), str(site.id))
triggered += 1
return {"sites_checked": len(sites), "scans_triggered": triggered}
except Exception:
await db.rollback()
raise
finally:
await engine.dispose()
return asyncio.run(_check())
@app.task(name="src.tasks.scanner.recover_stale_scans")
def recover_stale_scans() -> dict:
"""Periodic task: detect and recover scan jobs stuck in pending/running.
- Jobs stuck in 'pending' for >5 minutes are re-dispatched to Celery.
- Jobs stuck in 'running' for >10 minutes are marked as failed.
"""
import asyncio
from datetime import UTC, datetime, timedelta
from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.models.scan import ScanJob
from src.services.scanner import complete_scan_job
settings = get_settings()
async def _recover() -> dict:
engine = create_async_engine(settings.database_url, echo=False)
async with AsyncSession(engine, expire_on_commit=False) as db:
try:
now = datetime.now(UTC)
stale_pending_cutoff = now - timedelta(minutes=5)
stale_running_cutoff = now - timedelta(minutes=10)
result = await db.execute(
select(ScanJob).where(
or_(
# Pending too long — likely never picked up
(ScanJob.status == "pending")
& (ScanJob.created_at < stale_pending_cutoff),
# Running too long — likely worker died
(ScanJob.status == "running")
& (ScanJob.started_at < stale_running_cutoff),
)
)
)
stale_jobs = list(result.scalars().all())
redispatched = 0
failed = 0
for job in stale_jobs:
if job.status == "pending":
# Re-dispatch to Celery
logger.warning("Re-dispatching stale pending scan job %s", job.id)
run_scan.delay(str(job.id), str(job.site_id))
redispatched += 1
elif job.status == "running":
# Mark as failed — the worker likely died
logger.warning("Failing stale running scan job %s", job.id)
await complete_scan_job(
db,
job,
error_message=(
"Job timed out (running too long, worker may have crashed)"
),
)
failed += 1
await db.commit()
return {
"stale_jobs_found": len(stale_jobs),
"redispatched": redispatched,
"failed": failed,
}
except Exception:
await db.rollback()
raise
finally:
await engine.dispose()
return asyncio.run(_recover())

8
apps/api/start.sh Executable file
View File

@@ -0,0 +1,8 @@
#!/bin/sh
exec uvicorn src.main:app \
--host 0.0.0.0 \
--port "${PORT:-8000}" \
--workers "${WEB_CONCURRENCY:-1}" \
--access-log \
--proxy-headers \
--forwarded-allow-ips '*'

View File

241
apps/api/tests/conftest.py Normal file
View File

@@ -0,0 +1,241 @@
"""Shared test fixtures for the CMP API test suite.
Provides two modes:
- Unit tests: use `app` and `client` fixtures (no database required)
- Integration tests: use `db_client` fixture (requires PostgreSQL)
Integration tests are automatically skipped when no database is available.
"""
import os
# Disable rate limiting for the test suite. Many tests make dozens of
# requests from the same loopback address in rapid succession and the
# middleware would legitimately reject them as a DoS; the middleware
# has its own dedicated test module.
os.environ.setdefault("RATE_LIMIT_ENABLED", "false")
os.environ.setdefault("ENVIRONMENT", "test")
import uuid
import pytest
import pytest_asyncio
from httpx import ASGITransport, AsyncClient
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.main import create_app
from src.models.base import Base
# ── Detect whether a test database is available ──────────────────────
_TEST_DB_URL = os.environ.get(
"TEST_DATABASE_URL",
os.environ.get("DATABASE_URL", ""),
)
_HAS_DB = bool(_TEST_DB_URL) and "localhost" in _TEST_DB_URL
def _requires_db(fn):
"""Mark a test as requiring a live database.
Also pins the event loop to session scope so that fixtures sharing the
session-scoped engine don't get 'Future attached to a different loop'.
"""
fn = pytest.mark.asyncio(loop_scope="session")(fn)
fn = pytest.mark.skipif(not _HAS_DB, reason="No test database available")(fn)
return fn
requires_db = _requires_db
# ── Unit test fixtures (no database) ─────────────────────────────────
@pytest.fixture
def app():
"""Create a fresh FastAPI application instance."""
return create_app()
@pytest.fixture
async def client(app):
"""Async HTTP client for unit tests (no database)."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
# ── Integration test fixtures (with database) ────────────────────────
@pytest.fixture(scope="session")
def _test_engine():
"""Create a test database engine (session-scoped)."""
if not _HAS_DB:
pytest.skip("No test database available")
return create_async_engine(_TEST_DB_URL, echo=False)
@pytest_asyncio.fixture(scope="session", loop_scope="session")
async def _setup_db(_test_engine):
"""Create all tables once per test session, then seed fixture data.
Tests that depend on the cookie-category seed (normally applied by
the ``0001_initial_schema`` alembic migration) get the same rows
here so they can run without invoking alembic.
"""
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
await _seed_cookie_categories(conn)
yield
async with _test_engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
async def _seed_cookie_categories(conn) -> None:
"""Insert the default cookie categories. Mirrors migration 0001."""
import uuid as _uuid
from sqlalchemy import text
rows = [
("10000000-0000-0000-0000-000000000001", "Necessary", "necessary", True, 0),
("10000000-0000-0000-0000-000000000002", "Functional", "functional", False, 1),
("10000000-0000-0000-0000-000000000003", "Analytics", "analytics", False, 2),
("10000000-0000-0000-0000-000000000004", "Marketing", "marketing", False, 3),
("10000000-0000-0000-0000-000000000005", "Personalisation", "personalisation", False, 4),
]
stmt = text(
"""
INSERT INTO cookie_categories
(id, name, slug, description, is_essential, display_order)
VALUES (:id, :name, :slug, :description, :is_essential, :display_order)
ON CONFLICT (slug) DO NOTHING
""",
)
for row_id, name, slug, is_essential, order in rows:
await conn.execute(
stmt,
{
"id": _uuid.UUID(row_id),
"name": name,
"slug": slug,
"description": f"{name} cookies",
"is_essential": is_essential,
"display_order": order,
},
)
@pytest_asyncio.fixture(loop_scope="session")
async def db_client(_test_engine, _setup_db):
"""Async HTTP client where each route handler gets its own DB session.
Each request gets an independent session/connection so there are no
'another operation is in progress' errors from asyncpg.
"""
from src.db import get_db
app = create_app()
async def _override_get_db():
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
# ── Auth helper fixtures ─────────────────────────────────────────────
@pytest_asyncio.fixture(loop_scope="session")
async def test_org(_test_engine, _setup_db):
"""Create a test organisation in the database."""
from src.models.organisation import Organisation
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
org = Organisation(
id=uuid.uuid4(),
name="Test Organisation",
slug=f"test-org-{uuid.uuid4().hex[:8]}",
)
session.add(org)
await session.commit()
return org
@pytest_asyncio.fixture(loop_scope="session")
async def test_user(_test_engine, _setup_db, test_org):
"""Create a test user (owner role) with a known password."""
from src.models.user import User
from src.services.auth import hash_password
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
user = User(
id=uuid.uuid4(),
email=f"admin-{uuid.uuid4().hex[:8]}@test.com",
password_hash=hash_password("TestPassword123"),
full_name="Test Admin",
role="owner",
organisation_id=test_org.id,
)
session.add(user)
await session.commit()
return user
@pytest_asyncio.fixture(loop_scope="session")
async def auth_token(test_user):
"""Generate a valid JWT token for the test user."""
from src.services.auth import create_access_token
return create_access_token(
user_id=str(test_user.id),
organisation_id=str(test_user.organisation_id),
role=test_user.role,
email=test_user.email,
)
@pytest_asyncio.fixture(loop_scope="session")
async def auth_headers(auth_token):
"""HTTP headers with a valid Bearer token."""
return {"Authorization": f"Bearer {auth_token}"}
# ── Shared helper for creating sites in integration tests ────────────
async def create_test_site(
client: AsyncClient,
headers: dict,
*,
domain_prefix: str = "test",
display_name: str = "Test Site",
) -> str:
"""Create a site via the API and return its ID.
This is a helper function (not a fixture) so it can be called
inline within each test, avoiding async fixture event-loop issues.
"""
resp = await client.post(
"/api/v1/sites/",
json={
"domain": f"{domain_prefix}-{uuid.uuid4().hex[:8]}.com",
"display_name": display_name,
},
headers=headers,
)
assert resp.status_code == 201, f"Failed to create test site: {resp.text}"
return resp.json()["id"]

179
apps/api/tests/test_auth.py Normal file
View File

@@ -0,0 +1,179 @@
"""Tests for JWT authentication service and dependencies."""
import uuid
from datetime import UTC, datetime, timedelta
import pytest
from jose import JWTError, jwt
from src.config.settings import get_settings
from src.schemas.auth import CurrentUser
from src.services.auth import (
create_access_token,
create_refresh_token,
decode_token,
hash_password,
verify_password,
)
class TestPasswordHashing:
def test_hash_and_verify(self):
password = "s3cureP@ss!"
hashed = hash_password(password)
assert hashed != password
assert verify_password(password, hashed)
def test_wrong_password_fails(self):
hashed = hash_password("correct")
assert not verify_password("wrong", hashed)
def test_different_hashes_for_same_password(self):
h1 = hash_password("same")
h2 = hash_password("same")
assert h1 != h2 # bcrypt salts differ
class TestJWTTokens:
@pytest.fixture
def user_data(self):
return {
"user_id": uuid.uuid4(),
"organisation_id": uuid.uuid4(),
"role": "admin",
"email": "test@example.com",
}
def test_create_access_token_decodable(self, user_data):
token = create_access_token(**user_data)
payload = decode_token(token)
assert payload["sub"] == str(user_data["user_id"])
assert payload["org_id"] == str(user_data["organisation_id"])
assert payload["role"] == "admin"
assert payload["email"] == "test@example.com"
assert payload["type"] == "access"
def test_create_refresh_token_decodable(self, user_data):
token = create_refresh_token(
user_id=user_data["user_id"],
organisation_id=user_data["organisation_id"],
)
payload = decode_token(token)
assert payload["sub"] == str(user_data["user_id"])
assert payload["type"] == "refresh"
def test_access_token_expiry(self, user_data):
token = create_access_token(**user_data)
payload = decode_token(token)
settings = get_settings()
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
delta = exp - iat
assert abs(delta.total_seconds() - settings.jwt_access_token_expire_minutes * 60) < 5
def test_refresh_token_expiry(self, user_data):
token = create_refresh_token(
user_id=user_data["user_id"],
organisation_id=user_data["organisation_id"],
)
payload = decode_token(token)
settings = get_settings()
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
delta = exp - iat
expected = settings.jwt_refresh_token_expire_days * 86400
assert abs(delta.total_seconds() - expected) < 5
def test_expired_token_raises(self):
settings = get_settings()
payload = {
"sub": str(uuid.uuid4()),
"org_id": str(uuid.uuid4()),
"role": "viewer",
"exp": datetime.now(UTC) - timedelta(hours=1),
"iat": datetime.now(UTC) - timedelta(hours=2),
"type": "access",
}
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
with pytest.raises(JWTError):
decode_token(token)
def test_tampered_token_raises(self, user_data):
token = create_access_token(**user_data)
# Tamper with the token
tampered = token[:-5] + "XXXXX"
with pytest.raises(JWTError):
decode_token(tampered)
class TestCurrentUser:
def test_has_role(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="admin@example.com",
role="admin",
)
assert user.has_role("admin", "owner")
assert not user.has_role("editor", "viewer")
def test_is_admin(self):
admin = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="a@b.com",
role="admin",
)
viewer = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="v@b.com",
role="viewer",
)
assert admin.is_admin
assert not viewer.is_admin
@pytest.mark.asyncio
class TestAuthEndpoints:
async def test_me_without_token_returns_401(self, client):
response = await client.get("/api/v1/auth/me")
assert response.status_code == 401
async def test_me_with_valid_token(self, client):
user_id = uuid.uuid4()
org_id = uuid.uuid4()
token = create_access_token(
user_id=user_id,
organisation_id=org_id,
role="editor",
email="user@example.com",
)
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 200
data = response.json()
assert data["id"] == str(user_id)
assert data["organisation_id"] == str(org_id)
assert data["role"] == "editor"
assert data["email"] == "user@example.com"
async def test_me_with_refresh_token_rejected(self, client):
token = create_refresh_token(
user_id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
)
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": f"Bearer {token}"},
)
assert response.status_code == 401
async def test_me_with_invalid_token(self, client):
response = await client.get(
"/api/v1/auth/me",
headers={"Authorization": "Bearer invalid.token.here"},
)
assert response.status_code == 401

View File

@@ -0,0 +1,124 @@
"""Tests for the initial admin bootstrap service."""
import uuid
from unittest.mock import patch
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import Settings
from src.models.organisation import Organisation
from src.models.user import User
from src.services.auth import verify_password
from src.services.bootstrap import bootstrap_initial_admin
from tests.conftest import requires_db
def _settings(**overrides) -> Settings:
base: dict = dict(
environment="test",
initial_admin_email=None,
initial_admin_password=None,
initial_admin_full_name="Administrator",
initial_org_name="Default Organisation",
initial_org_slug="default",
)
base.update(overrides)
return Settings(**base)
class TestBootstrapNoOp:
"""Pure unit tests — bootstrap must short-circuit before touching the DB."""
async def test_noop_when_email_unset(self):
settings = _settings(initial_admin_password="pw")
with patch("src.services.bootstrap.async_session_factory") as factory:
await bootstrap_initial_admin(settings)
factory.assert_not_called()
async def test_noop_when_password_unset(self):
settings = _settings(initial_admin_email="admin@example.com")
with patch("src.services.bootstrap.async_session_factory") as factory:
await bootstrap_initial_admin(settings)
factory.assert_not_called()
@requires_db
class TestBootstrapWithDatabase:
"""Integration tests — exercise the real SQL path."""
@pytest_asyncio.fixture(loop_scope="session")
async def clean_db(self, _test_engine, _setup_db):
"""Strip users and orgs so bootstrap sees an empty table."""
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
await session.execute(User.__table__.delete())
await session.execute(Organisation.__table__.delete())
await session.commit()
yield
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
await session.execute(User.__table__.delete())
await session.execute(Organisation.__table__.delete())
await session.commit()
async def test_creates_org_and_owner_when_empty(self, _test_engine, clean_db):
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
settings = _settings(
initial_admin_email=email,
initial_admin_password="SuperSecret123",
initial_org_slug=slug,
initial_org_name="Bootstrapped Org",
)
def _factory():
return AsyncSession(_test_engine, expire_on_commit=False)
with patch("src.services.bootstrap.async_session_factory", _factory):
await bootstrap_initial_admin(settings)
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
user = (await session.execute(select(User).where(User.email == email))).scalar_one()
org = (
await session.execute(select(Organisation).where(Organisation.slug == slug))
).scalar_one()
assert user.role == "owner"
assert user.organisation_id == org.id
assert user.full_name == "Administrator"
assert verify_password("SuperSecret123", user.password_hash)
assert org.name == "Bootstrapped Org"
assert org.contact_email == email
async def test_idempotent_when_user_exists(self, _test_engine, clean_db):
"""A second invocation must not create a second user."""
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
settings = _settings(
initial_admin_email=email,
initial_admin_password="SuperSecret123",
initial_org_slug=slug,
)
def _factory():
return AsyncSession(_test_engine, expire_on_commit=False)
with patch("src.services.bootstrap.async_session_factory", _factory):
await bootstrap_initial_admin(settings)
await bootstrap_initial_admin(
_settings(
initial_admin_email="someone-else@example.com",
initial_admin_password="Different123",
initial_org_slug=slug,
)
)
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
users = (await session.execute(select(User))).scalars().all()
assert len(users) == 1
assert users[0].email == email
pytestmark = pytest.mark.asyncio(loop_scope="session")

View File

@@ -0,0 +1,869 @@
"""Tests for known cookies database and auto-categorisation engine — CMP-22.
Covers:
- Classification service logic (unit tests — pure functions)
- Pattern matching (exact, wildcard, regex)
- Priority ordering (allow-list → exact → regex → unmatched)
- Known cookie CRUD endpoints (unit tests with mocked DB)
- Classification endpoints (unit tests with mocked DB)
- Schema validation
- Integration tests against live database
"""
import uuid
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from src.schemas.cookie import (
ClassificationResultResponse,
ClassifySingleRequest,
ClassifySiteResponse,
KnownCookieCreate,
KnownCookieResponse,
KnownCookieUpdate,
)
from src.services.classification import (
ClassificationResult,
MatchSource,
_match_pattern,
_match_regex,
classify_cookie,
)
# ── Schema tests ─────────────────────────────────────────────────────
class TestSchemas:
"""Validate known cookie and classification schemas."""
def test_known_cookie_create(self):
kc = KnownCookieCreate(
name_pattern="_ga",
domain_pattern="*",
category_id=uuid.uuid4(),
vendor="Google",
description="GA cookie",
)
assert kc.is_regex is False
def test_known_cookie_create_regex(self):
kc = KnownCookieCreate(
name_pattern="_hj.*",
domain_pattern=".*",
category_id=uuid.uuid4(),
is_regex=True,
)
assert kc.is_regex is True
def test_known_cookie_update_partial(self):
ku = KnownCookieUpdate(vendor="Updated Vendor")
dumped = ku.model_dump(exclude_unset=True)
assert "vendor" in dumped
assert "category_id" not in dumped
def test_known_cookie_response(self):
resp = KnownCookieResponse(
id=uuid.uuid4(),
name_pattern="_ga",
domain_pattern="*",
category_id=uuid.uuid4(),
is_regex=False,
created_at=datetime.now(),
updated_at=datetime.now(),
)
assert resp.vendor is None
def test_classification_result_response(self):
crr = ClassificationResultResponse(
cookie_name="_ga",
cookie_domain=".example.com",
match_source="known_exact",
matched=True,
)
assert crr.matched is True
def test_classify_single_request(self):
req = ClassifySingleRequest(cookie_name="_ga", cookie_domain=".example.com")
assert req.cookie_name == "_ga"
def test_classify_single_request_validation(self):
with pytest.raises(ValueError):
ClassifySingleRequest(cookie_name="", cookie_domain=".example.com")
def test_classify_site_response(self):
resp = ClassifySiteResponse(
site_id="abc",
total=10,
matched=7,
unmatched=3,
results=[],
)
assert resp.matched == 7
def test_match_source_enum(self):
assert MatchSource.ALLOW_LIST == "allow_list"
assert MatchSource.KNOWN_EXACT == "known_exact"
assert MatchSource.KNOWN_REGEX == "known_regex"
assert MatchSource.UNMATCHED == "unmatched"
# ── Pattern matching unit tests ──────────────────────────────────────
class TestPatternMatching:
"""Test the _match_pattern and _match_regex helpers."""
def test_exact_match(self):
assert _match_pattern("_ga", "_ga") is True
def test_exact_match_case_insensitive(self):
assert _match_pattern("_GA", "_ga") is True
assert _match_pattern("_ga", "_GA") is True
def test_exact_no_match(self):
assert _match_pattern("_ga", "_gid") is False
def test_wildcard_star(self):
assert _match_pattern("*", "_ga") is True
assert _match_pattern("*", "anything") is True
def test_wildcard_prefix(self):
assert _match_pattern("_ga_*", "_ga_ABC123") is True
assert _match_pattern("_ga_*", "_ga_") is True
assert _match_pattern("_ga_*", "_gid") is False
def test_wildcard_suffix(self):
assert _match_pattern("*.google.com", ".google.com") is True
assert _match_pattern("*.google.com", "www.google.com") is True
assert _match_pattern("*.google.com", ".facebook.com") is False
def test_wildcard_middle(self):
assert _match_pattern("_ga*id", "_ga_gid") is True # * matches _g
assert _match_pattern("_ga*id", "_gaid") is True
assert _match_pattern("_ga*id", "_ga") is False # must end in id
def test_empty_values(self):
assert _match_pattern("", "_ga") is False
assert _match_pattern("_ga", "") is False
assert _match_pattern("", "") is False
def test_regex_match(self):
assert _match_regex(r"_hj.*", "_hjSession_12345") is True
assert _match_regex(r"_hj.*", "_ga") is False
def test_regex_case_insensitive(self):
assert _match_regex(r"_hj.*", "_HJSession") is True
def test_regex_anchored(self):
# re.match anchors at start by default
assert _match_regex(r"_pk_id.*", "_pk_id.abc.123") is True
assert _match_regex(r"_pk_id.*", "x_pk_id") is False
def test_regex_invalid_pattern(self):
assert _match_regex(r"[invalid", "test") is False
def test_regex_full_domain_match(self):
assert _match_regex(r".*", ".example.com") is True
def test_wildcard_dynamic_id_suffix(self):
"""Cookies with dynamic IDs should match wildcard prefix patterns."""
assert _match_pattern("_hjSessionUser_*", "_hjSessionUser_1150536") is True
assert _match_pattern("_hjSession_*", "_hjSession_9876543") is True
assert _match_pattern("ri--*", "ri--zC77O2yRxuIvW5fjRAq0RdzNYaF-x") is True
assert _match_pattern("intercom-id-*", "intercom-id-abc123def") is True
assert _match_pattern("amp_*", "amp_ff29a3") is True
assert _match_pattern("mp_*", "mp_abc123_mixpanel") is True
def test_wildcard_does_not_overmatch(self):
"""Wildcard patterns should not match unrelated cookies."""
assert _match_pattern("_hjSessionUser_*", "_hjSession_123") is False
assert _match_pattern("ri--*", "ri-single-dash") is False
assert _match_pattern("intercom-id-*", "intercom-session-xyz") is False
# ── Classification engine unit tests ─────────────────────────────────
def _make_category(slug: str, cat_id: uuid.UUID | None = None):
"""Create a mock CookieCategory."""
cat = MagicMock()
cat.id = cat_id or uuid.uuid4()
cat.slug = slug
return cat
def _make_known(
name_pattern: str,
domain_pattern: str,
category_id: uuid.UUID,
vendor: str | None = None,
description: str | None = None,
is_regex: bool = False,
):
"""Create a mock KnownCookie."""
known = MagicMock()
known.name_pattern = name_pattern
known.domain_pattern = domain_pattern
known.category_id = category_id
known.vendor = vendor
known.description = description
known.is_regex = is_regex
return known
def _make_allow_entry(
name_pattern: str,
domain_pattern: str,
category_id: uuid.UUID,
description: str | None = None,
):
"""Create a mock CookieAllowListEntry."""
entry = MagicMock()
entry.name_pattern = name_pattern
entry.domain_pattern = domain_pattern
entry.category_id = category_id
entry.description = description
return entry
class TestClassifyCookie:
"""Test the classify_cookie pure function."""
def setup_method(self):
self.analytics_cat = _make_category("analytics")
self.marketing_cat = _make_category("marketing")
self.necessary_cat = _make_category("necessary")
self.category_map = {
self.analytics_cat.id: self.analytics_cat,
self.marketing_cat.id: self.marketing_cat,
self.necessary_cat.id: self.necessary_cat,
}
def test_exact_known_match(self):
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
result = classify_cookie("_ga", ".example.com", [], [known], [], self.category_map)
assert result.matched is True
assert result.match_source == MatchSource.KNOWN_EXACT
assert result.category_slug == "analytics"
assert result.vendor == "Google"
def test_regex_known_match(self):
known = _make_known(
r"_hj.*",
r".*",
self.analytics_cat.id,
vendor="Hotjar",
is_regex=True,
)
result = classify_cookie(
"_hjSession_123",
".example.com",
[],
[],
[known],
self.category_map,
)
assert result.matched is True
assert result.match_source == MatchSource.KNOWN_REGEX
assert result.vendor == "Hotjar"
def test_allow_list_match(self):
entry = _make_allow_entry(
"_custom_cookie",
"*",
self.necessary_cat.id,
description="Site-specific override",
)
result = classify_cookie(
"_custom_cookie",
".example.com",
[entry],
[],
[],
self.category_map,
)
assert result.matched is True
assert result.match_source == MatchSource.ALLOW_LIST
assert result.category_slug == "necessary"
def test_allow_list_takes_priority_over_known(self):
"""Allow-list should override known cookies database."""
allow_entry = _make_allow_entry(
"_ga",
"*",
self.necessary_cat.id,
description="Overridden to necessary",
)
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
result = classify_cookie(
"_ga",
".example.com",
[allow_entry],
[known],
[],
self.category_map,
)
assert result.match_source == MatchSource.ALLOW_LIST
assert result.category_slug == "necessary"
def test_exact_takes_priority_over_regex(self):
"""Exact match should be preferred over regex match."""
exact = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
regex = _make_known(
r"_g.*",
r".*",
self.marketing_cat.id,
vendor="Other",
is_regex=True,
)
result = classify_cookie(
"_ga",
".example.com",
[],
[exact],
[regex],
self.category_map,
)
assert result.match_source == MatchSource.KNOWN_EXACT
assert result.category_slug == "analytics"
def test_unmatched(self):
result = classify_cookie(
"obscure_cookie",
".unknown.com",
[],
[],
[],
self.category_map,
)
assert result.matched is False
assert result.match_source == MatchSource.UNMATCHED
assert result.category_id is None
def test_domain_must_match(self):
"""Cookie should not match if domain pattern doesn't match."""
known = _make_known("_ga", "*.google.com", self.analytics_cat.id)
result = classify_cookie(
"_ga",
".example.com",
[],
[known],
[],
self.category_map,
)
assert result.matched is False
def test_name_must_match(self):
"""Cookie should not match if name pattern doesn't match."""
known = _make_known("_gid", "*", self.analytics_cat.id)
result = classify_cookie(
"_ga",
".example.com",
[],
[known],
[],
self.category_map,
)
assert result.matched is False
def test_wildcard_domain_match(self):
known = _make_known(
"fr",
"*.facebook.com",
self.marketing_cat.id,
vendor="Meta",
)
result = classify_cookie(
"fr",
".facebook.com",
[],
[known],
[],
self.category_map,
)
assert result.matched is True
assert result.vendor == "Meta"
def test_classification_result_fields(self):
result = ClassificationResult(
cookie_name="_ga",
cookie_domain=".example.com",
)
assert result.category_id is None
assert result.match_source == MatchSource.UNMATCHED
assert result.matched is False
# ── Router unit tests (mocked service) ──────────────────────────────
def _mock_db():
"""Create a mock async DB session."""
db = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = MagicMock()
mock_result.scalars.return_value.all.return_value = []
db.execute.return_value = mock_result
return db
async def _client(app, db):
"""Create an async test client with mocked DB and auth."""
from src.db import get_db
from src.services.dependencies import get_current_user, require_role
user = MagicMock()
user.organisation_id = uuid.uuid4()
user.role = "owner"
async def _override_get_db():
yield db
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_current_user] = lambda: user
def _override_require_role(*_roles):
return lambda: user
app.dependency_overrides[require_role] = _override_require_role
transport = ASGITransport(app=app)
return AsyncClient(transport=transport, base_url="http://test")
class TestKnownCookieRoutes:
"""Test known cookie CRUD endpoints."""
@pytest.mark.asyncio
async def test_list_known_cookies(self, app):
db = _mock_db()
async with await _client(app, db) as client:
resp = await client.get("/api/v1/cookies/known")
assert resp.status_code == 200
assert isinstance(resp.json(), list)
@pytest.mark.asyncio
async def test_create_known_cookie(self, app):
db = _mock_db()
# Mock category validation
cat_result = MagicMock()
cat_result.scalar_one_or_none.return_value = MagicMock()
# Mock the created known cookie
known_mock = MagicMock()
known_mock.id = uuid.uuid4()
known_mock.name_pattern = "_ga"
known_mock.domain_pattern = "*"
known_mock.category_id = uuid.uuid4()
known_mock.vendor = "Google"
known_mock.description = "GA cookie"
known_mock.is_regex = False
known_mock.created_at = datetime.now()
known_mock.updated_at = datetime.now()
call_count = 0
async def mock_execute(stmt):
nonlocal call_count
call_count += 1
if call_count == 1:
# Category validation
return cat_result
return MagicMock()
db.execute = mock_execute
db.flush = AsyncMock()
db.refresh = AsyncMock(side_effect=lambda obj: None)
db.add = MagicMock()
with patch(
"src.routers.cookies.KnownCookie",
return_value=known_mock,
):
async with await _client(app, db) as client:
resp = await client.post(
"/api/v1/cookies/known",
json={
"name_pattern": "_ga",
"domain_pattern": "*",
"category_id": str(uuid.uuid4()),
"vendor": "Google",
},
)
assert resp.status_code == 201
@pytest.mark.asyncio
async def test_get_known_cookie_not_found(self, app):
db = AsyncMock()
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
db.execute.return_value = mock_result
async with await _client(app, db) as client:
resp = await client.get(f"/api/v1/cookies/known/{uuid.uuid4()}")
assert resp.status_code == 404
class TestClassificationRoutes:
"""Test classification endpoint responses."""
@pytest.mark.asyncio
async def test_classify_preview(self, app):
db = _mock_db()
mock_result = ClassificationResult(
cookie_name="_ga",
cookie_domain=".example.com",
category_id=uuid.uuid4(),
category_slug="analytics",
vendor="Google",
match_source=MatchSource.KNOWN_EXACT,
matched=True,
)
with patch(
"src.routers.cookies.classify_single_cookie",
return_value=mock_result,
):
async with await _client(app, db) as client:
resp = await client.post(
f"/api/v1/cookies/sites/{uuid.uuid4()}/classify/preview",
json={
"cookie_name": "_ga",
"cookie_domain": ".example.com",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["matched"] is True
assert data["match_source"] == "known_exact"
# ── Integration tests ────────────────────────────────────────────────
try:
from tests.conftest import create_test_site, requires_db
except ImportError:
from conftest import create_test_site, requires_db
@requires_db
class TestClassificationIntegration:
"""Integration tests against a live database."""
async def _get_category_id(self, client: AsyncClient, headers: dict, slug: str) -> str:
"""Get a category ID by slug."""
resp = await client.get("/api/v1/cookies/categories", headers=headers)
assert resp.status_code == 200
for cat in resp.json():
if cat["slug"] == slug:
return cat["id"]
pytest.fail(f"Category '{slug}' not found")
async def _create_known_cookie(
self,
client: AsyncClient,
headers: dict,
name_pattern: str,
domain_pattern: str,
category_slug: str,
*,
vendor: str | None = None,
is_regex: bool = False,
) -> str:
"""Create a known cookie and return its ID."""
cat_id = await self._get_category_id(client, headers, category_slug)
resp = await client.post(
"/api/v1/cookies/known",
headers=headers,
json={
"name_pattern": name_pattern,
"domain_pattern": domain_pattern,
"category_id": cat_id,
"vendor": vendor,
"is_regex": is_regex,
},
)
assert resp.status_code == 201, resp.text
return resp.json()["id"]
async def _create_cookie(
self,
client: AsyncClient,
headers: dict,
site_id: str,
name: str,
domain: str,
) -> str:
"""Create a pending cookie on a site and return its ID."""
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}",
headers=headers,
json={"name": name, "domain": domain},
)
assert resp.status_code == 201, resp.text
return resp.json()["id"]
async def test_known_cookies_crud(self, db_client, auth_headers):
"""Test full CRUD lifecycle for known cookies."""
cat_id = await self._get_category_id(db_client, auth_headers, "analytics")
# Create
resp = await db_client.post(
"/api/v1/cookies/known",
headers=auth_headers,
json={
"name_pattern": f"_test_{uuid.uuid4().hex[:6]}",
"domain_pattern": "*",
"category_id": cat_id,
"vendor": "TestVendor",
"description": "Test cookie",
},
)
assert resp.status_code == 201
known_id = resp.json()["id"]
# Read
resp = await db_client.get(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
)
assert resp.status_code == 200
assert resp.json()["vendor"] == "TestVendor"
# Update
resp = await db_client.patch(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
json={"vendor": "UpdatedVendor"},
)
assert resp.status_code == 200
assert resp.json()["vendor"] == "UpdatedVendor"
# List (with search)
resp = await db_client.get(
"/api/v1/cookies/known",
headers=auth_headers,
params={"vendor": "UpdatedVendor"},
)
assert resp.status_code == 200
assert any(k["id"] == known_id for k in resp.json())
# Delete
resp = await db_client.delete(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
)
assert resp.status_code == 204
# Verify deleted
resp = await db_client.get(
f"/api/v1/cookies/known/{known_id}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_classify_exact_match(self, db_client, auth_headers):
"""Test classification with exact known cookie match."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-exact")
# Create a known cookie pattern
pattern_name = f"_test_exact_{uuid.uuid4().hex[:6]}"
await self._create_known_cookie(
db_client,
auth_headers,
pattern_name,
"*",
"analytics",
vendor="TestVendor",
)
# Create a pending cookie on the site
await self._create_cookie(
db_client,
auth_headers,
site_id,
pattern_name,
".example.com",
)
# Classify
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] >= 1
assert data["matched"] >= 1
matched = [r for r in data["results"] if r["matched"]]
assert any(r["cookie_name"] == pattern_name for r in matched)
async def test_classify_regex_match(self, db_client, auth_headers):
"""Test classification with regex known cookie match."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-regex")
prefix = f"_rx_{uuid.uuid4().hex[:4]}"
# Create regex pattern
await self._create_known_cookie(
db_client,
auth_headers,
f"{prefix}.*",
".*",
"analytics",
vendor="RegexVendor",
is_regex=True,
)
# Create a cookie that should match the regex
await self._create_cookie(
db_client,
auth_headers,
site_id,
f"{prefix}_session_123",
".example.com",
)
# Classify
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["matched"] >= 1
matched = [r for r in data["results"] if r["matched"]]
assert any(r["match_source"] == "known_regex" for r in matched)
async def test_classify_unmatched(self, db_client, auth_headers):
"""Cookies without known patterns should remain unmatched."""
site_id = await create_test_site(
db_client, auth_headers, domain_prefix="classify-unmatched"
)
unique_name = f"_unknown_{uuid.uuid4().hex[:8]}"
await self._create_cookie(
db_client,
auth_headers,
site_id,
unique_name,
".obscure-domain.com",
)
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["unmatched"] >= 1
async def test_classify_preview(self, db_client, auth_headers):
"""Test preview classification without saving."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-preview")
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify/preview",
headers=auth_headers,
json={
"cookie_name": "_unknown_cookie",
"cookie_domain": ".test.com",
},
)
assert resp.status_code == 200
data = resp.json()
assert data["matched"] is False
assert data["match_source"] == "unmatched"
async def test_classify_allow_list_priority(self, db_client, auth_headers):
"""Allow-list entries should take priority over known cookies."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-allow")
cookie_name = f"_priority_{uuid.uuid4().hex[:6]}"
# Add to known cookies as marketing
await self._create_known_cookie(
db_client,
auth_headers,
cookie_name,
"*",
"marketing",
)
# Add to allow-list as necessary (should take priority)
necessary_id = await self._get_category_id(db_client, auth_headers, "necessary")
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/allow-list",
headers=auth_headers,
json={
"name_pattern": cookie_name,
"domain_pattern": "*",
"category_id": necessary_id,
},
)
assert resp.status_code == 201
# Create cookie and classify
await self._create_cookie(
db_client,
auth_headers,
site_id,
cookie_name,
".example.com",
)
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
matched = [r for r in data["results"] if r["cookie_name"] == cookie_name]
assert len(matched) == 1
assert matched[0]["match_source"] == "allow_list"
assert matched[0]["category_id"] == necessary_id
async def test_known_cookies_not_found(self, db_client, auth_headers):
resp = await db_client.get(
f"/api/v1/cookies/known/{uuid.uuid4()}",
headers=auth_headers,
)
assert resp.status_code == 404
async def test_known_cookies_invalid_category(self, db_client, auth_headers):
resp = await db_client.post(
"/api/v1/cookies/known",
headers=auth_headers,
json={
"name_pattern": "_test",
"domain_pattern": "*",
"category_id": str(uuid.uuid4()),
},
)
assert resp.status_code == 400
async def test_known_cookies_auth_required(self, db_client):
"""Known cookie endpoints require authentication."""
resp = await db_client.get("/api/v1/cookies/known")
assert resp.status_code == 401
async def test_classify_empty_site(self, db_client, auth_headers):
"""Classifying a site with no cookies should return empty results."""
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-empty")
resp = await db_client.post(
f"/api/v1/cookies/sites/{site_id}/classify",
headers=auth_headers,
)
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 0
assert data["matched"] == 0
async def test_list_known_cookies_search(self, db_client, auth_headers):
"""Test searching known cookies by name pattern."""
unique = uuid.uuid4().hex[:6]
await self._create_known_cookie(
db_client,
auth_headers,
f"_search_{unique}",
"*",
"analytics",
)
resp = await db_client.get(
"/api/v1/cookies/known",
headers=auth_headers,
params={"search": f"_search_{unique}"},
)
assert resp.status_code == 200
results = resp.json()
assert len(results) >= 1
assert all(f"_search_{unique}" in r["name_pattern"] for r in results)

View File

@@ -0,0 +1,597 @@
"""Tests for the compliance rule engine and router."""
import uuid
import pytest
from httpx import ASGITransport, AsyncClient
from src.schemas.compliance import (
ComplianceCheckResponse,
ComplianceIssue,
Framework,
FrameworkResult,
Severity,
)
from src.services.compliance import (
CCPA_RULES,
CNIL_RULES,
EPRIVACY_RULES,
FRAMEWORK_RULES,
GDPR_RULES,
LGPD_RULES,
SiteContext,
calculate_overall_score,
run_compliance_check,
run_framework_check,
)
# ── SiteContext defaults ──────────────────────────────────────────────
class TestSiteContext:
def test_default_values(self):
ctx = SiteContext()
assert ctx.blocking_mode == "opt_in"
assert ctx.tcf_enabled is False
assert ctx.gcm_enabled is True
assert ctx.consent_expiry_days == 365
assert ctx.has_reject_button is True
assert ctx.has_granular_choices is True
assert ctx.has_cookie_wall is False
assert ctx.pre_ticked_boxes is False
def test_custom_values(self):
ctx = SiteContext(
blocking_mode="opt_out",
consent_expiry_days=180,
privacy_policy_url="https://example.com/privacy",
)
assert ctx.blocking_mode == "opt_out"
assert ctx.consent_expiry_days == 180
assert ctx.privacy_policy_url == "https://example.com/privacy"
# ── GDPR rules ────────────────────────────────────────────────────────
class TestGDPRRules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_in",
has_reject_button=True,
has_granular_choices=True,
has_cookie_wall=False,
pre_ticked_boxes=False,
privacy_policy_url="https://example.com/privacy",
uncategorised_cookies=0,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.score == 100
assert result.status == "compliant"
assert len(result.issues) == 0
assert result.rules_passed == result.rules_checked
def test_opt_out_mode_fails(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
assert result.status == "non_compliant"
def test_informational_mode_fails(self):
ctx = SiteContext(
blocking_mode="informational",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
def test_no_reject_button_fails(self):
ctx = SiteContext(
has_reject_button=False,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_reject_button" for i in result.issues)
def test_no_granular_consent_fails(self):
ctx = SiteContext(
has_granular_choices=False,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_granular" for i in result.issues)
def test_cookie_wall_fails(self):
ctx = SiteContext(
has_cookie_wall=True,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_cookie_wall" for i in result.issues)
def test_pre_ticked_fails(self):
ctx = SiteContext(
pre_ticked_boxes=True,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert any(i.rule_id == "gdpr_pre_ticked" for i in result.issues)
def test_no_privacy_policy_warns(self):
ctx = SiteContext(privacy_policy_url=None)
result = run_framework_check(Framework.GDPR, ctx)
policy_issues = [i for i in result.issues if i.rule_id == "gdpr_privacy_policy"]
assert len(policy_issues) == 1
assert policy_issues[0].severity == Severity.WARNING
def test_uncategorised_cookies_warns(self):
ctx = SiteContext(
uncategorised_cookies=5,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
uncat_issues = [i for i in result.issues if i.rule_id == "gdpr_uncategorised"]
assert len(uncat_issues) == 1
assert "5" in uncat_issues[0].message
def test_multiple_failures_accumulate(self):
ctx = SiteContext(
blocking_mode="opt_out",
has_reject_button=False,
has_granular_choices=False,
has_cookie_wall=True,
pre_ticked_boxes=True,
privacy_policy_url=None,
uncategorised_cookies=3,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.score == 0 # Capped at 0
assert result.status == "non_compliant"
assert len(result.issues) >= 5
# ── CNIL rules ────────────────────────────────────────────────────────
class TestCNILRules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_in",
has_reject_button=True,
has_granular_choices=True,
privacy_policy_url="https://example.com/privacy",
consent_expiry_days=180,
)
result = run_framework_check(Framework.CNIL, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_consent_expiry_too_long(self):
ctx = SiteContext(
consent_expiry_days=365,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert any(i.rule_id == "cnil_reconsent" for i in result.issues)
def test_consent_expiry_at_limit(self):
ctx = SiteContext(
consent_expiry_days=182,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert not any(i.rule_id == "cnil_reconsent" for i in result.issues)
def test_cookie_lifetime_too_long(self):
ctx = SiteContext(
consent_expiry_days=400,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
def test_cookie_lifetime_at_limit(self):
ctx = SiteContext(
consent_expiry_days=395,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert not any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
def test_inherits_gdpr_rules(self):
"""CNIL should check all GDPR rules plus CNIL-specific ones."""
assert len(CNIL_RULES) > len(GDPR_RULES)
def test_reject_first_layer(self):
ctx = SiteContext(
has_reject_button=False,
consent_expiry_days=180,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.CNIL, ctx)
assert any(i.rule_id == "cnil_reject_first_layer" for i in result.issues)
# ── CCPA rules ────────────────────────────────────────────────────────
class TestCCPARules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_opt_in_also_acceptable(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
assert not any(i.rule_id == "ccpa_opt_out" for i in result.issues)
def test_informational_mode_passes_ccpa(self):
"""CCPA opt-out check passes for informational (it's not 'informational')."""
ctx = SiteContext(
blocking_mode="informational",
privacy_policy_url="https://example.com/privacy",
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
# informational is not in ("opt_out", "opt_in"), so it fails
assert any(i.rule_id == "ccpa_opt_out" for i in result.issues)
def test_no_do_not_sell_link(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
banner_config={},
)
result = run_framework_check(Framework.CCPA, ctx)
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
def test_no_banner_config_fails_dns(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
banner_config=None,
)
result = run_framework_check(Framework.CCPA, ctx)
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
def test_no_privacy_policy_warns(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url=None,
banner_config={"show_do_not_sell_link": True},
)
result = run_framework_check(Framework.CCPA, ctx)
assert any(i.rule_id == "ccpa_privacy_policy" for i in result.issues)
# ── ePrivacy rules ────────────────────────────────────────────────────
class TestEPrivacyRules:
def test_compliant_site(self):
ctx = SiteContext(blocking_mode="opt_in")
result = run_framework_check(Framework.EPRIVACY, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_opt_out_passes(self):
ctx = SiteContext(blocking_mode="opt_out")
result = run_framework_check(Framework.EPRIVACY, ctx)
assert not any(i.rule_id == "eprivacy_consent" for i in result.issues)
def test_informational_fails(self):
ctx = SiteContext(blocking_mode="informational")
result = run_framework_check(Framework.EPRIVACY, ctx)
assert any(i.rule_id == "eprivacy_consent" for i in result.issues)
# ── LGPD rules ────────────────────────────────────────────────────────
class TestLGPDRules:
def test_compliant_site(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
has_granular_choices=True,
)
result = run_framework_check(Framework.LGPD, ctx)
assert result.score == 100
assert result.status == "compliant"
def test_informational_fails(self):
ctx = SiteContext(
blocking_mode="informational",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.LGPD, ctx)
assert any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
def test_no_privacy_policy_warns(self):
ctx = SiteContext(privacy_policy_url=None)
result = run_framework_check(Framework.LGPD, ctx)
assert any(i.rule_id == "lgpd_data_controller" for i in result.issues)
def test_no_granular_warns(self):
ctx = SiteContext(
has_granular_choices=False,
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.LGPD, ctx)
assert any(i.rule_id == "lgpd_granular" for i in result.issues)
def test_opt_out_passes(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.LGPD, ctx)
assert not any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
# ── Engine orchestration ──────────────────────────────────────────────
class TestComplianceEngine:
def test_run_all_frameworks(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
has_reject_button=True,
has_granular_choices=True,
consent_expiry_days=180,
)
results = run_compliance_check(ctx)
assert len(results) == 5
frameworks = {r.framework for r in results}
assert frameworks == {
Framework.GDPR,
Framework.CNIL,
Framework.CCPA,
Framework.EPRIVACY,
Framework.LGPD,
}
def test_run_specific_frameworks(self):
ctx = SiteContext()
results = run_compliance_check(ctx, [Framework.GDPR, Framework.CCPA])
assert len(results) == 2
assert results[0].framework == Framework.GDPR
assert results[1].framework == Framework.CCPA
def test_run_single_framework(self):
ctx = SiteContext()
results = run_compliance_check(ctx, [Framework.EPRIVACY])
assert len(results) == 1
assert results[0].framework == Framework.EPRIVACY
def test_empty_frameworks_list_runs_all(self):
ctx = SiteContext(
privacy_policy_url="https://example.com/privacy",
consent_expiry_days=180,
)
results = run_compliance_check(ctx, None)
assert len(results) == 5
class TestScoring:
def test_perfect_score(self):
result = FrameworkResult(
framework=Framework.GDPR,
score=100,
status="compliant",
rules_checked=7,
rules_passed=7,
)
assert calculate_overall_score([result]) == 100
def test_zero_score(self):
result = FrameworkResult(
framework=Framework.GDPR,
score=0,
status="non_compliant",
rules_checked=7,
rules_passed=0,
)
assert calculate_overall_score([result]) == 0
def test_average_across_frameworks(self):
results = [
FrameworkResult(
framework=Framework.GDPR,
score=100,
status="compliant",
rules_checked=7,
rules_passed=7,
),
FrameworkResult(
framework=Framework.CCPA,
score=50,
status="partial",
rules_checked=3,
rules_passed=1,
),
]
assert calculate_overall_score(results) == 75
def test_empty_results(self):
assert calculate_overall_score([]) == 100
def test_critical_issues_deduct_20(self):
ctx = SiteContext(
blocking_mode="opt_out",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
# opt_out causes one critical issue (gdpr_opt_in) → -20 points
assert result.score == 80
def test_warning_issues_deduct_5(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url=None,
uncategorised_cookies=0,
)
result = run_framework_check(Framework.GDPR, ctx)
# Missing privacy policy is a warning → -5 points
assert result.score == 95
def test_score_floors_at_zero(self):
ctx = SiteContext(
blocking_mode="opt_out",
has_reject_button=False,
has_granular_choices=False,
has_cookie_wall=True,
pre_ticked_boxes=True,
privacy_policy_url=None,
uncategorised_cookies=10,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.score == 0
def test_status_non_compliant_with_critical(self):
ctx = SiteContext(blocking_mode="opt_out")
result = run_framework_check(Framework.GDPR, ctx)
assert result.status == "non_compliant"
def test_status_partial_with_warnings_only(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url=None,
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.status == "partial"
def test_status_compliant_with_no_issues(self):
ctx = SiteContext(
blocking_mode="opt_in",
privacy_policy_url="https://example.com/privacy",
)
result = run_framework_check(Framework.GDPR, ctx)
assert result.status == "compliant"
# ── Framework registry ────────────────────────────────────────────────
class TestFrameworkRegistry:
def test_all_frameworks_registered(self):
assert Framework.GDPR in FRAMEWORK_RULES
assert Framework.CNIL in FRAMEWORK_RULES
assert Framework.CCPA in FRAMEWORK_RULES
assert Framework.EPRIVACY in FRAMEWORK_RULES
assert Framework.LGPD in FRAMEWORK_RULES
def test_each_framework_has_rules(self):
for fw, rules in FRAMEWORK_RULES.items():
assert len(rules) > 0, f"{fw} has no rules"
def test_rule_ids_are_unique_per_framework(self):
for fw, rules in FRAMEWORK_RULES.items():
ids = [r.rule_id for r in rules]
assert len(ids) == len(set(ids)), f"Duplicate rule IDs in {fw}"
def test_gdpr_rule_count(self):
assert len(GDPR_RULES) == 7
def test_cnil_includes_gdpr_rules(self):
gdpr_ids = {r.rule_id for r in GDPR_RULES}
cnil_ids = {r.rule_id for r in CNIL_RULES}
assert gdpr_ids.issubset(cnil_ids)
def test_ccpa_rule_count(self):
assert len(CCPA_RULES) == 3
def test_eprivacy_rule_count(self):
assert len(EPRIVACY_RULES) == 2
def test_lgpd_rule_count(self):
assert len(LGPD_RULES) == 3
# ── Router tests ──────────────────────────────────────────────────────
class TestComplianceRouter:
@pytest.fixture
def app(self):
from src.main import create_app
return create_app()
@pytest.fixture
async def client(self, app):
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac
async def test_list_frameworks(self, client):
resp = await client.get("/api/v1/compliance/frameworks")
assert resp.status_code == 200
data = resp.json()
assert len(data) == 5
ids = {fw["id"] for fw in data}
assert ids == {"gdpr", "cnil", "ccpa", "eprivacy", "lgpd"}
async def test_check_requires_auth(self, client):
resp = await client.post(f"/api/v1/compliance/check/{uuid.uuid4()}")
assert resp.status_code == 401
# ── Schema tests ──────────────────────────────────────────────────────
class TestSchemas:
def test_compliance_issue_schema(self):
issue = ComplianceIssue(
rule_id="test_rule",
severity=Severity.CRITICAL,
message="Test message",
recommendation="Test recommendation",
)
assert issue.rule_id == "test_rule"
assert issue.severity == Severity.CRITICAL
def test_framework_result_schema(self):
result = FrameworkResult(
framework=Framework.GDPR,
score=85,
status="partial",
rules_checked=7,
rules_passed=5,
)
assert result.framework == Framework.GDPR
assert result.score == 85
def test_compliance_check_response_schema(self):
response = ComplianceCheckResponse(
site_id="test-id",
results=[],
overall_score=100,
)
assert response.overall_score == 100
def test_severity_values(self):
assert Severity.CRITICAL == "critical"
assert Severity.WARNING == "warning"
assert Severity.INFO == "info"
def test_framework_values(self):
assert Framework.GDPR == "gdpr"
assert Framework.CNIL == "cnil"
assert Framework.CCPA == "ccpa"
assert Framework.EPRIVACY == "eprivacy"
assert Framework.LGPD == "lgpd"

View File

@@ -0,0 +1,258 @@
"""Tests for configuration hierarchy resolver and publisher."""
import uuid
import pytest
from src.services.config_resolver import (
SYSTEM_DEFAULTS,
build_public_config,
resolve_config,
)
class TestResolveConfig:
def test_returns_system_defaults_for_empty_config(self):
result = resolve_config({})
assert result["blocking_mode"] == "opt_in"
assert result["consent_expiry_days"] == 365
assert result["gcm_enabled"] is True
assert result["tcf_enabled"] is False
assert result["gpp_enabled"] is True
assert result["gpp_supported_apis"] == ["usnat"]
assert result["gpc_enabled"] is True
assert result["gpc_jurisdictions"] == [
"US-CA",
"US-CO",
"US-CT",
"US-TX",
"US-MT",
]
assert result["gpc_global_honour"] is False
def test_site_config_overrides_defaults(self):
site_config = {
"blocking_mode": "opt_out",
"consent_expiry_days": 180,
"tcf_enabled": True,
}
result = resolve_config(site_config)
assert result["blocking_mode"] == "opt_out"
assert result["consent_expiry_days"] == 180
assert result["tcf_enabled"] is True
# Non-overridden values stay as defaults
assert result["gcm_enabled"] is True
def test_org_defaults_override_system_defaults(self):
org_defaults = {"consent_expiry_days": 90}
result = resolve_config({}, org_defaults=org_defaults)
assert result["consent_expiry_days"] == 90
def test_site_config_overrides_org_defaults(self):
org_defaults = {"consent_expiry_days": 90}
site_config = {"consent_expiry_days": 30}
result = resolve_config(site_config, org_defaults=org_defaults)
assert result["consent_expiry_days"] == 30
def test_none_values_in_site_config_do_not_override(self):
site_config = {"blocking_mode": None, "consent_expiry_days": 180}
result = resolve_config(site_config)
# None should not override the default
assert result["blocking_mode"] == "opt_in"
assert result["consent_expiry_days"] == 180
def test_regional_override_applied(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"US-CA": "opt_out", "EU": "opt_in"},
}
result = resolve_config(site_config, region="US-CA")
assert result["blocking_mode"] == "opt_out"
def test_regional_override_falls_back_to_default(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"EU": "opt_in", "DEFAULT": "opt_out"},
}
result = resolve_config(site_config, region="BR")
assert result["blocking_mode"] == "opt_out"
def test_regional_override_no_match_keeps_site_config(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"EU": "opt_out"},
}
result = resolve_config(site_config, region="JP")
assert result["blocking_mode"] == "opt_in"
def test_no_region_ignores_regional_modes(self):
site_config = {
"blocking_mode": "opt_in",
"regional_modes": {"US-CA": "opt_out"},
}
result = resolve_config(site_config)
assert result["blocking_mode"] == "opt_in"
def test_gpp_site_config_overrides_defaults(self):
site_config = {
"gpp_enabled": False,
"gpp_supported_apis": ["usca", "usva"],
}
result = resolve_config(site_config)
assert result["gpp_enabled"] is False
assert result["gpp_supported_apis"] == ["usca", "usva"]
def test_gpc_site_config_overrides_defaults(self):
site_config = {
"gpc_enabled": False,
"gpc_global_honour": True,
"gpc_jurisdictions": ["US-CA"],
}
result = resolve_config(site_config)
assert result["gpc_enabled"] is False
assert result["gpc_global_honour"] is True
assert result["gpc_jurisdictions"] == ["US-CA"]
def test_gpp_gpc_org_defaults_override_system(self):
org_defaults = {
"gpp_enabled": False,
"gpc_global_honour": True,
}
result = resolve_config({}, org_defaults=org_defaults)
assert result["gpp_enabled"] is False
assert result["gpc_global_honour"] is True
# Non-overridden GPP/GPC fields stay as system defaults
assert result["gpc_enabled"] is True
def test_gpp_gpc_site_overrides_org(self):
org_defaults = {"gpp_supported_apis": ["usca"]}
site_config = {"gpp_supported_apis": ["usnat", "usco"]}
result = resolve_config(site_config, org_defaults=org_defaults)
assert result["gpp_supported_apis"] == ["usnat", "usco"]
def test_group_defaults_override_org_defaults(self):
org_defaults = {"consent_expiry_days": 90, "tcf_enabled": True}
group_defaults = {"consent_expiry_days": 60}
result = resolve_config(
{},
org_defaults=org_defaults,
group_defaults=group_defaults,
)
assert result["consent_expiry_days"] == 60 # Group overrides org
assert result["tcf_enabled"] is True # Still from org
def test_site_config_overrides_group_defaults(self):
group_defaults = {"consent_expiry_days": 60}
site_config = {"consent_expiry_days": 30}
result = resolve_config(site_config, group_defaults=group_defaults)
assert result["consent_expiry_days"] == 30 # Site overrides group
def test_none_in_group_defaults_does_not_override(self):
org_defaults = {"consent_expiry_days": 90}
group_defaults = {"consent_expiry_days": None}
result = resolve_config(
{},
org_defaults=org_defaults,
group_defaults=group_defaults,
)
assert result["consent_expiry_days"] == 90 # Org value preserved
def test_full_hierarchy(self):
org_defaults = {
"consent_expiry_days": 90,
"tcf_enabled": True,
}
site_config = {
"consent_expiry_days": 60,
"banner_config": {"primaryColour": "#ff0000"},
"regional_modes": {"EU": "opt_in", "US": "opt_out"},
}
result = resolve_config(site_config, org_defaults=org_defaults, region="US")
assert result["consent_expiry_days"] == 60 # Site overrides org
assert result["tcf_enabled"] is True # From org defaults
assert result["blocking_mode"] == "opt_out" # Regional override
assert result["banner_config"] == {"primaryColour": "#ff0000"}
def test_full_hierarchy_with_group(self):
org_defaults = {
"consent_expiry_days": 90,
"tcf_enabled": True,
"blocking_mode": "opt_in",
}
group_defaults = {
"consent_expiry_days": 60,
"privacy_policy_url": "https://group.example.com/privacy",
}
site_config = {
"banner_config": {"primaryColour": "#ff0000"},
"regional_modes": {"US": "opt_out"},
}
result = resolve_config(
site_config,
org_defaults=org_defaults,
group_defaults=group_defaults,
region="US",
)
assert result["consent_expiry_days"] == 60 # From group
assert result["tcf_enabled"] is True # From org
assert result["blocking_mode"] == "opt_out" # Regional override
assert result["privacy_policy_url"] == "https://group.example.com/privacy" # From group
assert result["banner_config"] == {"primaryColour": "#ff0000"} # From site
class TestBuildPublicConfig:
def test_includes_required_fields(self):
site_id = str(uuid.uuid4())
resolved = {**SYSTEM_DEFAULTS, "id": "config-123"}
result = build_public_config(site_id, resolved)
assert result["site_id"] == site_id
assert result["id"] == "config-123"
assert result["blocking_mode"] == "opt_in"
assert result["consent_expiry_days"] == 365
assert "gcm_enabled" in result
assert "tcf_enabled" in result
assert "banner_config" in result
assert result["gpp_enabled"] is True
assert result["gpp_supported_apis"] == ["usnat"]
assert result["gpc_enabled"] is True
assert result["gpc_jurisdictions"] == [
"US-CA",
"US-CO",
"US-CT",
"US-TX",
"US-MT",
]
assert result["gpc_global_honour"] is False
def test_strips_unknown_internal_fields(self):
site_id = str(uuid.uuid4())
resolved = {
**SYSTEM_DEFAULTS,
"id": "",
"internal_field": "should_not_appear",
"scan_enabled": True,
}
result = build_public_config(site_id, resolved)
assert "internal_field" not in result
assert "scan_enabled" not in result
class TestConfigRoutes:
def test_resolved_config_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/config/sites/{site_id}/resolved" in routes
def test_publish_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/config/sites/{site_id}/publish" in routes
def test_inheritance_route_registered(self, app):
routes = [r.path for r in app.routes]
assert "/api/v1/config/sites/{site_id}/inheritance" in routes
@pytest.mark.asyncio
async def test_publish_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.post(f"/api/v1/config/sites/{site_id}/publish")
assert resp.status_code == 401

View File

@@ -0,0 +1,130 @@
"""Tests for consent recording API schemas and routes."""
import uuid
import pytest
from pydantic import ValidationError
from src.schemas.consent import (
ConsentAction,
ConsentRecordCreate,
ConsentRecordResponse,
ConsentVerifyResponse,
)
class TestConsentSchemas:
def test_create_accept_all(self):
record = ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="visitor-abc-123",
action=ConsentAction.ACCEPT_ALL,
categories_accepted=["necessary", "analytics", "marketing"],
)
assert record.action == "accept_all"
assert len(record.categories_accepted) == 3
assert record.categories_rejected is None
def test_create_custom(self):
record = ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="visitor-xyz",
action=ConsentAction.CUSTOM,
categories_accepted=["necessary", "functional"],
categories_rejected=["analytics", "marketing"],
tc_string="COwQHgAAAAA",
gcm_state={"analytics_storage": "denied", "ad_storage": "denied"},
page_url="https://example.com/page",
country_code="GB",
region_code="GB-ENG",
)
assert record.action == "custom"
assert record.tc_string == "COwQHgAAAAA"
assert record.gcm_state["analytics_storage"] == "denied"
def test_create_reject_all(self):
record = ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="v-1",
action=ConsentAction.REJECT_ALL,
categories_accepted=["necessary"],
categories_rejected=["analytics", "marketing", "functional"],
)
assert record.action == "reject_all"
def test_empty_visitor_id_rejected(self):
with pytest.raises(ValidationError):
ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="",
action=ConsentAction.ACCEPT_ALL,
categories_accepted=["necessary"],
)
def test_invalid_action_rejected(self):
with pytest.raises(ValidationError):
ConsentRecordCreate(
site_id=uuid.uuid4(),
visitor_id="v-1",
action="invalid_action",
categories_accepted=[],
)
def test_response_from_attributes(self):
resp = ConsentRecordResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
visitor_id="v-1",
action="accept_all",
categories_accepted=["necessary"],
categories_rejected=None,
tc_string=None,
gcm_state=None,
page_url=None,
country_code=None,
region_code=None,
consented_at="2026-01-01T00:00:00Z",
)
assert resp.action == "accept_all"
def test_verify_response(self):
resp = ConsentVerifyResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
visitor_id="v-1",
action="accept_all",
categories_accepted=["necessary"],
consented_at="2026-01-01T00:00:00Z",
)
assert resp.valid is True
class TestConsentActions:
def test_action_values(self):
assert ConsentAction.ACCEPT_ALL == "accept_all"
assert ConsentAction.REJECT_ALL == "reject_all"
assert ConsentAction.CUSTOM == "custom"
assert ConsentAction.WITHDRAW == "withdraw"
@pytest.mark.asyncio
class TestConsentRoutesRegistered:
async def test_consent_routes_exist(self, client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/api/v1/consent/" in paths
assert "/api/v1/consent/{consent_id}" in paths
assert "/api/v1/consent/verify/{consent_id}" in paths
async def test_consent_post_validates_body(self, client):
"""POST /consent rejects invalid payloads."""
response = await client.post(
"/api/v1/consent/",
json={"invalid": "body"},
)
assert response.status_code == 422
async def test_config_public_endpoint_exists(self, client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/api/v1/config/sites/{site_id}" in paths

View File

@@ -0,0 +1,218 @@
"""Tests for cookie category, cookie, and allow-list schemas and routes."""
import uuid
import pytest
from pydantic import ValidationError
from src.schemas.cookie import (
AllowListEntryCreate,
AllowListEntryUpdate,
CookieCategoryResponse,
CookieCreate,
CookieResponse,
CookieUpdate,
ReviewStatus,
StorageType,
)
# ─── Schema tests ────────────────────────────────────────────────────
class TestStorageType:
def test_values(self):
assert StorageType.cookie == "cookie"
assert StorageType.local_storage == "local_storage"
assert StorageType.session_storage == "session_storage"
assert StorageType.indexed_db == "indexed_db"
class TestReviewStatus:
def test_values(self):
assert ReviewStatus.pending == "pending"
assert ReviewStatus.approved == "approved"
assert ReviewStatus.rejected == "rejected"
class TestCookieCreate:
def test_valid_minimal(self):
schema = CookieCreate(name="_ga", domain=".example.com")
assert schema.name == "_ga"
assert schema.domain == ".example.com"
assert schema.storage_type == StorageType.cookie
assert schema.category_id is None
def test_valid_full(self):
cat_id = uuid.uuid4()
schema = CookieCreate(
name="_ga",
domain=".google.com",
storage_type=StorageType.cookie,
category_id=cat_id,
description="Google Analytics cookie",
vendor="Google",
path="/",
max_age_seconds=63072000,
is_http_only=False,
is_secure=True,
same_site="Lax",
)
assert schema.category_id == cat_id
assert schema.max_age_seconds == 63072000
def test_rejects_empty_name(self):
with pytest.raises(ValidationError):
CookieCreate(name="", domain=".example.com")
def test_rejects_empty_domain(self):
with pytest.raises(ValidationError):
CookieCreate(name="_ga", domain="")
class TestCookieUpdate:
def test_partial_update(self):
schema = CookieUpdate(review_status=ReviewStatus.approved)
dump = schema.model_dump(exclude_unset=True)
assert dump == {"review_status": ReviewStatus.approved}
def test_update_category(self):
cat_id = uuid.uuid4()
schema = CookieUpdate(category_id=cat_id)
assert schema.category_id == cat_id
class TestAllowListEntryCreate:
def test_valid(self):
cat_id = uuid.uuid4()
schema = AllowListEntryCreate(
name_pattern="_ga*",
domain_pattern=".google.com",
category_id=cat_id,
description="Google Analytics cookies",
)
assert schema.name_pattern == "_ga*"
assert schema.category_id == cat_id
def test_rejects_empty_name_pattern(self):
with pytest.raises(ValidationError):
AllowListEntryCreate(
name_pattern="",
domain_pattern=".example.com",
category_id=uuid.uuid4(),
)
class TestAllowListEntryUpdate:
def test_partial_update(self):
schema = AllowListEntryUpdate(description="Updated description")
dump = schema.model_dump(exclude_unset=True)
assert dump == {"description": "Updated description"}
class TestCookieCategoryResponse:
def test_from_dict(self):
now = "2024-01-01T00:00:00"
resp = CookieCategoryResponse(
id=uuid.uuid4(),
name="Analytics",
slug="analytics",
description="Analytics cookies",
is_essential=False,
display_order=2,
tcf_purpose_ids=[1, 3],
gcm_consent_types=["analytics_storage"],
created_at=now,
updated_at=now,
)
assert resp.slug == "analytics"
assert resp.is_essential is False
class TestCookieResponse:
def test_from_dict(self):
now = "2024-01-01T00:00:00"
resp = CookieResponse(
id=uuid.uuid4(),
site_id=uuid.uuid4(),
name="_ga",
domain=".google.com",
storage_type="cookie",
review_status="pending",
created_at=now,
updated_at=now,
)
assert resp.name == "_ga"
assert resp.review_status == "pending"
# ─── Route tests ─────────────────────────────────────────────────────
class TestCookieCategoryRoutes:
def test_categories_route_registered(self, app):
"""Verify the categories routes are registered in the app."""
routes = [r.path for r in app.routes]
assert "/api/v1/cookies/categories" in routes
assert "/api/v1/cookies/categories/{category_id}" in routes
class TestCookieRoutes:
@pytest.mark.asyncio
async def test_list_cookies_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.get(f"/api/v1/cookies/sites/{site_id}")
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_create_cookie_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}",
json={"name": "_ga", "domain": ".google.com"},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_create_cookie_rejects_invalid_body(self, client):
site_id = uuid.uuid4()
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}",
json={"name": "", "domain": ""},
headers={"Authorization": "Bearer fake-token"},
)
# Should return 401 (bad token) or 422 (validation)
assert resp.status_code in (401, 422)
@pytest.mark.asyncio
async def test_summary_route_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/summary")
assert resp.status_code == 401
class TestAllowListRoutes:
@pytest.mark.asyncio
async def test_list_allow_list_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/allow-list")
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_create_allow_list_requires_auth(self, client):
site_id = uuid.uuid4()
resp = await client.post(
f"/api/v1/cookies/sites/{site_id}/allow-list",
json={
"name_pattern": "_ga*",
"domain_pattern": ".google.com",
"category_id": str(uuid.uuid4()),
},
)
assert resp.status_code == 401
@pytest.mark.asyncio
async def test_delete_allow_list_requires_auth(self, client):
site_id = uuid.uuid4()
entry_id = uuid.uuid4()
resp = await client.delete(f"/api/v1/cookies/sites/{site_id}/allow-list/{entry_id}")
assert resp.status_code == 401

222
apps/api/tests/test_cors.py Normal file
View File

@@ -0,0 +1,222 @@
"""Tests for the dynamic CORS origin validation service."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from src.services.cors import extract_domain_from_origin, get_allowed_domains, is_origin_allowed
class TestExtractDomainFromOrigin:
def test_https_origin(self):
assert extract_domain_from_origin("https://example.com") == "example.com"
def test_http_origin(self):
assert extract_domain_from_origin("http://example.com") == "example.com"
def test_origin_with_port(self):
assert extract_domain_from_origin("https://example.com:443") == "example.com"
def test_origin_with_subdomain(self):
assert extract_domain_from_origin("https://www.example.com") == "www.example.com"
def test_localhost(self):
assert extract_domain_from_origin("http://localhost:5173") == "localhost"
def test_empty_string(self):
assert extract_domain_from_origin("") is None
def test_invalid_url(self):
# urlparse is lenient, but hostname may be None for really bad input
result = extract_domain_from_origin("not-a-url")
# urlparse("not-a-url") sets hostname to None
assert result is None
class TestIsOriginAllowed:
def test_static_origin_exact_match(self):
assert (
is_origin_allowed(
"http://localhost:5173",
["http://localhost:5173"],
set(),
)
is True
)
def test_static_origin_no_match(self):
assert (
is_origin_allowed(
"https://evil.com",
["http://localhost:5173"],
set(),
)
is False
)
def test_wildcard_allows_everything(self):
assert (
is_origin_allowed(
"https://anything.com",
["*"],
set(),
)
is True
)
def test_registered_domain_match(self):
assert (
is_origin_allowed(
"https://example.com",
[],
{"example.com", "other.com"},
)
is True
)
def test_registered_domain_case_insensitive(self):
assert (
is_origin_allowed(
"https://Example.COM",
[],
{"example.com"},
)
is True
)
def test_registered_domain_no_match(self):
assert (
is_origin_allowed(
"https://evil.com",
[],
{"example.com"},
)
is False
)
def test_static_takes_priority(self):
assert (
is_origin_allowed(
"http://localhost:5173",
["http://localhost:5173"],
{"example.com"},
)
is True
)
def test_origin_with_port_matches_domain(self):
assert (
is_origin_allowed(
"https://example.com:8443",
[],
{"example.com"},
)
is True
)
def test_subdomain_matches_if_registered(self):
# www.example.com only matches if explicitly registered
assert (
is_origin_allowed(
"https://www.example.com",
[],
{"example.com"},
)
is False
)
def test_subdomain_matches_when_registered(self):
assert (
is_origin_allowed(
"https://www.example.com",
[],
{"www.example.com"},
)
is True
)
def test_empty_origin(self):
assert (
is_origin_allowed(
"",
[],
{"example.com"},
)
is False
)
def test_empty_lists(self):
assert (
is_origin_allowed(
"https://example.com",
[],
set(),
)
is False
)
class TestGetAllowedDomains:
@pytest.mark.asyncio
async def test_returns_primary_domains(self):
row1 = MagicMock()
row1.domain = "example.com"
row1.additional_domains = None
row2 = MagicMock()
row2.domain = "other.com"
row2.additional_domains = None
mock_result = MagicMock()
mock_result.all.return_value = [row1, row2]
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert "example.com" in domains
assert "other.com" in domains
@pytest.mark.asyncio
async def test_includes_additional_domains(self):
row = MagicMock()
row.domain = "example.com"
row.additional_domains = ["www.example.com", "app.example.com"]
mock_result = MagicMock()
mock_result.all.return_value = [row]
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert "example.com" in domains
assert "www.example.com" in domains
assert "app.example.com" in domains
@pytest.mark.asyncio
async def test_lowercases_domains(self):
row = MagicMock()
row.domain = "Example.COM"
row.additional_domains = ["WWW.Example.COM"]
mock_result = MagicMock()
mock_result.all.return_value = [row]
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert "example.com" in domains
assert "www.example.com" in domains
@pytest.mark.asyncio
async def test_empty_result(self):
mock_result = MagicMock()
mock_result.all.return_value = []
db = AsyncMock()
db.execute = AsyncMock(return_value=mock_result)
domains = await get_allowed_domains(db)
assert domains == set()

View File

@@ -0,0 +1,88 @@
"""Unit tests for auth dependencies."""
import uuid
from src.schemas.auth import CurrentUser
from src.services.auth import create_access_token, create_refresh_token, decode_token
class TestCurrentUser:
def test_has_role_matching(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="admin",
)
assert user.has_role("admin", "owner") is True
def test_has_role_not_matching(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="viewer",
)
assert user.has_role("admin", "owner") is False
def test_is_admin_property(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="admin",
)
assert user.is_admin is True
def test_is_admin_owner(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="owner",
)
assert user.is_admin is True
def test_is_admin_viewer(self):
user = CurrentUser(
id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
email="test@test.com",
role="viewer",
)
assert user.is_admin is False
class TestTokenCreation:
def test_access_token_roundtrip(self):
user_id = uuid.uuid4()
org_id = uuid.uuid4()
token = create_access_token(
user_id=user_id,
organisation_id=org_id,
role="editor",
email="test@test.com",
)
payload = decode_token(token)
assert payload["sub"] == str(user_id)
assert payload["org_id"] == str(org_id)
assert payload["role"] == "editor"
assert payload["type"] == "access"
def test_refresh_token_roundtrip(self):
user_id = uuid.uuid4()
org_id = uuid.uuid4()
token = create_refresh_token(user_id=user_id, organisation_id=org_id)
payload = decode_token(token)
assert payload["sub"] == str(user_id)
assert payload["type"] == "refresh"
def test_access_token_is_not_refresh(self):
token = create_access_token(
user_id=uuid.uuid4(),
organisation_id=uuid.uuid4(),
role="viewer",
email="test@test.com",
)
payload = decode_token(token)
assert payload["type"] != "refresh"

View File

@@ -0,0 +1,141 @@
"""Tests for the extension registry and edition detection."""
import pytest
from fastapi import APIRouter, FastAPI
from src.config.edition import edition_name, is_ee
from src.extensions.registry import (
ExtensionRegistry,
OpenAPITag,
discover_extensions,
get_registry,
)
# -- Edition detection -------------------------------------------------------
class TestEditionDetection:
"""The ``is_ee()`` / ``edition_name()`` helpers should return a
consistent pair regardless of which edition is installed. Core tests
don't assume a specific edition — that's checked in each repo's
own integration tests."""
def test_edition_name_matches_is_ee(self):
assert edition_name() == ("ee" if is_ee() else "ce")
def test_edition_name_is_valid(self):
assert edition_name() in ("ce", "ee")
# -- Extension registry (unit) ----------------------------------------------
class TestExtensionRegistry:
def _make_registry(self) -> ExtensionRegistry:
return ExtensionRegistry()
def test_empty_registry(self):
reg = self._make_registry()
assert reg.routers == []
assert reg.model_modules == []
assert reg.startup_hooks == []
def test_add_router(self):
reg = self._make_registry()
router = APIRouter()
reg.add_router(router, prefix="/api/v1")
assert len(reg.routers) == 1
assert reg.routers[0].router is router
assert reg.routers[0].prefix == "/api/v1"
def test_add_router_with_tags(self):
reg = self._make_registry()
router = APIRouter()
tag = OpenAPITag(name="billing", description="Billing endpoints")
reg.add_router(router, tags=[tag])
assert reg.routers[0].tags == [tag]
def test_add_model_module(self):
reg = self._make_registry()
reg.add_model_module("ee.api.src.models.billing")
assert reg.model_modules == ["ee.api.src.models.billing"]
def test_add_startup_hook(self):
reg = self._make_registry()
async def hook(app: FastAPI) -> None:
pass
reg.add_startup_hook(hook)
assert len(reg.startup_hooks) == 1
def test_apply_mounts_routers(self):
reg = self._make_registry()
router = APIRouter()
@router.get("/test")
async def _test() -> dict[str, str]:
return {"ok": True}
reg.add_router(router, prefix="/ext")
app = FastAPI()
reg.apply(app)
# The router should be included in the app routes
paths = [r.path for r in app.routes]
assert "/ext/test" in paths
def test_apply_adds_openapi_tags(self):
reg = self._make_registry()
router = APIRouter()
tag = OpenAPITag(name="billing", description="Billing endpoints")
reg.add_router(router, tags=[tag])
app = FastAPI()
app.openapi_tags = []
reg.apply(app)
assert any(t["name"] == "billing" for t in app.openapi_tags)
def test_apply_skips_duplicate_tags(self):
reg = self._make_registry()
router = APIRouter()
tag = OpenAPITag(name="billing", description="Billing endpoints")
reg.add_router(router, tags=[tag])
app = FastAPI()
app.openapi_tags = [{"name": "billing", "description": "Existing"}]
reg.apply(app)
billing_tags = [t for t in app.openapi_tags if t["name"] == "billing"]
assert len(billing_tags) == 1
assert billing_tags[0]["description"] == "Existing"
# -- discover_extensions -----------------------------------------------------
class TestDiscoverExtensions:
def test_discover_extensions_does_not_raise(self):
"""discover_extensions should not raise regardless of edition."""
discover_extensions()
# -- Global registry ---------------------------------------------------------
class TestGlobalRegistry:
def test_get_registry_returns_singleton(self):
assert get_registry() is get_registry()
# -- Health endpoint with edition field --------------------------------------
@pytest.mark.asyncio
async def test_health_reports_edition(client):
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["edition"] in ("ce", "ee")

View File

@@ -0,0 +1,573 @@
"""Tests for the GeoIP service.
Covers header-based detection, IP lookup, country-to-region mapping,
client IP extraction, and the combined detect_region flow.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
import src.services.geoip as geoip_module
from src.services.geoip import (
GeoResult,
_is_private_ip,
country_to_region,
detect_region,
detect_region_from_headers,
get_client_ip,
lookup_ip_maxmind,
lookup_ip_region,
)
# ── country_to_region ────────────────────────────────────────────────
class TestCountryToRegion:
def test_eu_country_returns_eu(self):
assert country_to_region("DE") == "EU"
assert country_to_region("FR") == "EU"
assert country_to_region("IT") == "EU"
assert country_to_region("ES") == "EU"
def test_eu_country_case_insensitive(self):
assert country_to_region("de") == "EU"
assert country_to_region("fr") == "EU"
def test_gb_returns_gb(self):
assert country_to_region("GB") == "GB"
def test_br_returns_br(self):
assert country_to_region("BR") == "BR"
def test_us_without_state(self):
assert country_to_region("US") == "US"
def test_us_with_state(self):
assert country_to_region("US", "CA") == "US-CA"
assert country_to_region("US", "ny") == "US-NY"
def test_non_eu_country_returned_as_is(self):
assert country_to_region("JP") == "JP"
assert country_to_region("AU") == "AU"
assert country_to_region("CA") == "CA"
# ── detect_region_from_headers ───────────────────────────────────────
class TestDetectRegionFromHeaders:
def _make_request(self, headers: dict[str, str]) -> MagicMock:
request = MagicMock()
request.headers = headers
return request
def test_cloudflare_header(self):
request = self._make_request({"cf-ipcountry": "DE"})
result = detect_region_from_headers(request)
assert result.country_code == "DE"
assert result.region == "EU"
assert result.is_resolved is True
def test_vercel_header(self):
request = self._make_request({"x-vercel-ip-country": "GB"})
result = detect_region_from_headers(request)
assert result.country_code == "GB"
assert result.region == "GB"
def test_appengine_header(self):
request = self._make_request({"x-appengine-country": "BR"})
result = detect_region_from_headers(request)
assert result.country_code == "BR"
assert result.region == "BR"
def test_custom_header(self):
request = self._make_request({"x-country-code": "JP"})
result = detect_region_from_headers(request)
assert result.country_code == "JP"
assert result.region == "JP"
def test_no_geo_headers(self):
request = self._make_request({})
result = detect_region_from_headers(request)
assert result.country_code is None
assert result.region is None
assert result.is_resolved is False
def test_ignores_xx_value(self):
request = self._make_request({"cf-ipcountry": "XX"})
result = detect_region_from_headers(request)
assert result.is_resolved is False
def test_header_priority_cloudflare_first(self):
request = self._make_request(
{
"cf-ipcountry": "FR",
"x-vercel-ip-country": "DE",
}
)
result = detect_region_from_headers(request)
assert result.country_code == "FR"
def test_case_normalisation(self):
request = self._make_request({"cf-ipcountry": "gb"})
result = detect_region_from_headers(request)
assert result.country_code == "GB"
assert result.region == "GB"
def test_configured_custom_header(self):
"""An operator-configured header is honoured."""
request = self._make_request({"x-gclb-country": "JP"})
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
result = detect_region_from_headers(request)
assert result.country_code == "JP"
assert result.region == "JP"
def test_configured_custom_header_takes_priority(self):
"""When both a custom and a built-in header are present, the
custom one wins — that's the operator's explicit choice."""
request = self._make_request(
{
"cf-ipcountry": "FR",
"x-gclb-country": "JP",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
result = detect_region_from_headers(request)
assert result.country_code == "JP"
def test_configured_header_falls_through_to_builtin(self):
"""If the custom header isn't present, the built-in list still
applies."""
request = self._make_request({"cf-ipcountry": "FR"})
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = None
result = detect_region_from_headers(request)
assert result.country_code == "FR"
assert result.region == "EU"
def test_configured_region_header_pairs_with_country(self):
"""A configured region header is paired with the custom country."""
request = self._make_request(
{
"x-gclb-country": "US",
"x-gclb-region": "CA",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.country_code == "US"
assert result.region == "US-CA"
def test_configured_region_header_strips_country_prefix(self):
"""ISO 3166-2 subdivisions may arrive prefixed (``US-CA``)."""
request = self._make_request(
{
"x-gclb-country": "US",
"x-gclb-region": "US-NY",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.region == "US-NY"
def test_configured_region_header_missing_is_country_only(self):
"""Only country hits region-aware path if the region header is absent."""
request = self._make_request({"x-gclb-country": "US"})
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.country_code == "US"
assert result.region == "US"
def test_configured_region_header_xx_ignored(self):
"""Region value of ``XX`` is treated as unknown."""
request = self._make_request(
{
"x-gclb-country": "US",
"x-gclb-region": "XX",
}
)
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_country_header = "x-gclb-country"
mock_settings.return_value.geoip_region_header = "x-gclb-region"
result = detect_region_from_headers(request)
assert result.region == "US"
# ── get_client_ip ────────────────────────────────────────────────────
class TestGetClientIp:
def _make_request(
self,
headers: dict[str, str] | None = None,
client_host: str | None = None,
) -> MagicMock:
request = MagicMock()
request.headers = headers or {}
if client_host:
request.client = MagicMock()
request.client.host = client_host
else:
request.client = None
return request
def test_x_forwarded_for_single(self):
request = self._make_request({"x-forwarded-for": "1.2.3.4"})
assert get_client_ip(request) == "1.2.3.4"
def test_x_forwarded_for_multiple(self):
request = self._make_request({"x-forwarded-for": "1.2.3.4, 5.6.7.8, 9.10.11.12"})
assert get_client_ip(request) == "1.2.3.4"
def test_x_real_ip(self):
request = self._make_request({"x-real-ip": "1.2.3.4"})
assert get_client_ip(request) == "1.2.3.4"
def test_forwarded_for_takes_priority_over_real_ip(self):
request = self._make_request(
{
"x-forwarded-for": "1.1.1.1",
"x-real-ip": "2.2.2.2",
}
)
assert get_client_ip(request) == "1.1.1.1"
def test_falls_back_to_client_host(self):
request = self._make_request(client_host="10.0.0.1")
assert get_client_ip(request) == "10.0.0.1"
def test_returns_none_when_no_ip(self):
request = self._make_request()
assert get_client_ip(request) is None
# ── _is_private_ip ───────────────────────────────────────────────────
class TestIsPrivateIp:
def test_loopback(self):
assert _is_private_ip("127.0.0.1") is True
assert _is_private_ip("127.0.0.2") is True
def test_private_ranges(self):
assert _is_private_ip("10.0.0.1") is True
assert _is_private_ip("192.168.1.1") is True
assert _is_private_ip("172.16.0.1") is True
def test_ipv6_loopback(self):
assert _is_private_ip("::1") is True
def test_localhost_string(self):
assert _is_private_ip("localhost") is True
def test_public_ip(self):
assert _is_private_ip("8.8.8.8") is False
assert _is_private_ip("1.1.1.1") is False
# ── lookup_ip_region ─────────────────────────────────────────────────
class TestLookupIpRegion:
@pytest.mark.asyncio
async def test_private_ip_returns_unresolved(self):
result = await lookup_ip_region("127.0.0.1")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_private_ip_10_range(self):
result = await lookup_ip_region("10.0.0.1")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_successful_lookup(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"countryCode": "DE",
"region": "BY",
}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.country_code == "DE"
assert result.region == "EU"
@pytest.mark.asyncio
async def test_failed_status(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"status": "fail", "message": "invalid query"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_http_error(self):
mock_response = MagicMock()
mock_response.status_code = 500
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_network_exception(self):
mock_client = AsyncMock()
mock_client.get = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
@pytest.mark.asyncio
async def test_us_with_state_lookup(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"countryCode": "US",
"region": "CA",
}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.country_code == "US"
assert result.region == "US-CA"
@pytest.mark.asyncio
async def test_missing_country_code(self):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"status": "success"}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await lookup_ip_region("8.8.8.8")
assert result.is_resolved is False
# ── detect_region (combined) ─────────────────────────────────────────
class TestDetectRegion:
@pytest.mark.asyncio
async def test_uses_headers_when_available(self):
request = MagicMock()
request.headers = {"cf-ipcountry": "FR"}
request.client = None
result = await detect_region(request)
assert result.country_code == "FR"
assert result.region == "EU"
@pytest.mark.asyncio
async def test_falls_back_to_ip_lookup(self):
request = MagicMock()
request.headers = {}
request.client = MagicMock()
request.client.host = "8.8.8.8"
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"status": "success",
"countryCode": "US",
"region": "CA",
}
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=mock_response)
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
mock_client.__aexit__ = AsyncMock(return_value=False)
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
result = await detect_region(request)
assert result.country_code == "US"
assert result.region == "US-CA"
@pytest.mark.asyncio
async def test_returns_unresolved_when_no_ip(self):
request = MagicMock()
request.headers = {}
request.client = None
result = await detect_region(request)
assert result.is_resolved is False
# ── GeoResult ────────────────────────────────────────────────────────
class TestGeoResult:
def test_is_resolved_true(self):
result = GeoResult(country_code="GB", region="GB")
assert result.is_resolved is True
def test_is_resolved_false(self):
result = GeoResult(country_code=None, region=None)
assert result.is_resolved is False
def test_frozen_dataclass(self):
result = GeoResult(country_code="GB", region="GB")
with pytest.raises(AttributeError):
result.country_code = "US" # type: ignore[misc]
# ── MaxMind database lookup ──────────────────────────────────────────
class TestLookupIpMaxmind:
def setup_method(self):
# Reset the module-level cache so each test starts clean.
geoip_module._maxmind_reader = None
geoip_module._maxmind_initialised = False
def _mock_reader(self, country_iso: str | None, subdivision_iso: str | None):
reader = MagicMock()
response = MagicMock()
response.country.iso_code = country_iso
if subdivision_iso is None:
response.subdivisions = None
else:
response.subdivisions.most_specific.iso_code = subdivision_iso
reader.city.return_value = response
return reader
def test_private_ip_returns_unresolved(self):
result = lookup_ip_maxmind("10.0.0.1")
assert result.is_resolved is False
def test_no_db_configured_returns_unresolved(self):
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_maxmind_db_path = None
result = lookup_ip_maxmind("8.8.8.8")
assert result.is_resolved is False
def test_successful_lookup_with_subdivision(self):
reader = self._mock_reader("US", "CA")
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.country_code == "US"
assert result.region == "US-CA"
reader.city.assert_called_once_with("8.8.8.8")
def test_successful_lookup_without_subdivision(self):
reader = self._mock_reader("DE", None)
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.country_code == "DE"
assert result.region == "EU"
def test_reader_raises_returns_unresolved(self):
reader = MagicMock()
reader.city.side_effect = RuntimeError("corrupt db")
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.is_resolved is False
def test_reader_missing_country_returns_unresolved(self):
reader = self._mock_reader(None, None)
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
result = lookup_ip_maxmind("8.8.8.8")
assert result.is_resolved is False
def test_bad_db_path_is_cached_as_failure(self):
"""A missing ``.mmdb`` file should not reopen on every request."""
with patch("src.services.geoip.get_settings") as mock_settings:
mock_settings.return_value.geoip_maxmind_db_path = "/nonexistent/geo.mmdb"
r1 = lookup_ip_maxmind("8.8.8.8")
r2 = lookup_ip_maxmind("1.1.1.1")
assert r1.is_resolved is False
assert r2.is_resolved is False
assert geoip_module._maxmind_initialised is True
assert geoip_module._maxmind_reader is None
class TestDetectRegionMaxmind:
def setup_method(self):
geoip_module._maxmind_reader = None
geoip_module._maxmind_initialised = False
@pytest.mark.asyncio
async def test_uses_maxmind_before_external_api(self):
"""With MaxMind configured, ip-api.com must not be called."""
reader = MagicMock()
response = MagicMock()
response.country.iso_code = "GB"
response.subdivisions.most_specific.iso_code = "SCT"
reader.city.return_value = response
geoip_module._maxmind_reader = reader
geoip_module._maxmind_initialised = True
request = MagicMock()
request.headers = {"x-forwarded-for": "8.8.8.8"}
request.client = None
with (
patch("src.services.geoip.get_settings") as mock_settings,
patch("src.services.geoip.httpx.AsyncClient") as mock_httpx,
):
mock_settings.return_value.geoip_country_header = None
mock_settings.return_value.geoip_region_header = None
mock_settings.return_value.geoip_maxmind_db_path = "/data/GeoLite2-City.mmdb"
result = await detect_region(request)
assert result.country_code == "GB"
assert result.region == "GB-SCT"
mock_httpx.assert_not_called()

View File

@@ -0,0 +1,31 @@
import pytest
@pytest.mark.asyncio
async def test_health_endpoint(client):
response = await client.get("/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["edition"] in ("ce", "ee")
@pytest.mark.asyncio
async def test_openapi_schema(client):
response = await client.get("/openapi.json")
assert response.status_code == 200
schema = response.json()
assert schema["info"]["title"] == "ConsentOS API"
assert schema["info"]["version"] == "0.1.0"
@pytest.mark.asyncio
async def test_api_routes_registered(client):
response = await client.get("/openapi.json")
paths = response.json()["paths"]
assert "/health" in paths
assert "/api/v1/auth/login" in paths
assert "/api/v1/config/sites/{site_id}" in paths
assert "/api/v1/consent/" in paths
assert "/api/v1/scanner/scans" in paths
assert "/api/v1/compliance/check/{site_id}" in paths

View File

@@ -0,0 +1,89 @@
"""Integration tests for authentication endpoints (requires database)."""
from tests.conftest import requires_db
@requires_db
class TestAuthLogin:
async def test_login_success(self, db_client, test_user):
resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123",
},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert data["token_type"] == "bearer"
async def test_login_wrong_password(self, db_client, test_user):
resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "wrong",
},
)
assert resp.status_code == 401
async def test_login_nonexistent_user(self, db_client):
resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": "nobody@test.com",
"password": "anything",
},
)
assert resp.status_code == 401
async def test_login_invalid_email(self, db_client):
resp = await db_client.post(
"/api/v1/auth/login",
json={"email": "not-an-email", "password": "anything"},
)
assert resp.status_code == 422
@requires_db
class TestAuthMe:
async def test_me_returns_user(self, db_client, auth_headers, test_user):
resp = await db_client.get("/api/v1/auth/me", headers=auth_headers)
assert resp.status_code == 200
data = resp.json()
assert data["email"] == test_user.email
assert data["role"] == "owner"
async def test_me_without_token(self, db_client):
resp = await db_client.get("/api/v1/auth/me")
assert resp.status_code == 401
@requires_db
class TestAuthRefresh:
async def test_refresh_returns_new_tokens(self, db_client, test_user):
# First login to get a refresh token
login_resp = await db_client.post(
"/api/v1/auth/login",
json={
"email": test_user.email,
"password": "TestPassword123",
},
)
refresh_token = login_resp.json()["refresh_token"]
resp = await db_client.post(
"/api/v1/auth/refresh",
json={"refresh_token": refresh_token},
)
assert resp.status_code == 200
assert "access_token" in resp.json()
async def test_refresh_with_invalid_token(self, db_client):
resp = await db_client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "invalid-token"},
)
assert resp.status_code == 401

Some files were not shown because too many files have changed in this diff Show More