feat: initial public release

ConsentOS — a privacy-first cookie consent management platform.

Self-hosted, source-available alternative to OneTrust, Cookiebot, and
CookieYes. Full standards coverage (IAB TCF v2.2, GPP v1, Google
Consent Mode v2, GPC, Shopify Customer Privacy API), multi-tenant
architecture with role-based access, configuration cascade
(system → org → group → site → region), dark-pattern detection in
the scanner, and a tamper-evident consent record audit trail.

This is the initial public release. Prior development history is
retained internally.

See README.md for the feature list, architecture overview, and
quick-start instructions. Licensed under the Elastic Licence 2.0 —
self-host freely; do not resell as a managed service.
This commit is contained in:
James Cottrill
2026-04-13 14:20:15 +00:00
commit fbf26453f2
341 changed files with 62807 additions and 0 deletions

0
apps/api/src/__init__.py Normal file
View File

View File

@@ -0,0 +1,89 @@
"""Celery application and task definitions for the CMP API.
Provides async-compatible scan scheduling via Celery with Redis as the
broker and result backend.
"""
import ssl
from celery import Celery
from celery.schedules import crontab
from src.config.settings import get_settings
settings = get_settings()
# Named `app` by Celery convention — the CLI finds it via -A src.celery_app
app = Celery(
"cmp",
broker=settings.redis_url,
backend=settings.redis_url,
)
# When using rediss:// (TLS) — e.g. Upstash — Celery requires explicit
# SSL certificate verification settings for both broker and backend.
_conf: dict = {
"task_serializer": "json",
"accept_content": ["json"],
"result_serializer": "json",
"timezone": "UTC",
"enable_utc": True,
"task_track_started": True,
"task_acks_late": True,
"worker_prefetch_multiplier": 1,
}
if settings.redis_url.startswith("rediss://"):
_conf["broker_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
_conf["redis_backend_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
app.conf.update(**_conf)
# ── Beat schedule (periodic tasks) ──────────────────────────────────
app.conf.beat_schedule = {
"check-scheduled-scans": {
"task": "src.tasks.scanner.check_scheduled_scans",
"schedule": crontab(minute="*/15"), # Every 15 minutes
},
"recover-stale-scans": {
"task": "src.tasks.scanner.recover_stale_scans",
"schedule": crontab(minute="*/5"), # Every 5 minutes
},
"purge-expired-consent-records": {
"task": "src.tasks.retention.purge_expired_consent_records",
"schedule": crontab(hour="1", minute="0"), # Daily at 01:00 UTC
},
}
# ── Explicit task imports ───────────────────────────────────────────
# Must be at the bottom to avoid circular imports. These ensure the
# worker process registers all @app.task definitions on startup.
import src.tasks.retention # noqa: E402
import src.tasks.scanner # noqa: E402, F401
# EE tasks are registered conditionally — they only exist in EE mode.
try:
import ee.api.src.tasks.compliance_scanner
import ee.api.src.tasks.compliance_scoring
import ee.api.src.tasks.retention # noqa: F401
app.conf.beat_schedule.update(
{
"check-scheduled-compliance-scans": {
"task": "src.tasks.compliance_scanner.check_scheduled_compliance_scans",
"schedule": crontab(hour="3", minute="0"),
},
"compute-daily-compliance-scores": {
"task": "src.tasks.compliance_scoring.compute_daily_scores",
"schedule": crontab(hour="4", minute="0"),
},
"run-retention-purge": {
"task": "src.tasks.retention.run_retention_purge",
"schedule": crontab(hour="2", minute="0"),
},
}
)
except ImportError:
pass

View File

View File

@@ -0,0 +1,40 @@
"""One-shot bootstrap of an initial organisation and owner user.
Usage:
python -m src.cli.bootstrap_admin
Reads ``INITIAL_ADMIN_EMAIL`` and ``INITIAL_ADMIN_PASSWORD`` (plus the
optional ``INITIAL_ADMIN_FULL_NAME``, ``INITIAL_ORG_NAME``, and
``INITIAL_ORG_SLUG``) from the environment. If the ``users`` table is
empty and both credentials are set, creates the org and owner user so
the operator can log in to the admin UI. Idempotent: if any user
already exists, exits 0 without touching the database.
Intended to be run as a one-shot init container *after* the database
migrations have been applied — typically via ``depends_on`` with
``service_healthy`` on the API container.
"""
from __future__ import annotations
import asyncio
import sys
from src.config.logging import setup_logging
from src.config.settings import get_settings
from src.services.bootstrap import bootstrap_initial_admin
async def _main() -> int:
settings = get_settings()
setup_logging(settings.log_level)
await bootstrap_initial_admin(settings)
return 0
def main() -> None:
sys.exit(asyncio.run(_main()))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,137 @@
"""Seed the known_cookies table from the Open Cookie Database CSV.
Usage:
python -m src.cli.seed_known_cookies [--csv PATH] [--clear]
The Open Cookie Database is a community-maintained catalogue of ~2,200+
cookie patterns. See https://github.com/jkwakman/Open-Cookie-Database
"""
from __future__ import annotations
import argparse
import csv
import sys
import uuid
from pathlib import Path
import sqlalchemy as sa
# ---------------------------------------------------------------------------
# Category mapping: Open Cookie Database category → CMP slug
# ---------------------------------------------------------------------------
_CATEGORY_MAP: dict[str, str] = {
"Functional": "functional",
"Analytics": "analytics",
"Marketing": "marketing",
"Personalization": "personalisation",
"Security": "necessary",
}
_DEFAULT_CSV = Path(__file__).resolve().parent.parent.parent / "data" / "open-cookie-database.csv"
def _build_sync_url(async_url: str) -> str:
"""Convert an asyncpg DSN to a psycopg2 DSN for one-off scripts."""
return async_url.replace("postgresql+asyncpg://", "postgresql://")
def seed(csv_path: Path, *, clear: bool = False) -> int:
"""Read the CSV and upsert rows into known_cookies.
Returns the number of rows inserted.
"""
from src.config.settings import get_settings
settings = get_settings()
engine = sa.create_engine(_build_sync_url(settings.database_url))
with engine.begin() as conn:
# Build slug → category_id lookup
rows = conn.execute(sa.text("SELECT id, slug FROM cookie_categories"))
slug_to_id: dict[str, str] = {r[1]: str(r[0]) for r in rows}
if clear:
conn.execute(sa.text("DELETE FROM known_cookies"))
inserted = 0
with csv_path.open(newline="", encoding="utf-8") as fh:
reader = csv.DictReader(fh)
for row in reader:
category = row.get("Category", "").strip()
slug = _CATEGORY_MAP.get(category)
if not slug or slug not in slug_to_id:
continue
name = row.get("Cookie / Data Key name", "").strip()
if not name:
continue
domain_raw = row.get("Domain", "").strip()
domain = domain_raw if domain_raw else "*"
wildcard = row.get("Wildcard match", "0").strip() == "1"
description = row.get("Description", "").strip() or None
vendor = row.get("Platform", "").strip() or None
# Build pattern: if wildcard, append * to name for glob matching
name_pattern = f"{name}*" if wildcard else name
is_regex = False
conn.execute(
sa.text(
"""
INSERT INTO known_cookies
(id, name_pattern, domain_pattern, category_id,
vendor, description, is_regex, created_at, updated_at)
VALUES
(:id, :name_pattern, :domain_pattern, :category_id,
:vendor, :description, :is_regex, NOW(), NOW())
ON CONFLICT (name_pattern, domain_pattern) DO UPDATE SET
category_id = EXCLUDED.category_id,
vendor = EXCLUDED.vendor,
description = EXCLUDED.description,
is_regex = EXCLUDED.is_regex,
updated_at = NOW()
"""
),
{
"id": str(uuid.uuid4()),
"name_pattern": name_pattern,
"domain_pattern": domain,
"category_id": slug_to_id[slug],
"vendor": vendor,
"description": description,
"is_regex": is_regex,
},
)
inserted += 1
return inserted
def main() -> None:
parser = argparse.ArgumentParser(description="Seed known cookies from Open Cookie Database")
parser.add_argument(
"--csv",
type=Path,
default=_DEFAULT_CSV,
help="Path to the Open Cookie Database CSV (default: bundled copy)",
)
parser.add_argument(
"--clear",
action="store_true",
help="Delete all existing known_cookies before importing",
)
args = parser.parse_args()
if not args.csv.exists():
print(f"Error: CSV not found at {args.csv}", file=sys.stderr)
sys.exit(1)
count = seed(args.csv, clear=args.clear)
print(f"Seeded {count} known cookie patterns from {args.csv.name}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,3 @@
from src.config.settings import Settings, get_settings
__all__ = ["Settings", "get_settings"]

View File

@@ -0,0 +1,26 @@
"""Edition detection for the open-core architecture.
Determines whether the application is running in community edition (CE)
or enterprise edition (EE) based on the availability of the ``ee``
package.
"""
from __future__ import annotations
from functools import lru_cache
@lru_cache(maxsize=1)
def is_ee() -> bool:
"""Return ``True`` if enterprise extensions are available."""
try:
import ee # noqa: F401
return True
except ImportError:
return False
def edition_name() -> str:
"""Return a human-readable edition label (``"ee"`` or ``"ce"``)."""
return "ee" if is_ee() else "ce"

View File

@@ -0,0 +1,26 @@
import logging
import sys
import structlog
def setup_logging(log_level: str = "INFO") -> None:
"""Configure structured logging with structlog."""
structlog.configure(
processors=[
structlog.contextvars.merge_contextvars,
structlog.processors.add_log_level,
structlog.processors.StackInfoRenderer(),
structlog.dev.set_exc_info,
structlog.processors.TimeStamper(fmt="iso"),
structlog.dev.ConsoleRenderer()
if sys.stderr.isatty()
else structlog.processors.JSONRenderer(),
],
wrapper_class=structlog.make_filtering_bound_logger(
getattr(logging, log_level.upper(), logging.INFO)
),
context_class=dict,
logger_factory=structlog.PrintLoggerFactory(),
cache_logger_on_first_use=True,
)

View File

@@ -0,0 +1,166 @@
from functools import lru_cache
from pydantic import model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
# Placeholder value — the application refuses to start in non-dev
# environments if ``jwt_secret_key`` is left at this literal.
_JWT_PLACEHOLDER = "CHANGE-ME-in-production"
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
case_sensitive=False,
)
# Application
app_name: str = "ConsentOS API"
app_version: str = "0.1.0"
debug: bool = False
environment: str = "development"
log_level: str = "INFO"
# Server
host: str = "0.0.0.0"
port: int = 8000
allowed_origins: str = "http://localhost:5173"
@property
def allowed_origins_list(self) -> list[str]:
"""Parse allowed_origins as a comma-separated string."""
return [o.strip() for o in self.allowed_origins.split(",") if o.strip()]
# Database
database_url: str = "postgresql+asyncpg://consentos:consentos@localhost:5432/consentos"
database_echo: bool = False
database_pool_size: int = 20
database_max_overflow: int = 10
# Redis
redis_url: str = "redis://localhost:6379/0"
# JWT
jwt_secret_key: str = _JWT_PLACEHOLDER
jwt_algorithm: str = "HS256"
jwt_access_token_expire_minutes: int = 30
jwt_refresh_token_expire_days: int = 7
# Pseudonymisation — HMAC key for IP / UA hashing on consent records.
# Defaults to deriving from the JWT secret if not explicitly set.
pseudonymisation_secret: str | None = None
# Bootstrap token — required as ``X-Admin-Bootstrap-Token`` on
# ``POST /api/v1/organisations/``. When unset (the default), the
# endpoint is disabled. Rotate or unset after your first org is
# provisioned to prevent further tenant creation.
admin_bootstrap_token: str | None = None
# Initial admin bootstrap — on first startup, if the ``users`` table
# is empty and both credentials below are set, the API creates an
# organisation and an owner user so the operator can log in to the
# admin UI for the first time. Idempotent: once any user exists this
# is a no-op, so the variables can safely remain set across restarts.
# Rotate the password via the admin UI after first login.
initial_admin_email: str | None = None
initial_admin_password: str | None = None
initial_admin_full_name: str = "Administrator"
initial_org_name: str = "Default Organisation"
initial_org_slug: str = "default"
# CDN — public URL where banner scripts (consent-loader.js,
# consent-bundle.js) are hosted. In dev the admin UI dog-foods
# the banner so localhost:5173 works for testing; in production
# this should be a real CDN URL (CloudFlare Pages, S3+CloudFront,
# Cloud CDN, etc.) — see docs for setup.
cdn_base_url: str = "http://localhost:5173"
# Scanner service
scanner_service_url: str = "http://localhost:8001"
scanner_timeout_seconds: int = 300
# Extra GeoIP country header — checked *before* the built-in list
# (``cf-ipcountry``, ``x-vercel-ip-country``, ``x-appengine-country``,
# ``x-country-code``). Set this when running behind a CDN/load
# balancer that uses a non-standard header, e.g. Google Cloud
# Load Balancer's ``x-gclb-country`` or an internal edge proxy.
# Header names are case-insensitive. Leave unset if one of the
# built-in headers is fine.
geoip_country_header: str | None = None
# Subdivision/state code header — optional companion to
# ``GEOIP_COUNTRY_HEADER``. When both are set the API pairs them to
# produce region keys like ``US-CA`` or ``GB-SCT`` (ISO 3166-2
# subdivision without the country prefix). Different CDNs expose
# this under different names: Cloudflare Enterprise uses
# ``cf-region-code``, Vercel uses ``x-vercel-ip-country-region``,
# GCP Load Balancer uses ``x-gclb-region``, CloudFront functions
# use ``cloudfront-viewer-country-region``. Leave unset if you
# only need country-level granularity.
geoip_region_header: str | None = None
# Local MaxMind GeoLite2/GeoIP2 City database — used as a fallback
# when no CDN header is present. Download GeoLite2-City.mmdb from
# https://dev.maxmind.com/geoip/geolite2-free-geolocation-data and
# mount it into the container (e.g. ``/data/GeoLite2-City.mmdb``).
# When unset, lookups fall back to the free external ip-api.com
# service, which is rate-limited and should not be relied on in
# production.
geoip_maxmind_db_path: str | None = None
# Rate limiting — on by default. Public endpoints (banner config +
# consent submission) are internet-exposed and must not be DoS-able.
# Auth endpoints get a stricter bucket via ``RateLimitMiddleware``.
rate_limit_enabled: bool = True
rate_limit_per_minute: int = 120
@model_validator(mode="after")
def _check_production_safety(self) -> "Settings":
"""Refuse to start with unsafe defaults in non-dev environments."""
if self.environment.lower() in ("development", "dev", "local", "test"):
return self
errors: list[str] = []
if self.jwt_secret_key == _JWT_PLACEHOLDER:
errors.append(
"JWT_SECRET_KEY is set to the placeholder value "
f"{_JWT_PLACEHOLDER!r}. Generate a strong random value "
"(e.g. `openssl rand -base64 48`) and set it in the "
"environment before starting the API."
)
if "*" in self.allowed_origins_list:
errors.append(
"ALLOWED_ORIGINS contains '*'. Wildcard CORS combined with "
"allow_credentials=True is a credential-theft vector. "
"Set ALLOWED_ORIGINS to an explicit list of trusted origins."
)
if errors:
msg = "Refusing to start with unsafe configuration:\n - " + "\n - ".join(
errors,
)
raise ValueError(msg)
return self
@property
def pseudonymisation_key(self) -> bytes:
"""Return the HMAC key used for pseudonymising IP/UA values.
If ``pseudonymisation_secret`` is not set, derives a per-instance
key from the JWT secret so operators don't have to configure two
secrets. Using JWT_SECRET directly is acceptable because the
HMAC is one-way and the resulting hashes are not reversible.
"""
source = self.pseudonymisation_secret or self.jwt_secret_key
return source.encode("utf-8")
@lru_cache(maxsize=1)
def get_settings() -> Settings:
return Settings()

View File

@@ -0,0 +1,3 @@
from src.db.session import get_db
__all__ = ["get_db"]

View File

@@ -0,0 +1,31 @@
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from src.config.settings import get_settings
settings = get_settings()
engine = create_async_engine(
settings.database_url,
echo=settings.database_echo,
pool_size=settings.database_pool_size,
max_overflow=settings.database_max_overflow,
)
async_session_factory = async_sessionmaker(
engine,
class_=AsyncSession,
expire_on_commit=False,
)
async def get_db() -> AsyncGenerator[AsyncSession, None]:
"""Dependency that yields an async database session."""
async with async_session_factory() as session:
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise

View File

@@ -0,0 +1,3 @@
from src.extensions.registry import discover_extensions, get_registry
__all__ = ["discover_extensions", "get_registry"]

View File

@@ -0,0 +1,197 @@
"""Extension registry for the open-core architecture.
Provides registration hooks that allow enterprise/commercial code to inject
routers, model modules, startup tasks, and OpenAPI tags into the core
application — without the core needing any direct knowledge of the
extensions.
In community edition (CE) mode, ``discover_extensions()`` is a no-op
because the ``ee`` package is not present.
"""
from __future__ import annotations
import importlib
import logging
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from collections.abc import Callable, Coroutine
from typing import Any
from fastapi import APIRouter, FastAPI
logger = logging.getLogger(__name__)
@dataclass
class OpenAPITag:
"""Metadata for a FastAPI OpenAPI tag."""
name: str
description: str
@dataclass
class RouterEntry:
"""A router registered by an extension."""
router: APIRouter
prefix: str = "/api/v1"
tags: list[OpenAPITag] = field(default_factory=list)
@dataclass
class ExtensionRegistry:
"""Central registry for extension-contributed components.
Extensions call the module-level helper functions (``register_router``,
``register_model_module``, etc.) which delegate to the singleton
instance stored in ``_registry``.
"""
routers: list[RouterEntry] = field(default_factory=list)
model_modules: list[str] = field(default_factory=list)
startup_hooks: list[Callable[[FastAPI], Coroutine[Any, Any, None]]] = field(
default_factory=list,
)
config_enrichers: list[Callable] = field(default_factory=list)
consent_record_hooks: list[Callable] = field(default_factory=list)
# ------------------------------------------------------------------
# Registration helpers
# ------------------------------------------------------------------
def add_router(
self,
router: APIRouter,
*,
prefix: str = "/api/v1",
tags: list[OpenAPITag] | None = None,
) -> None:
self.routers.append(RouterEntry(router=router, prefix=prefix, tags=tags or []))
def add_model_module(self, module_path: str) -> None:
self.model_modules.append(module_path)
def add_startup_hook(
self,
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
) -> None:
self.startup_hooks.append(hook)
def add_config_enricher(self, enricher: Callable) -> None:
self.config_enrichers.append(enricher)
def add_consent_record_hook(self, hook: Callable) -> None:
self.consent_record_hooks.append(hook)
# ------------------------------------------------------------------
# Application wiring
# ------------------------------------------------------------------
def apply(self, app: FastAPI) -> None:
"""Mount all registered routers and tags onto *app*."""
for entry in self.routers:
# Inject OpenAPI tags
for tag in entry.tags:
existing = app.openapi_tags or []
if not any(t["name"] == tag.name for t in existing):
existing.append(
{"name": tag.name, "description": tag.description},
)
app.openapi_tags = existing
app.include_router(entry.router, prefix=entry.prefix)
if self.routers:
logger.info(
"Registered %d extension router(s)",
len(self.routers),
)
# Import model modules so SQLAlchemy picks them up
for mod in self.model_modules:
importlib.import_module(mod)
if self.model_modules:
logger.info(
"Registered %d extension model module(s)",
len(self.model_modules),
)
# Singleton ------------------------------------------------------------------
_registry = ExtensionRegistry()
def get_registry() -> ExtensionRegistry:
"""Return the global extension registry."""
return _registry
# Convenience module-level API -----------------------------------------------
def register_router(
router: APIRouter,
*,
prefix: str = "/api/v1",
tags: list[OpenAPITag] | None = None,
) -> None:
"""Register an API router to be mounted at startup."""
_registry.add_router(router, prefix=prefix, tags=tags)
def register_model_module(module_path: str) -> None:
"""Register a dotted module path whose SQLAlchemy models should be imported."""
_registry.add_model_module(module_path)
def register_startup_hook(
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
) -> None:
"""Register an async callable to run during application startup."""
_registry.add_startup_hook(hook)
def register_config_enricher(enricher: Callable) -> None:
"""Register a callable that enriches published config.
The callable signature is ``async (site_id: UUID, db: AsyncSession, config: dict) -> None``.
It should mutate *config* in-place to add extension-specific data
(e.g. A/B test variants).
"""
_registry.add_config_enricher(enricher)
def register_consent_record_hook(hook: Callable) -> None:
"""Register a callable invoked after a consent record is persisted.
The callable signature is ``async (db: AsyncSession, consent_record) -> None``.
It is called from ``POST /api/v1/consent`` after the record has been
flushed to the database. Typical use: generating a consent receipt
(EE), writing audit logs, firing webhooks.
"""
_registry.add_consent_record_hook(hook)
# Discovery ------------------------------------------------------------------
def discover_extensions() -> None:
"""Import the EE registration module if installed.
Enterprise edition is distributed as a separate ``consent-enterprise``
package. When installed in the same environment, importing
``ee.api.src.register`` triggers its side-effect registrations. In
community edition the import simply fails and we carry on.
"""
try:
import ee.api.src.register # noqa: F401
logger.info("Enterprise extensions loaded")
except ImportError:
logger.debug("No enterprise extensions found (CE mode)")

210
apps/api/src/main.py Normal file
View File

@@ -0,0 +1,210 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from src.config.edition import edition_name
from src.config.logging import setup_logging
from src.config.settings import get_settings
from src.extensions.registry import discover_extensions, get_registry
from src.middleware.rate_limit import RateLimitMiddleware
from src.middleware.security_headers import SecurityHeadersMiddleware
from src.routers import (
auth,
compliance,
config,
consent,
cookies,
org_config,
organisations,
scanner,
site_group_config,
site_groups,
sites,
translations,
users,
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Application startup and shutdown lifecycle."""
settings = get_settings()
setup_logging(settings.log_level)
yield
def create_app() -> FastAPI:
"""Application factory."""
settings = get_settings()
app = FastAPI(
title=settings.app_name,
version=settings.app_version,
description=(
"Multi-tenant cookie consent management platform API. "
"Provides consent collection, cookie scanning, auto-blocking, "
"compliance checking, and analytics across multiple sites."
),
debug=settings.debug,
lifespan=lifespan,
openapi_tags=[
{
"name": "auth",
"description": "Authentication — login, token refresh, and current user.",
},
{
"name": "config",
"description": (
"Site configuration — public endpoints for the banner script "
"to fetch config, GeoIP-resolved config, and CDN publishing."
),
},
{
"name": "consent",
"description": (
"Consent recording and retrieval — public endpoints called "
"by the banner script to record visitor consent decisions."
),
},
{
"name": "sites",
"description": "Site and site config CRUD — manage domains and settings.",
},
{
"name": "cookies",
"description": (
"Cookie management — categories, discovered cookies, allow-list, "
"known cookies database, and auto-classification."
),
},
{
"name": "scanner",
"description": (
"Cookie scanner — trigger scans, view results, and receive "
"client-side cookie reports from the banner script."
),
},
{
"name": "compliance",
"description": (
"Compliance checking — run checks against GDPR, CNIL, CCPA, "
"ePrivacy, and LGPD frameworks."
),
},
{
"name": "organisations",
"description": "Organisation management — multi-tenant root entities.",
},
{
"name": "users",
"description": "User management — org-scoped users with role-based access.",
},
],
)
# Security headers
app.add_middleware(SecurityHeadersMiddleware)
# Rate limiting (must be added before CORS to count requests correctly)
if settings.rate_limit_enabled:
app.add_middleware(
RateLimitMiddleware,
redis_url=settings.redis_url,
requests_per_minute=settings.rate_limit_per_minute,
auth_requests_per_minute=10,
)
# CORS
app.add_middleware(
CORSMiddleware,
allow_origins=settings.allowed_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Core routers
api_prefix = "/api/v1"
app.include_router(auth.router, prefix=api_prefix)
app.include_router(config.router, prefix=api_prefix)
app.include_router(consent.router, prefix=api_prefix)
app.include_router(scanner.router, prefix=api_prefix)
app.include_router(compliance.router, prefix=api_prefix)
app.include_router(organisations.router, prefix=api_prefix)
app.include_router(org_config.router, prefix=api_prefix)
app.include_router(users.router, prefix=api_prefix)
app.include_router(site_groups.router, prefix=api_prefix)
app.include_router(site_group_config.router, prefix=api_prefix)
app.include_router(sites.router, prefix=api_prefix)
app.include_router(cookies.router, prefix=api_prefix)
app.include_router(translations.router, prefix=api_prefix)
app.include_router(translations.public_router, prefix=api_prefix)
# Discover and mount enterprise extensions (no-op in CE mode)
discover_extensions()
registry = get_registry()
registry.apply(app)
@app.get("/health", tags=["health"])
async def health() -> dict[str, str]:
"""Shallow liveness check.
Answers "is the process running?". Suitable for orchestrator
liveness probes. For deployment readiness, use
``/health/ready`` which verifies downstream dependencies.
"""
return {"status": "ok", "edition": edition_name()}
@app.get("/health/ready", tags=["health"])
async def health_ready() -> dict[str, object]:
"""Deep readiness check — verifies database and Redis.
Returns HTTP 503 if either dependency is unreachable so load
balancers route traffic away from broken instances.
"""
from fastapi import HTTPException
from sqlalchemy import text
from src.db.session import engine as db_engine
checks: dict[str, str] = {}
overall_ok = True
# Database
try:
async with db_engine.connect() as conn:
await conn.execute(text("SELECT 1"))
checks["database"] = "ok"
except Exception as exc:
checks["database"] = f"error: {type(exc).__name__}"
overall_ok = False
# Redis
try:
import redis.asyncio as aioredis
r = aioredis.from_url(settings.redis_url, decode_responses=True)
pong = await r.ping()
checks["redis"] = "ok" if pong else "error: ping failed"
if not pong:
overall_ok = False
await r.aclose()
except Exception as exc:
checks["redis"] = f"error: {type(exc).__name__}"
overall_ok = False
payload = {
"status": "ok" if overall_ok else "degraded",
"edition": edition_name(),
"checks": checks,
}
if not overall_ok:
raise HTTPException(status_code=503, detail=payload)
return payload
return app
app = create_app()

View File

View File

@@ -0,0 +1,111 @@
"""Redis-backed rate limiting middleware.
Applies per-IP rate limits to all incoming requests. Public endpoints
(consent recording, config fetching) are the primary protection target.
Uses a sliding window counter stored in Redis with automatic expiry.
"""
from __future__ import annotations
import logging
import time
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
logger = logging.getLogger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple per-IP rate limiter backed by Redis."""
def __init__(
self,
app: object,
redis_url: str = "redis://localhost:6379/0",
requests_per_minute: int = 120,
auth_requests_per_minute: int = 10,
) -> None:
super().__init__(app) # type: ignore[arg-type]
self.redis_url = redis_url
self.requests_per_minute = requests_per_minute
self.auth_requests_per_minute = auth_requests_per_minute
self._redis: object | None = None
async def _get_redis(self) -> object | None:
"""Lazy-initialise Redis connection."""
if self._redis is not None:
return self._redis
try:
import redis.asyncio as aioredis
self._redis = aioredis.from_url(self.redis_url, decode_responses=True)
return self._redis
except Exception:
logger.warning("Rate limiting disabled: Redis unavailable")
return None
def _get_client_ip(self, request: Request) -> str:
"""Extract the real client IP."""
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
if request.client:
return request.client.host
return "unknown"
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
# Skip rate limiting for health checks
if request.url.path in ("/health", "/health/ready", "/health/live"):
return await call_next(request)
r = await self._get_redis()
if r is None:
# Redis unavailable — allow request through
return await call_next(request)
# Auth endpoints get a stricter bucket to slow down credential
# stuffing — login, password reset, token refresh.
path = request.url.path
is_auth = path.startswith("/api/v1/auth/") and path not in ("/api/v1/auth/me",)
limit = self.auth_requests_per_minute if is_auth else self.requests_per_minute
bucket = "auth" if is_auth else "req"
client_ip = self._get_client_ip(request)
window = int(time.time() // 60)
key = f"cmp:rate:{bucket}:{client_ip}:{window}"
try:
current = await r.incr(key) # type: ignore[union-attr]
if current == 1:
await r.expire(key, 120) # type: ignore[union-attr]
if current > limit:
return JSONResponse(
status_code=429,
content={"detail": "Too many requests. Please try again later."},
headers={
"Retry-After": "60",
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": "0",
},
)
response = await call_next(request)
remaining = max(0, limit - current)
response.headers["X-RateLimit-Limit"] = str(limit)
response.headers["X-RateLimit-Remaining"] = str(remaining)
return response
except Exception:
logger.debug("Rate limit check failed", exc_info=True)
return await call_next(request)

View File

@@ -0,0 +1,41 @@
"""Security headers middleware.
Adds standard security headers to all API responses:
- X-Content-Type-Options: nosniff
- X-Frame-Options: DENY
- X-XSS-Protection: 0 (disabled in favour of CSP)
- Referrer-Policy: strict-origin-when-cross-origin
- Content-Security-Policy: default-src 'none'
- Strict-Transport-Security (HSTS) in production
"""
from __future__ import annotations
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security headers to all responses."""
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "0"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Content-Security-Policy"] = "default-src 'none'"
# HSTS — only on HTTPS requests (reverse proxy may terminate TLS)
if request.url.scheme == "https":
response.headers["Strict-Transport-Security"] = (
"max-age=63072000; includeSubDomains; preload"
)
return response

View File

@@ -0,0 +1,31 @@
from src.models.base import Base
from src.models.consent import ConsentRecord
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
from src.models.org_config import OrgConfig
from src.models.organisation import Organisation
from src.models.scan import ScanJob, ScanResult
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.models.site_group import SiteGroup
from src.models.site_group_config import SiteGroupConfig
from src.models.translation import Translation
from src.models.user import User
__all__ = [
"Base",
"ConsentRecord",
"Cookie",
"CookieAllowListEntry",
"CookieCategory",
"KnownCookie",
"OrgConfig",
"Organisation",
"ScanJob",
"ScanResult",
"Site",
"SiteConfig",
"SiteGroup",
"SiteGroupConfig",
"Translation",
"User",
]

View File

@@ -0,0 +1,48 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, func
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
class Base(DeclarativeBase):
"""Base class for all SQLAlchemy models."""
pass
class TimestampMixin:
"""Mixin that adds created_at and updated_at columns."""
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
class UUIDPrimaryKeyMixin:
"""Mixin that adds a UUID primary key."""
id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
primary_key=True,
default=uuid.uuid4,
)
class SoftDeleteMixin:
"""Mixin that adds soft delete support."""
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True),
nullable=True,
default=None,
)

View File

@@ -0,0 +1,81 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column
from src.models.base import Base, UUIDPrimaryKeyMixin
class ConsentRecord(UUIDPrimaryKeyMixin, Base):
"""Audit trail of every consent event. Partitioned by month for performance."""
__tablename__ = "consent_records"
__table_args__ = (
# Composite index for the most common analytics query pattern:
# "records for site X between dates A and B". The (site_id,
# consented_at DESC) ordering also supports "latest consents
# for site X" without an extra sort.
Index(
"ix_consent_records_site_consented_at",
"site_id",
"consented_at",
),
)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
# Visitor identification (anonymous)
visitor_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
ip_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
user_agent_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
# Consent details
action: Mapped[str] = mapped_column(String(30), nullable=False)
categories_accepted: Mapped[list] = mapped_column(JSONB, nullable=False)
categories_rejected: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# TCF
tc_string: Mapped[str | None] = mapped_column(Text, nullable=True)
# GCM state at time of consent
gcm_state: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# GPP
gpp_string: Mapped[str | None] = mapped_column(Text, nullable=True)
# GPC
gpc_detected: Mapped[bool | None] = mapped_column(nullable=True)
gpc_honoured: Mapped[bool | None] = mapped_column(nullable=True)
# A/B testing — soft references to EE `ab_tests` / `ab_test_variants`
# tables. Intentionally *no* FK constraint so the core schema works
# without the EE extension installed.
ab_test_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
nullable=True,
index=True,
)
ab_variant_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
nullable=True,
)
# Context
page_url: Mapped[str | None] = mapped_column(Text, nullable=True)
country_code: Mapped[str | None] = mapped_column(String(5), nullable=True)
region_code: Mapped[str | None] = mapped_column(String(10), nullable=True)
# Timestamp
consented_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
nullable=False,
index=True,
)

View File

@@ -0,0 +1,130 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class CookieCategory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Cookie category taxonomy (necessary, functional, analytics, marketing, personalisation)."""
__tablename__ = "cookie_categories"
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
is_essential: Mapped[bool] = mapped_column(default=False, nullable=False)
display_order: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
# TCF purpose mapping
tcf_purpose_ids: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# Google Consent Mode consent type mapping
gcm_consent_types: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# Relationships
cookies: Mapped[list["Cookie"]] = relationship(back_populates="category")
allow_list_entries: Mapped[list["CookieAllowListEntry"]] = relationship(
back_populates="category"
)
class Cookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A cookie discovered on a site via scanning or client-side reporting."""
__tablename__ = "cookies"
__table_args__ = (
UniqueConstraint(
"site_id",
"name",
"domain",
"storage_type",
name="uq_cookies_site_name_domain_type",
),
)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
category_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("cookie_categories.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
domain: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
path: Mapped[str | None] = mapped_column(String(500), nullable=True)
max_age_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
is_http_only: Mapped[bool | None] = mapped_column(nullable=True)
is_secure: Mapped[bool | None] = mapped_column(nullable=True)
same_site: Mapped[str | None] = mapped_column(String(10), nullable=True)
review_status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
first_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
last_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
# Relationships
site: Mapped["Site"] = relationship(back_populates="cookies") # noqa: F821
category: Mapped["CookieCategory | None"] = relationship(back_populates="cookies")
class CookieAllowListEntry(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Approved cookies per site with category assignment."""
__tablename__ = "cookie_allow_list"
__table_args__ = (
UniqueConstraint(
"site_id",
"name_pattern",
"domain_pattern",
name="uq_allow_list_site_name_domain",
),
)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
category_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
nullable=False,
)
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships
site: Mapped["Site"] = relationship(back_populates="cookie_allow_list") # noqa: F821
category: Mapped["CookieCategory"] = relationship(back_populates="allow_list_entries")
class KnownCookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Shared knowledge base of known cookie patterns for auto-categorisation."""
__tablename__ = "known_cookies"
__table_args__ = (
UniqueConstraint("name_pattern", "domain_pattern", name="uq_known_cookies_name_domain"),
)
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
category_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
nullable=False,
)
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
is_regex: Mapped[bool] = mapped_column(default=False, nullable=False)

View File

@@ -0,0 +1,64 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class OrgConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Organisation-level default configuration.
These defaults sit between system defaults and site config in the cascade:
System Defaults → Org Config → Site Group Config → Site Config → Regional Overrides
"""
__tablename__ = "org_configs"
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
unique=True,
nullable=False,
)
# Blocking mode
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# TCF
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
# GPP (Global Privacy Platform)
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# GPC (Global Privacy Control)
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
# Google Consent Mode
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# Shopify Customer Privacy API
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
# Banner
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Scanning
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Consent
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Relationship
organisation: Mapped["Organisation"] = relationship(back_populates="org_config") # noqa: F821

View File

@@ -0,0 +1,26 @@
from sqlalchemy import String, Text
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class Organisation(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""Multi-tenant root entity. Each organisation has multiple sites and users."""
__tablename__ = "organisations"
name: Mapped[str] = mapped_column(String(255), nullable=False)
slug: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
billing_plan: Mapped[str] = mapped_column(String(50), server_default="free", nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships
users: Mapped[list["User"]] = relationship(back_populates="organisation") # noqa: F821
sites: Mapped[list["Site"]] = relationship(back_populates="organisation") # noqa: F821
site_groups: Mapped[list["SiteGroup"]] = relationship( # noqa: F821
back_populates="organisation"
)
org_config: Mapped["OrgConfig | None"] = relationship( # noqa: F821
back_populates="organisation", uselist=False
)

View File

@@ -0,0 +1,68 @@
import uuid
from datetime import datetime
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class ScanJob(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""A cookie scanning job for a site."""
__tablename__ = "scan_jobs"
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
status: Mapped[str] = mapped_column(
String(20), server_default="pending", nullable=False, index=True
)
trigger: Mapped[str] = mapped_column(String(20), server_default="manual", nullable=False)
pages_scanned: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
pages_total: Mapped[int | None] = mapped_column(Integer, nullable=True)
cookies_found: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
# Relationships
site: Mapped["Site"] = relationship(back_populates="scan_jobs") # noqa: F821
results: Mapped[list["ScanResult"]] = relationship(back_populates="scan_job")
class ScanResult(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Individual result from a scan — a cookie found on a specific page."""
__tablename__ = "scan_results"
scan_job_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("scan_jobs.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
page_url: Mapped[str] = mapped_column(Text, nullable=False)
cookie_name: Mapped[str] = mapped_column(String(255), nullable=False)
cookie_domain: Mapped[str] = mapped_column(String(255), nullable=False)
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
script_source: Mapped[str | None] = mapped_column(Text, nullable=True)
auto_category: Mapped[str | None] = mapped_column(String(50), nullable=True)
initiator_chain: Mapped[list[str] | None] = mapped_column(
ARRAY(Text), nullable=True, comment="Ordered script URLs from root initiator to leaf"
)
found_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
# Relationships
scan_job: Mapped["ScanJob"] = relationship(back_populates="results")

View File

@@ -0,0 +1,48 @@
import uuid
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import ARRAY, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class Site(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""A domain being managed for cookie consent, belongs to an organisation."""
__tablename__ = "sites"
__table_args__ = (UniqueConstraint("organisation_id", "domain", name="uq_sites_org_domain"),)
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
domain: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
display_name: Mapped[str] = mapped_column(String(255), nullable=False)
is_active: Mapped[bool] = mapped_column(default=True, nullable=False)
additional_domains: Mapped[list[str] | None] = mapped_column(
ARRAY(String(255)), nullable=True, server_default=None
)
site_group_id: Mapped[uuid.UUID | None] = mapped_column(
UUID(as_uuid=True),
ForeignKey("site_groups.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
# Relationships
organisation: Mapped["Organisation"] = relationship(back_populates="sites") # noqa: F821
site_group: Mapped["SiteGroup | None"] = relationship(back_populates="sites") # noqa: F821
config: Mapped["SiteConfig | None"] = relationship( # noqa: F821
back_populates="site", uselist=False
)
cookies: Mapped[list["Cookie"]] = relationship(back_populates="site") # noqa: F821
cookie_allow_list: Mapped[list["CookieAllowListEntry"]] = relationship( # noqa: F821
back_populates="site"
)
scan_jobs: Mapped[list["ScanJob"]] = relationship(back_populates="site") # noqa: F821
translations: Mapped[list["Translation"]] = relationship( # noqa: F821
back_populates="site"
)

View File

@@ -0,0 +1,63 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class SiteConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Full configuration for a site: blocking mode, TCF, GCM, banner, scanning, consent."""
__tablename__ = "site_configs"
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
unique=True,
nullable=False,
)
# Blocking mode
blocking_mode: Mapped[str] = mapped_column(String(20), server_default="opt_in", nullable=False)
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# TCF
tcf_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
# GPP (Global Privacy Platform)
gpp_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# GPC (Global Privacy Control)
gpc_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
gpc_global_honour: Mapped[bool] = mapped_column(default=False, nullable=False)
# Google Consent Mode
gcm_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# Shopify Customer Privacy API
shopify_privacy_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
# Banner
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
display_mode: Mapped[str] = mapped_column(
String(30), server_default="bottom_banner", nullable=False
)
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Scanning
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
scan_max_pages: Mapped[int] = mapped_column(Integer, server_default="50", nullable=False)
# Consent
consent_expiry_days: Mapped[int] = mapped_column(Integer, server_default="365", nullable=False)
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Relationship
site: Mapped["Site"] = relationship(back_populates="config") # noqa: F821

View File

@@ -0,0 +1,32 @@
import uuid
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class SiteGroup(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""A logical grouping of sites within an organisation (e.g. a brand)."""
__tablename__ = "site_groups"
__table_args__ = (UniqueConstraint("organisation_id", "name", name="uq_site_groups_org_name"),)
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
# Relationships
organisation: Mapped["Organisation"] = relationship( # noqa: F821
back_populates="site_groups"
)
sites: Mapped[list["Site"]] = relationship(back_populates="site_group") # noqa: F821
group_config: Mapped["SiteGroupConfig | None"] = relationship( # noqa: F821
back_populates="site_group", uselist=False
)

View File

@@ -0,0 +1,63 @@
import uuid
from sqlalchemy import ForeignKey, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class SiteGroupConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Site-group-level default configuration.
These defaults sit between org defaults and site config in the cascade:
System Defaults -> Org Config -> Site Group Config -> Site Config -> Regional Overrides
"""
__tablename__ = "site_group_configs"
site_group_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("site_groups.id", ondelete="CASCADE"),
unique=True,
nullable=False,
)
# Blocking mode
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# TCF
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
# GPP (Global Privacy Platform)
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
# GPC (Global Privacy Control)
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
# Google Consent Mode
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
# Shopify Customer Privacy API
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
# Banner
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
# Scanning
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Consent
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Relationship
site_group: Mapped["SiteGroup"] = relationship(back_populates="group_config") # noqa: F821

View File

@@ -0,0 +1,26 @@
import uuid
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.dialects.postgresql import JSONB, UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
class Translation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
"""Internationalisation strings per site per locale."""
__tablename__ = "translations"
__table_args__ = (UniqueConstraint("site_id", "locale", name="uq_translations_site_locale"),)
site_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("sites.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
locale: Mapped[str] = mapped_column(String(10), nullable=False)
strings: Mapped[dict] = mapped_column(JSONB, nullable=False)
# Relationships
site: Mapped["Site"] = relationship(back_populates="translations") # noqa: F821

View File

@@ -0,0 +1,31 @@
import uuid
from sqlalchemy import ForeignKey, String
from sqlalchemy.dialects.postgresql import UUID
from sqlalchemy.orm import Mapped, mapped_column, relationship
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
class User(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
"""User account, scoped to an organisation with a role."""
__tablename__ = "users"
organisation_id: Mapped[uuid.UUID] = mapped_column(
UUID(as_uuid=True),
ForeignKey("organisations.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
full_name: Mapped[str] = mapped_column(String(255), nullable=False)
role: Mapped[str] = mapped_column(
String(20),
nullable=False,
server_default="viewer",
)
# Relationships
organisation: Mapped["Organisation"] = relationship(back_populates="users") # noqa: F821

View File

View File

@@ -0,0 +1,108 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from jose import JWTError
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import get_settings
from src.db import get_db
from src.models.user import User
from src.schemas.auth import CurrentUser, LoginRequest, RefreshRequest, TokenResponse
from src.services.auth import (
create_access_token,
create_refresh_token,
decode_token,
verify_password,
)
from src.services.dependencies import get_current_user
router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/login", response_model=TokenResponse)
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
"""Authenticate a user with email and password, return JWT tokens."""
result = await db.execute(
select(User).where(User.email == body.email, User.deleted_at.is_(None))
)
user = result.scalar_one_or_none()
if user is None or not verify_password(body.password, user.password_hash):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid email or password",
)
settings = get_settings()
access_token = create_access_token(
user_id=user.id,
organisation_id=user.organisation_id,
role=user.role,
email=user.email,
)
refresh_token = create_refresh_token(
user_id=user.id,
organisation_id=user.organisation_id,
)
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token,
expires_in=settings.jwt_access_token_expire_minutes * 60,
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh(
body: RefreshRequest,
db: AsyncSession = Depends(get_db),
) -> TokenResponse:
"""Exchange a valid refresh token for a new access/refresh token pair."""
try:
payload = decode_token(body.refresh_token)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired refresh token",
) from exc
if payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token is not a refresh token",
)
user_id = uuid.UUID(payload["sub"])
result = await db.execute(select(User).where(User.id == user_id, User.deleted_at.is_(None)))
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="User no longer exists",
)
settings = get_settings()
access_token = create_access_token(
user_id=user.id,
organisation_id=user.organisation_id,
role=user.role,
email=user.email,
)
new_refresh_token = create_refresh_token(
user_id=user.id,
organisation_id=user.organisation_id,
)
return TokenResponse(
access_token=access_token,
refresh_token=new_refresh_token,
expires_in=settings.jwt_access_token_expire_minutes * 60,
)
@router.get("/me", response_model=CurrentUser)
async def get_me(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
"""Return the currently authenticated user's profile from the JWT."""
return current_user

View File

@@ -0,0 +1,135 @@
"""Compliance checking endpoints.
Evaluates a site's configuration against regulatory frameworks (GDPR, CNIL,
CCPA, ePrivacy, LGPD) and returns per-framework compliance reports with scores,
issues, and recommendations.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.cookie import Cookie
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.schemas.compliance import (
ComplianceCheckRequest,
ComplianceCheckResponse,
Framework,
)
from src.services.compliance import (
SiteContext,
calculate_overall_score,
run_compliance_check,
)
from src.services.dependencies import get_current_user
router = APIRouter(prefix="/compliance", tags=["compliance"])
async def _build_site_context(
site_id: uuid.UUID,
db: AsyncSession,
) -> SiteContext:
"""Load site config and cookie stats to build a SiteContext."""
# Fetch site config
result = await db.execute(
select(SiteConfig).where(
SiteConfig.site_id == site_id,
SiteConfig.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
# Fetch cookie statistics
total_q = await db.execute(
select(func.count()).select_from(Cookie).where(Cookie.site_id == site_id)
)
total_cookies = total_q.scalar() or 0
uncat_q = await db.execute(
select(func.count())
.select_from(Cookie)
.where(
Cookie.site_id == site_id,
Cookie.category_id.is_(None),
)
)
uncategorised_cookies = uncat_q.scalar() or 0
if config is None:
return SiteContext(
total_cookies=total_cookies,
uncategorised_cookies=uncategorised_cookies,
)
banner_config = config.banner_config or {}
return SiteContext(
blocking_mode=config.blocking_mode,
regional_modes=config.regional_modes,
tcf_enabled=config.tcf_enabled,
gcm_enabled=config.gcm_enabled,
consent_expiry_days=config.consent_expiry_days,
privacy_policy_url=config.privacy_policy_url,
display_mode=config.display_mode,
banner_config=config.banner_config,
total_cookies=total_cookies,
uncategorised_cookies=uncategorised_cookies,
has_reject_button=banner_config.get("show_reject_all", True),
has_granular_choices=banner_config.get("show_category_toggles", True),
has_cookie_wall=banner_config.get("cookie_wall", False),
pre_ticked_boxes=banner_config.get("pre_ticked", False),
)
@router.post(
"/check/{site_id}",
response_model=ComplianceCheckResponse,
status_code=status.HTTP_200_OK,
)
async def check_compliance(
site_id: uuid.UUID,
body: ComplianceCheckRequest | None = None,
db: AsyncSession = Depends(get_db),
_user=Depends(get_current_user),
) -> ComplianceCheckResponse:
"""Run compliance checks against a site's configuration."""
# Verify site exists
site_result = await db.execute(
select(Site).where(Site.id == site_id, Site.deleted_at.is_(None))
)
site = site_result.scalar_one_or_none()
if site is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site not found",
)
ctx = await _build_site_context(site_id, db)
frameworks = body.frameworks if body else None
results = run_compliance_check(ctx, frameworks)
overall_score = calculate_overall_score(results)
return ComplianceCheckResponse(
site_id=str(site_id),
results=results,
overall_score=overall_score,
)
@router.get("/frameworks", response_model=list[dict])
async def list_frameworks() -> list[dict]:
"""List all available compliance frameworks."""
return [
{"id": fw.value, "name": fw.value.upper(), "description": desc}
for fw, desc in [
(Framework.GDPR, "EU General Data Protection Regulation"),
(Framework.CNIL, "French Data Protection Authority (stricter GDPR)"),
(Framework.CCPA, "California Consumer Privacy Act / CPRA"),
(Framework.EPRIVACY, "EU ePrivacy Directive"),
(Framework.LGPD, "Brazilian General Data Protection Law"),
]
]

View File

@@ -0,0 +1,324 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.extensions.registry import get_registry
from src.models.org_config import OrgConfig
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.models.site_group_config import SiteGroupConfig
from src.schemas.auth import CurrentUser
from src.schemas.site import SiteConfigResponse
from src.services.config_resolver import (
CONFIG_FIELDS,
build_public_config,
orm_to_config_dict,
resolve_config,
)
from src.services.dependencies import require_role
from src.services.geoip import detect_region
from src.services.publisher import publish_site_config
router = APIRouter(prefix="/config", tags=["config"])
@router.get("/sites/{site_id}", response_model=SiteConfigResponse)
async def get_public_site_config(
site_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Public endpoint: retrieve site config for the banner script. No auth required."""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
return config
@router.get("/sites/{site_id}/resolved")
async def get_resolved_config(
site_id: uuid.UUID,
region: str | None = None,
db: AsyncSession = Depends(get_db),
) -> dict:
"""Public endpoint: retrieve fully resolved config with regional overrides applied.
Applies the full cascade: System → Org → Group → Site → Regional.
"""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
config_dict = orm_to_config_dict(config, include_id=True)
# Load org defaults via the site
org_id = await _get_site_org_id(site_id, db)
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
# Load site group defaults
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
config_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
region=region,
)
return build_public_config(str(site_id), resolved)
@router.get("/sites/{site_id}/geo-resolved")
async def get_geo_resolved_config(
site_id: uuid.UUID,
request: Request,
db: AsyncSession = Depends(get_db),
) -> dict:
"""Public endpoint: resolve config using the visitor's detected region.
Detects the visitor's region from CDN headers or IP geolocation,
then applies regional blocking mode overrides automatically.
Uses the full cascade: System → Org → Group → Site → Regional.
"""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
# Detect region from request
geo = await detect_region(request)
config_dict = orm_to_config_dict(config, include_id=True)
org_id = await _get_site_org_id(site_id, db)
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
config_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
region=geo.region,
)
public = build_public_config(str(site_id), resolved)
# Include detected geo info so the banner can use it
public["detected_country"] = geo.country_code
public["detected_region"] = geo.region
return public
@router.get("/geo")
async def get_visitor_geo(request: Request) -> dict:
"""Public endpoint: return the detected region for the current visitor.
Useful for banner scripts that need to know the region before
fetching the full config.
"""
geo = await detect_region(request)
return {
"country_code": geo.country_code,
"region": geo.region,
}
@router.get("/sites/{site_id}/inheritance")
async def get_config_inheritance(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Return the full config inheritance chain for a site.
Shows the value at each level so the UI can display where each setting
comes from: system, org, group, or site.
"""
from src.services.config_resolver import SYSTEM_DEFAULTS
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
site_dict = orm_to_config_dict(config)
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
site_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
)
# For each config field, determine the source
sources: dict[str, dict] = {}
for field in CONFIG_FIELDS:
site_val = site_dict.get(field)
group_val = group_defaults.get(field) if group_defaults else None
org_val = org_defaults.get(field) if org_defaults else None
system_val = SYSTEM_DEFAULTS.get(field)
# Determine effective source (highest priority non-None wins)
if site_val is not None:
source = "site"
elif group_val is not None:
source = "group"
elif org_val is not None:
source = "org"
elif system_val is not None:
source = "system"
else:
source = "system"
sources[field] = {
"resolved_value": resolved.get(field),
"source": source,
"site_value": site_val,
"group_value": group_val,
"org_value": org_val,
"system_value": system_val,
}
return {
"site_id": str(site_id),
"site_group_id": str(group_id) if group_id else None,
"fields": sources,
}
@router.post("/sites/{site_id}/publish")
async def publish_config(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Publish fully-resolved site config to CDN. Requires admin role."""
result = await db.execute(
select(SiteConfig)
.join(Site)
.where(
SiteConfig.site_id == site_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found",
)
config_dict = orm_to_config_dict(config, include_id=True)
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
group_id = await _get_site_group_id(site_id, db)
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
resolved = resolve_config(
config_dict,
org_defaults=org_defaults,
group_defaults=group_defaults,
)
# Allow extensions to enrich the published config (e.g. A/B test data)
registry = get_registry()
for enricher in registry.config_enrichers:
await enricher(site_id, db, resolved)
publish_result = await publish_site_config(str(site_id), resolved)
if not publish_result.success:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Publish failed: {publish_result.error}",
)
return {
"published": True,
"path": publish_result.path,
"published_at": publish_result.published_at,
}
# ── Helpers ──────────────────────────────────────────────────────────
async def _get_site_org_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
"""Look up the organisation_id for a site."""
result = await db.execute(select(Site.organisation_id).where(Site.id == site_id))
return result.scalar_one_or_none()
async def _get_site_group_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
"""Look up the site_group_id for a site."""
result = await db.execute(select(Site.site_group_id).where(Site.id == site_id))
return result.scalar_one_or_none()
async def _load_org_defaults(organisation_id: uuid.UUID, db: AsyncSession) -> dict | None:
"""Load the org-level config defaults, or None if not set."""
result = await db.execute(select(OrgConfig).where(OrgConfig.organisation_id == organisation_id))
org_config = result.scalar_one_or_none()
if org_config is None:
return None
return orm_to_config_dict(org_config)
async def _load_group_defaults(group_id: uuid.UUID, db: AsyncSession) -> dict | None:
"""Load the site-group-level config defaults, or None if not set."""
result = await db.execute(
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
)
group_config = result.scalar_one_or_none()
if group_config is None:
return None
return orm_to_config_dict(group_config)

View File

@@ -0,0 +1,125 @@
import uuid
from fastapi import APIRouter, Depends, HTTPException, Request, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.extensions.registry import get_registry
from src.models.consent import ConsentRecord
from src.models.site import Site
from src.schemas.auth import CurrentUser
from src.schemas.consent import (
ConsentRecordCreate,
ConsentRecordResponse,
ConsentVerifyResponse,
)
from src.services.dependencies import require_role
from src.services.pseudonymisation import pseudonymise
router = APIRouter(prefix="/consent", tags=["consent"])
@router.post("/", response_model=ConsentRecordResponse, status_code=status.HTTP_201_CREATED)
async def record_consent(
body: ConsentRecordCreate,
request: Request,
db: AsyncSession = Depends(get_db),
) -> ConsentRecord:
"""Record a consent event from the banner. Public endpoint (no auth required)."""
# Pseudonymise IP and user agent with HMAC so the resulting values
# cannot be reversed without the server-side secret.
client_ip = request.client.host if request.client else ""
user_agent = request.headers.get("user-agent", "")
record = ConsentRecord(
site_id=body.site_id,
visitor_id=body.visitor_id,
ip_hash=pseudonymise(client_ip),
user_agent_hash=pseudonymise(user_agent),
action=body.action,
categories_accepted=body.categories_accepted,
categories_rejected=body.categories_rejected,
tc_string=body.tc_string,
gcm_state=body.gcm_state,
page_url=body.page_url,
country_code=body.country_code,
region_code=body.region_code,
)
db.add(record)
await db.flush()
await db.refresh(record)
# Invoke any registered post-record hooks (EE consent receipts, etc.)
for hook in get_registry().consent_record_hooks:
await hook(db, record)
return record
async def _load_record_for_org(
consent_id: uuid.UUID,
current_user: CurrentUser,
db: AsyncSession,
) -> ConsentRecord:
"""Load a consent record and enforce tenant isolation.
The record's site must belong to the caller's organisation. A record
from another tenant returns 404 rather than 403 so we don't leak
existence across tenants.
"""
stmt = (
select(ConsentRecord)
.join(Site, Site.id == ConsentRecord.site_id)
.where(
ConsentRecord.id == consent_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
record = (await db.execute(stmt)).scalar_one_or_none()
if record is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Consent record not found",
)
return record
@router.get("/{consent_id}", response_model=ConsentRecordResponse)
async def get_consent(
consent_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> ConsentRecord:
"""Retrieve a consent record by ID.
Requires authentication and tenant membership. Consent records
contain PII-adjacent data (hashed IP, page URL, category decisions)
and must not be readable by anyone holding a record UUID.
"""
return await _load_record_for_org(consent_id, current_user, db)
@router.get("/verify/{consent_id}", response_model=ConsentVerifyResponse)
async def verify_consent(
consent_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Verify that a consent record exists (audit proof).
Same tenant-scoped auth as :func:`get_consent` — proof of consent
is only meaningful to the organisation that owns the site, and
leaking existence to arbitrary callers enables enumeration.
"""
record = await _load_record_for_org(consent_id, current_user, db)
return {
"id": record.id,
"site_id": record.site_id,
"visitor_id": record.visitor_id,
"action": record.action,
"categories_accepted": record.categories_accepted,
"consented_at": record.consented_at,
"valid": True,
}

View File

@@ -0,0 +1,582 @@
"""Cookie category, cookie, and allow-list management endpoints."""
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
from src.models.site import Site
from src.schemas.auth import CurrentUser
from src.schemas.cookie import (
AllowListEntryCreate,
AllowListEntryResponse,
AllowListEntryUpdate,
ClassificationResultResponse,
ClassifySingleRequest,
ClassifySiteResponse,
CookieCategoryResponse,
CookieCreate,
CookieResponse,
CookieUpdate,
KnownCookieCreate,
KnownCookieResponse,
KnownCookieUpdate,
ReviewStatus,
)
from src.services.classification import classify_single_cookie, classify_site_cookies
from src.services.dependencies import get_current_user, require_role
router = APIRouter(prefix="/cookies", tags=["cookies"])
# ── Cookie categories (read-only, seeded by migration) ──────────────
@router.get("/categories", response_model=list[CookieCategoryResponse])
async def list_categories(
db: AsyncSession = Depends(get_db),
) -> list[CookieCategory]:
"""List all cookie categories. Public endpoint used by banner and admin."""
result = await db.execute(select(CookieCategory).order_by(CookieCategory.display_order))
return list(result.scalars().all())
@router.get("/categories/{category_id}", response_model=CookieCategoryResponse)
async def get_category(
category_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
) -> CookieCategory:
"""Get a single cookie category by ID."""
result = await db.execute(select(CookieCategory).where(CookieCategory.id == category_id))
category = result.scalar_one_or_none()
if not category:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category not found")
return category
# ── Cookies per site ─────────────────────────────────────────────────
async def _get_org_site(
site_id: uuid.UUID,
current_user: CurrentUser,
db: AsyncSession,
) -> Site:
"""Fetch a site ensuring it belongs to the user's organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if not site:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
return site
@router.get(
"/sites/{site_id}",
response_model=list[CookieResponse],
)
async def list_cookies(
site_id: uuid.UUID,
review_status: ReviewStatus | None = Query(None),
category_id: uuid.UUID | None = Query(None),
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[Cookie]:
"""List cookies discovered on a site, with optional filters."""
await _get_org_site(site_id, current_user, db)
query = select(Cookie).where(Cookie.site_id == site_id)
if review_status:
query = query.where(Cookie.review_status == review_status.value)
if category_id:
query = query.where(Cookie.category_id == category_id)
query = query.order_by(Cookie.name)
result = await db.execute(query)
return list(result.scalars().all())
@router.post(
"/sites/{site_id}",
response_model=CookieResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_cookie(
site_id: uuid.UUID,
body: CookieCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Cookie:
"""Create a cookie record for a site (manual entry or from scanner)."""
await _get_org_site(site_id, current_user, db)
# Validate category if provided
if body.category_id:
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
cookie = Cookie(
site_id=site_id,
**body.model_dump(),
first_seen_at=datetime.now(UTC).isoformat(),
last_seen_at=datetime.now(UTC).isoformat(),
)
db.add(cookie)
await db.flush()
await db.refresh(cookie)
return cookie
@router.get("/sites/{site_id}/summary")
async def cookie_summary(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Get a summary of cookies for a site (counts by status and category)."""
await _get_org_site(site_id, current_user, db)
# Count by review status
status_result = await db.execute(
select(Cookie.review_status, func.count(Cookie.id))
.where(Cookie.site_id == site_id)
.group_by(Cookie.review_status)
)
by_status = {row[0]: row[1] for row in status_result.all()}
# Count by category
cat_result = await db.execute(
select(CookieCategory.slug, func.count(Cookie.id))
.outerjoin(Cookie, Cookie.category_id == CookieCategory.id)
.where(Cookie.site_id == site_id)
.group_by(CookieCategory.slug)
)
by_category = {row[0]: row[1] for row in cat_result.all()}
# Uncategorised count
uncat_result = await db.execute(
select(func.count(Cookie.id)).where(Cookie.site_id == site_id, Cookie.category_id.is_(None))
)
uncategorised = uncat_result.scalar() or 0
return {
"total": sum(by_status.values()),
"by_status": by_status,
"by_category": by_category,
"uncategorised": uncategorised,
}
# ── Allow-list per site ──────────────────────────────────────────────
# (Must be defined before {cookie_id} routes to avoid path conflicts)
@router.get(
"/sites/{site_id}/allow-list",
response_model=list[AllowListEntryResponse],
)
async def list_allow_list(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> list[CookieAllowListEntry]:
"""List all allow-list entries for a site."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(CookieAllowListEntry)
.where(CookieAllowListEntry.site_id == site_id)
.order_by(CookieAllowListEntry.name_pattern)
)
return list(result.scalars().all())
@router.post(
"/sites/{site_id}/allow-list",
response_model=AllowListEntryResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_allow_list_entry(
site_id: uuid.UUID,
body: AllowListEntryCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> CookieAllowListEntry:
"""Add a cookie pattern to the allow-list for a site."""
await _get_org_site(site_id, current_user, db)
# Validate category
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
entry = CookieAllowListEntry(
site_id=site_id,
**body.model_dump(),
)
db.add(entry)
await db.flush()
await db.refresh(entry)
return entry
@router.patch(
"/sites/{site_id}/allow-list/{entry_id}",
response_model=AllowListEntryResponse,
)
async def update_allow_list_entry(
site_id: uuid.UUID,
entry_id: uuid.UUID,
body: AllowListEntryUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> CookieAllowListEntry:
"""Update an allow-list entry."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(CookieAllowListEntry).where(
CookieAllowListEntry.id == entry_id,
CookieAllowListEntry.site_id == site_id,
)
)
entry = result.scalar_one_or_none()
if not entry:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Allow-list entry not found",
)
updates = body.model_dump(exclude_unset=True)
if "category_id" in updates and updates["category_id"] is not None:
cat = await db.execute(
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
)
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
for field, value in updates.items():
setattr(entry, field, value)
entry.updated_at = datetime.now(UTC)
await db.flush()
await db.refresh(entry)
return entry
@router.delete(
"/sites/{site_id}/allow-list/{entry_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_allow_list_entry(
site_id: uuid.UUID,
entry_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Remove an entry from the allow-list."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(CookieAllowListEntry).where(
CookieAllowListEntry.id == entry_id,
CookieAllowListEntry.site_id == site_id,
)
)
entry = result.scalar_one_or_none()
if not entry:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Allow-list entry not found",
)
await db.delete(entry)
# ── Individual cookie by ID (must come after /summary and /allow-list) ──
@router.get("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
async def get_cookie(
site_id: uuid.UUID,
cookie_id: uuid.UUID,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> Cookie:
"""Get a single cookie by ID."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
)
cookie = result.scalar_one_or_none()
if not cookie:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
return cookie
@router.patch("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
async def update_cookie(
site_id: uuid.UUID,
cookie_id: uuid.UUID,
body: CookieUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Cookie:
"""Update a cookie record (e.g. assign category, change review status)."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
)
cookie = result.scalar_one_or_none()
if not cookie:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
updates = body.model_dump(exclude_unset=True)
# Validate category if being changed
if "category_id" in updates and updates["category_id"] is not None:
cat = await db.execute(
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
)
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
for field, value in updates.items():
setattr(cookie, field, value)
cookie.updated_at = datetime.now(UTC)
await db.flush()
await db.refresh(cookie)
return cookie
@router.delete(
"/sites/{site_id}/{cookie_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_cookie(
site_id: uuid.UUID,
cookie_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a cookie record."""
await _get_org_site(site_id, current_user, db)
result = await db.execute(
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
)
cookie = result.scalar_one_or_none()
if not cookie:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
await db.delete(cookie)
# ── Known cookies database ──────────────────────────────────────────
@router.get("/known", response_model=list[KnownCookieResponse])
async def list_known_cookies(
vendor: str | None = Query(None, description="Filter by vendor name"),
search: str | None = Query(None, description="Search by name pattern"),
db: AsyncSession = Depends(get_db),
_user: CurrentUser = Depends(get_current_user),
) -> list[KnownCookie]:
"""List known cookie patterns from the shared database."""
query = select(KnownCookie).order_by(KnownCookie.name_pattern)
if vendor:
query = query.where(KnownCookie.vendor == vendor)
if search:
query = query.where(KnownCookie.name_pattern.ilike(f"%{search}%"))
result = await db.execute(query)
return list(result.scalars().all())
@router.post(
"/known",
response_model=KnownCookieResponse,
status_code=status.HTTP_201_CREATED,
)
async def create_known_cookie(
body: KnownCookieCreate,
_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> KnownCookie:
"""Add a new pattern to the known cookies database."""
# Validate category
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
known = KnownCookie(**body.model_dump())
db.add(known)
await db.flush()
await db.refresh(known)
return known
@router.get("/known/{known_id}", response_model=KnownCookieResponse)
async def get_known_cookie(
known_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
_user: CurrentUser = Depends(get_current_user),
) -> KnownCookie:
"""Get a single known cookie pattern by ID."""
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
known = result.scalar_one_or_none()
if not known:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Known cookie not found",
)
return known
@router.patch("/known/{known_id}", response_model=KnownCookieResponse)
async def update_known_cookie(
known_id: uuid.UUID,
body: KnownCookieUpdate,
_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> KnownCookie:
"""Update a known cookie pattern."""
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
known = result.scalar_one_or_none()
if not known:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Known cookie not found",
)
updates = body.model_dump(exclude_unset=True)
if "category_id" in updates and updates["category_id"] is not None:
cat = await db.execute(
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
)
if not cat.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid category_id",
)
for field, value in updates.items():
setattr(known, field, value)
known.updated_at = datetime.now(UTC)
await db.flush()
await db.refresh(known)
return known
@router.delete(
"/known/{known_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
async def delete_known_cookie(
known_id: uuid.UUID,
_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a known cookie pattern."""
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
known = result.scalar_one_or_none()
if not known:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Known cookie not found",
)
await db.delete(known)
# ── Classification endpoints ────────────────────────────────────────
@router.post(
"/sites/{site_id}/classify",
response_model=ClassifySiteResponse,
)
async def classify_cookies(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> ClassifySiteResponse:
"""Auto-classify pending cookies for a site against known patterns."""
await _get_org_site(site_id, current_user, db)
results = await classify_site_cookies(db, site_id, only_pending=True)
matched_count = sum(1 for r in results if r.matched)
return ClassifySiteResponse(
site_id=str(site_id),
total=len(results),
matched=matched_count,
unmatched=len(results) - matched_count,
results=[
ClassificationResultResponse(
cookie_name=r.cookie_name,
cookie_domain=r.cookie_domain,
category_id=r.category_id,
category_slug=r.category_slug,
vendor=r.vendor,
description=r.description,
match_source=r.match_source,
matched=r.matched,
)
for r in results
],
)
@router.post(
"/sites/{site_id}/classify/preview",
response_model=ClassificationResultResponse,
)
async def classify_preview(
site_id: uuid.UUID,
body: ClassifySingleRequest,
current_user: CurrentUser = Depends(get_current_user),
db: AsyncSession = Depends(get_db),
) -> ClassificationResultResponse:
"""Preview classification for a single cookie without saving."""
await _get_org_site(site_id, current_user, db)
result = await classify_single_cookie(db, site_id, body.cookie_name, body.cookie_domain)
return ClassificationResultResponse(
cookie_name=result.cookie_name,
cookie_domain=result.cookie_domain,
category_id=result.category_id,
category_slug=result.category_slug,
vendor=result.vendor,
description=result.description,
match_source=result.match_source,
matched=result.matched,
)

View File

@@ -0,0 +1,69 @@
"""Organisation-level default configuration endpoints.
Provides GET and PUT for the organisation's global config defaults.
These defaults sit between system defaults and site config in the cascade.
"""
from fastapi import APIRouter, Depends
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.org_config import OrgConfig
from src.schemas.auth import CurrentUser
from src.schemas.org_config import OrgConfigResponse, OrgConfigUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/org-config", tags=["organisations"])
@router.get("/", response_model=OrgConfigResponse)
async def get_org_config(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> OrgConfig:
"""Retrieve the organisation's global configuration defaults."""
result = await db.execute(
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
)
config = result.scalar_one_or_none()
if config is None:
# Auto-create an empty config row so the response is always valid
config = OrgConfig(organisation_id=current_user.organisation_id)
db.add(config)
await db.flush()
await db.refresh(config)
return config
@router.put("/", response_model=OrgConfigResponse)
async def update_org_config(
body: OrgConfigUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> OrgConfig:
"""Create or update the organisation's global configuration defaults.
Only non-None fields will override system defaults when resolving site config.
"""
result = await db.execute(
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
)
config = result.scalar_one_or_none()
if config is None:
config = OrgConfig(
organisation_id=current_user.organisation_id,
**body.model_dump(exclude_unset=True),
)
db.add(config)
else:
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
await db.flush()
await db.refresh(config)
return config

View File

@@ -0,0 +1,118 @@
import hmac
from fastapi import APIRouter, Depends, Header, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import get_settings
from src.db import get_db
from src.models.organisation import Organisation
from src.schemas.auth import CurrentUser
from src.schemas.organisation import (
OrganisationCreate,
OrganisationResponse,
OrganisationUpdate,
)
from src.services.dependencies import require_role
router = APIRouter(prefix="/organisations", tags=["organisations"])
def _require_bootstrap_token(
x_admin_bootstrap_token: str | None = Header(default=None),
) -> None:
"""Gate organisation creation behind a static bootstrap token.
The token is configured via ``ADMIN_BOOTSTRAP_TOKEN``. When unset
(the default), the endpoint is disabled entirely — operators must
explicitly opt in and should rotate or unset the value after their
initial org is provisioned.
"""
expected = get_settings().admin_bootstrap_token
if not expected:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=(
"Organisation creation is disabled. Set ADMIN_BOOTSTRAP_TOKEN "
"in the environment to enable it."
),
)
if not x_admin_bootstrap_token or not hmac.compare_digest(
x_admin_bootstrap_token,
expected,
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or missing admin bootstrap token",
)
@router.post("/", response_model=OrganisationResponse, status_code=status.HTTP_201_CREATED)
async def create_organisation(
body: OrganisationCreate,
db: AsyncSession = Depends(get_db),
_: None = Depends(_require_bootstrap_token),
) -> Organisation:
"""Create a new organisation. Gated by ``X-Admin-Bootstrap-Token``.
See :func:`_require_bootstrap_token` for the gating semantics. Once
your initial organisation exists, rotate or unset
``ADMIN_BOOTSTRAP_TOKEN`` to disable further tenant creation.
"""
# Check slug uniqueness
existing = await db.execute(select(Organisation).where(Organisation.slug == body.slug))
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Organisation with slug '{body.slug}' already exists",
)
org = Organisation(**body.model_dump())
db.add(org)
await db.flush()
await db.refresh(org)
return org
@router.get("/me", response_model=OrganisationResponse)
async def get_my_organisation(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> Organisation:
"""Get the current user's organisation."""
result = await db.execute(
select(Organisation).where(
Organisation.id == current_user.organisation_id,
Organisation.deleted_at.is_(None),
)
)
org = result.scalar_one_or_none()
if org is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
return org
@router.patch("/me", response_model=OrganisationResponse)
async def update_my_organisation(
body: OrganisationUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> Organisation:
"""Update the current user's organisation. Requires owner or admin role."""
result = await db.execute(
select(Organisation).where(
Organisation.id == current_user.organisation_id,
Organisation.deleted_at.is_(None),
)
)
org = result.scalar_one_or_none()
if org is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(org, field, value)
await db.flush()
await db.refresh(org)
return org

View File

@@ -0,0 +1,310 @@
"""Scanner and client-side cookie report endpoints.
Accepts cookie reports from the client-side reporter embedded in the banner
bundle, upserts discovered cookies into the site's cookie inventory, and
provides scan job management (trigger, list, detail, diff).
"""
import logging
import uuid
from datetime import UTC, datetime, timedelta
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.cookie import Cookie
from src.models.scan import ScanJob, ScanResult
from src.models.site import Site
from src.schemas.auth import CurrentUser
from src.schemas.scanner import (
CookieReportRequest,
CookieReportResponse,
ScanDiffResponse,
ScanJobDetailResponse,
ScanJobResponse,
TriggerScanRequest,
)
from src.services.dependencies import get_current_user
from src.services.scanner import (
compute_scan_diff,
create_scan_job,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/scanner", tags=["scanner"])
# ── Client-side cookie report (public, no auth) ─────────────────────
@router.post(
"/report",
response_model=CookieReportResponse,
status_code=status.HTTP_202_ACCEPTED,
)
async def receive_cookie_report(
body: CookieReportRequest,
db: AsyncSession = Depends(get_db),
) -> CookieReportResponse:
"""Receive a cookie report from the client-side reporter.
This is a public endpoint (no auth) since it's called from the banner
script running on end-user browsers. The site_id acts as implicit auth.
"""
# Verify site exists
site_result = await db.execute(
select(Site).where(
Site.id == body.site_id,
Site.deleted_at.is_(None),
)
)
if site_result.scalar_one_or_none() is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site not found",
)
new_cookies = 0
now_iso = datetime.now(UTC).isoformat()
for reported in body.cookies:
# Check if this cookie already exists for the site
existing = await db.execute(
select(Cookie).where(
Cookie.site_id == body.site_id,
Cookie.name == reported.name,
Cookie.domain == reported.domain,
Cookie.storage_type == reported.storage_type,
)
)
cookie = existing.scalar_one_or_none()
if cookie:
# Update last_seen_at timestamp
cookie.last_seen_at = now_iso
else:
# Create new cookie record
cookie = Cookie(
site_id=body.site_id,
name=reported.name,
domain=reported.domain,
storage_type=reported.storage_type,
path=reported.path,
is_secure=reported.is_secure,
same_site=reported.same_site,
review_status="pending",
first_seen_at=now_iso,
last_seen_at=now_iso,
)
db.add(cookie)
new_cookies += 1
await db.flush()
return CookieReportResponse(
accepted=True,
cookies_received=len(body.cookies),
new_cookies=new_cookies,
)
# ── Scan job management (authenticated) ─────────────────────────────
async def _verify_site_access(
site_id: uuid.UUID,
user: CurrentUser,
db: AsyncSession,
) -> Site:
"""Verify site exists and belongs to the user's organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == user.organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if site is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site not found",
)
return site
@router.post(
"/scans",
response_model=ScanJobResponse,
status_code=status.HTTP_201_CREATED,
)
async def trigger_scan(
body: TriggerScanRequest,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
) -> ScanJob:
"""Trigger a new cookie scan for a site.
Creates a scan job in 'pending' state and dispatches it to the
Celery worker queue for execution.
"""
from src.services.scanner import complete_scan_job
await _verify_site_access(body.site_id, user, db)
# Check for an already-running scan
active_result = await db.execute(
select(ScanJob).where(
ScanJob.site_id == body.site_id,
ScanJob.status.in_(["pending", "running"]),
)
)
active_jobs = list(active_result.scalars().all())
now = datetime.now(UTC)
stale_pending_cutoff = now - timedelta(minutes=5)
stale_running_cutoff = now - timedelta(minutes=10)
for active_job in active_jobs:
is_stale_pending = (
active_job.status == "pending"
and active_job.created_at.replace(tzinfo=UTC) < stale_pending_cutoff
)
is_stale_running = (
active_job.status == "running"
and active_job.started_at
and active_job.started_at.replace(tzinfo=UTC) < stale_running_cutoff
)
if is_stale_pending or is_stale_running:
logger.warning(
"Failing stale %s scan job %s for site %s",
active_job.status,
active_job.id,
body.site_id,
)
await complete_scan_job(
db,
active_job,
error_message=(
f"Job was stale ({active_job.status} too long), superseded by new scan"
),
)
else:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="A scan is already in progress for this site",
)
job = await create_scan_job(
db,
site_id=body.site_id,
trigger="manual",
max_pages=body.max_pages,
)
# Commit before dispatching to Celery so the worker can find the
# job in the database immediately (avoids race condition).
await db.commit()
# Dispatch to Celery (import here to avoid import at module level
# when Celery broker is unavailable during testing)
try:
from src.tasks.scanner import run_scan
run_scan.delay(str(job.id), str(body.site_id))
except Exception:
logger.exception("Failed to dispatch scan job %s to Celery", job.id)
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=(
"Background task queue is unavailable — scan job"
" created but cannot be processed. Please try again later."
),
) from None
return job
@router.get("/scans/site/{site_id}", response_model=list[ScanJobResponse])
async def list_scans(
site_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
limit: int = Query(default=20, ge=1, le=100),
offset: int = Query(default=0, ge=0),
) -> list[ScanJob]:
"""List scan jobs for a site, most recent first."""
await _verify_site_access(site_id, user, db)
result = await db.execute(
select(ScanJob)
.where(ScanJob.site_id == site_id)
.order_by(ScanJob.created_at.desc())
.limit(limit)
.offset(offset)
)
return list(result.scalars().all())
@router.get("/scans/{scan_id}", response_model=ScanJobDetailResponse)
async def get_scan(
scan_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
) -> dict:
"""Retrieve a scan job with its results."""
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scan job not found",
)
# Verify org access
await _verify_site_access(job.site_id, user, db)
# Load results
results = await db.execute(
select(ScanResult).where(ScanResult.scan_job_id == scan_id).order_by(ScanResult.cookie_name)
)
scan_results = list(results.scalars().all())
return {
"id": job.id,
"site_id": job.site_id,
"status": job.status,
"trigger": job.trigger,
"pages_scanned": job.pages_scanned,
"pages_total": job.pages_total,
"cookies_found": job.cookies_found,
"error_message": job.error_message,
"started_at": job.started_at,
"completed_at": job.completed_at,
"created_at": job.created_at,
"updated_at": job.updated_at,
"results": scan_results,
}
@router.get("/scans/{scan_id}/diff", response_model=ScanDiffResponse)
async def get_scan_diff(
scan_id: uuid.UUID,
db: AsyncSession = Depends(get_db),
user: CurrentUser = Depends(get_current_user),
) -> ScanDiffResponse:
"""Get the diff between a scan and its predecessor."""
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
job = result.scalar_one_or_none()
if job is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Scan job not found",
)
await _verify_site_access(job.site_id, user, db)
return await compute_scan_diff(db, current_scan_id=scan_id, site_id=job.site_id)

View File

@@ -0,0 +1,101 @@
"""Site-group-level default configuration endpoints.
Provides GET and PUT for a site group's config defaults.
These defaults sit between org defaults and site config in the cascade.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site_group import SiteGroup
from src.models.site_group_config import SiteGroupConfig
from src.schemas.auth import CurrentUser
from src.schemas.site_group_config import SiteGroupConfigResponse, SiteGroupConfigUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
@router.get("/{group_id}/config", response_model=SiteGroupConfigResponse)
async def get_site_group_config(
group_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> SiteGroupConfig:
"""Retrieve configuration defaults for a site group."""
await _verify_group_ownership(group_id, current_user.organisation_id, db)
result = await db.execute(
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
)
config = result.scalar_one_or_none()
if config is None:
# Auto-create an empty config row so the response is always valid
config = SiteGroupConfig(site_group_id=group_id)
db.add(config)
await db.flush()
await db.refresh(config)
return config
@router.put("/{group_id}/config", response_model=SiteGroupConfigResponse)
async def update_site_group_config(
group_id: uuid.UUID,
body: SiteGroupConfigUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> SiteGroupConfig:
"""Create or update configuration defaults for a site group.
Only non-None fields will override org/system defaults when resolving site config.
"""
await _verify_group_ownership(group_id, current_user.organisation_id, db)
result = await db.execute(
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
)
config = result.scalar_one_or_none()
if config is None:
config = SiteGroupConfig(
site_group_id=group_id,
**body.model_dump(exclude_unset=True),
)
db.add(config)
else:
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
await db.flush()
await db.refresh(config)
return config
# -- Helpers ------------------------------------------------------------------
async def _verify_group_ownership(
group_id: uuid.UUID,
organisation_id: uuid.UUID,
db: AsyncSession,
) -> None:
"""Ensure the site group belongs to the user's organisation."""
result = await db.execute(
select(SiteGroup).where(
SiteGroup.id == group_id,
SiteGroup.organisation_id == organisation_id,
SiteGroup.deleted_at.is_(None),
)
)
if result.scalar_one_or_none() is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site group not found",
)

View File

@@ -0,0 +1,198 @@
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site import Site
from src.models.site_group import SiteGroup
from src.schemas.auth import CurrentUser
from src.schemas.site_group import SiteGroupCreate, SiteGroupResponse, SiteGroupUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
@router.post("/", response_model=SiteGroupResponse, status_code=status.HTTP_201_CREATED)
async def create_site_group(
body: SiteGroupCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Create a new site group within the current organisation."""
# Check name uniqueness within the org
existing = await db.execute(
select(SiteGroup).where(
SiteGroup.organisation_id == current_user.organisation_id,
SiteGroup.name == body.name,
SiteGroup.deleted_at.is_(None),
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Site group '{body.name}' already exists in this organisation",
)
group = SiteGroup(
organisation_id=current_user.organisation_id,
name=body.name,
description=body.description,
)
db.add(group)
await db.flush()
await db.refresh(group)
return _to_response(group, site_count=0)
@router.get("/", response_model=list[SiteGroupResponse])
async def list_site_groups(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[dict]:
"""List all site groups in the current organisation with site counts."""
# Subquery for site counts
site_count_sq = (
select(
Site.site_group_id,
func.count(Site.id).label("cnt"),
)
.where(Site.deleted_at.is_(None))
.group_by(Site.site_group_id)
.subquery()
)
result = await db.execute(
select(SiteGroup, func.coalesce(site_count_sq.c.cnt, 0).label("site_count"))
.outerjoin(site_count_sq, SiteGroup.id == site_count_sq.c.site_group_id)
.where(
SiteGroup.organisation_id == current_user.organisation_id,
SiteGroup.deleted_at.is_(None),
)
.order_by(SiteGroup.name)
)
return [_to_response(row.SiteGroup, site_count=row.site_count) for row in result.all()]
@router.get("/{group_id}", response_model=SiteGroupResponse)
async def get_site_group(
group_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Get a specific site group by ID."""
group = await _get_org_group(group_id, current_user.organisation_id, db)
site_count = await _count_sites(group_id, db)
return _to_response(group, site_count=site_count)
@router.patch("/{group_id}", response_model=SiteGroupResponse)
async def update_site_group(
group_id: uuid.UUID,
body: SiteGroupUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> dict:
"""Update a site group's name or description."""
group = await _get_org_group(group_id, current_user.organisation_id, db)
update_data = body.model_dump(exclude_unset=True)
# Check name uniqueness if name is being changed
if "name" in update_data and update_data["name"] != group.name:
existing = await db.execute(
select(SiteGroup).where(
SiteGroup.organisation_id == current_user.organisation_id,
SiteGroup.name == update_data["name"],
SiteGroup.deleted_at.is_(None),
SiteGroup.id != group_id,
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Site group '{update_data['name']}' already exists",
)
for field, value in update_data.items():
setattr(group, field, value)
await db.flush()
await db.refresh(group)
site_count = await _count_sites(group_id, db)
return _to_response(group, site_count=site_count)
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_site_group(
group_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Soft-delete a site group. Sites in this group become ungrouped."""
group = await _get_org_group(group_id, current_user.organisation_id, db)
# Ungroup all sites in this group
result = await db.execute(
select(Site).where(
Site.site_group_id == group_id,
Site.deleted_at.is_(None),
)
)
for site in result.scalars().all():
site.site_group_id = None
group.deleted_at = datetime.now(UTC)
await db.flush()
# ── Helpers ──────────────────────────────────────────────────────────
async def _get_org_group(
group_id: uuid.UUID,
organisation_id: uuid.UUID,
db: AsyncSession,
) -> SiteGroup:
"""Fetch a site group ensuring it belongs to the given organisation."""
result = await db.execute(
select(SiteGroup).where(
SiteGroup.id == group_id,
SiteGroup.organisation_id == organisation_id,
SiteGroup.deleted_at.is_(None),
)
)
group = result.scalar_one_or_none()
if group is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site group not found",
)
return group
async def _count_sites(group_id: uuid.UUID, db: AsyncSession) -> int:
"""Count active sites in a group."""
result = await db.execute(
select(func.count(Site.id)).where(
Site.site_group_id == group_id,
Site.deleted_at.is_(None),
)
)
return result.scalar_one()
def _to_response(group: SiteGroup, *, site_count: int) -> dict:
"""Convert a SiteGroup model to a response dict with site_count."""
return {
"id": group.id,
"organisation_id": group.organisation_id,
"name": group.name,
"description": group.description,
"created_at": group.created_at,
"updated_at": group.updated_at,
"site_count": site_count,
}

View File

@@ -0,0 +1,220 @@
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site import Site
from src.models.site_config import SiteConfig
from src.schemas.auth import CurrentUser
from src.schemas.site import (
SiteConfigCreate,
SiteConfigResponse,
SiteConfigUpdate,
SiteCreate,
SiteResponse,
SiteUpdate,
)
from src.services.dependencies import require_role
router = APIRouter(prefix="/sites", tags=["sites"])
# ── Site CRUD ────────────────────────────────────────────────────────
@router.post("/", response_model=SiteResponse, status_code=status.HTTP_201_CREATED)
async def create_site(
body: SiteCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> Site:
"""Create a new site within the current organisation."""
# Check domain uniqueness within the org
existing = await db.execute(
select(Site).where(
Site.organisation_id == current_user.organisation_id,
Site.domain == body.domain,
Site.deleted_at.is_(None),
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Site with domain '{body.domain}' already exists in this organisation",
)
site = Site(
organisation_id=current_user.organisation_id,
domain=body.domain,
display_name=body.display_name,
site_group_id=body.site_group_id,
)
db.add(site)
await db.flush()
# Auto-create a default site configuration
default_config = SiteConfig(site_id=site.id)
db.add(default_config)
await db.flush()
await db.refresh(site)
return site
@router.get("/", response_model=list[SiteResponse])
async def list_sites(
site_group_id: uuid.UUID | None = Query(default=None),
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[Site]:
"""List all active sites in the current organisation, optionally filtered by group."""
query = select(Site).where(
Site.organisation_id == current_user.organisation_id,
Site.deleted_at.is_(None),
)
if site_group_id is not None:
query = query.where(Site.site_group_id == site_group_id)
result = await db.execute(query.order_by(Site.domain))
return list(result.scalars().all())
@router.get("/{site_id}", response_model=SiteResponse)
async def get_site(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> Site:
"""Get a specific site by ID."""
site = await _get_org_site(site_id, current_user.organisation_id, db)
return site
@router.patch("/{site_id}", response_model=SiteResponse)
async def update_site(
site_id: uuid.UUID,
body: SiteUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Site:
"""Update a site's display name or active status."""
site = await _get_org_site(site_id, current_user.organisation_id, db)
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(site, field, value)
await db.flush()
await db.refresh(site)
return site
@router.delete("/{site_id}", status_code=status.HTTP_204_NO_CONTENT)
async def deactivate_site(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Soft-delete a site."""
site = await _get_org_site(site_id, current_user.organisation_id, db)
site.deleted_at = datetime.now(UTC)
await db.flush()
# ── Site config CRUD ─────────────────────────────────────────────────
@router.get("/{site_id}/config", response_model=SiteConfigResponse)
async def get_site_config(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Get the configuration for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found. Create one first.",
)
return config
@router.put("/{site_id}/config", response_model=SiteConfigResponse)
async def create_or_replace_site_config(
site_id: uuid.UUID,
body: SiteConfigCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Create or replace the full configuration for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
existing = result.scalar_one_or_none()
if existing is not None:
for field, value in body.model_dump().items():
setattr(existing, field, value)
await db.flush()
await db.refresh(existing)
return existing
config = SiteConfig(site_id=site_id, **body.model_dump())
db.add(config)
await db.flush()
await db.refresh(config)
return config
@router.patch("/{site_id}/config", response_model=SiteConfigResponse)
async def update_site_config(
site_id: uuid.UUID,
body: SiteConfigUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> SiteConfig:
"""Partially update the configuration for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
config = result.scalar_one_or_none()
if config is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Site configuration not found. Create one first.",
)
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(config, field, value)
await db.flush()
await db.refresh(config)
return config
# ── Helpers ──────────────────────────────────────────────────────────
async def _get_org_site(
site_id: uuid.UUID,
organisation_id: uuid.UUID,
db: AsyncSession,
) -> Site:
"""Fetch a site ensuring it belongs to the given organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if site is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
return site

View File

@@ -0,0 +1,195 @@
"""Translation management endpoints.
CRUD for per-site, per-locale translation strings used by the banner script.
"""
import uuid
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.site import Site
from src.models.translation import Translation
from src.schemas.auth import CurrentUser
from src.schemas.translation import TranslationCreate, TranslationResponse, TranslationUpdate
from src.services.dependencies import require_role
router = APIRouter(prefix="/sites/{site_id}/translations", tags=["translations"])
async def _get_org_site(site_id: uuid.UUID, organisation_id: uuid.UUID, db: AsyncSession) -> Site:
"""Ensure site belongs to the current organisation."""
result = await db.execute(
select(Site).where(
Site.id == site_id,
Site.organisation_id == organisation_id,
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if site is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
return site
@router.get("/", response_model=list[TranslationResponse])
async def list_translations(
site_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[Translation]:
"""List all translations for a site."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(Translation.site_id == site_id).order_by(Translation.locale)
)
return list(result.scalars().all())
@router.get("/{locale}", response_model=TranslationResponse)
async def get_translation(
site_id: uuid.UUID,
locale: str,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> Translation:
"""Get translation strings for a specific locale."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == locale,
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No translation found for locale '{locale}'",
)
return translation
@router.post("/", response_model=TranslationResponse, status_code=status.HTTP_201_CREATED)
async def create_translation(
site_id: uuid.UUID,
body: TranslationCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Translation:
"""Create a translation for a new locale."""
await _get_org_site(site_id, current_user.organisation_id, db)
# Check for duplicate locale
existing = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == body.locale,
)
)
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"Translation for locale '{body.locale}' already exists",
)
translation = Translation(
site_id=site_id,
locale=body.locale,
strings=body.strings,
)
db.add(translation)
await db.flush()
await db.refresh(translation)
return translation
@router.put("/{locale}", response_model=TranslationResponse)
async def update_translation(
site_id: uuid.UUID,
locale: str,
body: TranslationUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
db: AsyncSession = Depends(get_db),
) -> Translation:
"""Replace the strings for an existing locale translation."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == locale,
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No translation found for locale '{locale}'",
)
translation.strings = body.strings
await db.flush()
await db.refresh(translation)
return translation
@router.delete("/{locale}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_translation(
site_id: uuid.UUID,
locale: str,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Delete a translation for a specific locale."""
await _get_org_site(site_id, current_user.organisation_id, db)
result = await db.execute(
select(Translation).where(
Translation.site_id == site_id,
Translation.locale == locale,
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"No translation found for locale '{locale}'",
)
await db.delete(translation)
await db.flush()
# ── Public endpoint for the banner script ────────────────────────────
public_router = APIRouter(prefix="/translations", tags=["translations"])
@public_router.get("/{site_id}/{locale}")
async def get_public_translation(
site_id: uuid.UUID,
locale: str,
db: AsyncSession = Depends(get_db),
) -> dict[str, str]:
"""Public endpoint: return translation strings for the banner script.
No auth required. Returns the raw strings dict for a given site and locale.
Returns 404 if no translation exists (banner falls back to English defaults).
"""
result = await db.execute(
select(Translation)
.join(Site)
.where(
Translation.site_id == site_id,
Translation.locale == locale,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
translation = result.scalar_one_or_none()
if translation is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="Translation not found",
)
return translation.strings

View File

@@ -0,0 +1,136 @@
import uuid
from datetime import UTC, datetime
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.db import get_db
from src.models.user import User
from src.schemas.auth import CurrentUser
from src.schemas.user import UserCreate, UserResponse, UserUpdate
from src.services.auth import hash_password
from src.services.dependencies import require_role
router = APIRouter(prefix="/users", tags=["users"])
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
async def create_user(
body: UserCreate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> User:
"""Invite/create a new user within the current organisation."""
# Check email uniqueness
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail=f"User with email '{body.email}' already exists",
)
user = User(
organisation_id=current_user.organisation_id,
email=body.email,
password_hash=hash_password(body.password),
full_name=body.full_name,
role=body.role,
)
db.add(user)
await db.flush()
await db.refresh(user)
return user
@router.get("/", response_model=list[UserResponse])
async def list_users(
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> list[User]:
"""List all active users in the current organisation."""
result = await db.execute(
select(User)
.where(
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
.order_by(User.created_at)
)
return list(result.scalars().all())
@router.get("/{user_id}", response_model=UserResponse)
async def get_user(
user_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
db: AsyncSession = Depends(get_db),
) -> User:
"""Get a specific user by ID within the current organisation."""
result = await db.execute(
select(User).where(
User.id == user_id,
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
return user
@router.patch("/{user_id}", response_model=UserResponse)
async def update_user(
user_id: uuid.UUID,
body: UserUpdate,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> User:
"""Update a user's name or role. Requires owner or admin."""
result = await db.execute(
select(User).where(
User.id == user_id,
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
update_data = body.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(user, field, value)
await db.flush()
await db.refresh(user)
return user
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
async def deactivate_user(
user_id: uuid.UUID,
current_user: CurrentUser = Depends(require_role("owner", "admin")),
db: AsyncSession = Depends(get_db),
) -> None:
"""Soft-delete (deactivate) a user. Requires owner or admin."""
if user_id == current_user.id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot deactivate yourself",
)
result = await db.execute(
select(User).where(
User.id == user_id,
User.organisation_id == current_user.organisation_id,
User.deleted_at.is_(None),
)
)
user = result.scalar_one_or_none()
if user is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
user.deleted_at = datetime.now(UTC)
await db.flush()

View File

View File

@@ -0,0 +1,45 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, EmailStr
class LoginRequest(BaseModel):
email: EmailStr
password: str
class RefreshRequest(BaseModel):
refresh_token: str
class TokenResponse(BaseModel):
access_token: str
refresh_token: str
token_type: str = "bearer"
expires_in: int
class TokenPayload(BaseModel):
sub: str # user ID
org_id: str # organisation ID
role: str # user role
exp: datetime
iat: datetime
type: str = "access" # "access" or "refresh"
class CurrentUser(BaseModel):
"""Represents the authenticated user extracted from a JWT."""
id: uuid.UUID
organisation_id: uuid.UUID
email: str
role: str
def has_role(self, *roles: str) -> bool:
return self.role in roles
@property
def is_admin(self) -> bool:
return self.role in ("owner", "admin")

View File

@@ -0,0 +1,56 @@
"""Pydantic schemas for compliance check results."""
from enum import StrEnum
from pydantic import BaseModel, Field
class Severity(StrEnum):
CRITICAL = "critical"
WARNING = "warning"
INFO = "info"
class Framework(StrEnum):
GDPR = "gdpr"
CNIL = "cnil"
CCPA = "ccpa"
EPRIVACY = "eprivacy"
LGPD = "lgpd"
class ComplianceIssue(BaseModel):
"""A single compliance issue found during a check."""
rule_id: str
severity: Severity
message: str
recommendation: str
class FrameworkResult(BaseModel):
"""Compliance result for a single regulatory framework."""
framework: Framework
score: int = Field(ge=0, le=100, description="Compliance score (0-100)")
status: str = Field(description="Overall status: compliant, partial, non_compliant")
issues: list[ComplianceIssue] = Field(default_factory=list)
rules_checked: int = 0
rules_passed: int = 0
class ComplianceCheckRequest(BaseModel):
"""Request body for compliance checks."""
frameworks: list[Framework] | None = Field(
default=None,
description="Frameworks to check. If null, all frameworks are checked.",
)
class ComplianceCheckResponse(BaseModel):
"""Full compliance check response for a site."""
site_id: str
results: list[FrameworkResult]
overall_score: int = Field(ge=0, le=100, description="Weighted average across all frameworks")

View File

@@ -0,0 +1,62 @@
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class ConsentAction(StrEnum):
ACCEPT_ALL = "accept_all"
REJECT_ALL = "reject_all"
CUSTOM = "custom"
WITHDRAW = "withdraw"
class ConsentRecordCreate(BaseModel):
"""Payload sent by the banner when a consent event occurs."""
site_id: uuid.UUID
visitor_id: str = Field(min_length=1, max_length=255)
action: ConsentAction
categories_accepted: list[str]
categories_rejected: list[str] | None = None
tc_string: str | None = None
gcm_state: dict | None = None
gpp_string: str | None = None
gpc_detected: bool | None = None
gpc_honoured: bool | None = None
page_url: str | None = None
country_code: str | None = Field(default=None, max_length=5)
region_code: str | None = Field(default=None, max_length=10)
class ConsentRecordResponse(BaseModel):
id: uuid.UUID
site_id: uuid.UUID
visitor_id: str
action: str
categories_accepted: list
categories_rejected: list | None = None
tc_string: str | None = None
gcm_state: dict | None = None
gpp_string: str | None = None
gpc_detected: bool | None = None
gpc_honoured: bool | None = None
page_url: str | None = None
country_code: str | None = None
region_code: str | None = None
consented_at: datetime
model_config = {"from_attributes": True}
class ConsentVerifyResponse(BaseModel):
"""Audit proof that a consent record exists."""
id: uuid.UUID
site_id: uuid.UUID
visitor_id: str
action: str
categories_accepted: list
consented_at: datetime
valid: bool = True

View File

@@ -0,0 +1,210 @@
"""Pydantic schemas for cookie categories, cookies, and allow-list entries."""
from __future__ import annotations
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
# ─── Cookie category schemas ───
class CookieCategoryResponse(BaseModel):
"""Response schema for a cookie category."""
id: uuid.UUID
name: str
slug: str
description: str | None = None
is_essential: bool
display_order: int
tcf_purpose_ids: list[int] | None = None
gcm_consent_types: list[str] | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Storage type enum ───
class StorageType(StrEnum):
"""Type of browser storage used by the cookie/tracker."""
cookie = "cookie"
local_storage = "local_storage"
session_storage = "session_storage"
indexed_db = "indexed_db"
# ─── Review status enum ───
class ReviewStatus(StrEnum):
"""Review status for a discovered cookie."""
pending = "pending"
approved = "approved"
rejected = "rejected"
# ─── Cookie schemas ───
class CookieCreate(BaseModel):
"""Schema for creating a cookie record (typically from scanner/reporter)."""
name: str = Field(..., min_length=1, max_length=255)
domain: str = Field(..., min_length=1, max_length=255)
storage_type: StorageType = StorageType.cookie
category_id: uuid.UUID | None = None
description: str | None = None
vendor: str | None = Field(None, max_length=255)
path: str | None = Field(None, max_length=500)
max_age_seconds: int | None = None
is_http_only: bool | None = None
is_secure: bool | None = None
same_site: str | None = Field(None, max_length=10)
class CookieUpdate(BaseModel):
"""Schema for updating a cookie record."""
category_id: uuid.UUID | None = None
description: str | None = None
vendor: str | None = Field(None, max_length=255)
review_status: ReviewStatus | None = None
class CookieResponse(BaseModel):
"""Response schema for a cookie."""
id: uuid.UUID
site_id: uuid.UUID
category_id: uuid.UUID | None = None
name: str
domain: str
storage_type: str
description: str | None = None
vendor: str | None = None
path: str | None = None
max_age_seconds: int | None = None
is_http_only: bool | None = None
is_secure: bool | None = None
same_site: str | None = None
review_status: str
first_seen_at: str | None = None
last_seen_at: str | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Allow-list schemas ───
class AllowListEntryCreate(BaseModel):
"""Schema for adding a cookie to the allow-list."""
name_pattern: str = Field(..., min_length=1, max_length=255)
domain_pattern: str = Field(..., min_length=1, max_length=255)
category_id: uuid.UUID
description: str | None = None
class AllowListEntryUpdate(BaseModel):
"""Schema for updating an allow-list entry."""
category_id: uuid.UUID | None = None
description: str | None = None
class AllowListEntryResponse(BaseModel):
"""Response schema for an allow-list entry."""
id: uuid.UUID
site_id: uuid.UUID
category_id: uuid.UUID
name_pattern: str
domain_pattern: str
description: str | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Known cookie schemas ───
class KnownCookieCreate(BaseModel):
"""Schema for creating a known cookie pattern."""
name_pattern: str = Field(..., min_length=1, max_length=255)
domain_pattern: str = Field(..., min_length=1, max_length=255)
category_id: uuid.UUID
vendor: str | None = Field(None, max_length=255)
description: str | None = None
is_regex: bool = False
class KnownCookieUpdate(BaseModel):
"""Schema for updating a known cookie pattern."""
category_id: uuid.UUID | None = None
vendor: str | None = Field(None, max_length=255)
description: str | None = None
is_regex: bool | None = None
class KnownCookieResponse(BaseModel):
"""Response schema for a known cookie pattern."""
id: uuid.UUID
name_pattern: str
domain_pattern: str
category_id: uuid.UUID
vendor: str | None = None
description: str | None = None
is_regex: bool
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ─── Classification schemas ───
class ClassificationResultResponse(BaseModel):
"""Response for a single cookie classification result."""
cookie_name: str
cookie_domain: str
category_id: uuid.UUID | None = None
category_slug: str | None = None
vendor: str | None = None
description: str | None = None
match_source: str
matched: bool
class ClassifySiteResponse(BaseModel):
"""Response for classifying all cookies on a site."""
site_id: str
total: int
matched: int
unmatched: int
results: list[ClassificationResultResponse]
class ClassifySingleRequest(BaseModel):
"""Request to classify a single cookie (preview/test)."""
cookie_name: str = Field(..., min_length=1, max_length=255)
cookie_domain: str = Field(..., min_length=1, max_length=255)

View File

@@ -0,0 +1,61 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from src.schemas.site import BlockingMode
class OrgConfigUpdate(BaseModel):
"""Update (or create) organisation-level default configuration.
All fields are optional — only non-None values override the system defaults.
"""
blocking_mode: BlockingMode | None = None
regional_modes: dict | None = None
tcf_enabled: bool | None = None
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool | None = None
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool | None = None
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool | None = None
gcm_enabled: bool | None = None
gcm_default: dict | None = None
shopify_privacy_enabled: bool | None = None
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
class OrgConfigResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
blocking_mode: str | None
regional_modes: dict | None
tcf_enabled: bool | None
tcf_publisher_cc: str | None
gpp_enabled: bool | None
gpp_supported_apis: list[str] | None
gpc_enabled: bool | None
gpc_jurisdictions: list[str] | None
gpc_global_honour: bool | None
gcm_enabled: bool | None
gcm_default: dict | None
shopify_privacy_enabled: bool | None
banner_config: dict | None
privacy_policy_url: str | None
terms_url: str | None
scan_schedule_cron: str | None
scan_max_pages: int | None
consent_expiry_days: int | None
consent_retention_days: int | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,29 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class OrganisationCreate(BaseModel):
name: str = Field(min_length=1, max_length=255)
slug: str = Field(min_length=1, max_length=100, pattern=r"^[a-z0-9-]+$")
contact_email: str | None = None
billing_plan: str = "free"
class OrganisationUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
contact_email: str | None = None
billing_plan: str | None = None
class OrganisationResponse(BaseModel):
id: uuid.UUID
name: str
slug: str
contact_email: str | None
billing_plan: str
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,142 @@
"""Pydantic schemas for scanner and client-side cookie reports."""
from __future__ import annotations
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class ScanStatus(StrEnum):
PENDING = "pending"
RUNNING = "running"
COMPLETED = "completed"
FAILED = "failed"
class ScanTrigger(StrEnum):
MANUAL = "manual"
SCHEDULED = "scheduled"
CLIENT_REPORT = "client_report"
# ── Client-side cookie report ────────────────────────────────────────
class ReportedCookie(BaseModel):
"""A single cookie/storage item reported by the client-side reporter."""
name: str = Field(..., min_length=1, max_length=255)
domain: str = Field(..., min_length=1, max_length=255)
storage_type: str = Field(default="cookie", max_length=30)
value_length: int = Field(default=0, ge=0)
path: str | None = None
is_secure: bool | None = None
same_site: str | None = None
script_source: str | None = None
class CookieReportRequest(BaseModel):
"""Payload from the client-side cookie reporter."""
site_id: uuid.UUID
page_url: str = Field(..., max_length=2000)
cookies: list[ReportedCookie] = Field(..., max_length=500)
collected_at: datetime
user_agent: str = Field(default="", max_length=500)
class CookieReportResponse(BaseModel):
"""Acknowledgement response for a cookie report."""
accepted: bool = True
cookies_received: int
new_cookies: int = 0
# ── Scan job schemas ─────────────────────────────────────────────────
class ScanResultResponse(BaseModel):
"""A single scan result — a cookie found on a specific page."""
id: uuid.UUID
scan_job_id: uuid.UUID
page_url: str
cookie_name: str
cookie_domain: str
storage_type: str
attributes: dict | None = None
script_source: str | None = None
auto_category: str | None = None
initiator_chain: list[str] | None = None
found_at: datetime
created_at: datetime
model_config = {"from_attributes": True}
class ScanJobResponse(BaseModel):
"""Response schema for a scan job."""
id: uuid.UUID
site_id: uuid.UUID
status: str
trigger: str
pages_scanned: int
pages_total: int | None
cookies_found: int
error_message: str | None
started_at: datetime | None
completed_at: datetime | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
class ScanJobDetailResponse(ScanJobResponse):
"""Scan job with results included."""
results: list[ScanResultResponse] = []
class TriggerScanRequest(BaseModel):
"""Request to trigger a new scan."""
site_id: uuid.UUID
max_pages: int = Field(default=50, ge=1, le=500)
# ── Diff engine schemas ──────────────────────────────────────────────
class DiffStatus(StrEnum):
NEW = "new"
REMOVED = "removed"
CHANGED = "changed"
class CookieDiffItem(BaseModel):
"""A single cookie difference between two scans."""
name: str
domain: str
storage_type: str
diff_status: DiffStatus
details: str | None = None
class ScanDiffResponse(BaseModel):
"""Diff between two scans."""
current_scan_id: uuid.UUID
previous_scan_id: uuid.UUID | None
new_cookies: list[CookieDiffItem] = []
removed_cookies: list[CookieDiffItem] = []
changed_cookies: list[CookieDiffItem] = []
total_new: int = 0
total_removed: int = 0
total_changed: int = 0

View File

@@ -0,0 +1,117 @@
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, Field
class BlockingMode(StrEnum):
OPT_IN = "opt_in"
OPT_OUT = "opt_out"
INFORMATIONAL = "informational"
# ── Site schemas ─────────────────────────────────────────────────────
class SiteCreate(BaseModel):
domain: str = Field(min_length=1, max_length=255)
display_name: str = Field(min_length=1, max_length=255)
additional_domains: list[str] | None = None
site_group_id: uuid.UUID | None = None
class SiteUpdate(BaseModel):
display_name: str | None = Field(default=None, min_length=1, max_length=255)
is_active: bool | None = None
additional_domains: list[str] | None = None
site_group_id: uuid.UUID | None = None
class SiteResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
domain: str
display_name: str
is_active: bool
additional_domains: list[str] | None = None
site_group_id: uuid.UUID | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}
# ── Site config schemas ──────────────────────────────────────────────
class SiteConfigCreate(BaseModel):
blocking_mode: BlockingMode = BlockingMode.OPT_IN
regional_modes: dict | None = None
tcf_enabled: bool = False
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool = True
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool = True
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool = False
gcm_enabled: bool = True
gcm_default: dict | None = None
shopify_privacy_enabled: bool = False
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int = Field(default=50, ge=1, le=1000)
consent_expiry_days: int = Field(default=365, ge=1, le=730)
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
class SiteConfigUpdate(BaseModel):
blocking_mode: BlockingMode | None = None
regional_modes: dict | None = None
tcf_enabled: bool | None = None
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool | None = None
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool | None = None
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool | None = None
gcm_enabled: bool | None = None
gcm_default: dict | None = None
shopify_privacy_enabled: bool | None = None
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
class SiteConfigResponse(BaseModel):
id: uuid.UUID
site_id: uuid.UUID
blocking_mode: str
regional_modes: dict | None
tcf_enabled: bool
tcf_publisher_cc: str | None = None
gpp_enabled: bool = True
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool = True
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool = False
gcm_enabled: bool
gcm_default: dict | None = None
shopify_privacy_enabled: bool = False
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int = 50
consent_expiry_days: int = 365
consent_retention_days: int | None = None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,26 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class SiteGroupCreate(BaseModel):
name: str = Field(min_length=1, max_length=255)
description: str | None = None
class SiteGroupUpdate(BaseModel):
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = None
class SiteGroupResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
name: str
description: str | None
created_at: datetime
updated_at: datetime
site_count: int = 0
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,59 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from src.schemas.site import BlockingMode
class SiteGroupConfigUpdate(BaseModel):
"""Update (or create) site-group-level default configuration.
All fields are optional — only non-None values override the org/system defaults.
"""
blocking_mode: BlockingMode | None = None
regional_modes: dict | None = None
tcf_enabled: bool | None = None
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
gpp_enabled: bool | None = None
gpp_supported_apis: list[str] | None = None
gpc_enabled: bool | None = None
gpc_jurisdictions: list[str] | None = None
gpc_global_honour: bool | None = None
gcm_enabled: bool | None = None
gcm_default: dict | None = None
shopify_privacy_enabled: bool | None = None
banner_config: dict | None = None
privacy_policy_url: str | None = None
terms_url: str | None = None
scan_schedule_cron: str | None = None
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
class SiteGroupConfigResponse(BaseModel):
id: uuid.UUID
site_group_id: uuid.UUID
blocking_mode: str | None
regional_modes: dict | None
tcf_enabled: bool | None
tcf_publisher_cc: str | None
gpp_enabled: bool | None
gpp_supported_apis: list[str] | None
gpc_enabled: bool | None
gpc_jurisdictions: list[str] | None
gpc_global_honour: bool | None
gcm_enabled: bool | None
gcm_default: dict | None
shopify_privacy_enabled: bool | None
banner_config: dict | None
privacy_policy_url: str | None
terms_url: str | None
scan_schedule_cron: str | None
scan_max_pages: int | None
consent_expiry_days: int | None
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,24 @@
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
class TranslationCreate(BaseModel):
locale: str = Field(min_length=2, max_length=10)
strings: dict[str, str]
class TranslationUpdate(BaseModel):
strings: dict[str, str]
class TranslationResponse(BaseModel):
id: uuid.UUID
site_id: uuid.UUID
locale: str
strings: dict[str, str]
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

@@ -0,0 +1,36 @@
import uuid
from datetime import datetime
from enum import StrEnum
from pydantic import BaseModel, EmailStr, Field
class UserRole(StrEnum):
OWNER = "owner"
ADMIN = "admin"
EDITOR = "editor"
VIEWER = "viewer"
class UserCreate(BaseModel):
email: EmailStr
password: str = Field(min_length=8, max_length=72)
full_name: str = Field(min_length=1, max_length=255)
role: UserRole = UserRole.VIEWER
class UserUpdate(BaseModel):
full_name: str | None = Field(default=None, min_length=1, max_length=255)
role: UserRole | None = None
class UserResponse(BaseModel):
id: uuid.UUID
organisation_id: uuid.UUID
email: str
full_name: str
role: str
created_at: datetime
updated_at: datetime
model_config = {"from_attributes": True}

View File

View File

@@ -0,0 +1,59 @@
import uuid
from datetime import UTC, datetime, timedelta
import bcrypt
from jose import jwt
from src.config.settings import get_settings
def hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def verify_password(plain_password: str, hashed_password: str) -> bool:
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
def create_access_token(
user_id: uuid.UUID,
organisation_id: uuid.UUID,
role: str,
email: str,
) -> str:
settings = get_settings()
now = datetime.now(UTC)
expire = now + timedelta(minutes=settings.jwt_access_token_expire_minutes)
payload = {
"sub": str(user_id),
"org_id": str(organisation_id),
"role": role,
"email": email,
"exp": expire,
"iat": now,
"type": "access",
}
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
def create_refresh_token(
user_id: uuid.UUID,
organisation_id: uuid.UUID,
) -> str:
settings = get_settings()
now = datetime.now(UTC)
expire = now + timedelta(days=settings.jwt_refresh_token_expire_days)
payload = {
"sub": str(user_id),
"org_id": str(organisation_id),
"exp": expire,
"iat": now,
"type": "refresh",
}
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
def decode_token(token: str) -> dict:
"""Decode and validate a JWT token. Raises JWTError on failure."""
settings = get_settings()
return jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])

View File

@@ -0,0 +1,79 @@
"""First-run bootstrap of an organisation and owner user.
Runs once on API startup. If ``INITIAL_ADMIN_EMAIL`` and
``INITIAL_ADMIN_PASSWORD`` are set and the ``users`` table is empty,
creates an organisation and a single owner user so the operator can log
in to the admin UI for the first time. Idempotent: once any user
exists, this is a no-op, so the environment variables can safely remain
set across restarts. Complements ``ADMIN_BOOTSTRAP_TOKEN`` — that gates
runtime org creation; this creates the *initial* org + owner without
requiring a second round-trip.
"""
import logging
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from src.config.settings import Settings
from src.db.session import async_session_factory
from src.models.organisation import Organisation
from src.models.user import User
from src.services.auth import hash_password
logger = logging.getLogger(__name__)
async def bootstrap_initial_admin(settings: Settings) -> None:
"""Create the first organisation and owner user if none exist.
No-op when either credential env var is unset or when the database
already contains at least one user. Unexpected errors are logged
and swallowed — a failed bootstrap must not prevent the API from
starting, since operators can always fall back to manual provisioning.
"""
if not settings.initial_admin_email or not settings.initial_admin_password:
logger.debug("Initial admin bootstrap skipped: credentials not configured")
return
try:
async with async_session_factory() as session:
await _bootstrap(session, settings)
except Exception: # pragma: no cover — defensive, logged
logger.exception("Initial admin bootstrap failed")
async def _bootstrap(session: AsyncSession, settings: Settings) -> None:
existing_users = await session.scalar(select(func.count()).select_from(User))
if existing_users:
logger.debug("Initial admin bootstrap skipped: %d user(s) already exist", existing_users)
return
org = await session.scalar(
select(Organisation).where(Organisation.slug == settings.initial_org_slug)
)
if org is None:
org = Organisation(
name=settings.initial_org_name,
slug=settings.initial_org_slug,
contact_email=settings.initial_admin_email,
)
session.add(org)
await session.flush()
user = User(
organisation_id=org.id,
email=settings.initial_admin_email,
password_hash=hash_password(settings.initial_admin_password),
full_name=settings.initial_admin_full_name,
role="owner",
)
session.add(user)
await session.commit()
logger.warning(
"Initial admin bootstrap created owner %s in organisation '%s'. "
"Rotate the password via the admin UI as soon as possible.",
settings.initial_admin_email,
org.slug,
)

View File

@@ -0,0 +1,298 @@
"""Cookie auto-categorisation engine.
Matches discovered cookies against the known_cookies database using exact name
matching, domain matching, and regex patterns. Also checks site-specific
allow-list entries. Returns a classification result with category, vendor, and
confidence level.
Matching priority (highest first):
1. Site-specific allow-list (exact or pattern match)
2. Known cookies — exact name + domain match
3. Known cookies — regex pattern match on name + domain
4. Unmatched → remains as 'pending'
"""
from __future__ import annotations
import re
import uuid
from dataclasses import dataclass
from enum import StrEnum
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.models.cookie import (
Cookie,
CookieAllowListEntry,
CookieCategory,
KnownCookie,
)
class MatchSource(StrEnum):
"""Where the classification match came from."""
ALLOW_LIST = "allow_list"
KNOWN_EXACT = "known_exact"
KNOWN_REGEX = "known_regex"
UNMATCHED = "unmatched"
@dataclass
class ClassificationResult:
"""Result of classifying a single cookie."""
cookie_name: str
cookie_domain: str
category_id: uuid.UUID | None = None
category_slug: str | None = None
vendor: str | None = None
description: str | None = None
match_source: MatchSource = MatchSource.UNMATCHED
matched: bool = False
async def _load_allow_list(
db: AsyncSession,
site_id: uuid.UUID,
) -> list[CookieAllowListEntry]:
"""Load the allow-list entries for a site."""
result = await db.execute(
select(CookieAllowListEntry).where(
CookieAllowListEntry.site_id == site_id,
)
)
return list(result.scalars().all())
async def _load_known_cookies(
db: AsyncSession,
) -> tuple[list[KnownCookie], list[KnownCookie]]:
"""Load known cookies, split into exact and regex lists."""
result = await db.execute(select(KnownCookie))
all_known = list(result.scalars().all())
exact = [k for k in all_known if not k.is_regex]
regex = [k for k in all_known if k.is_regex]
return exact, regex
async def _load_category_map(
db: AsyncSession,
) -> dict[uuid.UUID, CookieCategory]:
"""Load a mapping of category ID to CookieCategory."""
result = await db.execute(select(CookieCategory))
return {cat.id: cat for cat in result.scalars().all()}
def _match_pattern(pattern: str, value: str) -> bool:
"""Check if a value matches a pattern (case-insensitive).
Patterns support:
- Exact match (e.g. "_ga")
- Wildcard with * (e.g. "_ga*", "*.google.com")
- Regex if it contains regex-specific characters
"""
if not pattern or not value:
return False
pattern_lower = pattern.lower()
value_lower = value.lower()
# Simple exact match
if pattern_lower == value_lower:
return True
# Wildcard: convert * to regex .*
if "*" in pattern_lower:
regex_pattern = "^" + re.escape(pattern_lower).replace(r"\*", ".*") + "$"
return bool(re.match(regex_pattern, value_lower))
return False
def _match_regex(pattern: str, value: str) -> bool:
"""Match a value against a regex pattern (case-insensitive)."""
try:
return bool(re.match(pattern, value, re.IGNORECASE))
except re.error:
return False
def _match_allow_list(
cookie_name: str,
cookie_domain: str,
allow_list: list[CookieAllowListEntry],
) -> CookieAllowListEntry | None:
"""Check if a cookie matches any allow-list entry."""
for entry in allow_list:
name_match = _match_pattern(entry.name_pattern, cookie_name)
domain_match = _match_pattern(entry.domain_pattern, cookie_domain)
if name_match and domain_match:
return entry
return None
def _match_exact_known(
cookie_name: str,
cookie_domain: str,
exact_known: list[KnownCookie],
) -> KnownCookie | None:
"""Find an exact match in the known cookies database."""
for known in exact_known:
name_match = _match_pattern(known.name_pattern, cookie_name)
domain_match = _match_pattern(known.domain_pattern, cookie_domain)
if name_match and domain_match:
return known
return None
def _match_regex_known(
cookie_name: str,
cookie_domain: str,
regex_known: list[KnownCookie],
) -> KnownCookie | None:
"""Find a regex match in the known cookies database."""
for known in regex_known:
name_match = _match_regex(known.name_pattern, cookie_name)
domain_match = _match_regex(known.domain_pattern, cookie_domain)
if name_match and domain_match:
return known
return None
def classify_cookie(
cookie_name: str,
cookie_domain: str,
allow_list: list[CookieAllowListEntry],
exact_known: list[KnownCookie],
regex_known: list[KnownCookie],
category_map: dict[uuid.UUID, CookieCategory],
) -> ClassificationResult:
"""Classify a single cookie against allow-list and known cookies DB.
This is a pure function — all data is passed in, no DB calls.
"""
# 1. Check allow-list first (site-specific overrides)
allow_match = _match_allow_list(cookie_name, cookie_domain, allow_list)
if allow_match:
cat = category_map.get(allow_match.category_id)
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
category_id=allow_match.category_id,
category_slug=cat.slug if cat else None,
description=allow_match.description,
match_source=MatchSource.ALLOW_LIST,
matched=True,
)
# 2. Check exact known cookies
exact_match = _match_exact_known(cookie_name, cookie_domain, exact_known)
if exact_match:
cat = category_map.get(exact_match.category_id)
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
category_id=exact_match.category_id,
category_slug=cat.slug if cat else None,
vendor=exact_match.vendor,
description=exact_match.description,
match_source=MatchSource.KNOWN_EXACT,
matched=True,
)
# 3. Check regex known cookies
regex_match = _match_regex_known(cookie_name, cookie_domain, regex_known)
if regex_match:
cat = category_map.get(regex_match.category_id)
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
category_id=regex_match.category_id,
category_slug=cat.slug if cat else None,
vendor=regex_match.vendor,
description=regex_match.description,
match_source=MatchSource.KNOWN_REGEX,
matched=True,
)
# 4. Unmatched
return ClassificationResult(
cookie_name=cookie_name,
cookie_domain=cookie_domain,
)
async def classify_site_cookies(
db: AsyncSession,
site_id: uuid.UUID,
*,
only_pending: bool = True,
) -> list[ClassificationResult]:
"""Classify all cookies for a site against known patterns.
If only_pending is True, only cookies with review_status='pending'
and no category are classified.
Returns a list of results. Also updates matching cookies in the DB.
"""
# Load lookup data
allow_list = await _load_allow_list(db, site_id)
exact_known, regex_known = await _load_known_cookies(db)
category_map = await _load_category_map(db)
# Load cookies to classify
query = select(Cookie).where(Cookie.site_id == site_id)
if only_pending:
query = query.where(
Cookie.review_status == "pending",
Cookie.category_id.is_(None),
)
result = await db.execute(query)
cookies = list(result.scalars().all())
results: list[ClassificationResult] = []
for cookie in cookies:
cr = classify_cookie(
cookie.name,
cookie.domain,
allow_list,
exact_known,
regex_known,
category_map,
)
results.append(cr)
# Update the cookie if matched
if cr.matched and cr.category_id:
cookie.category_id = cr.category_id
if cr.vendor and not cookie.vendor:
cookie.vendor = cr.vendor
if cr.description and not cookie.description:
cookie.description = cr.description
await db.flush()
return results
async def classify_single_cookie(
db: AsyncSession,
site_id: uuid.UUID,
cookie_name: str,
cookie_domain: str,
) -> ClassificationResult:
"""Classify a single cookie (e.g. for preview/testing)."""
allow_list = await _load_allow_list(db, site_id)
exact_known, regex_known = await _load_known_cookies(db)
category_map = await _load_category_map(db)
return classify_cookie(
cookie_name,
cookie_domain,
allow_list,
exact_known,
regex_known,
category_map,
)

View File

@@ -0,0 +1,482 @@
"""Pluggable compliance rule engine.
Each regulatory framework (GDPR, CNIL, CCPA, ePrivacy, LGPD) is defined as a
list of ComplianceRule objects. Rules evaluate site configuration, banner
settings, cookie data, and consent parameters to produce issues with severity,
message, and recommendation.
The engine aggregates individual rule results into per-framework reports with
a compliance score, status, and actionable issues list.
"""
from __future__ import annotations
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any
from src.schemas.compliance import (
ComplianceIssue,
Framework,
FrameworkResult,
Severity,
)
# ── Rule context ──────────────────────────────────────────────────────
@dataclass
class SiteContext:
"""All data needed to evaluate compliance rules for a site."""
# Site config fields
blocking_mode: str = "opt_in"
regional_modes: dict[str, str] | None = None
tcf_enabled: bool = False
gcm_enabled: bool = True
consent_expiry_days: int = 365
privacy_policy_url: str | None = None
# Banner config (JSONB — may have any keys)
banner_config: dict[str, Any] | None = None
# Cookie statistics
total_cookies: int = 0
uncategorised_cookies: int = 0
cookies_without_expiry: int = 0
# Consent settings
has_reject_button: bool = True
has_granular_choices: bool = True
has_cookie_wall: bool = False
pre_ticked_boxes: bool = False
# ── Rule definition ───────────────────────────────────────────────────
# A check function receives a SiteContext and returns a list of issues.
CheckFn = Callable[[SiteContext], list[ComplianceIssue]]
@dataclass
class ComplianceRule:
"""A single compliance rule with an ID, description, and check function."""
rule_id: str
description: str
check: CheckFn
# ── Helper factories ──────────────────────────────────────────────────
def _issue(
rule_id: str,
severity: Severity,
message: str,
recommendation: str,
) -> ComplianceIssue:
return ComplianceIssue(
rule_id=rule_id,
severity=severity,
message=message,
recommendation=recommendation,
)
# ── GDPR rules ────────────────────────────────────────────────────────
def _gdpr_opt_in(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.blocking_mode != "opt_in":
return [
_issue(
"gdpr_opt_in",
Severity.CRITICAL,
"GDPR requires opt-in consent before setting non-essential cookies.",
"Set blocking mode to 'opt_in'.",
)
]
return []
def _gdpr_reject_button(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.has_reject_button:
return [
_issue(
"gdpr_reject_button",
Severity.CRITICAL,
"The reject option must be as prominent as the accept option.",
"Add a clearly visible 'Reject all' button to the first layer.",
)
]
return []
def _gdpr_granular_consent(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.has_granular_choices:
return [
_issue(
"gdpr_granular",
Severity.CRITICAL,
"Users must be able to consent to individual cookie categories.",
"Provide granular category toggles in the consent banner.",
)
]
return []
def _gdpr_no_cookie_wall(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.has_cookie_wall:
return [
_issue(
"gdpr_cookie_wall",
Severity.CRITICAL,
"Cookie walls (blocking access unless consent is given) are not permitted.",
"Remove the cookie wall and allow access without consent.",
)
]
return []
def _gdpr_no_pre_ticked(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.pre_ticked_boxes:
return [
_issue(
"gdpr_pre_ticked",
Severity.CRITICAL,
"Pre-ticked consent boxes do not constitute valid consent.",
"Ensure all non-essential category checkboxes default to unchecked.",
)
]
return []
def _gdpr_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.privacy_policy_url:
return [
_issue(
"gdpr_privacy_policy",
Severity.WARNING,
"A link to the privacy policy should be accessible from the banner.",
"Configure a privacy policy URL in the site settings.",
)
]
return []
def _gdpr_uncategorised_cookies(ctx: SiteContext) -> list[ComplianceIssue]:
if ctx.uncategorised_cookies > 0:
return [
_issue(
"gdpr_uncategorised",
Severity.WARNING,
f"{ctx.uncategorised_cookies} cookie(s) have not been categorised.",
"Review and assign a category to all discovered cookies.",
)
]
return []
GDPR_RULES: list[ComplianceRule] = [
ComplianceRule("gdpr_opt_in", "Opt-in consent required", _gdpr_opt_in),
ComplianceRule("gdpr_reject_button", "Reject as prominent as accept", _gdpr_reject_button),
ComplianceRule("gdpr_granular", "Granular category consent", _gdpr_granular_consent),
ComplianceRule("gdpr_cookie_wall", "No cookie walls", _gdpr_no_cookie_wall),
ComplianceRule("gdpr_pre_ticked", "No pre-ticked boxes", _gdpr_no_pre_ticked),
ComplianceRule("gdpr_privacy_policy", "Privacy policy link", _gdpr_privacy_policy),
ComplianceRule("gdpr_uncategorised", "All cookies categorised", _gdpr_uncategorised_cookies),
]
# ── CNIL rules (French — stricter GDPR) ──────────────────────────────
def _cnil_consent_expiry(ctx: SiteContext) -> list[ComplianceIssue]:
"""CNIL mandates re-consent every 6 months (≈ 182 days)."""
if ctx.consent_expiry_days > 182:
return [
_issue(
"cnil_reconsent",
Severity.CRITICAL,
"CNIL requires re-consent at least every 6 months.",
"Set consent_expiry_days to 182 or fewer.",
)
]
return []
def _cnil_cookie_lifetime(ctx: SiteContext) -> list[ComplianceIssue]:
"""CNIL limits cookie lifetime to 13 months (≈ 395 days)."""
if ctx.consent_expiry_days > 395:
return [
_issue(
"cnil_cookie_lifetime",
Severity.CRITICAL,
"CNIL limits consent cookie lifetime to 13 months.",
"Set consent_expiry_days to 395 or fewer.",
)
]
return []
def _cnil_reject_first_layer(ctx: SiteContext) -> list[ComplianceIssue]:
"""CNIL requires 'Tout refuser' on the first layer of the banner."""
if not ctx.has_reject_button:
return [
_issue(
"cnil_reject_first_layer",
Severity.CRITICAL,
"CNIL requires a 'Reject all' button on the first layer of the banner.",
"Ensure the 'Reject all' button is visible on the first banner view.",
)
]
return []
# CNIL rules include all GDPR rules plus CNIL-specific ones
CNIL_RULES: list[ComplianceRule] = [
*GDPR_RULES,
ComplianceRule("cnil_reconsent", "Re-consent every 6 months", _cnil_consent_expiry),
ComplianceRule("cnil_cookie_lifetime", "13-month cookie lifetime", _cnil_cookie_lifetime),
ComplianceRule(
"cnil_reject_first_layer",
"Reject on first layer",
_cnil_reject_first_layer,
),
]
# ── CCPA / CPRA rules ────────────────────────────────────────────────
def _ccpa_opt_out(ctx: SiteContext) -> list[ComplianceIssue]:
"""CCPA uses an opt-out model — blocking mode should be opt_out."""
if ctx.blocking_mode not in ("opt_out", "opt_in"):
return [
_issue(
"ccpa_opt_out",
Severity.CRITICAL,
"CCPA requires at minimum an opt-out mechanism for data sale.",
"Set blocking mode to 'opt_out' or 'opt_in'.",
)
]
return []
def _ccpa_do_not_sell(ctx: SiteContext) -> list[ComplianceIssue]:
"""CCPA requires a 'Do Not Sell My Personal Information' link."""
bc = ctx.banner_config or {}
has_dns = bc.get("show_do_not_sell_link", False)
if not has_dns:
return [
_issue(
"ccpa_do_not_sell",
Severity.CRITICAL,
"CCPA requires a 'Do Not Sell My Personal Information' link.",
"Enable 'show_do_not_sell_link' in the banner configuration.",
)
]
return []
def _ccpa_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.privacy_policy_url:
return [
_issue(
"ccpa_privacy_policy",
Severity.WARNING,
"A privacy policy is required under CCPA.",
"Configure a privacy policy URL in the site settings.",
)
]
return []
CCPA_RULES: list[ComplianceRule] = [
ComplianceRule("ccpa_opt_out", "Opt-out mechanism", _ccpa_opt_out),
ComplianceRule("ccpa_do_not_sell", "Do Not Sell link", _ccpa_do_not_sell),
ComplianceRule("ccpa_privacy_policy", "Privacy policy required", _ccpa_privacy_policy),
]
# ── ePrivacy rules ───────────────────────────────────────────────────
def _eprivacy_consent(ctx: SiteContext) -> list[ComplianceIssue]:
"""ePrivacy requires consent for non-essential cookies."""
if ctx.blocking_mode == "informational":
return [
_issue(
"eprivacy_consent",
Severity.CRITICAL,
"ePrivacy Directive requires consent for non-essential cookies.",
"Set blocking mode to 'opt_in' or 'opt_out'.",
)
]
return []
def _eprivacy_necessary_exempt(ctx: SiteContext) -> list[ComplianceIssue]:
"""Strictly necessary cookies must be exempt from consent."""
# This is a configuration guidance check — ensure opt-in mode
# doesn't block necessary cookies (which the blocker handles by default).
# We report an info if everything looks good.
return []
EPRIVACY_RULES: list[ComplianceRule] = [
ComplianceRule("eprivacy_consent", "Consent for non-essential", _eprivacy_consent),
ComplianceRule(
"eprivacy_necessary_exempt",
"Necessary cookies exempt",
_eprivacy_necessary_exempt,
),
]
# ── LGPD rules (Brazil) ──────────────────────────────────────────────
def _lgpd_consent_basis(ctx: SiteContext) -> list[ComplianceIssue]:
"""LGPD requires consent or legitimate interest as legal basis."""
if ctx.blocking_mode == "informational":
return [
_issue(
"lgpd_consent_basis",
Severity.CRITICAL,
"LGPD requires a legal basis (consent or legitimate interest) for data processing.",
"Set blocking mode to 'opt_in' or 'opt_out'.",
)
]
return []
def _lgpd_data_controller(ctx: SiteContext) -> list[ComplianceIssue]:
"""LGPD requires identifying the data controller."""
if not ctx.privacy_policy_url:
return [
_issue(
"lgpd_data_controller",
Severity.WARNING,
"LGPD requires identification of the data controller.",
"Link to a privacy policy that identifies the data controller.",
)
]
return []
def _lgpd_granular(ctx: SiteContext) -> list[ComplianceIssue]:
if not ctx.has_granular_choices:
return [
_issue(
"lgpd_granular",
Severity.WARNING,
"LGPD recommends granular consent choices.",
"Provide individual category toggles in the consent banner.",
)
]
return []
LGPD_RULES: list[ComplianceRule] = [
ComplianceRule("lgpd_consent_basis", "Legal basis for processing", _lgpd_consent_basis),
ComplianceRule("lgpd_data_controller", "Identify data controller", _lgpd_data_controller),
ComplianceRule("lgpd_granular", "Granular consent choices", _lgpd_granular),
]
# ── Framework registry ────────────────────────────────────────────────
FRAMEWORK_RULES: dict[Framework, list[ComplianceRule]] = {
Framework.GDPR: GDPR_RULES,
Framework.CNIL: CNIL_RULES,
Framework.CCPA: CCPA_RULES,
Framework.EPRIVACY: EPRIVACY_RULES,
Framework.LGPD: LGPD_RULES,
}
# ── Engine ────────────────────────────────────────────────────────────
def run_framework_check(
framework: Framework,
ctx: SiteContext,
) -> FrameworkResult:
"""Run all rules for a single framework and produce a result."""
rules = FRAMEWORK_RULES.get(framework, [])
all_issues: list[ComplianceIssue] = []
rules_passed = 0
for rule in rules:
issues = rule.check(ctx)
if issues:
all_issues.extend(issues)
else:
rules_passed += 1
rules_checked = len(rules)
score = _calculate_score(all_issues, rules_checked)
status = _determine_status(score, all_issues)
return FrameworkResult(
framework=framework,
score=score,
status=status,
issues=all_issues,
rules_checked=rules_checked,
rules_passed=rules_passed,
)
def run_compliance_check(
ctx: SiteContext,
frameworks: list[Framework] | None = None,
) -> list[FrameworkResult]:
"""Run compliance checks for the specified (or all) frameworks."""
targets = frameworks if frameworks else list(FRAMEWORK_RULES.keys())
return [run_framework_check(fw, ctx) for fw in targets]
def calculate_overall_score(results: list[FrameworkResult]) -> int:
"""Calculate a weighted average score across framework results."""
if not results:
return 100
total = sum(r.score for r in results)
return round(total / len(results))
# ── Scoring helpers ───────────────────────────────────────────────────
def _calculate_score(
issues: list[ComplianceIssue],
rules_checked: int,
) -> int:
"""Score from 0-100. Critical issues deduct 20 pts, warnings 5 pts."""
if rules_checked == 0:
return 100
deductions = 0
for issue in issues:
if issue.severity == Severity.CRITICAL:
deductions += 20
elif issue.severity == Severity.WARNING:
deductions += 5
# INFO issues don't affect the score
return max(0, 100 - deductions)
def _determine_status(
score: int,
issues: list[ComplianceIssue],
) -> str:
"""Derive overall status string from score and issues."""
has_critical = any(i.severity == Severity.CRITICAL for i in issues)
if has_critical:
return "non_compliant"
if score >= 100:
return "compliant"
return "partial"

View File

@@ -0,0 +1,156 @@
"""Configuration hierarchy resolver.
Resolves site configuration by merging:
System Defaults → Org Defaults → Site Group Defaults → Site Config → Regional Overrides
Produces a fully resolved public config suitable for the banner script.
"""
from __future__ import annotations
from typing import Any
# System-level defaults (hard-coded, lowest priority)
SYSTEM_DEFAULTS: dict[str, Any] = {
"blocking_mode": "opt_in",
"tcf_enabled": False,
"gpp_enabled": True,
"gpp_supported_apis": ["usnat"],
"gpc_enabled": True,
"gpc_jurisdictions": ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"],
"gpc_global_honour": False,
"gcm_enabled": True,
"shopify_privacy_enabled": False,
"gcm_default": {
"ad_storage": "denied",
"ad_user_data": "denied",
"ad_personalization": "denied",
"analytics_storage": "denied",
"functionality_storage": "denied",
"personalization_storage": "denied",
"security_storage": "granted",
},
"banner_config": None,
"privacy_policy_url": None,
"terms_url": None,
"consent_expiry_days": 365,
}
def resolve_config(
site_config: dict[str, Any],
org_defaults: dict[str, Any] | None = None,
group_defaults: dict[str, Any] | None = None,
region: str | None = None,
) -> dict[str, Any]:
"""Resolve the full configuration by merging layers.
Args:
site_config: Site-specific configuration from the database.
org_defaults: Organisation-level default overrides (optional).
group_defaults: Site-group-level default overrides (optional).
region: ISO region code for regional mode override (optional).
Returns:
Fully resolved configuration dictionary.
"""
# Start with system defaults
resolved = {**SYSTEM_DEFAULTS}
# Apply organisation defaults (if any)
if org_defaults:
_merge_non_none(resolved, org_defaults)
# Apply site group defaults (if any)
if group_defaults:
_merge_non_none(resolved, group_defaults)
# Apply site-specific config
_merge_non_none(resolved, site_config)
# Apply regional blocking mode override
if region and site_config.get("regional_modes"):
regional_modes = site_config["regional_modes"]
if isinstance(regional_modes, dict):
# Try exact match first, then fall back to DEFAULT
regional_mode = regional_modes.get(region) or regional_modes.get("DEFAULT")
if regional_mode:
resolved["blocking_mode"] = regional_mode
return resolved
def build_public_config(
site_id: str,
resolved: dict[str, Any],
) -> dict[str, Any]:
"""Build a public configuration JSON for the banner script.
Strips internal fields and adds the site_id for identification.
"""
return {
"id": resolved.get("id", ""),
"site_id": site_id,
"blocking_mode": resolved["blocking_mode"],
"regional_modes": resolved.get("regional_modes"),
"tcf_enabled": resolved["tcf_enabled"],
"gpp_enabled": resolved["gpp_enabled"],
"gpp_supported_apis": resolved.get("gpp_supported_apis"),
"gpc_enabled": resolved["gpc_enabled"],
"gpc_jurisdictions": resolved.get("gpc_jurisdictions"),
"gpc_global_honour": resolved["gpc_global_honour"],
"gcm_enabled": resolved["gcm_enabled"],
"gcm_default": resolved.get("gcm_default"),
"shopify_privacy_enabled": resolved["shopify_privacy_enabled"],
"banner_config": resolved.get("banner_config"),
"privacy_policy_url": resolved.get("privacy_policy_url"),
"terms_url": resolved.get("terms_url"),
"consent_expiry_days": resolved["consent_expiry_days"],
"consent_group_id": resolved.get("consent_group_id"),
"ab_test": resolved.get("ab_test"),
}
CONFIG_FIELDS = (
"blocking_mode",
"regional_modes",
"tcf_enabled",
"tcf_publisher_cc",
"gpp_enabled",
"gpp_supported_apis",
"gpc_enabled",
"gpc_jurisdictions",
"gpc_global_honour",
"gcm_enabled",
"gcm_default",
"shopify_privacy_enabled",
"banner_config",
"privacy_policy_url",
"terms_url",
"consent_expiry_days",
)
def orm_to_config_dict(obj: Any, *, include_id: bool = False) -> dict[str, Any]:
"""Convert a SiteConfig or OrgConfig ORM object to a dict of config fields.
Only includes fields that are explicitly set (not NULL). This allows the
hierarchy to work correctly: unset fields at higher-priority layers don't
block inheritance from lower-priority layers.
"""
d: dict[str, Any] = {}
if include_id and hasattr(obj, "id"):
d["id"] = str(obj.id)
for field in CONFIG_FIELDS:
if hasattr(obj, field):
value = getattr(obj, field)
if value is not None:
d[field] = value
return d
def _merge_non_none(target: dict[str, Any], source: dict[str, Any]) -> None:
"""Merge source into target, skipping None values in source."""
for key, value in source.items():
if value is not None:
target[key] = value

View File

@@ -0,0 +1,77 @@
"""Dynamic CORS origin validation.
Provides an origin validator that checks incoming origins against
registered site domains (primary + additional) in addition to the
statically configured allowed_origins list.
"""
from __future__ import annotations
import logging
from urllib.parse import urlparse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.models.site import Site
logger = logging.getLogger(__name__)
def extract_domain_from_origin(origin: str) -> str | None:
"""Extract the hostname from an origin URL.
e.g. 'https://example.com:443''example.com'
"""
try:
parsed = urlparse(origin)
return parsed.hostname
except Exception:
return None
async def get_allowed_domains(db: AsyncSession) -> set[str]:
"""Fetch all registered domains (primary + additional) from active sites."""
result = await db.execute(
select(Site.domain, Site.additional_domains).where(
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
domains: set[str] = set()
for row in result.all():
domains.add(row.domain.lower())
if row.additional_domains:
for d in row.additional_domains:
domains.add(d.lower())
return domains
def is_origin_allowed(
origin: str,
static_origins: list[str],
registered_domains: set[str],
) -> bool:
"""Check if an origin is allowed by either the static list or registered domains.
Args:
origin: The Origin header value (e.g. 'https://example.com').
static_origins: Statically configured allowed origins from settings.
registered_domains: Set of registered site domains from the database.
Returns:
True if the origin is allowed.
"""
# Check static origins first (exact match)
if origin in static_origins:
return True
# Wildcard — allow everything
if "*" in static_origins:
return True
# Extract domain from origin and check against registered domains
domain = extract_domain_from_origin(origin)
return bool(domain and domain.lower() in registered_domains)

View File

@@ -0,0 +1,54 @@
import uuid
from collections.abc import Callable
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from jose import JWTError
from src.schemas.auth import CurrentUser
from src.services.auth import decode_token
bearer_scheme = HTTPBearer()
async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
) -> CurrentUser:
"""Extract and validate the current user from the JWT bearer token."""
try:
payload = decode_token(credentials.credentials)
except JWTError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
headers={"WWW-Authenticate": "Bearer"},
) from exc
if payload.get("type") != "access":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token type",
)
return CurrentUser(
id=uuid.UUID(payload["sub"]),
organisation_id=uuid.UUID(payload["org_id"]),
email=payload.get("email", ""),
role=payload.get("role", "viewer"),
)
def require_role(*allowed_roles: str) -> Callable:
"""Dependency factory that restricts access to users with specific roles."""
async def _check_role(
current_user: CurrentUser = Depends(get_current_user),
) -> CurrentUser:
if not current_user.has_role(*allowed_roles):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Role '{current_user.role}' is not permitted for this action",
)
return current_user
return _check_role

View File

@@ -0,0 +1,339 @@
"""GeoIP service — resolve an IP address to a country/region code.
Resolution order (see :func:`detect_region`):
1. **CDN / proxy headers.** Operators configure ``GEOIP_COUNTRY_HEADER``
(and optionally ``GEOIP_REGION_HEADER``) to match whatever their edge
uses — e.g. ``cf-ipcountry`` + ``cf-region-code`` on Cloudflare
Enterprise, or ``x-gclb-country`` + ``x-gclb-region`` on GCP. A short
built-in country list (``cf-ipcountry``, ``x-vercel-ip-country``,
``x-appengine-country``, ``x-country-code``) covers the common case
where only country-level granularity is needed.
2. **Local MaxMind GeoLite2-City database.** Set
``GEOIP_MAXMIND_DB_PATH`` to a mounted ``.mmdb`` file. Gives both
country and ISO 3166-2 subdivision without any external calls.
3. **External ip-api.com lookup** (rate-limited, no API key). Last-ditch
fallback; fine for development, not recommended for production.
4. Unresolved — the caller should fall back to the default region.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
import geoip2.database
import httpx
from fastapi import Request
from src.config.settings import get_settings
logger = logging.getLogger(__name__)
# Lazily-initialised MaxMind reader. ``geoip2.database.Reader`` opens
# the file once and then every lookup is a memory-mapped read, so we
# cache it for the lifetime of the process. ``None`` means either no
# path is configured, initialisation failed, or we haven't tried yet.
_maxmind_reader: geoip2.database.Reader | None = None
_maxmind_initialised = False
# Standard headers set by CDN / reverse proxy providers. Operators
# running behind a CDN that uses a non-standard header (e.g. Google
# Cloud Load Balancer's ``x-gclb-country``) can add one more via the
# ``GEOIP_COUNTRY_HEADER`` env var — see ``detect_region_from_headers``.
_GEO_HEADERS = [
"cf-ipcountry", # Cloudflare
"x-vercel-ip-country", # Vercel
"x-appengine-country", # Google App Engine
"x-country-code", # Generic / custom
]
# Mapping from two-letter country code to region codes used in regional_modes
# EU member states → "EU", US states handled separately, etc.
_EU_COUNTRIES = frozenset(
{
"AT",
"BE",
"BG",
"HR",
"CY",
"CZ",
"DK",
"EE",
"FI",
"FR",
"DE",
"GR",
"HU",
"IE",
"IT",
"LV",
"LT",
"LU",
"MT",
"NL",
"PL",
"PT",
"RO",
"SK",
"SI",
"ES",
"SE",
}
)
@dataclass(frozen=True)
class GeoResult:
"""Result of a GeoIP lookup."""
country_code: str | None
region: str | None
@property
def is_resolved(self) -> bool:
return self.country_code is not None
def country_to_region(country_code: str, state_code: str | None = None) -> str:
"""Map a country code (+ optional subdivision) to a regional_modes key.
Resolution order:
- EU member states collapse to ``"EU"`` regardless of subdivision;
regional_modes treats the bloc as a single unit.
- Any other country with a subdivision produces ``"{CC}-{SUB}"``
(e.g. ``"US-CA"``, ``"GB-SCT"``, ``"BR-SP"``). The operator
opts in to subdivision-level resolution by configuring a key
of that form in ``regional_modes``; if they don't, the
fallback resolver still matches on the plain country code.
- Country with no subdivision is returned as-is (``"GB"``,
``"BR"``, …).
"""
upper = country_code.upper()
if upper in _EU_COUNTRIES:
return "EU"
if state_code:
return f"{upper}-{state_code.upper()}"
return upper
def detect_region_from_headers(request: Request) -> GeoResult:
"""Attempt to detect the visitor's region from proxy/CDN headers.
This is the fastest path — no external calls needed. A custom
country header configured via ``GEOIP_COUNTRY_HEADER`` takes
priority over the built-in list so operators can plumb in
non-standard CDN/load-balancer headers (e.g. ``x-gclb-country``)
without code changes.
When ``GEOIP_REGION_HEADER`` is also set and the custom country
header resolved, the subdivision code from that header is paired
with the country to build region keys like ``US-CA``. The built-in
country list is country-only — operators who need subdivision
granularity must configure the explicit pair.
Header lookups are case-insensitive.
"""
settings = get_settings()
custom_country = settings.geoip_country_header
custom_region = settings.geoip_region_header
if custom_country:
value = request.headers.get(custom_country)
if value and value.upper() != "XX":
country = value.upper().strip()
state: str | None = None
if custom_region:
raw_state = request.headers.get(custom_region)
if raw_state and raw_state.upper() != "XX":
# ISO 3166-2 subdivision codes may be prefixed
# with the country (e.g. ``US-CA``) or bare (e.g.
# ``CA``). Strip the prefix so ``country_to_region``
# sees just the subdivision.
stripped = raw_state.strip().upper()
state = stripped.split("-", 1)[-1] if "-" in stripped else stripped
return GeoResult(
country_code=country,
region=country_to_region(country, state),
)
for header in _GEO_HEADERS:
value = request.headers.get(header)
if value and value.upper() != "XX":
country = value.upper().strip()
return GeoResult(
country_code=country,
region=country_to_region(country),
)
return GeoResult(country_code=None, region=None)
def get_client_ip(request: Request) -> str | None:
"""Extract the real client IP from the request.
Checks X-Forwarded-For and X-Real-IP before falling back to the
direct connection address.
"""
# X-Forwarded-For: client, proxy1, proxy2
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
if request.client:
return request.client.host
return None
async def lookup_ip_region(ip: str) -> GeoResult:
"""Look up the region for an IP address via an external API.
Uses ip-api.com (free tier, no key required, 45 req/min).
In production this should be replaced with a local MaxMind database.
"""
if _is_private_ip(ip):
return GeoResult(country_code=None, region=None)
try:
async with httpx.AsyncClient(timeout=3.0) as client:
resp = await client.get(
f"http://ip-api.com/json/{ip}",
params={"fields": "status,countryCode,region"},
)
if resp.status_code != 200:
return GeoResult(country_code=None, region=None)
data = resp.json()
if data.get("status") != "success":
return GeoResult(country_code=None, region=None)
country = data.get("countryCode")
state = data.get("region") # State/province code
if not country:
return GeoResult(country_code=None, region=None)
region = country_to_region(country, state)
return GeoResult(country_code=country, region=region)
except Exception:
logger.debug("GeoIP lookup failed for %s", ip, exc_info=True)
return GeoResult(country_code=None, region=None)
def _get_maxmind_reader() -> geoip2.database.Reader | None:
"""Return the cached MaxMind reader, opening the DB on first use.
Caches both successful opens and failures (via
``_maxmind_initialised``) so we don't retry a bad path on every
request. Returns ``None`` if no path is configured or the DB
couldn't be opened.
"""
global _maxmind_reader, _maxmind_initialised
if _maxmind_initialised:
return _maxmind_reader
_maxmind_initialised = True
db_path = get_settings().geoip_maxmind_db_path
if not db_path:
return None
try:
_maxmind_reader = geoip2.database.Reader(db_path)
logger.info("GeoIP: opened MaxMind database at %s", db_path)
except Exception:
logger.warning(
"GeoIP: failed to open MaxMind database at %s — falling back to "
"external lookups. Check GEOIP_MAXMIND_DB_PATH and that the file "
"is readable inside the container.",
db_path,
exc_info=True,
)
_maxmind_reader = None
return _maxmind_reader
def lookup_ip_maxmind(ip: str) -> GeoResult:
"""Resolve an IP via the local MaxMind database.
Memory-mapped read, no network I/O — cheap enough to call
synchronously from the async path. Returns an unresolved
``GeoResult`` when the DB isn't configured, the IP is private, or
the record can't be found.
"""
if _is_private_ip(ip):
return GeoResult(country_code=None, region=None)
reader = _get_maxmind_reader()
if reader is None:
return GeoResult(country_code=None, region=None)
try:
response = reader.city(ip)
except Exception:
logger.debug("MaxMind lookup failed for %s", ip, exc_info=True)
return GeoResult(country_code=None, region=None)
country = response.country.iso_code
if not country:
return GeoResult(country_code=None, region=None)
# ``subdivisions`` is ordered most-specific first; the first entry
# is the ISO 3166-2 code (without the country prefix).
state = response.subdivisions.most_specific.iso_code if response.subdivisions else None
return GeoResult(
country_code=country.upper(),
region=country_to_region(country, state),
)
async def detect_region(request: Request) -> GeoResult:
"""Detect the visitor's region.
Resolution order:
1. CDN/proxy headers (see :func:`detect_region_from_headers`).
2. Local MaxMind database, if ``GEOIP_MAXMIND_DB_PATH`` is set.
3. External ``ip-api.com`` lookup — last-ditch fallback.
Returns an unresolved :class:`GeoResult` if every tier fails.
"""
result = detect_region_from_headers(request)
if result.is_resolved:
return result
ip = get_client_ip(request)
if not ip:
return GeoResult(country_code=None, region=None)
if get_settings().geoip_maxmind_db_path:
result = lookup_ip_maxmind(ip)
if result.is_resolved:
return result
return await lookup_ip_region(ip)
def _is_private_ip(ip: str) -> bool:
"""Check if an IP address is a private/loopback address."""
return (
ip.startswith("127.")
or ip.startswith("10.")
or ip.startswith("192.168.")
or ip.startswith("172.16.")
or ip.startswith("172.17.")
or ip.startswith("172.18.")
or ip.startswith("172.19.")
or ip.startswith("172.2")
or ip.startswith("172.3")
or ip == "::1"
or ip == "localhost"
)

View File

@@ -0,0 +1,41 @@
"""Pseudonymisation helpers for consent records.
Consent records capture a hash of the visitor's IP address and
user-agent string for abuse protection and audit trail purposes.
Previously this used an unsalted truncated SHA-256, which is trivially
reversible for IPv4 addresses (only ~4 billion inputs). We now use
HMAC-SHA256 keyed with a server-side secret so the hash cannot be
recovered without access to the secret.
Public API: :func:`pseudonymise`.
"""
from __future__ import annotations
import hmac
from hashlib import sha256
from src.config.settings import get_settings
# Length of the hex-encoded digest stored in the database. 32 hex chars
# = 128 bits, which is more than enough entropy while keeping the
# column compact. (Previous code used 16 hex chars = 64 bits.)
_DIGEST_HEX_LEN = 32
def pseudonymise(value: str) -> str:
"""Return a keyed hash of *value* safe to store in an audit record.
Uses HMAC-SHA256 with the configured ``pseudonymisation_secret``
(falling back to ``jwt_secret_key`` if not explicitly set). The
resulting hex digest is truncated to 32 characters (128 bits).
An empty input always returns an empty string so callers don't
have to branch on missing data.
"""
if not value:
return ""
key = get_settings().pseudonymisation_key
digest = hmac.new(key, value.encode("utf-8"), sha256).hexdigest()
return digest[:_DIGEST_HEX_LEN]

View File

@@ -0,0 +1,89 @@
"""CDN publishing pipeline.
Publishes resolved site configurations as static JSON files for the
banner script to fetch. Supports local filesystem (development) and
can be extended for S3/GCS/CloudFront.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
from src.config.settings import get_settings
from .config_resolver import build_public_config, resolve_config
logger = logging.getLogger(__name__)
class PublishResult:
"""Result of a publish operation."""
def __init__(self, success: bool, path: str, error: str | None = None) -> None:
self.success = success
self.path = path
self.error = error
self.published_at = datetime.now(UTC).isoformat() if success else None
async def publish_site_config(
site_id: str,
site_config: dict[str, Any],
org_defaults: dict[str, Any] | None = None,
) -> PublishResult:
"""Resolve and publish a site configuration to CDN.
Args:
site_id: The site UUID as a string.
site_config: Raw site configuration from the database.
org_defaults: Organisation-level defaults (optional).
Returns:
PublishResult with success status and path.
"""
try:
# Resolve the full config hierarchy
resolved = resolve_config(site_config, org_defaults)
# Build the public-facing config
public_config = build_public_config(site_id, resolved)
# Publish to the configured backend
settings = get_settings()
path = await _publish_local(site_id, public_config, settings.cdn_base_url)
logger.info("Published config for site %s to %s", site_id, path)
return PublishResult(success=True, path=path)
except Exception as exc:
logger.exception("Failed to publish config for site %s", site_id)
return PublishResult(success=False, path="", error=str(exc))
async def _publish_local(
site_id: str,
config: dict[str, Any],
cdn_base: str,
) -> str:
"""Publish config to local filesystem (for development/Docker Compose).
Writes to the CDN proxy's HTML directory so nginx can serve it.
"""
# Default local publish directory
publish_dir = Path("/app/cdn-publish") if Path("/app").exists() else Path("cdn-publish")
publish_dir.mkdir(parents=True, exist_ok=True)
# Write the config JSON
config_path = publish_dir / f"site-config-{site_id}.json"
config_path.write_text(json.dumps(config, indent=2, default=str))
# Also write a versioned copy for cache-busting
version = datetime.now(UTC).strftime("%Y%m%d%H%M%S")
versioned_path = publish_dir / f"site-config-{site_id}-{version}.json"
versioned_path.write_text(json.dumps(config, indent=2, default=str))
return str(config_path)

View File

@@ -0,0 +1,322 @@
"""Scan orchestration and diff engine.
Provides scan job lifecycle management, result diffing between scans,
and cookie inventory synchronisation from scan results.
"""
from __future__ import annotations
import uuid
from datetime import UTC, datetime
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from src.models.cookie import Cookie
from src.models.scan import ScanJob, ScanResult
from src.models.site import Site
from src.schemas.scanner import (
CookieDiffItem,
DiffStatus,
ScanDiffResponse,
)
async def create_scan_job(
db: AsyncSession,
*,
site_id: uuid.UUID,
trigger: str = "manual",
max_pages: int = 50,
) -> ScanJob:
"""Create a new scan job in 'pending' state."""
job = ScanJob(
site_id=site_id,
status="pending",
trigger=trigger,
pages_total=max_pages,
)
db.add(job)
await db.flush()
return job
async def start_scan_job(db: AsyncSession, job: ScanJob) -> ScanJob:
"""Transition a scan job to 'running'.
Idempotent: if the job is already running (e.g. Celery re-delivered the
task after a worker crash), this is a no-op. Also handles re-delivery
after a transient failure that left the job in 'failed' state mid-retry.
"""
if job.status == "running":
return job
job.status = "running"
job.started_at = datetime.now(UTC)
# Reset any previous error so the retry starts clean
job.error_message = None
await db.flush()
return job
async def complete_scan_job(
db: AsyncSession,
job: ScanJob,
*,
pages_scanned: int = 0,
cookies_found: int = 0,
error_message: str | None = None,
) -> ScanJob:
"""Mark a scan job as completed or failed."""
job.status = "failed" if error_message else "completed"
job.completed_at = datetime.now(UTC)
job.pages_scanned = pages_scanned
job.cookies_found = cookies_found
job.error_message = error_message
await db.flush()
return job
async def add_scan_result(
db: AsyncSession,
*,
scan_job_id: uuid.UUID,
page_url: str,
cookie_name: str,
cookie_domain: str,
storage_type: str = "cookie",
attributes: dict | None = None,
script_source: str | None = None,
auto_category: str | None = None,
initiator_chain: list[str] | None = None,
) -> ScanResult:
"""Record a single cookie discovery from a scan."""
result = ScanResult(
scan_job_id=scan_job_id,
page_url=page_url,
cookie_name=cookie_name,
cookie_domain=cookie_domain,
storage_type=storage_type,
attributes=attributes,
script_source=script_source,
auto_category=auto_category,
initiator_chain=initiator_chain,
)
db.add(result)
await db.flush()
return result
async def get_previous_completed_scan(
db: AsyncSession,
*,
site_id: uuid.UUID,
before_scan_id: uuid.UUID,
) -> ScanJob | None:
"""Find the most recent completed scan before the given one."""
# First get the creation time of the reference scan
ref_result = await db.execute(select(ScanJob.created_at).where(ScanJob.id == before_scan_id))
ref_time = ref_result.scalar_one_or_none()
if ref_time is None:
return None
result = await db.execute(
select(ScanJob)
.where(
ScanJob.site_id == site_id,
ScanJob.status == "completed",
ScanJob.id != before_scan_id,
ScanJob.created_at < ref_time,
)
.order_by(ScanJob.created_at.desc())
.limit(1)
)
return result.scalar_one_or_none()
def _result_key(r: ScanResult) -> tuple[str, str, str]:
"""Unique key for a scan result (cookie identity)."""
return (r.cookie_name, r.cookie_domain, r.storage_type)
async def compute_scan_diff(
db: AsyncSession,
*,
current_scan_id: uuid.UUID,
site_id: uuid.UUID,
) -> ScanDiffResponse:
"""Compute the diff between the current scan and the previous one.
Returns new, removed, and changed cookies. If no previous scan exists,
all cookies in the current scan are marked as 'new'.
"""
previous_scan = await get_previous_completed_scan(
db, site_id=site_id, before_scan_id=current_scan_id
)
# Load current scan results
current_results = await db.execute(
select(ScanResult).where(ScanResult.scan_job_id == current_scan_id)
)
current_items = list(current_results.scalars().all())
current_keys = {_result_key(r): r for r in current_items}
if previous_scan is None:
# No previous scan — everything is new
new_cookies = [
CookieDiffItem(
name=r.cookie_name,
domain=r.cookie_domain,
storage_type=r.storage_type,
diff_status=DiffStatus.NEW,
details="First scan — no previous data",
)
for r in current_items
]
return ScanDiffResponse(
current_scan_id=current_scan_id,
previous_scan_id=None,
new_cookies=new_cookies,
total_new=len(new_cookies),
)
# Load previous scan results
prev_results = await db.execute(
select(ScanResult).where(ScanResult.scan_job_id == previous_scan.id)
)
prev_items = list(prev_results.scalars().all())
prev_keys = {_result_key(r): r for r in prev_items}
new_cookies: list[CookieDiffItem] = []
removed_cookies: list[CookieDiffItem] = []
changed_cookies: list[CookieDiffItem] = []
# New cookies: in current but not in previous
for key, r in current_keys.items():
if key not in prev_keys:
new_cookies.append(
CookieDiffItem(
name=r.cookie_name,
domain=r.cookie_domain,
storage_type=r.storage_type,
diff_status=DiffStatus.NEW,
)
)
# Removed cookies: in previous but not in current
for key, r in prev_keys.items():
if key not in current_keys:
removed_cookies.append(
CookieDiffItem(
name=r.cookie_name,
domain=r.cookie_domain,
storage_type=r.storage_type,
diff_status=DiffStatus.REMOVED,
)
)
# Changed cookies: in both but with different attributes
for key in current_keys:
if key in prev_keys:
curr = current_keys[key]
prev = prev_keys[key]
changes: list[str] = []
if curr.script_source != prev.script_source:
changes.append("script_source changed")
if curr.auto_category != prev.auto_category:
changes.append("auto_category changed")
# Compare cookie attributes (e.g. secure, httpOnly)
if (curr.attributes or {}) != (prev.attributes or {}):
changes.append("attributes changed")
if changes:
changed_cookies.append(
CookieDiffItem(
name=curr.cookie_name,
domain=curr.cookie_domain,
storage_type=curr.storage_type,
diff_status=DiffStatus.CHANGED,
details="; ".join(changes),
)
)
return ScanDiffResponse(
current_scan_id=current_scan_id,
previous_scan_id=previous_scan.id,
new_cookies=new_cookies,
removed_cookies=removed_cookies,
changed_cookies=changed_cookies,
total_new=len(new_cookies),
total_removed=len(removed_cookies),
total_changed=len(changed_cookies),
)
async def sync_scan_results_to_cookies(
db: AsyncSession,
*,
scan_job_id: uuid.UUID,
site_id: uuid.UUID,
) -> int:
"""Upsert scan results into the site's cookie inventory.
Creates new Cookie records for newly discovered items or updates
last_seen_at for existing ones. Returns the number of new cookies.
"""
results = await db.execute(select(ScanResult).where(ScanResult.scan_job_id == scan_job_id))
items = list(results.scalars().all())
now_iso = datetime.now(UTC).isoformat()
new_count = 0
for item in items:
existing = await db.execute(
select(Cookie).where(
Cookie.site_id == site_id,
Cookie.name == item.cookie_name,
Cookie.domain == item.cookie_domain,
Cookie.storage_type == item.storage_type,
)
)
cookie = existing.scalar_one_or_none()
if cookie:
cookie.last_seen_at = now_iso
else:
cookie = Cookie(
site_id=site_id,
name=item.cookie_name,
domain=item.cookie_domain,
storage_type=item.storage_type,
review_status="pending",
first_seen_at=now_iso,
last_seen_at=now_iso,
)
db.add(cookie)
new_count += 1
await db.flush()
return new_count
async def get_sites_due_for_scan(db: AsyncSession) -> list[Site]:
"""Find sites with a scan schedule that are due for scanning.
A site is due when it has a scan_schedule_cron set and either has
never been scanned or the last scan completed before the schedule
interval. For simplicity, this checks the most recent scan's
completed_at against the current time minus a derived interval.
"""
from src.models.site_config import SiteConfig
# Find sites with a cron schedule
result = await db.execute(
select(Site)
.join(SiteConfig, SiteConfig.site_id == Site.id)
.where(
Site.deleted_at.is_(None),
Site.is_active.is_(True),
SiteConfig.scan_schedule_cron.isnot(None),
)
)
return list(result.scalars().all())

View File

View File

@@ -0,0 +1,87 @@
"""Consent record retention purge.
Deletes consent records older than each site's configured
``consent_retention_days``. Sites with no retention configured are
skipped — operators must explicitly opt in per site (or set it at the
org/system level and let the cascade resolve it).
Scheduled by ``celery beat`` daily at 01:00 UTC via the entry in
``src.celery_app.beat_schedule``.
"""
from __future__ import annotations
import asyncio
import logging
from datetime import UTC, datetime, timedelta
from src.celery_app import app
logger = logging.getLogger(__name__)
async def _purge() -> dict[str, int]:
"""Delete expired consent records across all sites with retention set.
Returns a summary ``{"sites_processed": N, "records_deleted": M}``.
"""
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.models.consent import ConsentRecord
from src.models.site_config import SiteConfig
settings = get_settings()
engine = create_async_engine(settings.database_url, echo=False)
sites_processed = 0
records_deleted = 0
async with AsyncSession(engine, expire_on_commit=False) as session:
configs = (
(
await session.execute(
select(SiteConfig).where(SiteConfig.consent_retention_days.isnot(None)),
)
)
.scalars()
.all()
)
now = datetime.now(UTC)
for cfg in configs:
retention_days = cfg.consent_retention_days
if not retention_days or retention_days <= 0:
continue
cutoff = now - timedelta(days=retention_days)
result = await session.execute(
delete(ConsentRecord).where(
ConsentRecord.site_id == cfg.site_id,
ConsentRecord.consented_at < cutoff,
),
)
deleted = result.rowcount or 0
records_deleted += deleted
sites_processed += 1
if deleted:
logger.info(
"retention.purged",
extra={
"site_id": str(cfg.site_id),
"retention_days": retention_days,
"deleted": deleted,
"cutoff": cutoff.isoformat(),
},
)
await session.commit()
await engine.dispose()
return {"sites_processed": sites_processed, "records_deleted": records_deleted}
@app.task(name="src.tasks.retention.purge_expired_consent_records")
def purge_expired_consent_records() -> dict[str, int]:
"""Celery entrypoint for the retention purge."""
return asyncio.run(_purge())

View File

@@ -0,0 +1,308 @@
"""Celery tasks for scan job execution and scheduling.
The run_scan task calls the scanner HTTP service to execute a Playwright
crawl, then processes the results: stores scan results, runs auto-
classification, syncs discovered cookies to the site inventory, and
computes diffs against the previous scan.
"""
from __future__ import annotations
import logging
import uuid
import httpx
from src.celery_app import app
logger = logging.getLogger(__name__)
@app.task(name="src.tasks.scanner.run_scan", bind=True, max_retries=2)
def run_scan(self, scan_job_id: str, site_id: str) -> dict:
"""Execute a scan job by calling the scanner service.
1. Transition job to 'running'
2. Look up site domain
3. Call scanner HTTP service with the domain
4. Store scan results and run auto-classification
5. Sync discovered cookies to the site inventory
6. Mark job as completed
"""
import asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.models.scan import ScanJob
from src.models.site import Site
from src.services.classification import classify_single_cookie
from src.services.scanner import (
add_scan_result,
complete_scan_job,
start_scan_job,
sync_scan_results_to_cookies,
)
settings = get_settings()
job_uuid = uuid.UUID(scan_job_id)
site_uuid = uuid.UUID(site_id)
async def _execute() -> dict:
engine = create_async_engine(settings.database_url, echo=False)
async with AsyncSession(engine, expire_on_commit=False) as db:
try:
# Load the job
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
job = result.scalar_one_or_none()
if job is None:
return {"error": "Scan job not found"}
# Load the site to get the domain
site_result = await db.execute(select(Site).where(Site.id == site_uuid))
site = site_result.scalar_one_or_none()
if site is None:
return {"error": "Site not found"}
# Transition to running
await start_scan_job(db, job)
await db.commit()
# Call the scanner service
scanner_url = f"{settings.scanner_service_url}/scan"
max_pages = job.pages_total or 50
async with httpx.AsyncClient(
timeout=httpx.Timeout(settings.scanner_timeout_seconds)
) as client:
resp = await client.post(
scanner_url,
json={
"domain": site.domain,
"max_pages": max_pages,
},
)
resp.raise_for_status()
scan_data = resp.json()
# Store scan results
cookies = scan_data.get("cookies", [])
pages_crawled = scan_data.get("pages_crawled", 0)
for cookie in cookies:
# Auto-classify the cookie
category = await classify_single_cookie(
db,
site_id=site_uuid,
cookie_name=cookie["name"],
cookie_domain=cookie["domain"],
)
await add_scan_result(
db,
scan_job_id=job_uuid,
page_url=cookie.get("page_url", ""),
cookie_name=cookie["name"],
cookie_domain=cookie["domain"],
storage_type=cookie.get("storage_type", "cookie"),
attributes={
"path": cookie.get("path"),
"http_only": cookie.get("http_only"),
"secure": cookie.get("secure"),
"same_site": cookie.get("same_site"),
"value_length": cookie.get("value_length", 0),
},
script_source=cookie.get("script_source"),
auto_category=category.category_slug if category else None,
initiator_chain=cookie.get("initiator_chain") or None,
)
await db.commit()
# Mark job as completed
await complete_scan_job(
db,
job,
pages_scanned=pages_crawled,
cookies_found=len(cookies),
)
await db.commit()
# Sync results to cookie inventory
new_cookies = await sync_scan_results_to_cookies(
db,
scan_job_id=job_uuid,
site_id=site_uuid,
)
await db.commit()
logger.info(
"Scan %s completed: %d pages, %d cookies, %d new",
scan_job_id,
pages_crawled,
len(cookies),
new_cookies,
)
return {
"scan_job_id": scan_job_id,
"status": "completed",
"pages_scanned": pages_crawled,
"cookies_found": len(cookies),
"new_cookies_synced": new_cookies,
}
except httpx.HTTPError as exc:
logger.error("Scanner service error for job %s: %s", scan_job_id, exc)
await db.rollback()
# Only mark failed on the final retry; otherwise let the
# retry set status back to "running" cleanly.
if self.request.retries >= self.max_retries:
await _mark_failed(db, job_uuid, f"Scanner service error: {exc}")
raise self.retry(exc=exc, countdown=30) from exc
except Exception as exc:
logger.exception("Scan task failed for job %s", scan_job_id)
await db.rollback()
await _mark_failed(db, job_uuid, str(exc))
return {"error": str(exc)}
finally:
await engine.dispose()
return asyncio.run(_execute())
async def _mark_failed(db, job_uuid: uuid.UUID, message: str) -> None:
"""Mark a scan job as failed."""
from sqlalchemy import select
from src.models.scan import ScanJob
from src.services.scanner import complete_scan_job
try:
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
job = result.scalar_one_or_none()
if job:
await complete_scan_job(db, job, error_message=message)
await db.commit()
except Exception:
logger.exception("Failed to mark scan job %s as failed", job_uuid)
@app.task(name="src.tasks.scanner.check_scheduled_scans")
def check_scheduled_scans() -> dict:
"""Periodic task: check which sites are due for a scheduled scan.
Runs every 15 minutes via Celery Beat. For each site with a
scan_schedule_cron, checks if a scan is overdue and triggers one.
"""
import asyncio
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.services.scanner import create_scan_job, get_sites_due_for_scan
settings = get_settings()
async def _check() -> dict:
engine = create_async_engine(settings.database_url, echo=False)
async with AsyncSession(engine, expire_on_commit=False) as db:
try:
sites = await get_sites_due_for_scan(db)
triggered = 0
for site in sites:
job = await create_scan_job(db, site_id=site.id, trigger="scheduled")
await db.commit()
# Dispatch the scan task
run_scan.delay(str(job.id), str(site.id))
triggered += 1
return {"sites_checked": len(sites), "scans_triggered": triggered}
except Exception:
await db.rollback()
raise
finally:
await engine.dispose()
return asyncio.run(_check())
@app.task(name="src.tasks.scanner.recover_stale_scans")
def recover_stale_scans() -> dict:
"""Periodic task: detect and recover scan jobs stuck in pending/running.
- Jobs stuck in 'pending' for >5 minutes are re-dispatched to Celery.
- Jobs stuck in 'running' for >10 minutes are marked as failed.
"""
import asyncio
from datetime import UTC, datetime, timedelta
from sqlalchemy import or_, select
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.config.settings import get_settings
from src.models.scan import ScanJob
from src.services.scanner import complete_scan_job
settings = get_settings()
async def _recover() -> dict:
engine = create_async_engine(settings.database_url, echo=False)
async with AsyncSession(engine, expire_on_commit=False) as db:
try:
now = datetime.now(UTC)
stale_pending_cutoff = now - timedelta(minutes=5)
stale_running_cutoff = now - timedelta(minutes=10)
result = await db.execute(
select(ScanJob).where(
or_(
# Pending too long — likely never picked up
(ScanJob.status == "pending")
& (ScanJob.created_at < stale_pending_cutoff),
# Running too long — likely worker died
(ScanJob.status == "running")
& (ScanJob.started_at < stale_running_cutoff),
)
)
)
stale_jobs = list(result.scalars().all())
redispatched = 0
failed = 0
for job in stale_jobs:
if job.status == "pending":
# Re-dispatch to Celery
logger.warning("Re-dispatching stale pending scan job %s", job.id)
run_scan.delay(str(job.id), str(job.site_id))
redispatched += 1
elif job.status == "running":
# Mark as failed — the worker likely died
logger.warning("Failing stale running scan job %s", job.id)
await complete_scan_job(
db,
job,
error_message=(
"Job timed out (running too long, worker may have crashed)"
),
)
failed += 1
await db.commit()
return {
"stale_jobs_found": len(stale_jobs),
"redispatched": redispatched,
"failed": failed,
}
except Exception:
await db.rollback()
raise
finally:
await engine.dispose()
return asyncio.run(_recover())