feat: initial public release
ConsentOS — a privacy-first cookie consent management platform. Self-hosted, source-available alternative to OneTrust, Cookiebot, and CookieYes. Full standards coverage (IAB TCF v2.2, GPP v1, Google Consent Mode v2, GPC, Shopify Customer Privacy API), multi-tenant architecture with role-based access, configuration cascade (system → org → group → site → region), dark-pattern detection in the scanner, and a tamper-evident consent record audit trail. This is the initial public release. Prior development history is retained internally. See README.md for the feature list, architecture overview, and quick-start instructions. Licensed under the Elastic Licence 2.0 — self-host freely; do not resell as a managed service.
This commit is contained in:
12
apps/api/.dockerignore
Normal file
12
apps/api/.dockerignore
Normal file
@@ -0,0 +1,12 @@
|
||||
.venv/
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
.pytest_cache/
|
||||
.mypy_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
*.egg-info/
|
||||
tests/
|
||||
fly.toml
|
||||
.env
|
||||
.env.*
|
||||
51
apps/api/Dockerfile
Normal file
51
apps/api/Dockerfile
Normal file
@@ -0,0 +1,51 @@
|
||||
# ── Build stage ──────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc libpq-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir --prefix=/install .
|
||||
|
||||
# ── Runtime stage ────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
libpq5 curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Non-root user for security
|
||||
RUN groupadd -r cmp && useradd -r -g cmp -d /app -s /sbin/nologin cmp
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy installed dependencies from builder
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
RUN chown -R cmp:cmp /app
|
||||
|
||||
USER cmp
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=30s --retries=3 \
|
||||
CMD curl -f http://localhost:8000/health || exit 1
|
||||
|
||||
# Start the server. Database migrations and the initial-admin
|
||||
# bootstrap are owned by a separate init container (see the OSS
|
||||
# docker-compose in consentos-deployment) — the API assumes the
|
||||
# schema is ready by the time it starts.
|
||||
# Workers configurable via WEB_CONCURRENCY (default 4, use 1 for 256MB RAM)
|
||||
CMD ["sh", "-c", "uvicorn src.main:app \
|
||||
--host 0.0.0.0 \
|
||||
--port ${PORT:-8000} \
|
||||
--workers ${WEB_CONCURRENCY:-4} \
|
||||
--access-log \
|
||||
--proxy-headers \
|
||||
--forwarded-allow-ips '*'"]
|
||||
149
apps/api/alembic.ini
Normal file
149
apps/api/alembic.ini
Normal file
@@ -0,0 +1,149 @@
|
||||
# A generic, single database configuration.
|
||||
|
||||
[alembic]
|
||||
# path to migration scripts.
|
||||
# this is typically a path given in POSIX (e.g. forward slashes)
|
||||
# format, relative to the token %(here)s which refers to the location of this
|
||||
# ini file
|
||||
script_location = %(here)s/alembic
|
||||
|
||||
# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s
|
||||
# Uncomment the line below if you want the files to be prepended with date and time
|
||||
# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file
|
||||
# for all available tokens
|
||||
# file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s
|
||||
# Or organize into date-based subdirectories (requires recursive_version_locations = true)
|
||||
# file_template = %%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# sys.path path, will be prepended to sys.path if present.
|
||||
# defaults to the current working directory. for multiple paths, the path separator
|
||||
# is defined by "path_separator" below.
|
||||
prepend_sys_path = .
|
||||
|
||||
|
||||
# timezone to use when rendering the date within the migration file
|
||||
# as well as the filename.
|
||||
# If specified, requires the tzdata library which can be installed by adding
|
||||
# `alembic[tz]` to the pip requirements.
|
||||
# string value is passed to ZoneInfo()
|
||||
# leave blank for localtime
|
||||
# timezone =
|
||||
|
||||
# max length of characters to apply to the "slug" field
|
||||
# truncate_slug_length = 40
|
||||
|
||||
# set to 'true' to run the environment during
|
||||
# the 'revision' command, regardless of autogenerate
|
||||
# revision_environment = false
|
||||
|
||||
# set to 'true' to allow .pyc and .pyo files without
|
||||
# a source .py file to be detected as revisions in the
|
||||
# versions/ directory
|
||||
# sourceless = false
|
||||
|
||||
# version location specification; This defaults
|
||||
# to <script_location>/versions. When using multiple version
|
||||
# directories, initial revisions must be specified with --version-path.
|
||||
# The path separator used here should be the separator specified by "path_separator"
|
||||
# below.
|
||||
# version_locations = %(here)s/bar:%(here)s/bat:%(here)s/alembic/versions
|
||||
|
||||
# path_separator; This indicates what character is used to split lists of file
|
||||
# paths, including version_locations and prepend_sys_path within configparser
|
||||
# files such as alembic.ini.
|
||||
# The default rendered in new alembic.ini files is "os", which uses os.pathsep
|
||||
# to provide os-dependent path splitting.
|
||||
#
|
||||
# Note that in order to support legacy alembic.ini files, this default does NOT
|
||||
# take place if path_separator is not present in alembic.ini. If this
|
||||
# option is omitted entirely, fallback logic is as follows:
|
||||
#
|
||||
# 1. Parsing of the version_locations option falls back to using the legacy
|
||||
# "version_path_separator" key, which if absent then falls back to the legacy
|
||||
# behavior of splitting on spaces and/or commas.
|
||||
# 2. Parsing of the prepend_sys_path option falls back to the legacy
|
||||
# behavior of splitting on spaces, commas, or colons.
|
||||
#
|
||||
# Valid values for path_separator are:
|
||||
#
|
||||
# path_separator = :
|
||||
# path_separator = ;
|
||||
# path_separator = space
|
||||
# path_separator = newline
|
||||
#
|
||||
# Use os.pathsep. Default configuration used for new projects.
|
||||
path_separator = os
|
||||
|
||||
# set to 'true' to search source files recursively
|
||||
# in each "version_locations" directory
|
||||
# new in Alembic version 1.10
|
||||
# recursive_version_locations = false
|
||||
|
||||
# the output encoding used when revision files
|
||||
# are written from script.py.mako
|
||||
# output_encoding = utf-8
|
||||
|
||||
# database URL. This is consumed by the user-maintained env.py script only.
|
||||
# other means of configuring database URLs may be customized within the env.py
|
||||
# file.
|
||||
sqlalchemy.url = postgresql://consentos:consentos@localhost:5432/consentos
|
||||
|
||||
|
||||
[post_write_hooks]
|
||||
# post_write_hooks defines scripts or Python functions that are run
|
||||
# on newly generated revision scripts. See the documentation for further
|
||||
# detail and examples
|
||||
|
||||
# format using "black" - use the console_scripts runner, against the "black" entrypoint
|
||||
# hooks = black
|
||||
# black.type = console_scripts
|
||||
# black.entrypoint = black
|
||||
# black.options = -l 79 REVISION_SCRIPT_FILENAME
|
||||
|
||||
# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module
|
||||
# hooks = ruff
|
||||
# ruff.type = module
|
||||
# ruff.module = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Alternatively, use the exec runner to execute a binary found on your PATH
|
||||
# hooks = ruff
|
||||
# ruff.type = exec
|
||||
# ruff.executable = ruff
|
||||
# ruff.options = check --fix REVISION_SCRIPT_FILENAME
|
||||
|
||||
# Logging configuration. This is also consumed by the user-maintained
|
||||
# env.py script only.
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARNING
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARNING
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
1
apps/api/alembic/README
Normal file
1
apps/api/alembic/README
Normal file
@@ -0,0 +1 @@
|
||||
Generic single-database configuration.
|
||||
61
apps/api/alembic/env.py
Normal file
61
apps/api/alembic/env.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import os
|
||||
from logging.config import fileConfig
|
||||
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
from alembic import context
|
||||
from src.models import Base
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url from environment if set
|
||||
database_url = os.environ.get("DATABASE_URL")
|
||||
if database_url:
|
||||
# Alembic needs the synchronous driver
|
||||
database_url = database_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
config.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
# Set up Python logging from the config file
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section, {}),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
28
apps/api/alembic/script.py.mako
Normal file
28
apps/api/alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
${downgrades if downgrades else "pass"}
|
||||
442
apps/api/alembic/versions/0001_initial_schema.py
Normal file
442
apps/api/alembic/versions/0001_initial_schema.py
Normal file
@@ -0,0 +1,442 @@
|
||||
"""initial schema
|
||||
|
||||
Revision ID: 0001
|
||||
Revises:
|
||||
Create Date: 2026-04-13
|
||||
|
||||
Creates the full core schema plus seeds the default cookie categories.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '0001'
|
||||
down_revision: Union[str, Sequence[str], None] = None
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Upgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('cookie_categories',
|
||||
sa.Column('name', sa.String(length=50), nullable=False),
|
||||
sa.Column('slug', sa.String(length=50), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_essential', sa.Boolean(), nullable=False),
|
||||
sa.Column('display_order', sa.Integer(), server_default='0', nullable=False),
|
||||
sa.Column('tcf_purpose_ids', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gcm_consent_types', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('name'),
|
||||
sa.UniqueConstraint('slug')
|
||||
)
|
||||
op.create_table('organisations',
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('slug', sa.String(length=100), nullable=False),
|
||||
sa.Column('contact_email', sa.String(length=255), nullable=True),
|
||||
sa.Column('billing_plan', sa.String(length=50), server_default='free', nullable=False),
|
||||
sa.Column('notes', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_organisations_slug'), 'organisations', ['slug'], unique=True)
|
||||
op.create_table('known_cookies',
|
||||
sa.Column('name_pattern', sa.String(length=255), nullable=False),
|
||||
sa.Column('domain_pattern', sa.String(length=255), nullable=False),
|
||||
sa.Column('category_id', sa.UUID(), nullable=False),
|
||||
sa.Column('vendor', sa.String(length=255), nullable=True),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('is_regex', sa.Boolean(), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['category_id'], ['cookie_categories.id'], ondelete='RESTRICT'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('name_pattern', 'domain_pattern', name='uq_known_cookies_name_domain')
|
||||
)
|
||||
op.create_index(op.f('ix_known_cookies_name_pattern'), 'known_cookies', ['name_pattern'], unique=False)
|
||||
op.create_table('org_configs',
|
||||
sa.Column('organisation_id', sa.UUID(), nullable=False),
|
||||
sa.Column('blocking_mode', sa.String(length=20), nullable=True),
|
||||
sa.Column('regional_modes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('tcf_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('tcf_publisher_cc', sa.String(length=2), nullable=True),
|
||||
sa.Column('gpp_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('gpp_supported_apis', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gpc_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('gpc_jurisdictions', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gpc_global_honour', sa.Boolean(), nullable=True),
|
||||
sa.Column('gcm_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('gcm_default', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('shopify_privacy_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('banner_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('privacy_policy_url', sa.Text(), nullable=True),
|
||||
sa.Column('terms_url', sa.Text(), nullable=True),
|
||||
sa.Column('scan_schedule_cron', sa.String(length=100), nullable=True),
|
||||
sa.Column('scan_max_pages', sa.Integer(), nullable=True),
|
||||
sa.Column('consent_expiry_days', sa.Integer(), nullable=True),
|
||||
sa.Column('consent_retention_days', sa.Integer(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('organisation_id')
|
||||
)
|
||||
op.create_table('site_groups',
|
||||
sa.Column('organisation_id', sa.UUID(), nullable=False),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('organisation_id', 'name', name='uq_site_groups_org_name')
|
||||
)
|
||||
op.create_index(op.f('ix_site_groups_organisation_id'), 'site_groups', ['organisation_id'], unique=False)
|
||||
op.create_table('users',
|
||||
sa.Column('organisation_id', sa.UUID(), nullable=False),
|
||||
sa.Column('email', sa.String(length=255), nullable=False),
|
||||
sa.Column('password_hash', sa.String(length=255), nullable=False),
|
||||
sa.Column('full_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('role', sa.String(length=20), server_default='viewer', nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_users_email'), 'users', ['email'], unique=True)
|
||||
op.create_index(op.f('ix_users_organisation_id'), 'users', ['organisation_id'], unique=False)
|
||||
op.create_table('site_group_configs',
|
||||
sa.Column('site_group_id', sa.UUID(), nullable=False),
|
||||
sa.Column('blocking_mode', sa.String(length=20), nullable=True),
|
||||
sa.Column('regional_modes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('tcf_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('tcf_publisher_cc', sa.String(length=2), nullable=True),
|
||||
sa.Column('gpp_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('gpp_supported_apis', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gpc_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('gpc_jurisdictions', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gpc_global_honour', sa.Boolean(), nullable=True),
|
||||
sa.Column('gcm_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('gcm_default', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('shopify_privacy_enabled', sa.Boolean(), nullable=True),
|
||||
sa.Column('banner_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('privacy_policy_url', sa.Text(), nullable=True),
|
||||
sa.Column('terms_url', sa.Text(), nullable=True),
|
||||
sa.Column('scan_schedule_cron', sa.String(length=100), nullable=True),
|
||||
sa.Column('scan_max_pages', sa.Integer(), nullable=True),
|
||||
sa.Column('consent_expiry_days', sa.Integer(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['site_group_id'], ['site_groups.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('site_group_id')
|
||||
)
|
||||
op.create_table('sites',
|
||||
sa.Column('organisation_id', sa.UUID(), nullable=False),
|
||||
sa.Column('domain', sa.String(length=255), nullable=False),
|
||||
sa.Column('display_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('is_active', sa.Boolean(), nullable=False),
|
||||
sa.Column('additional_domains', postgresql.ARRAY(sa.String(length=255)), nullable=True),
|
||||
sa.Column('site_group_id', sa.UUID(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('deleted_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['organisation_id'], ['organisations.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['site_group_id'], ['site_groups.id'], ondelete='SET NULL'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('organisation_id', 'domain', name='uq_sites_org_domain')
|
||||
)
|
||||
op.create_index(op.f('ix_sites_domain'), 'sites', ['domain'], unique=False)
|
||||
op.create_index(op.f('ix_sites_organisation_id'), 'sites', ['organisation_id'], unique=False)
|
||||
op.create_index(op.f('ix_sites_site_group_id'), 'sites', ['site_group_id'], unique=False)
|
||||
op.create_table('consent_records',
|
||||
sa.Column('site_id', sa.UUID(), nullable=False),
|
||||
sa.Column('visitor_id', sa.String(length=255), nullable=False),
|
||||
sa.Column('ip_hash', sa.String(length=64), nullable=True),
|
||||
sa.Column('user_agent_hash', sa.String(length=64), nullable=True),
|
||||
sa.Column('action', sa.String(length=30), nullable=False),
|
||||
sa.Column('categories_accepted', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('categories_rejected', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('tc_string', sa.Text(), nullable=True),
|
||||
sa.Column('gcm_state', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gpp_string', sa.Text(), nullable=True),
|
||||
sa.Column('gpc_detected', sa.Boolean(), nullable=True),
|
||||
sa.Column('gpc_honoured', sa.Boolean(), nullable=True),
|
||||
sa.Column('ab_test_id', sa.UUID(), nullable=True),
|
||||
sa.Column('ab_variant_id', sa.UUID(), nullable=True),
|
||||
sa.Column('page_url', sa.Text(), nullable=True),
|
||||
sa.Column('country_code', sa.String(length=5), nullable=True),
|
||||
sa.Column('region_code', sa.String(length=10), nullable=True),
|
||||
sa.Column('consented_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_consent_records_ab_test_id'), 'consent_records', ['ab_test_id'], unique=False)
|
||||
op.create_index(op.f('ix_consent_records_consented_at'), 'consent_records', ['consented_at'], unique=False)
|
||||
op.create_index(op.f('ix_consent_records_site_id'), 'consent_records', ['site_id'], unique=False)
|
||||
op.create_index(op.f('ix_consent_records_visitor_id'), 'consent_records', ['visitor_id'], unique=False)
|
||||
op.create_table('cookie_allow_list',
|
||||
sa.Column('site_id', sa.UUID(), nullable=False),
|
||||
sa.Column('category_id', sa.UUID(), nullable=False),
|
||||
sa.Column('name_pattern', sa.String(length=255), nullable=False),
|
||||
sa.Column('domain_pattern', sa.String(length=255), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['category_id'], ['cookie_categories.id'], ondelete='RESTRICT'),
|
||||
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('site_id', 'name_pattern', 'domain_pattern', name='uq_allow_list_site_name_domain')
|
||||
)
|
||||
op.create_index(op.f('ix_cookie_allow_list_site_id'), 'cookie_allow_list', ['site_id'], unique=False)
|
||||
op.create_table('cookies',
|
||||
sa.Column('site_id', sa.UUID(), nullable=False),
|
||||
sa.Column('category_id', sa.UUID(), nullable=True),
|
||||
sa.Column('name', sa.String(length=255), nullable=False),
|
||||
sa.Column('domain', sa.String(length=255), nullable=False),
|
||||
sa.Column('storage_type', sa.String(length=30), server_default='cookie', nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('vendor', sa.String(length=255), nullable=True),
|
||||
sa.Column('path', sa.String(length=500), nullable=True),
|
||||
sa.Column('max_age_seconds', sa.Integer(), nullable=True),
|
||||
sa.Column('is_http_only', sa.Boolean(), nullable=True),
|
||||
sa.Column('is_secure', sa.Boolean(), nullable=True),
|
||||
sa.Column('same_site', sa.String(length=10), nullable=True),
|
||||
sa.Column('review_status', sa.String(length=20), server_default='pending', nullable=False),
|
||||
sa.Column('first_seen_at', sa.String(length=50), nullable=True),
|
||||
sa.Column('last_seen_at', sa.String(length=50), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['category_id'], ['cookie_categories.id'], ondelete='SET NULL'),
|
||||
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('site_id', 'name', 'domain', 'storage_type', name='uq_cookies_site_name_domain_type')
|
||||
)
|
||||
op.create_index(op.f('ix_cookies_category_id'), 'cookies', ['category_id'], unique=False)
|
||||
op.create_index(op.f('ix_cookies_name'), 'cookies', ['name'], unique=False)
|
||||
op.create_index(op.f('ix_cookies_site_id'), 'cookies', ['site_id'], unique=False)
|
||||
op.create_table('scan_jobs',
|
||||
sa.Column('site_id', sa.UUID(), nullable=False),
|
||||
sa.Column('status', sa.String(length=20), server_default='pending', nullable=False),
|
||||
sa.Column('trigger', sa.String(length=20), server_default='manual', nullable=False),
|
||||
sa.Column('pages_scanned', sa.Integer(), server_default='0', nullable=False),
|
||||
sa.Column('pages_total', sa.Integer(), nullable=True),
|
||||
sa.Column('cookies_found', sa.Integer(), server_default='0', nullable=False),
|
||||
sa.Column('error_message', sa.Text(), nullable=True),
|
||||
sa.Column('started_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('completed_at', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_scan_jobs_site_id'), 'scan_jobs', ['site_id'], unique=False)
|
||||
op.create_index(op.f('ix_scan_jobs_status'), 'scan_jobs', ['status'], unique=False)
|
||||
op.create_table('site_configs',
|
||||
sa.Column('site_id', sa.UUID(), nullable=False),
|
||||
sa.Column('blocking_mode', sa.String(length=20), server_default='opt_in', nullable=False),
|
||||
sa.Column('regional_modes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('tcf_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('tcf_publisher_cc', sa.String(length=2), nullable=True),
|
||||
sa.Column('gpp_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('gpp_supported_apis', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gpc_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('gpc_jurisdictions', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('gpc_global_honour', sa.Boolean(), nullable=False),
|
||||
sa.Column('gcm_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('gcm_default', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('shopify_privacy_enabled', sa.Boolean(), nullable=False),
|
||||
sa.Column('banner_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('display_mode', sa.String(length=30), server_default='bottom_banner', nullable=False),
|
||||
sa.Column('privacy_policy_url', sa.Text(), nullable=True),
|
||||
sa.Column('terms_url', sa.Text(), nullable=True),
|
||||
sa.Column('scan_schedule_cron', sa.String(length=100), nullable=True),
|
||||
sa.Column('scan_max_pages', sa.Integer(), server_default='50', nullable=False),
|
||||
sa.Column('consent_expiry_days', sa.Integer(), server_default='365', nullable=False),
|
||||
sa.Column('consent_retention_days', sa.Integer(), nullable=True),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('site_id')
|
||||
)
|
||||
op.create_table('translations',
|
||||
sa.Column('site_id', sa.UUID(), nullable=False),
|
||||
sa.Column('locale', sa.String(length=10), nullable=False),
|
||||
sa.Column('strings', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['site_id'], ['sites.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('site_id', 'locale', name='uq_translations_site_locale')
|
||||
)
|
||||
op.create_index(op.f('ix_translations_site_id'), 'translations', ['site_id'], unique=False)
|
||||
op.create_table('scan_results',
|
||||
sa.Column('scan_job_id', sa.UUID(), nullable=False),
|
||||
sa.Column('page_url', sa.Text(), nullable=False),
|
||||
sa.Column('cookie_name', sa.String(length=255), nullable=False),
|
||||
sa.Column('cookie_domain', sa.String(length=255), nullable=False),
|
||||
sa.Column('storage_type', sa.String(length=30), server_default='cookie', nullable=False),
|
||||
sa.Column('attributes', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('script_source', sa.Text(), nullable=True),
|
||||
sa.Column('auto_category', sa.String(length=50), nullable=True),
|
||||
sa.Column('initiator_chain', postgresql.ARRAY(sa.Text()), nullable=True, comment='Ordered script URLs from root initiator to leaf'),
|
||||
sa.Column('found_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('id', sa.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['scan_job_id'], ['scan_jobs.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_scan_results_scan_job_id'), 'scan_results', ['scan_job_id'], unique=False)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
# ── Seed default cookie categories ───────────────────────────────
|
||||
cookie_categories_table = sa.table(
|
||||
"cookie_categories",
|
||||
sa.column("id", sa.UUID()),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("slug", sa.String),
|
||||
sa.column("description", sa.Text),
|
||||
sa.column("is_essential", sa.Boolean),
|
||||
sa.column("display_order", sa.Integer),
|
||||
sa.column("tcf_purpose_ids", postgresql.JSONB),
|
||||
sa.column("gcm_consent_types", postgresql.JSONB),
|
||||
)
|
||||
op.bulk_insert(
|
||||
cookie_categories_table,
|
||||
[
|
||||
{
|
||||
"id": uuid.UUID("10000000-0000-0000-0000-000000000001"),
|
||||
"name": "Necessary",
|
||||
"slug": "necessary",
|
||||
"description": (
|
||||
"Essential cookies required for the website to function. "
|
||||
"These cannot be disabled."
|
||||
),
|
||||
"is_essential": True,
|
||||
"display_order": 0,
|
||||
"tcf_purpose_ids": None,
|
||||
"gcm_consent_types": ["functionality_storage", "security_storage"],
|
||||
},
|
||||
{
|
||||
"id": uuid.UUID("10000000-0000-0000-0000-000000000002"),
|
||||
"name": "Functional",
|
||||
"slug": "functional",
|
||||
"description": (
|
||||
"Cookies that enable enhanced functionality and personalisation, "
|
||||
"such as remembering preferences."
|
||||
),
|
||||
"is_essential": False,
|
||||
"display_order": 1,
|
||||
"tcf_purpose_ids": [1],
|
||||
"gcm_consent_types": ["functionality_storage", "personalization_storage"],
|
||||
},
|
||||
{
|
||||
"id": uuid.UUID("10000000-0000-0000-0000-000000000003"),
|
||||
"name": "Analytics",
|
||||
"slug": "analytics",
|
||||
"description": (
|
||||
"Cookies used to collect information about how visitors use the website, "
|
||||
"helping to improve the site."
|
||||
),
|
||||
"is_essential": False,
|
||||
"display_order": 2,
|
||||
"tcf_purpose_ids": [7, 8, 9],
|
||||
"gcm_consent_types": ["analytics_storage"],
|
||||
},
|
||||
{
|
||||
"id": uuid.UUID("10000000-0000-0000-0000-000000000004"),
|
||||
"name": "Marketing",
|
||||
"slug": "marketing",
|
||||
"description": (
|
||||
"Cookies used to deliver personalised advertisements and "
|
||||
"track advertising campaign performance."
|
||||
),
|
||||
"is_essential": False,
|
||||
"display_order": 3,
|
||||
"tcf_purpose_ids": [2, 3, 4, 5, 6, 10, 11],
|
||||
"gcm_consent_types": ["ad_storage", "ad_user_data", "ad_personalization"],
|
||||
},
|
||||
{
|
||||
"id": uuid.UUID("10000000-0000-0000-0000-000000000005"),
|
||||
"name": "Personalisation",
|
||||
"slug": "personalisation",
|
||||
"description": (
|
||||
"Cookies that enable content personalisation based on "
|
||||
"user profiles and browsing behaviour."
|
||||
),
|
||||
"is_essential": False,
|
||||
"display_order": 4,
|
||||
"tcf_purpose_ids": [3, 4, 6],
|
||||
"gcm_consent_types": ["personalization_storage"],
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Downgrade schema."""
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_index(op.f('ix_scan_results_scan_job_id'), table_name='scan_results')
|
||||
op.drop_table('scan_results')
|
||||
op.drop_index(op.f('ix_translations_site_id'), table_name='translations')
|
||||
op.drop_table('translations')
|
||||
op.drop_table('site_configs')
|
||||
op.drop_index(op.f('ix_scan_jobs_status'), table_name='scan_jobs')
|
||||
op.drop_index(op.f('ix_scan_jobs_site_id'), table_name='scan_jobs')
|
||||
op.drop_table('scan_jobs')
|
||||
op.drop_index(op.f('ix_cookies_site_id'), table_name='cookies')
|
||||
op.drop_index(op.f('ix_cookies_name'), table_name='cookies')
|
||||
op.drop_index(op.f('ix_cookies_category_id'), table_name='cookies')
|
||||
op.drop_table('cookies')
|
||||
op.drop_index(op.f('ix_cookie_allow_list_site_id'), table_name='cookie_allow_list')
|
||||
op.drop_table('cookie_allow_list')
|
||||
op.drop_index(op.f('ix_consent_records_visitor_id'), table_name='consent_records')
|
||||
op.drop_index(op.f('ix_consent_records_site_id'), table_name='consent_records')
|
||||
op.drop_index(op.f('ix_consent_records_consented_at'), table_name='consent_records')
|
||||
op.drop_index(op.f('ix_consent_records_ab_test_id'), table_name='consent_records')
|
||||
op.drop_table('consent_records')
|
||||
op.drop_index(op.f('ix_sites_site_group_id'), table_name='sites')
|
||||
op.drop_index(op.f('ix_sites_organisation_id'), table_name='sites')
|
||||
op.drop_index(op.f('ix_sites_domain'), table_name='sites')
|
||||
op.drop_table('sites')
|
||||
op.drop_table('site_group_configs')
|
||||
op.drop_index(op.f('ix_users_organisation_id'), table_name='users')
|
||||
op.drop_index(op.f('ix_users_email'), table_name='users')
|
||||
op.drop_table('users')
|
||||
op.drop_index(op.f('ix_site_groups_organisation_id'), table_name='site_groups')
|
||||
op.drop_table('site_groups')
|
||||
op.drop_table('org_configs')
|
||||
op.drop_index(op.f('ix_known_cookies_name_pattern'), table_name='known_cookies')
|
||||
op.drop_table('known_cookies')
|
||||
op.drop_index(op.f('ix_organisations_slug'), table_name='organisations')
|
||||
op.drop_table('organisations')
|
||||
op.drop_table('cookie_categories')
|
||||
# ### end Alembic commands ###
|
||||
36
apps/api/alembic/versions/0002_composite_consent_index.py
Normal file
36
apps/api/alembic/versions/0002_composite_consent_index.py
Normal file
@@ -0,0 +1,36 @@
|
||||
"""composite index on consent_records(site_id, consented_at)
|
||||
|
||||
Revision ID: 0002
|
||||
Revises: 0001
|
||||
Create Date: 2026-04-13
|
||||
|
||||
The most common analytic query pattern is "consents for site X in date
|
||||
range" (consent rates, trends, regional breakdowns). The single-column
|
||||
indexes on ``site_id`` and ``consented_at`` each help a little, but a
|
||||
composite index is materially faster for the combined filter.
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "0002"
|
||||
down_revision: Union[str, Sequence[str], None] = "0001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_consent_records_site_consented_at",
|
||||
"consent_records",
|
||||
["site_id", "consented_at"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(
|
||||
"ix_consent_records_site_consented_at",
|
||||
table_name="consent_records",
|
||||
)
|
||||
9
apps/api/data/ATTRIBUTION.md
Normal file
9
apps/api/data/ATTRIBUTION.md
Normal file
@@ -0,0 +1,9 @@
|
||||
# Open Cookie Database
|
||||
|
||||
The file `open-cookie-database.csv` is sourced from the
|
||||
[Open Cookie Database](https://github.com/jkwakman/Open-Cookie-Database)
|
||||
by jkwakman, licensed under the Creative Commons Attribution 4.0 International
|
||||
(CC BY 4.0) licence.
|
||||
|
||||
To update the database, download the latest CSV from the repository above and
|
||||
replace this file, then run `make seed` to reload the data.
|
||||
2265
apps/api/data/open-cookie-database.csv
Normal file
2265
apps/api/data/open-cookie-database.csv
Normal file
File diff suppressed because it is too large
Load Diff
65
apps/api/fly.toml
Normal file
65
apps/api/fly.toml
Normal file
@@ -0,0 +1,65 @@
|
||||
# Fly.io configuration for the ConsentOS API
|
||||
# See https://fly.io/docs/reference/configuration/ for reference.
|
||||
#
|
||||
# This app runs three process groups from the same Docker image:
|
||||
# - app: FastAPI web server (handles HTTP traffic)
|
||||
# - worker: Celery worker (processes scan jobs and background tasks)
|
||||
# - beat: Celery beat scheduler (triggers periodic tasks)
|
||||
|
||||
app = "consentos-api"
|
||||
primary_region = "lhr" # London
|
||||
|
||||
[build]
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
[env]
|
||||
ENVIRONMENT = "production"
|
||||
LOG_LEVEL = "INFO"
|
||||
PORT = "8000"
|
||||
RATE_LIMIT_ENABLED = "true"
|
||||
RATE_LIMIT_PER_MINUTE = "120"
|
||||
|
||||
# ── Migrations run once per deployment, before processes start ──────
|
||||
[deploy]
|
||||
release_command = "python -m alembic upgrade head"
|
||||
|
||||
# ── Process groups ──────────────────────────────────────────────────
|
||||
[processes]
|
||||
app = "sh start.sh"
|
||||
worker = "celery -A src.celery_app worker --loglevel=info --concurrency=2"
|
||||
beat = "celery -A src.celery_app beat --loglevel=info"
|
||||
|
||||
# ── HTTP service (only the 'app' process serves HTTP) ───────────────
|
||||
[http_service]
|
||||
internal_port = 8000
|
||||
force_https = true
|
||||
auto_stop_machines = "stop"
|
||||
auto_start_machines = true
|
||||
min_machines_running = 0
|
||||
processes = ["app"]
|
||||
|
||||
[http_service.concurrency]
|
||||
type = "requests"
|
||||
hard_limit = 250
|
||||
soft_limit = 200
|
||||
|
||||
# ── VM sizing per process ───────────────────────────────────────────
|
||||
# The app and beat processes are lightweight; the worker needs more
|
||||
# memory for processing scan results.
|
||||
[[vm]]
|
||||
memory = "256mb"
|
||||
cpu_kind = "shared"
|
||||
cpus = 1
|
||||
processes = ["app"]
|
||||
|
||||
[[vm]]
|
||||
memory = "256mb"
|
||||
cpu_kind = "shared"
|
||||
cpus = 1
|
||||
processes = ["worker"]
|
||||
|
||||
[[vm]]
|
||||
memory = "256mb"
|
||||
cpu_kind = "shared"
|
||||
cpus = 1
|
||||
processes = ["beat"]
|
||||
64
apps/api/pyproject.toml
Normal file
64
apps/api/pyproject.toml
Normal file
@@ -0,0 +1,64 @@
|
||||
[project]
|
||||
name = "consentos-api"
|
||||
version = "0.1.0"
|
||||
description = "ConsentOS — API service"
|
||||
license = "Elastic-2.0"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"fastapi>=0.115,<1",
|
||||
"uvicorn[standard]>=0.34,<1",
|
||||
"sqlalchemy[asyncio]>=2.0,<3",
|
||||
"asyncpg>=0.30,<1",
|
||||
"alembic>=1.14,<2",
|
||||
"pydantic>=2.0,<3",
|
||||
"pydantic-settings>=2.0,<3",
|
||||
"python-jose[cryptography]>=3.3,<4",
|
||||
"bcrypt>=4.0,<5",
|
||||
"redis>=5.0,<6",
|
||||
"celery>=5.4,<6",
|
||||
"httpx>=0.28,<1",
|
||||
"structlog>=24.0,<25",
|
||||
"psycopg2-binary>=2.9,<3",
|
||||
"email-validator>=2.0,<3",
|
||||
"jinja2>=3.1,<4",
|
||||
"markupsafe>=2.1,<3",
|
||||
"reportlab>=4.0,<5",
|
||||
"geoip2>=4.8,<5",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.0,<9",
|
||||
"pytest-asyncio>=0.24,<1",
|
||||
"pytest-cov>=6.0,<7",
|
||||
"httpx>=0.28,<1",
|
||||
"ruff>=0.9,<1",
|
||||
"mypy>=1.13,<2",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["setuptools>=68"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["."]
|
||||
include = ["src*"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
asyncio_mode = "auto"
|
||||
asyncio_default_fixture_loop_scope = "session"
|
||||
filterwarnings = ["ignore::DeprecationWarning"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "N", "W", "UP", "B", "SIM", "RUF"]
|
||||
ignore = ["B008"] # Depends() in FastAPI defaults is idiomatic
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.12"
|
||||
strict = true
|
||||
plugins = ["pydantic.mypy"]
|
||||
0
apps/api/src/__init__.py
Normal file
0
apps/api/src/__init__.py
Normal file
89
apps/api/src/celery_app.py
Normal file
89
apps/api/src/celery_app.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Celery application and task definitions for the CMP API.
|
||||
|
||||
Provides async-compatible scan scheduling via Celery with Redis as the
|
||||
broker and result backend.
|
||||
"""
|
||||
|
||||
import ssl
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Named `app` by Celery convention — the CLI finds it via -A src.celery_app
|
||||
app = Celery(
|
||||
"cmp",
|
||||
broker=settings.redis_url,
|
||||
backend=settings.redis_url,
|
||||
)
|
||||
|
||||
# When using rediss:// (TLS) — e.g. Upstash — Celery requires explicit
|
||||
# SSL certificate verification settings for both broker and backend.
|
||||
_conf: dict = {
|
||||
"task_serializer": "json",
|
||||
"accept_content": ["json"],
|
||||
"result_serializer": "json",
|
||||
"timezone": "UTC",
|
||||
"enable_utc": True,
|
||||
"task_track_started": True,
|
||||
"task_acks_late": True,
|
||||
"worker_prefetch_multiplier": 1,
|
||||
}
|
||||
|
||||
if settings.redis_url.startswith("rediss://"):
|
||||
_conf["broker_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
|
||||
_conf["redis_backend_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
|
||||
|
||||
app.conf.update(**_conf)
|
||||
|
||||
|
||||
# ── Beat schedule (periodic tasks) ──────────────────────────────────
|
||||
|
||||
app.conf.beat_schedule = {
|
||||
"check-scheduled-scans": {
|
||||
"task": "src.tasks.scanner.check_scheduled_scans",
|
||||
"schedule": crontab(minute="*/15"), # Every 15 minutes
|
||||
},
|
||||
"recover-stale-scans": {
|
||||
"task": "src.tasks.scanner.recover_stale_scans",
|
||||
"schedule": crontab(minute="*/5"), # Every 5 minutes
|
||||
},
|
||||
"purge-expired-consent-records": {
|
||||
"task": "src.tasks.retention.purge_expired_consent_records",
|
||||
"schedule": crontab(hour="1", minute="0"), # Daily at 01:00 UTC
|
||||
},
|
||||
}
|
||||
|
||||
# ── Explicit task imports ───────────────────────────────────────────
|
||||
# Must be at the bottom to avoid circular imports. These ensure the
|
||||
# worker process registers all @app.task definitions on startup.
|
||||
import src.tasks.retention # noqa: E402
|
||||
import src.tasks.scanner # noqa: E402, F401
|
||||
|
||||
# EE tasks are registered conditionally — they only exist in EE mode.
|
||||
try:
|
||||
import ee.api.src.tasks.compliance_scanner
|
||||
import ee.api.src.tasks.compliance_scoring
|
||||
import ee.api.src.tasks.retention # noqa: F401
|
||||
|
||||
app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-scheduled-compliance-scans": {
|
||||
"task": "src.tasks.compliance_scanner.check_scheduled_compliance_scans",
|
||||
"schedule": crontab(hour="3", minute="0"),
|
||||
},
|
||||
"compute-daily-compliance-scores": {
|
||||
"task": "src.tasks.compliance_scoring.compute_daily_scores",
|
||||
"schedule": crontab(hour="4", minute="0"),
|
||||
},
|
||||
"run-retention-purge": {
|
||||
"task": "src.tasks.retention.run_retention_purge",
|
||||
"schedule": crontab(hour="2", minute="0"),
|
||||
},
|
||||
}
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
0
apps/api/src/cli/__init__.py
Normal file
0
apps/api/src/cli/__init__.py
Normal file
40
apps/api/src/cli/bootstrap_admin.py
Normal file
40
apps/api/src/cli/bootstrap_admin.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""One-shot bootstrap of an initial organisation and owner user.
|
||||
|
||||
Usage:
|
||||
python -m src.cli.bootstrap_admin
|
||||
|
||||
Reads ``INITIAL_ADMIN_EMAIL`` and ``INITIAL_ADMIN_PASSWORD`` (plus the
|
||||
optional ``INITIAL_ADMIN_FULL_NAME``, ``INITIAL_ORG_NAME``, and
|
||||
``INITIAL_ORG_SLUG``) from the environment. If the ``users`` table is
|
||||
empty and both credentials are set, creates the org and owner user so
|
||||
the operator can log in to the admin UI. Idempotent: if any user
|
||||
already exists, exits 0 without touching the database.
|
||||
|
||||
Intended to be run as a one-shot init container *after* the database
|
||||
migrations have been applied — typically via ``depends_on`` with
|
||||
``service_healthy`` on the API container.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from src.config.logging import setup_logging
|
||||
from src.config.settings import get_settings
|
||||
from src.services.bootstrap import bootstrap_initial_admin
|
||||
|
||||
|
||||
async def _main() -> int:
|
||||
settings = get_settings()
|
||||
setup_logging(settings.log_level)
|
||||
await bootstrap_initial_admin(settings)
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> None:
|
||||
sys.exit(asyncio.run(_main()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
apps/api/src/cli/seed_known_cookies.py
Normal file
137
apps/api/src/cli/seed_known_cookies.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Seed the known_cookies table from the Open Cookie Database CSV.
|
||||
|
||||
Usage:
|
||||
python -m src.cli.seed_known_cookies [--csv PATH] [--clear]
|
||||
|
||||
The Open Cookie Database is a community-maintained catalogue of ~2,200+
|
||||
cookie patterns. See https://github.com/jkwakman/Open-Cookie-Database
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Category mapping: Open Cookie Database category → CMP slug
|
||||
# ---------------------------------------------------------------------------
|
||||
_CATEGORY_MAP: dict[str, str] = {
|
||||
"Functional": "functional",
|
||||
"Analytics": "analytics",
|
||||
"Marketing": "marketing",
|
||||
"Personalization": "personalisation",
|
||||
"Security": "necessary",
|
||||
}
|
||||
|
||||
_DEFAULT_CSV = Path(__file__).resolve().parent.parent.parent / "data" / "open-cookie-database.csv"
|
||||
|
||||
|
||||
def _build_sync_url(async_url: str) -> str:
|
||||
"""Convert an asyncpg DSN to a psycopg2 DSN for one-off scripts."""
|
||||
return async_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
|
||||
def seed(csv_path: Path, *, clear: bool = False) -> int:
|
||||
"""Read the CSV and upsert rows into known_cookies.
|
||||
|
||||
Returns the number of rows inserted.
|
||||
"""
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
engine = sa.create_engine(_build_sync_url(settings.database_url))
|
||||
|
||||
with engine.begin() as conn:
|
||||
# Build slug → category_id lookup
|
||||
rows = conn.execute(sa.text("SELECT id, slug FROM cookie_categories"))
|
||||
slug_to_id: dict[str, str] = {r[1]: str(r[0]) for r in rows}
|
||||
|
||||
if clear:
|
||||
conn.execute(sa.text("DELETE FROM known_cookies"))
|
||||
|
||||
inserted = 0
|
||||
with csv_path.open(newline="", encoding="utf-8") as fh:
|
||||
reader = csv.DictReader(fh)
|
||||
for row in reader:
|
||||
category = row.get("Category", "").strip()
|
||||
slug = _CATEGORY_MAP.get(category)
|
||||
if not slug or slug not in slug_to_id:
|
||||
continue
|
||||
|
||||
name = row.get("Cookie / Data Key name", "").strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
domain_raw = row.get("Domain", "").strip()
|
||||
domain = domain_raw if domain_raw else "*"
|
||||
|
||||
wildcard = row.get("Wildcard match", "0").strip() == "1"
|
||||
description = row.get("Description", "").strip() or None
|
||||
vendor = row.get("Platform", "").strip() or None
|
||||
|
||||
# Build pattern: if wildcard, append * to name for glob matching
|
||||
name_pattern = f"{name}*" if wildcard else name
|
||||
is_regex = False
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO known_cookies
|
||||
(id, name_pattern, domain_pattern, category_id,
|
||||
vendor, description, is_regex, created_at, updated_at)
|
||||
VALUES
|
||||
(:id, :name_pattern, :domain_pattern, :category_id,
|
||||
:vendor, :description, :is_regex, NOW(), NOW())
|
||||
ON CONFLICT (name_pattern, domain_pattern) DO UPDATE SET
|
||||
category_id = EXCLUDED.category_id,
|
||||
vendor = EXCLUDED.vendor,
|
||||
description = EXCLUDED.description,
|
||||
is_regex = EXCLUDED.is_regex,
|
||||
updated_at = NOW()
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"name_pattern": name_pattern,
|
||||
"domain_pattern": domain,
|
||||
"category_id": slug_to_id[slug],
|
||||
"vendor": vendor,
|
||||
"description": description,
|
||||
"is_regex": is_regex,
|
||||
},
|
||||
)
|
||||
inserted += 1
|
||||
|
||||
return inserted
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Seed known cookies from Open Cookie Database")
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=Path,
|
||||
default=_DEFAULT_CSV,
|
||||
help="Path to the Open Cookie Database CSV (default: bundled copy)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear",
|
||||
action="store_true",
|
||||
help="Delete all existing known_cookies before importing",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.csv.exists():
|
||||
print(f"Error: CSV not found at {args.csv}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
count = seed(args.csv, clear=args.clear)
|
||||
print(f"Seeded {count} known cookie patterns from {args.csv.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
apps/api/src/config/__init__.py
Normal file
3
apps/api/src/config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.config.settings import Settings, get_settings
|
||||
|
||||
__all__ = ["Settings", "get_settings"]
|
||||
26
apps/api/src/config/edition.py
Normal file
26
apps/api/src/config/edition.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Edition detection for the open-core architecture.
|
||||
|
||||
Determines whether the application is running in community edition (CE)
|
||||
or enterprise edition (EE) based on the availability of the ``ee``
|
||||
package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_ee() -> bool:
|
||||
"""Return ``True`` if enterprise extensions are available."""
|
||||
try:
|
||||
import ee # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def edition_name() -> str:
|
||||
"""Return a human-readable edition label (``"ee"`` or ``"ce"``)."""
|
||||
return "ee" if is_ee() else "ce"
|
||||
26
apps/api/src/config/logging.py
Normal file
26
apps/api/src/config/logging.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def setup_logging(log_level: str = "INFO") -> None:
|
||||
"""Configure structured logging with structlog."""
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.dev.set_exc_info,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.dev.ConsoleRenderer()
|
||||
if sys.stderr.isatty()
|
||||
else structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(
|
||||
getattr(logging, log_level.upper(), logging.INFO)
|
||||
),
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
166
apps/api/src/config/settings.py
Normal file
166
apps/api/src/config/settings.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
# Placeholder value — the application refuses to start in non-dev
|
||||
# environments if ``jwt_secret_key`` is left at this literal.
|
||||
_JWT_PLACEHOLDER = "CHANGE-ME-in-production"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
# Application
|
||||
app_name: str = "ConsentOS API"
|
||||
app_version: str = "0.1.0"
|
||||
debug: bool = False
|
||||
environment: str = "development"
|
||||
log_level: str = "INFO"
|
||||
|
||||
# Server
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
allowed_origins: str = "http://localhost:5173"
|
||||
|
||||
@property
|
||||
def allowed_origins_list(self) -> list[str]:
|
||||
"""Parse allowed_origins as a comma-separated string."""
|
||||
return [o.strip() for o in self.allowed_origins.split(",") if o.strip()]
|
||||
|
||||
# Database
|
||||
database_url: str = "postgresql+asyncpg://consentos:consentos@localhost:5432/consentos"
|
||||
database_echo: bool = False
|
||||
database_pool_size: int = 20
|
||||
database_max_overflow: int = 10
|
||||
|
||||
# Redis
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# JWT
|
||||
jwt_secret_key: str = _JWT_PLACEHOLDER
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_expire_minutes: int = 30
|
||||
jwt_refresh_token_expire_days: int = 7
|
||||
|
||||
# Pseudonymisation — HMAC key for IP / UA hashing on consent records.
|
||||
# Defaults to deriving from the JWT secret if not explicitly set.
|
||||
pseudonymisation_secret: str | None = None
|
||||
|
||||
# Bootstrap token — required as ``X-Admin-Bootstrap-Token`` on
|
||||
# ``POST /api/v1/organisations/``. When unset (the default), the
|
||||
# endpoint is disabled. Rotate or unset after your first org is
|
||||
# provisioned to prevent further tenant creation.
|
||||
admin_bootstrap_token: str | None = None
|
||||
|
||||
# Initial admin bootstrap — on first startup, if the ``users`` table
|
||||
# is empty and both credentials below are set, the API creates an
|
||||
# organisation and an owner user so the operator can log in to the
|
||||
# admin UI for the first time. Idempotent: once any user exists this
|
||||
# is a no-op, so the variables can safely remain set across restarts.
|
||||
# Rotate the password via the admin UI after first login.
|
||||
initial_admin_email: str | None = None
|
||||
initial_admin_password: str | None = None
|
||||
initial_admin_full_name: str = "Administrator"
|
||||
initial_org_name: str = "Default Organisation"
|
||||
initial_org_slug: str = "default"
|
||||
|
||||
# CDN — public URL where banner scripts (consent-loader.js,
|
||||
# consent-bundle.js) are hosted. In dev the admin UI dog-foods
|
||||
# the banner so localhost:5173 works for testing; in production
|
||||
# this should be a real CDN URL (CloudFlare Pages, S3+CloudFront,
|
||||
# Cloud CDN, etc.) — see docs for setup.
|
||||
cdn_base_url: str = "http://localhost:5173"
|
||||
|
||||
# Scanner service
|
||||
scanner_service_url: str = "http://localhost:8001"
|
||||
scanner_timeout_seconds: int = 300
|
||||
|
||||
# Extra GeoIP country header — checked *before* the built-in list
|
||||
# (``cf-ipcountry``, ``x-vercel-ip-country``, ``x-appengine-country``,
|
||||
# ``x-country-code``). Set this when running behind a CDN/load
|
||||
# balancer that uses a non-standard header, e.g. Google Cloud
|
||||
# Load Balancer's ``x-gclb-country`` or an internal edge proxy.
|
||||
# Header names are case-insensitive. Leave unset if one of the
|
||||
# built-in headers is fine.
|
||||
geoip_country_header: str | None = None
|
||||
|
||||
# Subdivision/state code header — optional companion to
|
||||
# ``GEOIP_COUNTRY_HEADER``. When both are set the API pairs them to
|
||||
# produce region keys like ``US-CA`` or ``GB-SCT`` (ISO 3166-2
|
||||
# subdivision without the country prefix). Different CDNs expose
|
||||
# this under different names: Cloudflare Enterprise uses
|
||||
# ``cf-region-code``, Vercel uses ``x-vercel-ip-country-region``,
|
||||
# GCP Load Balancer uses ``x-gclb-region``, CloudFront functions
|
||||
# use ``cloudfront-viewer-country-region``. Leave unset if you
|
||||
# only need country-level granularity.
|
||||
geoip_region_header: str | None = None
|
||||
|
||||
# Local MaxMind GeoLite2/GeoIP2 City database — used as a fallback
|
||||
# when no CDN header is present. Download GeoLite2-City.mmdb from
|
||||
# https://dev.maxmind.com/geoip/geolite2-free-geolocation-data and
|
||||
# mount it into the container (e.g. ``/data/GeoLite2-City.mmdb``).
|
||||
# When unset, lookups fall back to the free external ip-api.com
|
||||
# service, which is rate-limited and should not be relied on in
|
||||
# production.
|
||||
geoip_maxmind_db_path: str | None = None
|
||||
|
||||
# Rate limiting — on by default. Public endpoints (banner config +
|
||||
# consent submission) are internet-exposed and must not be DoS-able.
|
||||
# Auth endpoints get a stricter bucket via ``RateLimitMiddleware``.
|
||||
rate_limit_enabled: bool = True
|
||||
rate_limit_per_minute: int = 120
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_production_safety(self) -> "Settings":
|
||||
"""Refuse to start with unsafe defaults in non-dev environments."""
|
||||
if self.environment.lower() in ("development", "dev", "local", "test"):
|
||||
return self
|
||||
|
||||
errors: list[str] = []
|
||||
|
||||
if self.jwt_secret_key == _JWT_PLACEHOLDER:
|
||||
errors.append(
|
||||
"JWT_SECRET_KEY is set to the placeholder value "
|
||||
f"{_JWT_PLACEHOLDER!r}. Generate a strong random value "
|
||||
"(e.g. `openssl rand -base64 48`) and set it in the "
|
||||
"environment before starting the API."
|
||||
)
|
||||
|
||||
if "*" in self.allowed_origins_list:
|
||||
errors.append(
|
||||
"ALLOWED_ORIGINS contains '*'. Wildcard CORS combined with "
|
||||
"allow_credentials=True is a credential-theft vector. "
|
||||
"Set ALLOWED_ORIGINS to an explicit list of trusted origins."
|
||||
)
|
||||
|
||||
if errors:
|
||||
msg = "Refusing to start with unsafe configuration:\n - " + "\n - ".join(
|
||||
errors,
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def pseudonymisation_key(self) -> bytes:
|
||||
"""Return the HMAC key used for pseudonymising IP/UA values.
|
||||
|
||||
If ``pseudonymisation_secret`` is not set, derives a per-instance
|
||||
key from the JWT secret so operators don't have to configure two
|
||||
secrets. Using JWT_SECRET directly is acceptable because the
|
||||
HMAC is one-way and the resulting hashes are not reversible.
|
||||
"""
|
||||
source = self.pseudonymisation_secret or self.jwt_secret_key
|
||||
return source.encode("utf-8")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
3
apps/api/src/db/__init__.py
Normal file
3
apps/api/src/db/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.db.session import get_db
|
||||
|
||||
__all__ = ["get_db"]
|
||||
31
apps/api/src/db/session.py
Normal file
31
apps/api/src/db/session.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.database_echo,
|
||||
pool_size=settings.database_pool_size,
|
||||
max_overflow=settings.database_max_overflow,
|
||||
)
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Dependency that yields an async database session."""
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
3
apps/api/src/extensions/__init__.py
Normal file
3
apps/api/src/extensions/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.extensions.registry import discover_extensions, get_registry
|
||||
|
||||
__all__ = ["discover_extensions", "get_registry"]
|
||||
197
apps/api/src/extensions/registry.py
Normal file
197
apps/api/src/extensions/registry.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Extension registry for the open-core architecture.
|
||||
|
||||
Provides registration hooks that allow enterprise/commercial code to inject
|
||||
routers, model modules, startup tasks, and OpenAPI tags into the core
|
||||
application — without the core needing any direct knowledge of the
|
||||
extensions.
|
||||
|
||||
In community edition (CE) mode, ``discover_extensions()`` is a no-op
|
||||
because the ``ee`` package is not present.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAPITag:
|
||||
"""Metadata for a FastAPI OpenAPI tag."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterEntry:
|
||||
"""A router registered by an extension."""
|
||||
|
||||
router: APIRouter
|
||||
prefix: str = "/api/v1"
|
||||
tags: list[OpenAPITag] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtensionRegistry:
|
||||
"""Central registry for extension-contributed components.
|
||||
|
||||
Extensions call the module-level helper functions (``register_router``,
|
||||
``register_model_module``, etc.) which delegate to the singleton
|
||||
instance stored in ``_registry``.
|
||||
"""
|
||||
|
||||
routers: list[RouterEntry] = field(default_factory=list)
|
||||
model_modules: list[str] = field(default_factory=list)
|
||||
startup_hooks: list[Callable[[FastAPI], Coroutine[Any, Any, None]]] = field(
|
||||
default_factory=list,
|
||||
)
|
||||
config_enrichers: list[Callable] = field(default_factory=list)
|
||||
consent_record_hooks: list[Callable] = field(default_factory=list)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def add_router(
|
||||
self,
|
||||
router: APIRouter,
|
||||
*,
|
||||
prefix: str = "/api/v1",
|
||||
tags: list[OpenAPITag] | None = None,
|
||||
) -> None:
|
||||
self.routers.append(RouterEntry(router=router, prefix=prefix, tags=tags or []))
|
||||
|
||||
def add_model_module(self, module_path: str) -> None:
|
||||
self.model_modules.append(module_path)
|
||||
|
||||
def add_startup_hook(
|
||||
self,
|
||||
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
self.startup_hooks.append(hook)
|
||||
|
||||
def add_config_enricher(self, enricher: Callable) -> None:
|
||||
self.config_enrichers.append(enricher)
|
||||
|
||||
def add_consent_record_hook(self, hook: Callable) -> None:
|
||||
self.consent_record_hooks.append(hook)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Application wiring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def apply(self, app: FastAPI) -> None:
|
||||
"""Mount all registered routers and tags onto *app*."""
|
||||
for entry in self.routers:
|
||||
# Inject OpenAPI tags
|
||||
for tag in entry.tags:
|
||||
existing = app.openapi_tags or []
|
||||
if not any(t["name"] == tag.name for t in existing):
|
||||
existing.append(
|
||||
{"name": tag.name, "description": tag.description},
|
||||
)
|
||||
app.openapi_tags = existing
|
||||
|
||||
app.include_router(entry.router, prefix=entry.prefix)
|
||||
|
||||
if self.routers:
|
||||
logger.info(
|
||||
"Registered %d extension router(s)",
|
||||
len(self.routers),
|
||||
)
|
||||
|
||||
# Import model modules so SQLAlchemy picks them up
|
||||
for mod in self.model_modules:
|
||||
importlib.import_module(mod)
|
||||
|
||||
if self.model_modules:
|
||||
logger.info(
|
||||
"Registered %d extension model module(s)",
|
||||
len(self.model_modules),
|
||||
)
|
||||
|
||||
|
||||
# Singleton ------------------------------------------------------------------
|
||||
|
||||
_registry = ExtensionRegistry()
|
||||
|
||||
|
||||
def get_registry() -> ExtensionRegistry:
|
||||
"""Return the global extension registry."""
|
||||
return _registry
|
||||
|
||||
|
||||
# Convenience module-level API -----------------------------------------------
|
||||
|
||||
|
||||
def register_router(
|
||||
router: APIRouter,
|
||||
*,
|
||||
prefix: str = "/api/v1",
|
||||
tags: list[OpenAPITag] | None = None,
|
||||
) -> None:
|
||||
"""Register an API router to be mounted at startup."""
|
||||
_registry.add_router(router, prefix=prefix, tags=tags)
|
||||
|
||||
|
||||
def register_model_module(module_path: str) -> None:
|
||||
"""Register a dotted module path whose SQLAlchemy models should be imported."""
|
||||
_registry.add_model_module(module_path)
|
||||
|
||||
|
||||
def register_startup_hook(
|
||||
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Register an async callable to run during application startup."""
|
||||
_registry.add_startup_hook(hook)
|
||||
|
||||
|
||||
def register_config_enricher(enricher: Callable) -> None:
|
||||
"""Register a callable that enriches published config.
|
||||
|
||||
The callable signature is ``async (site_id: UUID, db: AsyncSession, config: dict) -> None``.
|
||||
It should mutate *config* in-place to add extension-specific data
|
||||
(e.g. A/B test variants).
|
||||
"""
|
||||
_registry.add_config_enricher(enricher)
|
||||
|
||||
|
||||
def register_consent_record_hook(hook: Callable) -> None:
|
||||
"""Register a callable invoked after a consent record is persisted.
|
||||
|
||||
The callable signature is ``async (db: AsyncSession, consent_record) -> None``.
|
||||
It is called from ``POST /api/v1/consent`` after the record has been
|
||||
flushed to the database. Typical use: generating a consent receipt
|
||||
(EE), writing audit logs, firing webhooks.
|
||||
"""
|
||||
_registry.add_consent_record_hook(hook)
|
||||
|
||||
|
||||
# Discovery ------------------------------------------------------------------
|
||||
|
||||
|
||||
def discover_extensions() -> None:
|
||||
"""Import the EE registration module if installed.
|
||||
|
||||
Enterprise edition is distributed as a separate ``consent-enterprise``
|
||||
package. When installed in the same environment, importing
|
||||
``ee.api.src.register`` triggers its side-effect registrations. In
|
||||
community edition the import simply fails and we carry on.
|
||||
"""
|
||||
try:
|
||||
import ee.api.src.register # noqa: F401
|
||||
|
||||
logger.info("Enterprise extensions loaded")
|
||||
except ImportError:
|
||||
logger.debug("No enterprise extensions found (CE mode)")
|
||||
210
apps/api/src/main.py
Normal file
210
apps/api/src/main.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from src.config.edition import edition_name
|
||||
from src.config.logging import setup_logging
|
||||
from src.config.settings import get_settings
|
||||
from src.extensions.registry import discover_extensions, get_registry
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from src.routers import (
|
||||
auth,
|
||||
compliance,
|
||||
config,
|
||||
consent,
|
||||
cookies,
|
||||
org_config,
|
||||
organisations,
|
||||
scanner,
|
||||
site_group_config,
|
||||
site_groups,
|
||||
sites,
|
||||
translations,
|
||||
users,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application startup and shutdown lifecycle."""
|
||||
settings = get_settings()
|
||||
setup_logging(settings.log_level)
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Application factory."""
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description=(
|
||||
"Multi-tenant cookie consent management platform API. "
|
||||
"Provides consent collection, cookie scanning, auto-blocking, "
|
||||
"compliance checking, and analytics across multiple sites."
|
||||
),
|
||||
debug=settings.debug,
|
||||
lifespan=lifespan,
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "auth",
|
||||
"description": "Authentication — login, token refresh, and current user.",
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"description": (
|
||||
"Site configuration — public endpoints for the banner script "
|
||||
"to fetch config, GeoIP-resolved config, and CDN publishing."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "consent",
|
||||
"description": (
|
||||
"Consent recording and retrieval — public endpoints called "
|
||||
"by the banner script to record visitor consent decisions."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "sites",
|
||||
"description": "Site and site config CRUD — manage domains and settings.",
|
||||
},
|
||||
{
|
||||
"name": "cookies",
|
||||
"description": (
|
||||
"Cookie management — categories, discovered cookies, allow-list, "
|
||||
"known cookies database, and auto-classification."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "scanner",
|
||||
"description": (
|
||||
"Cookie scanner — trigger scans, view results, and receive "
|
||||
"client-side cookie reports from the banner script."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "compliance",
|
||||
"description": (
|
||||
"Compliance checking — run checks against GDPR, CNIL, CCPA, "
|
||||
"ePrivacy, and LGPD frameworks."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "organisations",
|
||||
"description": "Organisation management — multi-tenant root entities.",
|
||||
},
|
||||
{
|
||||
"name": "users",
|
||||
"description": "User management — org-scoped users with role-based access.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Security headers
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Rate limiting (must be added before CORS to count requests correctly)
|
||||
if settings.rate_limit_enabled:
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
redis_url=settings.redis_url,
|
||||
requests_per_minute=settings.rate_limit_per_minute,
|
||||
auth_requests_per_minute=10,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.allowed_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Core routers
|
||||
api_prefix = "/api/v1"
|
||||
app.include_router(auth.router, prefix=api_prefix)
|
||||
app.include_router(config.router, prefix=api_prefix)
|
||||
app.include_router(consent.router, prefix=api_prefix)
|
||||
app.include_router(scanner.router, prefix=api_prefix)
|
||||
app.include_router(compliance.router, prefix=api_prefix)
|
||||
app.include_router(organisations.router, prefix=api_prefix)
|
||||
app.include_router(org_config.router, prefix=api_prefix)
|
||||
app.include_router(users.router, prefix=api_prefix)
|
||||
app.include_router(site_groups.router, prefix=api_prefix)
|
||||
app.include_router(site_group_config.router, prefix=api_prefix)
|
||||
app.include_router(sites.router, prefix=api_prefix)
|
||||
app.include_router(cookies.router, prefix=api_prefix)
|
||||
app.include_router(translations.router, prefix=api_prefix)
|
||||
app.include_router(translations.public_router, prefix=api_prefix)
|
||||
|
||||
# Discover and mount enterprise extensions (no-op in CE mode)
|
||||
discover_extensions()
|
||||
registry = get_registry()
|
||||
registry.apply(app)
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health() -> dict[str, str]:
|
||||
"""Shallow liveness check.
|
||||
|
||||
Answers "is the process running?". Suitable for orchestrator
|
||||
liveness probes. For deployment readiness, use
|
||||
``/health/ready`` which verifies downstream dependencies.
|
||||
"""
|
||||
return {"status": "ok", "edition": edition_name()}
|
||||
|
||||
@app.get("/health/ready", tags=["health"])
|
||||
async def health_ready() -> dict[str, object]:
|
||||
"""Deep readiness check — verifies database and Redis.
|
||||
|
||||
Returns HTTP 503 if either dependency is unreachable so load
|
||||
balancers route traffic away from broken instances.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import text
|
||||
|
||||
from src.db.session import engine as db_engine
|
||||
|
||||
checks: dict[str, str] = {}
|
||||
overall_ok = True
|
||||
|
||||
# Database
|
||||
try:
|
||||
async with db_engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "ok"
|
||||
except Exception as exc:
|
||||
checks["database"] = f"error: {type(exc).__name__}"
|
||||
overall_ok = False
|
||||
|
||||
# Redis
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
r = aioredis.from_url(settings.redis_url, decode_responses=True)
|
||||
pong = await r.ping()
|
||||
checks["redis"] = "ok" if pong else "error: ping failed"
|
||||
if not pong:
|
||||
overall_ok = False
|
||||
await r.aclose()
|
||||
except Exception as exc:
|
||||
checks["redis"] = f"error: {type(exc).__name__}"
|
||||
overall_ok = False
|
||||
|
||||
payload = {
|
||||
"status": "ok" if overall_ok else "degraded",
|
||||
"edition": edition_name(),
|
||||
"checks": checks,
|
||||
}
|
||||
if not overall_ok:
|
||||
raise HTTPException(status_code=503, detail=payload)
|
||||
return payload
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
0
apps/api/src/middleware/__init__.py
Normal file
0
apps/api/src/middleware/__init__.py
Normal file
111
apps/api/src/middleware/rate_limit.py
Normal file
111
apps/api/src/middleware/rate_limit.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Redis-backed rate limiting middleware.
|
||||
|
||||
Applies per-IP rate limits to all incoming requests. Public endpoints
|
||||
(consent recording, config fetching) are the primary protection target.
|
||||
|
||||
Uses a sliding window counter stored in Redis with automatic expiry.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Simple per-IP rate limiter backed by Redis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: object,
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
requests_per_minute: int = 120,
|
||||
auth_requests_per_minute: int = 10,
|
||||
) -> None:
|
||||
super().__init__(app) # type: ignore[arg-type]
|
||||
self.redis_url = redis_url
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.auth_requests_per_minute = auth_requests_per_minute
|
||||
self._redis: object | None = None
|
||||
|
||||
async def _get_redis(self) -> object | None:
|
||||
"""Lazy-initialise Redis connection."""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis = aioredis.from_url(self.redis_url, decode_responses=True)
|
||||
return self._redis
|
||||
except Exception:
|
||||
logger.warning("Rate limiting disabled: Redis unavailable")
|
||||
return None
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract the real client IP."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path in ("/health", "/health/ready", "/health/live"):
|
||||
return await call_next(request)
|
||||
|
||||
r = await self._get_redis()
|
||||
if r is None:
|
||||
# Redis unavailable — allow request through
|
||||
return await call_next(request)
|
||||
|
||||
# Auth endpoints get a stricter bucket to slow down credential
|
||||
# stuffing — login, password reset, token refresh.
|
||||
path = request.url.path
|
||||
is_auth = path.startswith("/api/v1/auth/") and path not in ("/api/v1/auth/me",)
|
||||
limit = self.auth_requests_per_minute if is_auth else self.requests_per_minute
|
||||
bucket = "auth" if is_auth else "req"
|
||||
|
||||
client_ip = self._get_client_ip(request)
|
||||
window = int(time.time() // 60)
|
||||
key = f"cmp:rate:{bucket}:{client_ip}:{window}"
|
||||
|
||||
try:
|
||||
current = await r.incr(key) # type: ignore[union-attr]
|
||||
if current == 1:
|
||||
await r.expire(key, 120) # type: ignore[union-attr]
|
||||
|
||||
if current > limit:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Too many requests. Please try again later."},
|
||||
headers={
|
||||
"Retry-After": "60",
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
remaining = max(0, limit - current)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
except Exception:
|
||||
logger.debug("Rate limit check failed", exc_info=True)
|
||||
return await call_next(request)
|
||||
41
apps/api/src/middleware/security_headers.py
Normal file
41
apps/api/src/middleware/security_headers.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Security headers middleware.
|
||||
|
||||
Adds standard security headers to all API responses:
|
||||
- X-Content-Type-Options: nosniff
|
||||
- X-Frame-Options: DENY
|
||||
- X-XSS-Protection: 0 (disabled in favour of CSP)
|
||||
- Referrer-Policy: strict-origin-when-cross-origin
|
||||
- Content-Security-Policy: default-src 'none'
|
||||
- Strict-Transport-Security (HSTS) in production
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "0"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Content-Security-Policy"] = "default-src 'none'"
|
||||
|
||||
# HSTS — only on HTTPS requests (reverse proxy may terminate TLS)
|
||||
if request.url.scheme == "https":
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=63072000; includeSubDomains; preload"
|
||||
)
|
||||
|
||||
return response
|
||||
31
apps/api/src/models/__init__.py
Normal file
31
apps/api/src/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from src.models.base import Base
|
||||
from src.models.consent import ConsentRecord
|
||||
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
|
||||
from src.models.org_config import OrgConfig
|
||||
from src.models.organisation import Organisation
|
||||
from src.models.scan import ScanJob, ScanResult
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.models.site_group import SiteGroup
|
||||
from src.models.site_group_config import SiteGroupConfig
|
||||
from src.models.translation import Translation
|
||||
from src.models.user import User
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ConsentRecord",
|
||||
"Cookie",
|
||||
"CookieAllowListEntry",
|
||||
"CookieCategory",
|
||||
"KnownCookie",
|
||||
"OrgConfig",
|
||||
"Organisation",
|
||||
"ScanJob",
|
||||
"ScanResult",
|
||||
"Site",
|
||||
"SiteConfig",
|
||||
"SiteGroup",
|
||||
"SiteGroupConfig",
|
||||
"Translation",
|
||||
"User",
|
||||
]
|
||||
48
apps/api/src/models/base.py
Normal file
48
apps/api/src/models/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all SQLAlchemy models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin that adds created_at and updated_at columns."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UUIDPrimaryKeyMixin:
|
||||
"""Mixin that adds a UUID primary key."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
"""Mixin that adds soft delete support."""
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
81
apps/api/src/models/consent.py
Normal file
81
apps/api/src/models/consent.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from src.models.base import Base, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class ConsentRecord(UUIDPrimaryKeyMixin, Base):
|
||||
"""Audit trail of every consent event. Partitioned by month for performance."""
|
||||
|
||||
__tablename__ = "consent_records"
|
||||
__table_args__ = (
|
||||
# Composite index for the most common analytics query pattern:
|
||||
# "records for site X between dates A and B". The (site_id,
|
||||
# consented_at DESC) ordering also supports "latest consents
|
||||
# for site X" without an extra sort.
|
||||
Index(
|
||||
"ix_consent_records_site_consented_at",
|
||||
"site_id",
|
||||
"consented_at",
|
||||
),
|
||||
)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Visitor identification (anonymous)
|
||||
visitor_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
ip_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
user_agent_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
|
||||
# Consent details
|
||||
action: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||
categories_accepted: Mapped[list] = mapped_column(JSONB, nullable=False)
|
||||
categories_rejected: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tc_string: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# GCM state at time of consent
|
||||
gcm_state: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPP
|
||||
gpp_string: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# GPC
|
||||
gpc_detected: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpc_honoured: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# A/B testing — soft references to EE `ab_tests` / `ab_test_variants`
|
||||
# tables. Intentionally *no* FK constraint so the core schema works
|
||||
# without the EE extension installed.
|
||||
ab_test_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
ab_variant_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Context
|
||||
page_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
country_code: Mapped[str | None] = mapped_column(String(5), nullable=True)
|
||||
region_code: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
|
||||
# Timestamp
|
||||
consented_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
130
apps/api/src/models/cookie.py
Normal file
130
apps/api/src/models/cookie.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class CookieCategory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Cookie category taxonomy (necessary, functional, analytics, marketing, personalisation)."""
|
||||
|
||||
__tablename__ = "cookie_categories"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_essential: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
display_order: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||
|
||||
# TCF purpose mapping
|
||||
tcf_purpose_ids: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Google Consent Mode consent type mapping
|
||||
gcm_consent_types: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Relationships
|
||||
cookies: Mapped[list["Cookie"]] = relationship(back_populates="category")
|
||||
allow_list_entries: Mapped[list["CookieAllowListEntry"]] = relationship(
|
||||
back_populates="category"
|
||||
)
|
||||
|
||||
|
||||
class Cookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A cookie discovered on a site via scanning or client-side reporting."""
|
||||
|
||||
__tablename__ = "cookies"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"site_id",
|
||||
"name",
|
||||
"domain",
|
||||
"storage_type",
|
||||
name="uq_cookies_site_name_domain_type",
|
||||
),
|
||||
)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
category_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("cookie_categories.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
domain: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
path: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
max_age_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
is_http_only: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
is_secure: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
same_site: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
review_status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
|
||||
first_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
last_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="cookies") # noqa: F821
|
||||
category: Mapped["CookieCategory | None"] = relationship(back_populates="cookies")
|
||||
|
||||
|
||||
class CookieAllowListEntry(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Approved cookies per site with category assignment."""
|
||||
|
||||
__tablename__ = "cookie_allow_list"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"site_id",
|
||||
"name_pattern",
|
||||
"domain_pattern",
|
||||
name="uq_allow_list_site_name_domain",
|
||||
),
|
||||
)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
category_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="cookie_allow_list") # noqa: F821
|
||||
category: Mapped["CookieCategory"] = relationship(back_populates="allow_list_entries")
|
||||
|
||||
|
||||
class KnownCookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Shared knowledge base of known cookie patterns for auto-categorisation."""
|
||||
|
||||
__tablename__ = "known_cookies"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("name_pattern", "domain_pattern", name="uq_known_cookies_name_domain"),
|
||||
)
|
||||
|
||||
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_regex: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
64
apps/api/src/models/org_config.py
Normal file
64
apps/api/src/models/org_config.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class OrgConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Organisation-level default configuration.
|
||||
|
||||
These defaults sit between system defaults and site config in the cascade:
|
||||
System Defaults → Org Config → Site Group Config → Site Config → Regional Overrides
|
||||
"""
|
||||
|
||||
__tablename__ = "org_configs"
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Blocking mode
|
||||
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
|
||||
|
||||
# GPP (Global Privacy Platform)
|
||||
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPC (Global Privacy Control)
|
||||
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Google Consent Mode
|
||||
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Shopify Customer Privacy API
|
||||
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Banner
|
||||
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Scanning
|
||||
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Consent
|
||||
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationship
|
||||
organisation: Mapped["Organisation"] = relationship(back_populates="org_config") # noqa: F821
|
||||
26
apps/api/src/models/organisation.py
Normal file
26
apps/api/src/models/organisation.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class Organisation(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""Multi-tenant root entity. Each organisation has multiple sites and users."""
|
||||
|
||||
__tablename__ = "organisations"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
slug: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
|
||||
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
billing_plan: Mapped[str] = mapped_column(String(50), server_default="free", nullable=False)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
users: Mapped[list["User"]] = relationship(back_populates="organisation") # noqa: F821
|
||||
sites: Mapped[list["Site"]] = relationship(back_populates="organisation") # noqa: F821
|
||||
site_groups: Mapped[list["SiteGroup"]] = relationship( # noqa: F821
|
||||
back_populates="organisation"
|
||||
)
|
||||
org_config: Mapped["OrgConfig | None"] = relationship( # noqa: F821
|
||||
back_populates="organisation", uselist=False
|
||||
)
|
||||
68
apps/api/src/models/scan.py
Normal file
68
apps/api/src/models/scan.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class ScanJob(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A cookie scanning job for a site."""
|
||||
|
||||
__tablename__ = "scan_jobs"
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), server_default="pending", nullable=False, index=True
|
||||
)
|
||||
trigger: Mapped[str] = mapped_column(String(20), server_default="manual", nullable=False)
|
||||
pages_scanned: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||
pages_total: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
cookies_found: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="scan_jobs") # noqa: F821
|
||||
results: Mapped[list["ScanResult"]] = relationship(back_populates="scan_job")
|
||||
|
||||
|
||||
class ScanResult(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Individual result from a scan — a cookie found on a specific page."""
|
||||
|
||||
__tablename__ = "scan_results"
|
||||
|
||||
scan_job_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("scan_jobs.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
page_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
cookie_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cookie_domain: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
script_source: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
auto_category: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
initiator_chain: Mapped[list[str] | None] = mapped_column(
|
||||
ARRAY(Text), nullable=True, comment="Ordered script URLs from root initiator to leaf"
|
||||
)
|
||||
|
||||
found_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
scan_job: Mapped["ScanJob"] = relationship(back_populates="results")
|
||||
48
apps/api/src/models/site.py
Normal file
48
apps/api/src/models/site.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class Site(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""A domain being managed for cookie consent, belongs to an organisation."""
|
||||
|
||||
__tablename__ = "sites"
|
||||
__table_args__ = (UniqueConstraint("organisation_id", "domain", name="uq_sites_org_domain"),)
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
domain: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
display_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
additional_domains: Mapped[list[str] | None] = mapped_column(
|
||||
ARRAY(String(255)), nullable=True, server_default=None
|
||||
)
|
||||
site_group_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("site_groups.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organisation: Mapped["Organisation"] = relationship(back_populates="sites") # noqa: F821
|
||||
site_group: Mapped["SiteGroup | None"] = relationship(back_populates="sites") # noqa: F821
|
||||
config: Mapped["SiteConfig | None"] = relationship( # noqa: F821
|
||||
back_populates="site", uselist=False
|
||||
)
|
||||
cookies: Mapped[list["Cookie"]] = relationship(back_populates="site") # noqa: F821
|
||||
cookie_allow_list: Mapped[list["CookieAllowListEntry"]] = relationship( # noqa: F821
|
||||
back_populates="site"
|
||||
)
|
||||
scan_jobs: Mapped[list["ScanJob"]] = relationship(back_populates="site") # noqa: F821
|
||||
translations: Mapped[list["Translation"]] = relationship( # noqa: F821
|
||||
back_populates="site"
|
||||
)
|
||||
63
apps/api/src/models/site_config.py
Normal file
63
apps/api/src/models/site_config.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class SiteConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Full configuration for a site: blocking mode, TCF, GCM, banner, scanning, consent."""
|
||||
|
||||
__tablename__ = "site_configs"
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Blocking mode
|
||||
blocking_mode: Mapped[str] = mapped_column(String(20), server_default="opt_in", nullable=False)
|
||||
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tcf_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
|
||||
|
||||
# GPP (Global Privacy Platform)
|
||||
gpp_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPC (Global Privacy Control)
|
||||
gpc_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
gpc_global_honour: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
|
||||
# Google Consent Mode
|
||||
gcm_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Shopify Customer Privacy API
|
||||
shopify_privacy_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
|
||||
# Banner
|
||||
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
display_mode: Mapped[str] = mapped_column(
|
||||
String(30), server_default="bottom_banner", nullable=False
|
||||
)
|
||||
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Scanning
|
||||
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
scan_max_pages: Mapped[int] = mapped_column(Integer, server_default="50", nullable=False)
|
||||
|
||||
# Consent
|
||||
consent_expiry_days: Mapped[int] = mapped_column(Integer, server_default="365", nullable=False)
|
||||
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationship
|
||||
site: Mapped["Site"] = relationship(back_populates="config") # noqa: F821
|
||||
32
apps/api/src/models/site_group.py
Normal file
32
apps/api/src/models/site_group.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class SiteGroup(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""A logical grouping of sites within an organisation (e.g. a brand)."""
|
||||
|
||||
__tablename__ = "site_groups"
|
||||
__table_args__ = (UniqueConstraint("organisation_id", "name", name="uq_site_groups_org_name"),)
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
organisation: Mapped["Organisation"] = relationship( # noqa: F821
|
||||
back_populates="site_groups"
|
||||
)
|
||||
sites: Mapped[list["Site"]] = relationship(back_populates="site_group") # noqa: F821
|
||||
group_config: Mapped["SiteGroupConfig | None"] = relationship( # noqa: F821
|
||||
back_populates="site_group", uselist=False
|
||||
)
|
||||
63
apps/api/src/models/site_group_config.py
Normal file
63
apps/api/src/models/site_group_config.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class SiteGroupConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Site-group-level default configuration.
|
||||
|
||||
These defaults sit between org defaults and site config in the cascade:
|
||||
System Defaults -> Org Config -> Site Group Config -> Site Config -> Regional Overrides
|
||||
"""
|
||||
|
||||
__tablename__ = "site_group_configs"
|
||||
|
||||
site_group_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("site_groups.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Blocking mode
|
||||
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
|
||||
|
||||
# GPP (Global Privacy Platform)
|
||||
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPC (Global Privacy Control)
|
||||
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Google Consent Mode
|
||||
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Shopify Customer Privacy API
|
||||
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Banner
|
||||
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Scanning
|
||||
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Consent
|
||||
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationship
|
||||
site_group: Mapped["SiteGroup"] = relationship(back_populates="group_config") # noqa: F821
|
||||
26
apps/api/src/models/translation.py
Normal file
26
apps/api/src/models/translation.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class Translation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Internationalisation strings per site per locale."""
|
||||
|
||||
__tablename__ = "translations"
|
||||
__table_args__ = (UniqueConstraint("site_id", "locale", name="uq_translations_site_locale"),)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
locale: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
strings: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="translations") # noqa: F821
|
||||
31
apps/api/src/models/user.py
Normal file
31
apps/api/src/models/user.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class User(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""User account, scoped to an organisation with a role."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
full_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
server_default="viewer",
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organisation: Mapped["Organisation"] = relationship(back_populates="users") # noqa: F821
|
||||
0
apps/api/src/routers/__init__.py
Normal file
0
apps/api/src/routers/__init__.py
Normal file
108
apps/api/src/routers/auth.py
Normal file
108
apps/api/src/routers/auth.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import JWTError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.db import get_db
|
||||
from src.models.user import User
|
||||
from src.schemas.auth import CurrentUser, LoginRequest, RefreshRequest, TokenResponse
|
||||
from src.services.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
verify_password,
|
||||
)
|
||||
from src.services.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
|
||||
"""Authenticate a user with email and password, return JWT tokens."""
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == body.email, User.deleted_at.is_(None))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
access_token = create_access_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
role=user.role,
|
||||
email=user.email,
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=settings.jwt_access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(
|
||||
body: RefreshRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TokenResponse:
|
||||
"""Exchange a valid refresh token for a new access/refresh token pair."""
|
||||
try:
|
||||
payload = decode_token(body.refresh_token)
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
) from exc
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is not a refresh token",
|
||||
)
|
||||
|
||||
user_id = uuid.UUID(payload["sub"])
|
||||
result = await db.execute(select(User).where(User.id == user_id, User.deleted_at.is_(None)))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User no longer exists",
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
access_token = create_access_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
role=user.role,
|
||||
email=user.email,
|
||||
)
|
||||
new_refresh_token = create_refresh_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=settings.jwt_access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=CurrentUser)
|
||||
async def get_me(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
"""Return the currently authenticated user's profile from the JWT."""
|
||||
return current_user
|
||||
135
apps/api/src/routers/compliance.py
Normal file
135
apps/api/src/routers/compliance.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Compliance checking endpoints.
|
||||
|
||||
Evaluates a site's configuration against regulatory frameworks (GDPR, CNIL,
|
||||
CCPA, ePrivacy, LGPD) and returns per-framework compliance reports with scores,
|
||||
issues, and recommendations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.cookie import Cookie
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.schemas.compliance import (
|
||||
ComplianceCheckRequest,
|
||||
ComplianceCheckResponse,
|
||||
Framework,
|
||||
)
|
||||
from src.services.compliance import (
|
||||
SiteContext,
|
||||
calculate_overall_score,
|
||||
run_compliance_check,
|
||||
)
|
||||
from src.services.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
|
||||
|
||||
async def _build_site_context(
|
||||
site_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> SiteContext:
|
||||
"""Load site config and cookie stats to build a SiteContext."""
|
||||
# Fetch site config
|
||||
result = await db.execute(
|
||||
select(SiteConfig).where(
|
||||
SiteConfig.site_id == site_id,
|
||||
SiteConfig.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
# Fetch cookie statistics
|
||||
total_q = await db.execute(
|
||||
select(func.count()).select_from(Cookie).where(Cookie.site_id == site_id)
|
||||
)
|
||||
total_cookies = total_q.scalar() or 0
|
||||
|
||||
uncat_q = await db.execute(
|
||||
select(func.count())
|
||||
.select_from(Cookie)
|
||||
.where(
|
||||
Cookie.site_id == site_id,
|
||||
Cookie.category_id.is_(None),
|
||||
)
|
||||
)
|
||||
uncategorised_cookies = uncat_q.scalar() or 0
|
||||
|
||||
if config is None:
|
||||
return SiteContext(
|
||||
total_cookies=total_cookies,
|
||||
uncategorised_cookies=uncategorised_cookies,
|
||||
)
|
||||
|
||||
banner_config = config.banner_config or {}
|
||||
return SiteContext(
|
||||
blocking_mode=config.blocking_mode,
|
||||
regional_modes=config.regional_modes,
|
||||
tcf_enabled=config.tcf_enabled,
|
||||
gcm_enabled=config.gcm_enabled,
|
||||
consent_expiry_days=config.consent_expiry_days,
|
||||
privacy_policy_url=config.privacy_policy_url,
|
||||
display_mode=config.display_mode,
|
||||
banner_config=config.banner_config,
|
||||
total_cookies=total_cookies,
|
||||
uncategorised_cookies=uncategorised_cookies,
|
||||
has_reject_button=banner_config.get("show_reject_all", True),
|
||||
has_granular_choices=banner_config.get("show_category_toggles", True),
|
||||
has_cookie_wall=banner_config.get("cookie_wall", False),
|
||||
pre_ticked_boxes=banner_config.get("pre_ticked", False),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/check/{site_id}",
|
||||
response_model=ComplianceCheckResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def check_compliance(
|
||||
site_id: uuid.UUID,
|
||||
body: ComplianceCheckRequest | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user=Depends(get_current_user),
|
||||
) -> ComplianceCheckResponse:
|
||||
"""Run compliance checks against a site's configuration."""
|
||||
# Verify site exists
|
||||
site_result = await db.execute(
|
||||
select(Site).where(Site.id == site_id, Site.deleted_at.is_(None))
|
||||
)
|
||||
site = site_result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site not found",
|
||||
)
|
||||
|
||||
ctx = await _build_site_context(site_id, db)
|
||||
frameworks = body.frameworks if body else None
|
||||
results = run_compliance_check(ctx, frameworks)
|
||||
overall_score = calculate_overall_score(results)
|
||||
|
||||
return ComplianceCheckResponse(
|
||||
site_id=str(site_id),
|
||||
results=results,
|
||||
overall_score=overall_score,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/frameworks", response_model=list[dict])
|
||||
async def list_frameworks() -> list[dict]:
|
||||
"""List all available compliance frameworks."""
|
||||
return [
|
||||
{"id": fw.value, "name": fw.value.upper(), "description": desc}
|
||||
for fw, desc in [
|
||||
(Framework.GDPR, "EU General Data Protection Regulation"),
|
||||
(Framework.CNIL, "French Data Protection Authority (stricter GDPR)"),
|
||||
(Framework.CCPA, "California Consumer Privacy Act / CPRA"),
|
||||
(Framework.EPRIVACY, "EU ePrivacy Directive"),
|
||||
(Framework.LGPD, "Brazilian General Data Protection Law"),
|
||||
]
|
||||
]
|
||||
324
apps/api/src/routers/config.py
Normal file
324
apps/api/src/routers/config.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.extensions.registry import get_registry
|
||||
from src.models.org_config import OrgConfig
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.models.site_group_config import SiteGroupConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site import SiteConfigResponse
|
||||
from src.services.config_resolver import (
|
||||
CONFIG_FIELDS,
|
||||
build_public_config,
|
||||
orm_to_config_dict,
|
||||
resolve_config,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
from src.services.geoip import detect_region
|
||||
from src.services.publisher import publish_site_config
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}", response_model=SiteConfigResponse)
|
||||
async def get_public_site_config(
|
||||
site_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Public endpoint: retrieve site config for the banner script. No auth required."""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/resolved")
|
||||
async def get_resolved_config(
|
||||
site_id: uuid.UUID,
|
||||
region: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Public endpoint: retrieve fully resolved config with regional overrides applied.
|
||||
|
||||
Applies the full cascade: System → Org → Group → Site → Regional.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
config_dict = orm_to_config_dict(config, include_id=True)
|
||||
|
||||
# Load org defaults via the site
|
||||
org_id = await _get_site_org_id(site_id, db)
|
||||
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
|
||||
|
||||
# Load site group defaults
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
|
||||
resolved = resolve_config(
|
||||
config_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
region=region,
|
||||
)
|
||||
return build_public_config(str(site_id), resolved)
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/geo-resolved")
|
||||
async def get_geo_resolved_config(
|
||||
site_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Public endpoint: resolve config using the visitor's detected region.
|
||||
|
||||
Detects the visitor's region from CDN headers or IP geolocation,
|
||||
then applies regional blocking mode overrides automatically.
|
||||
Uses the full cascade: System → Org → Group → Site → Regional.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
# Detect region from request
|
||||
geo = await detect_region(request)
|
||||
|
||||
config_dict = orm_to_config_dict(config, include_id=True)
|
||||
org_id = await _get_site_org_id(site_id, db)
|
||||
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
|
||||
resolved = resolve_config(
|
||||
config_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
region=geo.region,
|
||||
)
|
||||
public = build_public_config(str(site_id), resolved)
|
||||
|
||||
# Include detected geo info so the banner can use it
|
||||
public["detected_country"] = geo.country_code
|
||||
public["detected_region"] = geo.region
|
||||
|
||||
return public
|
||||
|
||||
|
||||
@router.get("/geo")
|
||||
async def get_visitor_geo(request: Request) -> dict:
|
||||
"""Public endpoint: return the detected region for the current visitor.
|
||||
|
||||
Useful for banner scripts that need to know the region before
|
||||
fetching the full config.
|
||||
"""
|
||||
geo = await detect_region(request)
|
||||
return {
|
||||
"country_code": geo.country_code,
|
||||
"region": geo.region,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/inheritance")
|
||||
async def get_config_inheritance(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Return the full config inheritance chain for a site.
|
||||
|
||||
Shows the value at each level so the UI can display where each setting
|
||||
comes from: system, org, group, or site.
|
||||
"""
|
||||
from src.services.config_resolver import SYSTEM_DEFAULTS
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
site_dict = orm_to_config_dict(config)
|
||||
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
|
||||
resolved = resolve_config(
|
||||
site_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
|
||||
# For each config field, determine the source
|
||||
sources: dict[str, dict] = {}
|
||||
for field in CONFIG_FIELDS:
|
||||
site_val = site_dict.get(field)
|
||||
group_val = group_defaults.get(field) if group_defaults else None
|
||||
org_val = org_defaults.get(field) if org_defaults else None
|
||||
system_val = SYSTEM_DEFAULTS.get(field)
|
||||
|
||||
# Determine effective source (highest priority non-None wins)
|
||||
if site_val is not None:
|
||||
source = "site"
|
||||
elif group_val is not None:
|
||||
source = "group"
|
||||
elif org_val is not None:
|
||||
source = "org"
|
||||
elif system_val is not None:
|
||||
source = "system"
|
||||
else:
|
||||
source = "system"
|
||||
|
||||
sources[field] = {
|
||||
"resolved_value": resolved.get(field),
|
||||
"source": source,
|
||||
"site_value": site_val,
|
||||
"group_value": group_val,
|
||||
"org_value": org_val,
|
||||
"system_value": system_val,
|
||||
}
|
||||
|
||||
return {
|
||||
"site_id": str(site_id),
|
||||
"site_group_id": str(group_id) if group_id else None,
|
||||
"fields": sources,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sites/{site_id}/publish")
|
||||
async def publish_config(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Publish fully-resolved site config to CDN. Requires admin role."""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
config_dict = orm_to_config_dict(config, include_id=True)
|
||||
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
resolved = resolve_config(
|
||||
config_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
|
||||
# Allow extensions to enrich the published config (e.g. A/B test data)
|
||||
registry = get_registry()
|
||||
for enricher in registry.config_enrichers:
|
||||
await enricher(site_id, db, resolved)
|
||||
|
||||
publish_result = await publish_site_config(str(site_id), resolved)
|
||||
|
||||
if not publish_result.success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Publish failed: {publish_result.error}",
|
||||
)
|
||||
|
||||
return {
|
||||
"published": True,
|
||||
"path": publish_result.path,
|
||||
"published_at": publish_result.published_at,
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_site_org_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
|
||||
"""Look up the organisation_id for a site."""
|
||||
result = await db.execute(select(Site.organisation_id).where(Site.id == site_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _get_site_group_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
|
||||
"""Look up the site_group_id for a site."""
|
||||
result = await db.execute(select(Site.site_group_id).where(Site.id == site_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _load_org_defaults(organisation_id: uuid.UUID, db: AsyncSession) -> dict | None:
|
||||
"""Load the org-level config defaults, or None if not set."""
|
||||
result = await db.execute(select(OrgConfig).where(OrgConfig.organisation_id == organisation_id))
|
||||
org_config = result.scalar_one_or_none()
|
||||
if org_config is None:
|
||||
return None
|
||||
return orm_to_config_dict(org_config)
|
||||
|
||||
|
||||
async def _load_group_defaults(group_id: uuid.UUID, db: AsyncSession) -> dict | None:
|
||||
"""Load the site-group-level config defaults, or None if not set."""
|
||||
result = await db.execute(
|
||||
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
|
||||
)
|
||||
group_config = result.scalar_one_or_none()
|
||||
if group_config is None:
|
||||
return None
|
||||
return orm_to_config_dict(group_config)
|
||||
125
apps/api/src/routers/consent.py
Normal file
125
apps/api/src/routers/consent.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.extensions.registry import get_registry
|
||||
from src.models.consent import ConsentRecord
|
||||
from src.models.site import Site
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.consent import (
|
||||
ConsentRecordCreate,
|
||||
ConsentRecordResponse,
|
||||
ConsentVerifyResponse,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
from src.services.pseudonymisation import pseudonymise
|
||||
|
||||
router = APIRouter(prefix="/consent", tags=["consent"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ConsentRecordResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def record_consent(
|
||||
body: ConsentRecordCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ConsentRecord:
|
||||
"""Record a consent event from the banner. Public endpoint (no auth required)."""
|
||||
# Pseudonymise IP and user agent with HMAC so the resulting values
|
||||
# cannot be reversed without the server-side secret.
|
||||
client_ip = request.client.host if request.client else ""
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
record = ConsentRecord(
|
||||
site_id=body.site_id,
|
||||
visitor_id=body.visitor_id,
|
||||
ip_hash=pseudonymise(client_ip),
|
||||
user_agent_hash=pseudonymise(user_agent),
|
||||
action=body.action,
|
||||
categories_accepted=body.categories_accepted,
|
||||
categories_rejected=body.categories_rejected,
|
||||
tc_string=body.tc_string,
|
||||
gcm_state=body.gcm_state,
|
||||
page_url=body.page_url,
|
||||
country_code=body.country_code,
|
||||
region_code=body.region_code,
|
||||
)
|
||||
db.add(record)
|
||||
await db.flush()
|
||||
await db.refresh(record)
|
||||
|
||||
# Invoke any registered post-record hooks (EE consent receipts, etc.)
|
||||
for hook in get_registry().consent_record_hooks:
|
||||
await hook(db, record)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
async def _load_record_for_org(
|
||||
consent_id: uuid.UUID,
|
||||
current_user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
) -> ConsentRecord:
|
||||
"""Load a consent record and enforce tenant isolation.
|
||||
|
||||
The record's site must belong to the caller's organisation. A record
|
||||
from another tenant returns 404 rather than 403 so we don't leak
|
||||
existence across tenants.
|
||||
"""
|
||||
stmt = (
|
||||
select(ConsentRecord)
|
||||
.join(Site, Site.id == ConsentRecord.site_id)
|
||||
.where(
|
||||
ConsentRecord.id == consent_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
record = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if record is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Consent record not found",
|
||||
)
|
||||
return record
|
||||
|
||||
|
||||
@router.get("/{consent_id}", response_model=ConsentRecordResponse)
|
||||
async def get_consent(
|
||||
consent_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ConsentRecord:
|
||||
"""Retrieve a consent record by ID.
|
||||
|
||||
Requires authentication and tenant membership. Consent records
|
||||
contain PII-adjacent data (hashed IP, page URL, category decisions)
|
||||
and must not be readable by anyone holding a record UUID.
|
||||
"""
|
||||
return await _load_record_for_org(consent_id, current_user, db)
|
||||
|
||||
|
||||
@router.get("/verify/{consent_id}", response_model=ConsentVerifyResponse)
|
||||
async def verify_consent(
|
||||
consent_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Verify that a consent record exists (audit proof).
|
||||
|
||||
Same tenant-scoped auth as :func:`get_consent` — proof of consent
|
||||
is only meaningful to the organisation that owns the site, and
|
||||
leaking existence to arbitrary callers enables enumeration.
|
||||
"""
|
||||
record = await _load_record_for_org(consent_id, current_user, db)
|
||||
return {
|
||||
"id": record.id,
|
||||
"site_id": record.site_id,
|
||||
"visitor_id": record.visitor_id,
|
||||
"action": record.action,
|
||||
"categories_accepted": record.categories_accepted,
|
||||
"consented_at": record.consented_at,
|
||||
"valid": True,
|
||||
}
|
||||
582
apps/api/src/routers/cookies.py
Normal file
582
apps/api/src/routers/cookies.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Cookie category, cookie, and allow-list management endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
|
||||
from src.models.site import Site
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.cookie import (
|
||||
AllowListEntryCreate,
|
||||
AllowListEntryResponse,
|
||||
AllowListEntryUpdate,
|
||||
ClassificationResultResponse,
|
||||
ClassifySingleRequest,
|
||||
ClassifySiteResponse,
|
||||
CookieCategoryResponse,
|
||||
CookieCreate,
|
||||
CookieResponse,
|
||||
CookieUpdate,
|
||||
KnownCookieCreate,
|
||||
KnownCookieResponse,
|
||||
KnownCookieUpdate,
|
||||
ReviewStatus,
|
||||
)
|
||||
from src.services.classification import classify_single_cookie, classify_site_cookies
|
||||
from src.services.dependencies import get_current_user, require_role
|
||||
|
||||
router = APIRouter(prefix="/cookies", tags=["cookies"])
|
||||
|
||||
|
||||
# ── Cookie categories (read-only, seeded by migration) ──────────────
|
||||
|
||||
|
||||
@router.get("/categories", response_model=list[CookieCategoryResponse])
|
||||
async def list_categories(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[CookieCategory]:
|
||||
"""List all cookie categories. Public endpoint used by banner and admin."""
|
||||
result = await db.execute(select(CookieCategory).order_by(CookieCategory.display_order))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/categories/{category_id}", response_model=CookieCategoryResponse)
|
||||
async def get_category(
|
||||
category_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieCategory:
|
||||
"""Get a single cookie category by ID."""
|
||||
result = await db.execute(select(CookieCategory).where(CookieCategory.id == category_id))
|
||||
category = result.scalar_one_or_none()
|
||||
if not category:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category not found")
|
||||
return category
|
||||
|
||||
|
||||
# ── Cookies per site ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_org_site(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
) -> Site:
|
||||
"""Fetch a site ensuring it belongs to the user's organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if not site:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
|
||||
return site
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sites/{site_id}",
|
||||
response_model=list[CookieResponse],
|
||||
)
|
||||
async def list_cookies(
|
||||
site_id: uuid.UUID,
|
||||
review_status: ReviewStatus | None = Query(None),
|
||||
category_id: uuid.UUID | None = Query(None),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[Cookie]:
|
||||
"""List cookies discovered on a site, with optional filters."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
query = select(Cookie).where(Cookie.site_id == site_id)
|
||||
if review_status:
|
||||
query = query.where(Cookie.review_status == review_status.value)
|
||||
if category_id:
|
||||
query = query.where(Cookie.category_id == category_id)
|
||||
query = query.order_by(Cookie.name)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}",
|
||||
response_model=CookieResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_cookie(
|
||||
site_id: uuid.UUID,
|
||||
body: CookieCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Cookie:
|
||||
"""Create a cookie record for a site (manual entry or from scanner)."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
# Validate category if provided
|
||||
if body.category_id:
|
||||
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
cookie = Cookie(
|
||||
site_id=site_id,
|
||||
**body.model_dump(),
|
||||
first_seen_at=datetime.now(UTC).isoformat(),
|
||||
last_seen_at=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
db.add(cookie)
|
||||
await db.flush()
|
||||
await db.refresh(cookie)
|
||||
return cookie
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/summary")
|
||||
async def cookie_summary(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Get a summary of cookies for a site (counts by status and category)."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
# Count by review status
|
||||
status_result = await db.execute(
|
||||
select(Cookie.review_status, func.count(Cookie.id))
|
||||
.where(Cookie.site_id == site_id)
|
||||
.group_by(Cookie.review_status)
|
||||
)
|
||||
by_status = {row[0]: row[1] for row in status_result.all()}
|
||||
|
||||
# Count by category
|
||||
cat_result = await db.execute(
|
||||
select(CookieCategory.slug, func.count(Cookie.id))
|
||||
.outerjoin(Cookie, Cookie.category_id == CookieCategory.id)
|
||||
.where(Cookie.site_id == site_id)
|
||||
.group_by(CookieCategory.slug)
|
||||
)
|
||||
by_category = {row[0]: row[1] for row in cat_result.all()}
|
||||
|
||||
# Uncategorised count
|
||||
uncat_result = await db.execute(
|
||||
select(func.count(Cookie.id)).where(Cookie.site_id == site_id, Cookie.category_id.is_(None))
|
||||
)
|
||||
uncategorised = uncat_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total": sum(by_status.values()),
|
||||
"by_status": by_status,
|
||||
"by_category": by_category,
|
||||
"uncategorised": uncategorised,
|
||||
}
|
||||
|
||||
|
||||
# ── Allow-list per site ──────────────────────────────────────────────
|
||||
# (Must be defined before {cookie_id} routes to avoid path conflicts)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sites/{site_id}/allow-list",
|
||||
response_model=list[AllowListEntryResponse],
|
||||
)
|
||||
async def list_allow_list(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[CookieAllowListEntry]:
|
||||
"""List all allow-list entries for a site."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry)
|
||||
.where(CookieAllowListEntry.site_id == site_id)
|
||||
.order_by(CookieAllowListEntry.name_pattern)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}/allow-list",
|
||||
response_model=AllowListEntryResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_allow_list_entry(
|
||||
site_id: uuid.UUID,
|
||||
body: AllowListEntryCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieAllowListEntry:
|
||||
"""Add a cookie pattern to the allow-list for a site."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
# Validate category
|
||||
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
entry = CookieAllowListEntry(
|
||||
site_id=site_id,
|
||||
**body.model_dump(),
|
||||
)
|
||||
db.add(entry)
|
||||
await db.flush()
|
||||
await db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sites/{site_id}/allow-list/{entry_id}",
|
||||
response_model=AllowListEntryResponse,
|
||||
)
|
||||
async def update_allow_list_entry(
|
||||
site_id: uuid.UUID,
|
||||
entry_id: uuid.UUID,
|
||||
body: AllowListEntryUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieAllowListEntry:
|
||||
"""Update an allow-list entry."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry).where(
|
||||
CookieAllowListEntry.id == entry_id,
|
||||
CookieAllowListEntry.site_id == site_id,
|
||||
)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Allow-list entry not found",
|
||||
)
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
|
||||
if "category_id" in updates and updates["category_id"] is not None:
|
||||
cat = await db.execute(
|
||||
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
|
||||
)
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(entry, field, value)
|
||||
entry.updated_at = datetime.now(UTC)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sites/{site_id}/allow-list/{entry_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_allow_list_entry(
|
||||
site_id: uuid.UUID,
|
||||
entry_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Remove an entry from the allow-list."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry).where(
|
||||
CookieAllowListEntry.id == entry_id,
|
||||
CookieAllowListEntry.site_id == site_id,
|
||||
)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Allow-list entry not found",
|
||||
)
|
||||
|
||||
await db.delete(entry)
|
||||
|
||||
|
||||
# ── Individual cookie by ID (must come after /summary and /allow-list) ──
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
|
||||
async def get_cookie(
|
||||
site_id: uuid.UUID,
|
||||
cookie_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Cookie:
|
||||
"""Get a single cookie by ID."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
|
||||
)
|
||||
cookie = result.scalar_one_or_none()
|
||||
if not cookie:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
|
||||
return cookie
|
||||
|
||||
|
||||
@router.patch("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
|
||||
async def update_cookie(
|
||||
site_id: uuid.UUID,
|
||||
cookie_id: uuid.UUID,
|
||||
body: CookieUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Cookie:
|
||||
"""Update a cookie record (e.g. assign category, change review status)."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
|
||||
)
|
||||
cookie = result.scalar_one_or_none()
|
||||
if not cookie:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
|
||||
# Validate category if being changed
|
||||
if "category_id" in updates and updates["category_id"] is not None:
|
||||
cat = await db.execute(
|
||||
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
|
||||
)
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(cookie, field, value)
|
||||
cookie.updated_at = datetime.now(UTC)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(cookie)
|
||||
return cookie
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sites/{site_id}/{cookie_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_cookie(
|
||||
site_id: uuid.UUID,
|
||||
cookie_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a cookie record."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
|
||||
)
|
||||
cookie = result.scalar_one_or_none()
|
||||
if not cookie:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
|
||||
|
||||
await db.delete(cookie)
|
||||
|
||||
|
||||
# ── Known cookies database ──────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/known", response_model=list[KnownCookieResponse])
|
||||
async def list_known_cookies(
|
||||
vendor: str | None = Query(None, description="Filter by vendor name"),
|
||||
search: str | None = Query(None, description="Search by name pattern"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[KnownCookie]:
|
||||
"""List known cookie patterns from the shared database."""
|
||||
query = select(KnownCookie).order_by(KnownCookie.name_pattern)
|
||||
if vendor:
|
||||
query = query.where(KnownCookie.vendor == vendor)
|
||||
if search:
|
||||
query = query.where(KnownCookie.name_pattern.ilike(f"%{search}%"))
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/known",
|
||||
response_model=KnownCookieResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_known_cookie(
|
||||
body: KnownCookieCreate,
|
||||
_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> KnownCookie:
|
||||
"""Add a new pattern to the known cookies database."""
|
||||
# Validate category
|
||||
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
known = KnownCookie(**body.model_dump())
|
||||
db.add(known)
|
||||
await db.flush()
|
||||
await db.refresh(known)
|
||||
return known
|
||||
|
||||
|
||||
@router.get("/known/{known_id}", response_model=KnownCookieResponse)
|
||||
async def get_known_cookie(
|
||||
known_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: CurrentUser = Depends(get_current_user),
|
||||
) -> KnownCookie:
|
||||
"""Get a single known cookie pattern by ID."""
|
||||
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
|
||||
known = result.scalar_one_or_none()
|
||||
if not known:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Known cookie not found",
|
||||
)
|
||||
return known
|
||||
|
||||
|
||||
@router.patch("/known/{known_id}", response_model=KnownCookieResponse)
|
||||
async def update_known_cookie(
|
||||
known_id: uuid.UUID,
|
||||
body: KnownCookieUpdate,
|
||||
_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> KnownCookie:
|
||||
"""Update a known cookie pattern."""
|
||||
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
|
||||
known = result.scalar_one_or_none()
|
||||
if not known:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Known cookie not found",
|
||||
)
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
if "category_id" in updates and updates["category_id"] is not None:
|
||||
cat = await db.execute(
|
||||
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
|
||||
)
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(known, field, value)
|
||||
known.updated_at = datetime.now(UTC)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(known)
|
||||
return known
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/known/{known_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_known_cookie(
|
||||
known_id: uuid.UUID,
|
||||
_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a known cookie pattern."""
|
||||
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
|
||||
known = result.scalar_one_or_none()
|
||||
if not known:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Known cookie not found",
|
||||
)
|
||||
await db.delete(known)
|
||||
|
||||
|
||||
# ── Classification endpoints ────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}/classify",
|
||||
response_model=ClassifySiteResponse,
|
||||
)
|
||||
async def classify_cookies(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ClassifySiteResponse:
|
||||
"""Auto-classify pending cookies for a site against known patterns."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
results = await classify_site_cookies(db, site_id, only_pending=True)
|
||||
matched_count = sum(1 for r in results if r.matched)
|
||||
|
||||
return ClassifySiteResponse(
|
||||
site_id=str(site_id),
|
||||
total=len(results),
|
||||
matched=matched_count,
|
||||
unmatched=len(results) - matched_count,
|
||||
results=[
|
||||
ClassificationResultResponse(
|
||||
cookie_name=r.cookie_name,
|
||||
cookie_domain=r.cookie_domain,
|
||||
category_id=r.category_id,
|
||||
category_slug=r.category_slug,
|
||||
vendor=r.vendor,
|
||||
description=r.description,
|
||||
match_source=r.match_source,
|
||||
matched=r.matched,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}/classify/preview",
|
||||
response_model=ClassificationResultResponse,
|
||||
)
|
||||
async def classify_preview(
|
||||
site_id: uuid.UUID,
|
||||
body: ClassifySingleRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ClassificationResultResponse:
|
||||
"""Preview classification for a single cookie without saving."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await classify_single_cookie(db, site_id, body.cookie_name, body.cookie_domain)
|
||||
return ClassificationResultResponse(
|
||||
cookie_name=result.cookie_name,
|
||||
cookie_domain=result.cookie_domain,
|
||||
category_id=result.category_id,
|
||||
category_slug=result.category_slug,
|
||||
vendor=result.vendor,
|
||||
description=result.description,
|
||||
match_source=result.match_source,
|
||||
matched=result.matched,
|
||||
)
|
||||
69
apps/api/src/routers/org_config.py
Normal file
69
apps/api/src/routers/org_config.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Organisation-level default configuration endpoints.
|
||||
|
||||
Provides GET and PUT for the organisation's global config defaults.
|
||||
These defaults sit between system defaults and site config in the cascade.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.org_config import OrgConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.org_config import OrgConfigResponse, OrgConfigUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/org-config", tags=["organisations"])
|
||||
|
||||
|
||||
@router.get("/", response_model=OrgConfigResponse)
|
||||
async def get_org_config(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> OrgConfig:
|
||||
"""Retrieve the organisation's global configuration defaults."""
|
||||
result = await db.execute(
|
||||
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
# Auto-create an empty config row so the response is always valid
|
||||
config = OrgConfig(organisation_id=current_user.organisation_id)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/", response_model=OrgConfigResponse)
|
||||
async def update_org_config(
|
||||
body: OrgConfigUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> OrgConfig:
|
||||
"""Create or update the organisation's global configuration defaults.
|
||||
|
||||
Only non-None fields will override system defaults when resolving site config.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
config = OrgConfig(
|
||||
organisation_id=current_user.organisation_id,
|
||||
**body.model_dump(exclude_unset=True),
|
||||
)
|
||||
db.add(config)
|
||||
else:
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
118
apps/api/src/routers/organisations.py
Normal file
118
apps/api/src/routers/organisations.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import hmac
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.db import get_db
|
||||
from src.models.organisation import Organisation
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.organisation import (
|
||||
OrganisationCreate,
|
||||
OrganisationResponse,
|
||||
OrganisationUpdate,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/organisations", tags=["organisations"])
|
||||
|
||||
|
||||
def _require_bootstrap_token(
|
||||
x_admin_bootstrap_token: str | None = Header(default=None),
|
||||
) -> None:
|
||||
"""Gate organisation creation behind a static bootstrap token.
|
||||
|
||||
The token is configured via ``ADMIN_BOOTSTRAP_TOKEN``. When unset
|
||||
(the default), the endpoint is disabled entirely — operators must
|
||||
explicitly opt in and should rotate or unset the value after their
|
||||
initial org is provisioned.
|
||||
"""
|
||||
expected = get_settings().admin_bootstrap_token
|
||||
if not expected:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
"Organisation creation is disabled. Set ADMIN_BOOTSTRAP_TOKEN "
|
||||
"in the environment to enable it."
|
||||
),
|
||||
)
|
||||
if not x_admin_bootstrap_token or not hmac.compare_digest(
|
||||
x_admin_bootstrap_token,
|
||||
expected,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing admin bootstrap token",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=OrganisationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_organisation(
|
||||
body: OrganisationCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(_require_bootstrap_token),
|
||||
) -> Organisation:
|
||||
"""Create a new organisation. Gated by ``X-Admin-Bootstrap-Token``.
|
||||
|
||||
See :func:`_require_bootstrap_token` for the gating semantics. Once
|
||||
your initial organisation exists, rotate or unset
|
||||
``ADMIN_BOOTSTRAP_TOKEN`` to disable further tenant creation.
|
||||
"""
|
||||
# Check slug uniqueness
|
||||
existing = await db.execute(select(Organisation).where(Organisation.slug == body.slug))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Organisation with slug '{body.slug}' already exists",
|
||||
)
|
||||
|
||||
org = Organisation(**body.model_dump())
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
await db.refresh(org)
|
||||
return org
|
||||
|
||||
|
||||
@router.get("/me", response_model=OrganisationResponse)
|
||||
async def get_my_organisation(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Organisation:
|
||||
"""Get the current user's organisation."""
|
||||
result = await db.execute(
|
||||
select(Organisation).where(
|
||||
Organisation.id == current_user.organisation_id,
|
||||
Organisation.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
|
||||
return org
|
||||
|
||||
|
||||
@router.patch("/me", response_model=OrganisationResponse)
|
||||
async def update_my_organisation(
|
||||
body: OrganisationUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Organisation:
|
||||
"""Update the current user's organisation. Requires owner or admin role."""
|
||||
result = await db.execute(
|
||||
select(Organisation).where(
|
||||
Organisation.id == current_user.organisation_id,
|
||||
Organisation.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(org, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(org)
|
||||
return org
|
||||
310
apps/api/src/routers/scanner.py
Normal file
310
apps/api/src/routers/scanner.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Scanner and client-side cookie report endpoints.
|
||||
|
||||
Accepts cookie reports from the client-side reporter embedded in the banner
|
||||
bundle, upserts discovered cookies into the site's cookie inventory, and
|
||||
provides scan job management (trigger, list, detail, diff).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.cookie import Cookie
|
||||
from src.models.scan import ScanJob, ScanResult
|
||||
from src.models.site import Site
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.scanner import (
|
||||
CookieReportRequest,
|
||||
CookieReportResponse,
|
||||
ScanDiffResponse,
|
||||
ScanJobDetailResponse,
|
||||
ScanJobResponse,
|
||||
TriggerScanRequest,
|
||||
)
|
||||
from src.services.dependencies import get_current_user
|
||||
from src.services.scanner import (
|
||||
compute_scan_diff,
|
||||
create_scan_job,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/scanner", tags=["scanner"])
|
||||
|
||||
|
||||
# ── Client-side cookie report (public, no auth) ─────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/report",
|
||||
response_model=CookieReportResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def receive_cookie_report(
|
||||
body: CookieReportRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieReportResponse:
|
||||
"""Receive a cookie report from the client-side reporter.
|
||||
|
||||
This is a public endpoint (no auth) since it's called from the banner
|
||||
script running on end-user browsers. The site_id acts as implicit auth.
|
||||
"""
|
||||
# Verify site exists
|
||||
site_result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == body.site_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if site_result.scalar_one_or_none() is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site not found",
|
||||
)
|
||||
|
||||
new_cookies = 0
|
||||
now_iso = datetime.now(UTC).isoformat()
|
||||
|
||||
for reported in body.cookies:
|
||||
# Check if this cookie already exists for the site
|
||||
existing = await db.execute(
|
||||
select(Cookie).where(
|
||||
Cookie.site_id == body.site_id,
|
||||
Cookie.name == reported.name,
|
||||
Cookie.domain == reported.domain,
|
||||
Cookie.storage_type == reported.storage_type,
|
||||
)
|
||||
)
|
||||
cookie = existing.scalar_one_or_none()
|
||||
|
||||
if cookie:
|
||||
# Update last_seen_at timestamp
|
||||
cookie.last_seen_at = now_iso
|
||||
else:
|
||||
# Create new cookie record
|
||||
cookie = Cookie(
|
||||
site_id=body.site_id,
|
||||
name=reported.name,
|
||||
domain=reported.domain,
|
||||
storage_type=reported.storage_type,
|
||||
path=reported.path,
|
||||
is_secure=reported.is_secure,
|
||||
same_site=reported.same_site,
|
||||
review_status="pending",
|
||||
first_seen_at=now_iso,
|
||||
last_seen_at=now_iso,
|
||||
)
|
||||
db.add(cookie)
|
||||
new_cookies += 1
|
||||
|
||||
await db.flush()
|
||||
|
||||
return CookieReportResponse(
|
||||
accepted=True,
|
||||
cookies_received=len(body.cookies),
|
||||
new_cookies=new_cookies,
|
||||
)
|
||||
|
||||
|
||||
# ── Scan job management (authenticated) ─────────────────────────────
|
||||
|
||||
|
||||
async def _verify_site_access(
|
||||
site_id: uuid.UUID,
|
||||
user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
) -> Site:
|
||||
"""Verify site exists and belongs to the user's organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site not found",
|
||||
)
|
||||
return site
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scans",
|
||||
response_model=ScanJobResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def trigger_scan(
|
||||
body: TriggerScanRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScanJob:
|
||||
"""Trigger a new cookie scan for a site.
|
||||
|
||||
Creates a scan job in 'pending' state and dispatches it to the
|
||||
Celery worker queue for execution.
|
||||
"""
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
await _verify_site_access(body.site_id, user, db)
|
||||
|
||||
# Check for an already-running scan
|
||||
active_result = await db.execute(
|
||||
select(ScanJob).where(
|
||||
ScanJob.site_id == body.site_id,
|
||||
ScanJob.status.in_(["pending", "running"]),
|
||||
)
|
||||
)
|
||||
active_jobs = list(active_result.scalars().all())
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stale_pending_cutoff = now - timedelta(minutes=5)
|
||||
stale_running_cutoff = now - timedelta(minutes=10)
|
||||
|
||||
for active_job in active_jobs:
|
||||
is_stale_pending = (
|
||||
active_job.status == "pending"
|
||||
and active_job.created_at.replace(tzinfo=UTC) < stale_pending_cutoff
|
||||
)
|
||||
is_stale_running = (
|
||||
active_job.status == "running"
|
||||
and active_job.started_at
|
||||
and active_job.started_at.replace(tzinfo=UTC) < stale_running_cutoff
|
||||
)
|
||||
if is_stale_pending or is_stale_running:
|
||||
logger.warning(
|
||||
"Failing stale %s scan job %s for site %s",
|
||||
active_job.status,
|
||||
active_job.id,
|
||||
body.site_id,
|
||||
)
|
||||
await complete_scan_job(
|
||||
db,
|
||||
active_job,
|
||||
error_message=(
|
||||
f"Job was stale ({active_job.status} too long), superseded by new scan"
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="A scan is already in progress for this site",
|
||||
)
|
||||
|
||||
job = await create_scan_job(
|
||||
db,
|
||||
site_id=body.site_id,
|
||||
trigger="manual",
|
||||
max_pages=body.max_pages,
|
||||
)
|
||||
|
||||
# Commit before dispatching to Celery so the worker can find the
|
||||
# job in the database immediately (avoids race condition).
|
||||
await db.commit()
|
||||
|
||||
# Dispatch to Celery (import here to avoid import at module level
|
||||
# when Celery broker is unavailable during testing)
|
||||
try:
|
||||
from src.tasks.scanner import run_scan
|
||||
|
||||
run_scan.delay(str(job.id), str(body.site_id))
|
||||
except Exception:
|
||||
logger.exception("Failed to dispatch scan job %s to Celery", job.id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=(
|
||||
"Background task queue is unavailable — scan job"
|
||||
" created but cannot be processed. Please try again later."
|
||||
),
|
||||
) from None
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/scans/site/{site_id}", response_model=list[ScanJobResponse])
|
||||
async def list_scans(
|
||||
site_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> list[ScanJob]:
|
||||
"""List scan jobs for a site, most recent first."""
|
||||
await _verify_site_access(site_id, user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(ScanJob)
|
||||
.where(ScanJob.site_id == site_id)
|
||||
.order_by(ScanJob.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/scans/{scan_id}", response_model=ScanJobDetailResponse)
|
||||
async def get_scan(
|
||||
scan_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Retrieve a scan job with its results."""
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Scan job not found",
|
||||
)
|
||||
|
||||
# Verify org access
|
||||
await _verify_site_access(job.site_id, user, db)
|
||||
|
||||
# Load results
|
||||
results = await db.execute(
|
||||
select(ScanResult).where(ScanResult.scan_job_id == scan_id).order_by(ScanResult.cookie_name)
|
||||
)
|
||||
scan_results = list(results.scalars().all())
|
||||
|
||||
return {
|
||||
"id": job.id,
|
||||
"site_id": job.site_id,
|
||||
"status": job.status,
|
||||
"trigger": job.trigger,
|
||||
"pages_scanned": job.pages_scanned,
|
||||
"pages_total": job.pages_total,
|
||||
"cookies_found": job.cookies_found,
|
||||
"error_message": job.error_message,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
"results": scan_results,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scans/{scan_id}/diff", response_model=ScanDiffResponse)
|
||||
async def get_scan_diff(
|
||||
scan_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScanDiffResponse:
|
||||
"""Get the diff between a scan and its predecessor."""
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Scan job not found",
|
||||
)
|
||||
|
||||
await _verify_site_access(job.site_id, user, db)
|
||||
|
||||
return await compute_scan_diff(db, current_scan_id=scan_id, site_id=job.site_id)
|
||||
101
apps/api/src/routers/site_group_config.py
Normal file
101
apps/api/src/routers/site_group_config.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Site-group-level default configuration endpoints.
|
||||
|
||||
Provides GET and PUT for a site group's config defaults.
|
||||
These defaults sit between org defaults and site config in the cascade.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site_group import SiteGroup
|
||||
from src.models.site_group_config import SiteGroupConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site_group_config import SiteGroupConfigResponse, SiteGroupConfigUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
|
||||
|
||||
|
||||
@router.get("/{group_id}/config", response_model=SiteGroupConfigResponse)
|
||||
async def get_site_group_config(
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteGroupConfig:
|
||||
"""Retrieve configuration defaults for a site group."""
|
||||
await _verify_group_ownership(group_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
# Auto-create an empty config row so the response is always valid
|
||||
config = SiteGroupConfig(site_group_id=group_id)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/{group_id}/config", response_model=SiteGroupConfigResponse)
|
||||
async def update_site_group_config(
|
||||
group_id: uuid.UUID,
|
||||
body: SiteGroupConfigUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteGroupConfig:
|
||||
"""Create or update configuration defaults for a site group.
|
||||
|
||||
Only non-None fields will override org/system defaults when resolving site config.
|
||||
"""
|
||||
await _verify_group_ownership(group_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
config = SiteGroupConfig(
|
||||
site_group_id=group_id,
|
||||
**body.model_dump(exclude_unset=True),
|
||||
)
|
||||
db.add(config)
|
||||
else:
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
# -- Helpers ------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _verify_group_ownership(
|
||||
group_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Ensure the site group belongs to the user's organisation."""
|
||||
result = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.id == group_id,
|
||||
SiteGroup.organisation_id == organisation_id,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if result.scalar_one_or_none() is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site group not found",
|
||||
)
|
||||
198
apps/api/src/routers/site_groups.py
Normal file
198
apps/api/src/routers/site_groups.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
from src.models.site_group import SiteGroup
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site_group import SiteGroupCreate, SiteGroupResponse, SiteGroupUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
|
||||
|
||||
|
||||
@router.post("/", response_model=SiteGroupResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_site_group(
|
||||
body: SiteGroupCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Create a new site group within the current organisation."""
|
||||
# Check name uniqueness within the org
|
||||
existing = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.organisation_id == current_user.organisation_id,
|
||||
SiteGroup.name == body.name,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Site group '{body.name}' already exists in this organisation",
|
||||
)
|
||||
|
||||
group = SiteGroup(
|
||||
organisation_id=current_user.organisation_id,
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
)
|
||||
db.add(group)
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
return _to_response(group, site_count=0)
|
||||
|
||||
|
||||
@router.get("/", response_model=list[SiteGroupResponse])
|
||||
async def list_site_groups(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict]:
|
||||
"""List all site groups in the current organisation with site counts."""
|
||||
# Subquery for site counts
|
||||
site_count_sq = (
|
||||
select(
|
||||
Site.site_group_id,
|
||||
func.count(Site.id).label("cnt"),
|
||||
)
|
||||
.where(Site.deleted_at.is_(None))
|
||||
.group_by(Site.site_group_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteGroup, func.coalesce(site_count_sq.c.cnt, 0).label("site_count"))
|
||||
.outerjoin(site_count_sq, SiteGroup.id == site_count_sq.c.site_group_id)
|
||||
.where(
|
||||
SiteGroup.organisation_id == current_user.organisation_id,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(SiteGroup.name)
|
||||
)
|
||||
|
||||
return [_to_response(row.SiteGroup, site_count=row.site_count) for row in result.all()]
|
||||
|
||||
|
||||
@router.get("/{group_id}", response_model=SiteGroupResponse)
|
||||
async def get_site_group(
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Get a specific site group by ID."""
|
||||
group = await _get_org_group(group_id, current_user.organisation_id, db)
|
||||
site_count = await _count_sites(group_id, db)
|
||||
return _to_response(group, site_count=site_count)
|
||||
|
||||
|
||||
@router.patch("/{group_id}", response_model=SiteGroupResponse)
|
||||
async def update_site_group(
|
||||
group_id: uuid.UUID,
|
||||
body: SiteGroupUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Update a site group's name or description."""
|
||||
group = await _get_org_group(group_id, current_user.organisation_id, db)
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
|
||||
# Check name uniqueness if name is being changed
|
||||
if "name" in update_data and update_data["name"] != group.name:
|
||||
existing = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.organisation_id == current_user.organisation_id,
|
||||
SiteGroup.name == update_data["name"],
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
SiteGroup.id != group_id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Site group '{update_data['name']}' already exists",
|
||||
)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(group, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
site_count = await _count_sites(group_id, db)
|
||||
return _to_response(group, site_count=site_count)
|
||||
|
||||
|
||||
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_site_group(
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Soft-delete a site group. Sites in this group become ungrouped."""
|
||||
group = await _get_org_group(group_id, current_user.organisation_id, db)
|
||||
|
||||
# Ungroup all sites in this group
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.site_group_id == group_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
for site in result.scalars().all():
|
||||
site.site_group_id = None
|
||||
|
||||
group.deleted_at = datetime.now(UTC)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_org_group(
|
||||
group_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> SiteGroup:
|
||||
"""Fetch a site group ensuring it belongs to the given organisation."""
|
||||
result = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.id == group_id,
|
||||
SiteGroup.organisation_id == organisation_id,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
group = result.scalar_one_or_none()
|
||||
if group is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site group not found",
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
async def _count_sites(group_id: uuid.UUID, db: AsyncSession) -> int:
|
||||
"""Count active sites in a group."""
|
||||
result = await db.execute(
|
||||
select(func.count(Site.id)).where(
|
||||
Site.site_group_id == group_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
def _to_response(group: SiteGroup, *, site_count: int) -> dict:
|
||||
"""Convert a SiteGroup model to a response dict with site_count."""
|
||||
return {
|
||||
"id": group.id,
|
||||
"organisation_id": group.organisation_id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"created_at": group.created_at,
|
||||
"updated_at": group.updated_at,
|
||||
"site_count": site_count,
|
||||
}
|
||||
220
apps/api/src/routers/sites.py
Normal file
220
apps/api/src/routers/sites.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site import (
|
||||
SiteConfigCreate,
|
||||
SiteConfigResponse,
|
||||
SiteConfigUpdate,
|
||||
SiteCreate,
|
||||
SiteResponse,
|
||||
SiteUpdate,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/sites", tags=["sites"])
|
||||
|
||||
|
||||
# ── Site CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/", response_model=SiteResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_site(
|
||||
body: SiteCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Site:
|
||||
"""Create a new site within the current organisation."""
|
||||
# Check domain uniqueness within the org
|
||||
existing = await db.execute(
|
||||
select(Site).where(
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.domain == body.domain,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Site with domain '{body.domain}' already exists in this organisation",
|
||||
)
|
||||
|
||||
site = Site(
|
||||
organisation_id=current_user.organisation_id,
|
||||
domain=body.domain,
|
||||
display_name=body.display_name,
|
||||
site_group_id=body.site_group_id,
|
||||
)
|
||||
db.add(site)
|
||||
await db.flush()
|
||||
|
||||
# Auto-create a default site configuration
|
||||
default_config = SiteConfig(site_id=site.id)
|
||||
db.add(default_config)
|
||||
await db.flush()
|
||||
|
||||
await db.refresh(site)
|
||||
return site
|
||||
|
||||
|
||||
@router.get("/", response_model=list[SiteResponse])
|
||||
async def list_sites(
|
||||
site_group_id: uuid.UUID | None = Query(default=None),
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[Site]:
|
||||
"""List all active sites in the current organisation, optionally filtered by group."""
|
||||
query = select(Site).where(
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
if site_group_id is not None:
|
||||
query = query.where(Site.site_group_id == site_group_id)
|
||||
result = await db.execute(query.order_by(Site.domain))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/{site_id}", response_model=SiteResponse)
|
||||
async def get_site(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Site:
|
||||
"""Get a specific site by ID."""
|
||||
site = await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
return site
|
||||
|
||||
|
||||
@router.patch("/{site_id}", response_model=SiteResponse)
|
||||
async def update_site(
|
||||
site_id: uuid.UUID,
|
||||
body: SiteUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Site:
|
||||
"""Update a site's display name or active status."""
|
||||
site = await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(site, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(site)
|
||||
return site
|
||||
|
||||
|
||||
@router.delete("/{site_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def deactivate_site(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Soft-delete a site."""
|
||||
site = await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
site.deleted_at = datetime.now(UTC)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ── Site config CRUD ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/{site_id}/config", response_model=SiteConfigResponse)
|
||||
async def get_site_config(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Get the configuration for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found. Create one first.",
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/{site_id}/config", response_model=SiteConfigResponse)
|
||||
async def create_or_replace_site_config(
|
||||
site_id: uuid.UUID,
|
||||
body: SiteConfigCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Create or replace the full configuration for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
for field, value in body.model_dump().items():
|
||||
setattr(existing, field, value)
|
||||
await db.flush()
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
|
||||
config = SiteConfig(site_id=site_id, **body.model_dump())
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
@router.patch("/{site_id}/config", response_model=SiteConfigResponse)
|
||||
async def update_site_config(
|
||||
site_id: uuid.UUID,
|
||||
body: SiteConfigUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Partially update the configuration for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found. Create one first.",
|
||||
)
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_org_site(
|
||||
site_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> Site:
|
||||
"""Fetch a site ensuring it belongs to the given organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
|
||||
return site
|
||||
195
apps/api/src/routers/translations.py
Normal file
195
apps/api/src/routers/translations.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Translation management endpoints.
|
||||
|
||||
CRUD for per-site, per-locale translation strings used by the banner script.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
from src.models.translation import Translation
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.translation import TranslationCreate, TranslationResponse, TranslationUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/sites/{site_id}/translations", tags=["translations"])
|
||||
|
||||
|
||||
async def _get_org_site(site_id: uuid.UUID, organisation_id: uuid.UUID, db: AsyncSession) -> Site:
|
||||
"""Ensure site belongs to the current organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
|
||||
return site
|
||||
|
||||
|
||||
@router.get("/", response_model=list[TranslationResponse])
|
||||
async def list_translations(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[Translation]:
|
||||
"""List all translations for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(Translation.site_id == site_id).order_by(Translation.locale)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/{locale}", response_model=TranslationResponse)
|
||||
async def get_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Translation:
|
||||
"""Get translation strings for a specific locale."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No translation found for locale '{locale}'",
|
||||
)
|
||||
return translation
|
||||
|
||||
|
||||
@router.post("/", response_model=TranslationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_translation(
|
||||
site_id: uuid.UUID,
|
||||
body: TranslationCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Translation:
|
||||
"""Create a translation for a new locale."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
# Check for duplicate locale
|
||||
existing = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == body.locale,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Translation for locale '{body.locale}' already exists",
|
||||
)
|
||||
|
||||
translation = Translation(
|
||||
site_id=site_id,
|
||||
locale=body.locale,
|
||||
strings=body.strings,
|
||||
)
|
||||
db.add(translation)
|
||||
await db.flush()
|
||||
await db.refresh(translation)
|
||||
return translation
|
||||
|
||||
|
||||
@router.put("/{locale}", response_model=TranslationResponse)
|
||||
async def update_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
body: TranslationUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Translation:
|
||||
"""Replace the strings for an existing locale translation."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No translation found for locale '{locale}'",
|
||||
)
|
||||
|
||||
translation.strings = body.strings
|
||||
await db.flush()
|
||||
await db.refresh(translation)
|
||||
return translation
|
||||
|
||||
|
||||
@router.delete("/{locale}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a translation for a specific locale."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No translation found for locale '{locale}'",
|
||||
)
|
||||
await db.delete(translation)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ── Public endpoint for the banner script ────────────────────────────
|
||||
|
||||
public_router = APIRouter(prefix="/translations", tags=["translations"])
|
||||
|
||||
|
||||
@public_router.get("/{site_id}/{locale}")
|
||||
async def get_public_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, str]:
|
||||
"""Public endpoint: return translation strings for the banner script.
|
||||
|
||||
No auth required. Returns the raw strings dict for a given site and locale.
|
||||
Returns 404 if no translation exists (banner falls back to English defaults).
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Translation)
|
||||
.join(Site)
|
||||
.where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Translation not found",
|
||||
)
|
||||
return translation.strings
|
||||
136
apps/api/src/routers/users.py
Normal file
136
apps/api/src/routers/users.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.user import User
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.user import UserCreate, UserResponse, UserUpdate
|
||||
from src.services.auth import hash_password
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
body: UserCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Invite/create a new user within the current organisation."""
|
||||
# Check email uniqueness
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"User with email '{body.email}' already exists",
|
||||
)
|
||||
|
||||
user = User(
|
||||
organisation_id=current_user.organisation_id,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
full_name=body.full_name,
|
||||
role=body.role,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/", response_model=list[UserResponse])
|
||||
async def list_users(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[User]:
|
||||
"""List all active users in the current organisation."""
|
||||
result = await db.execute(
|
||||
select(User)
|
||||
.where(
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(User.created_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Get a specific user by ID within the current organisation."""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: uuid.UUID,
|
||||
body: UserUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Update a user's name or role. Requires owner or admin."""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def deactivate_user(
|
||||
user_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Soft-delete (deactivate) a user. Requires owner or admin."""
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot deactivate yourself",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
user.deleted_at = datetime.now(UTC)
|
||||
await db.flush()
|
||||
0
apps/api/src/schemas/__init__.py
Normal file
0
apps/api/src/schemas/__init__.py
Normal file
45
apps/api/src/schemas/auth.py
Normal file
45
apps/api/src/schemas/auth.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # user ID
|
||||
org_id: str # organisation ID
|
||||
role: str # user role
|
||||
exp: datetime
|
||||
iat: datetime
|
||||
type: str = "access" # "access" or "refresh"
|
||||
|
||||
|
||||
class CurrentUser(BaseModel):
|
||||
"""Represents the authenticated user extracted from a JWT."""
|
||||
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
email: str
|
||||
role: str
|
||||
|
||||
def has_role(self, *roles: str) -> bool:
|
||||
return self.role in roles
|
||||
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
return self.role in ("owner", "admin")
|
||||
56
apps/api/src/schemas/compliance.py
Normal file
56
apps/api/src/schemas/compliance.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Pydantic schemas for compliance check results."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Severity(StrEnum):
|
||||
CRITICAL = "critical"
|
||||
WARNING = "warning"
|
||||
INFO = "info"
|
||||
|
||||
|
||||
class Framework(StrEnum):
|
||||
GDPR = "gdpr"
|
||||
CNIL = "cnil"
|
||||
CCPA = "ccpa"
|
||||
EPRIVACY = "eprivacy"
|
||||
LGPD = "lgpd"
|
||||
|
||||
|
||||
class ComplianceIssue(BaseModel):
|
||||
"""A single compliance issue found during a check."""
|
||||
|
||||
rule_id: str
|
||||
severity: Severity
|
||||
message: str
|
||||
recommendation: str
|
||||
|
||||
|
||||
class FrameworkResult(BaseModel):
|
||||
"""Compliance result for a single regulatory framework."""
|
||||
|
||||
framework: Framework
|
||||
score: int = Field(ge=0, le=100, description="Compliance score (0-100)")
|
||||
status: str = Field(description="Overall status: compliant, partial, non_compliant")
|
||||
issues: list[ComplianceIssue] = Field(default_factory=list)
|
||||
rules_checked: int = 0
|
||||
rules_passed: int = 0
|
||||
|
||||
|
||||
class ComplianceCheckRequest(BaseModel):
|
||||
"""Request body for compliance checks."""
|
||||
|
||||
frameworks: list[Framework] | None = Field(
|
||||
default=None,
|
||||
description="Frameworks to check. If null, all frameworks are checked.",
|
||||
)
|
||||
|
||||
|
||||
class ComplianceCheckResponse(BaseModel):
|
||||
"""Full compliance check response for a site."""
|
||||
|
||||
site_id: str
|
||||
results: list[FrameworkResult]
|
||||
overall_score: int = Field(ge=0, le=100, description="Weighted average across all frameworks")
|
||||
62
apps/api/src/schemas/consent.py
Normal file
62
apps/api/src/schemas/consent.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ConsentAction(StrEnum):
|
||||
ACCEPT_ALL = "accept_all"
|
||||
REJECT_ALL = "reject_all"
|
||||
CUSTOM = "custom"
|
||||
WITHDRAW = "withdraw"
|
||||
|
||||
|
||||
class ConsentRecordCreate(BaseModel):
|
||||
"""Payload sent by the banner when a consent event occurs."""
|
||||
|
||||
site_id: uuid.UUID
|
||||
visitor_id: str = Field(min_length=1, max_length=255)
|
||||
action: ConsentAction
|
||||
categories_accepted: list[str]
|
||||
categories_rejected: list[str] | None = None
|
||||
tc_string: str | None = None
|
||||
gcm_state: dict | None = None
|
||||
gpp_string: str | None = None
|
||||
gpc_detected: bool | None = None
|
||||
gpc_honoured: bool | None = None
|
||||
page_url: str | None = None
|
||||
country_code: str | None = Field(default=None, max_length=5)
|
||||
region_code: str | None = Field(default=None, max_length=10)
|
||||
|
||||
|
||||
class ConsentRecordResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
visitor_id: str
|
||||
action: str
|
||||
categories_accepted: list
|
||||
categories_rejected: list | None = None
|
||||
tc_string: str | None = None
|
||||
gcm_state: dict | None = None
|
||||
gpp_string: str | None = None
|
||||
gpc_detected: bool | None = None
|
||||
gpc_honoured: bool | None = None
|
||||
page_url: str | None = None
|
||||
country_code: str | None = None
|
||||
region_code: str | None = None
|
||||
consented_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ConsentVerifyResponse(BaseModel):
|
||||
"""Audit proof that a consent record exists."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
visitor_id: str
|
||||
action: str
|
||||
categories_accepted: list
|
||||
consented_at: datetime
|
||||
valid: bool = True
|
||||
210
apps/api/src/schemas/cookie.py
Normal file
210
apps/api/src/schemas/cookie.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Pydantic schemas for cookie categories, cookies, and allow-list entries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ─── Cookie category schemas ───
|
||||
|
||||
|
||||
class CookieCategoryResponse(BaseModel):
|
||||
"""Response schema for a cookie category."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
slug: str
|
||||
description: str | None = None
|
||||
is_essential: bool
|
||||
display_order: int
|
||||
tcf_purpose_ids: list[int] | None = None
|
||||
gcm_consent_types: list[str] | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Storage type enum ───
|
||||
|
||||
|
||||
class StorageType(StrEnum):
|
||||
"""Type of browser storage used by the cookie/tracker."""
|
||||
|
||||
cookie = "cookie"
|
||||
local_storage = "local_storage"
|
||||
session_storage = "session_storage"
|
||||
indexed_db = "indexed_db"
|
||||
|
||||
|
||||
# ─── Review status enum ───
|
||||
|
||||
|
||||
class ReviewStatus(StrEnum):
|
||||
"""Review status for a discovered cookie."""
|
||||
|
||||
pending = "pending"
|
||||
approved = "approved"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
# ─── Cookie schemas ───
|
||||
|
||||
|
||||
class CookieCreate(BaseModel):
|
||||
"""Schema for creating a cookie record (typically from scanner/reporter)."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
domain: str = Field(..., min_length=1, max_length=255)
|
||||
storage_type: StorageType = StorageType.cookie
|
||||
category_id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
path: str | None = Field(None, max_length=500)
|
||||
max_age_seconds: int | None = None
|
||||
is_http_only: bool | None = None
|
||||
is_secure: bool | None = None
|
||||
same_site: str | None = Field(None, max_length=10)
|
||||
|
||||
|
||||
class CookieUpdate(BaseModel):
|
||||
"""Schema for updating a cookie record."""
|
||||
|
||||
category_id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
review_status: ReviewStatus | None = None
|
||||
|
||||
|
||||
class CookieResponse(BaseModel):
|
||||
"""Response schema for a cookie."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
category_id: uuid.UUID | None = None
|
||||
name: str
|
||||
domain: str
|
||||
storage_type: str
|
||||
description: str | None = None
|
||||
vendor: str | None = None
|
||||
path: str | None = None
|
||||
max_age_seconds: int | None = None
|
||||
is_http_only: bool | None = None
|
||||
is_secure: bool | None = None
|
||||
same_site: str | None = None
|
||||
review_status: str
|
||||
first_seen_at: str | None = None
|
||||
last_seen_at: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Allow-list schemas ───
|
||||
|
||||
|
||||
class AllowListEntryCreate(BaseModel):
|
||||
"""Schema for adding a cookie to the allow-list."""
|
||||
|
||||
name_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
domain_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
category_id: uuid.UUID
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class AllowListEntryUpdate(BaseModel):
|
||||
"""Schema for updating an allow-list entry."""
|
||||
|
||||
category_id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class AllowListEntryResponse(BaseModel):
|
||||
"""Response schema for an allow-list entry."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
category_id: uuid.UUID
|
||||
name_pattern: str
|
||||
domain_pattern: str
|
||||
description: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Known cookie schemas ───
|
||||
|
||||
|
||||
class KnownCookieCreate(BaseModel):
|
||||
"""Schema for creating a known cookie pattern."""
|
||||
|
||||
name_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
domain_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
category_id: uuid.UUID
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
description: str | None = None
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class KnownCookieUpdate(BaseModel):
|
||||
"""Schema for updating a known cookie pattern."""
|
||||
|
||||
category_id: uuid.UUID | None = None
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
description: str | None = None
|
||||
is_regex: bool | None = None
|
||||
|
||||
|
||||
class KnownCookieResponse(BaseModel):
|
||||
"""Response schema for a known cookie pattern."""
|
||||
|
||||
id: uuid.UUID
|
||||
name_pattern: str
|
||||
domain_pattern: str
|
||||
category_id: uuid.UUID
|
||||
vendor: str | None = None
|
||||
description: str | None = None
|
||||
is_regex: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Classification schemas ───
|
||||
|
||||
|
||||
class ClassificationResultResponse(BaseModel):
|
||||
"""Response for a single cookie classification result."""
|
||||
|
||||
cookie_name: str
|
||||
cookie_domain: str
|
||||
category_id: uuid.UUID | None = None
|
||||
category_slug: str | None = None
|
||||
vendor: str | None = None
|
||||
description: str | None = None
|
||||
match_source: str
|
||||
matched: bool
|
||||
|
||||
|
||||
class ClassifySiteResponse(BaseModel):
|
||||
"""Response for classifying all cookies on a site."""
|
||||
|
||||
site_id: str
|
||||
total: int
|
||||
matched: int
|
||||
unmatched: int
|
||||
results: list[ClassificationResultResponse]
|
||||
|
||||
|
||||
class ClassifySingleRequest(BaseModel):
|
||||
"""Request to classify a single cookie (preview/test)."""
|
||||
|
||||
cookie_name: str = Field(..., min_length=1, max_length=255)
|
||||
cookie_domain: str = Field(..., min_length=1, max_length=255)
|
||||
61
apps/api/src/schemas/org_config.py
Normal file
61
apps/api/src/schemas/org_config.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.schemas.site import BlockingMode
|
||||
|
||||
|
||||
class OrgConfigUpdate(BaseModel):
|
||||
"""Update (or create) organisation-level default configuration.
|
||||
|
||||
All fields are optional — only non-None values override the system defaults.
|
||||
"""
|
||||
|
||||
blocking_mode: BlockingMode | None = None
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool | None = None
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool | None = None
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool | None = None
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool | None = None
|
||||
gcm_enabled: bool | None = None
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool | None = None
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
|
||||
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
|
||||
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class OrgConfigResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
blocking_mode: str | None
|
||||
regional_modes: dict | None
|
||||
tcf_enabled: bool | None
|
||||
tcf_publisher_cc: str | None
|
||||
gpp_enabled: bool | None
|
||||
gpp_supported_apis: list[str] | None
|
||||
gpc_enabled: bool | None
|
||||
gpc_jurisdictions: list[str] | None
|
||||
gpc_global_honour: bool | None
|
||||
gcm_enabled: bool | None
|
||||
gcm_default: dict | None
|
||||
shopify_privacy_enabled: bool | None
|
||||
banner_config: dict | None
|
||||
privacy_policy_url: str | None
|
||||
terms_url: str | None
|
||||
scan_schedule_cron: str | None
|
||||
scan_max_pages: int | None
|
||||
consent_expiry_days: int | None
|
||||
consent_retention_days: int | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
29
apps/api/src/schemas/organisation.py
Normal file
29
apps/api/src/schemas/organisation.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OrganisationCreate(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
slug: str = Field(min_length=1, max_length=100, pattern=r"^[a-z0-9-]+$")
|
||||
contact_email: str | None = None
|
||||
billing_plan: str = "free"
|
||||
|
||||
|
||||
class OrganisationUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
contact_email: str | None = None
|
||||
billing_plan: str | None = None
|
||||
|
||||
|
||||
class OrganisationResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
slug: str
|
||||
contact_email: str | None
|
||||
billing_plan: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
142
apps/api/src/schemas/scanner.py
Normal file
142
apps/api/src/schemas/scanner.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Pydantic schemas for scanner and client-side cookie reports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ScanStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ScanTrigger(StrEnum):
|
||||
MANUAL = "manual"
|
||||
SCHEDULED = "scheduled"
|
||||
CLIENT_REPORT = "client_report"
|
||||
|
||||
|
||||
# ── Client-side cookie report ────────────────────────────────────────
|
||||
|
||||
|
||||
class ReportedCookie(BaseModel):
|
||||
"""A single cookie/storage item reported by the client-side reporter."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
domain: str = Field(..., min_length=1, max_length=255)
|
||||
storage_type: str = Field(default="cookie", max_length=30)
|
||||
value_length: int = Field(default=0, ge=0)
|
||||
path: str | None = None
|
||||
is_secure: bool | None = None
|
||||
same_site: str | None = None
|
||||
script_source: str | None = None
|
||||
|
||||
|
||||
class CookieReportRequest(BaseModel):
|
||||
"""Payload from the client-side cookie reporter."""
|
||||
|
||||
site_id: uuid.UUID
|
||||
page_url: str = Field(..., max_length=2000)
|
||||
cookies: list[ReportedCookie] = Field(..., max_length=500)
|
||||
collected_at: datetime
|
||||
user_agent: str = Field(default="", max_length=500)
|
||||
|
||||
|
||||
class CookieReportResponse(BaseModel):
|
||||
"""Acknowledgement response for a cookie report."""
|
||||
|
||||
accepted: bool = True
|
||||
cookies_received: int
|
||||
new_cookies: int = 0
|
||||
|
||||
|
||||
# ── Scan job schemas ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ScanResultResponse(BaseModel):
|
||||
"""A single scan result — a cookie found on a specific page."""
|
||||
|
||||
id: uuid.UUID
|
||||
scan_job_id: uuid.UUID
|
||||
page_url: str
|
||||
cookie_name: str
|
||||
cookie_domain: str
|
||||
storage_type: str
|
||||
attributes: dict | None = None
|
||||
script_source: str | None = None
|
||||
auto_category: str | None = None
|
||||
initiator_chain: list[str] | None = None
|
||||
found_at: datetime
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ScanJobResponse(BaseModel):
|
||||
"""Response schema for a scan job."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
status: str
|
||||
trigger: str
|
||||
pages_scanned: int
|
||||
pages_total: int | None
|
||||
cookies_found: int
|
||||
error_message: str | None
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ScanJobDetailResponse(ScanJobResponse):
|
||||
"""Scan job with results included."""
|
||||
|
||||
results: list[ScanResultResponse] = []
|
||||
|
||||
|
||||
class TriggerScanRequest(BaseModel):
|
||||
"""Request to trigger a new scan."""
|
||||
|
||||
site_id: uuid.UUID
|
||||
max_pages: int = Field(default=50, ge=1, le=500)
|
||||
|
||||
|
||||
# ── Diff engine schemas ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class DiffStatus(StrEnum):
|
||||
NEW = "new"
|
||||
REMOVED = "removed"
|
||||
CHANGED = "changed"
|
||||
|
||||
|
||||
class CookieDiffItem(BaseModel):
|
||||
"""A single cookie difference between two scans."""
|
||||
|
||||
name: str
|
||||
domain: str
|
||||
storage_type: str
|
||||
diff_status: DiffStatus
|
||||
details: str | None = None
|
||||
|
||||
|
||||
class ScanDiffResponse(BaseModel):
|
||||
"""Diff between two scans."""
|
||||
|
||||
current_scan_id: uuid.UUID
|
||||
previous_scan_id: uuid.UUID | None
|
||||
new_cookies: list[CookieDiffItem] = []
|
||||
removed_cookies: list[CookieDiffItem] = []
|
||||
changed_cookies: list[CookieDiffItem] = []
|
||||
total_new: int = 0
|
||||
total_removed: int = 0
|
||||
total_changed: int = 0
|
||||
117
apps/api/src/schemas/site.py
Normal file
117
apps/api/src/schemas/site.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BlockingMode(StrEnum):
|
||||
OPT_IN = "opt_in"
|
||||
OPT_OUT = "opt_out"
|
||||
INFORMATIONAL = "informational"
|
||||
|
||||
|
||||
# ── Site schemas ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class SiteCreate(BaseModel):
|
||||
domain: str = Field(min_length=1, max_length=255)
|
||||
display_name: str = Field(min_length=1, max_length=255)
|
||||
additional_domains: list[str] | None = None
|
||||
site_group_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class SiteUpdate(BaseModel):
|
||||
display_name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
is_active: bool | None = None
|
||||
additional_domains: list[str] | None = None
|
||||
site_group_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class SiteResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
domain: str
|
||||
display_name: str
|
||||
is_active: bool
|
||||
additional_domains: list[str] | None = None
|
||||
site_group_id: uuid.UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── Site config schemas ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class SiteConfigCreate(BaseModel):
|
||||
blocking_mode: BlockingMode = BlockingMode.OPT_IN
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool = False
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool = True
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool = True
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool = False
|
||||
gcm_enabled: bool = True
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool = False
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int = Field(default=50, ge=1, le=1000)
|
||||
consent_expiry_days: int = Field(default=365, ge=1, le=730)
|
||||
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class SiteConfigUpdate(BaseModel):
|
||||
blocking_mode: BlockingMode | None = None
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool | None = None
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool | None = None
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool | None = None
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool | None = None
|
||||
gcm_enabled: bool | None = None
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool | None = None
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
|
||||
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
|
||||
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class SiteConfigResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
blocking_mode: str
|
||||
regional_modes: dict | None
|
||||
tcf_enabled: bool
|
||||
tcf_publisher_cc: str | None = None
|
||||
gpp_enabled: bool = True
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool = True
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool = False
|
||||
gcm_enabled: bool
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool = False
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int = 50
|
||||
consent_expiry_days: int = 365
|
||||
consent_retention_days: int | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
26
apps/api/src/schemas/site_group.py
Normal file
26
apps/api/src/schemas/site_group.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SiteGroupCreate(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SiteGroupUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SiteGroupResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
site_count: int = 0
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
59
apps/api/src/schemas/site_group_config.py
Normal file
59
apps/api/src/schemas/site_group_config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.schemas.site import BlockingMode
|
||||
|
||||
|
||||
class SiteGroupConfigUpdate(BaseModel):
|
||||
"""Update (or create) site-group-level default configuration.
|
||||
|
||||
All fields are optional — only non-None values override the org/system defaults.
|
||||
"""
|
||||
|
||||
blocking_mode: BlockingMode | None = None
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool | None = None
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool | None = None
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool | None = None
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool | None = None
|
||||
gcm_enabled: bool | None = None
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool | None = None
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
|
||||
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class SiteGroupConfigResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_group_id: uuid.UUID
|
||||
blocking_mode: str | None
|
||||
regional_modes: dict | None
|
||||
tcf_enabled: bool | None
|
||||
tcf_publisher_cc: str | None
|
||||
gpp_enabled: bool | None
|
||||
gpp_supported_apis: list[str] | None
|
||||
gpc_enabled: bool | None
|
||||
gpc_jurisdictions: list[str] | None
|
||||
gpc_global_honour: bool | None
|
||||
gcm_enabled: bool | None
|
||||
gcm_default: dict | None
|
||||
shopify_privacy_enabled: bool | None
|
||||
banner_config: dict | None
|
||||
privacy_policy_url: str | None
|
||||
terms_url: str | None
|
||||
scan_schedule_cron: str | None
|
||||
scan_max_pages: int | None
|
||||
consent_expiry_days: int | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
24
apps/api/src/schemas/translation.py
Normal file
24
apps/api/src/schemas/translation.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TranslationCreate(BaseModel):
|
||||
locale: str = Field(min_length=2, max_length=10)
|
||||
strings: dict[str, str]
|
||||
|
||||
|
||||
class TranslationUpdate(BaseModel):
|
||||
strings: dict[str, str]
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
locale: str
|
||||
strings: dict[str, str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
36
apps/api/src/schemas/user.py
Normal file
36
apps/api/src/schemas/user.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class UserRole(StrEnum):
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
EDITOR = "editor"
|
||||
VIEWER = "viewer"
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=72)
|
||||
full_name: str = Field(min_length=1, max_length=255)
|
||||
role: UserRole = UserRole.VIEWER
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
full_name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
role: UserRole | None = None
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
email: str
|
||||
full_name: str
|
||||
role: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
0
apps/api/src/services/__init__.py
Normal file
0
apps/api/src/services/__init__.py
Normal file
59
apps/api/src/services/auth.py
Normal file
59
apps/api/src/services/auth.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import bcrypt
|
||||
from jose import jwt
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
role: str,
|
||||
email: str,
|
||||
) -> str:
|
||||
settings = get_settings()
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(minutes=settings.jwt_access_token_expire_minutes)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"org_id": str(organisation_id),
|
||||
"role": role,
|
||||
"email": email,
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"type": "access",
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
user_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
) -> str:
|
||||
settings = get_settings()
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(days=settings.jwt_refresh_token_expire_days)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"org_id": str(organisation_id),
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"type": "refresh",
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
"""Decode and validate a JWT token. Raises JWTError on failure."""
|
||||
settings = get_settings()
|
||||
return jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
|
||||
79
apps/api/src/services/bootstrap.py
Normal file
79
apps/api/src/services/bootstrap.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""First-run bootstrap of an organisation and owner user.
|
||||
|
||||
Runs once on API startup. If ``INITIAL_ADMIN_EMAIL`` and
|
||||
``INITIAL_ADMIN_PASSWORD`` are set and the ``users`` table is empty,
|
||||
creates an organisation and a single owner user so the operator can log
|
||||
in to the admin UI for the first time. Idempotent: once any user
|
||||
exists, this is a no-op, so the environment variables can safely remain
|
||||
set across restarts. Complements ``ADMIN_BOOTSTRAP_TOKEN`` — that gates
|
||||
runtime org creation; this creates the *initial* org + owner without
|
||||
requiring a second round-trip.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.db.session import async_session_factory
|
||||
from src.models.organisation import Organisation
|
||||
from src.models.user import User
|
||||
from src.services.auth import hash_password
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def bootstrap_initial_admin(settings: Settings) -> None:
|
||||
"""Create the first organisation and owner user if none exist.
|
||||
|
||||
No-op when either credential env var is unset or when the database
|
||||
already contains at least one user. Unexpected errors are logged
|
||||
and swallowed — a failed bootstrap must not prevent the API from
|
||||
starting, since operators can always fall back to manual provisioning.
|
||||
"""
|
||||
if not settings.initial_admin_email or not settings.initial_admin_password:
|
||||
logger.debug("Initial admin bootstrap skipped: credentials not configured")
|
||||
return
|
||||
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
await _bootstrap(session, settings)
|
||||
except Exception: # pragma: no cover — defensive, logged
|
||||
logger.exception("Initial admin bootstrap failed")
|
||||
|
||||
|
||||
async def _bootstrap(session: AsyncSession, settings: Settings) -> None:
|
||||
existing_users = await session.scalar(select(func.count()).select_from(User))
|
||||
if existing_users:
|
||||
logger.debug("Initial admin bootstrap skipped: %d user(s) already exist", existing_users)
|
||||
return
|
||||
|
||||
org = await session.scalar(
|
||||
select(Organisation).where(Organisation.slug == settings.initial_org_slug)
|
||||
)
|
||||
if org is None:
|
||||
org = Organisation(
|
||||
name=settings.initial_org_name,
|
||||
slug=settings.initial_org_slug,
|
||||
contact_email=settings.initial_admin_email,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
organisation_id=org.id,
|
||||
email=settings.initial_admin_email,
|
||||
password_hash=hash_password(settings.initial_admin_password),
|
||||
full_name=settings.initial_admin_full_name,
|
||||
role="owner",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
logger.warning(
|
||||
"Initial admin bootstrap created owner %s in organisation '%s'. "
|
||||
"Rotate the password via the admin UI as soon as possible.",
|
||||
settings.initial_admin_email,
|
||||
org.slug,
|
||||
)
|
||||
298
apps/api/src/services/classification.py
Normal file
298
apps/api/src/services/classification.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Cookie auto-categorisation engine.
|
||||
|
||||
Matches discovered cookies against the known_cookies database using exact name
|
||||
matching, domain matching, and regex patterns. Also checks site-specific
|
||||
allow-list entries. Returns a classification result with category, vendor, and
|
||||
confidence level.
|
||||
|
||||
Matching priority (highest first):
|
||||
1. Site-specific allow-list (exact or pattern match)
|
||||
2. Known cookies — exact name + domain match
|
||||
3. Known cookies — regex pattern match on name + domain
|
||||
4. Unmatched → remains as 'pending'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.cookie import (
|
||||
Cookie,
|
||||
CookieAllowListEntry,
|
||||
CookieCategory,
|
||||
KnownCookie,
|
||||
)
|
||||
|
||||
|
||||
class MatchSource(StrEnum):
|
||||
"""Where the classification match came from."""
|
||||
|
||||
ALLOW_LIST = "allow_list"
|
||||
KNOWN_EXACT = "known_exact"
|
||||
KNOWN_REGEX = "known_regex"
|
||||
UNMATCHED = "unmatched"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassificationResult:
|
||||
"""Result of classifying a single cookie."""
|
||||
|
||||
cookie_name: str
|
||||
cookie_domain: str
|
||||
category_id: uuid.UUID | None = None
|
||||
category_slug: str | None = None
|
||||
vendor: str | None = None
|
||||
description: str | None = None
|
||||
match_source: MatchSource = MatchSource.UNMATCHED
|
||||
matched: bool = False
|
||||
|
||||
|
||||
async def _load_allow_list(
|
||||
db: AsyncSession,
|
||||
site_id: uuid.UUID,
|
||||
) -> list[CookieAllowListEntry]:
|
||||
"""Load the allow-list entries for a site."""
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry).where(
|
||||
CookieAllowListEntry.site_id == site_id,
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def _load_known_cookies(
|
||||
db: AsyncSession,
|
||||
) -> tuple[list[KnownCookie], list[KnownCookie]]:
|
||||
"""Load known cookies, split into exact and regex lists."""
|
||||
result = await db.execute(select(KnownCookie))
|
||||
all_known = list(result.scalars().all())
|
||||
|
||||
exact = [k for k in all_known if not k.is_regex]
|
||||
regex = [k for k in all_known if k.is_regex]
|
||||
return exact, regex
|
||||
|
||||
|
||||
async def _load_category_map(
|
||||
db: AsyncSession,
|
||||
) -> dict[uuid.UUID, CookieCategory]:
|
||||
"""Load a mapping of category ID to CookieCategory."""
|
||||
result = await db.execute(select(CookieCategory))
|
||||
return {cat.id: cat for cat in result.scalars().all()}
|
||||
|
||||
|
||||
def _match_pattern(pattern: str, value: str) -> bool:
|
||||
"""Check if a value matches a pattern (case-insensitive).
|
||||
|
||||
Patterns support:
|
||||
- Exact match (e.g. "_ga")
|
||||
- Wildcard with * (e.g. "_ga*", "*.google.com")
|
||||
- Regex if it contains regex-specific characters
|
||||
"""
|
||||
if not pattern or not value:
|
||||
return False
|
||||
|
||||
pattern_lower = pattern.lower()
|
||||
value_lower = value.lower()
|
||||
|
||||
# Simple exact match
|
||||
if pattern_lower == value_lower:
|
||||
return True
|
||||
|
||||
# Wildcard: convert * to regex .*
|
||||
if "*" in pattern_lower:
|
||||
regex_pattern = "^" + re.escape(pattern_lower).replace(r"\*", ".*") + "$"
|
||||
return bool(re.match(regex_pattern, value_lower))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _match_regex(pattern: str, value: str) -> bool:
|
||||
"""Match a value against a regex pattern (case-insensitive)."""
|
||||
try:
|
||||
return bool(re.match(pattern, value, re.IGNORECASE))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
|
||||
def _match_allow_list(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
allow_list: list[CookieAllowListEntry],
|
||||
) -> CookieAllowListEntry | None:
|
||||
"""Check if a cookie matches any allow-list entry."""
|
||||
for entry in allow_list:
|
||||
name_match = _match_pattern(entry.name_pattern, cookie_name)
|
||||
domain_match = _match_pattern(entry.domain_pattern, cookie_domain)
|
||||
if name_match and domain_match:
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
def _match_exact_known(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
exact_known: list[KnownCookie],
|
||||
) -> KnownCookie | None:
|
||||
"""Find an exact match in the known cookies database."""
|
||||
for known in exact_known:
|
||||
name_match = _match_pattern(known.name_pattern, cookie_name)
|
||||
domain_match = _match_pattern(known.domain_pattern, cookie_domain)
|
||||
if name_match and domain_match:
|
||||
return known
|
||||
return None
|
||||
|
||||
|
||||
def _match_regex_known(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
regex_known: list[KnownCookie],
|
||||
) -> KnownCookie | None:
|
||||
"""Find a regex match in the known cookies database."""
|
||||
for known in regex_known:
|
||||
name_match = _match_regex(known.name_pattern, cookie_name)
|
||||
domain_match = _match_regex(known.domain_pattern, cookie_domain)
|
||||
if name_match and domain_match:
|
||||
return known
|
||||
return None
|
||||
|
||||
|
||||
def classify_cookie(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
allow_list: list[CookieAllowListEntry],
|
||||
exact_known: list[KnownCookie],
|
||||
regex_known: list[KnownCookie],
|
||||
category_map: dict[uuid.UUID, CookieCategory],
|
||||
) -> ClassificationResult:
|
||||
"""Classify a single cookie against allow-list and known cookies DB.
|
||||
|
||||
This is a pure function — all data is passed in, no DB calls.
|
||||
"""
|
||||
# 1. Check allow-list first (site-specific overrides)
|
||||
allow_match = _match_allow_list(cookie_name, cookie_domain, allow_list)
|
||||
if allow_match:
|
||||
cat = category_map.get(allow_match.category_id)
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
category_id=allow_match.category_id,
|
||||
category_slug=cat.slug if cat else None,
|
||||
description=allow_match.description,
|
||||
match_source=MatchSource.ALLOW_LIST,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
# 2. Check exact known cookies
|
||||
exact_match = _match_exact_known(cookie_name, cookie_domain, exact_known)
|
||||
if exact_match:
|
||||
cat = category_map.get(exact_match.category_id)
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
category_id=exact_match.category_id,
|
||||
category_slug=cat.slug if cat else None,
|
||||
vendor=exact_match.vendor,
|
||||
description=exact_match.description,
|
||||
match_source=MatchSource.KNOWN_EXACT,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
# 3. Check regex known cookies
|
||||
regex_match = _match_regex_known(cookie_name, cookie_domain, regex_known)
|
||||
if regex_match:
|
||||
cat = category_map.get(regex_match.category_id)
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
category_id=regex_match.category_id,
|
||||
category_slug=cat.slug if cat else None,
|
||||
vendor=regex_match.vendor,
|
||||
description=regex_match.description,
|
||||
match_source=MatchSource.KNOWN_REGEX,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
# 4. Unmatched
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
)
|
||||
|
||||
|
||||
async def classify_site_cookies(
|
||||
db: AsyncSession,
|
||||
site_id: uuid.UUID,
|
||||
*,
|
||||
only_pending: bool = True,
|
||||
) -> list[ClassificationResult]:
|
||||
"""Classify all cookies for a site against known patterns.
|
||||
|
||||
If only_pending is True, only cookies with review_status='pending'
|
||||
and no category are classified.
|
||||
|
||||
Returns a list of results. Also updates matching cookies in the DB.
|
||||
"""
|
||||
# Load lookup data
|
||||
allow_list = await _load_allow_list(db, site_id)
|
||||
exact_known, regex_known = await _load_known_cookies(db)
|
||||
category_map = await _load_category_map(db)
|
||||
|
||||
# Load cookies to classify
|
||||
query = select(Cookie).where(Cookie.site_id == site_id)
|
||||
if only_pending:
|
||||
query = query.where(
|
||||
Cookie.review_status == "pending",
|
||||
Cookie.category_id.is_(None),
|
||||
)
|
||||
result = await db.execute(query)
|
||||
cookies = list(result.scalars().all())
|
||||
|
||||
results: list[ClassificationResult] = []
|
||||
for cookie in cookies:
|
||||
cr = classify_cookie(
|
||||
cookie.name,
|
||||
cookie.domain,
|
||||
allow_list,
|
||||
exact_known,
|
||||
regex_known,
|
||||
category_map,
|
||||
)
|
||||
results.append(cr)
|
||||
|
||||
# Update the cookie if matched
|
||||
if cr.matched and cr.category_id:
|
||||
cookie.category_id = cr.category_id
|
||||
if cr.vendor and not cookie.vendor:
|
||||
cookie.vendor = cr.vendor
|
||||
if cr.description and not cookie.description:
|
||||
cookie.description = cr.description
|
||||
|
||||
await db.flush()
|
||||
return results
|
||||
|
||||
|
||||
async def classify_single_cookie(
|
||||
db: AsyncSession,
|
||||
site_id: uuid.UUID,
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
) -> ClassificationResult:
|
||||
"""Classify a single cookie (e.g. for preview/testing)."""
|
||||
allow_list = await _load_allow_list(db, site_id)
|
||||
exact_known, regex_known = await _load_known_cookies(db)
|
||||
category_map = await _load_category_map(db)
|
||||
|
||||
return classify_cookie(
|
||||
cookie_name,
|
||||
cookie_domain,
|
||||
allow_list,
|
||||
exact_known,
|
||||
regex_known,
|
||||
category_map,
|
||||
)
|
||||
482
apps/api/src/services/compliance.py
Normal file
482
apps/api/src/services/compliance.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""Pluggable compliance rule engine.
|
||||
|
||||
Each regulatory framework (GDPR, CNIL, CCPA, ePrivacy, LGPD) is defined as a
|
||||
list of ComplianceRule objects. Rules evaluate site configuration, banner
|
||||
settings, cookie data, and consent parameters to produce issues with severity,
|
||||
message, and recommendation.
|
||||
|
||||
The engine aggregates individual rule results into per-framework reports with
|
||||
a compliance score, status, and actionable issues list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.schemas.compliance import (
|
||||
ComplianceIssue,
|
||||
Framework,
|
||||
FrameworkResult,
|
||||
Severity,
|
||||
)
|
||||
|
||||
# ── Rule context ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class SiteContext:
|
||||
"""All data needed to evaluate compliance rules for a site."""
|
||||
|
||||
# Site config fields
|
||||
blocking_mode: str = "opt_in"
|
||||
regional_modes: dict[str, str] | None = None
|
||||
tcf_enabled: bool = False
|
||||
gcm_enabled: bool = True
|
||||
consent_expiry_days: int = 365
|
||||
privacy_policy_url: str | None = None
|
||||
|
||||
# Banner config (JSONB — may have any keys)
|
||||
banner_config: dict[str, Any] | None = None
|
||||
|
||||
# Cookie statistics
|
||||
total_cookies: int = 0
|
||||
uncategorised_cookies: int = 0
|
||||
cookies_without_expiry: int = 0
|
||||
|
||||
# Consent settings
|
||||
has_reject_button: bool = True
|
||||
has_granular_choices: bool = True
|
||||
has_cookie_wall: bool = False
|
||||
pre_ticked_boxes: bool = False
|
||||
|
||||
|
||||
# ── Rule definition ───────────────────────────────────────────────────
|
||||
|
||||
# A check function receives a SiteContext and returns a list of issues.
|
||||
CheckFn = Callable[[SiteContext], list[ComplianceIssue]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplianceRule:
|
||||
"""A single compliance rule with an ID, description, and check function."""
|
||||
|
||||
rule_id: str
|
||||
description: str
|
||||
check: CheckFn
|
||||
|
||||
|
||||
# ── Helper factories ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _issue(
|
||||
rule_id: str,
|
||||
severity: Severity,
|
||||
message: str,
|
||||
recommendation: str,
|
||||
) -> ComplianceIssue:
|
||||
return ComplianceIssue(
|
||||
rule_id=rule_id,
|
||||
severity=severity,
|
||||
message=message,
|
||||
recommendation=recommendation,
|
||||
)
|
||||
|
||||
|
||||
# ── GDPR rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _gdpr_opt_in(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.blocking_mode != "opt_in":
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_opt_in",
|
||||
Severity.CRITICAL,
|
||||
"GDPR requires opt-in consent before setting non-essential cookies.",
|
||||
"Set blocking mode to 'opt_in'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_reject_button(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.has_reject_button:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_reject_button",
|
||||
Severity.CRITICAL,
|
||||
"The reject option must be as prominent as the accept option.",
|
||||
"Add a clearly visible 'Reject all' button to the first layer.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_granular_consent(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.has_granular_choices:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_granular",
|
||||
Severity.CRITICAL,
|
||||
"Users must be able to consent to individual cookie categories.",
|
||||
"Provide granular category toggles in the consent banner.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_no_cookie_wall(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.has_cookie_wall:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_cookie_wall",
|
||||
Severity.CRITICAL,
|
||||
"Cookie walls (blocking access unless consent is given) are not permitted.",
|
||||
"Remove the cookie wall and allow access without consent.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_no_pre_ticked(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.pre_ticked_boxes:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_pre_ticked",
|
||||
Severity.CRITICAL,
|
||||
"Pre-ticked consent boxes do not constitute valid consent.",
|
||||
"Ensure all non-essential category checkboxes default to unchecked.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.privacy_policy_url:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_privacy_policy",
|
||||
Severity.WARNING,
|
||||
"A link to the privacy policy should be accessible from the banner.",
|
||||
"Configure a privacy policy URL in the site settings.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_uncategorised_cookies(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.uncategorised_cookies > 0:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_uncategorised",
|
||||
Severity.WARNING,
|
||||
f"{ctx.uncategorised_cookies} cookie(s) have not been categorised.",
|
||||
"Review and assign a category to all discovered cookies.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
GDPR_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("gdpr_opt_in", "Opt-in consent required", _gdpr_opt_in),
|
||||
ComplianceRule("gdpr_reject_button", "Reject as prominent as accept", _gdpr_reject_button),
|
||||
ComplianceRule("gdpr_granular", "Granular category consent", _gdpr_granular_consent),
|
||||
ComplianceRule("gdpr_cookie_wall", "No cookie walls", _gdpr_no_cookie_wall),
|
||||
ComplianceRule("gdpr_pre_ticked", "No pre-ticked boxes", _gdpr_no_pre_ticked),
|
||||
ComplianceRule("gdpr_privacy_policy", "Privacy policy link", _gdpr_privacy_policy),
|
||||
ComplianceRule("gdpr_uncategorised", "All cookies categorised", _gdpr_uncategorised_cookies),
|
||||
]
|
||||
|
||||
|
||||
# ── CNIL rules (French — stricter GDPR) ──────────────────────────────
|
||||
|
||||
|
||||
def _cnil_consent_expiry(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CNIL mandates re-consent every 6 months (≈ 182 days)."""
|
||||
if ctx.consent_expiry_days > 182:
|
||||
return [
|
||||
_issue(
|
||||
"cnil_reconsent",
|
||||
Severity.CRITICAL,
|
||||
"CNIL requires re-consent at least every 6 months.",
|
||||
"Set consent_expiry_days to 182 or fewer.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _cnil_cookie_lifetime(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CNIL limits cookie lifetime to 13 months (≈ 395 days)."""
|
||||
if ctx.consent_expiry_days > 395:
|
||||
return [
|
||||
_issue(
|
||||
"cnil_cookie_lifetime",
|
||||
Severity.CRITICAL,
|
||||
"CNIL limits consent cookie lifetime to 13 months.",
|
||||
"Set consent_expiry_days to 395 or fewer.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _cnil_reject_first_layer(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CNIL requires 'Tout refuser' on the first layer of the banner."""
|
||||
if not ctx.has_reject_button:
|
||||
return [
|
||||
_issue(
|
||||
"cnil_reject_first_layer",
|
||||
Severity.CRITICAL,
|
||||
"CNIL requires a 'Reject all' button on the first layer of the banner.",
|
||||
"Ensure the 'Reject all' button is visible on the first banner view.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
# CNIL rules include all GDPR rules plus CNIL-specific ones
|
||||
CNIL_RULES: list[ComplianceRule] = [
|
||||
*GDPR_RULES,
|
||||
ComplianceRule("cnil_reconsent", "Re-consent every 6 months", _cnil_consent_expiry),
|
||||
ComplianceRule("cnil_cookie_lifetime", "13-month cookie lifetime", _cnil_cookie_lifetime),
|
||||
ComplianceRule(
|
||||
"cnil_reject_first_layer",
|
||||
"Reject on first layer",
|
||||
_cnil_reject_first_layer,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── CCPA / CPRA rules ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ccpa_opt_out(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CCPA uses an opt-out model — blocking mode should be opt_out."""
|
||||
if ctx.blocking_mode not in ("opt_out", "opt_in"):
|
||||
return [
|
||||
_issue(
|
||||
"ccpa_opt_out",
|
||||
Severity.CRITICAL,
|
||||
"CCPA requires at minimum an opt-out mechanism for data sale.",
|
||||
"Set blocking mode to 'opt_out' or 'opt_in'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _ccpa_do_not_sell(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CCPA requires a 'Do Not Sell My Personal Information' link."""
|
||||
bc = ctx.banner_config or {}
|
||||
has_dns = bc.get("show_do_not_sell_link", False)
|
||||
if not has_dns:
|
||||
return [
|
||||
_issue(
|
||||
"ccpa_do_not_sell",
|
||||
Severity.CRITICAL,
|
||||
"CCPA requires a 'Do Not Sell My Personal Information' link.",
|
||||
"Enable 'show_do_not_sell_link' in the banner configuration.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _ccpa_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.privacy_policy_url:
|
||||
return [
|
||||
_issue(
|
||||
"ccpa_privacy_policy",
|
||||
Severity.WARNING,
|
||||
"A privacy policy is required under CCPA.",
|
||||
"Configure a privacy policy URL in the site settings.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
CCPA_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("ccpa_opt_out", "Opt-out mechanism", _ccpa_opt_out),
|
||||
ComplianceRule("ccpa_do_not_sell", "Do Not Sell link", _ccpa_do_not_sell),
|
||||
ComplianceRule("ccpa_privacy_policy", "Privacy policy required", _ccpa_privacy_policy),
|
||||
]
|
||||
|
||||
|
||||
# ── ePrivacy rules ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _eprivacy_consent(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""ePrivacy requires consent for non-essential cookies."""
|
||||
if ctx.blocking_mode == "informational":
|
||||
return [
|
||||
_issue(
|
||||
"eprivacy_consent",
|
||||
Severity.CRITICAL,
|
||||
"ePrivacy Directive requires consent for non-essential cookies.",
|
||||
"Set blocking mode to 'opt_in' or 'opt_out'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _eprivacy_necessary_exempt(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""Strictly necessary cookies must be exempt from consent."""
|
||||
# This is a configuration guidance check — ensure opt-in mode
|
||||
# doesn't block necessary cookies (which the blocker handles by default).
|
||||
# We report an info if everything looks good.
|
||||
return []
|
||||
|
||||
|
||||
EPRIVACY_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("eprivacy_consent", "Consent for non-essential", _eprivacy_consent),
|
||||
ComplianceRule(
|
||||
"eprivacy_necessary_exempt",
|
||||
"Necessary cookies exempt",
|
||||
_eprivacy_necessary_exempt,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── LGPD rules (Brazil) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _lgpd_consent_basis(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""LGPD requires consent or legitimate interest as legal basis."""
|
||||
if ctx.blocking_mode == "informational":
|
||||
return [
|
||||
_issue(
|
||||
"lgpd_consent_basis",
|
||||
Severity.CRITICAL,
|
||||
"LGPD requires a legal basis (consent or legitimate interest) for data processing.",
|
||||
"Set blocking mode to 'opt_in' or 'opt_out'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _lgpd_data_controller(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""LGPD requires identifying the data controller."""
|
||||
if not ctx.privacy_policy_url:
|
||||
return [
|
||||
_issue(
|
||||
"lgpd_data_controller",
|
||||
Severity.WARNING,
|
||||
"LGPD requires identification of the data controller.",
|
||||
"Link to a privacy policy that identifies the data controller.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _lgpd_granular(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.has_granular_choices:
|
||||
return [
|
||||
_issue(
|
||||
"lgpd_granular",
|
||||
Severity.WARNING,
|
||||
"LGPD recommends granular consent choices.",
|
||||
"Provide individual category toggles in the consent banner.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
LGPD_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("lgpd_consent_basis", "Legal basis for processing", _lgpd_consent_basis),
|
||||
ComplianceRule("lgpd_data_controller", "Identify data controller", _lgpd_data_controller),
|
||||
ComplianceRule("lgpd_granular", "Granular consent choices", _lgpd_granular),
|
||||
]
|
||||
|
||||
|
||||
# ── Framework registry ────────────────────────────────────────────────
|
||||
|
||||
FRAMEWORK_RULES: dict[Framework, list[ComplianceRule]] = {
|
||||
Framework.GDPR: GDPR_RULES,
|
||||
Framework.CNIL: CNIL_RULES,
|
||||
Framework.CCPA: CCPA_RULES,
|
||||
Framework.EPRIVACY: EPRIVACY_RULES,
|
||||
Framework.LGPD: LGPD_RULES,
|
||||
}
|
||||
|
||||
|
||||
# ── Engine ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_framework_check(
|
||||
framework: Framework,
|
||||
ctx: SiteContext,
|
||||
) -> FrameworkResult:
|
||||
"""Run all rules for a single framework and produce a result."""
|
||||
rules = FRAMEWORK_RULES.get(framework, [])
|
||||
all_issues: list[ComplianceIssue] = []
|
||||
rules_passed = 0
|
||||
|
||||
for rule in rules:
|
||||
issues = rule.check(ctx)
|
||||
if issues:
|
||||
all_issues.extend(issues)
|
||||
else:
|
||||
rules_passed += 1
|
||||
|
||||
rules_checked = len(rules)
|
||||
score = _calculate_score(all_issues, rules_checked)
|
||||
status = _determine_status(score, all_issues)
|
||||
|
||||
return FrameworkResult(
|
||||
framework=framework,
|
||||
score=score,
|
||||
status=status,
|
||||
issues=all_issues,
|
||||
rules_checked=rules_checked,
|
||||
rules_passed=rules_passed,
|
||||
)
|
||||
|
||||
|
||||
def run_compliance_check(
|
||||
ctx: SiteContext,
|
||||
frameworks: list[Framework] | None = None,
|
||||
) -> list[FrameworkResult]:
|
||||
"""Run compliance checks for the specified (or all) frameworks."""
|
||||
targets = frameworks if frameworks else list(FRAMEWORK_RULES.keys())
|
||||
return [run_framework_check(fw, ctx) for fw in targets]
|
||||
|
||||
|
||||
def calculate_overall_score(results: list[FrameworkResult]) -> int:
|
||||
"""Calculate a weighted average score across framework results."""
|
||||
if not results:
|
||||
return 100
|
||||
total = sum(r.score for r in results)
|
||||
return round(total / len(results))
|
||||
|
||||
|
||||
# ── Scoring helpers ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _calculate_score(
|
||||
issues: list[ComplianceIssue],
|
||||
rules_checked: int,
|
||||
) -> int:
|
||||
"""Score from 0-100. Critical issues deduct 20 pts, warnings 5 pts."""
|
||||
if rules_checked == 0:
|
||||
return 100
|
||||
|
||||
deductions = 0
|
||||
for issue in issues:
|
||||
if issue.severity == Severity.CRITICAL:
|
||||
deductions += 20
|
||||
elif issue.severity == Severity.WARNING:
|
||||
deductions += 5
|
||||
# INFO issues don't affect the score
|
||||
|
||||
return max(0, 100 - deductions)
|
||||
|
||||
|
||||
def _determine_status(
|
||||
score: int,
|
||||
issues: list[ComplianceIssue],
|
||||
) -> str:
|
||||
"""Derive overall status string from score and issues."""
|
||||
has_critical = any(i.severity == Severity.CRITICAL for i in issues)
|
||||
if has_critical:
|
||||
return "non_compliant"
|
||||
if score >= 100:
|
||||
return "compliant"
|
||||
return "partial"
|
||||
156
apps/api/src/services/config_resolver.py
Normal file
156
apps/api/src/services/config_resolver.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Configuration hierarchy resolver.
|
||||
|
||||
Resolves site configuration by merging:
|
||||
System Defaults → Org Defaults → Site Group Defaults → Site Config → Regional Overrides
|
||||
|
||||
Produces a fully resolved public config suitable for the banner script.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# System-level defaults (hard-coded, lowest priority)
|
||||
SYSTEM_DEFAULTS: dict[str, Any] = {
|
||||
"blocking_mode": "opt_in",
|
||||
"tcf_enabled": False,
|
||||
"gpp_enabled": True,
|
||||
"gpp_supported_apis": ["usnat"],
|
||||
"gpc_enabled": True,
|
||||
"gpc_jurisdictions": ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"],
|
||||
"gpc_global_honour": False,
|
||||
"gcm_enabled": True,
|
||||
"shopify_privacy_enabled": False,
|
||||
"gcm_default": {
|
||||
"ad_storage": "denied",
|
||||
"ad_user_data": "denied",
|
||||
"ad_personalization": "denied",
|
||||
"analytics_storage": "denied",
|
||||
"functionality_storage": "denied",
|
||||
"personalization_storage": "denied",
|
||||
"security_storage": "granted",
|
||||
},
|
||||
"banner_config": None,
|
||||
"privacy_policy_url": None,
|
||||
"terms_url": None,
|
||||
"consent_expiry_days": 365,
|
||||
}
|
||||
|
||||
|
||||
def resolve_config(
|
||||
site_config: dict[str, Any],
|
||||
org_defaults: dict[str, Any] | None = None,
|
||||
group_defaults: dict[str, Any] | None = None,
|
||||
region: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve the full configuration by merging layers.
|
||||
|
||||
Args:
|
||||
site_config: Site-specific configuration from the database.
|
||||
org_defaults: Organisation-level default overrides (optional).
|
||||
group_defaults: Site-group-level default overrides (optional).
|
||||
region: ISO region code for regional mode override (optional).
|
||||
|
||||
Returns:
|
||||
Fully resolved configuration dictionary.
|
||||
"""
|
||||
# Start with system defaults
|
||||
resolved = {**SYSTEM_DEFAULTS}
|
||||
|
||||
# Apply organisation defaults (if any)
|
||||
if org_defaults:
|
||||
_merge_non_none(resolved, org_defaults)
|
||||
|
||||
# Apply site group defaults (if any)
|
||||
if group_defaults:
|
||||
_merge_non_none(resolved, group_defaults)
|
||||
|
||||
# Apply site-specific config
|
||||
_merge_non_none(resolved, site_config)
|
||||
|
||||
# Apply regional blocking mode override
|
||||
if region and site_config.get("regional_modes"):
|
||||
regional_modes = site_config["regional_modes"]
|
||||
if isinstance(regional_modes, dict):
|
||||
# Try exact match first, then fall back to DEFAULT
|
||||
regional_mode = regional_modes.get(region) or regional_modes.get("DEFAULT")
|
||||
if regional_mode:
|
||||
resolved["blocking_mode"] = regional_mode
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def build_public_config(
|
||||
site_id: str,
|
||||
resolved: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build a public configuration JSON for the banner script.
|
||||
|
||||
Strips internal fields and adds the site_id for identification.
|
||||
"""
|
||||
return {
|
||||
"id": resolved.get("id", ""),
|
||||
"site_id": site_id,
|
||||
"blocking_mode": resolved["blocking_mode"],
|
||||
"regional_modes": resolved.get("regional_modes"),
|
||||
"tcf_enabled": resolved["tcf_enabled"],
|
||||
"gpp_enabled": resolved["gpp_enabled"],
|
||||
"gpp_supported_apis": resolved.get("gpp_supported_apis"),
|
||||
"gpc_enabled": resolved["gpc_enabled"],
|
||||
"gpc_jurisdictions": resolved.get("gpc_jurisdictions"),
|
||||
"gpc_global_honour": resolved["gpc_global_honour"],
|
||||
"gcm_enabled": resolved["gcm_enabled"],
|
||||
"gcm_default": resolved.get("gcm_default"),
|
||||
"shopify_privacy_enabled": resolved["shopify_privacy_enabled"],
|
||||
"banner_config": resolved.get("banner_config"),
|
||||
"privacy_policy_url": resolved.get("privacy_policy_url"),
|
||||
"terms_url": resolved.get("terms_url"),
|
||||
"consent_expiry_days": resolved["consent_expiry_days"],
|
||||
"consent_group_id": resolved.get("consent_group_id"),
|
||||
"ab_test": resolved.get("ab_test"),
|
||||
}
|
||||
|
||||
|
||||
CONFIG_FIELDS = (
|
||||
"blocking_mode",
|
||||
"regional_modes",
|
||||
"tcf_enabled",
|
||||
"tcf_publisher_cc",
|
||||
"gpp_enabled",
|
||||
"gpp_supported_apis",
|
||||
"gpc_enabled",
|
||||
"gpc_jurisdictions",
|
||||
"gpc_global_honour",
|
||||
"gcm_enabled",
|
||||
"gcm_default",
|
||||
"shopify_privacy_enabled",
|
||||
"banner_config",
|
||||
"privacy_policy_url",
|
||||
"terms_url",
|
||||
"consent_expiry_days",
|
||||
)
|
||||
|
||||
|
||||
def orm_to_config_dict(obj: Any, *, include_id: bool = False) -> dict[str, Any]:
|
||||
"""Convert a SiteConfig or OrgConfig ORM object to a dict of config fields.
|
||||
|
||||
Only includes fields that are explicitly set (not NULL). This allows the
|
||||
hierarchy to work correctly: unset fields at higher-priority layers don't
|
||||
block inheritance from lower-priority layers.
|
||||
"""
|
||||
d: dict[str, Any] = {}
|
||||
if include_id and hasattr(obj, "id"):
|
||||
d["id"] = str(obj.id)
|
||||
for field in CONFIG_FIELDS:
|
||||
if hasattr(obj, field):
|
||||
value = getattr(obj, field)
|
||||
if value is not None:
|
||||
d[field] = value
|
||||
return d
|
||||
|
||||
|
||||
def _merge_non_none(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""Merge source into target, skipping None values in source."""
|
||||
for key, value in source.items():
|
||||
if value is not None:
|
||||
target[key] = value
|
||||
77
apps/api/src/services/cors.py
Normal file
77
apps/api/src/services/cors.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Dynamic CORS origin validation.
|
||||
|
||||
Provides an origin validator that checks incoming origins against
|
||||
registered site domains (primary + additional) in addition to the
|
||||
statically configured allowed_origins list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.site import Site
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_domain_from_origin(origin: str) -> str | None:
|
||||
"""Extract the hostname from an origin URL.
|
||||
|
||||
e.g. 'https://example.com:443' → 'example.com'
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(origin)
|
||||
return parsed.hostname
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def get_allowed_domains(db: AsyncSession) -> set[str]:
|
||||
"""Fetch all registered domains (primary + additional) from active sites."""
|
||||
result = await db.execute(
|
||||
select(Site.domain, Site.additional_domains).where(
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
domains: set[str] = set()
|
||||
for row in result.all():
|
||||
domains.add(row.domain.lower())
|
||||
if row.additional_domains:
|
||||
for d in row.additional_domains:
|
||||
domains.add(d.lower())
|
||||
|
||||
return domains
|
||||
|
||||
|
||||
def is_origin_allowed(
|
||||
origin: str,
|
||||
static_origins: list[str],
|
||||
registered_domains: set[str],
|
||||
) -> bool:
|
||||
"""Check if an origin is allowed by either the static list or registered domains.
|
||||
|
||||
Args:
|
||||
origin: The Origin header value (e.g. 'https://example.com').
|
||||
static_origins: Statically configured allowed origins from settings.
|
||||
registered_domains: Set of registered site domains from the database.
|
||||
|
||||
Returns:
|
||||
True if the origin is allowed.
|
||||
"""
|
||||
# Check static origins first (exact match)
|
||||
if origin in static_origins:
|
||||
return True
|
||||
|
||||
# Wildcard — allow everything
|
||||
if "*" in static_origins:
|
||||
return True
|
||||
|
||||
# Extract domain from origin and check against registered domains
|
||||
domain = extract_domain_from_origin(origin)
|
||||
return bool(domain and domain.lower() in registered_domains)
|
||||
54
apps/api/src/services/dependencies.py
Normal file
54
apps/api/src/services/dependencies.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError
|
||||
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.services.auth import decode_token
|
||||
|
||||
bearer_scheme = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
) -> CurrentUser:
|
||||
"""Extract and validate the current user from the JWT bearer token."""
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
)
|
||||
|
||||
return CurrentUser(
|
||||
id=uuid.UUID(payload["sub"]),
|
||||
organisation_id=uuid.UUID(payload["org_id"]),
|
||||
email=payload.get("email", ""),
|
||||
role=payload.get("role", "viewer"),
|
||||
)
|
||||
|
||||
|
||||
def require_role(*allowed_roles: str) -> Callable:
|
||||
"""Dependency factory that restricts access to users with specific roles."""
|
||||
|
||||
async def _check_role(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
if not current_user.has_role(*allowed_roles):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role '{current_user.role}' is not permitted for this action",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return _check_role
|
||||
339
apps/api/src/services/geoip.py
Normal file
339
apps/api/src/services/geoip.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""GeoIP service — resolve an IP address to a country/region code.
|
||||
|
||||
Resolution order (see :func:`detect_region`):
|
||||
|
||||
1. **CDN / proxy headers.** Operators configure ``GEOIP_COUNTRY_HEADER``
|
||||
(and optionally ``GEOIP_REGION_HEADER``) to match whatever their edge
|
||||
uses — e.g. ``cf-ipcountry`` + ``cf-region-code`` on Cloudflare
|
||||
Enterprise, or ``x-gclb-country`` + ``x-gclb-region`` on GCP. A short
|
||||
built-in country list (``cf-ipcountry``, ``x-vercel-ip-country``,
|
||||
``x-appengine-country``, ``x-country-code``) covers the common case
|
||||
where only country-level granularity is needed.
|
||||
2. **Local MaxMind GeoLite2-City database.** Set
|
||||
``GEOIP_MAXMIND_DB_PATH`` to a mounted ``.mmdb`` file. Gives both
|
||||
country and ISO 3166-2 subdivision without any external calls.
|
||||
3. **External ip-api.com lookup** (rate-limited, no API key). Last-ditch
|
||||
fallback; fine for development, not recommended for production.
|
||||
4. Unresolved — the caller should fall back to the default region.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import geoip2.database
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazily-initialised MaxMind reader. ``geoip2.database.Reader`` opens
|
||||
# the file once and then every lookup is a memory-mapped read, so we
|
||||
# cache it for the lifetime of the process. ``None`` means either no
|
||||
# path is configured, initialisation failed, or we haven't tried yet.
|
||||
_maxmind_reader: geoip2.database.Reader | None = None
|
||||
_maxmind_initialised = False
|
||||
|
||||
# Standard headers set by CDN / reverse proxy providers. Operators
|
||||
# running behind a CDN that uses a non-standard header (e.g. Google
|
||||
# Cloud Load Balancer's ``x-gclb-country``) can add one more via the
|
||||
# ``GEOIP_COUNTRY_HEADER`` env var — see ``detect_region_from_headers``.
|
||||
_GEO_HEADERS = [
|
||||
"cf-ipcountry", # Cloudflare
|
||||
"x-vercel-ip-country", # Vercel
|
||||
"x-appengine-country", # Google App Engine
|
||||
"x-country-code", # Generic / custom
|
||||
]
|
||||
|
||||
# Mapping from two-letter country code to region codes used in regional_modes
|
||||
# EU member states → "EU", US states handled separately, etc.
|
||||
_EU_COUNTRIES = frozenset(
|
||||
{
|
||||
"AT",
|
||||
"BE",
|
||||
"BG",
|
||||
"HR",
|
||||
"CY",
|
||||
"CZ",
|
||||
"DK",
|
||||
"EE",
|
||||
"FI",
|
||||
"FR",
|
||||
"DE",
|
||||
"GR",
|
||||
"HU",
|
||||
"IE",
|
||||
"IT",
|
||||
"LV",
|
||||
"LT",
|
||||
"LU",
|
||||
"MT",
|
||||
"NL",
|
||||
"PL",
|
||||
"PT",
|
||||
"RO",
|
||||
"SK",
|
||||
"SI",
|
||||
"ES",
|
||||
"SE",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeoResult:
|
||||
"""Result of a GeoIP lookup."""
|
||||
|
||||
country_code: str | None
|
||||
region: str | None
|
||||
|
||||
@property
|
||||
def is_resolved(self) -> bool:
|
||||
return self.country_code is not None
|
||||
|
||||
|
||||
def country_to_region(country_code: str, state_code: str | None = None) -> str:
|
||||
"""Map a country code (+ optional subdivision) to a regional_modes key.
|
||||
|
||||
Resolution order:
|
||||
- EU member states collapse to ``"EU"`` regardless of subdivision;
|
||||
regional_modes treats the bloc as a single unit.
|
||||
- Any other country with a subdivision produces ``"{CC}-{SUB}"``
|
||||
(e.g. ``"US-CA"``, ``"GB-SCT"``, ``"BR-SP"``). The operator
|
||||
opts in to subdivision-level resolution by configuring a key
|
||||
of that form in ``regional_modes``; if they don't, the
|
||||
fallback resolver still matches on the plain country code.
|
||||
- Country with no subdivision is returned as-is (``"GB"``,
|
||||
``"BR"``, …).
|
||||
"""
|
||||
upper = country_code.upper()
|
||||
|
||||
if upper in _EU_COUNTRIES:
|
||||
return "EU"
|
||||
|
||||
if state_code:
|
||||
return f"{upper}-{state_code.upper()}"
|
||||
|
||||
return upper
|
||||
|
||||
|
||||
def detect_region_from_headers(request: Request) -> GeoResult:
|
||||
"""Attempt to detect the visitor's region from proxy/CDN headers.
|
||||
|
||||
This is the fastest path — no external calls needed. A custom
|
||||
country header configured via ``GEOIP_COUNTRY_HEADER`` takes
|
||||
priority over the built-in list so operators can plumb in
|
||||
non-standard CDN/load-balancer headers (e.g. ``x-gclb-country``)
|
||||
without code changes.
|
||||
|
||||
When ``GEOIP_REGION_HEADER`` is also set and the custom country
|
||||
header resolved, the subdivision code from that header is paired
|
||||
with the country to build region keys like ``US-CA``. The built-in
|
||||
country list is country-only — operators who need subdivision
|
||||
granularity must configure the explicit pair.
|
||||
|
||||
Header lookups are case-insensitive.
|
||||
"""
|
||||
settings = get_settings()
|
||||
custom_country = settings.geoip_country_header
|
||||
custom_region = settings.geoip_region_header
|
||||
|
||||
if custom_country:
|
||||
value = request.headers.get(custom_country)
|
||||
if value and value.upper() != "XX":
|
||||
country = value.upper().strip()
|
||||
state: str | None = None
|
||||
if custom_region:
|
||||
raw_state = request.headers.get(custom_region)
|
||||
if raw_state and raw_state.upper() != "XX":
|
||||
# ISO 3166-2 subdivision codes may be prefixed
|
||||
# with the country (e.g. ``US-CA``) or bare (e.g.
|
||||
# ``CA``). Strip the prefix so ``country_to_region``
|
||||
# sees just the subdivision.
|
||||
stripped = raw_state.strip().upper()
|
||||
state = stripped.split("-", 1)[-1] if "-" in stripped else stripped
|
||||
return GeoResult(
|
||||
country_code=country,
|
||||
region=country_to_region(country, state),
|
||||
)
|
||||
|
||||
for header in _GEO_HEADERS:
|
||||
value = request.headers.get(header)
|
||||
if value and value.upper() != "XX":
|
||||
country = value.upper().strip()
|
||||
return GeoResult(
|
||||
country_code=country,
|
||||
region=country_to_region(country),
|
||||
)
|
||||
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str | None:
|
||||
"""Extract the real client IP from the request.
|
||||
|
||||
Checks X-Forwarded-For and X-Real-IP before falling back to the
|
||||
direct connection address.
|
||||
"""
|
||||
# X-Forwarded-For: client, proxy1, proxy2
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def lookup_ip_region(ip: str) -> GeoResult:
|
||||
"""Look up the region for an IP address via an external API.
|
||||
|
||||
Uses ip-api.com (free tier, no key required, 45 req/min).
|
||||
In production this should be replaced with a local MaxMind database.
|
||||
"""
|
||||
if _is_private_ip(ip):
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
resp = await client.get(
|
||||
f"http://ip-api.com/json/{ip}",
|
||||
params={"fields": "status,countryCode,region"},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
data = resp.json()
|
||||
if data.get("status") != "success":
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
country = data.get("countryCode")
|
||||
state = data.get("region") # State/province code
|
||||
if not country:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
region = country_to_region(country, state)
|
||||
return GeoResult(country_code=country, region=region)
|
||||
|
||||
except Exception:
|
||||
logger.debug("GeoIP lookup failed for %s", ip, exc_info=True)
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
|
||||
def _get_maxmind_reader() -> geoip2.database.Reader | None:
|
||||
"""Return the cached MaxMind reader, opening the DB on first use.
|
||||
|
||||
Caches both successful opens and failures (via
|
||||
``_maxmind_initialised``) so we don't retry a bad path on every
|
||||
request. Returns ``None`` if no path is configured or the DB
|
||||
couldn't be opened.
|
||||
"""
|
||||
global _maxmind_reader, _maxmind_initialised
|
||||
if _maxmind_initialised:
|
||||
return _maxmind_reader
|
||||
|
||||
_maxmind_initialised = True
|
||||
db_path = get_settings().geoip_maxmind_db_path
|
||||
if not db_path:
|
||||
return None
|
||||
|
||||
try:
|
||||
_maxmind_reader = geoip2.database.Reader(db_path)
|
||||
logger.info("GeoIP: opened MaxMind database at %s", db_path)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"GeoIP: failed to open MaxMind database at %s — falling back to "
|
||||
"external lookups. Check GEOIP_MAXMIND_DB_PATH and that the file "
|
||||
"is readable inside the container.",
|
||||
db_path,
|
||||
exc_info=True,
|
||||
)
|
||||
_maxmind_reader = None
|
||||
|
||||
return _maxmind_reader
|
||||
|
||||
|
||||
def lookup_ip_maxmind(ip: str) -> GeoResult:
|
||||
"""Resolve an IP via the local MaxMind database.
|
||||
|
||||
Memory-mapped read, no network I/O — cheap enough to call
|
||||
synchronously from the async path. Returns an unresolved
|
||||
``GeoResult`` when the DB isn't configured, the IP is private, or
|
||||
the record can't be found.
|
||||
"""
|
||||
if _is_private_ip(ip):
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
reader = _get_maxmind_reader()
|
||||
if reader is None:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
try:
|
||||
response = reader.city(ip)
|
||||
except Exception:
|
||||
logger.debug("MaxMind lookup failed for %s", ip, exc_info=True)
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
country = response.country.iso_code
|
||||
if not country:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
# ``subdivisions`` is ordered most-specific first; the first entry
|
||||
# is the ISO 3166-2 code (without the country prefix).
|
||||
state = response.subdivisions.most_specific.iso_code if response.subdivisions else None
|
||||
return GeoResult(
|
||||
country_code=country.upper(),
|
||||
region=country_to_region(country, state),
|
||||
)
|
||||
|
||||
|
||||
async def detect_region(request: Request) -> GeoResult:
|
||||
"""Detect the visitor's region.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. CDN/proxy headers (see :func:`detect_region_from_headers`).
|
||||
2. Local MaxMind database, if ``GEOIP_MAXMIND_DB_PATH`` is set.
|
||||
3. External ``ip-api.com`` lookup — last-ditch fallback.
|
||||
|
||||
Returns an unresolved :class:`GeoResult` if every tier fails.
|
||||
"""
|
||||
result = detect_region_from_headers(request)
|
||||
if result.is_resolved:
|
||||
return result
|
||||
|
||||
ip = get_client_ip(request)
|
||||
if not ip:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
if get_settings().geoip_maxmind_db_path:
|
||||
result = lookup_ip_maxmind(ip)
|
||||
if result.is_resolved:
|
||||
return result
|
||||
|
||||
return await lookup_ip_region(ip)
|
||||
|
||||
|
||||
def _is_private_ip(ip: str) -> bool:
|
||||
"""Check if an IP address is a private/loopback address."""
|
||||
return (
|
||||
ip.startswith("127.")
|
||||
or ip.startswith("10.")
|
||||
or ip.startswith("192.168.")
|
||||
or ip.startswith("172.16.")
|
||||
or ip.startswith("172.17.")
|
||||
or ip.startswith("172.18.")
|
||||
or ip.startswith("172.19.")
|
||||
or ip.startswith("172.2")
|
||||
or ip.startswith("172.3")
|
||||
or ip == "::1"
|
||||
or ip == "localhost"
|
||||
)
|
||||
41
apps/api/src/services/pseudonymisation.py
Normal file
41
apps/api/src/services/pseudonymisation.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Pseudonymisation helpers for consent records.
|
||||
|
||||
Consent records capture a hash of the visitor's IP address and
|
||||
user-agent string for abuse protection and audit trail purposes.
|
||||
|
||||
Previously this used an unsalted truncated SHA-256, which is trivially
|
||||
reversible for IPv4 addresses (only ~4 billion inputs). We now use
|
||||
HMAC-SHA256 keyed with a server-side secret so the hash cannot be
|
||||
recovered without access to the secret.
|
||||
|
||||
Public API: :func:`pseudonymise`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
from hashlib import sha256
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
# Length of the hex-encoded digest stored in the database. 32 hex chars
|
||||
# = 128 bits, which is more than enough entropy while keeping the
|
||||
# column compact. (Previous code used 16 hex chars = 64 bits.)
|
||||
_DIGEST_HEX_LEN = 32
|
||||
|
||||
|
||||
def pseudonymise(value: str) -> str:
|
||||
"""Return a keyed hash of *value* safe to store in an audit record.
|
||||
|
||||
Uses HMAC-SHA256 with the configured ``pseudonymisation_secret``
|
||||
(falling back to ``jwt_secret_key`` if not explicitly set). The
|
||||
resulting hex digest is truncated to 32 characters (128 bits).
|
||||
|
||||
An empty input always returns an empty string so callers don't
|
||||
have to branch on missing data.
|
||||
"""
|
||||
if not value:
|
||||
return ""
|
||||
key = get_settings().pseudonymisation_key
|
||||
digest = hmac.new(key, value.encode("utf-8"), sha256).hexdigest()
|
||||
return digest[:_DIGEST_HEX_LEN]
|
||||
89
apps/api/src/services/publisher.py
Normal file
89
apps/api/src/services/publisher.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""CDN publishing pipeline.
|
||||
|
||||
Publishes resolved site configurations as static JSON files for the
|
||||
banner script to fetch. Supports local filesystem (development) and
|
||||
can be extended for S3/GCS/CloudFront.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
from .config_resolver import build_public_config, resolve_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PublishResult:
|
||||
"""Result of a publish operation."""
|
||||
|
||||
def __init__(self, success: bool, path: str, error: str | None = None) -> None:
|
||||
self.success = success
|
||||
self.path = path
|
||||
self.error = error
|
||||
self.published_at = datetime.now(UTC).isoformat() if success else None
|
||||
|
||||
|
||||
async def publish_site_config(
|
||||
site_id: str,
|
||||
site_config: dict[str, Any],
|
||||
org_defaults: dict[str, Any] | None = None,
|
||||
) -> PublishResult:
|
||||
"""Resolve and publish a site configuration to CDN.
|
||||
|
||||
Args:
|
||||
site_id: The site UUID as a string.
|
||||
site_config: Raw site configuration from the database.
|
||||
org_defaults: Organisation-level defaults (optional).
|
||||
|
||||
Returns:
|
||||
PublishResult with success status and path.
|
||||
"""
|
||||
try:
|
||||
# Resolve the full config hierarchy
|
||||
resolved = resolve_config(site_config, org_defaults)
|
||||
|
||||
# Build the public-facing config
|
||||
public_config = build_public_config(site_id, resolved)
|
||||
|
||||
# Publish to the configured backend
|
||||
settings = get_settings()
|
||||
path = await _publish_local(site_id, public_config, settings.cdn_base_url)
|
||||
|
||||
logger.info("Published config for site %s to %s", site_id, path)
|
||||
return PublishResult(success=True, path=path)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to publish config for site %s", site_id)
|
||||
return PublishResult(success=False, path="", error=str(exc))
|
||||
|
||||
|
||||
async def _publish_local(
|
||||
site_id: str,
|
||||
config: dict[str, Any],
|
||||
cdn_base: str,
|
||||
) -> str:
|
||||
"""Publish config to local filesystem (for development/Docker Compose).
|
||||
|
||||
Writes to the CDN proxy's HTML directory so nginx can serve it.
|
||||
"""
|
||||
# Default local publish directory
|
||||
publish_dir = Path("/app/cdn-publish") if Path("/app").exists() else Path("cdn-publish")
|
||||
publish_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the config JSON
|
||||
config_path = publish_dir / f"site-config-{site_id}.json"
|
||||
config_path.write_text(json.dumps(config, indent=2, default=str))
|
||||
|
||||
# Also write a versioned copy for cache-busting
|
||||
version = datetime.now(UTC).strftime("%Y%m%d%H%M%S")
|
||||
versioned_path = publish_dir / f"site-config-{site_id}-{version}.json"
|
||||
versioned_path.write_text(json.dumps(config, indent=2, default=str))
|
||||
|
||||
return str(config_path)
|
||||
322
apps/api/src/services/scanner.py
Normal file
322
apps/api/src/services/scanner.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Scan orchestration and diff engine.
|
||||
|
||||
Provides scan job lifecycle management, result diffing between scans,
|
||||
and cookie inventory synchronisation from scan results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.cookie import Cookie
|
||||
from src.models.scan import ScanJob, ScanResult
|
||||
from src.models.site import Site
|
||||
from src.schemas.scanner import (
|
||||
CookieDiffItem,
|
||||
DiffStatus,
|
||||
ScanDiffResponse,
|
||||
)
|
||||
|
||||
|
||||
async def create_scan_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
site_id: uuid.UUID,
|
||||
trigger: str = "manual",
|
||||
max_pages: int = 50,
|
||||
) -> ScanJob:
|
||||
"""Create a new scan job in 'pending' state."""
|
||||
job = ScanJob(
|
||||
site_id=site_id,
|
||||
status="pending",
|
||||
trigger=trigger,
|
||||
pages_total=max_pages,
|
||||
)
|
||||
db.add(job)
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
|
||||
async def start_scan_job(db: AsyncSession, job: ScanJob) -> ScanJob:
|
||||
"""Transition a scan job to 'running'.
|
||||
|
||||
Idempotent: if the job is already running (e.g. Celery re-delivered the
|
||||
task after a worker crash), this is a no-op. Also handles re-delivery
|
||||
after a transient failure that left the job in 'failed' state mid-retry.
|
||||
"""
|
||||
if job.status == "running":
|
||||
return job
|
||||
job.status = "running"
|
||||
job.started_at = datetime.now(UTC)
|
||||
# Reset any previous error so the retry starts clean
|
||||
job.error_message = None
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
|
||||
async def complete_scan_job(
|
||||
db: AsyncSession,
|
||||
job: ScanJob,
|
||||
*,
|
||||
pages_scanned: int = 0,
|
||||
cookies_found: int = 0,
|
||||
error_message: str | None = None,
|
||||
) -> ScanJob:
|
||||
"""Mark a scan job as completed or failed."""
|
||||
job.status = "failed" if error_message else "completed"
|
||||
job.completed_at = datetime.now(UTC)
|
||||
job.pages_scanned = pages_scanned
|
||||
job.cookies_found = cookies_found
|
||||
job.error_message = error_message
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
|
||||
async def add_scan_result(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
scan_job_id: uuid.UUID,
|
||||
page_url: str,
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
storage_type: str = "cookie",
|
||||
attributes: dict | None = None,
|
||||
script_source: str | None = None,
|
||||
auto_category: str | None = None,
|
||||
initiator_chain: list[str] | None = None,
|
||||
) -> ScanResult:
|
||||
"""Record a single cookie discovery from a scan."""
|
||||
result = ScanResult(
|
||||
scan_job_id=scan_job_id,
|
||||
page_url=page_url,
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
storage_type=storage_type,
|
||||
attributes=attributes,
|
||||
script_source=script_source,
|
||||
auto_category=auto_category,
|
||||
initiator_chain=initiator_chain,
|
||||
)
|
||||
db.add(result)
|
||||
await db.flush()
|
||||
return result
|
||||
|
||||
|
||||
async def get_previous_completed_scan(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
site_id: uuid.UUID,
|
||||
before_scan_id: uuid.UUID,
|
||||
) -> ScanJob | None:
|
||||
"""Find the most recent completed scan before the given one."""
|
||||
# First get the creation time of the reference scan
|
||||
ref_result = await db.execute(select(ScanJob.created_at).where(ScanJob.id == before_scan_id))
|
||||
ref_time = ref_result.scalar_one_or_none()
|
||||
if ref_time is None:
|
||||
return None
|
||||
|
||||
result = await db.execute(
|
||||
select(ScanJob)
|
||||
.where(
|
||||
ScanJob.site_id == site_id,
|
||||
ScanJob.status == "completed",
|
||||
ScanJob.id != before_scan_id,
|
||||
ScanJob.created_at < ref_time,
|
||||
)
|
||||
.order_by(ScanJob.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def _result_key(r: ScanResult) -> tuple[str, str, str]:
|
||||
"""Unique key for a scan result (cookie identity)."""
|
||||
return (r.cookie_name, r.cookie_domain, r.storage_type)
|
||||
|
||||
|
||||
async def compute_scan_diff(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
current_scan_id: uuid.UUID,
|
||||
site_id: uuid.UUID,
|
||||
) -> ScanDiffResponse:
|
||||
"""Compute the diff between the current scan and the previous one.
|
||||
|
||||
Returns new, removed, and changed cookies. If no previous scan exists,
|
||||
all cookies in the current scan are marked as 'new'.
|
||||
"""
|
||||
previous_scan = await get_previous_completed_scan(
|
||||
db, site_id=site_id, before_scan_id=current_scan_id
|
||||
)
|
||||
|
||||
# Load current scan results
|
||||
current_results = await db.execute(
|
||||
select(ScanResult).where(ScanResult.scan_job_id == current_scan_id)
|
||||
)
|
||||
current_items = list(current_results.scalars().all())
|
||||
current_keys = {_result_key(r): r for r in current_items}
|
||||
|
||||
if previous_scan is None:
|
||||
# No previous scan — everything is new
|
||||
new_cookies = [
|
||||
CookieDiffItem(
|
||||
name=r.cookie_name,
|
||||
domain=r.cookie_domain,
|
||||
storage_type=r.storage_type,
|
||||
diff_status=DiffStatus.NEW,
|
||||
details="First scan — no previous data",
|
||||
)
|
||||
for r in current_items
|
||||
]
|
||||
return ScanDiffResponse(
|
||||
current_scan_id=current_scan_id,
|
||||
previous_scan_id=None,
|
||||
new_cookies=new_cookies,
|
||||
total_new=len(new_cookies),
|
||||
)
|
||||
|
||||
# Load previous scan results
|
||||
prev_results = await db.execute(
|
||||
select(ScanResult).where(ScanResult.scan_job_id == previous_scan.id)
|
||||
)
|
||||
prev_items = list(prev_results.scalars().all())
|
||||
prev_keys = {_result_key(r): r for r in prev_items}
|
||||
|
||||
new_cookies: list[CookieDiffItem] = []
|
||||
removed_cookies: list[CookieDiffItem] = []
|
||||
changed_cookies: list[CookieDiffItem] = []
|
||||
|
||||
# New cookies: in current but not in previous
|
||||
for key, r in current_keys.items():
|
||||
if key not in prev_keys:
|
||||
new_cookies.append(
|
||||
CookieDiffItem(
|
||||
name=r.cookie_name,
|
||||
domain=r.cookie_domain,
|
||||
storage_type=r.storage_type,
|
||||
diff_status=DiffStatus.NEW,
|
||||
)
|
||||
)
|
||||
|
||||
# Removed cookies: in previous but not in current
|
||||
for key, r in prev_keys.items():
|
||||
if key not in current_keys:
|
||||
removed_cookies.append(
|
||||
CookieDiffItem(
|
||||
name=r.cookie_name,
|
||||
domain=r.cookie_domain,
|
||||
storage_type=r.storage_type,
|
||||
diff_status=DiffStatus.REMOVED,
|
||||
)
|
||||
)
|
||||
|
||||
# Changed cookies: in both but with different attributes
|
||||
for key in current_keys:
|
||||
if key in prev_keys:
|
||||
curr = current_keys[key]
|
||||
prev = prev_keys[key]
|
||||
changes: list[str] = []
|
||||
|
||||
if curr.script_source != prev.script_source:
|
||||
changes.append("script_source changed")
|
||||
if curr.auto_category != prev.auto_category:
|
||||
changes.append("auto_category changed")
|
||||
# Compare cookie attributes (e.g. secure, httpOnly)
|
||||
if (curr.attributes or {}) != (prev.attributes or {}):
|
||||
changes.append("attributes changed")
|
||||
|
||||
if changes:
|
||||
changed_cookies.append(
|
||||
CookieDiffItem(
|
||||
name=curr.cookie_name,
|
||||
domain=curr.cookie_domain,
|
||||
storage_type=curr.storage_type,
|
||||
diff_status=DiffStatus.CHANGED,
|
||||
details="; ".join(changes),
|
||||
)
|
||||
)
|
||||
|
||||
return ScanDiffResponse(
|
||||
current_scan_id=current_scan_id,
|
||||
previous_scan_id=previous_scan.id,
|
||||
new_cookies=new_cookies,
|
||||
removed_cookies=removed_cookies,
|
||||
changed_cookies=changed_cookies,
|
||||
total_new=len(new_cookies),
|
||||
total_removed=len(removed_cookies),
|
||||
total_changed=len(changed_cookies),
|
||||
)
|
||||
|
||||
|
||||
async def sync_scan_results_to_cookies(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
scan_job_id: uuid.UUID,
|
||||
site_id: uuid.UUID,
|
||||
) -> int:
|
||||
"""Upsert scan results into the site's cookie inventory.
|
||||
|
||||
Creates new Cookie records for newly discovered items or updates
|
||||
last_seen_at for existing ones. Returns the number of new cookies.
|
||||
"""
|
||||
results = await db.execute(select(ScanResult).where(ScanResult.scan_job_id == scan_job_id))
|
||||
items = list(results.scalars().all())
|
||||
|
||||
now_iso = datetime.now(UTC).isoformat()
|
||||
new_count = 0
|
||||
|
||||
for item in items:
|
||||
existing = await db.execute(
|
||||
select(Cookie).where(
|
||||
Cookie.site_id == site_id,
|
||||
Cookie.name == item.cookie_name,
|
||||
Cookie.domain == item.cookie_domain,
|
||||
Cookie.storage_type == item.storage_type,
|
||||
)
|
||||
)
|
||||
cookie = existing.scalar_one_or_none()
|
||||
|
||||
if cookie:
|
||||
cookie.last_seen_at = now_iso
|
||||
else:
|
||||
cookie = Cookie(
|
||||
site_id=site_id,
|
||||
name=item.cookie_name,
|
||||
domain=item.cookie_domain,
|
||||
storage_type=item.storage_type,
|
||||
review_status="pending",
|
||||
first_seen_at=now_iso,
|
||||
last_seen_at=now_iso,
|
||||
)
|
||||
db.add(cookie)
|
||||
new_count += 1
|
||||
|
||||
await db.flush()
|
||||
return new_count
|
||||
|
||||
|
||||
async def get_sites_due_for_scan(db: AsyncSession) -> list[Site]:
|
||||
"""Find sites with a scan schedule that are due for scanning.
|
||||
|
||||
A site is due when it has a scan_schedule_cron set and either has
|
||||
never been scanned or the last scan completed before the schedule
|
||||
interval. For simplicity, this checks the most recent scan's
|
||||
completed_at against the current time minus a derived interval.
|
||||
"""
|
||||
from src.models.site_config import SiteConfig
|
||||
|
||||
# Find sites with a cron schedule
|
||||
result = await db.execute(
|
||||
select(Site)
|
||||
.join(SiteConfig, SiteConfig.site_id == Site.id)
|
||||
.where(
|
||||
Site.deleted_at.is_(None),
|
||||
Site.is_active.is_(True),
|
||||
SiteConfig.scan_schedule_cron.isnot(None),
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
0
apps/api/src/tasks/__init__.py
Normal file
0
apps/api/src/tasks/__init__.py
Normal file
87
apps/api/src/tasks/retention.py
Normal file
87
apps/api/src/tasks/retention.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Consent record retention purge.
|
||||
|
||||
Deletes consent records older than each site's configured
|
||||
``consent_retention_days``. Sites with no retention configured are
|
||||
skipped — operators must explicitly opt in per site (or set it at the
|
||||
org/system level and let the cascade resolve it).
|
||||
|
||||
Scheduled by ``celery beat`` daily at 01:00 UTC via the entry in
|
||||
``src.celery_app.beat_schedule``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from src.celery_app import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _purge() -> dict[str, int]:
|
||||
"""Delete expired consent records across all sites with retention set.
|
||||
|
||||
Returns a summary ``{"sites_processed": N, "records_deleted": M}``.
|
||||
"""
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.models.consent import ConsentRecord
|
||||
from src.models.site_config import SiteConfig
|
||||
|
||||
settings = get_settings()
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
|
||||
sites_processed = 0
|
||||
records_deleted = 0
|
||||
|
||||
async with AsyncSession(engine, expire_on_commit=False) as session:
|
||||
configs = (
|
||||
(
|
||||
await session.execute(
|
||||
select(SiteConfig).where(SiteConfig.consent_retention_days.isnot(None)),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
for cfg in configs:
|
||||
retention_days = cfg.consent_retention_days
|
||||
if not retention_days or retention_days <= 0:
|
||||
continue
|
||||
cutoff = now - timedelta(days=retention_days)
|
||||
result = await session.execute(
|
||||
delete(ConsentRecord).where(
|
||||
ConsentRecord.site_id == cfg.site_id,
|
||||
ConsentRecord.consented_at < cutoff,
|
||||
),
|
||||
)
|
||||
deleted = result.rowcount or 0
|
||||
records_deleted += deleted
|
||||
sites_processed += 1
|
||||
if deleted:
|
||||
logger.info(
|
||||
"retention.purged",
|
||||
extra={
|
||||
"site_id": str(cfg.site_id),
|
||||
"retention_days": retention_days,
|
||||
"deleted": deleted,
|
||||
"cutoff": cutoff.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
await engine.dispose()
|
||||
return {"sites_processed": sites_processed, "records_deleted": records_deleted}
|
||||
|
||||
|
||||
@app.task(name="src.tasks.retention.purge_expired_consent_records")
|
||||
def purge_expired_consent_records() -> dict[str, int]:
|
||||
"""Celery entrypoint for the retention purge."""
|
||||
return asyncio.run(_purge())
|
||||
308
apps/api/src/tasks/scanner.py
Normal file
308
apps/api/src/tasks/scanner.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Celery tasks for scan job execution and scheduling.
|
||||
|
||||
The run_scan task calls the scanner HTTP service to execute a Playwright
|
||||
crawl, then processes the results: stores scan results, runs auto-
|
||||
classification, syncs discovered cookies to the site inventory, and
|
||||
computes diffs against the previous scan.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
from src.celery_app import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.task(name="src.tasks.scanner.run_scan", bind=True, max_retries=2)
|
||||
def run_scan(self, scan_job_id: str, site_id: str) -> dict:
|
||||
"""Execute a scan job by calling the scanner service.
|
||||
|
||||
1. Transition job to 'running'
|
||||
2. Look up site domain
|
||||
3. Call scanner HTTP service with the domain
|
||||
4. Store scan results and run auto-classification
|
||||
5. Sync discovered cookies to the site inventory
|
||||
6. Mark job as completed
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.models.scan import ScanJob
|
||||
from src.models.site import Site
|
||||
from src.services.classification import classify_single_cookie
|
||||
from src.services.scanner import (
|
||||
add_scan_result,
|
||||
complete_scan_job,
|
||||
start_scan_job,
|
||||
sync_scan_results_to_cookies,
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
job_uuid = uuid.UUID(scan_job_id)
|
||||
site_uuid = uuid.UUID(site_id)
|
||||
|
||||
async def _execute() -> dict:
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async with AsyncSession(engine, expire_on_commit=False) as db:
|
||||
try:
|
||||
# Load the job
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
return {"error": "Scan job not found"}
|
||||
|
||||
# Load the site to get the domain
|
||||
site_result = await db.execute(select(Site).where(Site.id == site_uuid))
|
||||
site = site_result.scalar_one_or_none()
|
||||
if site is None:
|
||||
return {"error": "Site not found"}
|
||||
|
||||
# Transition to running
|
||||
await start_scan_job(db, job)
|
||||
await db.commit()
|
||||
|
||||
# Call the scanner service
|
||||
scanner_url = f"{settings.scanner_service_url}/scan"
|
||||
max_pages = job.pages_total or 50
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(settings.scanner_timeout_seconds)
|
||||
) as client:
|
||||
resp = await client.post(
|
||||
scanner_url,
|
||||
json={
|
||||
"domain": site.domain,
|
||||
"max_pages": max_pages,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
scan_data = resp.json()
|
||||
|
||||
# Store scan results
|
||||
cookies = scan_data.get("cookies", [])
|
||||
pages_crawled = scan_data.get("pages_crawled", 0)
|
||||
|
||||
for cookie in cookies:
|
||||
# Auto-classify the cookie
|
||||
category = await classify_single_cookie(
|
||||
db,
|
||||
site_id=site_uuid,
|
||||
cookie_name=cookie["name"],
|
||||
cookie_domain=cookie["domain"],
|
||||
)
|
||||
|
||||
await add_scan_result(
|
||||
db,
|
||||
scan_job_id=job_uuid,
|
||||
page_url=cookie.get("page_url", ""),
|
||||
cookie_name=cookie["name"],
|
||||
cookie_domain=cookie["domain"],
|
||||
storage_type=cookie.get("storage_type", "cookie"),
|
||||
attributes={
|
||||
"path": cookie.get("path"),
|
||||
"http_only": cookie.get("http_only"),
|
||||
"secure": cookie.get("secure"),
|
||||
"same_site": cookie.get("same_site"),
|
||||
"value_length": cookie.get("value_length", 0),
|
||||
},
|
||||
script_source=cookie.get("script_source"),
|
||||
auto_category=category.category_slug if category else None,
|
||||
initiator_chain=cookie.get("initiator_chain") or None,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Mark job as completed
|
||||
await complete_scan_job(
|
||||
db,
|
||||
job,
|
||||
pages_scanned=pages_crawled,
|
||||
cookies_found=len(cookies),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Sync results to cookie inventory
|
||||
new_cookies = await sync_scan_results_to_cookies(
|
||||
db,
|
||||
scan_job_id=job_uuid,
|
||||
site_id=site_uuid,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
"Scan %s completed: %d pages, %d cookies, %d new",
|
||||
scan_job_id,
|
||||
pages_crawled,
|
||||
len(cookies),
|
||||
new_cookies,
|
||||
)
|
||||
|
||||
return {
|
||||
"scan_job_id": scan_job_id,
|
||||
"status": "completed",
|
||||
"pages_scanned": pages_crawled,
|
||||
"cookies_found": len(cookies),
|
||||
"new_cookies_synced": new_cookies,
|
||||
}
|
||||
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Scanner service error for job %s: %s", scan_job_id, exc)
|
||||
await db.rollback()
|
||||
# Only mark failed on the final retry; otherwise let the
|
||||
# retry set status back to "running" cleanly.
|
||||
if self.request.retries >= self.max_retries:
|
||||
await _mark_failed(db, job_uuid, f"Scanner service error: {exc}")
|
||||
raise self.retry(exc=exc, countdown=30) from exc
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Scan task failed for job %s", scan_job_id)
|
||||
await db.rollback()
|
||||
await _mark_failed(db, job_uuid, str(exc))
|
||||
return {"error": str(exc)}
|
||||
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
return asyncio.run(_execute())
|
||||
|
||||
|
||||
async def _mark_failed(db, job_uuid: uuid.UUID, message: str) -> None:
|
||||
"""Mark a scan job as failed."""
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.models.scan import ScanJob
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
try:
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
|
||||
job = result.scalar_one_or_none()
|
||||
if job:
|
||||
await complete_scan_job(db, job, error_message=message)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to mark scan job %s as failed", job_uuid)
|
||||
|
||||
|
||||
@app.task(name="src.tasks.scanner.check_scheduled_scans")
|
||||
def check_scheduled_scans() -> dict:
|
||||
"""Periodic task: check which sites are due for a scheduled scan.
|
||||
|
||||
Runs every 15 minutes via Celery Beat. For each site with a
|
||||
scan_schedule_cron, checks if a scan is overdue and triggers one.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.services.scanner import create_scan_job, get_sites_due_for_scan
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
async def _check() -> dict:
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async with AsyncSession(engine, expire_on_commit=False) as db:
|
||||
try:
|
||||
sites = await get_sites_due_for_scan(db)
|
||||
triggered = 0
|
||||
|
||||
for site in sites:
|
||||
job = await create_scan_job(db, site_id=site.id, trigger="scheduled")
|
||||
await db.commit()
|
||||
# Dispatch the scan task
|
||||
run_scan.delay(str(job.id), str(site.id))
|
||||
triggered += 1
|
||||
|
||||
return {"sites_checked": len(sites), "scans_triggered": triggered}
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
raise
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
return asyncio.run(_check())
|
||||
|
||||
|
||||
@app.task(name="src.tasks.scanner.recover_stale_scans")
|
||||
def recover_stale_scans() -> dict:
|
||||
"""Periodic task: detect and recover scan jobs stuck in pending/running.
|
||||
|
||||
- Jobs stuck in 'pending' for >5 minutes are re-dispatched to Celery.
|
||||
- Jobs stuck in 'running' for >10 minutes are marked as failed.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.models.scan import ScanJob
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
async def _recover() -> dict:
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async with AsyncSession(engine, expire_on_commit=False) as db:
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
stale_pending_cutoff = now - timedelta(minutes=5)
|
||||
stale_running_cutoff = now - timedelta(minutes=10)
|
||||
|
||||
result = await db.execute(
|
||||
select(ScanJob).where(
|
||||
or_(
|
||||
# Pending too long — likely never picked up
|
||||
(ScanJob.status == "pending")
|
||||
& (ScanJob.created_at < stale_pending_cutoff),
|
||||
# Running too long — likely worker died
|
||||
(ScanJob.status == "running")
|
||||
& (ScanJob.started_at < stale_running_cutoff),
|
||||
)
|
||||
)
|
||||
)
|
||||
stale_jobs = list(result.scalars().all())
|
||||
|
||||
redispatched = 0
|
||||
failed = 0
|
||||
|
||||
for job in stale_jobs:
|
||||
if job.status == "pending":
|
||||
# Re-dispatch to Celery
|
||||
logger.warning("Re-dispatching stale pending scan job %s", job.id)
|
||||
run_scan.delay(str(job.id), str(job.site_id))
|
||||
redispatched += 1
|
||||
elif job.status == "running":
|
||||
# Mark as failed — the worker likely died
|
||||
logger.warning("Failing stale running scan job %s", job.id)
|
||||
await complete_scan_job(
|
||||
db,
|
||||
job,
|
||||
error_message=(
|
||||
"Job timed out (running too long, worker may have crashed)"
|
||||
),
|
||||
)
|
||||
failed += 1
|
||||
|
||||
await db.commit()
|
||||
return {
|
||||
"stale_jobs_found": len(stale_jobs),
|
||||
"redispatched": redispatched,
|
||||
"failed": failed,
|
||||
}
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
raise
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
return asyncio.run(_recover())
|
||||
8
apps/api/start.sh
Executable file
8
apps/api/start.sh
Executable file
@@ -0,0 +1,8 @@
|
||||
#!/bin/sh
|
||||
exec uvicorn src.main:app \
|
||||
--host 0.0.0.0 \
|
||||
--port "${PORT:-8000}" \
|
||||
--workers "${WEB_CONCURRENCY:-1}" \
|
||||
--access-log \
|
||||
--proxy-headers \
|
||||
--forwarded-allow-ips '*'
|
||||
0
apps/api/tests/__init__.py
Normal file
0
apps/api/tests/__init__.py
Normal file
241
apps/api/tests/conftest.py
Normal file
241
apps/api/tests/conftest.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""Shared test fixtures for the CMP API test suite.
|
||||
|
||||
Provides two modes:
|
||||
- Unit tests: use `app` and `client` fixtures (no database required)
|
||||
- Integration tests: use `db_client` fixture (requires PostgreSQL)
|
||||
|
||||
Integration tests are automatically skipped when no database is available.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Disable rate limiting for the test suite. Many tests make dozens of
|
||||
# requests from the same loopback address in rapid succession and the
|
||||
# middleware would legitimately reject them as a DoS; the middleware
|
||||
# has its own dedicated test module.
|
||||
os.environ.setdefault("RATE_LIMIT_ENABLED", "false")
|
||||
os.environ.setdefault("ENVIRONMENT", "test")
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.main import create_app
|
||||
from src.models.base import Base
|
||||
|
||||
# ── Detect whether a test database is available ──────────────────────
|
||||
|
||||
_TEST_DB_URL = os.environ.get(
|
||||
"TEST_DATABASE_URL",
|
||||
os.environ.get("DATABASE_URL", ""),
|
||||
)
|
||||
|
||||
_HAS_DB = bool(_TEST_DB_URL) and "localhost" in _TEST_DB_URL
|
||||
|
||||
|
||||
def _requires_db(fn):
|
||||
"""Mark a test as requiring a live database.
|
||||
|
||||
Also pins the event loop to session scope so that fixtures sharing the
|
||||
session-scoped engine don't get 'Future attached to a different loop'.
|
||||
"""
|
||||
fn = pytest.mark.asyncio(loop_scope="session")(fn)
|
||||
fn = pytest.mark.skipif(not _HAS_DB, reason="No test database available")(fn)
|
||||
return fn
|
||||
|
||||
|
||||
requires_db = _requires_db
|
||||
|
||||
|
||||
# ── Unit test fixtures (no database) ─────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app():
|
||||
"""Create a fresh FastAPI application instance."""
|
||||
return create_app()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def client(app):
|
||||
"""Async HTTP client for unit tests (no database)."""
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
# ── Integration test fixtures (with database) ────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def _test_engine():
|
||||
"""Create a test database engine (session-scoped)."""
|
||||
if not _HAS_DB:
|
||||
pytest.skip("No test database available")
|
||||
return create_async_engine(_TEST_DB_URL, echo=False)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session", loop_scope="session")
|
||||
async def _setup_db(_test_engine):
|
||||
"""Create all tables once per test session, then seed fixture data.
|
||||
|
||||
Tests that depend on the cookie-category seed (normally applied by
|
||||
the ``0001_initial_schema`` alembic migration) get the same rows
|
||||
here so they can run without invoking alembic.
|
||||
"""
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
await _seed_cookie_categories(conn)
|
||||
yield
|
||||
async with _test_engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.drop_all)
|
||||
|
||||
|
||||
async def _seed_cookie_categories(conn) -> None:
|
||||
"""Insert the default cookie categories. Mirrors migration 0001."""
|
||||
import uuid as _uuid
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
rows = [
|
||||
("10000000-0000-0000-0000-000000000001", "Necessary", "necessary", True, 0),
|
||||
("10000000-0000-0000-0000-000000000002", "Functional", "functional", False, 1),
|
||||
("10000000-0000-0000-0000-000000000003", "Analytics", "analytics", False, 2),
|
||||
("10000000-0000-0000-0000-000000000004", "Marketing", "marketing", False, 3),
|
||||
("10000000-0000-0000-0000-000000000005", "Personalisation", "personalisation", False, 4),
|
||||
]
|
||||
stmt = text(
|
||||
"""
|
||||
INSERT INTO cookie_categories
|
||||
(id, name, slug, description, is_essential, display_order)
|
||||
VALUES (:id, :name, :slug, :description, :is_essential, :display_order)
|
||||
ON CONFLICT (slug) DO NOTHING
|
||||
""",
|
||||
)
|
||||
for row_id, name, slug, is_essential, order in rows:
|
||||
await conn.execute(
|
||||
stmt,
|
||||
{
|
||||
"id": _uuid.UUID(row_id),
|
||||
"name": name,
|
||||
"slug": slug,
|
||||
"description": f"{name} cookies",
|
||||
"is_essential": is_essential,
|
||||
"display_order": order,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def db_client(_test_engine, _setup_db):
|
||||
"""Async HTTP client where each route handler gets its own DB session.
|
||||
|
||||
Each request gets an independent session/connection so there are no
|
||||
'another operation is in progress' errors from asyncpg.
|
||||
"""
|
||||
from src.db import get_db
|
||||
|
||||
app = create_app()
|
||||
|
||||
async def _override_get_db():
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
|
||||
# ── Auth helper fixtures ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def test_org(_test_engine, _setup_db):
|
||||
"""Create a test organisation in the database."""
|
||||
from src.models.organisation import Organisation
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
org = Organisation(
|
||||
id=uuid.uuid4(),
|
||||
name="Test Organisation",
|
||||
slug=f"test-org-{uuid.uuid4().hex[:8]}",
|
||||
)
|
||||
session.add(org)
|
||||
await session.commit()
|
||||
return org
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def test_user(_test_engine, _setup_db, test_org):
|
||||
"""Create a test user (owner role) with a known password."""
|
||||
from src.models.user import User
|
||||
from src.services.auth import hash_password
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
user = User(
|
||||
id=uuid.uuid4(),
|
||||
email=f"admin-{uuid.uuid4().hex[:8]}@test.com",
|
||||
password_hash=hash_password("TestPassword123"),
|
||||
full_name="Test Admin",
|
||||
role="owner",
|
||||
organisation_id=test_org.id,
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def auth_token(test_user):
|
||||
"""Generate a valid JWT token for the test user."""
|
||||
from src.services.auth import create_access_token
|
||||
|
||||
return create_access_token(
|
||||
user_id=str(test_user.id),
|
||||
organisation_id=str(test_user.organisation_id),
|
||||
role=test_user.role,
|
||||
email=test_user.email,
|
||||
)
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def auth_headers(auth_token):
|
||||
"""HTTP headers with a valid Bearer token."""
|
||||
return {"Authorization": f"Bearer {auth_token}"}
|
||||
|
||||
|
||||
# ── Shared helper for creating sites in integration tests ────────────
|
||||
|
||||
|
||||
async def create_test_site(
|
||||
client: AsyncClient,
|
||||
headers: dict,
|
||||
*,
|
||||
domain_prefix: str = "test",
|
||||
display_name: str = "Test Site",
|
||||
) -> str:
|
||||
"""Create a site via the API and return its ID.
|
||||
|
||||
This is a helper function (not a fixture) so it can be called
|
||||
inline within each test, avoiding async fixture event-loop issues.
|
||||
"""
|
||||
resp = await client.post(
|
||||
"/api/v1/sites/",
|
||||
json={
|
||||
"domain": f"{domain_prefix}-{uuid.uuid4().hex[:8]}.com",
|
||||
"display_name": display_name,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert resp.status_code == 201, f"Failed to create test site: {resp.text}"
|
||||
return resp.json()["id"]
|
||||
179
apps/api/tests/test_auth.py
Normal file
179
apps/api/tests/test_auth.py
Normal file
@@ -0,0 +1,179 @@
|
||||
"""Tests for JWT authentication service and dependencies."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.services.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
hash_password,
|
||||
verify_password,
|
||||
)
|
||||
|
||||
|
||||
class TestPasswordHashing:
|
||||
def test_hash_and_verify(self):
|
||||
password = "s3cureP@ss!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed)
|
||||
|
||||
def test_wrong_password_fails(self):
|
||||
hashed = hash_password("correct")
|
||||
assert not verify_password("wrong", hashed)
|
||||
|
||||
def test_different_hashes_for_same_password(self):
|
||||
h1 = hash_password("same")
|
||||
h2 = hash_password("same")
|
||||
assert h1 != h2 # bcrypt salts differ
|
||||
|
||||
|
||||
class TestJWTTokens:
|
||||
@pytest.fixture
|
||||
def user_data(self):
|
||||
return {
|
||||
"user_id": uuid.uuid4(),
|
||||
"organisation_id": uuid.uuid4(),
|
||||
"role": "admin",
|
||||
"email": "test@example.com",
|
||||
}
|
||||
|
||||
def test_create_access_token_decodable(self, user_data):
|
||||
token = create_access_token(**user_data)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_data["user_id"])
|
||||
assert payload["org_id"] == str(user_data["organisation_id"])
|
||||
assert payload["role"] == "admin"
|
||||
assert payload["email"] == "test@example.com"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_create_refresh_token_decodable(self, user_data):
|
||||
token = create_refresh_token(
|
||||
user_id=user_data["user_id"],
|
||||
organisation_id=user_data["organisation_id"],
|
||||
)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_data["user_id"])
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_access_token_expiry(self, user_data):
|
||||
token = create_access_token(**user_data)
|
||||
payload = decode_token(token)
|
||||
settings = get_settings()
|
||||
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
|
||||
delta = exp - iat
|
||||
assert abs(delta.total_seconds() - settings.jwt_access_token_expire_minutes * 60) < 5
|
||||
|
||||
def test_refresh_token_expiry(self, user_data):
|
||||
token = create_refresh_token(
|
||||
user_id=user_data["user_id"],
|
||||
organisation_id=user_data["organisation_id"],
|
||||
)
|
||||
payload = decode_token(token)
|
||||
settings = get_settings()
|
||||
exp = datetime.fromtimestamp(payload["exp"], tz=UTC)
|
||||
iat = datetime.fromtimestamp(payload["iat"], tz=UTC)
|
||||
delta = exp - iat
|
||||
expected = settings.jwt_refresh_token_expire_days * 86400
|
||||
assert abs(delta.total_seconds() - expected) < 5
|
||||
|
||||
def test_expired_token_raises(self):
|
||||
settings = get_settings()
|
||||
payload = {
|
||||
"sub": str(uuid.uuid4()),
|
||||
"org_id": str(uuid.uuid4()),
|
||||
"role": "viewer",
|
||||
"exp": datetime.now(UTC) - timedelta(hours=1),
|
||||
"iat": datetime.now(UTC) - timedelta(hours=2),
|
||||
"type": "access",
|
||||
}
|
||||
token = jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
with pytest.raises(JWTError):
|
||||
decode_token(token)
|
||||
|
||||
def test_tampered_token_raises(self, user_data):
|
||||
token = create_access_token(**user_data)
|
||||
# Tamper with the token
|
||||
tampered = token[:-5] + "XXXXX"
|
||||
with pytest.raises(JWTError):
|
||||
decode_token(tampered)
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
def test_has_role(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="admin@example.com",
|
||||
role="admin",
|
||||
)
|
||||
assert user.has_role("admin", "owner")
|
||||
assert not user.has_role("editor", "viewer")
|
||||
|
||||
def test_is_admin(self):
|
||||
admin = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="a@b.com",
|
||||
role="admin",
|
||||
)
|
||||
viewer = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="v@b.com",
|
||||
role="viewer",
|
||||
)
|
||||
assert admin.is_admin
|
||||
assert not viewer.is_admin
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAuthEndpoints:
|
||||
async def test_me_without_token_returns_401(self, client):
|
||||
response = await client.get("/api/v1/auth/me")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_me_with_valid_token(self, client):
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
role="editor",
|
||||
email="user@example.com",
|
||||
)
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == str(user_id)
|
||||
assert data["organisation_id"] == str(org_id)
|
||||
assert data["role"] == "editor"
|
||||
assert data["email"] == "user@example.com"
|
||||
|
||||
async def test_me_with_refresh_token_rejected(self, client):
|
||||
token = create_refresh_token(
|
||||
user_id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
)
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_me_with_invalid_token(self, client):
|
||||
response = await client.get(
|
||||
"/api/v1/auth/me",
|
||||
headers={"Authorization": "Bearer invalid.token.here"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
124
apps/api/tests/test_bootstrap.py
Normal file
124
apps/api/tests/test_bootstrap.py
Normal file
@@ -0,0 +1,124 @@
|
||||
"""Tests for the initial admin bootstrap service."""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.models.organisation import Organisation
|
||||
from src.models.user import User
|
||||
from src.services.auth import verify_password
|
||||
from src.services.bootstrap import bootstrap_initial_admin
|
||||
from tests.conftest import requires_db
|
||||
|
||||
|
||||
def _settings(**overrides) -> Settings:
|
||||
base: dict = dict(
|
||||
environment="test",
|
||||
initial_admin_email=None,
|
||||
initial_admin_password=None,
|
||||
initial_admin_full_name="Administrator",
|
||||
initial_org_name="Default Organisation",
|
||||
initial_org_slug="default",
|
||||
)
|
||||
base.update(overrides)
|
||||
return Settings(**base)
|
||||
|
||||
|
||||
class TestBootstrapNoOp:
|
||||
"""Pure unit tests — bootstrap must short-circuit before touching the DB."""
|
||||
|
||||
async def test_noop_when_email_unset(self):
|
||||
settings = _settings(initial_admin_password="pw")
|
||||
with patch("src.services.bootstrap.async_session_factory") as factory:
|
||||
await bootstrap_initial_admin(settings)
|
||||
factory.assert_not_called()
|
||||
|
||||
async def test_noop_when_password_unset(self):
|
||||
settings = _settings(initial_admin_email="admin@example.com")
|
||||
with patch("src.services.bootstrap.async_session_factory") as factory:
|
||||
await bootstrap_initial_admin(settings)
|
||||
factory.assert_not_called()
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestBootstrapWithDatabase:
|
||||
"""Integration tests — exercise the real SQL path."""
|
||||
|
||||
@pytest_asyncio.fixture(loop_scope="session")
|
||||
async def clean_db(self, _test_engine, _setup_db):
|
||||
"""Strip users and orgs so bootstrap sees an empty table."""
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
await session.execute(User.__table__.delete())
|
||||
await session.execute(Organisation.__table__.delete())
|
||||
await session.commit()
|
||||
yield
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
await session.execute(User.__table__.delete())
|
||||
await session.execute(Organisation.__table__.delete())
|
||||
await session.commit()
|
||||
|
||||
async def test_creates_org_and_owner_when_empty(self, _test_engine, clean_db):
|
||||
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
|
||||
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
|
||||
settings = _settings(
|
||||
initial_admin_email=email,
|
||||
initial_admin_password="SuperSecret123",
|
||||
initial_org_slug=slug,
|
||||
initial_org_name="Bootstrapped Org",
|
||||
)
|
||||
|
||||
def _factory():
|
||||
return AsyncSession(_test_engine, expire_on_commit=False)
|
||||
|
||||
with patch("src.services.bootstrap.async_session_factory", _factory):
|
||||
await bootstrap_initial_admin(settings)
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
user = (await session.execute(select(User).where(User.email == email))).scalar_one()
|
||||
org = (
|
||||
await session.execute(select(Organisation).where(Organisation.slug == slug))
|
||||
).scalar_one()
|
||||
|
||||
assert user.role == "owner"
|
||||
assert user.organisation_id == org.id
|
||||
assert user.full_name == "Administrator"
|
||||
assert verify_password("SuperSecret123", user.password_hash)
|
||||
assert org.name == "Bootstrapped Org"
|
||||
assert org.contact_email == email
|
||||
|
||||
async def test_idempotent_when_user_exists(self, _test_engine, clean_db):
|
||||
"""A second invocation must not create a second user."""
|
||||
email = f"admin-{uuid.uuid4().hex[:8]}@example.com"
|
||||
slug = f"bootstrap-{uuid.uuid4().hex[:8]}"
|
||||
settings = _settings(
|
||||
initial_admin_email=email,
|
||||
initial_admin_password="SuperSecret123",
|
||||
initial_org_slug=slug,
|
||||
)
|
||||
|
||||
def _factory():
|
||||
return AsyncSession(_test_engine, expire_on_commit=False)
|
||||
|
||||
with patch("src.services.bootstrap.async_session_factory", _factory):
|
||||
await bootstrap_initial_admin(settings)
|
||||
await bootstrap_initial_admin(
|
||||
_settings(
|
||||
initial_admin_email="someone-else@example.com",
|
||||
initial_admin_password="Different123",
|
||||
initial_org_slug=slug,
|
||||
)
|
||||
)
|
||||
|
||||
async with AsyncSession(_test_engine, expire_on_commit=False) as session:
|
||||
users = (await session.execute(select(User))).scalars().all()
|
||||
|
||||
assert len(users) == 1
|
||||
assert users[0].email == email
|
||||
|
||||
|
||||
pytestmark = pytest.mark.asyncio(loop_scope="session")
|
||||
869
apps/api/tests/test_classification.py
Normal file
869
apps/api/tests/test_classification.py
Normal file
@@ -0,0 +1,869 @@
|
||||
"""Tests for known cookies database and auto-categorisation engine — CMP-22.
|
||||
|
||||
Covers:
|
||||
- Classification service logic (unit tests — pure functions)
|
||||
- Pattern matching (exact, wildcard, regex)
|
||||
- Priority ordering (allow-list → exact → regex → unmatched)
|
||||
- Known cookie CRUD endpoints (unit tests with mocked DB)
|
||||
- Classification endpoints (unit tests with mocked DB)
|
||||
- Schema validation
|
||||
- Integration tests against live database
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.schemas.cookie import (
|
||||
ClassificationResultResponse,
|
||||
ClassifySingleRequest,
|
||||
ClassifySiteResponse,
|
||||
KnownCookieCreate,
|
||||
KnownCookieResponse,
|
||||
KnownCookieUpdate,
|
||||
)
|
||||
from src.services.classification import (
|
||||
ClassificationResult,
|
||||
MatchSource,
|
||||
_match_pattern,
|
||||
_match_regex,
|
||||
classify_cookie,
|
||||
)
|
||||
|
||||
# ── Schema tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
"""Validate known cookie and classification schemas."""
|
||||
|
||||
def test_known_cookie_create(self):
|
||||
kc = KnownCookieCreate(
|
||||
name_pattern="_ga",
|
||||
domain_pattern="*",
|
||||
category_id=uuid.uuid4(),
|
||||
vendor="Google",
|
||||
description="GA cookie",
|
||||
)
|
||||
assert kc.is_regex is False
|
||||
|
||||
def test_known_cookie_create_regex(self):
|
||||
kc = KnownCookieCreate(
|
||||
name_pattern="_hj.*",
|
||||
domain_pattern=".*",
|
||||
category_id=uuid.uuid4(),
|
||||
is_regex=True,
|
||||
)
|
||||
assert kc.is_regex is True
|
||||
|
||||
def test_known_cookie_update_partial(self):
|
||||
ku = KnownCookieUpdate(vendor="Updated Vendor")
|
||||
dumped = ku.model_dump(exclude_unset=True)
|
||||
assert "vendor" in dumped
|
||||
assert "category_id" not in dumped
|
||||
|
||||
def test_known_cookie_response(self):
|
||||
resp = KnownCookieResponse(
|
||||
id=uuid.uuid4(),
|
||||
name_pattern="_ga",
|
||||
domain_pattern="*",
|
||||
category_id=uuid.uuid4(),
|
||||
is_regex=False,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
assert resp.vendor is None
|
||||
|
||||
def test_classification_result_response(self):
|
||||
crr = ClassificationResultResponse(
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
match_source="known_exact",
|
||||
matched=True,
|
||||
)
|
||||
assert crr.matched is True
|
||||
|
||||
def test_classify_single_request(self):
|
||||
req = ClassifySingleRequest(cookie_name="_ga", cookie_domain=".example.com")
|
||||
assert req.cookie_name == "_ga"
|
||||
|
||||
def test_classify_single_request_validation(self):
|
||||
with pytest.raises(ValueError):
|
||||
ClassifySingleRequest(cookie_name="", cookie_domain=".example.com")
|
||||
|
||||
def test_classify_site_response(self):
|
||||
resp = ClassifySiteResponse(
|
||||
site_id="abc",
|
||||
total=10,
|
||||
matched=7,
|
||||
unmatched=3,
|
||||
results=[],
|
||||
)
|
||||
assert resp.matched == 7
|
||||
|
||||
def test_match_source_enum(self):
|
||||
assert MatchSource.ALLOW_LIST == "allow_list"
|
||||
assert MatchSource.KNOWN_EXACT == "known_exact"
|
||||
assert MatchSource.KNOWN_REGEX == "known_regex"
|
||||
assert MatchSource.UNMATCHED == "unmatched"
|
||||
|
||||
|
||||
# ── Pattern matching unit tests ──────────────────────────────────────
|
||||
|
||||
|
||||
class TestPatternMatching:
|
||||
"""Test the _match_pattern and _match_regex helpers."""
|
||||
|
||||
def test_exact_match(self):
|
||||
assert _match_pattern("_ga", "_ga") is True
|
||||
|
||||
def test_exact_match_case_insensitive(self):
|
||||
assert _match_pattern("_GA", "_ga") is True
|
||||
assert _match_pattern("_ga", "_GA") is True
|
||||
|
||||
def test_exact_no_match(self):
|
||||
assert _match_pattern("_ga", "_gid") is False
|
||||
|
||||
def test_wildcard_star(self):
|
||||
assert _match_pattern("*", "_ga") is True
|
||||
assert _match_pattern("*", "anything") is True
|
||||
|
||||
def test_wildcard_prefix(self):
|
||||
assert _match_pattern("_ga_*", "_ga_ABC123") is True
|
||||
assert _match_pattern("_ga_*", "_ga_") is True
|
||||
assert _match_pattern("_ga_*", "_gid") is False
|
||||
|
||||
def test_wildcard_suffix(self):
|
||||
assert _match_pattern("*.google.com", ".google.com") is True
|
||||
assert _match_pattern("*.google.com", "www.google.com") is True
|
||||
assert _match_pattern("*.google.com", ".facebook.com") is False
|
||||
|
||||
def test_wildcard_middle(self):
|
||||
assert _match_pattern("_ga*id", "_ga_gid") is True # * matches _g
|
||||
assert _match_pattern("_ga*id", "_gaid") is True
|
||||
assert _match_pattern("_ga*id", "_ga") is False # must end in id
|
||||
|
||||
def test_empty_values(self):
|
||||
assert _match_pattern("", "_ga") is False
|
||||
assert _match_pattern("_ga", "") is False
|
||||
assert _match_pattern("", "") is False
|
||||
|
||||
def test_regex_match(self):
|
||||
assert _match_regex(r"_hj.*", "_hjSession_12345") is True
|
||||
assert _match_regex(r"_hj.*", "_ga") is False
|
||||
|
||||
def test_regex_case_insensitive(self):
|
||||
assert _match_regex(r"_hj.*", "_HJSession") is True
|
||||
|
||||
def test_regex_anchored(self):
|
||||
# re.match anchors at start by default
|
||||
assert _match_regex(r"_pk_id.*", "_pk_id.abc.123") is True
|
||||
assert _match_regex(r"_pk_id.*", "x_pk_id") is False
|
||||
|
||||
def test_regex_invalid_pattern(self):
|
||||
assert _match_regex(r"[invalid", "test") is False
|
||||
|
||||
def test_regex_full_domain_match(self):
|
||||
assert _match_regex(r".*", ".example.com") is True
|
||||
|
||||
def test_wildcard_dynamic_id_suffix(self):
|
||||
"""Cookies with dynamic IDs should match wildcard prefix patterns."""
|
||||
assert _match_pattern("_hjSessionUser_*", "_hjSessionUser_1150536") is True
|
||||
assert _match_pattern("_hjSession_*", "_hjSession_9876543") is True
|
||||
assert _match_pattern("ri--*", "ri--zC77O2yRxuIvW5fjRAq0RdzNYaF-x") is True
|
||||
assert _match_pattern("intercom-id-*", "intercom-id-abc123def") is True
|
||||
assert _match_pattern("amp_*", "amp_ff29a3") is True
|
||||
assert _match_pattern("mp_*", "mp_abc123_mixpanel") is True
|
||||
|
||||
def test_wildcard_does_not_overmatch(self):
|
||||
"""Wildcard patterns should not match unrelated cookies."""
|
||||
assert _match_pattern("_hjSessionUser_*", "_hjSession_123") is False
|
||||
assert _match_pattern("ri--*", "ri-single-dash") is False
|
||||
assert _match_pattern("intercom-id-*", "intercom-session-xyz") is False
|
||||
|
||||
|
||||
# ── Classification engine unit tests ─────────────────────────────────
|
||||
|
||||
|
||||
def _make_category(slug: str, cat_id: uuid.UUID | None = None):
|
||||
"""Create a mock CookieCategory."""
|
||||
cat = MagicMock()
|
||||
cat.id = cat_id or uuid.uuid4()
|
||||
cat.slug = slug
|
||||
return cat
|
||||
|
||||
|
||||
def _make_known(
|
||||
name_pattern: str,
|
||||
domain_pattern: str,
|
||||
category_id: uuid.UUID,
|
||||
vendor: str | None = None,
|
||||
description: str | None = None,
|
||||
is_regex: bool = False,
|
||||
):
|
||||
"""Create a mock KnownCookie."""
|
||||
known = MagicMock()
|
||||
known.name_pattern = name_pattern
|
||||
known.domain_pattern = domain_pattern
|
||||
known.category_id = category_id
|
||||
known.vendor = vendor
|
||||
known.description = description
|
||||
known.is_regex = is_regex
|
||||
return known
|
||||
|
||||
|
||||
def _make_allow_entry(
|
||||
name_pattern: str,
|
||||
domain_pattern: str,
|
||||
category_id: uuid.UUID,
|
||||
description: str | None = None,
|
||||
):
|
||||
"""Create a mock CookieAllowListEntry."""
|
||||
entry = MagicMock()
|
||||
entry.name_pattern = name_pattern
|
||||
entry.domain_pattern = domain_pattern
|
||||
entry.category_id = category_id
|
||||
entry.description = description
|
||||
return entry
|
||||
|
||||
|
||||
class TestClassifyCookie:
|
||||
"""Test the classify_cookie pure function."""
|
||||
|
||||
def setup_method(self):
|
||||
self.analytics_cat = _make_category("analytics")
|
||||
self.marketing_cat = _make_category("marketing")
|
||||
self.necessary_cat = _make_category("necessary")
|
||||
self.category_map = {
|
||||
self.analytics_cat.id: self.analytics_cat,
|
||||
self.marketing_cat.id: self.marketing_cat,
|
||||
self.necessary_cat.id: self.necessary_cat,
|
||||
}
|
||||
|
||||
def test_exact_known_match(self):
|
||||
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
|
||||
result = classify_cookie("_ga", ".example.com", [], [known], [], self.category_map)
|
||||
assert result.matched is True
|
||||
assert result.match_source == MatchSource.KNOWN_EXACT
|
||||
assert result.category_slug == "analytics"
|
||||
assert result.vendor == "Google"
|
||||
|
||||
def test_regex_known_match(self):
|
||||
known = _make_known(
|
||||
r"_hj.*",
|
||||
r".*",
|
||||
self.analytics_cat.id,
|
||||
vendor="Hotjar",
|
||||
is_regex=True,
|
||||
)
|
||||
result = classify_cookie(
|
||||
"_hjSession_123",
|
||||
".example.com",
|
||||
[],
|
||||
[],
|
||||
[known],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.match_source == MatchSource.KNOWN_REGEX
|
||||
assert result.vendor == "Hotjar"
|
||||
|
||||
def test_allow_list_match(self):
|
||||
entry = _make_allow_entry(
|
||||
"_custom_cookie",
|
||||
"*",
|
||||
self.necessary_cat.id,
|
||||
description="Site-specific override",
|
||||
)
|
||||
result = classify_cookie(
|
||||
"_custom_cookie",
|
||||
".example.com",
|
||||
[entry],
|
||||
[],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.match_source == MatchSource.ALLOW_LIST
|
||||
assert result.category_slug == "necessary"
|
||||
|
||||
def test_allow_list_takes_priority_over_known(self):
|
||||
"""Allow-list should override known cookies database."""
|
||||
allow_entry = _make_allow_entry(
|
||||
"_ga",
|
||||
"*",
|
||||
self.necessary_cat.id,
|
||||
description="Overridden to necessary",
|
||||
)
|
||||
known = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[allow_entry],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.match_source == MatchSource.ALLOW_LIST
|
||||
assert result.category_slug == "necessary"
|
||||
|
||||
def test_exact_takes_priority_over_regex(self):
|
||||
"""Exact match should be preferred over regex match."""
|
||||
exact = _make_known("_ga", "*", self.analytics_cat.id, vendor="Google")
|
||||
regex = _make_known(
|
||||
r"_g.*",
|
||||
r".*",
|
||||
self.marketing_cat.id,
|
||||
vendor="Other",
|
||||
is_regex=True,
|
||||
)
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[],
|
||||
[exact],
|
||||
[regex],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.match_source == MatchSource.KNOWN_EXACT
|
||||
assert result.category_slug == "analytics"
|
||||
|
||||
def test_unmatched(self):
|
||||
result = classify_cookie(
|
||||
"obscure_cookie",
|
||||
".unknown.com",
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is False
|
||||
assert result.match_source == MatchSource.UNMATCHED
|
||||
assert result.category_id is None
|
||||
|
||||
def test_domain_must_match(self):
|
||||
"""Cookie should not match if domain pattern doesn't match."""
|
||||
known = _make_known("_ga", "*.google.com", self.analytics_cat.id)
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is False
|
||||
|
||||
def test_name_must_match(self):
|
||||
"""Cookie should not match if name pattern doesn't match."""
|
||||
known = _make_known("_gid", "*", self.analytics_cat.id)
|
||||
result = classify_cookie(
|
||||
"_ga",
|
||||
".example.com",
|
||||
[],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is False
|
||||
|
||||
def test_wildcard_domain_match(self):
|
||||
known = _make_known(
|
||||
"fr",
|
||||
"*.facebook.com",
|
||||
self.marketing_cat.id,
|
||||
vendor="Meta",
|
||||
)
|
||||
result = classify_cookie(
|
||||
"fr",
|
||||
".facebook.com",
|
||||
[],
|
||||
[known],
|
||||
[],
|
||||
self.category_map,
|
||||
)
|
||||
assert result.matched is True
|
||||
assert result.vendor == "Meta"
|
||||
|
||||
def test_classification_result_fields(self):
|
||||
result = ClassificationResult(
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
)
|
||||
assert result.category_id is None
|
||||
assert result.match_source == MatchSource.UNMATCHED
|
||||
assert result.matched is False
|
||||
|
||||
|
||||
# ── Router unit tests (mocked service) ──────────────────────────────
|
||||
|
||||
|
||||
def _mock_db():
|
||||
"""Create a mock async DB session."""
|
||||
db = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = MagicMock()
|
||||
mock_result.scalars.return_value.all.return_value = []
|
||||
db.execute.return_value = mock_result
|
||||
return db
|
||||
|
||||
|
||||
async def _client(app, db):
|
||||
"""Create an async test client with mocked DB and auth."""
|
||||
from src.db import get_db
|
||||
from src.services.dependencies import get_current_user, require_role
|
||||
|
||||
user = MagicMock()
|
||||
user.organisation_id = uuid.uuid4()
|
||||
user.role = "owner"
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_current_user] = lambda: user
|
||||
|
||||
def _override_require_role(*_roles):
|
||||
return lambda: user
|
||||
|
||||
app.dependency_overrides[require_role] = _override_require_role
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
return AsyncClient(transport=transport, base_url="http://test")
|
||||
|
||||
|
||||
class TestKnownCookieRoutes:
|
||||
"""Test known cookie CRUD endpoints."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_known_cookies(self, app):
|
||||
db = _mock_db()
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.get("/api/v1/cookies/known")
|
||||
assert resp.status_code == 200
|
||||
assert isinstance(resp.json(), list)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_known_cookie(self, app):
|
||||
db = _mock_db()
|
||||
# Mock category validation
|
||||
cat_result = MagicMock()
|
||||
cat_result.scalar_one_or_none.return_value = MagicMock()
|
||||
# Mock the created known cookie
|
||||
known_mock = MagicMock()
|
||||
known_mock.id = uuid.uuid4()
|
||||
known_mock.name_pattern = "_ga"
|
||||
known_mock.domain_pattern = "*"
|
||||
known_mock.category_id = uuid.uuid4()
|
||||
known_mock.vendor = "Google"
|
||||
known_mock.description = "GA cookie"
|
||||
known_mock.is_regex = False
|
||||
known_mock.created_at = datetime.now()
|
||||
known_mock.updated_at = datetime.now()
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_execute(stmt):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# Category validation
|
||||
return cat_result
|
||||
return MagicMock()
|
||||
|
||||
db.execute = mock_execute
|
||||
db.flush = AsyncMock()
|
||||
db.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||
db.add = MagicMock()
|
||||
|
||||
with patch(
|
||||
"src.routers.cookies.KnownCookie",
|
||||
return_value=known_mock,
|
||||
):
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/cookies/known",
|
||||
json={
|
||||
"name_pattern": "_ga",
|
||||
"domain_pattern": "*",
|
||||
"category_id": str(uuid.uuid4()),
|
||||
"vendor": "Google",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_known_cookie_not_found(self, app):
|
||||
db = AsyncMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.scalar_one_or_none.return_value = None
|
||||
db.execute.return_value = mock_result
|
||||
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.get(f"/api/v1/cookies/known/{uuid.uuid4()}")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
class TestClassificationRoutes:
|
||||
"""Test classification endpoint responses."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_preview(self, app):
|
||||
db = _mock_db()
|
||||
mock_result = ClassificationResult(
|
||||
cookie_name="_ga",
|
||||
cookie_domain=".example.com",
|
||||
category_id=uuid.uuid4(),
|
||||
category_slug="analytics",
|
||||
vendor="Google",
|
||||
match_source=MatchSource.KNOWN_EXACT,
|
||||
matched=True,
|
||||
)
|
||||
with patch(
|
||||
"src.routers.cookies.classify_single_cookie",
|
||||
return_value=mock_result,
|
||||
):
|
||||
async with await _client(app, db) as client:
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{uuid.uuid4()}/classify/preview",
|
||||
json={
|
||||
"cookie_name": "_ga",
|
||||
"cookie_domain": ".example.com",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["matched"] is True
|
||||
assert data["match_source"] == "known_exact"
|
||||
|
||||
|
||||
# ── Integration tests ────────────────────────────────────────────────
|
||||
|
||||
|
||||
try:
|
||||
from tests.conftest import create_test_site, requires_db
|
||||
except ImportError:
|
||||
from conftest import create_test_site, requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestClassificationIntegration:
|
||||
"""Integration tests against a live database."""
|
||||
|
||||
async def _get_category_id(self, client: AsyncClient, headers: dict, slug: str) -> str:
|
||||
"""Get a category ID by slug."""
|
||||
resp = await client.get("/api/v1/cookies/categories", headers=headers)
|
||||
assert resp.status_code == 200
|
||||
for cat in resp.json():
|
||||
if cat["slug"] == slug:
|
||||
return cat["id"]
|
||||
pytest.fail(f"Category '{slug}' not found")
|
||||
|
||||
async def _create_known_cookie(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
headers: dict,
|
||||
name_pattern: str,
|
||||
domain_pattern: str,
|
||||
category_slug: str,
|
||||
*,
|
||||
vendor: str | None = None,
|
||||
is_regex: bool = False,
|
||||
) -> str:
|
||||
"""Create a known cookie and return its ID."""
|
||||
cat_id = await self._get_category_id(client, headers, category_slug)
|
||||
resp = await client.post(
|
||||
"/api/v1/cookies/known",
|
||||
headers=headers,
|
||||
json={
|
||||
"name_pattern": name_pattern,
|
||||
"domain_pattern": domain_pattern,
|
||||
"category_id": cat_id,
|
||||
"vendor": vendor,
|
||||
"is_regex": is_regex,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()["id"]
|
||||
|
||||
async def _create_cookie(
|
||||
self,
|
||||
client: AsyncClient,
|
||||
headers: dict,
|
||||
site_id: str,
|
||||
name: str,
|
||||
domain: str,
|
||||
) -> str:
|
||||
"""Create a pending cookie on a site and return its ID."""
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
headers=headers,
|
||||
json={"name": name, "domain": domain},
|
||||
)
|
||||
assert resp.status_code == 201, resp.text
|
||||
return resp.json()["id"]
|
||||
|
||||
async def test_known_cookies_crud(self, db_client, auth_headers):
|
||||
"""Test full CRUD lifecycle for known cookies."""
|
||||
cat_id = await self._get_category_id(db_client, auth_headers, "analytics")
|
||||
# Create
|
||||
resp = await db_client.post(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name_pattern": f"_test_{uuid.uuid4().hex[:6]}",
|
||||
"domain_pattern": "*",
|
||||
"category_id": cat_id,
|
||||
"vendor": "TestVendor",
|
||||
"description": "Test cookie",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
known_id = resp.json()["id"]
|
||||
|
||||
# Read
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["vendor"] == "TestVendor"
|
||||
|
||||
# Update
|
||||
resp = await db_client.patch(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
json={"vendor": "UpdatedVendor"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["vendor"] == "UpdatedVendor"
|
||||
|
||||
# List (with search)
|
||||
resp = await db_client.get(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
params={"vendor": "UpdatedVendor"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert any(k["id"] == known_id for k in resp.json())
|
||||
|
||||
# Delete
|
||||
resp = await db_client.delete(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 204
|
||||
|
||||
# Verify deleted
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/known/{known_id}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_classify_exact_match(self, db_client, auth_headers):
|
||||
"""Test classification with exact known cookie match."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-exact")
|
||||
# Create a known cookie pattern
|
||||
pattern_name = f"_test_exact_{uuid.uuid4().hex[:6]}"
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
pattern_name,
|
||||
"*",
|
||||
"analytics",
|
||||
vendor="TestVendor",
|
||||
)
|
||||
# Create a pending cookie on the site
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
pattern_name,
|
||||
".example.com",
|
||||
)
|
||||
# Classify
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] >= 1
|
||||
assert data["matched"] >= 1
|
||||
matched = [r for r in data["results"] if r["matched"]]
|
||||
assert any(r["cookie_name"] == pattern_name for r in matched)
|
||||
|
||||
async def test_classify_regex_match(self, db_client, auth_headers):
|
||||
"""Test classification with regex known cookie match."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-regex")
|
||||
prefix = f"_rx_{uuid.uuid4().hex[:4]}"
|
||||
# Create regex pattern
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
f"{prefix}.*",
|
||||
".*",
|
||||
"analytics",
|
||||
vendor="RegexVendor",
|
||||
is_regex=True,
|
||||
)
|
||||
# Create a cookie that should match the regex
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
f"{prefix}_session_123",
|
||||
".example.com",
|
||||
)
|
||||
# Classify
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["matched"] >= 1
|
||||
matched = [r for r in data["results"] if r["matched"]]
|
||||
assert any(r["match_source"] == "known_regex" for r in matched)
|
||||
|
||||
async def test_classify_unmatched(self, db_client, auth_headers):
|
||||
"""Cookies without known patterns should remain unmatched."""
|
||||
site_id = await create_test_site(
|
||||
db_client, auth_headers, domain_prefix="classify-unmatched"
|
||||
)
|
||||
unique_name = f"_unknown_{uuid.uuid4().hex[:8]}"
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
unique_name,
|
||||
".obscure-domain.com",
|
||||
)
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["unmatched"] >= 1
|
||||
|
||||
async def test_classify_preview(self, db_client, auth_headers):
|
||||
"""Test preview classification without saving."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-preview")
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify/preview",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"cookie_name": "_unknown_cookie",
|
||||
"cookie_domain": ".test.com",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["matched"] is False
|
||||
assert data["match_source"] == "unmatched"
|
||||
|
||||
async def test_classify_allow_list_priority(self, db_client, auth_headers):
|
||||
"""Allow-list entries should take priority over known cookies."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-allow")
|
||||
cookie_name = f"_priority_{uuid.uuid4().hex[:6]}"
|
||||
|
||||
# Add to known cookies as marketing
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
cookie_name,
|
||||
"*",
|
||||
"marketing",
|
||||
)
|
||||
|
||||
# Add to allow-list as necessary (should take priority)
|
||||
necessary_id = await self._get_category_id(db_client, auth_headers, "necessary")
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name_pattern": cookie_name,
|
||||
"domain_pattern": "*",
|
||||
"category_id": necessary_id,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
# Create cookie and classify
|
||||
await self._create_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
site_id,
|
||||
cookie_name,
|
||||
".example.com",
|
||||
)
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
matched = [r for r in data["results"] if r["cookie_name"] == cookie_name]
|
||||
assert len(matched) == 1
|
||||
assert matched[0]["match_source"] == "allow_list"
|
||||
assert matched[0]["category_id"] == necessary_id
|
||||
|
||||
async def test_known_cookies_not_found(self, db_client, auth_headers):
|
||||
resp = await db_client.get(
|
||||
f"/api/v1/cookies/known/{uuid.uuid4()}",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_known_cookies_invalid_category(self, db_client, auth_headers):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
json={
|
||||
"name_pattern": "_test",
|
||||
"domain_pattern": "*",
|
||||
"category_id": str(uuid.uuid4()),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_known_cookies_auth_required(self, db_client):
|
||||
"""Known cookie endpoints require authentication."""
|
||||
resp = await db_client.get("/api/v1/cookies/known")
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_classify_empty_site(self, db_client, auth_headers):
|
||||
"""Classifying a site with no cookies should return empty results."""
|
||||
site_id = await create_test_site(db_client, auth_headers, domain_prefix="classify-empty")
|
||||
resp = await db_client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/classify",
|
||||
headers=auth_headers,
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["total"] == 0
|
||||
assert data["matched"] == 0
|
||||
|
||||
async def test_list_known_cookies_search(self, db_client, auth_headers):
|
||||
"""Test searching known cookies by name pattern."""
|
||||
unique = uuid.uuid4().hex[:6]
|
||||
await self._create_known_cookie(
|
||||
db_client,
|
||||
auth_headers,
|
||||
f"_search_{unique}",
|
||||
"*",
|
||||
"analytics",
|
||||
)
|
||||
resp = await db_client.get(
|
||||
"/api/v1/cookies/known",
|
||||
headers=auth_headers,
|
||||
params={"search": f"_search_{unique}"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
results = resp.json()
|
||||
assert len(results) >= 1
|
||||
assert all(f"_search_{unique}" in r["name_pattern"] for r in results)
|
||||
597
apps/api/tests/test_compliance.py
Normal file
597
apps/api/tests/test_compliance.py
Normal file
@@ -0,0 +1,597 @@
|
||||
"""Tests for the compliance rule engine and router."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.schemas.compliance import (
|
||||
ComplianceCheckResponse,
|
||||
ComplianceIssue,
|
||||
Framework,
|
||||
FrameworkResult,
|
||||
Severity,
|
||||
)
|
||||
from src.services.compliance import (
|
||||
CCPA_RULES,
|
||||
CNIL_RULES,
|
||||
EPRIVACY_RULES,
|
||||
FRAMEWORK_RULES,
|
||||
GDPR_RULES,
|
||||
LGPD_RULES,
|
||||
SiteContext,
|
||||
calculate_overall_score,
|
||||
run_compliance_check,
|
||||
run_framework_check,
|
||||
)
|
||||
|
||||
# ── SiteContext defaults ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSiteContext:
|
||||
def test_default_values(self):
|
||||
ctx = SiteContext()
|
||||
assert ctx.blocking_mode == "opt_in"
|
||||
assert ctx.tcf_enabled is False
|
||||
assert ctx.gcm_enabled is True
|
||||
assert ctx.consent_expiry_days == 365
|
||||
assert ctx.has_reject_button is True
|
||||
assert ctx.has_granular_choices is True
|
||||
assert ctx.has_cookie_wall is False
|
||||
assert ctx.pre_ticked_boxes is False
|
||||
|
||||
def test_custom_values(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
consent_expiry_days=180,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
assert ctx.blocking_mode == "opt_out"
|
||||
assert ctx.consent_expiry_days == 180
|
||||
assert ctx.privacy_policy_url == "https://example.com/privacy"
|
||||
|
||||
|
||||
# ── GDPR rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGDPRRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
has_reject_button=True,
|
||||
has_granular_choices=True,
|
||||
has_cookie_wall=False,
|
||||
pre_ticked_boxes=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
uncategorised_cookies=0,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
assert len(result.issues) == 0
|
||||
assert result.rules_passed == result.rules_checked
|
||||
|
||||
def test_opt_out_mode_fails(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
|
||||
assert result.status == "non_compliant"
|
||||
|
||||
def test_informational_mode_fails(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="informational",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_opt_in" for i in result.issues)
|
||||
|
||||
def test_no_reject_button_fails(self):
|
||||
ctx = SiteContext(
|
||||
has_reject_button=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_reject_button" for i in result.issues)
|
||||
|
||||
def test_no_granular_consent_fails(self):
|
||||
ctx = SiteContext(
|
||||
has_granular_choices=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_granular" for i in result.issues)
|
||||
|
||||
def test_cookie_wall_fails(self):
|
||||
ctx = SiteContext(
|
||||
has_cookie_wall=True,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_cookie_wall" for i in result.issues)
|
||||
|
||||
def test_pre_ticked_fails(self):
|
||||
ctx = SiteContext(
|
||||
pre_ticked_boxes=True,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert any(i.rule_id == "gdpr_pre_ticked" for i in result.issues)
|
||||
|
||||
def test_no_privacy_policy_warns(self):
|
||||
ctx = SiteContext(privacy_policy_url=None)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
policy_issues = [i for i in result.issues if i.rule_id == "gdpr_privacy_policy"]
|
||||
assert len(policy_issues) == 1
|
||||
assert policy_issues[0].severity == Severity.WARNING
|
||||
|
||||
def test_uncategorised_cookies_warns(self):
|
||||
ctx = SiteContext(
|
||||
uncategorised_cookies=5,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
uncat_issues = [i for i in result.issues if i.rule_id == "gdpr_uncategorised"]
|
||||
assert len(uncat_issues) == 1
|
||||
assert "5" in uncat_issues[0].message
|
||||
|
||||
def test_multiple_failures_accumulate(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
has_reject_button=False,
|
||||
has_granular_choices=False,
|
||||
has_cookie_wall=True,
|
||||
pre_ticked_boxes=True,
|
||||
privacy_policy_url=None,
|
||||
uncategorised_cookies=3,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.score == 0 # Capped at 0
|
||||
assert result.status == "non_compliant"
|
||||
assert len(result.issues) >= 5
|
||||
|
||||
|
||||
# ── CNIL rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCNILRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
has_reject_button=True,
|
||||
has_granular_choices=True,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
consent_expiry_days=180,
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_consent_expiry_too_long(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=365,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert any(i.rule_id == "cnil_reconsent" for i in result.issues)
|
||||
|
||||
def test_consent_expiry_at_limit(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=182,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert not any(i.rule_id == "cnil_reconsent" for i in result.issues)
|
||||
|
||||
def test_cookie_lifetime_too_long(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=400,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
|
||||
|
||||
def test_cookie_lifetime_at_limit(self):
|
||||
ctx = SiteContext(
|
||||
consent_expiry_days=395,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert not any(i.rule_id == "cnil_cookie_lifetime" for i in result.issues)
|
||||
|
||||
def test_inherits_gdpr_rules(self):
|
||||
"""CNIL should check all GDPR rules plus CNIL-specific ones."""
|
||||
assert len(CNIL_RULES) > len(GDPR_RULES)
|
||||
|
||||
def test_reject_first_layer(self):
|
||||
ctx = SiteContext(
|
||||
has_reject_button=False,
|
||||
consent_expiry_days=180,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.CNIL, ctx)
|
||||
assert any(i.rule_id == "cnil_reject_first_layer" for i in result.issues)
|
||||
|
||||
|
||||
# ── CCPA rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCCPARules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_opt_in_also_acceptable(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert not any(i.rule_id == "ccpa_opt_out" for i in result.issues)
|
||||
|
||||
def test_informational_mode_passes_ccpa(self):
|
||||
"""CCPA opt-out check passes for informational (it's not 'informational')."""
|
||||
ctx = SiteContext(
|
||||
blocking_mode="informational",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
# informational is not in ("opt_out", "opt_in"), so it fails
|
||||
assert any(i.rule_id == "ccpa_opt_out" for i in result.issues)
|
||||
|
||||
def test_no_do_not_sell_link(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config={},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
|
||||
|
||||
def test_no_banner_config_fails_dns(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
banner_config=None,
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert any(i.rule_id == "ccpa_do_not_sell" for i in result.issues)
|
||||
|
||||
def test_no_privacy_policy_warns(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url=None,
|
||||
banner_config={"show_do_not_sell_link": True},
|
||||
)
|
||||
result = run_framework_check(Framework.CCPA, ctx)
|
||||
assert any(i.rule_id == "ccpa_privacy_policy" for i in result.issues)
|
||||
|
||||
|
||||
# ── ePrivacy rules ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEPrivacyRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(blocking_mode="opt_in")
|
||||
result = run_framework_check(Framework.EPRIVACY, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_opt_out_passes(self):
|
||||
ctx = SiteContext(blocking_mode="opt_out")
|
||||
result = run_framework_check(Framework.EPRIVACY, ctx)
|
||||
assert not any(i.rule_id == "eprivacy_consent" for i in result.issues)
|
||||
|
||||
def test_informational_fails(self):
|
||||
ctx = SiteContext(blocking_mode="informational")
|
||||
result = run_framework_check(Framework.EPRIVACY, ctx)
|
||||
assert any(i.rule_id == "eprivacy_consent" for i in result.issues)
|
||||
|
||||
|
||||
# ── LGPD rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLGPDRules:
|
||||
def test_compliant_site(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
has_granular_choices=True,
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert result.score == 100
|
||||
assert result.status == "compliant"
|
||||
|
||||
def test_informational_fails(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="informational",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
|
||||
|
||||
def test_no_privacy_policy_warns(self):
|
||||
ctx = SiteContext(privacy_policy_url=None)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert any(i.rule_id == "lgpd_data_controller" for i in result.issues)
|
||||
|
||||
def test_no_granular_warns(self):
|
||||
ctx = SiteContext(
|
||||
has_granular_choices=False,
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert any(i.rule_id == "lgpd_granular" for i in result.issues)
|
||||
|
||||
def test_opt_out_passes(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.LGPD, ctx)
|
||||
assert not any(i.rule_id == "lgpd_consent_basis" for i in result.issues)
|
||||
|
||||
|
||||
# ── Engine orchestration ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestComplianceEngine:
|
||||
def test_run_all_frameworks(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
has_reject_button=True,
|
||||
has_granular_choices=True,
|
||||
consent_expiry_days=180,
|
||||
)
|
||||
results = run_compliance_check(ctx)
|
||||
assert len(results) == 5
|
||||
frameworks = {r.framework for r in results}
|
||||
assert frameworks == {
|
||||
Framework.GDPR,
|
||||
Framework.CNIL,
|
||||
Framework.CCPA,
|
||||
Framework.EPRIVACY,
|
||||
Framework.LGPD,
|
||||
}
|
||||
|
||||
def test_run_specific_frameworks(self):
|
||||
ctx = SiteContext()
|
||||
results = run_compliance_check(ctx, [Framework.GDPR, Framework.CCPA])
|
||||
assert len(results) == 2
|
||||
assert results[0].framework == Framework.GDPR
|
||||
assert results[1].framework == Framework.CCPA
|
||||
|
||||
def test_run_single_framework(self):
|
||||
ctx = SiteContext()
|
||||
results = run_compliance_check(ctx, [Framework.EPRIVACY])
|
||||
assert len(results) == 1
|
||||
assert results[0].framework == Framework.EPRIVACY
|
||||
|
||||
def test_empty_frameworks_list_runs_all(self):
|
||||
ctx = SiteContext(
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
consent_expiry_days=180,
|
||||
)
|
||||
results = run_compliance_check(ctx, None)
|
||||
assert len(results) == 5
|
||||
|
||||
|
||||
class TestScoring:
|
||||
def test_perfect_score(self):
|
||||
result = FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=100,
|
||||
status="compliant",
|
||||
rules_checked=7,
|
||||
rules_passed=7,
|
||||
)
|
||||
assert calculate_overall_score([result]) == 100
|
||||
|
||||
def test_zero_score(self):
|
||||
result = FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=0,
|
||||
status="non_compliant",
|
||||
rules_checked=7,
|
||||
rules_passed=0,
|
||||
)
|
||||
assert calculate_overall_score([result]) == 0
|
||||
|
||||
def test_average_across_frameworks(self):
|
||||
results = [
|
||||
FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=100,
|
||||
status="compliant",
|
||||
rules_checked=7,
|
||||
rules_passed=7,
|
||||
),
|
||||
FrameworkResult(
|
||||
framework=Framework.CCPA,
|
||||
score=50,
|
||||
status="partial",
|
||||
rules_checked=3,
|
||||
rules_passed=1,
|
||||
),
|
||||
]
|
||||
assert calculate_overall_score(results) == 75
|
||||
|
||||
def test_empty_results(self):
|
||||
assert calculate_overall_score([]) == 100
|
||||
|
||||
def test_critical_issues_deduct_20(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
# opt_out causes one critical issue (gdpr_opt_in) → -20 points
|
||||
assert result.score == 80
|
||||
|
||||
def test_warning_issues_deduct_5(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url=None,
|
||||
uncategorised_cookies=0,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
# Missing privacy policy is a warning → -5 points
|
||||
assert result.score == 95
|
||||
|
||||
def test_score_floors_at_zero(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_out",
|
||||
has_reject_button=False,
|
||||
has_granular_choices=False,
|
||||
has_cookie_wall=True,
|
||||
pre_ticked_boxes=True,
|
||||
privacy_policy_url=None,
|
||||
uncategorised_cookies=10,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.score == 0
|
||||
|
||||
def test_status_non_compliant_with_critical(self):
|
||||
ctx = SiteContext(blocking_mode="opt_out")
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.status == "non_compliant"
|
||||
|
||||
def test_status_partial_with_warnings_only(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url=None,
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.status == "partial"
|
||||
|
||||
def test_status_compliant_with_no_issues(self):
|
||||
ctx = SiteContext(
|
||||
blocking_mode="opt_in",
|
||||
privacy_policy_url="https://example.com/privacy",
|
||||
)
|
||||
result = run_framework_check(Framework.GDPR, ctx)
|
||||
assert result.status == "compliant"
|
||||
|
||||
|
||||
# ── Framework registry ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestFrameworkRegistry:
|
||||
def test_all_frameworks_registered(self):
|
||||
assert Framework.GDPR in FRAMEWORK_RULES
|
||||
assert Framework.CNIL in FRAMEWORK_RULES
|
||||
assert Framework.CCPA in FRAMEWORK_RULES
|
||||
assert Framework.EPRIVACY in FRAMEWORK_RULES
|
||||
assert Framework.LGPD in FRAMEWORK_RULES
|
||||
|
||||
def test_each_framework_has_rules(self):
|
||||
for fw, rules in FRAMEWORK_RULES.items():
|
||||
assert len(rules) > 0, f"{fw} has no rules"
|
||||
|
||||
def test_rule_ids_are_unique_per_framework(self):
|
||||
for fw, rules in FRAMEWORK_RULES.items():
|
||||
ids = [r.rule_id for r in rules]
|
||||
assert len(ids) == len(set(ids)), f"Duplicate rule IDs in {fw}"
|
||||
|
||||
def test_gdpr_rule_count(self):
|
||||
assert len(GDPR_RULES) == 7
|
||||
|
||||
def test_cnil_includes_gdpr_rules(self):
|
||||
gdpr_ids = {r.rule_id for r in GDPR_RULES}
|
||||
cnil_ids = {r.rule_id for r in CNIL_RULES}
|
||||
assert gdpr_ids.issubset(cnil_ids)
|
||||
|
||||
def test_ccpa_rule_count(self):
|
||||
assert len(CCPA_RULES) == 3
|
||||
|
||||
def test_eprivacy_rule_count(self):
|
||||
assert len(EPRIVACY_RULES) == 2
|
||||
|
||||
def test_lgpd_rule_count(self):
|
||||
assert len(LGPD_RULES) == 3
|
||||
|
||||
|
||||
# ── Router tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestComplianceRouter:
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
from src.main import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
@pytest.fixture
|
||||
async def client(self, app):
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac
|
||||
|
||||
async def test_list_frameworks(self, client):
|
||||
resp = await client.get("/api/v1/compliance/frameworks")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert len(data) == 5
|
||||
ids = {fw["id"] for fw in data}
|
||||
assert ids == {"gdpr", "cnil", "ccpa", "eprivacy", "lgpd"}
|
||||
|
||||
async def test_check_requires_auth(self, client):
|
||||
resp = await client.post(f"/api/v1/compliance/check/{uuid.uuid4()}")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ── Schema tests ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestSchemas:
|
||||
def test_compliance_issue_schema(self):
|
||||
issue = ComplianceIssue(
|
||||
rule_id="test_rule",
|
||||
severity=Severity.CRITICAL,
|
||||
message="Test message",
|
||||
recommendation="Test recommendation",
|
||||
)
|
||||
assert issue.rule_id == "test_rule"
|
||||
assert issue.severity == Severity.CRITICAL
|
||||
|
||||
def test_framework_result_schema(self):
|
||||
result = FrameworkResult(
|
||||
framework=Framework.GDPR,
|
||||
score=85,
|
||||
status="partial",
|
||||
rules_checked=7,
|
||||
rules_passed=5,
|
||||
)
|
||||
assert result.framework == Framework.GDPR
|
||||
assert result.score == 85
|
||||
|
||||
def test_compliance_check_response_schema(self):
|
||||
response = ComplianceCheckResponse(
|
||||
site_id="test-id",
|
||||
results=[],
|
||||
overall_score=100,
|
||||
)
|
||||
assert response.overall_score == 100
|
||||
|
||||
def test_severity_values(self):
|
||||
assert Severity.CRITICAL == "critical"
|
||||
assert Severity.WARNING == "warning"
|
||||
assert Severity.INFO == "info"
|
||||
|
||||
def test_framework_values(self):
|
||||
assert Framework.GDPR == "gdpr"
|
||||
assert Framework.CNIL == "cnil"
|
||||
assert Framework.CCPA == "ccpa"
|
||||
assert Framework.EPRIVACY == "eprivacy"
|
||||
assert Framework.LGPD == "lgpd"
|
||||
258
apps/api/tests/test_config_resolver.py
Normal file
258
apps/api/tests/test_config_resolver.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Tests for configuration hierarchy resolver and publisher."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services.config_resolver import (
|
||||
SYSTEM_DEFAULTS,
|
||||
build_public_config,
|
||||
resolve_config,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveConfig:
|
||||
def test_returns_system_defaults_for_empty_config(self):
|
||||
result = resolve_config({})
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
assert result["consent_expiry_days"] == 365
|
||||
assert result["gcm_enabled"] is True
|
||||
assert result["tcf_enabled"] is False
|
||||
assert result["gpp_enabled"] is True
|
||||
assert result["gpp_supported_apis"] == ["usnat"]
|
||||
assert result["gpc_enabled"] is True
|
||||
assert result["gpc_jurisdictions"] == [
|
||||
"US-CA",
|
||||
"US-CO",
|
||||
"US-CT",
|
||||
"US-TX",
|
||||
"US-MT",
|
||||
]
|
||||
assert result["gpc_global_honour"] is False
|
||||
|
||||
def test_site_config_overrides_defaults(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_out",
|
||||
"consent_expiry_days": 180,
|
||||
"tcf_enabled": True,
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["blocking_mode"] == "opt_out"
|
||||
assert result["consent_expiry_days"] == 180
|
||||
assert result["tcf_enabled"] is True
|
||||
# Non-overridden values stay as defaults
|
||||
assert result["gcm_enabled"] is True
|
||||
|
||||
def test_org_defaults_override_system_defaults(self):
|
||||
org_defaults = {"consent_expiry_days": 90}
|
||||
result = resolve_config({}, org_defaults=org_defaults)
|
||||
assert result["consent_expiry_days"] == 90
|
||||
|
||||
def test_site_config_overrides_org_defaults(self):
|
||||
org_defaults = {"consent_expiry_days": 90}
|
||||
site_config = {"consent_expiry_days": 30}
|
||||
result = resolve_config(site_config, org_defaults=org_defaults)
|
||||
assert result["consent_expiry_days"] == 30
|
||||
|
||||
def test_none_values_in_site_config_do_not_override(self):
|
||||
site_config = {"blocking_mode": None, "consent_expiry_days": 180}
|
||||
result = resolve_config(site_config)
|
||||
# None should not override the default
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
assert result["consent_expiry_days"] == 180
|
||||
|
||||
def test_regional_override_applied(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"US-CA": "opt_out", "EU": "opt_in"},
|
||||
}
|
||||
result = resolve_config(site_config, region="US-CA")
|
||||
assert result["blocking_mode"] == "opt_out"
|
||||
|
||||
def test_regional_override_falls_back_to_default(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"EU": "opt_in", "DEFAULT": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config, region="BR")
|
||||
assert result["blocking_mode"] == "opt_out"
|
||||
|
||||
def test_regional_override_no_match_keeps_site_config(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"EU": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config, region="JP")
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
|
||||
def test_no_region_ignores_regional_modes(self):
|
||||
site_config = {
|
||||
"blocking_mode": "opt_in",
|
||||
"regional_modes": {"US-CA": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
|
||||
def test_gpp_site_config_overrides_defaults(self):
|
||||
site_config = {
|
||||
"gpp_enabled": False,
|
||||
"gpp_supported_apis": ["usca", "usva"],
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["gpp_enabled"] is False
|
||||
assert result["gpp_supported_apis"] == ["usca", "usva"]
|
||||
|
||||
def test_gpc_site_config_overrides_defaults(self):
|
||||
site_config = {
|
||||
"gpc_enabled": False,
|
||||
"gpc_global_honour": True,
|
||||
"gpc_jurisdictions": ["US-CA"],
|
||||
}
|
||||
result = resolve_config(site_config)
|
||||
assert result["gpc_enabled"] is False
|
||||
assert result["gpc_global_honour"] is True
|
||||
assert result["gpc_jurisdictions"] == ["US-CA"]
|
||||
|
||||
def test_gpp_gpc_org_defaults_override_system(self):
|
||||
org_defaults = {
|
||||
"gpp_enabled": False,
|
||||
"gpc_global_honour": True,
|
||||
}
|
||||
result = resolve_config({}, org_defaults=org_defaults)
|
||||
assert result["gpp_enabled"] is False
|
||||
assert result["gpc_global_honour"] is True
|
||||
# Non-overridden GPP/GPC fields stay as system defaults
|
||||
assert result["gpc_enabled"] is True
|
||||
|
||||
def test_gpp_gpc_site_overrides_org(self):
|
||||
org_defaults = {"gpp_supported_apis": ["usca"]}
|
||||
site_config = {"gpp_supported_apis": ["usnat", "usco"]}
|
||||
result = resolve_config(site_config, org_defaults=org_defaults)
|
||||
assert result["gpp_supported_apis"] == ["usnat", "usco"]
|
||||
|
||||
def test_group_defaults_override_org_defaults(self):
|
||||
org_defaults = {"consent_expiry_days": 90, "tcf_enabled": True}
|
||||
group_defaults = {"consent_expiry_days": 60}
|
||||
result = resolve_config(
|
||||
{},
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
assert result["consent_expiry_days"] == 60 # Group overrides org
|
||||
assert result["tcf_enabled"] is True # Still from org
|
||||
|
||||
def test_site_config_overrides_group_defaults(self):
|
||||
group_defaults = {"consent_expiry_days": 60}
|
||||
site_config = {"consent_expiry_days": 30}
|
||||
result = resolve_config(site_config, group_defaults=group_defaults)
|
||||
assert result["consent_expiry_days"] == 30 # Site overrides group
|
||||
|
||||
def test_none_in_group_defaults_does_not_override(self):
|
||||
org_defaults = {"consent_expiry_days": 90}
|
||||
group_defaults = {"consent_expiry_days": None}
|
||||
result = resolve_config(
|
||||
{},
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
assert result["consent_expiry_days"] == 90 # Org value preserved
|
||||
|
||||
def test_full_hierarchy(self):
|
||||
org_defaults = {
|
||||
"consent_expiry_days": 90,
|
||||
"tcf_enabled": True,
|
||||
}
|
||||
site_config = {
|
||||
"consent_expiry_days": 60,
|
||||
"banner_config": {"primaryColour": "#ff0000"},
|
||||
"regional_modes": {"EU": "opt_in", "US": "opt_out"},
|
||||
}
|
||||
result = resolve_config(site_config, org_defaults=org_defaults, region="US")
|
||||
assert result["consent_expiry_days"] == 60 # Site overrides org
|
||||
assert result["tcf_enabled"] is True # From org defaults
|
||||
assert result["blocking_mode"] == "opt_out" # Regional override
|
||||
assert result["banner_config"] == {"primaryColour": "#ff0000"}
|
||||
|
||||
def test_full_hierarchy_with_group(self):
|
||||
org_defaults = {
|
||||
"consent_expiry_days": 90,
|
||||
"tcf_enabled": True,
|
||||
"blocking_mode": "opt_in",
|
||||
}
|
||||
group_defaults = {
|
||||
"consent_expiry_days": 60,
|
||||
"privacy_policy_url": "https://group.example.com/privacy",
|
||||
}
|
||||
site_config = {
|
||||
"banner_config": {"primaryColour": "#ff0000"},
|
||||
"regional_modes": {"US": "opt_out"},
|
||||
}
|
||||
result = resolve_config(
|
||||
site_config,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
region="US",
|
||||
)
|
||||
assert result["consent_expiry_days"] == 60 # From group
|
||||
assert result["tcf_enabled"] is True # From org
|
||||
assert result["blocking_mode"] == "opt_out" # Regional override
|
||||
assert result["privacy_policy_url"] == "https://group.example.com/privacy" # From group
|
||||
assert result["banner_config"] == {"primaryColour": "#ff0000"} # From site
|
||||
|
||||
|
||||
class TestBuildPublicConfig:
|
||||
def test_includes_required_fields(self):
|
||||
site_id = str(uuid.uuid4())
|
||||
resolved = {**SYSTEM_DEFAULTS, "id": "config-123"}
|
||||
result = build_public_config(site_id, resolved)
|
||||
|
||||
assert result["site_id"] == site_id
|
||||
assert result["id"] == "config-123"
|
||||
assert result["blocking_mode"] == "opt_in"
|
||||
assert result["consent_expiry_days"] == 365
|
||||
assert "gcm_enabled" in result
|
||||
assert "tcf_enabled" in result
|
||||
assert "banner_config" in result
|
||||
assert result["gpp_enabled"] is True
|
||||
assert result["gpp_supported_apis"] == ["usnat"]
|
||||
assert result["gpc_enabled"] is True
|
||||
assert result["gpc_jurisdictions"] == [
|
||||
"US-CA",
|
||||
"US-CO",
|
||||
"US-CT",
|
||||
"US-TX",
|
||||
"US-MT",
|
||||
]
|
||||
assert result["gpc_global_honour"] is False
|
||||
|
||||
def test_strips_unknown_internal_fields(self):
|
||||
site_id = str(uuid.uuid4())
|
||||
resolved = {
|
||||
**SYSTEM_DEFAULTS,
|
||||
"id": "",
|
||||
"internal_field": "should_not_appear",
|
||||
"scan_enabled": True,
|
||||
}
|
||||
result = build_public_config(site_id, resolved)
|
||||
assert "internal_field" not in result
|
||||
assert "scan_enabled" not in result
|
||||
|
||||
|
||||
class TestConfigRoutes:
|
||||
def test_resolved_config_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/config/sites/{site_id}/resolved" in routes
|
||||
|
||||
def test_publish_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/config/sites/{site_id}/publish" in routes
|
||||
|
||||
def test_inheritance_route_registered(self, app):
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/config/sites/{site_id}/inheritance" in routes
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_publish_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(f"/api/v1/config/sites/{site_id}/publish")
|
||||
assert resp.status_code == 401
|
||||
130
apps/api/tests/test_consent.py
Normal file
130
apps/api/tests/test_consent.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Tests for consent recording API schemas and routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.schemas.consent import (
|
||||
ConsentAction,
|
||||
ConsentRecordCreate,
|
||||
ConsentRecordResponse,
|
||||
ConsentVerifyResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestConsentSchemas:
|
||||
def test_create_accept_all(self):
|
||||
record = ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="visitor-abc-123",
|
||||
action=ConsentAction.ACCEPT_ALL,
|
||||
categories_accepted=["necessary", "analytics", "marketing"],
|
||||
)
|
||||
assert record.action == "accept_all"
|
||||
assert len(record.categories_accepted) == 3
|
||||
assert record.categories_rejected is None
|
||||
|
||||
def test_create_custom(self):
|
||||
record = ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="visitor-xyz",
|
||||
action=ConsentAction.CUSTOM,
|
||||
categories_accepted=["necessary", "functional"],
|
||||
categories_rejected=["analytics", "marketing"],
|
||||
tc_string="COwQHgAAAAA",
|
||||
gcm_state={"analytics_storage": "denied", "ad_storage": "denied"},
|
||||
page_url="https://example.com/page",
|
||||
country_code="GB",
|
||||
region_code="GB-ENG",
|
||||
)
|
||||
assert record.action == "custom"
|
||||
assert record.tc_string == "COwQHgAAAAA"
|
||||
assert record.gcm_state["analytics_storage"] == "denied"
|
||||
|
||||
def test_create_reject_all(self):
|
||||
record = ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action=ConsentAction.REJECT_ALL,
|
||||
categories_accepted=["necessary"],
|
||||
categories_rejected=["analytics", "marketing", "functional"],
|
||||
)
|
||||
assert record.action == "reject_all"
|
||||
|
||||
def test_empty_visitor_id_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="",
|
||||
action=ConsentAction.ACCEPT_ALL,
|
||||
categories_accepted=["necessary"],
|
||||
)
|
||||
|
||||
def test_invalid_action_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ConsentRecordCreate(
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action="invalid_action",
|
||||
categories_accepted=[],
|
||||
)
|
||||
|
||||
def test_response_from_attributes(self):
|
||||
resp = ConsentRecordResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action="accept_all",
|
||||
categories_accepted=["necessary"],
|
||||
categories_rejected=None,
|
||||
tc_string=None,
|
||||
gcm_state=None,
|
||||
page_url=None,
|
||||
country_code=None,
|
||||
region_code=None,
|
||||
consented_at="2026-01-01T00:00:00Z",
|
||||
)
|
||||
assert resp.action == "accept_all"
|
||||
|
||||
def test_verify_response(self):
|
||||
resp = ConsentVerifyResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
visitor_id="v-1",
|
||||
action="accept_all",
|
||||
categories_accepted=["necessary"],
|
||||
consented_at="2026-01-01T00:00:00Z",
|
||||
)
|
||||
assert resp.valid is True
|
||||
|
||||
|
||||
class TestConsentActions:
|
||||
def test_action_values(self):
|
||||
assert ConsentAction.ACCEPT_ALL == "accept_all"
|
||||
assert ConsentAction.REJECT_ALL == "reject_all"
|
||||
assert ConsentAction.CUSTOM == "custom"
|
||||
assert ConsentAction.WITHDRAW == "withdraw"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestConsentRoutesRegistered:
|
||||
async def test_consent_routes_exist(self, client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/api/v1/consent/" in paths
|
||||
assert "/api/v1/consent/{consent_id}" in paths
|
||||
assert "/api/v1/consent/verify/{consent_id}" in paths
|
||||
|
||||
async def test_consent_post_validates_body(self, client):
|
||||
"""POST /consent rejects invalid payloads."""
|
||||
response = await client.post(
|
||||
"/api/v1/consent/",
|
||||
json={"invalid": "body"},
|
||||
)
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_config_public_endpoint_exists(self, client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/api/v1/config/sites/{site_id}" in paths
|
||||
218
apps/api/tests/test_cookies.py
Normal file
218
apps/api/tests/test_cookies.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Tests for cookie category, cookie, and allow-list schemas and routes."""
|
||||
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from src.schemas.cookie import (
|
||||
AllowListEntryCreate,
|
||||
AllowListEntryUpdate,
|
||||
CookieCategoryResponse,
|
||||
CookieCreate,
|
||||
CookieResponse,
|
||||
CookieUpdate,
|
||||
ReviewStatus,
|
||||
StorageType,
|
||||
)
|
||||
|
||||
# ─── Schema tests ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestStorageType:
|
||||
def test_values(self):
|
||||
assert StorageType.cookie == "cookie"
|
||||
assert StorageType.local_storage == "local_storage"
|
||||
assert StorageType.session_storage == "session_storage"
|
||||
assert StorageType.indexed_db == "indexed_db"
|
||||
|
||||
|
||||
class TestReviewStatus:
|
||||
def test_values(self):
|
||||
assert ReviewStatus.pending == "pending"
|
||||
assert ReviewStatus.approved == "approved"
|
||||
assert ReviewStatus.rejected == "rejected"
|
||||
|
||||
|
||||
class TestCookieCreate:
|
||||
def test_valid_minimal(self):
|
||||
schema = CookieCreate(name="_ga", domain=".example.com")
|
||||
assert schema.name == "_ga"
|
||||
assert schema.domain == ".example.com"
|
||||
assert schema.storage_type == StorageType.cookie
|
||||
assert schema.category_id is None
|
||||
|
||||
def test_valid_full(self):
|
||||
cat_id = uuid.uuid4()
|
||||
schema = CookieCreate(
|
||||
name="_ga",
|
||||
domain=".google.com",
|
||||
storage_type=StorageType.cookie,
|
||||
category_id=cat_id,
|
||||
description="Google Analytics cookie",
|
||||
vendor="Google",
|
||||
path="/",
|
||||
max_age_seconds=63072000,
|
||||
is_http_only=False,
|
||||
is_secure=True,
|
||||
same_site="Lax",
|
||||
)
|
||||
assert schema.category_id == cat_id
|
||||
assert schema.max_age_seconds == 63072000
|
||||
|
||||
def test_rejects_empty_name(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CookieCreate(name="", domain=".example.com")
|
||||
|
||||
def test_rejects_empty_domain(self):
|
||||
with pytest.raises(ValidationError):
|
||||
CookieCreate(name="_ga", domain="")
|
||||
|
||||
|
||||
class TestCookieUpdate:
|
||||
def test_partial_update(self):
|
||||
schema = CookieUpdate(review_status=ReviewStatus.approved)
|
||||
dump = schema.model_dump(exclude_unset=True)
|
||||
assert dump == {"review_status": ReviewStatus.approved}
|
||||
|
||||
def test_update_category(self):
|
||||
cat_id = uuid.uuid4()
|
||||
schema = CookieUpdate(category_id=cat_id)
|
||||
assert schema.category_id == cat_id
|
||||
|
||||
|
||||
class TestAllowListEntryCreate:
|
||||
def test_valid(self):
|
||||
cat_id = uuid.uuid4()
|
||||
schema = AllowListEntryCreate(
|
||||
name_pattern="_ga*",
|
||||
domain_pattern=".google.com",
|
||||
category_id=cat_id,
|
||||
description="Google Analytics cookies",
|
||||
)
|
||||
assert schema.name_pattern == "_ga*"
|
||||
assert schema.category_id == cat_id
|
||||
|
||||
def test_rejects_empty_name_pattern(self):
|
||||
with pytest.raises(ValidationError):
|
||||
AllowListEntryCreate(
|
||||
name_pattern="",
|
||||
domain_pattern=".example.com",
|
||||
category_id=uuid.uuid4(),
|
||||
)
|
||||
|
||||
|
||||
class TestAllowListEntryUpdate:
|
||||
def test_partial_update(self):
|
||||
schema = AllowListEntryUpdate(description="Updated description")
|
||||
dump = schema.model_dump(exclude_unset=True)
|
||||
assert dump == {"description": "Updated description"}
|
||||
|
||||
|
||||
class TestCookieCategoryResponse:
|
||||
def test_from_dict(self):
|
||||
now = "2024-01-01T00:00:00"
|
||||
resp = CookieCategoryResponse(
|
||||
id=uuid.uuid4(),
|
||||
name="Analytics",
|
||||
slug="analytics",
|
||||
description="Analytics cookies",
|
||||
is_essential=False,
|
||||
display_order=2,
|
||||
tcf_purpose_ids=[1, 3],
|
||||
gcm_consent_types=["analytics_storage"],
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.slug == "analytics"
|
||||
assert resp.is_essential is False
|
||||
|
||||
|
||||
class TestCookieResponse:
|
||||
def test_from_dict(self):
|
||||
now = "2024-01-01T00:00:00"
|
||||
resp = CookieResponse(
|
||||
id=uuid.uuid4(),
|
||||
site_id=uuid.uuid4(),
|
||||
name="_ga",
|
||||
domain=".google.com",
|
||||
storage_type="cookie",
|
||||
review_status="pending",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
)
|
||||
assert resp.name == "_ga"
|
||||
assert resp.review_status == "pending"
|
||||
|
||||
|
||||
# ─── Route tests ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCookieCategoryRoutes:
|
||||
def test_categories_route_registered(self, app):
|
||||
"""Verify the categories routes are registered in the app."""
|
||||
routes = [r.path for r in app.routes]
|
||||
assert "/api/v1/cookies/categories" in routes
|
||||
assert "/api/v1/cookies/categories/{category_id}" in routes
|
||||
|
||||
|
||||
class TestCookieRoutes:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_cookies_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site_id}")
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cookie_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
json={"name": "_ga", "domain": ".google.com"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_cookie_rejects_invalid_body(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}",
|
||||
json={"name": "", "domain": ""},
|
||||
headers={"Authorization": "Bearer fake-token"},
|
||||
)
|
||||
# Should return 401 (bad token) or 422 (validation)
|
||||
assert resp.status_code in (401, 422)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_summary_route_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/summary")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
class TestAllowListRoutes:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_allow_list_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.get(f"/api/v1/cookies/sites/{site_id}/allow-list")
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_allow_list_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
resp = await client.post(
|
||||
f"/api/v1/cookies/sites/{site_id}/allow-list",
|
||||
json={
|
||||
"name_pattern": "_ga*",
|
||||
"domain_pattern": ".google.com",
|
||||
"category_id": str(uuid.uuid4()),
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_allow_list_requires_auth(self, client):
|
||||
site_id = uuid.uuid4()
|
||||
entry_id = uuid.uuid4()
|
||||
resp = await client.delete(f"/api/v1/cookies/sites/{site_id}/allow-list/{entry_id}")
|
||||
assert resp.status_code == 401
|
||||
222
apps/api/tests/test_cors.py
Normal file
222
apps/api/tests/test_cors.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Tests for the dynamic CORS origin validation service."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.services.cors import extract_domain_from_origin, get_allowed_domains, is_origin_allowed
|
||||
|
||||
|
||||
class TestExtractDomainFromOrigin:
|
||||
def test_https_origin(self):
|
||||
assert extract_domain_from_origin("https://example.com") == "example.com"
|
||||
|
||||
def test_http_origin(self):
|
||||
assert extract_domain_from_origin("http://example.com") == "example.com"
|
||||
|
||||
def test_origin_with_port(self):
|
||||
assert extract_domain_from_origin("https://example.com:443") == "example.com"
|
||||
|
||||
def test_origin_with_subdomain(self):
|
||||
assert extract_domain_from_origin("https://www.example.com") == "www.example.com"
|
||||
|
||||
def test_localhost(self):
|
||||
assert extract_domain_from_origin("http://localhost:5173") == "localhost"
|
||||
|
||||
def test_empty_string(self):
|
||||
assert extract_domain_from_origin("") is None
|
||||
|
||||
def test_invalid_url(self):
|
||||
# urlparse is lenient, but hostname may be None for really bad input
|
||||
result = extract_domain_from_origin("not-a-url")
|
||||
# urlparse("not-a-url") sets hostname to None
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsOriginAllowed:
|
||||
def test_static_origin_exact_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"http://localhost:5173",
|
||||
["http://localhost:5173"],
|
||||
set(),
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_static_origin_no_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://evil.com",
|
||||
["http://localhost:5173"],
|
||||
set(),
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_wildcard_allows_everything(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://anything.com",
|
||||
["*"],
|
||||
set(),
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_registered_domain_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://example.com",
|
||||
[],
|
||||
{"example.com", "other.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_registered_domain_case_insensitive(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://Example.COM",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_registered_domain_no_match(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://evil.com",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_static_takes_priority(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"http://localhost:5173",
|
||||
["http://localhost:5173"],
|
||||
{"example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_origin_with_port_matches_domain(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://example.com:8443",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_subdomain_matches_if_registered(self):
|
||||
# www.example.com only matches if explicitly registered
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://www.example.com",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_subdomain_matches_when_registered(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://www.example.com",
|
||||
[],
|
||||
{"www.example.com"},
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_empty_origin(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"",
|
||||
[],
|
||||
{"example.com"},
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_empty_lists(self):
|
||||
assert (
|
||||
is_origin_allowed(
|
||||
"https://example.com",
|
||||
[],
|
||||
set(),
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
class TestGetAllowedDomains:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_primary_domains(self):
|
||||
row1 = MagicMock()
|
||||
row1.domain = "example.com"
|
||||
row1.additional_domains = None
|
||||
|
||||
row2 = MagicMock()
|
||||
row2.domain = "other.com"
|
||||
row2.additional_domains = None
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = [row1, row2]
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert "example.com" in domains
|
||||
assert "other.com" in domains
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_includes_additional_domains(self):
|
||||
row = MagicMock()
|
||||
row.domain = "example.com"
|
||||
row.additional_domains = ["www.example.com", "app.example.com"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = [row]
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert "example.com" in domains
|
||||
assert "www.example.com" in domains
|
||||
assert "app.example.com" in domains
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lowercases_domains(self):
|
||||
row = MagicMock()
|
||||
row.domain = "Example.COM"
|
||||
row.additional_domains = ["WWW.Example.COM"]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = [row]
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert "example.com" in domains
|
||||
assert "www.example.com" in domains
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_result(self):
|
||||
mock_result = MagicMock()
|
||||
mock_result.all.return_value = []
|
||||
|
||||
db = AsyncMock()
|
||||
db.execute = AsyncMock(return_value=mock_result)
|
||||
|
||||
domains = await get_allowed_domains(db)
|
||||
assert domains == set()
|
||||
88
apps/api/tests/test_dependencies.py
Normal file
88
apps/api/tests/test_dependencies.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Unit tests for auth dependencies."""
|
||||
|
||||
import uuid
|
||||
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.services.auth import create_access_token, create_refresh_token, decode_token
|
||||
|
||||
|
||||
class TestCurrentUser:
|
||||
def test_has_role_matching(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="admin",
|
||||
)
|
||||
assert user.has_role("admin", "owner") is True
|
||||
|
||||
def test_has_role_not_matching(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="viewer",
|
||||
)
|
||||
assert user.has_role("admin", "owner") is False
|
||||
|
||||
def test_is_admin_property(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="admin",
|
||||
)
|
||||
assert user.is_admin is True
|
||||
|
||||
def test_is_admin_owner(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="owner",
|
||||
)
|
||||
assert user.is_admin is True
|
||||
|
||||
def test_is_admin_viewer(self):
|
||||
user = CurrentUser(
|
||||
id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
email="test@test.com",
|
||||
role="viewer",
|
||||
)
|
||||
assert user.is_admin is False
|
||||
|
||||
|
||||
class TestTokenCreation:
|
||||
def test_access_token_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
token = create_access_token(
|
||||
user_id=user_id,
|
||||
organisation_id=org_id,
|
||||
role="editor",
|
||||
email="test@test.com",
|
||||
)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["org_id"] == str(org_id)
|
||||
assert payload["role"] == "editor"
|
||||
assert payload["type"] == "access"
|
||||
|
||||
def test_refresh_token_roundtrip(self):
|
||||
user_id = uuid.uuid4()
|
||||
org_id = uuid.uuid4()
|
||||
token = create_refresh_token(user_id=user_id, organisation_id=org_id)
|
||||
payload = decode_token(token)
|
||||
assert payload["sub"] == str(user_id)
|
||||
assert payload["type"] == "refresh"
|
||||
|
||||
def test_access_token_is_not_refresh(self):
|
||||
token = create_access_token(
|
||||
user_id=uuid.uuid4(),
|
||||
organisation_id=uuid.uuid4(),
|
||||
role="viewer",
|
||||
email="test@test.com",
|
||||
)
|
||||
payload = decode_token(token)
|
||||
assert payload["type"] != "refresh"
|
||||
141
apps/api/tests/test_extensions.py
Normal file
141
apps/api/tests/test_extensions.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""Tests for the extension registry and edition detection."""
|
||||
|
||||
import pytest
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
from src.config.edition import edition_name, is_ee
|
||||
from src.extensions.registry import (
|
||||
ExtensionRegistry,
|
||||
OpenAPITag,
|
||||
discover_extensions,
|
||||
get_registry,
|
||||
)
|
||||
|
||||
# -- Edition detection -------------------------------------------------------
|
||||
|
||||
|
||||
class TestEditionDetection:
|
||||
"""The ``is_ee()`` / ``edition_name()`` helpers should return a
|
||||
consistent pair regardless of which edition is installed. Core tests
|
||||
don't assume a specific edition — that's checked in each repo's
|
||||
own integration tests."""
|
||||
|
||||
def test_edition_name_matches_is_ee(self):
|
||||
assert edition_name() == ("ee" if is_ee() else "ce")
|
||||
|
||||
def test_edition_name_is_valid(self):
|
||||
assert edition_name() in ("ce", "ee")
|
||||
|
||||
|
||||
# -- Extension registry (unit) ----------------------------------------------
|
||||
|
||||
|
||||
class TestExtensionRegistry:
|
||||
def _make_registry(self) -> ExtensionRegistry:
|
||||
return ExtensionRegistry()
|
||||
|
||||
def test_empty_registry(self):
|
||||
reg = self._make_registry()
|
||||
assert reg.routers == []
|
||||
assert reg.model_modules == []
|
||||
assert reg.startup_hooks == []
|
||||
|
||||
def test_add_router(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
reg.add_router(router, prefix="/api/v1")
|
||||
assert len(reg.routers) == 1
|
||||
assert reg.routers[0].router is router
|
||||
assert reg.routers[0].prefix == "/api/v1"
|
||||
|
||||
def test_add_router_with_tags(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
tag = OpenAPITag(name="billing", description="Billing endpoints")
|
||||
reg.add_router(router, tags=[tag])
|
||||
assert reg.routers[0].tags == [tag]
|
||||
|
||||
def test_add_model_module(self):
|
||||
reg = self._make_registry()
|
||||
reg.add_model_module("ee.api.src.models.billing")
|
||||
assert reg.model_modules == ["ee.api.src.models.billing"]
|
||||
|
||||
def test_add_startup_hook(self):
|
||||
reg = self._make_registry()
|
||||
|
||||
async def hook(app: FastAPI) -> None:
|
||||
pass
|
||||
|
||||
reg.add_startup_hook(hook)
|
||||
assert len(reg.startup_hooks) == 1
|
||||
|
||||
def test_apply_mounts_routers(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
|
||||
@router.get("/test")
|
||||
async def _test() -> dict[str, str]:
|
||||
return {"ok": True}
|
||||
|
||||
reg.add_router(router, prefix="/ext")
|
||||
|
||||
app = FastAPI()
|
||||
reg.apply(app)
|
||||
|
||||
# The router should be included in the app routes
|
||||
paths = [r.path for r in app.routes]
|
||||
assert "/ext/test" in paths
|
||||
|
||||
def test_apply_adds_openapi_tags(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
tag = OpenAPITag(name="billing", description="Billing endpoints")
|
||||
reg.add_router(router, tags=[tag])
|
||||
|
||||
app = FastAPI()
|
||||
app.openapi_tags = []
|
||||
reg.apply(app)
|
||||
|
||||
assert any(t["name"] == "billing" for t in app.openapi_tags)
|
||||
|
||||
def test_apply_skips_duplicate_tags(self):
|
||||
reg = self._make_registry()
|
||||
router = APIRouter()
|
||||
tag = OpenAPITag(name="billing", description="Billing endpoints")
|
||||
reg.add_router(router, tags=[tag])
|
||||
|
||||
app = FastAPI()
|
||||
app.openapi_tags = [{"name": "billing", "description": "Existing"}]
|
||||
reg.apply(app)
|
||||
|
||||
billing_tags = [t for t in app.openapi_tags if t["name"] == "billing"]
|
||||
assert len(billing_tags) == 1
|
||||
assert billing_tags[0]["description"] == "Existing"
|
||||
|
||||
|
||||
# -- discover_extensions -----------------------------------------------------
|
||||
|
||||
|
||||
class TestDiscoverExtensions:
|
||||
def test_discover_extensions_does_not_raise(self):
|
||||
"""discover_extensions should not raise regardless of edition."""
|
||||
discover_extensions()
|
||||
|
||||
|
||||
# -- Global registry ---------------------------------------------------------
|
||||
|
||||
|
||||
class TestGlobalRegistry:
|
||||
def test_get_registry_returns_singleton(self):
|
||||
assert get_registry() is get_registry()
|
||||
|
||||
|
||||
# -- Health endpoint with edition field --------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_reports_edition(client):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["edition"] in ("ce", "ee")
|
||||
573
apps/api/tests/test_geoip.py
Normal file
573
apps/api/tests/test_geoip.py
Normal file
@@ -0,0 +1,573 @@
|
||||
"""Tests for the GeoIP service.
|
||||
|
||||
Covers header-based detection, IP lookup, country-to-region mapping,
|
||||
client IP extraction, and the combined detect_region flow.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
import src.services.geoip as geoip_module
|
||||
from src.services.geoip import (
|
||||
GeoResult,
|
||||
_is_private_ip,
|
||||
country_to_region,
|
||||
detect_region,
|
||||
detect_region_from_headers,
|
||||
get_client_ip,
|
||||
lookup_ip_maxmind,
|
||||
lookup_ip_region,
|
||||
)
|
||||
|
||||
# ── country_to_region ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCountryToRegion:
|
||||
def test_eu_country_returns_eu(self):
|
||||
assert country_to_region("DE") == "EU"
|
||||
assert country_to_region("FR") == "EU"
|
||||
assert country_to_region("IT") == "EU"
|
||||
assert country_to_region("ES") == "EU"
|
||||
|
||||
def test_eu_country_case_insensitive(self):
|
||||
assert country_to_region("de") == "EU"
|
||||
assert country_to_region("fr") == "EU"
|
||||
|
||||
def test_gb_returns_gb(self):
|
||||
assert country_to_region("GB") == "GB"
|
||||
|
||||
def test_br_returns_br(self):
|
||||
assert country_to_region("BR") == "BR"
|
||||
|
||||
def test_us_without_state(self):
|
||||
assert country_to_region("US") == "US"
|
||||
|
||||
def test_us_with_state(self):
|
||||
assert country_to_region("US", "CA") == "US-CA"
|
||||
assert country_to_region("US", "ny") == "US-NY"
|
||||
|
||||
def test_non_eu_country_returned_as_is(self):
|
||||
assert country_to_region("JP") == "JP"
|
||||
assert country_to_region("AU") == "AU"
|
||||
assert country_to_region("CA") == "CA"
|
||||
|
||||
|
||||
# ── detect_region_from_headers ───────────────────────────────────────
|
||||
|
||||
|
||||
class TestDetectRegionFromHeaders:
|
||||
def _make_request(self, headers: dict[str, str]) -> MagicMock:
|
||||
request = MagicMock()
|
||||
request.headers = headers
|
||||
return request
|
||||
|
||||
def test_cloudflare_header(self):
|
||||
request = self._make_request({"cf-ipcountry": "DE"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "DE"
|
||||
assert result.region == "EU"
|
||||
assert result.is_resolved is True
|
||||
|
||||
def test_vercel_header(self):
|
||||
request = self._make_request({"x-vercel-ip-country": "GB"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "GB"
|
||||
assert result.region == "GB"
|
||||
|
||||
def test_appengine_header(self):
|
||||
request = self._make_request({"x-appengine-country": "BR"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "BR"
|
||||
assert result.region == "BR"
|
||||
|
||||
def test_custom_header(self):
|
||||
request = self._make_request({"x-country-code": "JP"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "JP"
|
||||
assert result.region == "JP"
|
||||
|
||||
def test_no_geo_headers(self):
|
||||
request = self._make_request({})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code is None
|
||||
assert result.region is None
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_ignores_xx_value(self):
|
||||
request = self._make_request({"cf-ipcountry": "XX"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_header_priority_cloudflare_first(self):
|
||||
request = self._make_request(
|
||||
{
|
||||
"cf-ipcountry": "FR",
|
||||
"x-vercel-ip-country": "DE",
|
||||
}
|
||||
)
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "FR"
|
||||
|
||||
def test_case_normalisation(self):
|
||||
request = self._make_request({"cf-ipcountry": "gb"})
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "GB"
|
||||
assert result.region == "GB"
|
||||
|
||||
def test_configured_custom_header(self):
|
||||
"""An operator-configured header is honoured."""
|
||||
request = self._make_request({"x-gclb-country": "JP"})
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "JP"
|
||||
assert result.region == "JP"
|
||||
|
||||
def test_configured_custom_header_takes_priority(self):
|
||||
"""When both a custom and a built-in header are present, the
|
||||
custom one wins — that's the operator's explicit choice."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"cf-ipcountry": "FR",
|
||||
"x-gclb-country": "JP",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "JP"
|
||||
|
||||
def test_configured_header_falls_through_to_builtin(self):
|
||||
"""If the custom header isn't present, the built-in list still
|
||||
applies."""
|
||||
request = self._make_request({"cf-ipcountry": "FR"})
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = None
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "FR"
|
||||
assert result.region == "EU"
|
||||
|
||||
def test_configured_region_header_pairs_with_country(self):
|
||||
"""A configured region header is paired with the custom country."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-gclb-country": "US",
|
||||
"x-gclb-region": "CA",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
|
||||
def test_configured_region_header_strips_country_prefix(self):
|
||||
"""ISO 3166-2 subdivisions may arrive prefixed (``US-CA``)."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-gclb-country": "US",
|
||||
"x-gclb-region": "US-NY",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.region == "US-NY"
|
||||
|
||||
def test_configured_region_header_missing_is_country_only(self):
|
||||
"""Only country hits region-aware path if the region header is absent."""
|
||||
request = self._make_request({"x-gclb-country": "US"})
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US"
|
||||
|
||||
def test_configured_region_header_xx_ignored(self):
|
||||
"""Region value of ``XX`` is treated as unknown."""
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-gclb-country": "US",
|
||||
"x-gclb-region": "XX",
|
||||
}
|
||||
)
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_country_header = "x-gclb-country"
|
||||
mock_settings.return_value.geoip_region_header = "x-gclb-region"
|
||||
result = detect_region_from_headers(request)
|
||||
assert result.region == "US"
|
||||
|
||||
|
||||
# ── get_client_ip ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGetClientIp:
|
||||
def _make_request(
|
||||
self,
|
||||
headers: dict[str, str] | None = None,
|
||||
client_host: str | None = None,
|
||||
) -> MagicMock:
|
||||
request = MagicMock()
|
||||
request.headers = headers or {}
|
||||
if client_host:
|
||||
request.client = MagicMock()
|
||||
request.client.host = client_host
|
||||
else:
|
||||
request.client = None
|
||||
return request
|
||||
|
||||
def test_x_forwarded_for_single(self):
|
||||
request = self._make_request({"x-forwarded-for": "1.2.3.4"})
|
||||
assert get_client_ip(request) == "1.2.3.4"
|
||||
|
||||
def test_x_forwarded_for_multiple(self):
|
||||
request = self._make_request({"x-forwarded-for": "1.2.3.4, 5.6.7.8, 9.10.11.12"})
|
||||
assert get_client_ip(request) == "1.2.3.4"
|
||||
|
||||
def test_x_real_ip(self):
|
||||
request = self._make_request({"x-real-ip": "1.2.3.4"})
|
||||
assert get_client_ip(request) == "1.2.3.4"
|
||||
|
||||
def test_forwarded_for_takes_priority_over_real_ip(self):
|
||||
request = self._make_request(
|
||||
{
|
||||
"x-forwarded-for": "1.1.1.1",
|
||||
"x-real-ip": "2.2.2.2",
|
||||
}
|
||||
)
|
||||
assert get_client_ip(request) == "1.1.1.1"
|
||||
|
||||
def test_falls_back_to_client_host(self):
|
||||
request = self._make_request(client_host="10.0.0.1")
|
||||
assert get_client_ip(request) == "10.0.0.1"
|
||||
|
||||
def test_returns_none_when_no_ip(self):
|
||||
request = self._make_request()
|
||||
assert get_client_ip(request) is None
|
||||
|
||||
|
||||
# ── _is_private_ip ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestIsPrivateIp:
|
||||
def test_loopback(self):
|
||||
assert _is_private_ip("127.0.0.1") is True
|
||||
assert _is_private_ip("127.0.0.2") is True
|
||||
|
||||
def test_private_ranges(self):
|
||||
assert _is_private_ip("10.0.0.1") is True
|
||||
assert _is_private_ip("192.168.1.1") is True
|
||||
assert _is_private_ip("172.16.0.1") is True
|
||||
|
||||
def test_ipv6_loopback(self):
|
||||
assert _is_private_ip("::1") is True
|
||||
|
||||
def test_localhost_string(self):
|
||||
assert _is_private_ip("localhost") is True
|
||||
|
||||
def test_public_ip(self):
|
||||
assert _is_private_ip("8.8.8.8") is False
|
||||
assert _is_private_ip("1.1.1.1") is False
|
||||
|
||||
|
||||
# ── lookup_ip_region ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLookupIpRegion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_ip_returns_unresolved(self):
|
||||
result = await lookup_ip_region("127.0.0.1")
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_private_ip_10_range(self):
|
||||
result = await lookup_ip_region("10.0.0.1")
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_lookup(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "success",
|
||||
"countryCode": "DE",
|
||||
"region": "BY",
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.country_code == "DE"
|
||||
assert result.region == "EU"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failed_status(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "fail", "message": "invalid query"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_error(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_network_exception(self):
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(side_effect=httpx.ConnectError("Connection refused"))
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_us_with_state_lookup(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "success",
|
||||
"countryCode": "US",
|
||||
"region": "CA",
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_country_code(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"status": "success"}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await lookup_ip_region("8.8.8.8")
|
||||
|
||||
assert result.is_resolved is False
|
||||
|
||||
|
||||
# ── detect_region (combined) ─────────────────────────────────────────
|
||||
|
||||
|
||||
class TestDetectRegion:
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_headers_when_available(self):
|
||||
request = MagicMock()
|
||||
request.headers = {"cf-ipcountry": "FR"}
|
||||
request.client = None
|
||||
|
||||
result = await detect_region(request)
|
||||
assert result.country_code == "FR"
|
||||
assert result.region == "EU"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_ip_lookup(self):
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.client = MagicMock()
|
||||
request.client.host = "8.8.8.8"
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"status": "success",
|
||||
"countryCode": "US",
|
||||
"region": "CA",
|
||||
}
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.get = AsyncMock(return_value=mock_response)
|
||||
mock_client.__aenter__ = AsyncMock(return_value=mock_client)
|
||||
mock_client.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("src.services.geoip.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await detect_region(request)
|
||||
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_unresolved_when_no_ip(self):
|
||||
request = MagicMock()
|
||||
request.headers = {}
|
||||
request.client = None
|
||||
|
||||
result = await detect_region(request)
|
||||
assert result.is_resolved is False
|
||||
|
||||
|
||||
# ── GeoResult ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestGeoResult:
|
||||
def test_is_resolved_true(self):
|
||||
result = GeoResult(country_code="GB", region="GB")
|
||||
assert result.is_resolved is True
|
||||
|
||||
def test_is_resolved_false(self):
|
||||
result = GeoResult(country_code=None, region=None)
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_frozen_dataclass(self):
|
||||
result = GeoResult(country_code="GB", region="GB")
|
||||
with pytest.raises(AttributeError):
|
||||
result.country_code = "US" # type: ignore[misc]
|
||||
|
||||
|
||||
# ── MaxMind database lookup ──────────────────────────────────────────
|
||||
|
||||
|
||||
class TestLookupIpMaxmind:
|
||||
def setup_method(self):
|
||||
# Reset the module-level cache so each test starts clean.
|
||||
geoip_module._maxmind_reader = None
|
||||
geoip_module._maxmind_initialised = False
|
||||
|
||||
def _mock_reader(self, country_iso: str | None, subdivision_iso: str | None):
|
||||
reader = MagicMock()
|
||||
response = MagicMock()
|
||||
response.country.iso_code = country_iso
|
||||
if subdivision_iso is None:
|
||||
response.subdivisions = None
|
||||
else:
|
||||
response.subdivisions.most_specific.iso_code = subdivision_iso
|
||||
reader.city.return_value = response
|
||||
return reader
|
||||
|
||||
def test_private_ip_returns_unresolved(self):
|
||||
result = lookup_ip_maxmind("10.0.0.1")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_no_db_configured_returns_unresolved(self):
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_maxmind_db_path = None
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_successful_lookup_with_subdivision(self):
|
||||
reader = self._mock_reader("US", "CA")
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.country_code == "US"
|
||||
assert result.region == "US-CA"
|
||||
reader.city.assert_called_once_with("8.8.8.8")
|
||||
|
||||
def test_successful_lookup_without_subdivision(self):
|
||||
reader = self._mock_reader("DE", None)
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.country_code == "DE"
|
||||
assert result.region == "EU"
|
||||
|
||||
def test_reader_raises_returns_unresolved(self):
|
||||
reader = MagicMock()
|
||||
reader.city.side_effect = RuntimeError("corrupt db")
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_reader_missing_country_returns_unresolved(self):
|
||||
reader = self._mock_reader(None, None)
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
result = lookup_ip_maxmind("8.8.8.8")
|
||||
assert result.is_resolved is False
|
||||
|
||||
def test_bad_db_path_is_cached_as_failure(self):
|
||||
"""A missing ``.mmdb`` file should not reopen on every request."""
|
||||
with patch("src.services.geoip.get_settings") as mock_settings:
|
||||
mock_settings.return_value.geoip_maxmind_db_path = "/nonexistent/geo.mmdb"
|
||||
r1 = lookup_ip_maxmind("8.8.8.8")
|
||||
r2 = lookup_ip_maxmind("1.1.1.1")
|
||||
assert r1.is_resolved is False
|
||||
assert r2.is_resolved is False
|
||||
assert geoip_module._maxmind_initialised is True
|
||||
assert geoip_module._maxmind_reader is None
|
||||
|
||||
|
||||
class TestDetectRegionMaxmind:
|
||||
def setup_method(self):
|
||||
geoip_module._maxmind_reader = None
|
||||
geoip_module._maxmind_initialised = False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uses_maxmind_before_external_api(self):
|
||||
"""With MaxMind configured, ip-api.com must not be called."""
|
||||
reader = MagicMock()
|
||||
response = MagicMock()
|
||||
response.country.iso_code = "GB"
|
||||
response.subdivisions.most_specific.iso_code = "SCT"
|
||||
reader.city.return_value = response
|
||||
geoip_module._maxmind_reader = reader
|
||||
geoip_module._maxmind_initialised = True
|
||||
|
||||
request = MagicMock()
|
||||
request.headers = {"x-forwarded-for": "8.8.8.8"}
|
||||
request.client = None
|
||||
|
||||
with (
|
||||
patch("src.services.geoip.get_settings") as mock_settings,
|
||||
patch("src.services.geoip.httpx.AsyncClient") as mock_httpx,
|
||||
):
|
||||
mock_settings.return_value.geoip_country_header = None
|
||||
mock_settings.return_value.geoip_region_header = None
|
||||
mock_settings.return_value.geoip_maxmind_db_path = "/data/GeoLite2-City.mmdb"
|
||||
|
||||
result = await detect_region(request)
|
||||
|
||||
assert result.country_code == "GB"
|
||||
assert result.region == "GB-SCT"
|
||||
mock_httpx.assert_not_called()
|
||||
31
apps/api/tests/test_health.py
Normal file
31
apps/api/tests/test_health.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_endpoint(client):
|
||||
response = await client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["edition"] in ("ce", "ee")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_schema(client):
|
||||
response = await client.get("/openapi.json")
|
||||
assert response.status_code == 200
|
||||
schema = response.json()
|
||||
assert schema["info"]["title"] == "ConsentOS API"
|
||||
assert schema["info"]["version"] == "0.1.0"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_routes_registered(client):
|
||||
response = await client.get("/openapi.json")
|
||||
paths = response.json()["paths"]
|
||||
assert "/health" in paths
|
||||
assert "/api/v1/auth/login" in paths
|
||||
assert "/api/v1/config/sites/{site_id}" in paths
|
||||
assert "/api/v1/consent/" in paths
|
||||
assert "/api/v1/scanner/scans" in paths
|
||||
assert "/api/v1/compliance/check/{site_id}" in paths
|
||||
89
apps/api/tests/test_integration_auth.py
Normal file
89
apps/api/tests/test_integration_auth.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Integration tests for authentication endpoints (requires database)."""
|
||||
|
||||
from tests.conftest import requires_db
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestAuthLogin:
|
||||
async def test_login_success(self, db_client, test_user):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": test_user.email,
|
||||
"password": "TestPassword123",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "access_token" in data
|
||||
assert "refresh_token" in data
|
||||
assert data["token_type"] == "bearer"
|
||||
|
||||
async def test_login_wrong_password(self, db_client, test_user):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": test_user.email,
|
||||
"password": "wrong",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_login_nonexistent_user(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": "nobody@test.com",
|
||||
"password": "anything",
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_login_invalid_email(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"email": "not-an-email", "password": "anything"},
|
||||
)
|
||||
assert resp.status_code == 422
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestAuthMe:
|
||||
async def test_me_returns_user(self, db_client, auth_headers, test_user):
|
||||
resp = await db_client.get("/api/v1/auth/me", headers=auth_headers)
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["email"] == test_user.email
|
||||
assert data["role"] == "owner"
|
||||
|
||||
async def test_me_without_token(self, db_client):
|
||||
resp = await db_client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@requires_db
|
||||
class TestAuthRefresh:
|
||||
async def test_refresh_returns_new_tokens(self, db_client, test_user):
|
||||
# First login to get a refresh token
|
||||
login_resp = await db_client.post(
|
||||
"/api/v1/auth/login",
|
||||
json={
|
||||
"email": test_user.email,
|
||||
"password": "TestPassword123",
|
||||
},
|
||||
)
|
||||
refresh_token = login_resp.json()["refresh_token"]
|
||||
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": refresh_token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert "access_token" in resp.json()
|
||||
|
||||
async def test_refresh_with_invalid_token(self, db_client):
|
||||
resp = await db_client.post(
|
||||
"/api/v1/auth/refresh",
|
||||
json={"refresh_token": "invalid-token"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user