feat: initial public release
ConsentOS — a privacy-first cookie consent management platform. Self-hosted, source-available alternative to OneTrust, Cookiebot, and CookieYes. Full standards coverage (IAB TCF v2.2, GPP v1, Google Consent Mode v2, GPC, Shopify Customer Privacy API), multi-tenant architecture with role-based access, configuration cascade (system → org → group → site → region), dark-pattern detection in the scanner, and a tamper-evident consent record audit trail. This is the initial public release. Prior development history is retained internally. See README.md for the feature list, architecture overview, and quick-start instructions. Licensed under the Elastic Licence 2.0 — self-host freely; do not resell as a managed service.
This commit is contained in:
0
apps/api/src/__init__.py
Normal file
0
apps/api/src/__init__.py
Normal file
89
apps/api/src/celery_app.py
Normal file
89
apps/api/src/celery_app.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Celery application and task definitions for the CMP API.
|
||||
|
||||
Provides async-compatible scan scheduling via Celery with Redis as the
|
||||
broker and result backend.
|
||||
"""
|
||||
|
||||
import ssl
|
||||
|
||||
from celery import Celery
|
||||
from celery.schedules import crontab
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
# Named `app` by Celery convention — the CLI finds it via -A src.celery_app
|
||||
app = Celery(
|
||||
"cmp",
|
||||
broker=settings.redis_url,
|
||||
backend=settings.redis_url,
|
||||
)
|
||||
|
||||
# When using rediss:// (TLS) — e.g. Upstash — Celery requires explicit
|
||||
# SSL certificate verification settings for both broker and backend.
|
||||
_conf: dict = {
|
||||
"task_serializer": "json",
|
||||
"accept_content": ["json"],
|
||||
"result_serializer": "json",
|
||||
"timezone": "UTC",
|
||||
"enable_utc": True,
|
||||
"task_track_started": True,
|
||||
"task_acks_late": True,
|
||||
"worker_prefetch_multiplier": 1,
|
||||
}
|
||||
|
||||
if settings.redis_url.startswith("rediss://"):
|
||||
_conf["broker_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
|
||||
_conf["redis_backend_use_ssl"] = {"ssl_cert_reqs": ssl.CERT_NONE}
|
||||
|
||||
app.conf.update(**_conf)
|
||||
|
||||
|
||||
# ── Beat schedule (periodic tasks) ──────────────────────────────────
|
||||
|
||||
app.conf.beat_schedule = {
|
||||
"check-scheduled-scans": {
|
||||
"task": "src.tasks.scanner.check_scheduled_scans",
|
||||
"schedule": crontab(minute="*/15"), # Every 15 minutes
|
||||
},
|
||||
"recover-stale-scans": {
|
||||
"task": "src.tasks.scanner.recover_stale_scans",
|
||||
"schedule": crontab(minute="*/5"), # Every 5 minutes
|
||||
},
|
||||
"purge-expired-consent-records": {
|
||||
"task": "src.tasks.retention.purge_expired_consent_records",
|
||||
"schedule": crontab(hour="1", minute="0"), # Daily at 01:00 UTC
|
||||
},
|
||||
}
|
||||
|
||||
# ── Explicit task imports ───────────────────────────────────────────
|
||||
# Must be at the bottom to avoid circular imports. These ensure the
|
||||
# worker process registers all @app.task definitions on startup.
|
||||
import src.tasks.retention # noqa: E402
|
||||
import src.tasks.scanner # noqa: E402, F401
|
||||
|
||||
# EE tasks are registered conditionally — they only exist in EE mode.
|
||||
try:
|
||||
import ee.api.src.tasks.compliance_scanner
|
||||
import ee.api.src.tasks.compliance_scoring
|
||||
import ee.api.src.tasks.retention # noqa: F401
|
||||
|
||||
app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-scheduled-compliance-scans": {
|
||||
"task": "src.tasks.compliance_scanner.check_scheduled_compliance_scans",
|
||||
"schedule": crontab(hour="3", minute="0"),
|
||||
},
|
||||
"compute-daily-compliance-scores": {
|
||||
"task": "src.tasks.compliance_scoring.compute_daily_scores",
|
||||
"schedule": crontab(hour="4", minute="0"),
|
||||
},
|
||||
"run-retention-purge": {
|
||||
"task": "src.tasks.retention.run_retention_purge",
|
||||
"schedule": crontab(hour="2", minute="0"),
|
||||
},
|
||||
}
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
0
apps/api/src/cli/__init__.py
Normal file
0
apps/api/src/cli/__init__.py
Normal file
40
apps/api/src/cli/bootstrap_admin.py
Normal file
40
apps/api/src/cli/bootstrap_admin.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""One-shot bootstrap of an initial organisation and owner user.
|
||||
|
||||
Usage:
|
||||
python -m src.cli.bootstrap_admin
|
||||
|
||||
Reads ``INITIAL_ADMIN_EMAIL`` and ``INITIAL_ADMIN_PASSWORD`` (plus the
|
||||
optional ``INITIAL_ADMIN_FULL_NAME``, ``INITIAL_ORG_NAME``, and
|
||||
``INITIAL_ORG_SLUG``) from the environment. If the ``users`` table is
|
||||
empty and both credentials are set, creates the org and owner user so
|
||||
the operator can log in to the admin UI. Idempotent: if any user
|
||||
already exists, exits 0 without touching the database.
|
||||
|
||||
Intended to be run as a one-shot init container *after* the database
|
||||
migrations have been applied — typically via ``depends_on`` with
|
||||
``service_healthy`` on the API container.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
from src.config.logging import setup_logging
|
||||
from src.config.settings import get_settings
|
||||
from src.services.bootstrap import bootstrap_initial_admin
|
||||
|
||||
|
||||
async def _main() -> int:
|
||||
settings = get_settings()
|
||||
setup_logging(settings.log_level)
|
||||
await bootstrap_initial_admin(settings)
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> None:
|
||||
sys.exit(asyncio.run(_main()))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
137
apps/api/src/cli/seed_known_cookies.py
Normal file
137
apps/api/src/cli/seed_known_cookies.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Seed the known_cookies table from the Open Cookie Database CSV.
|
||||
|
||||
Usage:
|
||||
python -m src.cli.seed_known_cookies [--csv PATH] [--clear]
|
||||
|
||||
The Open Cookie Database is a community-maintained catalogue of ~2,200+
|
||||
cookie patterns. See https://github.com/jkwakman/Open-Cookie-Database
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import sys
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Category mapping: Open Cookie Database category → CMP slug
|
||||
# ---------------------------------------------------------------------------
|
||||
_CATEGORY_MAP: dict[str, str] = {
|
||||
"Functional": "functional",
|
||||
"Analytics": "analytics",
|
||||
"Marketing": "marketing",
|
||||
"Personalization": "personalisation",
|
||||
"Security": "necessary",
|
||||
}
|
||||
|
||||
_DEFAULT_CSV = Path(__file__).resolve().parent.parent.parent / "data" / "open-cookie-database.csv"
|
||||
|
||||
|
||||
def _build_sync_url(async_url: str) -> str:
|
||||
"""Convert an asyncpg DSN to a psycopg2 DSN for one-off scripts."""
|
||||
return async_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||
|
||||
|
||||
def seed(csv_path: Path, *, clear: bool = False) -> int:
|
||||
"""Read the CSV and upsert rows into known_cookies.
|
||||
|
||||
Returns the number of rows inserted.
|
||||
"""
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
engine = sa.create_engine(_build_sync_url(settings.database_url))
|
||||
|
||||
with engine.begin() as conn:
|
||||
# Build slug → category_id lookup
|
||||
rows = conn.execute(sa.text("SELECT id, slug FROM cookie_categories"))
|
||||
slug_to_id: dict[str, str] = {r[1]: str(r[0]) for r in rows}
|
||||
|
||||
if clear:
|
||||
conn.execute(sa.text("DELETE FROM known_cookies"))
|
||||
|
||||
inserted = 0
|
||||
with csv_path.open(newline="", encoding="utf-8") as fh:
|
||||
reader = csv.DictReader(fh)
|
||||
for row in reader:
|
||||
category = row.get("Category", "").strip()
|
||||
slug = _CATEGORY_MAP.get(category)
|
||||
if not slug or slug not in slug_to_id:
|
||||
continue
|
||||
|
||||
name = row.get("Cookie / Data Key name", "").strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
domain_raw = row.get("Domain", "").strip()
|
||||
domain = domain_raw if domain_raw else "*"
|
||||
|
||||
wildcard = row.get("Wildcard match", "0").strip() == "1"
|
||||
description = row.get("Description", "").strip() or None
|
||||
vendor = row.get("Platform", "").strip() or None
|
||||
|
||||
# Build pattern: if wildcard, append * to name for glob matching
|
||||
name_pattern = f"{name}*" if wildcard else name
|
||||
is_regex = False
|
||||
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
INSERT INTO known_cookies
|
||||
(id, name_pattern, domain_pattern, category_id,
|
||||
vendor, description, is_regex, created_at, updated_at)
|
||||
VALUES
|
||||
(:id, :name_pattern, :domain_pattern, :category_id,
|
||||
:vendor, :description, :is_regex, NOW(), NOW())
|
||||
ON CONFLICT (name_pattern, domain_pattern) DO UPDATE SET
|
||||
category_id = EXCLUDED.category_id,
|
||||
vendor = EXCLUDED.vendor,
|
||||
description = EXCLUDED.description,
|
||||
is_regex = EXCLUDED.is_regex,
|
||||
updated_at = NOW()
|
||||
"""
|
||||
),
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"name_pattern": name_pattern,
|
||||
"domain_pattern": domain,
|
||||
"category_id": slug_to_id[slug],
|
||||
"vendor": vendor,
|
||||
"description": description,
|
||||
"is_regex": is_regex,
|
||||
},
|
||||
)
|
||||
inserted += 1
|
||||
|
||||
return inserted
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Seed known cookies from Open Cookie Database")
|
||||
parser.add_argument(
|
||||
"--csv",
|
||||
type=Path,
|
||||
default=_DEFAULT_CSV,
|
||||
help="Path to the Open Cookie Database CSV (default: bundled copy)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--clear",
|
||||
action="store_true",
|
||||
help="Delete all existing known_cookies before importing",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.csv.exists():
|
||||
print(f"Error: CSV not found at {args.csv}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
count = seed(args.csv, clear=args.clear)
|
||||
print(f"Seeded {count} known cookie patterns from {args.csv.name}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
3
apps/api/src/config/__init__.py
Normal file
3
apps/api/src/config/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.config.settings import Settings, get_settings
|
||||
|
||||
__all__ = ["Settings", "get_settings"]
|
||||
26
apps/api/src/config/edition.py
Normal file
26
apps/api/src/config/edition.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""Edition detection for the open-core architecture.
|
||||
|
||||
Determines whether the application is running in community edition (CE)
|
||||
or enterprise edition (EE) based on the availability of the ``ee``
|
||||
package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_ee() -> bool:
|
||||
"""Return ``True`` if enterprise extensions are available."""
|
||||
try:
|
||||
import ee # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def edition_name() -> str:
|
||||
"""Return a human-readable edition label (``"ee"`` or ``"ce"``)."""
|
||||
return "ee" if is_ee() else "ce"
|
||||
26
apps/api/src/config/logging.py
Normal file
26
apps/api/src/config/logging.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import structlog
|
||||
|
||||
|
||||
def setup_logging(log_level: str = "INFO") -> None:
|
||||
"""Configure structured logging with structlog."""
|
||||
structlog.configure(
|
||||
processors=[
|
||||
structlog.contextvars.merge_contextvars,
|
||||
structlog.processors.add_log_level,
|
||||
structlog.processors.StackInfoRenderer(),
|
||||
structlog.dev.set_exc_info,
|
||||
structlog.processors.TimeStamper(fmt="iso"),
|
||||
structlog.dev.ConsoleRenderer()
|
||||
if sys.stderr.isatty()
|
||||
else structlog.processors.JSONRenderer(),
|
||||
],
|
||||
wrapper_class=structlog.make_filtering_bound_logger(
|
||||
getattr(logging, log_level.upper(), logging.INFO)
|
||||
),
|
||||
context_class=dict,
|
||||
logger_factory=structlog.PrintLoggerFactory(),
|
||||
cache_logger_on_first_use=True,
|
||||
)
|
||||
166
apps/api/src/config/settings.py
Normal file
166
apps/api/src/config/settings.py
Normal file
@@ -0,0 +1,166 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
# Placeholder value — the application refuses to start in non-dev
|
||||
# environments if ``jwt_secret_key`` is left at this literal.
|
||||
_JWT_PLACEHOLDER = "CHANGE-ME-in-production"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
)
|
||||
|
||||
# Application
|
||||
app_name: str = "ConsentOS API"
|
||||
app_version: str = "0.1.0"
|
||||
debug: bool = False
|
||||
environment: str = "development"
|
||||
log_level: str = "INFO"
|
||||
|
||||
# Server
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
allowed_origins: str = "http://localhost:5173"
|
||||
|
||||
@property
|
||||
def allowed_origins_list(self) -> list[str]:
|
||||
"""Parse allowed_origins as a comma-separated string."""
|
||||
return [o.strip() for o in self.allowed_origins.split(",") if o.strip()]
|
||||
|
||||
# Database
|
||||
database_url: str = "postgresql+asyncpg://consentos:consentos@localhost:5432/consentos"
|
||||
database_echo: bool = False
|
||||
database_pool_size: int = 20
|
||||
database_max_overflow: int = 10
|
||||
|
||||
# Redis
|
||||
redis_url: str = "redis://localhost:6379/0"
|
||||
|
||||
# JWT
|
||||
jwt_secret_key: str = _JWT_PLACEHOLDER
|
||||
jwt_algorithm: str = "HS256"
|
||||
jwt_access_token_expire_minutes: int = 30
|
||||
jwt_refresh_token_expire_days: int = 7
|
||||
|
||||
# Pseudonymisation — HMAC key for IP / UA hashing on consent records.
|
||||
# Defaults to deriving from the JWT secret if not explicitly set.
|
||||
pseudonymisation_secret: str | None = None
|
||||
|
||||
# Bootstrap token — required as ``X-Admin-Bootstrap-Token`` on
|
||||
# ``POST /api/v1/organisations/``. When unset (the default), the
|
||||
# endpoint is disabled. Rotate or unset after your first org is
|
||||
# provisioned to prevent further tenant creation.
|
||||
admin_bootstrap_token: str | None = None
|
||||
|
||||
# Initial admin bootstrap — on first startup, if the ``users`` table
|
||||
# is empty and both credentials below are set, the API creates an
|
||||
# organisation and an owner user so the operator can log in to the
|
||||
# admin UI for the first time. Idempotent: once any user exists this
|
||||
# is a no-op, so the variables can safely remain set across restarts.
|
||||
# Rotate the password via the admin UI after first login.
|
||||
initial_admin_email: str | None = None
|
||||
initial_admin_password: str | None = None
|
||||
initial_admin_full_name: str = "Administrator"
|
||||
initial_org_name: str = "Default Organisation"
|
||||
initial_org_slug: str = "default"
|
||||
|
||||
# CDN — public URL where banner scripts (consent-loader.js,
|
||||
# consent-bundle.js) are hosted. In dev the admin UI dog-foods
|
||||
# the banner so localhost:5173 works for testing; in production
|
||||
# this should be a real CDN URL (CloudFlare Pages, S3+CloudFront,
|
||||
# Cloud CDN, etc.) — see docs for setup.
|
||||
cdn_base_url: str = "http://localhost:5173"
|
||||
|
||||
# Scanner service
|
||||
scanner_service_url: str = "http://localhost:8001"
|
||||
scanner_timeout_seconds: int = 300
|
||||
|
||||
# Extra GeoIP country header — checked *before* the built-in list
|
||||
# (``cf-ipcountry``, ``x-vercel-ip-country``, ``x-appengine-country``,
|
||||
# ``x-country-code``). Set this when running behind a CDN/load
|
||||
# balancer that uses a non-standard header, e.g. Google Cloud
|
||||
# Load Balancer's ``x-gclb-country`` or an internal edge proxy.
|
||||
# Header names are case-insensitive. Leave unset if one of the
|
||||
# built-in headers is fine.
|
||||
geoip_country_header: str | None = None
|
||||
|
||||
# Subdivision/state code header — optional companion to
|
||||
# ``GEOIP_COUNTRY_HEADER``. When both are set the API pairs them to
|
||||
# produce region keys like ``US-CA`` or ``GB-SCT`` (ISO 3166-2
|
||||
# subdivision without the country prefix). Different CDNs expose
|
||||
# this under different names: Cloudflare Enterprise uses
|
||||
# ``cf-region-code``, Vercel uses ``x-vercel-ip-country-region``,
|
||||
# GCP Load Balancer uses ``x-gclb-region``, CloudFront functions
|
||||
# use ``cloudfront-viewer-country-region``. Leave unset if you
|
||||
# only need country-level granularity.
|
||||
geoip_region_header: str | None = None
|
||||
|
||||
# Local MaxMind GeoLite2/GeoIP2 City database — used as a fallback
|
||||
# when no CDN header is present. Download GeoLite2-City.mmdb from
|
||||
# https://dev.maxmind.com/geoip/geolite2-free-geolocation-data and
|
||||
# mount it into the container (e.g. ``/data/GeoLite2-City.mmdb``).
|
||||
# When unset, lookups fall back to the free external ip-api.com
|
||||
# service, which is rate-limited and should not be relied on in
|
||||
# production.
|
||||
geoip_maxmind_db_path: str | None = None
|
||||
|
||||
# Rate limiting — on by default. Public endpoints (banner config +
|
||||
# consent submission) are internet-exposed and must not be DoS-able.
|
||||
# Auth endpoints get a stricter bucket via ``RateLimitMiddleware``.
|
||||
rate_limit_enabled: bool = True
|
||||
rate_limit_per_minute: int = 120
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_production_safety(self) -> "Settings":
|
||||
"""Refuse to start with unsafe defaults in non-dev environments."""
|
||||
if self.environment.lower() in ("development", "dev", "local", "test"):
|
||||
return self
|
||||
|
||||
errors: list[str] = []
|
||||
|
||||
if self.jwt_secret_key == _JWT_PLACEHOLDER:
|
||||
errors.append(
|
||||
"JWT_SECRET_KEY is set to the placeholder value "
|
||||
f"{_JWT_PLACEHOLDER!r}. Generate a strong random value "
|
||||
"(e.g. `openssl rand -base64 48`) and set it in the "
|
||||
"environment before starting the API."
|
||||
)
|
||||
|
||||
if "*" in self.allowed_origins_list:
|
||||
errors.append(
|
||||
"ALLOWED_ORIGINS contains '*'. Wildcard CORS combined with "
|
||||
"allow_credentials=True is a credential-theft vector. "
|
||||
"Set ALLOWED_ORIGINS to an explicit list of trusted origins."
|
||||
)
|
||||
|
||||
if errors:
|
||||
msg = "Refusing to start with unsafe configuration:\n - " + "\n - ".join(
|
||||
errors,
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def pseudonymisation_key(self) -> bytes:
|
||||
"""Return the HMAC key used for pseudonymising IP/UA values.
|
||||
|
||||
If ``pseudonymisation_secret`` is not set, derives a per-instance
|
||||
key from the JWT secret so operators don't have to configure two
|
||||
secrets. Using JWT_SECRET directly is acceptable because the
|
||||
HMAC is one-way and the resulting hashes are not reversible.
|
||||
"""
|
||||
source = self.pseudonymisation_secret or self.jwt_secret_key
|
||||
return source.encode("utf-8")
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
3
apps/api/src/db/__init__.py
Normal file
3
apps/api/src/db/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.db.session import get_db
|
||||
|
||||
__all__ = ["get_db"]
|
||||
31
apps/api/src/db/session.py
Normal file
31
apps/api/src/db/session.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
engine = create_async_engine(
|
||||
settings.database_url,
|
||||
echo=settings.database_echo,
|
||||
pool_size=settings.database_pool_size,
|
||||
max_overflow=settings.database_max_overflow,
|
||||
)
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
|
||||
async def get_db() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Dependency that yields an async database session."""
|
||||
async with async_session_factory() as session:
|
||||
try:
|
||||
yield session
|
||||
await session.commit()
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
raise
|
||||
3
apps/api/src/extensions/__init__.py
Normal file
3
apps/api/src/extensions/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from src.extensions.registry import discover_extensions, get_registry
|
||||
|
||||
__all__ = ["discover_extensions", "get_registry"]
|
||||
197
apps/api/src/extensions/registry.py
Normal file
197
apps/api/src/extensions/registry.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""Extension registry for the open-core architecture.
|
||||
|
||||
Provides registration hooks that allow enterprise/commercial code to inject
|
||||
routers, model modules, startup tasks, and OpenAPI tags into the core
|
||||
application — without the core needing any direct knowledge of the
|
||||
extensions.
|
||||
|
||||
In community edition (CE) mode, ``discover_extensions()`` is a no-op
|
||||
because the ``ee`` package is not present.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, FastAPI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OpenAPITag:
|
||||
"""Metadata for a FastAPI OpenAPI tag."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RouterEntry:
|
||||
"""A router registered by an extension."""
|
||||
|
||||
router: APIRouter
|
||||
prefix: str = "/api/v1"
|
||||
tags: list[OpenAPITag] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtensionRegistry:
|
||||
"""Central registry for extension-contributed components.
|
||||
|
||||
Extensions call the module-level helper functions (``register_router``,
|
||||
``register_model_module``, etc.) which delegate to the singleton
|
||||
instance stored in ``_registry``.
|
||||
"""
|
||||
|
||||
routers: list[RouterEntry] = field(default_factory=list)
|
||||
model_modules: list[str] = field(default_factory=list)
|
||||
startup_hooks: list[Callable[[FastAPI], Coroutine[Any, Any, None]]] = field(
|
||||
default_factory=list,
|
||||
)
|
||||
config_enrichers: list[Callable] = field(default_factory=list)
|
||||
consent_record_hooks: list[Callable] = field(default_factory=list)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Registration helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def add_router(
|
||||
self,
|
||||
router: APIRouter,
|
||||
*,
|
||||
prefix: str = "/api/v1",
|
||||
tags: list[OpenAPITag] | None = None,
|
||||
) -> None:
|
||||
self.routers.append(RouterEntry(router=router, prefix=prefix, tags=tags or []))
|
||||
|
||||
def add_model_module(self, module_path: str) -> None:
|
||||
self.model_modules.append(module_path)
|
||||
|
||||
def add_startup_hook(
|
||||
self,
|
||||
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
self.startup_hooks.append(hook)
|
||||
|
||||
def add_config_enricher(self, enricher: Callable) -> None:
|
||||
self.config_enrichers.append(enricher)
|
||||
|
||||
def add_consent_record_hook(self, hook: Callable) -> None:
|
||||
self.consent_record_hooks.append(hook)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Application wiring
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def apply(self, app: FastAPI) -> None:
|
||||
"""Mount all registered routers and tags onto *app*."""
|
||||
for entry in self.routers:
|
||||
# Inject OpenAPI tags
|
||||
for tag in entry.tags:
|
||||
existing = app.openapi_tags or []
|
||||
if not any(t["name"] == tag.name for t in existing):
|
||||
existing.append(
|
||||
{"name": tag.name, "description": tag.description},
|
||||
)
|
||||
app.openapi_tags = existing
|
||||
|
||||
app.include_router(entry.router, prefix=entry.prefix)
|
||||
|
||||
if self.routers:
|
||||
logger.info(
|
||||
"Registered %d extension router(s)",
|
||||
len(self.routers),
|
||||
)
|
||||
|
||||
# Import model modules so SQLAlchemy picks them up
|
||||
for mod in self.model_modules:
|
||||
importlib.import_module(mod)
|
||||
|
||||
if self.model_modules:
|
||||
logger.info(
|
||||
"Registered %d extension model module(s)",
|
||||
len(self.model_modules),
|
||||
)
|
||||
|
||||
|
||||
# Singleton ------------------------------------------------------------------
|
||||
|
||||
_registry = ExtensionRegistry()
|
||||
|
||||
|
||||
def get_registry() -> ExtensionRegistry:
|
||||
"""Return the global extension registry."""
|
||||
return _registry
|
||||
|
||||
|
||||
# Convenience module-level API -----------------------------------------------
|
||||
|
||||
|
||||
def register_router(
|
||||
router: APIRouter,
|
||||
*,
|
||||
prefix: str = "/api/v1",
|
||||
tags: list[OpenAPITag] | None = None,
|
||||
) -> None:
|
||||
"""Register an API router to be mounted at startup."""
|
||||
_registry.add_router(router, prefix=prefix, tags=tags)
|
||||
|
||||
|
||||
def register_model_module(module_path: str) -> None:
|
||||
"""Register a dotted module path whose SQLAlchemy models should be imported."""
|
||||
_registry.add_model_module(module_path)
|
||||
|
||||
|
||||
def register_startup_hook(
|
||||
hook: Callable[[FastAPI], Coroutine[Any, Any, None]],
|
||||
) -> None:
|
||||
"""Register an async callable to run during application startup."""
|
||||
_registry.add_startup_hook(hook)
|
||||
|
||||
|
||||
def register_config_enricher(enricher: Callable) -> None:
|
||||
"""Register a callable that enriches published config.
|
||||
|
||||
The callable signature is ``async (site_id: UUID, db: AsyncSession, config: dict) -> None``.
|
||||
It should mutate *config* in-place to add extension-specific data
|
||||
(e.g. A/B test variants).
|
||||
"""
|
||||
_registry.add_config_enricher(enricher)
|
||||
|
||||
|
||||
def register_consent_record_hook(hook: Callable) -> None:
|
||||
"""Register a callable invoked after a consent record is persisted.
|
||||
|
||||
The callable signature is ``async (db: AsyncSession, consent_record) -> None``.
|
||||
It is called from ``POST /api/v1/consent`` after the record has been
|
||||
flushed to the database. Typical use: generating a consent receipt
|
||||
(EE), writing audit logs, firing webhooks.
|
||||
"""
|
||||
_registry.add_consent_record_hook(hook)
|
||||
|
||||
|
||||
# Discovery ------------------------------------------------------------------
|
||||
|
||||
|
||||
def discover_extensions() -> None:
|
||||
"""Import the EE registration module if installed.
|
||||
|
||||
Enterprise edition is distributed as a separate ``consent-enterprise``
|
||||
package. When installed in the same environment, importing
|
||||
``ee.api.src.register`` triggers its side-effect registrations. In
|
||||
community edition the import simply fails and we carry on.
|
||||
"""
|
||||
try:
|
||||
import ee.api.src.register # noqa: F401
|
||||
|
||||
logger.info("Enterprise extensions loaded")
|
||||
except ImportError:
|
||||
logger.debug("No enterprise extensions found (CE mode)")
|
||||
210
apps/api/src/main.py
Normal file
210
apps/api/src/main.py
Normal file
@@ -0,0 +1,210 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from src.config.edition import edition_name
|
||||
from src.config.logging import setup_logging
|
||||
from src.config.settings import get_settings
|
||||
from src.extensions.registry import discover_extensions, get_registry
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from src.routers import (
|
||||
auth,
|
||||
compliance,
|
||||
config,
|
||||
consent,
|
||||
cookies,
|
||||
org_config,
|
||||
organisations,
|
||||
scanner,
|
||||
site_group_config,
|
||||
site_groups,
|
||||
sites,
|
||||
translations,
|
||||
users,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
"""Application startup and shutdown lifecycle."""
|
||||
settings = get_settings()
|
||||
setup_logging(settings.log_level)
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
"""Application factory."""
|
||||
settings = get_settings()
|
||||
|
||||
app = FastAPI(
|
||||
title=settings.app_name,
|
||||
version=settings.app_version,
|
||||
description=(
|
||||
"Multi-tenant cookie consent management platform API. "
|
||||
"Provides consent collection, cookie scanning, auto-blocking, "
|
||||
"compliance checking, and analytics across multiple sites."
|
||||
),
|
||||
debug=settings.debug,
|
||||
lifespan=lifespan,
|
||||
openapi_tags=[
|
||||
{
|
||||
"name": "auth",
|
||||
"description": "Authentication — login, token refresh, and current user.",
|
||||
},
|
||||
{
|
||||
"name": "config",
|
||||
"description": (
|
||||
"Site configuration — public endpoints for the banner script "
|
||||
"to fetch config, GeoIP-resolved config, and CDN publishing."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "consent",
|
||||
"description": (
|
||||
"Consent recording and retrieval — public endpoints called "
|
||||
"by the banner script to record visitor consent decisions."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "sites",
|
||||
"description": "Site and site config CRUD — manage domains and settings.",
|
||||
},
|
||||
{
|
||||
"name": "cookies",
|
||||
"description": (
|
||||
"Cookie management — categories, discovered cookies, allow-list, "
|
||||
"known cookies database, and auto-classification."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "scanner",
|
||||
"description": (
|
||||
"Cookie scanner — trigger scans, view results, and receive "
|
||||
"client-side cookie reports from the banner script."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "compliance",
|
||||
"description": (
|
||||
"Compliance checking — run checks against GDPR, CNIL, CCPA, "
|
||||
"ePrivacy, and LGPD frameworks."
|
||||
),
|
||||
},
|
||||
{
|
||||
"name": "organisations",
|
||||
"description": "Organisation management — multi-tenant root entities.",
|
||||
},
|
||||
{
|
||||
"name": "users",
|
||||
"description": "User management — org-scoped users with role-based access.",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Security headers
|
||||
app.add_middleware(SecurityHeadersMiddleware)
|
||||
|
||||
# Rate limiting (must be added before CORS to count requests correctly)
|
||||
if settings.rate_limit_enabled:
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
redis_url=settings.redis_url,
|
||||
requests_per_minute=settings.rate_limit_per_minute,
|
||||
auth_requests_per_minute=10,
|
||||
)
|
||||
|
||||
# CORS
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.allowed_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Core routers
|
||||
api_prefix = "/api/v1"
|
||||
app.include_router(auth.router, prefix=api_prefix)
|
||||
app.include_router(config.router, prefix=api_prefix)
|
||||
app.include_router(consent.router, prefix=api_prefix)
|
||||
app.include_router(scanner.router, prefix=api_prefix)
|
||||
app.include_router(compliance.router, prefix=api_prefix)
|
||||
app.include_router(organisations.router, prefix=api_prefix)
|
||||
app.include_router(org_config.router, prefix=api_prefix)
|
||||
app.include_router(users.router, prefix=api_prefix)
|
||||
app.include_router(site_groups.router, prefix=api_prefix)
|
||||
app.include_router(site_group_config.router, prefix=api_prefix)
|
||||
app.include_router(sites.router, prefix=api_prefix)
|
||||
app.include_router(cookies.router, prefix=api_prefix)
|
||||
app.include_router(translations.router, prefix=api_prefix)
|
||||
app.include_router(translations.public_router, prefix=api_prefix)
|
||||
|
||||
# Discover and mount enterprise extensions (no-op in CE mode)
|
||||
discover_extensions()
|
||||
registry = get_registry()
|
||||
registry.apply(app)
|
||||
|
||||
@app.get("/health", tags=["health"])
|
||||
async def health() -> dict[str, str]:
|
||||
"""Shallow liveness check.
|
||||
|
||||
Answers "is the process running?". Suitable for orchestrator
|
||||
liveness probes. For deployment readiness, use
|
||||
``/health/ready`` which verifies downstream dependencies.
|
||||
"""
|
||||
return {"status": "ok", "edition": edition_name()}
|
||||
|
||||
@app.get("/health/ready", tags=["health"])
|
||||
async def health_ready() -> dict[str, object]:
|
||||
"""Deep readiness check — verifies database and Redis.
|
||||
|
||||
Returns HTTP 503 if either dependency is unreachable so load
|
||||
balancers route traffic away from broken instances.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import text
|
||||
|
||||
from src.db.session import engine as db_engine
|
||||
|
||||
checks: dict[str, str] = {}
|
||||
overall_ok = True
|
||||
|
||||
# Database
|
||||
try:
|
||||
async with db_engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
checks["database"] = "ok"
|
||||
except Exception as exc:
|
||||
checks["database"] = f"error: {type(exc).__name__}"
|
||||
overall_ok = False
|
||||
|
||||
# Redis
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
r = aioredis.from_url(settings.redis_url, decode_responses=True)
|
||||
pong = await r.ping()
|
||||
checks["redis"] = "ok" if pong else "error: ping failed"
|
||||
if not pong:
|
||||
overall_ok = False
|
||||
await r.aclose()
|
||||
except Exception as exc:
|
||||
checks["redis"] = f"error: {type(exc).__name__}"
|
||||
overall_ok = False
|
||||
|
||||
payload = {
|
||||
"status": "ok" if overall_ok else "degraded",
|
||||
"edition": edition_name(),
|
||||
"checks": checks,
|
||||
}
|
||||
if not overall_ok:
|
||||
raise HTTPException(status_code=503, detail=payload)
|
||||
return payload
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
0
apps/api/src/middleware/__init__.py
Normal file
0
apps/api/src/middleware/__init__.py
Normal file
111
apps/api/src/middleware/rate_limit.py
Normal file
111
apps/api/src/middleware/rate_limit.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Redis-backed rate limiting middleware.
|
||||
|
||||
Applies per-IP rate limits to all incoming requests. Public endpoints
|
||||
(consent recording, config fetching) are the primary protection target.
|
||||
|
||||
Uses a sliding window counter stored in Redis with automatic expiry.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Simple per-IP rate limiter backed by Redis."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: object,
|
||||
redis_url: str = "redis://localhost:6379/0",
|
||||
requests_per_minute: int = 120,
|
||||
auth_requests_per_minute: int = 10,
|
||||
) -> None:
|
||||
super().__init__(app) # type: ignore[arg-type]
|
||||
self.redis_url = redis_url
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.auth_requests_per_minute = auth_requests_per_minute
|
||||
self._redis: object | None = None
|
||||
|
||||
async def _get_redis(self) -> object | None:
|
||||
"""Lazy-initialise Redis connection."""
|
||||
if self._redis is not None:
|
||||
return self._redis
|
||||
try:
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
self._redis = aioredis.from_url(self.redis_url, decode_responses=True)
|
||||
return self._redis
|
||||
except Exception:
|
||||
logger.warning("Rate limiting disabled: Redis unavailable")
|
||||
return None
|
||||
|
||||
def _get_client_ip(self, request: Request) -> str:
|
||||
"""Extract the real client IP."""
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
if request.client:
|
||||
return request.client.host
|
||||
return "unknown"
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
# Skip rate limiting for health checks
|
||||
if request.url.path in ("/health", "/health/ready", "/health/live"):
|
||||
return await call_next(request)
|
||||
|
||||
r = await self._get_redis()
|
||||
if r is None:
|
||||
# Redis unavailable — allow request through
|
||||
return await call_next(request)
|
||||
|
||||
# Auth endpoints get a stricter bucket to slow down credential
|
||||
# stuffing — login, password reset, token refresh.
|
||||
path = request.url.path
|
||||
is_auth = path.startswith("/api/v1/auth/") and path not in ("/api/v1/auth/me",)
|
||||
limit = self.auth_requests_per_minute if is_auth else self.requests_per_minute
|
||||
bucket = "auth" if is_auth else "req"
|
||||
|
||||
client_ip = self._get_client_ip(request)
|
||||
window = int(time.time() // 60)
|
||||
key = f"cmp:rate:{bucket}:{client_ip}:{window}"
|
||||
|
||||
try:
|
||||
current = await r.incr(key) # type: ignore[union-attr]
|
||||
if current == 1:
|
||||
await r.expire(key, 120) # type: ignore[union-attr]
|
||||
|
||||
if current > limit:
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={"detail": "Too many requests. Please try again later."},
|
||||
headers={
|
||||
"Retry-After": "60",
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": "0",
|
||||
},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
remaining = max(0, limit - current)
|
||||
response.headers["X-RateLimit-Limit"] = str(limit)
|
||||
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
||||
return response
|
||||
|
||||
except Exception:
|
||||
logger.debug("Rate limit check failed", exc_info=True)
|
||||
return await call_next(request)
|
||||
41
apps/api/src/middleware/security_headers.py
Normal file
41
apps/api/src/middleware/security_headers.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Security headers middleware.
|
||||
|
||||
Adds standard security headers to all API responses:
|
||||
- X-Content-Type-Options: nosniff
|
||||
- X-Frame-Options: DENY
|
||||
- X-XSS-Protection: 0 (disabled in favour of CSP)
|
||||
- Referrer-Policy: strict-origin-when-cross-origin
|
||||
- Content-Security-Policy: default-src 'none'
|
||||
- Strict-Transport-Security (HSTS) in production
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Add security headers to all responses."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "0"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Content-Security-Policy"] = "default-src 'none'"
|
||||
|
||||
# HSTS — only on HTTPS requests (reverse proxy may terminate TLS)
|
||||
if request.url.scheme == "https":
|
||||
response.headers["Strict-Transport-Security"] = (
|
||||
"max-age=63072000; includeSubDomains; preload"
|
||||
)
|
||||
|
||||
return response
|
||||
31
apps/api/src/models/__init__.py
Normal file
31
apps/api/src/models/__init__.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from src.models.base import Base
|
||||
from src.models.consent import ConsentRecord
|
||||
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
|
||||
from src.models.org_config import OrgConfig
|
||||
from src.models.organisation import Organisation
|
||||
from src.models.scan import ScanJob, ScanResult
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.models.site_group import SiteGroup
|
||||
from src.models.site_group_config import SiteGroupConfig
|
||||
from src.models.translation import Translation
|
||||
from src.models.user import User
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"ConsentRecord",
|
||||
"Cookie",
|
||||
"CookieAllowListEntry",
|
||||
"CookieCategory",
|
||||
"KnownCookie",
|
||||
"OrgConfig",
|
||||
"Organisation",
|
||||
"ScanJob",
|
||||
"ScanResult",
|
||||
"Site",
|
||||
"SiteConfig",
|
||||
"SiteGroup",
|
||||
"SiteGroupConfig",
|
||||
"Translation",
|
||||
"User",
|
||||
]
|
||||
48
apps/api/src/models/base.py
Normal file
48
apps/api/src/models/base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Base class for all SQLAlchemy models."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TimestampMixin:
|
||||
"""Mixin that adds created_at and updated_at columns."""
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
class UUIDPrimaryKeyMixin:
|
||||
"""Mixin that adds a UUID primary key."""
|
||||
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
||||
|
||||
|
||||
class SoftDeleteMixin:
|
||||
"""Mixin that adds soft delete support."""
|
||||
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
nullable=True,
|
||||
default=None,
|
||||
)
|
||||
81
apps/api/src/models/consent.py
Normal file
81
apps/api/src/models/consent.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Index, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from src.models.base import Base, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class ConsentRecord(UUIDPrimaryKeyMixin, Base):
|
||||
"""Audit trail of every consent event. Partitioned by month for performance."""
|
||||
|
||||
__tablename__ = "consent_records"
|
||||
__table_args__ = (
|
||||
# Composite index for the most common analytics query pattern:
|
||||
# "records for site X between dates A and B". The (site_id,
|
||||
# consented_at DESC) ordering also supports "latest consents
|
||||
# for site X" without an extra sort.
|
||||
Index(
|
||||
"ix_consent_records_site_consented_at",
|
||||
"site_id",
|
||||
"consented_at",
|
||||
),
|
||||
)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Visitor identification (anonymous)
|
||||
visitor_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
ip_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
user_agent_hash: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
|
||||
# Consent details
|
||||
action: Mapped[str] = mapped_column(String(30), nullable=False)
|
||||
categories_accepted: Mapped[list] = mapped_column(JSONB, nullable=False)
|
||||
categories_rejected: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tc_string: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# GCM state at time of consent
|
||||
gcm_state: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPP
|
||||
gpp_string: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# GPC
|
||||
gpc_detected: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpc_honoured: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# A/B testing — soft references to EE `ab_tests` / `ab_test_variants`
|
||||
# tables. Intentionally *no* FK constraint so the core schema works
|
||||
# without the EE extension installed.
|
||||
ab_test_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
ab_variant_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
nullable=True,
|
||||
)
|
||||
|
||||
# Context
|
||||
page_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
country_code: Mapped[str | None] = mapped_column(String(5), nullable=True)
|
||||
region_code: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
|
||||
# Timestamp
|
||||
consented_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
130
apps/api/src/models/cookie.py
Normal file
130
apps/api/src/models/cookie.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class CookieCategory(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Cookie category taxonomy (necessary, functional, analytics, marketing, personalisation)."""
|
||||
|
||||
__tablename__ = "cookie_categories"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||
slug: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_essential: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
display_order: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||
|
||||
# TCF purpose mapping
|
||||
tcf_purpose_ids: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Google Consent Mode consent type mapping
|
||||
gcm_consent_types: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Relationships
|
||||
cookies: Mapped[list["Cookie"]] = relationship(back_populates="category")
|
||||
allow_list_entries: Mapped[list["CookieAllowListEntry"]] = relationship(
|
||||
back_populates="category"
|
||||
)
|
||||
|
||||
|
||||
class Cookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A cookie discovered on a site via scanning or client-side reporting."""
|
||||
|
||||
__tablename__ = "cookies"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"site_id",
|
||||
"name",
|
||||
"domain",
|
||||
"storage_type",
|
||||
name="uq_cookies_site_name_domain_type",
|
||||
),
|
||||
)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
category_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("cookie_categories.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
domain: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
path: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||
max_age_seconds: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
is_http_only: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
is_secure: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
same_site: Mapped[str | None] = mapped_column(String(10), nullable=True)
|
||||
review_status: Mapped[str] = mapped_column(String(20), server_default="pending", nullable=False)
|
||||
first_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
last_seen_at: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="cookies") # noqa: F821
|
||||
category: Mapped["CookieCategory | None"] = relationship(back_populates="cookies")
|
||||
|
||||
|
||||
class CookieAllowListEntry(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Approved cookies per site with category assignment."""
|
||||
|
||||
__tablename__ = "cookie_allow_list"
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"site_id",
|
||||
"name_pattern",
|
||||
"domain_pattern",
|
||||
name="uq_allow_list_site_name_domain",
|
||||
),
|
||||
)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
category_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="cookie_allow_list") # noqa: F821
|
||||
category: Mapped["CookieCategory"] = relationship(back_populates="allow_list_entries")
|
||||
|
||||
|
||||
class KnownCookie(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Shared knowledge base of known cookie patterns for auto-categorisation."""
|
||||
|
||||
__tablename__ = "known_cookies"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("name_pattern", "domain_pattern", name="uq_known_cookies_name_domain"),
|
||||
)
|
||||
|
||||
name_pattern: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
domain_pattern: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
category_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("cookie_categories.id", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
)
|
||||
vendor: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
is_regex: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
64
apps/api/src/models/org_config.py
Normal file
64
apps/api/src/models/org_config.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class OrgConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Organisation-level default configuration.
|
||||
|
||||
These defaults sit between system defaults and site config in the cascade:
|
||||
System Defaults → Org Config → Site Group Config → Site Config → Regional Overrides
|
||||
"""
|
||||
|
||||
__tablename__ = "org_configs"
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Blocking mode
|
||||
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
|
||||
|
||||
# GPP (Global Privacy Platform)
|
||||
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPC (Global Privacy Control)
|
||||
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Google Consent Mode
|
||||
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Shopify Customer Privacy API
|
||||
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Banner
|
||||
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Scanning
|
||||
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Consent
|
||||
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationship
|
||||
organisation: Mapped["Organisation"] = relationship(back_populates="org_config") # noqa: F821
|
||||
26
apps/api/src/models/organisation.py
Normal file
26
apps/api/src/models/organisation.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class Organisation(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""Multi-tenant root entity. Each organisation has multiple sites and users."""
|
||||
|
||||
__tablename__ = "organisations"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
slug: Mapped[str] = mapped_column(String(100), unique=True, nullable=False, index=True)
|
||||
contact_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||
billing_plan: Mapped[str] = mapped_column(String(50), server_default="free", nullable=False)
|
||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
users: Mapped[list["User"]] = relationship(back_populates="organisation") # noqa: F821
|
||||
sites: Mapped[list["Site"]] = relationship(back_populates="organisation") # noqa: F821
|
||||
site_groups: Mapped[list["SiteGroup"]] = relationship( # noqa: F821
|
||||
back_populates="organisation"
|
||||
)
|
||||
org_config: Mapped["OrgConfig | None"] = relationship( # noqa: F821
|
||||
back_populates="organisation", uselist=False
|
||||
)
|
||||
68
apps/api/src/models/scan.py
Normal file
68
apps/api/src/models/scan.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, Integer, String, Text, func
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class ScanJob(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""A cookie scanning job for a site."""
|
||||
|
||||
__tablename__ = "scan_jobs"
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(20), server_default="pending", nullable=False, index=True
|
||||
)
|
||||
trigger: Mapped[str] = mapped_column(String(20), server_default="manual", nullable=False)
|
||||
pages_scanned: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||
pages_total: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
cookies_found: Mapped[int] = mapped_column(Integer, server_default="0", nullable=False)
|
||||
error_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
started_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="scan_jobs") # noqa: F821
|
||||
results: Mapped[list["ScanResult"]] = relationship(back_populates="scan_job")
|
||||
|
||||
|
||||
class ScanResult(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Individual result from a scan — a cookie found on a specific page."""
|
||||
|
||||
__tablename__ = "scan_results"
|
||||
|
||||
scan_job_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("scan_jobs.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
|
||||
page_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
cookie_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
cookie_domain: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
storage_type: Mapped[str] = mapped_column(String(30), server_default="cookie", nullable=False)
|
||||
attributes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
script_source: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
auto_category: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
||||
initiator_chain: Mapped[list[str] | None] = mapped_column(
|
||||
ARRAY(Text), nullable=True, comment="Ordered script URLs from root initiator to leaf"
|
||||
)
|
||||
|
||||
found_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
# Relationships
|
||||
scan_job: Mapped["ScanJob"] = relationship(back_populates="results")
|
||||
48
apps/api/src/models/site.py
Normal file
48
apps/api/src/models/site.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class Site(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""A domain being managed for cookie consent, belongs to an organisation."""
|
||||
|
||||
__tablename__ = "sites"
|
||||
__table_args__ = (UniqueConstraint("organisation_id", "domain", name="uq_sites_org_domain"),)
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
domain: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||
display_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
is_active: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
additional_domains: Mapped[list[str] | None] = mapped_column(
|
||||
ARRAY(String(255)), nullable=True, server_default=None
|
||||
)
|
||||
site_group_id: Mapped[uuid.UUID | None] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("site_groups.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organisation: Mapped["Organisation"] = relationship(back_populates="sites") # noqa: F821
|
||||
site_group: Mapped["SiteGroup | None"] = relationship(back_populates="sites") # noqa: F821
|
||||
config: Mapped["SiteConfig | None"] = relationship( # noqa: F821
|
||||
back_populates="site", uselist=False
|
||||
)
|
||||
cookies: Mapped[list["Cookie"]] = relationship(back_populates="site") # noqa: F821
|
||||
cookie_allow_list: Mapped[list["CookieAllowListEntry"]] = relationship( # noqa: F821
|
||||
back_populates="site"
|
||||
)
|
||||
scan_jobs: Mapped[list["ScanJob"]] = relationship(back_populates="site") # noqa: F821
|
||||
translations: Mapped[list["Translation"]] = relationship( # noqa: F821
|
||||
back_populates="site"
|
||||
)
|
||||
63
apps/api/src/models/site_config.py
Normal file
63
apps/api/src/models/site_config.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class SiteConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Full configuration for a site: blocking mode, TCF, GCM, banner, scanning, consent."""
|
||||
|
||||
__tablename__ = "site_configs"
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Blocking mode
|
||||
blocking_mode: Mapped[str] = mapped_column(String(20), server_default="opt_in", nullable=False)
|
||||
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tcf_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
|
||||
|
||||
# GPP (Global Privacy Platform)
|
||||
gpp_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPC (Global Privacy Control)
|
||||
gpc_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
gpc_global_honour: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
|
||||
# Google Consent Mode
|
||||
gcm_enabled: Mapped[bool] = mapped_column(default=True, nullable=False)
|
||||
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Shopify Customer Privacy API
|
||||
shopify_privacy_enabled: Mapped[bool] = mapped_column(default=False, nullable=False)
|
||||
|
||||
# Banner
|
||||
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
display_mode: Mapped[str] = mapped_column(
|
||||
String(30), server_default="bottom_banner", nullable=False
|
||||
)
|
||||
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Scanning
|
||||
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
scan_max_pages: Mapped[int] = mapped_column(Integer, server_default="50", nullable=False)
|
||||
|
||||
# Consent
|
||||
consent_expiry_days: Mapped[int] = mapped_column(Integer, server_default="365", nullable=False)
|
||||
consent_retention_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationship
|
||||
site: Mapped["Site"] = relationship(back_populates="config") # noqa: F821
|
||||
32
apps/api/src/models/site_group.py
Normal file
32
apps/api/src/models/site_group.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class SiteGroup(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""A logical grouping of sites within an organisation (e.g. a brand)."""
|
||||
|
||||
__tablename__ = "site_groups"
|
||||
__table_args__ = (UniqueConstraint("organisation_id", "name", name="uq_site_groups_org_name"),)
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Relationships
|
||||
organisation: Mapped["Organisation"] = relationship( # noqa: F821
|
||||
back_populates="site_groups"
|
||||
)
|
||||
sites: Mapped[list["Site"]] = relationship(back_populates="site_group") # noqa: F821
|
||||
group_config: Mapped["SiteGroupConfig | None"] = relationship( # noqa: F821
|
||||
back_populates="site_group", uselist=False
|
||||
)
|
||||
63
apps/api/src/models/site_group_config.py
Normal file
63
apps/api/src/models/site_group_config.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, Integer, String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class SiteGroupConfig(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Site-group-level default configuration.
|
||||
|
||||
These defaults sit between org defaults and site config in the cascade:
|
||||
System Defaults -> Org Config -> Site Group Config -> Site Config -> Regional Overrides
|
||||
"""
|
||||
|
||||
__tablename__ = "site_group_configs"
|
||||
|
||||
site_group_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("site_groups.id", ondelete="CASCADE"),
|
||||
unique=True,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Blocking mode
|
||||
blocking_mode: Mapped[str | None] = mapped_column(String(20), nullable=True)
|
||||
regional_modes: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# TCF
|
||||
tcf_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
tcf_publisher_cc: Mapped[str | None] = mapped_column(String(2), nullable=True)
|
||||
|
||||
# GPP (Global Privacy Platform)
|
||||
gpp_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpp_supported_apis: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# GPC (Global Privacy Control)
|
||||
gpc_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gpc_jurisdictions: Mapped[list | None] = mapped_column(JSONB, nullable=True)
|
||||
gpc_global_honour: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Google Consent Mode
|
||||
gcm_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
gcm_default: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
|
||||
# Shopify Customer Privacy API
|
||||
shopify_privacy_enabled: Mapped[bool | None] = mapped_column(nullable=True)
|
||||
|
||||
# Banner
|
||||
banner_config: Mapped[dict | None] = mapped_column(JSONB, nullable=True)
|
||||
privacy_policy_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
terms_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
# Scanning
|
||||
scan_schedule_cron: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||
scan_max_pages: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Consent
|
||||
consent_expiry_days: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# Relationship
|
||||
site_group: Mapped["SiteGroup"] = relationship(back_populates="group_config") # noqa: F821
|
||||
26
apps/api/src/models/translation.py
Normal file
26
apps/api/src/models/translation.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.dialects.postgresql import JSONB, UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class Translation(UUIDPrimaryKeyMixin, TimestampMixin, Base):
|
||||
"""Internationalisation strings per site per locale."""
|
||||
|
||||
__tablename__ = "translations"
|
||||
__table_args__ = (UniqueConstraint("site_id", "locale", name="uq_translations_site_locale"),)
|
||||
|
||||
site_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("sites.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
locale: Mapped[str] = mapped_column(String(10), nullable=False)
|
||||
strings: Mapped[dict] = mapped_column(JSONB, nullable=False)
|
||||
|
||||
# Relationships
|
||||
site: Mapped["Site"] = relationship(back_populates="translations") # noqa: F821
|
||||
31
apps/api/src/models/user.py
Normal file
31
apps/api/src/models/user.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy import ForeignKey, String
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from src.models.base import Base, SoftDeleteMixin, TimestampMixin, UUIDPrimaryKeyMixin
|
||||
|
||||
|
||||
class User(UUIDPrimaryKeyMixin, TimestampMixin, SoftDeleteMixin, Base):
|
||||
"""User account, scoped to an organisation with a role."""
|
||||
|
||||
__tablename__ = "users"
|
||||
|
||||
organisation_id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
ForeignKey("organisations.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
full_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
role: Mapped[str] = mapped_column(
|
||||
String(20),
|
||||
nullable=False,
|
||||
server_default="viewer",
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organisation: Mapped["Organisation"] = relationship(back_populates="users") # noqa: F821
|
||||
0
apps/api/src/routers/__init__.py
Normal file
0
apps/api/src/routers/__init__.py
Normal file
108
apps/api/src/routers/auth.py
Normal file
108
apps/api/src/routers/auth.py
Normal file
@@ -0,0 +1,108 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import JWTError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.db import get_db
|
||||
from src.models.user import User
|
||||
from src.schemas.auth import CurrentUser, LoginRequest, RefreshRequest, TokenResponse
|
||||
from src.services.auth import (
|
||||
create_access_token,
|
||||
create_refresh_token,
|
||||
decode_token,
|
||||
verify_password,
|
||||
)
|
||||
from src.services.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
@router.post("/login", response_model=TokenResponse)
|
||||
async def login(body: LoginRequest, db: AsyncSession = Depends(get_db)) -> TokenResponse:
|
||||
"""Authenticate a user with email and password, return JWT tokens."""
|
||||
result = await db.execute(
|
||||
select(User).where(User.email == body.email, User.deleted_at.is_(None))
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None or not verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid email or password",
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
access_token = create_access_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
role=user.role,
|
||||
email=user.email,
|
||||
)
|
||||
refresh_token = create_refresh_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_in=settings.jwt_access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=TokenResponse)
|
||||
async def refresh(
|
||||
body: RefreshRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> TokenResponse:
|
||||
"""Exchange a valid refresh token for a new access/refresh token pair."""
|
||||
try:
|
||||
payload = decode_token(body.refresh_token)
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired refresh token",
|
||||
) from exc
|
||||
|
||||
if payload.get("type") != "refresh":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token is not a refresh token",
|
||||
)
|
||||
|
||||
user_id = uuid.UUID(payload["sub"])
|
||||
result = await db.execute(select(User).where(User.id == user_id, User.deleted_at.is_(None)))
|
||||
user = result.scalar_one_or_none()
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="User no longer exists",
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
access_token = create_access_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
role=user.role,
|
||||
email=user.email,
|
||||
)
|
||||
new_refresh_token = create_refresh_token(
|
||||
user_id=user.id,
|
||||
organisation_id=user.organisation_id,
|
||||
)
|
||||
|
||||
return TokenResponse(
|
||||
access_token=access_token,
|
||||
refresh_token=new_refresh_token,
|
||||
expires_in=settings.jwt_access_token_expire_minutes * 60,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=CurrentUser)
|
||||
async def get_me(current_user: CurrentUser = Depends(get_current_user)) -> CurrentUser:
|
||||
"""Return the currently authenticated user's profile from the JWT."""
|
||||
return current_user
|
||||
135
apps/api/src/routers/compliance.py
Normal file
135
apps/api/src/routers/compliance.py
Normal file
@@ -0,0 +1,135 @@
|
||||
"""Compliance checking endpoints.
|
||||
|
||||
Evaluates a site's configuration against regulatory frameworks (GDPR, CNIL,
|
||||
CCPA, ePrivacy, LGPD) and returns per-framework compliance reports with scores,
|
||||
issues, and recommendations.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.cookie import Cookie
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.schemas.compliance import (
|
||||
ComplianceCheckRequest,
|
||||
ComplianceCheckResponse,
|
||||
Framework,
|
||||
)
|
||||
from src.services.compliance import (
|
||||
SiteContext,
|
||||
calculate_overall_score,
|
||||
run_compliance_check,
|
||||
)
|
||||
from src.services.dependencies import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/compliance", tags=["compliance"])
|
||||
|
||||
|
||||
async def _build_site_context(
|
||||
site_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> SiteContext:
|
||||
"""Load site config and cookie stats to build a SiteContext."""
|
||||
# Fetch site config
|
||||
result = await db.execute(
|
||||
select(SiteConfig).where(
|
||||
SiteConfig.site_id == site_id,
|
||||
SiteConfig.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
# Fetch cookie statistics
|
||||
total_q = await db.execute(
|
||||
select(func.count()).select_from(Cookie).where(Cookie.site_id == site_id)
|
||||
)
|
||||
total_cookies = total_q.scalar() or 0
|
||||
|
||||
uncat_q = await db.execute(
|
||||
select(func.count())
|
||||
.select_from(Cookie)
|
||||
.where(
|
||||
Cookie.site_id == site_id,
|
||||
Cookie.category_id.is_(None),
|
||||
)
|
||||
)
|
||||
uncategorised_cookies = uncat_q.scalar() or 0
|
||||
|
||||
if config is None:
|
||||
return SiteContext(
|
||||
total_cookies=total_cookies,
|
||||
uncategorised_cookies=uncategorised_cookies,
|
||||
)
|
||||
|
||||
banner_config = config.banner_config or {}
|
||||
return SiteContext(
|
||||
blocking_mode=config.blocking_mode,
|
||||
regional_modes=config.regional_modes,
|
||||
tcf_enabled=config.tcf_enabled,
|
||||
gcm_enabled=config.gcm_enabled,
|
||||
consent_expiry_days=config.consent_expiry_days,
|
||||
privacy_policy_url=config.privacy_policy_url,
|
||||
display_mode=config.display_mode,
|
||||
banner_config=config.banner_config,
|
||||
total_cookies=total_cookies,
|
||||
uncategorised_cookies=uncategorised_cookies,
|
||||
has_reject_button=banner_config.get("show_reject_all", True),
|
||||
has_granular_choices=banner_config.get("show_category_toggles", True),
|
||||
has_cookie_wall=banner_config.get("cookie_wall", False),
|
||||
pre_ticked_boxes=banner_config.get("pre_ticked", False),
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/check/{site_id}",
|
||||
response_model=ComplianceCheckResponse,
|
||||
status_code=status.HTTP_200_OK,
|
||||
)
|
||||
async def check_compliance(
|
||||
site_id: uuid.UUID,
|
||||
body: ComplianceCheckRequest | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user=Depends(get_current_user),
|
||||
) -> ComplianceCheckResponse:
|
||||
"""Run compliance checks against a site's configuration."""
|
||||
# Verify site exists
|
||||
site_result = await db.execute(
|
||||
select(Site).where(Site.id == site_id, Site.deleted_at.is_(None))
|
||||
)
|
||||
site = site_result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site not found",
|
||||
)
|
||||
|
||||
ctx = await _build_site_context(site_id, db)
|
||||
frameworks = body.frameworks if body else None
|
||||
results = run_compliance_check(ctx, frameworks)
|
||||
overall_score = calculate_overall_score(results)
|
||||
|
||||
return ComplianceCheckResponse(
|
||||
site_id=str(site_id),
|
||||
results=results,
|
||||
overall_score=overall_score,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/frameworks", response_model=list[dict])
|
||||
async def list_frameworks() -> list[dict]:
|
||||
"""List all available compliance frameworks."""
|
||||
return [
|
||||
{"id": fw.value, "name": fw.value.upper(), "description": desc}
|
||||
for fw, desc in [
|
||||
(Framework.GDPR, "EU General Data Protection Regulation"),
|
||||
(Framework.CNIL, "French Data Protection Authority (stricter GDPR)"),
|
||||
(Framework.CCPA, "California Consumer Privacy Act / CPRA"),
|
||||
(Framework.EPRIVACY, "EU ePrivacy Directive"),
|
||||
(Framework.LGPD, "Brazilian General Data Protection Law"),
|
||||
]
|
||||
]
|
||||
324
apps/api/src/routers/config.py
Normal file
324
apps/api/src/routers/config.py
Normal file
@@ -0,0 +1,324 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.extensions.registry import get_registry
|
||||
from src.models.org_config import OrgConfig
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.models.site_group_config import SiteGroupConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site import SiteConfigResponse
|
||||
from src.services.config_resolver import (
|
||||
CONFIG_FIELDS,
|
||||
build_public_config,
|
||||
orm_to_config_dict,
|
||||
resolve_config,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
from src.services.geoip import detect_region
|
||||
from src.services.publisher import publish_site_config
|
||||
|
||||
router = APIRouter(prefix="/config", tags=["config"])
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}", response_model=SiteConfigResponse)
|
||||
async def get_public_site_config(
|
||||
site_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Public endpoint: retrieve site config for the banner script. No auth required."""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/resolved")
|
||||
async def get_resolved_config(
|
||||
site_id: uuid.UUID,
|
||||
region: str | None = None,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Public endpoint: retrieve fully resolved config with regional overrides applied.
|
||||
|
||||
Applies the full cascade: System → Org → Group → Site → Regional.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
config_dict = orm_to_config_dict(config, include_id=True)
|
||||
|
||||
# Load org defaults via the site
|
||||
org_id = await _get_site_org_id(site_id, db)
|
||||
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
|
||||
|
||||
# Load site group defaults
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
|
||||
resolved = resolve_config(
|
||||
config_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
region=region,
|
||||
)
|
||||
return build_public_config(str(site_id), resolved)
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/geo-resolved")
|
||||
async def get_geo_resolved_config(
|
||||
site_id: uuid.UUID,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Public endpoint: resolve config using the visitor's detected region.
|
||||
|
||||
Detects the visitor's region from CDN headers or IP geolocation,
|
||||
then applies regional blocking mode overrides automatically.
|
||||
Uses the full cascade: System → Org → Group → Site → Regional.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
# Detect region from request
|
||||
geo = await detect_region(request)
|
||||
|
||||
config_dict = orm_to_config_dict(config, include_id=True)
|
||||
org_id = await _get_site_org_id(site_id, db)
|
||||
org_defaults = await _load_org_defaults(org_id, db) if org_id else None
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
|
||||
resolved = resolve_config(
|
||||
config_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
region=geo.region,
|
||||
)
|
||||
public = build_public_config(str(site_id), resolved)
|
||||
|
||||
# Include detected geo info so the banner can use it
|
||||
public["detected_country"] = geo.country_code
|
||||
public["detected_region"] = geo.region
|
||||
|
||||
return public
|
||||
|
||||
|
||||
@router.get("/geo")
|
||||
async def get_visitor_geo(request: Request) -> dict:
|
||||
"""Public endpoint: return the detected region for the current visitor.
|
||||
|
||||
Useful for banner scripts that need to know the region before
|
||||
fetching the full config.
|
||||
"""
|
||||
geo = await detect_region(request)
|
||||
return {
|
||||
"country_code": geo.country_code,
|
||||
"region": geo.region,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/inheritance")
|
||||
async def get_config_inheritance(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Return the full config inheritance chain for a site.
|
||||
|
||||
Shows the value at each level so the UI can display where each setting
|
||||
comes from: system, org, group, or site.
|
||||
"""
|
||||
from src.services.config_resolver import SYSTEM_DEFAULTS
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
site_dict = orm_to_config_dict(config)
|
||||
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
|
||||
resolved = resolve_config(
|
||||
site_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
|
||||
# For each config field, determine the source
|
||||
sources: dict[str, dict] = {}
|
||||
for field in CONFIG_FIELDS:
|
||||
site_val = site_dict.get(field)
|
||||
group_val = group_defaults.get(field) if group_defaults else None
|
||||
org_val = org_defaults.get(field) if org_defaults else None
|
||||
system_val = SYSTEM_DEFAULTS.get(field)
|
||||
|
||||
# Determine effective source (highest priority non-None wins)
|
||||
if site_val is not None:
|
||||
source = "site"
|
||||
elif group_val is not None:
|
||||
source = "group"
|
||||
elif org_val is not None:
|
||||
source = "org"
|
||||
elif system_val is not None:
|
||||
source = "system"
|
||||
else:
|
||||
source = "system"
|
||||
|
||||
sources[field] = {
|
||||
"resolved_value": resolved.get(field),
|
||||
"source": source,
|
||||
"site_value": site_val,
|
||||
"group_value": group_val,
|
||||
"org_value": org_val,
|
||||
"system_value": system_val,
|
||||
}
|
||||
|
||||
return {
|
||||
"site_id": str(site_id),
|
||||
"site_group_id": str(group_id) if group_id else None,
|
||||
"fields": sources,
|
||||
}
|
||||
|
||||
|
||||
@router.post("/sites/{site_id}/publish")
|
||||
async def publish_config(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Publish fully-resolved site config to CDN. Requires admin role."""
|
||||
result = await db.execute(
|
||||
select(SiteConfig)
|
||||
.join(Site)
|
||||
.where(
|
||||
SiteConfig.site_id == site_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found",
|
||||
)
|
||||
|
||||
config_dict = orm_to_config_dict(config, include_id=True)
|
||||
org_defaults = await _load_org_defaults(current_user.organisation_id, db)
|
||||
group_id = await _get_site_group_id(site_id, db)
|
||||
group_defaults = await _load_group_defaults(group_id, db) if group_id else None
|
||||
resolved = resolve_config(
|
||||
config_dict,
|
||||
org_defaults=org_defaults,
|
||||
group_defaults=group_defaults,
|
||||
)
|
||||
|
||||
# Allow extensions to enrich the published config (e.g. A/B test data)
|
||||
registry = get_registry()
|
||||
for enricher in registry.config_enrichers:
|
||||
await enricher(site_id, db, resolved)
|
||||
|
||||
publish_result = await publish_site_config(str(site_id), resolved)
|
||||
|
||||
if not publish_result.success:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Publish failed: {publish_result.error}",
|
||||
)
|
||||
|
||||
return {
|
||||
"published": True,
|
||||
"path": publish_result.path,
|
||||
"published_at": publish_result.published_at,
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_site_org_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
|
||||
"""Look up the organisation_id for a site."""
|
||||
result = await db.execute(select(Site.organisation_id).where(Site.id == site_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _get_site_group_id(site_id: uuid.UUID, db: AsyncSession) -> uuid.UUID | None:
|
||||
"""Look up the site_group_id for a site."""
|
||||
result = await db.execute(select(Site.site_group_id).where(Site.id == site_id))
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
async def _load_org_defaults(organisation_id: uuid.UUID, db: AsyncSession) -> dict | None:
|
||||
"""Load the org-level config defaults, or None if not set."""
|
||||
result = await db.execute(select(OrgConfig).where(OrgConfig.organisation_id == organisation_id))
|
||||
org_config = result.scalar_one_or_none()
|
||||
if org_config is None:
|
||||
return None
|
||||
return orm_to_config_dict(org_config)
|
||||
|
||||
|
||||
async def _load_group_defaults(group_id: uuid.UUID, db: AsyncSession) -> dict | None:
|
||||
"""Load the site-group-level config defaults, or None if not set."""
|
||||
result = await db.execute(
|
||||
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
|
||||
)
|
||||
group_config = result.scalar_one_or_none()
|
||||
if group_config is None:
|
||||
return None
|
||||
return orm_to_config_dict(group_config)
|
||||
125
apps/api/src/routers/consent.py
Normal file
125
apps/api/src/routers/consent.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.extensions.registry import get_registry
|
||||
from src.models.consent import ConsentRecord
|
||||
from src.models.site import Site
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.consent import (
|
||||
ConsentRecordCreate,
|
||||
ConsentRecordResponse,
|
||||
ConsentVerifyResponse,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
from src.services.pseudonymisation import pseudonymise
|
||||
|
||||
router = APIRouter(prefix="/consent", tags=["consent"])
|
||||
|
||||
|
||||
@router.post("/", response_model=ConsentRecordResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def record_consent(
|
||||
body: ConsentRecordCreate,
|
||||
request: Request,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ConsentRecord:
|
||||
"""Record a consent event from the banner. Public endpoint (no auth required)."""
|
||||
# Pseudonymise IP and user agent with HMAC so the resulting values
|
||||
# cannot be reversed without the server-side secret.
|
||||
client_ip = request.client.host if request.client else ""
|
||||
user_agent = request.headers.get("user-agent", "")
|
||||
|
||||
record = ConsentRecord(
|
||||
site_id=body.site_id,
|
||||
visitor_id=body.visitor_id,
|
||||
ip_hash=pseudonymise(client_ip),
|
||||
user_agent_hash=pseudonymise(user_agent),
|
||||
action=body.action,
|
||||
categories_accepted=body.categories_accepted,
|
||||
categories_rejected=body.categories_rejected,
|
||||
tc_string=body.tc_string,
|
||||
gcm_state=body.gcm_state,
|
||||
page_url=body.page_url,
|
||||
country_code=body.country_code,
|
||||
region_code=body.region_code,
|
||||
)
|
||||
db.add(record)
|
||||
await db.flush()
|
||||
await db.refresh(record)
|
||||
|
||||
# Invoke any registered post-record hooks (EE consent receipts, etc.)
|
||||
for hook in get_registry().consent_record_hooks:
|
||||
await hook(db, record)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
async def _load_record_for_org(
|
||||
consent_id: uuid.UUID,
|
||||
current_user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
) -> ConsentRecord:
|
||||
"""Load a consent record and enforce tenant isolation.
|
||||
|
||||
The record's site must belong to the caller's organisation. A record
|
||||
from another tenant returns 404 rather than 403 so we don't leak
|
||||
existence across tenants.
|
||||
"""
|
||||
stmt = (
|
||||
select(ConsentRecord)
|
||||
.join(Site, Site.id == ConsentRecord.site_id)
|
||||
.where(
|
||||
ConsentRecord.id == consent_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
record = (await db.execute(stmt)).scalar_one_or_none()
|
||||
if record is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Consent record not found",
|
||||
)
|
||||
return record
|
||||
|
||||
|
||||
@router.get("/{consent_id}", response_model=ConsentRecordResponse)
|
||||
async def get_consent(
|
||||
consent_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ConsentRecord:
|
||||
"""Retrieve a consent record by ID.
|
||||
|
||||
Requires authentication and tenant membership. Consent records
|
||||
contain PII-adjacent data (hashed IP, page URL, category decisions)
|
||||
and must not be readable by anyone holding a record UUID.
|
||||
"""
|
||||
return await _load_record_for_org(consent_id, current_user, db)
|
||||
|
||||
|
||||
@router.get("/verify/{consent_id}", response_model=ConsentVerifyResponse)
|
||||
async def verify_consent(
|
||||
consent_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Verify that a consent record exists (audit proof).
|
||||
|
||||
Same tenant-scoped auth as :func:`get_consent` — proof of consent
|
||||
is only meaningful to the organisation that owns the site, and
|
||||
leaking existence to arbitrary callers enables enumeration.
|
||||
"""
|
||||
record = await _load_record_for_org(consent_id, current_user, db)
|
||||
return {
|
||||
"id": record.id,
|
||||
"site_id": record.site_id,
|
||||
"visitor_id": record.visitor_id,
|
||||
"action": record.action,
|
||||
"categories_accepted": record.categories_accepted,
|
||||
"consented_at": record.consented_at,
|
||||
"valid": True,
|
||||
}
|
||||
582
apps/api/src/routers/cookies.py
Normal file
582
apps/api/src/routers/cookies.py
Normal file
@@ -0,0 +1,582 @@
|
||||
"""Cookie category, cookie, and allow-list management endpoints."""
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.cookie import Cookie, CookieAllowListEntry, CookieCategory, KnownCookie
|
||||
from src.models.site import Site
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.cookie import (
|
||||
AllowListEntryCreate,
|
||||
AllowListEntryResponse,
|
||||
AllowListEntryUpdate,
|
||||
ClassificationResultResponse,
|
||||
ClassifySingleRequest,
|
||||
ClassifySiteResponse,
|
||||
CookieCategoryResponse,
|
||||
CookieCreate,
|
||||
CookieResponse,
|
||||
CookieUpdate,
|
||||
KnownCookieCreate,
|
||||
KnownCookieResponse,
|
||||
KnownCookieUpdate,
|
||||
ReviewStatus,
|
||||
)
|
||||
from src.services.classification import classify_single_cookie, classify_site_cookies
|
||||
from src.services.dependencies import get_current_user, require_role
|
||||
|
||||
router = APIRouter(prefix="/cookies", tags=["cookies"])
|
||||
|
||||
|
||||
# ── Cookie categories (read-only, seeded by migration) ──────────────
|
||||
|
||||
|
||||
@router.get("/categories", response_model=list[CookieCategoryResponse])
|
||||
async def list_categories(
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[CookieCategory]:
|
||||
"""List all cookie categories. Public endpoint used by banner and admin."""
|
||||
result = await db.execute(select(CookieCategory).order_by(CookieCategory.display_order))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/categories/{category_id}", response_model=CookieCategoryResponse)
|
||||
async def get_category(
|
||||
category_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieCategory:
|
||||
"""Get a single cookie category by ID."""
|
||||
result = await db.execute(select(CookieCategory).where(CookieCategory.id == category_id))
|
||||
category = result.scalar_one_or_none()
|
||||
if not category:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Category not found")
|
||||
return category
|
||||
|
||||
|
||||
# ── Cookies per site ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_org_site(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
) -> Site:
|
||||
"""Fetch a site ensuring it belongs to the user's organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if not site:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
|
||||
return site
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sites/{site_id}",
|
||||
response_model=list[CookieResponse],
|
||||
)
|
||||
async def list_cookies(
|
||||
site_id: uuid.UUID,
|
||||
review_status: ReviewStatus | None = Query(None),
|
||||
category_id: uuid.UUID | None = Query(None),
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[Cookie]:
|
||||
"""List cookies discovered on a site, with optional filters."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
query = select(Cookie).where(Cookie.site_id == site_id)
|
||||
if review_status:
|
||||
query = query.where(Cookie.review_status == review_status.value)
|
||||
if category_id:
|
||||
query = query.where(Cookie.category_id == category_id)
|
||||
query = query.order_by(Cookie.name)
|
||||
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}",
|
||||
response_model=CookieResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_cookie(
|
||||
site_id: uuid.UUID,
|
||||
body: CookieCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Cookie:
|
||||
"""Create a cookie record for a site (manual entry or from scanner)."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
# Validate category if provided
|
||||
if body.category_id:
|
||||
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
cookie = Cookie(
|
||||
site_id=site_id,
|
||||
**body.model_dump(),
|
||||
first_seen_at=datetime.now(UTC).isoformat(),
|
||||
last_seen_at=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
db.add(cookie)
|
||||
await db.flush()
|
||||
await db.refresh(cookie)
|
||||
return cookie
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/summary")
|
||||
async def cookie_summary(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Get a summary of cookies for a site (counts by status and category)."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
# Count by review status
|
||||
status_result = await db.execute(
|
||||
select(Cookie.review_status, func.count(Cookie.id))
|
||||
.where(Cookie.site_id == site_id)
|
||||
.group_by(Cookie.review_status)
|
||||
)
|
||||
by_status = {row[0]: row[1] for row in status_result.all()}
|
||||
|
||||
# Count by category
|
||||
cat_result = await db.execute(
|
||||
select(CookieCategory.slug, func.count(Cookie.id))
|
||||
.outerjoin(Cookie, Cookie.category_id == CookieCategory.id)
|
||||
.where(Cookie.site_id == site_id)
|
||||
.group_by(CookieCategory.slug)
|
||||
)
|
||||
by_category = {row[0]: row[1] for row in cat_result.all()}
|
||||
|
||||
# Uncategorised count
|
||||
uncat_result = await db.execute(
|
||||
select(func.count(Cookie.id)).where(Cookie.site_id == site_id, Cookie.category_id.is_(None))
|
||||
)
|
||||
uncategorised = uncat_result.scalar() or 0
|
||||
|
||||
return {
|
||||
"total": sum(by_status.values()),
|
||||
"by_status": by_status,
|
||||
"by_category": by_category,
|
||||
"uncategorised": uncategorised,
|
||||
}
|
||||
|
||||
|
||||
# ── Allow-list per site ──────────────────────────────────────────────
|
||||
# (Must be defined before {cookie_id} routes to avoid path conflicts)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/sites/{site_id}/allow-list",
|
||||
response_model=list[AllowListEntryResponse],
|
||||
)
|
||||
async def list_allow_list(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[CookieAllowListEntry]:
|
||||
"""List all allow-list entries for a site."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry)
|
||||
.where(CookieAllowListEntry.site_id == site_id)
|
||||
.order_by(CookieAllowListEntry.name_pattern)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}/allow-list",
|
||||
response_model=AllowListEntryResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_allow_list_entry(
|
||||
site_id: uuid.UUID,
|
||||
body: AllowListEntryCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieAllowListEntry:
|
||||
"""Add a cookie pattern to the allow-list for a site."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
# Validate category
|
||||
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
entry = CookieAllowListEntry(
|
||||
site_id=site_id,
|
||||
**body.model_dump(),
|
||||
)
|
||||
db.add(entry)
|
||||
await db.flush()
|
||||
await db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/sites/{site_id}/allow-list/{entry_id}",
|
||||
response_model=AllowListEntryResponse,
|
||||
)
|
||||
async def update_allow_list_entry(
|
||||
site_id: uuid.UUID,
|
||||
entry_id: uuid.UUID,
|
||||
body: AllowListEntryUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieAllowListEntry:
|
||||
"""Update an allow-list entry."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry).where(
|
||||
CookieAllowListEntry.id == entry_id,
|
||||
CookieAllowListEntry.site_id == site_id,
|
||||
)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Allow-list entry not found",
|
||||
)
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
|
||||
if "category_id" in updates and updates["category_id"] is not None:
|
||||
cat = await db.execute(
|
||||
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
|
||||
)
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(entry, field, value)
|
||||
entry.updated_at = datetime.now(UTC)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(entry)
|
||||
return entry
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sites/{site_id}/allow-list/{entry_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_allow_list_entry(
|
||||
site_id: uuid.UUID,
|
||||
entry_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Remove an entry from the allow-list."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry).where(
|
||||
CookieAllowListEntry.id == entry_id,
|
||||
CookieAllowListEntry.site_id == site_id,
|
||||
)
|
||||
)
|
||||
entry = result.scalar_one_or_none()
|
||||
if not entry:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Allow-list entry not found",
|
||||
)
|
||||
|
||||
await db.delete(entry)
|
||||
|
||||
|
||||
# ── Individual cookie by ID (must come after /summary and /allow-list) ──
|
||||
|
||||
|
||||
@router.get("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
|
||||
async def get_cookie(
|
||||
site_id: uuid.UUID,
|
||||
cookie_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Cookie:
|
||||
"""Get a single cookie by ID."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
|
||||
)
|
||||
cookie = result.scalar_one_or_none()
|
||||
if not cookie:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
|
||||
return cookie
|
||||
|
||||
|
||||
@router.patch("/sites/{site_id}/{cookie_id}", response_model=CookieResponse)
|
||||
async def update_cookie(
|
||||
site_id: uuid.UUID,
|
||||
cookie_id: uuid.UUID,
|
||||
body: CookieUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Cookie:
|
||||
"""Update a cookie record (e.g. assign category, change review status)."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
|
||||
)
|
||||
cookie = result.scalar_one_or_none()
|
||||
if not cookie:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
|
||||
# Validate category if being changed
|
||||
if "category_id" in updates and updates["category_id"] is not None:
|
||||
cat = await db.execute(
|
||||
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
|
||||
)
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(cookie, field, value)
|
||||
cookie.updated_at = datetime.now(UTC)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(cookie)
|
||||
return cookie
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/sites/{site_id}/{cookie_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_cookie(
|
||||
site_id: uuid.UUID,
|
||||
cookie_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a cookie record."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(Cookie).where(Cookie.id == cookie_id, Cookie.site_id == site_id)
|
||||
)
|
||||
cookie = result.scalar_one_or_none()
|
||||
if not cookie:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Cookie not found")
|
||||
|
||||
await db.delete(cookie)
|
||||
|
||||
|
||||
# ── Known cookies database ──────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/known", response_model=list[KnownCookieResponse])
|
||||
async def list_known_cookies(
|
||||
vendor: str | None = Query(None, description="Filter by vendor name"),
|
||||
search: str | None = Query(None, description="Search by name pattern"),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: CurrentUser = Depends(get_current_user),
|
||||
) -> list[KnownCookie]:
|
||||
"""List known cookie patterns from the shared database."""
|
||||
query = select(KnownCookie).order_by(KnownCookie.name_pattern)
|
||||
if vendor:
|
||||
query = query.where(KnownCookie.vendor == vendor)
|
||||
if search:
|
||||
query = query.where(KnownCookie.name_pattern.ilike(f"%{search}%"))
|
||||
result = await db.execute(query)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.post(
|
||||
"/known",
|
||||
response_model=KnownCookieResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def create_known_cookie(
|
||||
body: KnownCookieCreate,
|
||||
_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> KnownCookie:
|
||||
"""Add a new pattern to the known cookies database."""
|
||||
# Validate category
|
||||
cat = await db.execute(select(CookieCategory).where(CookieCategory.id == body.category_id))
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
known = KnownCookie(**body.model_dump())
|
||||
db.add(known)
|
||||
await db.flush()
|
||||
await db.refresh(known)
|
||||
return known
|
||||
|
||||
|
||||
@router.get("/known/{known_id}", response_model=KnownCookieResponse)
|
||||
async def get_known_cookie(
|
||||
known_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_user: CurrentUser = Depends(get_current_user),
|
||||
) -> KnownCookie:
|
||||
"""Get a single known cookie pattern by ID."""
|
||||
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
|
||||
known = result.scalar_one_or_none()
|
||||
if not known:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Known cookie not found",
|
||||
)
|
||||
return known
|
||||
|
||||
|
||||
@router.patch("/known/{known_id}", response_model=KnownCookieResponse)
|
||||
async def update_known_cookie(
|
||||
known_id: uuid.UUID,
|
||||
body: KnownCookieUpdate,
|
||||
_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> KnownCookie:
|
||||
"""Update a known cookie pattern."""
|
||||
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
|
||||
known = result.scalar_one_or_none()
|
||||
if not known:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Known cookie not found",
|
||||
)
|
||||
|
||||
updates = body.model_dump(exclude_unset=True)
|
||||
if "category_id" in updates and updates["category_id"] is not None:
|
||||
cat = await db.execute(
|
||||
select(CookieCategory).where(CookieCategory.id == updates["category_id"])
|
||||
)
|
||||
if not cat.scalar_one_or_none():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid category_id",
|
||||
)
|
||||
|
||||
for field, value in updates.items():
|
||||
setattr(known, field, value)
|
||||
known.updated_at = datetime.now(UTC)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(known)
|
||||
return known
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/known/{known_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
async def delete_known_cookie(
|
||||
known_id: uuid.UUID,
|
||||
_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a known cookie pattern."""
|
||||
result = await db.execute(select(KnownCookie).where(KnownCookie.id == known_id))
|
||||
known = result.scalar_one_or_none()
|
||||
if not known:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Known cookie not found",
|
||||
)
|
||||
await db.delete(known)
|
||||
|
||||
|
||||
# ── Classification endpoints ────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}/classify",
|
||||
response_model=ClassifySiteResponse,
|
||||
)
|
||||
async def classify_cookies(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ClassifySiteResponse:
|
||||
"""Auto-classify pending cookies for a site against known patterns."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
results = await classify_site_cookies(db, site_id, only_pending=True)
|
||||
matched_count = sum(1 for r in results if r.matched)
|
||||
|
||||
return ClassifySiteResponse(
|
||||
site_id=str(site_id),
|
||||
total=len(results),
|
||||
matched=matched_count,
|
||||
unmatched=len(results) - matched_count,
|
||||
results=[
|
||||
ClassificationResultResponse(
|
||||
cookie_name=r.cookie_name,
|
||||
cookie_domain=r.cookie_domain,
|
||||
category_id=r.category_id,
|
||||
category_slug=r.category_slug,
|
||||
vendor=r.vendor,
|
||||
description=r.description,
|
||||
match_source=r.match_source,
|
||||
matched=r.matched,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sites/{site_id}/classify/preview",
|
||||
response_model=ClassificationResultResponse,
|
||||
)
|
||||
async def classify_preview(
|
||||
site_id: uuid.UUID,
|
||||
body: ClassifySingleRequest,
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> ClassificationResultResponse:
|
||||
"""Preview classification for a single cookie without saving."""
|
||||
await _get_org_site(site_id, current_user, db)
|
||||
|
||||
result = await classify_single_cookie(db, site_id, body.cookie_name, body.cookie_domain)
|
||||
return ClassificationResultResponse(
|
||||
cookie_name=result.cookie_name,
|
||||
cookie_domain=result.cookie_domain,
|
||||
category_id=result.category_id,
|
||||
category_slug=result.category_slug,
|
||||
vendor=result.vendor,
|
||||
description=result.description,
|
||||
match_source=result.match_source,
|
||||
matched=result.matched,
|
||||
)
|
||||
69
apps/api/src/routers/org_config.py
Normal file
69
apps/api/src/routers/org_config.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Organisation-level default configuration endpoints.
|
||||
|
||||
Provides GET and PUT for the organisation's global config defaults.
|
||||
These defaults sit between system defaults and site config in the cascade.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.org_config import OrgConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.org_config import OrgConfigResponse, OrgConfigUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/org-config", tags=["organisations"])
|
||||
|
||||
|
||||
@router.get("/", response_model=OrgConfigResponse)
|
||||
async def get_org_config(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> OrgConfig:
|
||||
"""Retrieve the organisation's global configuration defaults."""
|
||||
result = await db.execute(
|
||||
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
# Auto-create an empty config row so the response is always valid
|
||||
config = OrgConfig(organisation_id=current_user.organisation_id)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/", response_model=OrgConfigResponse)
|
||||
async def update_org_config(
|
||||
body: OrgConfigUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> OrgConfig:
|
||||
"""Create or update the organisation's global configuration defaults.
|
||||
|
||||
Only non-None fields will override system defaults when resolving site config.
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(OrgConfig).where(OrgConfig.organisation_id == current_user.organisation_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
config = OrgConfig(
|
||||
organisation_id=current_user.organisation_id,
|
||||
**body.model_dump(exclude_unset=True),
|
||||
)
|
||||
db.add(config)
|
||||
else:
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
118
apps/api/src/routers/organisations.py
Normal file
118
apps/api/src/routers/organisations.py
Normal file
@@ -0,0 +1,118 @@
|
||||
import hmac
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.db import get_db
|
||||
from src.models.organisation import Organisation
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.organisation import (
|
||||
OrganisationCreate,
|
||||
OrganisationResponse,
|
||||
OrganisationUpdate,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/organisations", tags=["organisations"])
|
||||
|
||||
|
||||
def _require_bootstrap_token(
|
||||
x_admin_bootstrap_token: str | None = Header(default=None),
|
||||
) -> None:
|
||||
"""Gate organisation creation behind a static bootstrap token.
|
||||
|
||||
The token is configured via ``ADMIN_BOOTSTRAP_TOKEN``. When unset
|
||||
(the default), the endpoint is disabled entirely — operators must
|
||||
explicitly opt in and should rotate or unset the value after their
|
||||
initial org is provisioned.
|
||||
"""
|
||||
expected = get_settings().admin_bootstrap_token
|
||||
if not expected:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=(
|
||||
"Organisation creation is disabled. Set ADMIN_BOOTSTRAP_TOKEN "
|
||||
"in the environment to enable it."
|
||||
),
|
||||
)
|
||||
if not x_admin_bootstrap_token or not hmac.compare_digest(
|
||||
x_admin_bootstrap_token,
|
||||
expected,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or missing admin bootstrap token",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/", response_model=OrganisationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_organisation(
|
||||
body: OrganisationCreate,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
_: None = Depends(_require_bootstrap_token),
|
||||
) -> Organisation:
|
||||
"""Create a new organisation. Gated by ``X-Admin-Bootstrap-Token``.
|
||||
|
||||
See :func:`_require_bootstrap_token` for the gating semantics. Once
|
||||
your initial organisation exists, rotate or unset
|
||||
``ADMIN_BOOTSTRAP_TOKEN`` to disable further tenant creation.
|
||||
"""
|
||||
# Check slug uniqueness
|
||||
existing = await db.execute(select(Organisation).where(Organisation.slug == body.slug))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Organisation with slug '{body.slug}' already exists",
|
||||
)
|
||||
|
||||
org = Organisation(**body.model_dump())
|
||||
db.add(org)
|
||||
await db.flush()
|
||||
await db.refresh(org)
|
||||
return org
|
||||
|
||||
|
||||
@router.get("/me", response_model=OrganisationResponse)
|
||||
async def get_my_organisation(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Organisation:
|
||||
"""Get the current user's organisation."""
|
||||
result = await db.execute(
|
||||
select(Organisation).where(
|
||||
Organisation.id == current_user.organisation_id,
|
||||
Organisation.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
|
||||
return org
|
||||
|
||||
|
||||
@router.patch("/me", response_model=OrganisationResponse)
|
||||
async def update_my_organisation(
|
||||
body: OrganisationUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Organisation:
|
||||
"""Update the current user's organisation. Requires owner or admin role."""
|
||||
result = await db.execute(
|
||||
select(Organisation).where(
|
||||
Organisation.id == current_user.organisation_id,
|
||||
Organisation.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
org = result.scalar_one_or_none()
|
||||
if org is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Organisation not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(org, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(org)
|
||||
return org
|
||||
310
apps/api/src/routers/scanner.py
Normal file
310
apps/api/src/routers/scanner.py
Normal file
@@ -0,0 +1,310 @@
|
||||
"""Scanner and client-side cookie report endpoints.
|
||||
|
||||
Accepts cookie reports from the client-side reporter embedded in the banner
|
||||
bundle, upserts discovered cookies into the site's cookie inventory, and
|
||||
provides scan job management (trigger, list, detail, diff).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.cookie import Cookie
|
||||
from src.models.scan import ScanJob, ScanResult
|
||||
from src.models.site import Site
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.scanner import (
|
||||
CookieReportRequest,
|
||||
CookieReportResponse,
|
||||
ScanDiffResponse,
|
||||
ScanJobDetailResponse,
|
||||
ScanJobResponse,
|
||||
TriggerScanRequest,
|
||||
)
|
||||
from src.services.dependencies import get_current_user
|
||||
from src.services.scanner import (
|
||||
compute_scan_diff,
|
||||
create_scan_job,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/scanner", tags=["scanner"])
|
||||
|
||||
|
||||
# ── Client-side cookie report (public, no auth) ─────────────────────
|
||||
|
||||
|
||||
@router.post(
|
||||
"/report",
|
||||
response_model=CookieReportResponse,
|
||||
status_code=status.HTTP_202_ACCEPTED,
|
||||
)
|
||||
async def receive_cookie_report(
|
||||
body: CookieReportRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> CookieReportResponse:
|
||||
"""Receive a cookie report from the client-side reporter.
|
||||
|
||||
This is a public endpoint (no auth) since it's called from the banner
|
||||
script running on end-user browsers. The site_id acts as implicit auth.
|
||||
"""
|
||||
# Verify site exists
|
||||
site_result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == body.site_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if site_result.scalar_one_or_none() is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site not found",
|
||||
)
|
||||
|
||||
new_cookies = 0
|
||||
now_iso = datetime.now(UTC).isoformat()
|
||||
|
||||
for reported in body.cookies:
|
||||
# Check if this cookie already exists for the site
|
||||
existing = await db.execute(
|
||||
select(Cookie).where(
|
||||
Cookie.site_id == body.site_id,
|
||||
Cookie.name == reported.name,
|
||||
Cookie.domain == reported.domain,
|
||||
Cookie.storage_type == reported.storage_type,
|
||||
)
|
||||
)
|
||||
cookie = existing.scalar_one_or_none()
|
||||
|
||||
if cookie:
|
||||
# Update last_seen_at timestamp
|
||||
cookie.last_seen_at = now_iso
|
||||
else:
|
||||
# Create new cookie record
|
||||
cookie = Cookie(
|
||||
site_id=body.site_id,
|
||||
name=reported.name,
|
||||
domain=reported.domain,
|
||||
storage_type=reported.storage_type,
|
||||
path=reported.path,
|
||||
is_secure=reported.is_secure,
|
||||
same_site=reported.same_site,
|
||||
review_status="pending",
|
||||
first_seen_at=now_iso,
|
||||
last_seen_at=now_iso,
|
||||
)
|
||||
db.add(cookie)
|
||||
new_cookies += 1
|
||||
|
||||
await db.flush()
|
||||
|
||||
return CookieReportResponse(
|
||||
accepted=True,
|
||||
cookies_received=len(body.cookies),
|
||||
new_cookies=new_cookies,
|
||||
)
|
||||
|
||||
|
||||
# ── Scan job management (authenticated) ─────────────────────────────
|
||||
|
||||
|
||||
async def _verify_site_access(
|
||||
site_id: uuid.UUID,
|
||||
user: CurrentUser,
|
||||
db: AsyncSession,
|
||||
) -> Site:
|
||||
"""Verify site exists and belongs to the user's organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site not found",
|
||||
)
|
||||
return site
|
||||
|
||||
|
||||
@router.post(
|
||||
"/scans",
|
||||
response_model=ScanJobResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
async def trigger_scan(
|
||||
body: TriggerScanRequest,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScanJob:
|
||||
"""Trigger a new cookie scan for a site.
|
||||
|
||||
Creates a scan job in 'pending' state and dispatches it to the
|
||||
Celery worker queue for execution.
|
||||
"""
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
await _verify_site_access(body.site_id, user, db)
|
||||
|
||||
# Check for an already-running scan
|
||||
active_result = await db.execute(
|
||||
select(ScanJob).where(
|
||||
ScanJob.site_id == body.site_id,
|
||||
ScanJob.status.in_(["pending", "running"]),
|
||||
)
|
||||
)
|
||||
active_jobs = list(active_result.scalars().all())
|
||||
|
||||
now = datetime.now(UTC)
|
||||
stale_pending_cutoff = now - timedelta(minutes=5)
|
||||
stale_running_cutoff = now - timedelta(minutes=10)
|
||||
|
||||
for active_job in active_jobs:
|
||||
is_stale_pending = (
|
||||
active_job.status == "pending"
|
||||
and active_job.created_at.replace(tzinfo=UTC) < stale_pending_cutoff
|
||||
)
|
||||
is_stale_running = (
|
||||
active_job.status == "running"
|
||||
and active_job.started_at
|
||||
and active_job.started_at.replace(tzinfo=UTC) < stale_running_cutoff
|
||||
)
|
||||
if is_stale_pending or is_stale_running:
|
||||
logger.warning(
|
||||
"Failing stale %s scan job %s for site %s",
|
||||
active_job.status,
|
||||
active_job.id,
|
||||
body.site_id,
|
||||
)
|
||||
await complete_scan_job(
|
||||
db,
|
||||
active_job,
|
||||
error_message=(
|
||||
f"Job was stale ({active_job.status} too long), superseded by new scan"
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="A scan is already in progress for this site",
|
||||
)
|
||||
|
||||
job = await create_scan_job(
|
||||
db,
|
||||
site_id=body.site_id,
|
||||
trigger="manual",
|
||||
max_pages=body.max_pages,
|
||||
)
|
||||
|
||||
# Commit before dispatching to Celery so the worker can find the
|
||||
# job in the database immediately (avoids race condition).
|
||||
await db.commit()
|
||||
|
||||
# Dispatch to Celery (import here to avoid import at module level
|
||||
# when Celery broker is unavailable during testing)
|
||||
try:
|
||||
from src.tasks.scanner import run_scan
|
||||
|
||||
run_scan.delay(str(job.id), str(body.site_id))
|
||||
except Exception:
|
||||
logger.exception("Failed to dispatch scan job %s to Celery", job.id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail=(
|
||||
"Background task queue is unavailable — scan job"
|
||||
" created but cannot be processed. Please try again later."
|
||||
),
|
||||
) from None
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@router.get("/scans/site/{site_id}", response_model=list[ScanJobResponse])
|
||||
async def list_scans(
|
||||
site_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
offset: int = Query(default=0, ge=0),
|
||||
) -> list[ScanJob]:
|
||||
"""List scan jobs for a site, most recent first."""
|
||||
await _verify_site_access(site_id, user, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(ScanJob)
|
||||
.where(ScanJob.site_id == site_id)
|
||||
.order_by(ScanJob.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/scans/{scan_id}", response_model=ScanJobDetailResponse)
|
||||
async def get_scan(
|
||||
scan_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Retrieve a scan job with its results."""
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Scan job not found",
|
||||
)
|
||||
|
||||
# Verify org access
|
||||
await _verify_site_access(job.site_id, user, db)
|
||||
|
||||
# Load results
|
||||
results = await db.execute(
|
||||
select(ScanResult).where(ScanResult.scan_job_id == scan_id).order_by(ScanResult.cookie_name)
|
||||
)
|
||||
scan_results = list(results.scalars().all())
|
||||
|
||||
return {
|
||||
"id": job.id,
|
||||
"site_id": job.site_id,
|
||||
"status": job.status,
|
||||
"trigger": job.trigger,
|
||||
"pages_scanned": job.pages_scanned,
|
||||
"pages_total": job.pages_total,
|
||||
"cookies_found": job.cookies_found,
|
||||
"error_message": job.error_message,
|
||||
"started_at": job.started_at,
|
||||
"completed_at": job.completed_at,
|
||||
"created_at": job.created_at,
|
||||
"updated_at": job.updated_at,
|
||||
"results": scan_results,
|
||||
}
|
||||
|
||||
|
||||
@router.get("/scans/{scan_id}/diff", response_model=ScanDiffResponse)
|
||||
async def get_scan_diff(
|
||||
scan_id: uuid.UUID,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
user: CurrentUser = Depends(get_current_user),
|
||||
) -> ScanDiffResponse:
|
||||
"""Get the diff between a scan and its predecessor."""
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == scan_id))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Scan job not found",
|
||||
)
|
||||
|
||||
await _verify_site_access(job.site_id, user, db)
|
||||
|
||||
return await compute_scan_diff(db, current_scan_id=scan_id, site_id=job.site_id)
|
||||
101
apps/api/src/routers/site_group_config.py
Normal file
101
apps/api/src/routers/site_group_config.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Site-group-level default configuration endpoints.
|
||||
|
||||
Provides GET and PUT for a site group's config defaults.
|
||||
These defaults sit between org defaults and site config in the cascade.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site_group import SiteGroup
|
||||
from src.models.site_group_config import SiteGroupConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site_group_config import SiteGroupConfigResponse, SiteGroupConfigUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
|
||||
|
||||
|
||||
@router.get("/{group_id}/config", response_model=SiteGroupConfigResponse)
|
||||
async def get_site_group_config(
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteGroupConfig:
|
||||
"""Retrieve configuration defaults for a site group."""
|
||||
await _verify_group_ownership(group_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
# Auto-create an empty config row so the response is always valid
|
||||
config = SiteGroupConfig(site_group_id=group_id)
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/{group_id}/config", response_model=SiteGroupConfigResponse)
|
||||
async def update_site_group_config(
|
||||
group_id: uuid.UUID,
|
||||
body: SiteGroupConfigUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteGroupConfig:
|
||||
"""Create or update configuration defaults for a site group.
|
||||
|
||||
Only non-None fields will override org/system defaults when resolving site config.
|
||||
"""
|
||||
await _verify_group_ownership(group_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteGroupConfig).where(SiteGroupConfig.site_group_id == group_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
|
||||
if config is None:
|
||||
config = SiteGroupConfig(
|
||||
site_group_id=group_id,
|
||||
**body.model_dump(exclude_unset=True),
|
||||
)
|
||||
db.add(config)
|
||||
else:
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
# -- Helpers ------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _verify_group_ownership(
|
||||
group_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Ensure the site group belongs to the user's organisation."""
|
||||
result = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.id == group_id,
|
||||
SiteGroup.organisation_id == organisation_id,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if result.scalar_one_or_none() is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site group not found",
|
||||
)
|
||||
198
apps/api/src/routers/site_groups.py
Normal file
198
apps/api/src/routers/site_groups.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
from src.models.site_group import SiteGroup
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site_group import SiteGroupCreate, SiteGroupResponse, SiteGroupUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/site-groups", tags=["site-groups"])
|
||||
|
||||
|
||||
@router.post("/", response_model=SiteGroupResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_site_group(
|
||||
body: SiteGroupCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Create a new site group within the current organisation."""
|
||||
# Check name uniqueness within the org
|
||||
existing = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.organisation_id == current_user.organisation_id,
|
||||
SiteGroup.name == body.name,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Site group '{body.name}' already exists in this organisation",
|
||||
)
|
||||
|
||||
group = SiteGroup(
|
||||
organisation_id=current_user.organisation_id,
|
||||
name=body.name,
|
||||
description=body.description,
|
||||
)
|
||||
db.add(group)
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
return _to_response(group, site_count=0)
|
||||
|
||||
|
||||
@router.get("/", response_model=list[SiteGroupResponse])
|
||||
async def list_site_groups(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[dict]:
|
||||
"""List all site groups in the current organisation with site counts."""
|
||||
# Subquery for site counts
|
||||
site_count_sq = (
|
||||
select(
|
||||
Site.site_group_id,
|
||||
func.count(Site.id).label("cnt"),
|
||||
)
|
||||
.where(Site.deleted_at.is_(None))
|
||||
.group_by(Site.site_group_id)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(SiteGroup, func.coalesce(site_count_sq.c.cnt, 0).label("site_count"))
|
||||
.outerjoin(site_count_sq, SiteGroup.id == site_count_sq.c.site_group_id)
|
||||
.where(
|
||||
SiteGroup.organisation_id == current_user.organisation_id,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(SiteGroup.name)
|
||||
)
|
||||
|
||||
return [_to_response(row.SiteGroup, site_count=row.site_count) for row in result.all()]
|
||||
|
||||
|
||||
@router.get("/{group_id}", response_model=SiteGroupResponse)
|
||||
async def get_site_group(
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Get a specific site group by ID."""
|
||||
group = await _get_org_group(group_id, current_user.organisation_id, db)
|
||||
site_count = await _count_sites(group_id, db)
|
||||
return _to_response(group, site_count=site_count)
|
||||
|
||||
|
||||
@router.patch("/{group_id}", response_model=SiteGroupResponse)
|
||||
async def update_site_group(
|
||||
group_id: uuid.UUID,
|
||||
body: SiteGroupUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict:
|
||||
"""Update a site group's name or description."""
|
||||
group = await _get_org_group(group_id, current_user.organisation_id, db)
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
|
||||
# Check name uniqueness if name is being changed
|
||||
if "name" in update_data and update_data["name"] != group.name:
|
||||
existing = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.organisation_id == current_user.organisation_id,
|
||||
SiteGroup.name == update_data["name"],
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
SiteGroup.id != group_id,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Site group '{update_data['name']}' already exists",
|
||||
)
|
||||
|
||||
for field, value in update_data.items():
|
||||
setattr(group, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(group)
|
||||
site_count = await _count_sites(group_id, db)
|
||||
return _to_response(group, site_count=site_count)
|
||||
|
||||
|
||||
@router.delete("/{group_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_site_group(
|
||||
group_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Soft-delete a site group. Sites in this group become ungrouped."""
|
||||
group = await _get_org_group(group_id, current_user.organisation_id, db)
|
||||
|
||||
# Ungroup all sites in this group
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.site_group_id == group_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
for site in result.scalars().all():
|
||||
site.site_group_id = None
|
||||
|
||||
group.deleted_at = datetime.now(UTC)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_org_group(
|
||||
group_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> SiteGroup:
|
||||
"""Fetch a site group ensuring it belongs to the given organisation."""
|
||||
result = await db.execute(
|
||||
select(SiteGroup).where(
|
||||
SiteGroup.id == group_id,
|
||||
SiteGroup.organisation_id == organisation_id,
|
||||
SiteGroup.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
group = result.scalar_one_or_none()
|
||||
if group is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site group not found",
|
||||
)
|
||||
return group
|
||||
|
||||
|
||||
async def _count_sites(group_id: uuid.UUID, db: AsyncSession) -> int:
|
||||
"""Count active sites in a group."""
|
||||
result = await db.execute(
|
||||
select(func.count(Site.id)).where(
|
||||
Site.site_group_id == group_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
return result.scalar_one()
|
||||
|
||||
|
||||
def _to_response(group: SiteGroup, *, site_count: int) -> dict:
|
||||
"""Convert a SiteGroup model to a response dict with site_count."""
|
||||
return {
|
||||
"id": group.id,
|
||||
"organisation_id": group.organisation_id,
|
||||
"name": group.name,
|
||||
"description": group.description,
|
||||
"created_at": group.created_at,
|
||||
"updated_at": group.updated_at,
|
||||
"site_count": site_count,
|
||||
}
|
||||
220
apps/api/src/routers/sites.py
Normal file
220
apps/api/src/routers/sites.py
Normal file
@@ -0,0 +1,220 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
from src.models.site_config import SiteConfig
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.site import (
|
||||
SiteConfigCreate,
|
||||
SiteConfigResponse,
|
||||
SiteConfigUpdate,
|
||||
SiteCreate,
|
||||
SiteResponse,
|
||||
SiteUpdate,
|
||||
)
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/sites", tags=["sites"])
|
||||
|
||||
|
||||
# ── Site CRUD ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/", response_model=SiteResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_site(
|
||||
body: SiteCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Site:
|
||||
"""Create a new site within the current organisation."""
|
||||
# Check domain uniqueness within the org
|
||||
existing = await db.execute(
|
||||
select(Site).where(
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.domain == body.domain,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Site with domain '{body.domain}' already exists in this organisation",
|
||||
)
|
||||
|
||||
site = Site(
|
||||
organisation_id=current_user.organisation_id,
|
||||
domain=body.domain,
|
||||
display_name=body.display_name,
|
||||
site_group_id=body.site_group_id,
|
||||
)
|
||||
db.add(site)
|
||||
await db.flush()
|
||||
|
||||
# Auto-create a default site configuration
|
||||
default_config = SiteConfig(site_id=site.id)
|
||||
db.add(default_config)
|
||||
await db.flush()
|
||||
|
||||
await db.refresh(site)
|
||||
return site
|
||||
|
||||
|
||||
@router.get("/", response_model=list[SiteResponse])
|
||||
async def list_sites(
|
||||
site_group_id: uuid.UUID | None = Query(default=None),
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[Site]:
|
||||
"""List all active sites in the current organisation, optionally filtered by group."""
|
||||
query = select(Site).where(
|
||||
Site.organisation_id == current_user.organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
if site_group_id is not None:
|
||||
query = query.where(Site.site_group_id == site_group_id)
|
||||
result = await db.execute(query.order_by(Site.domain))
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/{site_id}", response_model=SiteResponse)
|
||||
async def get_site(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Site:
|
||||
"""Get a specific site by ID."""
|
||||
site = await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
return site
|
||||
|
||||
|
||||
@router.patch("/{site_id}", response_model=SiteResponse)
|
||||
async def update_site(
|
||||
site_id: uuid.UUID,
|
||||
body: SiteUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Site:
|
||||
"""Update a site's display name or active status."""
|
||||
site = await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(site, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(site)
|
||||
return site
|
||||
|
||||
|
||||
@router.delete("/{site_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def deactivate_site(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Soft-delete a site."""
|
||||
site = await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
site.deleted_at = datetime.now(UTC)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ── Site config CRUD ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/{site_id}/config", response_model=SiteConfigResponse)
|
||||
async def get_site_config(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Get the configuration for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found. Create one first.",
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
@router.put("/{site_id}/config", response_model=SiteConfigResponse)
|
||||
async def create_or_replace_site_config(
|
||||
site_id: uuid.UUID,
|
||||
body: SiteConfigCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Create or replace the full configuration for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
for field, value in body.model_dump().items():
|
||||
setattr(existing, field, value)
|
||||
await db.flush()
|
||||
await db.refresh(existing)
|
||||
return existing
|
||||
|
||||
config = SiteConfig(site_id=site_id, **body.model_dump())
|
||||
db.add(config)
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
@router.patch("/{site_id}/config", response_model=SiteConfigResponse)
|
||||
async def update_site_config(
|
||||
site_id: uuid.UUID,
|
||||
body: SiteConfigUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> SiteConfig:
|
||||
"""Partially update the configuration for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
result = await db.execute(select(SiteConfig).where(SiteConfig.site_id == site_id))
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Site configuration not found. Create one first.",
|
||||
)
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(config, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(config)
|
||||
return config
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _get_org_site(
|
||||
site_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
db: AsyncSession,
|
||||
) -> Site:
|
||||
"""Fetch a site ensuring it belongs to the given organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
|
||||
return site
|
||||
195
apps/api/src/routers/translations.py
Normal file
195
apps/api/src/routers/translations.py
Normal file
@@ -0,0 +1,195 @@
|
||||
"""Translation management endpoints.
|
||||
|
||||
CRUD for per-site, per-locale translation strings used by the banner script.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
from src.models.translation import Translation
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.translation import TranslationCreate, TranslationResponse, TranslationUpdate
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/sites/{site_id}/translations", tags=["translations"])
|
||||
|
||||
|
||||
async def _get_org_site(site_id: uuid.UUID, organisation_id: uuid.UUID, db: AsyncSession) -> Site:
|
||||
"""Ensure site belongs to the current organisation."""
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_id,
|
||||
Site.organisation_id == organisation_id,
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if site is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Site not found")
|
||||
return site
|
||||
|
||||
|
||||
@router.get("/", response_model=list[TranslationResponse])
|
||||
async def list_translations(
|
||||
site_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[Translation]:
|
||||
"""List all translations for a site."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(Translation.site_id == site_id).order_by(Translation.locale)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/{locale}", response_model=TranslationResponse)
|
||||
async def get_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Translation:
|
||||
"""Get translation strings for a specific locale."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No translation found for locale '{locale}'",
|
||||
)
|
||||
return translation
|
||||
|
||||
|
||||
@router.post("/", response_model=TranslationResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_translation(
|
||||
site_id: uuid.UUID,
|
||||
body: TranslationCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Translation:
|
||||
"""Create a translation for a new locale."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
|
||||
# Check for duplicate locale
|
||||
existing = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == body.locale,
|
||||
)
|
||||
)
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"Translation for locale '{body.locale}' already exists",
|
||||
)
|
||||
|
||||
translation = Translation(
|
||||
site_id=site_id,
|
||||
locale=body.locale,
|
||||
strings=body.strings,
|
||||
)
|
||||
db.add(translation)
|
||||
await db.flush()
|
||||
await db.refresh(translation)
|
||||
return translation
|
||||
|
||||
|
||||
@router.put("/{locale}", response_model=TranslationResponse)
|
||||
async def update_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
body: TranslationUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> Translation:
|
||||
"""Replace the strings for an existing locale translation."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No translation found for locale '{locale}'",
|
||||
)
|
||||
|
||||
translation.strings = body.strings
|
||||
await db.flush()
|
||||
await db.refresh(translation)
|
||||
return translation
|
||||
|
||||
|
||||
@router.delete("/{locale}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Delete a translation for a specific locale."""
|
||||
await _get_org_site(site_id, current_user.organisation_id, db)
|
||||
result = await db.execute(
|
||||
select(Translation).where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"No translation found for locale '{locale}'",
|
||||
)
|
||||
await db.delete(translation)
|
||||
await db.flush()
|
||||
|
||||
|
||||
# ── Public endpoint for the banner script ────────────────────────────
|
||||
|
||||
public_router = APIRouter(prefix="/translations", tags=["translations"])
|
||||
|
||||
|
||||
@public_router.get("/{site_id}/{locale}")
|
||||
async def get_public_translation(
|
||||
site_id: uuid.UUID,
|
||||
locale: str,
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> dict[str, str]:
|
||||
"""Public endpoint: return translation strings for the banner script.
|
||||
|
||||
No auth required. Returns the raw strings dict for a given site and locale.
|
||||
Returns 404 if no translation exists (banner falls back to English defaults).
|
||||
"""
|
||||
result = await db.execute(
|
||||
select(Translation)
|
||||
.join(Site)
|
||||
.where(
|
||||
Translation.site_id == site_id,
|
||||
Translation.locale == locale,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
translation = result.scalar_one_or_none()
|
||||
if translation is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="Translation not found",
|
||||
)
|
||||
return translation.strings
|
||||
136
apps/api/src/routers/users.py
Normal file
136
apps/api/src/routers/users.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.user import User
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.schemas.user import UserCreate, UserResponse, UserUpdate
|
||||
from src.services.auth import hash_password
|
||||
from src.services.dependencies import require_role
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
|
||||
|
||||
@router.post("/", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_user(
|
||||
body: UserCreate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Invite/create a new user within the current organisation."""
|
||||
# Check email uniqueness
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail=f"User with email '{body.email}' already exists",
|
||||
)
|
||||
|
||||
user = User(
|
||||
organisation_id=current_user.organisation_id,
|
||||
email=body.email,
|
||||
password_hash=hash_password(body.password),
|
||||
full_name=body.full_name,
|
||||
role=body.role,
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.get("/", response_model=list[UserResponse])
|
||||
async def list_users(
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> list[User]:
|
||||
"""List all active users in the current organisation."""
|
||||
result = await db.execute(
|
||||
select(User)
|
||||
.where(
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
.order_by(User.created_at)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
@router.get("/{user_id}", response_model=UserResponse)
|
||||
async def get_user(
|
||||
user_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin", "editor", "viewer")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Get a specific user by ID within the current organisation."""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
return user
|
||||
|
||||
|
||||
@router.patch("/{user_id}", response_model=UserResponse)
|
||||
async def update_user(
|
||||
user_id: uuid.UUID,
|
||||
body: UserUpdate,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> User:
|
||||
"""Update a user's name or role. Requires owner or admin."""
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
update_data = body.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(user, field, value)
|
||||
|
||||
await db.flush()
|
||||
await db.refresh(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/{user_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def deactivate_user(
|
||||
user_id: uuid.UUID,
|
||||
current_user: CurrentUser = Depends(require_role("owner", "admin")),
|
||||
db: AsyncSession = Depends(get_db),
|
||||
) -> None:
|
||||
"""Soft-delete (deactivate) a user. Requires owner or admin."""
|
||||
if user_id == current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot deactivate yourself",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(User).where(
|
||||
User.id == user_id,
|
||||
User.organisation_id == current_user.organisation_id,
|
||||
User.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
||||
|
||||
user.deleted_at = datetime.now(UTC)
|
||||
await db.flush()
|
||||
0
apps/api/src/schemas/__init__.py
Normal file
0
apps/api/src/schemas/__init__.py
Normal file
45
apps/api/src/schemas/auth.py
Normal file
45
apps/api/src/schemas/auth.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, EmailStr
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
token_type: str = "bearer"
|
||||
expires_in: int
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str # user ID
|
||||
org_id: str # organisation ID
|
||||
role: str # user role
|
||||
exp: datetime
|
||||
iat: datetime
|
||||
type: str = "access" # "access" or "refresh"
|
||||
|
||||
|
||||
class CurrentUser(BaseModel):
|
||||
"""Represents the authenticated user extracted from a JWT."""
|
||||
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
email: str
|
||||
role: str
|
||||
|
||||
def has_role(self, *roles: str) -> bool:
|
||||
return self.role in roles
|
||||
|
||||
@property
|
||||
def is_admin(self) -> bool:
|
||||
return self.role in ("owner", "admin")
|
||||
56
apps/api/src/schemas/compliance.py
Normal file
56
apps/api/src/schemas/compliance.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Pydantic schemas for compliance check results."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Severity(StrEnum):
|
||||
CRITICAL = "critical"
|
||||
WARNING = "warning"
|
||||
INFO = "info"
|
||||
|
||||
|
||||
class Framework(StrEnum):
|
||||
GDPR = "gdpr"
|
||||
CNIL = "cnil"
|
||||
CCPA = "ccpa"
|
||||
EPRIVACY = "eprivacy"
|
||||
LGPD = "lgpd"
|
||||
|
||||
|
||||
class ComplianceIssue(BaseModel):
|
||||
"""A single compliance issue found during a check."""
|
||||
|
||||
rule_id: str
|
||||
severity: Severity
|
||||
message: str
|
||||
recommendation: str
|
||||
|
||||
|
||||
class FrameworkResult(BaseModel):
|
||||
"""Compliance result for a single regulatory framework."""
|
||||
|
||||
framework: Framework
|
||||
score: int = Field(ge=0, le=100, description="Compliance score (0-100)")
|
||||
status: str = Field(description="Overall status: compliant, partial, non_compliant")
|
||||
issues: list[ComplianceIssue] = Field(default_factory=list)
|
||||
rules_checked: int = 0
|
||||
rules_passed: int = 0
|
||||
|
||||
|
||||
class ComplianceCheckRequest(BaseModel):
|
||||
"""Request body for compliance checks."""
|
||||
|
||||
frameworks: list[Framework] | None = Field(
|
||||
default=None,
|
||||
description="Frameworks to check. If null, all frameworks are checked.",
|
||||
)
|
||||
|
||||
|
||||
class ComplianceCheckResponse(BaseModel):
|
||||
"""Full compliance check response for a site."""
|
||||
|
||||
site_id: str
|
||||
results: list[FrameworkResult]
|
||||
overall_score: int = Field(ge=0, le=100, description="Weighted average across all frameworks")
|
||||
62
apps/api/src/schemas/consent.py
Normal file
62
apps/api/src/schemas/consent.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ConsentAction(StrEnum):
|
||||
ACCEPT_ALL = "accept_all"
|
||||
REJECT_ALL = "reject_all"
|
||||
CUSTOM = "custom"
|
||||
WITHDRAW = "withdraw"
|
||||
|
||||
|
||||
class ConsentRecordCreate(BaseModel):
|
||||
"""Payload sent by the banner when a consent event occurs."""
|
||||
|
||||
site_id: uuid.UUID
|
||||
visitor_id: str = Field(min_length=1, max_length=255)
|
||||
action: ConsentAction
|
||||
categories_accepted: list[str]
|
||||
categories_rejected: list[str] | None = None
|
||||
tc_string: str | None = None
|
||||
gcm_state: dict | None = None
|
||||
gpp_string: str | None = None
|
||||
gpc_detected: bool | None = None
|
||||
gpc_honoured: bool | None = None
|
||||
page_url: str | None = None
|
||||
country_code: str | None = Field(default=None, max_length=5)
|
||||
region_code: str | None = Field(default=None, max_length=10)
|
||||
|
||||
|
||||
class ConsentRecordResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
visitor_id: str
|
||||
action: str
|
||||
categories_accepted: list
|
||||
categories_rejected: list | None = None
|
||||
tc_string: str | None = None
|
||||
gcm_state: dict | None = None
|
||||
gpp_string: str | None = None
|
||||
gpc_detected: bool | None = None
|
||||
gpc_honoured: bool | None = None
|
||||
page_url: str | None = None
|
||||
country_code: str | None = None
|
||||
region_code: str | None = None
|
||||
consented_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ConsentVerifyResponse(BaseModel):
|
||||
"""Audit proof that a consent record exists."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
visitor_id: str
|
||||
action: str
|
||||
categories_accepted: list
|
||||
consented_at: datetime
|
||||
valid: bool = True
|
||||
210
apps/api/src/schemas/cookie.py
Normal file
210
apps/api/src/schemas/cookie.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Pydantic schemas for cookie categories, cookies, and allow-list entries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# ─── Cookie category schemas ───
|
||||
|
||||
|
||||
class CookieCategoryResponse(BaseModel):
|
||||
"""Response schema for a cookie category."""
|
||||
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
slug: str
|
||||
description: str | None = None
|
||||
is_essential: bool
|
||||
display_order: int
|
||||
tcf_purpose_ids: list[int] | None = None
|
||||
gcm_consent_types: list[str] | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Storage type enum ───
|
||||
|
||||
|
||||
class StorageType(StrEnum):
|
||||
"""Type of browser storage used by the cookie/tracker."""
|
||||
|
||||
cookie = "cookie"
|
||||
local_storage = "local_storage"
|
||||
session_storage = "session_storage"
|
||||
indexed_db = "indexed_db"
|
||||
|
||||
|
||||
# ─── Review status enum ───
|
||||
|
||||
|
||||
class ReviewStatus(StrEnum):
|
||||
"""Review status for a discovered cookie."""
|
||||
|
||||
pending = "pending"
|
||||
approved = "approved"
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
# ─── Cookie schemas ───
|
||||
|
||||
|
||||
class CookieCreate(BaseModel):
|
||||
"""Schema for creating a cookie record (typically from scanner/reporter)."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
domain: str = Field(..., min_length=1, max_length=255)
|
||||
storage_type: StorageType = StorageType.cookie
|
||||
category_id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
path: str | None = Field(None, max_length=500)
|
||||
max_age_seconds: int | None = None
|
||||
is_http_only: bool | None = None
|
||||
is_secure: bool | None = None
|
||||
same_site: str | None = Field(None, max_length=10)
|
||||
|
||||
|
||||
class CookieUpdate(BaseModel):
|
||||
"""Schema for updating a cookie record."""
|
||||
|
||||
category_id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
review_status: ReviewStatus | None = None
|
||||
|
||||
|
||||
class CookieResponse(BaseModel):
|
||||
"""Response schema for a cookie."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
category_id: uuid.UUID | None = None
|
||||
name: str
|
||||
domain: str
|
||||
storage_type: str
|
||||
description: str | None = None
|
||||
vendor: str | None = None
|
||||
path: str | None = None
|
||||
max_age_seconds: int | None = None
|
||||
is_http_only: bool | None = None
|
||||
is_secure: bool | None = None
|
||||
same_site: str | None = None
|
||||
review_status: str
|
||||
first_seen_at: str | None = None
|
||||
last_seen_at: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Allow-list schemas ───
|
||||
|
||||
|
||||
class AllowListEntryCreate(BaseModel):
|
||||
"""Schema for adding a cookie to the allow-list."""
|
||||
|
||||
name_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
domain_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
category_id: uuid.UUID
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class AllowListEntryUpdate(BaseModel):
|
||||
"""Schema for updating an allow-list entry."""
|
||||
|
||||
category_id: uuid.UUID | None = None
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class AllowListEntryResponse(BaseModel):
|
||||
"""Response schema for an allow-list entry."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
category_id: uuid.UUID
|
||||
name_pattern: str
|
||||
domain_pattern: str
|
||||
description: str | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Known cookie schemas ───
|
||||
|
||||
|
||||
class KnownCookieCreate(BaseModel):
|
||||
"""Schema for creating a known cookie pattern."""
|
||||
|
||||
name_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
domain_pattern: str = Field(..., min_length=1, max_length=255)
|
||||
category_id: uuid.UUID
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
description: str | None = None
|
||||
is_regex: bool = False
|
||||
|
||||
|
||||
class KnownCookieUpdate(BaseModel):
|
||||
"""Schema for updating a known cookie pattern."""
|
||||
|
||||
category_id: uuid.UUID | None = None
|
||||
vendor: str | None = Field(None, max_length=255)
|
||||
description: str | None = None
|
||||
is_regex: bool | None = None
|
||||
|
||||
|
||||
class KnownCookieResponse(BaseModel):
|
||||
"""Response schema for a known cookie pattern."""
|
||||
|
||||
id: uuid.UUID
|
||||
name_pattern: str
|
||||
domain_pattern: str
|
||||
category_id: uuid.UUID
|
||||
vendor: str | None = None
|
||||
description: str | None = None
|
||||
is_regex: bool
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ─── Classification schemas ───
|
||||
|
||||
|
||||
class ClassificationResultResponse(BaseModel):
|
||||
"""Response for a single cookie classification result."""
|
||||
|
||||
cookie_name: str
|
||||
cookie_domain: str
|
||||
category_id: uuid.UUID | None = None
|
||||
category_slug: str | None = None
|
||||
vendor: str | None = None
|
||||
description: str | None = None
|
||||
match_source: str
|
||||
matched: bool
|
||||
|
||||
|
||||
class ClassifySiteResponse(BaseModel):
|
||||
"""Response for classifying all cookies on a site."""
|
||||
|
||||
site_id: str
|
||||
total: int
|
||||
matched: int
|
||||
unmatched: int
|
||||
results: list[ClassificationResultResponse]
|
||||
|
||||
|
||||
class ClassifySingleRequest(BaseModel):
|
||||
"""Request to classify a single cookie (preview/test)."""
|
||||
|
||||
cookie_name: str = Field(..., min_length=1, max_length=255)
|
||||
cookie_domain: str = Field(..., min_length=1, max_length=255)
|
||||
61
apps/api/src/schemas/org_config.py
Normal file
61
apps/api/src/schemas/org_config.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.schemas.site import BlockingMode
|
||||
|
||||
|
||||
class OrgConfigUpdate(BaseModel):
|
||||
"""Update (or create) organisation-level default configuration.
|
||||
|
||||
All fields are optional — only non-None values override the system defaults.
|
||||
"""
|
||||
|
||||
blocking_mode: BlockingMode | None = None
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool | None = None
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool | None = None
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool | None = None
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool | None = None
|
||||
gcm_enabled: bool | None = None
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool | None = None
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
|
||||
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
|
||||
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class OrgConfigResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
blocking_mode: str | None
|
||||
regional_modes: dict | None
|
||||
tcf_enabled: bool | None
|
||||
tcf_publisher_cc: str | None
|
||||
gpp_enabled: bool | None
|
||||
gpp_supported_apis: list[str] | None
|
||||
gpc_enabled: bool | None
|
||||
gpc_jurisdictions: list[str] | None
|
||||
gpc_global_honour: bool | None
|
||||
gcm_enabled: bool | None
|
||||
gcm_default: dict | None
|
||||
shopify_privacy_enabled: bool | None
|
||||
banner_config: dict | None
|
||||
privacy_policy_url: str | None
|
||||
terms_url: str | None
|
||||
scan_schedule_cron: str | None
|
||||
scan_max_pages: int | None
|
||||
consent_expiry_days: int | None
|
||||
consent_retention_days: int | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
29
apps/api/src/schemas/organisation.py
Normal file
29
apps/api/src/schemas/organisation.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class OrganisationCreate(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
slug: str = Field(min_length=1, max_length=100, pattern=r"^[a-z0-9-]+$")
|
||||
contact_email: str | None = None
|
||||
billing_plan: str = "free"
|
||||
|
||||
|
||||
class OrganisationUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
contact_email: str | None = None
|
||||
billing_plan: str | None = None
|
||||
|
||||
|
||||
class OrganisationResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
name: str
|
||||
slug: str
|
||||
contact_email: str | None
|
||||
billing_plan: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
142
apps/api/src/schemas/scanner.py
Normal file
142
apps/api/src/schemas/scanner.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""Pydantic schemas for scanner and client-side cookie reports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ScanStatus(StrEnum):
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class ScanTrigger(StrEnum):
|
||||
MANUAL = "manual"
|
||||
SCHEDULED = "scheduled"
|
||||
CLIENT_REPORT = "client_report"
|
||||
|
||||
|
||||
# ── Client-side cookie report ────────────────────────────────────────
|
||||
|
||||
|
||||
class ReportedCookie(BaseModel):
|
||||
"""A single cookie/storage item reported by the client-side reporter."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
domain: str = Field(..., min_length=1, max_length=255)
|
||||
storage_type: str = Field(default="cookie", max_length=30)
|
||||
value_length: int = Field(default=0, ge=0)
|
||||
path: str | None = None
|
||||
is_secure: bool | None = None
|
||||
same_site: str | None = None
|
||||
script_source: str | None = None
|
||||
|
||||
|
||||
class CookieReportRequest(BaseModel):
|
||||
"""Payload from the client-side cookie reporter."""
|
||||
|
||||
site_id: uuid.UUID
|
||||
page_url: str = Field(..., max_length=2000)
|
||||
cookies: list[ReportedCookie] = Field(..., max_length=500)
|
||||
collected_at: datetime
|
||||
user_agent: str = Field(default="", max_length=500)
|
||||
|
||||
|
||||
class CookieReportResponse(BaseModel):
|
||||
"""Acknowledgement response for a cookie report."""
|
||||
|
||||
accepted: bool = True
|
||||
cookies_received: int
|
||||
new_cookies: int = 0
|
||||
|
||||
|
||||
# ── Scan job schemas ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class ScanResultResponse(BaseModel):
|
||||
"""A single scan result — a cookie found on a specific page."""
|
||||
|
||||
id: uuid.UUID
|
||||
scan_job_id: uuid.UUID
|
||||
page_url: str
|
||||
cookie_name: str
|
||||
cookie_domain: str
|
||||
storage_type: str
|
||||
attributes: dict | None = None
|
||||
script_source: str | None = None
|
||||
auto_category: str | None = None
|
||||
initiator_chain: list[str] | None = None
|
||||
found_at: datetime
|
||||
created_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ScanJobResponse(BaseModel):
|
||||
"""Response schema for a scan job."""
|
||||
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
status: str
|
||||
trigger: str
|
||||
pages_scanned: int
|
||||
pages_total: int | None
|
||||
cookies_found: int
|
||||
error_message: str | None
|
||||
started_at: datetime | None
|
||||
completed_at: datetime | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
class ScanJobDetailResponse(ScanJobResponse):
|
||||
"""Scan job with results included."""
|
||||
|
||||
results: list[ScanResultResponse] = []
|
||||
|
||||
|
||||
class TriggerScanRequest(BaseModel):
|
||||
"""Request to trigger a new scan."""
|
||||
|
||||
site_id: uuid.UUID
|
||||
max_pages: int = Field(default=50, ge=1, le=500)
|
||||
|
||||
|
||||
# ── Diff engine schemas ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class DiffStatus(StrEnum):
|
||||
NEW = "new"
|
||||
REMOVED = "removed"
|
||||
CHANGED = "changed"
|
||||
|
||||
|
||||
class CookieDiffItem(BaseModel):
|
||||
"""A single cookie difference between two scans."""
|
||||
|
||||
name: str
|
||||
domain: str
|
||||
storage_type: str
|
||||
diff_status: DiffStatus
|
||||
details: str | None = None
|
||||
|
||||
|
||||
class ScanDiffResponse(BaseModel):
|
||||
"""Diff between two scans."""
|
||||
|
||||
current_scan_id: uuid.UUID
|
||||
previous_scan_id: uuid.UUID | None
|
||||
new_cookies: list[CookieDiffItem] = []
|
||||
removed_cookies: list[CookieDiffItem] = []
|
||||
changed_cookies: list[CookieDiffItem] = []
|
||||
total_new: int = 0
|
||||
total_removed: int = 0
|
||||
total_changed: int = 0
|
||||
117
apps/api/src/schemas/site.py
Normal file
117
apps/api/src/schemas/site.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class BlockingMode(StrEnum):
|
||||
OPT_IN = "opt_in"
|
||||
OPT_OUT = "opt_out"
|
||||
INFORMATIONAL = "informational"
|
||||
|
||||
|
||||
# ── Site schemas ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class SiteCreate(BaseModel):
|
||||
domain: str = Field(min_length=1, max_length=255)
|
||||
display_name: str = Field(min_length=1, max_length=255)
|
||||
additional_domains: list[str] | None = None
|
||||
site_group_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class SiteUpdate(BaseModel):
|
||||
display_name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
is_active: bool | None = None
|
||||
additional_domains: list[str] | None = None
|
||||
site_group_id: uuid.UUID | None = None
|
||||
|
||||
|
||||
class SiteResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
domain: str
|
||||
display_name: str
|
||||
is_active: bool
|
||||
additional_domains: list[str] | None = None
|
||||
site_group_id: uuid.UUID | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
|
||||
|
||||
# ── Site config schemas ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class SiteConfigCreate(BaseModel):
|
||||
blocking_mode: BlockingMode = BlockingMode.OPT_IN
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool = False
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool = True
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool = True
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool = False
|
||||
gcm_enabled: bool = True
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool = False
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int = Field(default=50, ge=1, le=1000)
|
||||
consent_expiry_days: int = Field(default=365, ge=1, le=730)
|
||||
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class SiteConfigUpdate(BaseModel):
|
||||
blocking_mode: BlockingMode | None = None
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool | None = None
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool | None = None
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool | None = None
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool | None = None
|
||||
gcm_enabled: bool | None = None
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool | None = None
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
|
||||
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
|
||||
consent_retention_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class SiteConfigResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
blocking_mode: str
|
||||
regional_modes: dict | None
|
||||
tcf_enabled: bool
|
||||
tcf_publisher_cc: str | None = None
|
||||
gpp_enabled: bool = True
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool = True
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool = False
|
||||
gcm_enabled: bool
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool = False
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int = 50
|
||||
consent_expiry_days: int = 365
|
||||
consent_retention_days: int | None = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
26
apps/api/src/schemas/site_group.py
Normal file
26
apps/api/src/schemas/site_group.py
Normal file
@@ -0,0 +1,26 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SiteGroupCreate(BaseModel):
|
||||
name: str = Field(min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SiteGroupUpdate(BaseModel):
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = None
|
||||
|
||||
|
||||
class SiteGroupResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
name: str
|
||||
description: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
site_count: int = 0
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
59
apps/api/src/schemas/site_group_config.py
Normal file
59
apps/api/src/schemas/site_group_config.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.schemas.site import BlockingMode
|
||||
|
||||
|
||||
class SiteGroupConfigUpdate(BaseModel):
|
||||
"""Update (or create) site-group-level default configuration.
|
||||
|
||||
All fields are optional — only non-None values override the org/system defaults.
|
||||
"""
|
||||
|
||||
blocking_mode: BlockingMode | None = None
|
||||
regional_modes: dict | None = None
|
||||
tcf_enabled: bool | None = None
|
||||
tcf_publisher_cc: str | None = Field(default=None, max_length=2)
|
||||
gpp_enabled: bool | None = None
|
||||
gpp_supported_apis: list[str] | None = None
|
||||
gpc_enabled: bool | None = None
|
||||
gpc_jurisdictions: list[str] | None = None
|
||||
gpc_global_honour: bool | None = None
|
||||
gcm_enabled: bool | None = None
|
||||
gcm_default: dict | None = None
|
||||
shopify_privacy_enabled: bool | None = None
|
||||
banner_config: dict | None = None
|
||||
privacy_policy_url: str | None = None
|
||||
terms_url: str | None = None
|
||||
scan_schedule_cron: str | None = None
|
||||
scan_max_pages: int | None = Field(default=None, ge=1, le=1000)
|
||||
consent_expiry_days: int | None = Field(default=None, ge=1, le=730)
|
||||
|
||||
|
||||
class SiteGroupConfigResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_group_id: uuid.UUID
|
||||
blocking_mode: str | None
|
||||
regional_modes: dict | None
|
||||
tcf_enabled: bool | None
|
||||
tcf_publisher_cc: str | None
|
||||
gpp_enabled: bool | None
|
||||
gpp_supported_apis: list[str] | None
|
||||
gpc_enabled: bool | None
|
||||
gpc_jurisdictions: list[str] | None
|
||||
gpc_global_honour: bool | None
|
||||
gcm_enabled: bool | None
|
||||
gcm_default: dict | None
|
||||
shopify_privacy_enabled: bool | None
|
||||
banner_config: dict | None
|
||||
privacy_policy_url: str | None
|
||||
terms_url: str | None
|
||||
scan_schedule_cron: str | None
|
||||
scan_max_pages: int | None
|
||||
consent_expiry_days: int | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
24
apps/api/src/schemas/translation.py
Normal file
24
apps/api/src/schemas/translation.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TranslationCreate(BaseModel):
|
||||
locale: str = Field(min_length=2, max_length=10)
|
||||
strings: dict[str, str]
|
||||
|
||||
|
||||
class TranslationUpdate(BaseModel):
|
||||
strings: dict[str, str]
|
||||
|
||||
|
||||
class TranslationResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
site_id: uuid.UUID
|
||||
locale: str
|
||||
strings: dict[str, str]
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
36
apps/api/src/schemas/user.py
Normal file
36
apps/api/src/schemas/user.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, EmailStr, Field
|
||||
|
||||
|
||||
class UserRole(StrEnum):
|
||||
OWNER = "owner"
|
||||
ADMIN = "admin"
|
||||
EDITOR = "editor"
|
||||
VIEWER = "viewer"
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(min_length=8, max_length=72)
|
||||
full_name: str = Field(min_length=1, max_length=255)
|
||||
role: UserRole = UserRole.VIEWER
|
||||
|
||||
|
||||
class UserUpdate(BaseModel):
|
||||
full_name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
role: UserRole | None = None
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: uuid.UUID
|
||||
organisation_id: uuid.UUID
|
||||
email: str
|
||||
full_name: str
|
||||
role: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = {"from_attributes": True}
|
||||
0
apps/api/src/services/__init__.py
Normal file
0
apps/api/src/services/__init__.py
Normal file
59
apps/api/src/services/auth.py
Normal file
59
apps/api/src/services/auth.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import bcrypt
|
||||
from jose import jwt
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
|
||||
|
||||
|
||||
def create_access_token(
|
||||
user_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
role: str,
|
||||
email: str,
|
||||
) -> str:
|
||||
settings = get_settings()
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(minutes=settings.jwt_access_token_expire_minutes)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"org_id": str(organisation_id),
|
||||
"role": role,
|
||||
"email": email,
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"type": "access",
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def create_refresh_token(
|
||||
user_id: uuid.UUID,
|
||||
organisation_id: uuid.UUID,
|
||||
) -> str:
|
||||
settings = get_settings()
|
||||
now = datetime.now(UTC)
|
||||
expire = now + timedelta(days=settings.jwt_refresh_token_expire_days)
|
||||
payload = {
|
||||
"sub": str(user_id),
|
||||
"org_id": str(organisation_id),
|
||||
"exp": expire,
|
||||
"iat": now,
|
||||
"type": "refresh",
|
||||
}
|
||||
return jwt.encode(payload, settings.jwt_secret_key, algorithm=settings.jwt_algorithm)
|
||||
|
||||
|
||||
def decode_token(token: str) -> dict:
|
||||
"""Decode and validate a JWT token. Raises JWTError on failure."""
|
||||
settings = get_settings()
|
||||
return jwt.decode(token, settings.jwt_secret_key, algorithms=[settings.jwt_algorithm])
|
||||
79
apps/api/src/services/bootstrap.py
Normal file
79
apps/api/src/services/bootstrap.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""First-run bootstrap of an organisation and owner user.
|
||||
|
||||
Runs once on API startup. If ``INITIAL_ADMIN_EMAIL`` and
|
||||
``INITIAL_ADMIN_PASSWORD`` are set and the ``users`` table is empty,
|
||||
creates an organisation and a single owner user so the operator can log
|
||||
in to the admin UI for the first time. Idempotent: once any user
|
||||
exists, this is a no-op, so the environment variables can safely remain
|
||||
set across restarts. Complements ``ADMIN_BOOTSTRAP_TOKEN`` — that gates
|
||||
runtime org creation; this creates the *initial* org + owner without
|
||||
requiring a second round-trip.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.config.settings import Settings
|
||||
from src.db.session import async_session_factory
|
||||
from src.models.organisation import Organisation
|
||||
from src.models.user import User
|
||||
from src.services.auth import hash_password
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def bootstrap_initial_admin(settings: Settings) -> None:
|
||||
"""Create the first organisation and owner user if none exist.
|
||||
|
||||
No-op when either credential env var is unset or when the database
|
||||
already contains at least one user. Unexpected errors are logged
|
||||
and swallowed — a failed bootstrap must not prevent the API from
|
||||
starting, since operators can always fall back to manual provisioning.
|
||||
"""
|
||||
if not settings.initial_admin_email or not settings.initial_admin_password:
|
||||
logger.debug("Initial admin bootstrap skipped: credentials not configured")
|
||||
return
|
||||
|
||||
try:
|
||||
async with async_session_factory() as session:
|
||||
await _bootstrap(session, settings)
|
||||
except Exception: # pragma: no cover — defensive, logged
|
||||
logger.exception("Initial admin bootstrap failed")
|
||||
|
||||
|
||||
async def _bootstrap(session: AsyncSession, settings: Settings) -> None:
|
||||
existing_users = await session.scalar(select(func.count()).select_from(User))
|
||||
if existing_users:
|
||||
logger.debug("Initial admin bootstrap skipped: %d user(s) already exist", existing_users)
|
||||
return
|
||||
|
||||
org = await session.scalar(
|
||||
select(Organisation).where(Organisation.slug == settings.initial_org_slug)
|
||||
)
|
||||
if org is None:
|
||||
org = Organisation(
|
||||
name=settings.initial_org_name,
|
||||
slug=settings.initial_org_slug,
|
||||
contact_email=settings.initial_admin_email,
|
||||
)
|
||||
session.add(org)
|
||||
await session.flush()
|
||||
|
||||
user = User(
|
||||
organisation_id=org.id,
|
||||
email=settings.initial_admin_email,
|
||||
password_hash=hash_password(settings.initial_admin_password),
|
||||
full_name=settings.initial_admin_full_name,
|
||||
role="owner",
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
|
||||
logger.warning(
|
||||
"Initial admin bootstrap created owner %s in organisation '%s'. "
|
||||
"Rotate the password via the admin UI as soon as possible.",
|
||||
settings.initial_admin_email,
|
||||
org.slug,
|
||||
)
|
||||
298
apps/api/src/services/classification.py
Normal file
298
apps/api/src/services/classification.py
Normal file
@@ -0,0 +1,298 @@
|
||||
"""Cookie auto-categorisation engine.
|
||||
|
||||
Matches discovered cookies against the known_cookies database using exact name
|
||||
matching, domain matching, and regex patterns. Also checks site-specific
|
||||
allow-list entries. Returns a classification result with category, vendor, and
|
||||
confidence level.
|
||||
|
||||
Matching priority (highest first):
|
||||
1. Site-specific allow-list (exact or pattern match)
|
||||
2. Known cookies — exact name + domain match
|
||||
3. Known cookies — regex pattern match on name + domain
|
||||
4. Unmatched → remains as 'pending'
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.cookie import (
|
||||
Cookie,
|
||||
CookieAllowListEntry,
|
||||
CookieCategory,
|
||||
KnownCookie,
|
||||
)
|
||||
|
||||
|
||||
class MatchSource(StrEnum):
|
||||
"""Where the classification match came from."""
|
||||
|
||||
ALLOW_LIST = "allow_list"
|
||||
KNOWN_EXACT = "known_exact"
|
||||
KNOWN_REGEX = "known_regex"
|
||||
UNMATCHED = "unmatched"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClassificationResult:
|
||||
"""Result of classifying a single cookie."""
|
||||
|
||||
cookie_name: str
|
||||
cookie_domain: str
|
||||
category_id: uuid.UUID | None = None
|
||||
category_slug: str | None = None
|
||||
vendor: str | None = None
|
||||
description: str | None = None
|
||||
match_source: MatchSource = MatchSource.UNMATCHED
|
||||
matched: bool = False
|
||||
|
||||
|
||||
async def _load_allow_list(
|
||||
db: AsyncSession,
|
||||
site_id: uuid.UUID,
|
||||
) -> list[CookieAllowListEntry]:
|
||||
"""Load the allow-list entries for a site."""
|
||||
result = await db.execute(
|
||||
select(CookieAllowListEntry).where(
|
||||
CookieAllowListEntry.site_id == site_id,
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
|
||||
|
||||
async def _load_known_cookies(
|
||||
db: AsyncSession,
|
||||
) -> tuple[list[KnownCookie], list[KnownCookie]]:
|
||||
"""Load known cookies, split into exact and regex lists."""
|
||||
result = await db.execute(select(KnownCookie))
|
||||
all_known = list(result.scalars().all())
|
||||
|
||||
exact = [k for k in all_known if not k.is_regex]
|
||||
regex = [k for k in all_known if k.is_regex]
|
||||
return exact, regex
|
||||
|
||||
|
||||
async def _load_category_map(
|
||||
db: AsyncSession,
|
||||
) -> dict[uuid.UUID, CookieCategory]:
|
||||
"""Load a mapping of category ID to CookieCategory."""
|
||||
result = await db.execute(select(CookieCategory))
|
||||
return {cat.id: cat for cat in result.scalars().all()}
|
||||
|
||||
|
||||
def _match_pattern(pattern: str, value: str) -> bool:
|
||||
"""Check if a value matches a pattern (case-insensitive).
|
||||
|
||||
Patterns support:
|
||||
- Exact match (e.g. "_ga")
|
||||
- Wildcard with * (e.g. "_ga*", "*.google.com")
|
||||
- Regex if it contains regex-specific characters
|
||||
"""
|
||||
if not pattern or not value:
|
||||
return False
|
||||
|
||||
pattern_lower = pattern.lower()
|
||||
value_lower = value.lower()
|
||||
|
||||
# Simple exact match
|
||||
if pattern_lower == value_lower:
|
||||
return True
|
||||
|
||||
# Wildcard: convert * to regex .*
|
||||
if "*" in pattern_lower:
|
||||
regex_pattern = "^" + re.escape(pattern_lower).replace(r"\*", ".*") + "$"
|
||||
return bool(re.match(regex_pattern, value_lower))
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def _match_regex(pattern: str, value: str) -> bool:
|
||||
"""Match a value against a regex pattern (case-insensitive)."""
|
||||
try:
|
||||
return bool(re.match(pattern, value, re.IGNORECASE))
|
||||
except re.error:
|
||||
return False
|
||||
|
||||
|
||||
def _match_allow_list(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
allow_list: list[CookieAllowListEntry],
|
||||
) -> CookieAllowListEntry | None:
|
||||
"""Check if a cookie matches any allow-list entry."""
|
||||
for entry in allow_list:
|
||||
name_match = _match_pattern(entry.name_pattern, cookie_name)
|
||||
domain_match = _match_pattern(entry.domain_pattern, cookie_domain)
|
||||
if name_match and domain_match:
|
||||
return entry
|
||||
return None
|
||||
|
||||
|
||||
def _match_exact_known(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
exact_known: list[KnownCookie],
|
||||
) -> KnownCookie | None:
|
||||
"""Find an exact match in the known cookies database."""
|
||||
for known in exact_known:
|
||||
name_match = _match_pattern(known.name_pattern, cookie_name)
|
||||
domain_match = _match_pattern(known.domain_pattern, cookie_domain)
|
||||
if name_match and domain_match:
|
||||
return known
|
||||
return None
|
||||
|
||||
|
||||
def _match_regex_known(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
regex_known: list[KnownCookie],
|
||||
) -> KnownCookie | None:
|
||||
"""Find a regex match in the known cookies database."""
|
||||
for known in regex_known:
|
||||
name_match = _match_regex(known.name_pattern, cookie_name)
|
||||
domain_match = _match_regex(known.domain_pattern, cookie_domain)
|
||||
if name_match and domain_match:
|
||||
return known
|
||||
return None
|
||||
|
||||
|
||||
def classify_cookie(
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
allow_list: list[CookieAllowListEntry],
|
||||
exact_known: list[KnownCookie],
|
||||
regex_known: list[KnownCookie],
|
||||
category_map: dict[uuid.UUID, CookieCategory],
|
||||
) -> ClassificationResult:
|
||||
"""Classify a single cookie against allow-list and known cookies DB.
|
||||
|
||||
This is a pure function — all data is passed in, no DB calls.
|
||||
"""
|
||||
# 1. Check allow-list first (site-specific overrides)
|
||||
allow_match = _match_allow_list(cookie_name, cookie_domain, allow_list)
|
||||
if allow_match:
|
||||
cat = category_map.get(allow_match.category_id)
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
category_id=allow_match.category_id,
|
||||
category_slug=cat.slug if cat else None,
|
||||
description=allow_match.description,
|
||||
match_source=MatchSource.ALLOW_LIST,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
# 2. Check exact known cookies
|
||||
exact_match = _match_exact_known(cookie_name, cookie_domain, exact_known)
|
||||
if exact_match:
|
||||
cat = category_map.get(exact_match.category_id)
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
category_id=exact_match.category_id,
|
||||
category_slug=cat.slug if cat else None,
|
||||
vendor=exact_match.vendor,
|
||||
description=exact_match.description,
|
||||
match_source=MatchSource.KNOWN_EXACT,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
# 3. Check regex known cookies
|
||||
regex_match = _match_regex_known(cookie_name, cookie_domain, regex_known)
|
||||
if regex_match:
|
||||
cat = category_map.get(regex_match.category_id)
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
category_id=regex_match.category_id,
|
||||
category_slug=cat.slug if cat else None,
|
||||
vendor=regex_match.vendor,
|
||||
description=regex_match.description,
|
||||
match_source=MatchSource.KNOWN_REGEX,
|
||||
matched=True,
|
||||
)
|
||||
|
||||
# 4. Unmatched
|
||||
return ClassificationResult(
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
)
|
||||
|
||||
|
||||
async def classify_site_cookies(
|
||||
db: AsyncSession,
|
||||
site_id: uuid.UUID,
|
||||
*,
|
||||
only_pending: bool = True,
|
||||
) -> list[ClassificationResult]:
|
||||
"""Classify all cookies for a site against known patterns.
|
||||
|
||||
If only_pending is True, only cookies with review_status='pending'
|
||||
and no category are classified.
|
||||
|
||||
Returns a list of results. Also updates matching cookies in the DB.
|
||||
"""
|
||||
# Load lookup data
|
||||
allow_list = await _load_allow_list(db, site_id)
|
||||
exact_known, regex_known = await _load_known_cookies(db)
|
||||
category_map = await _load_category_map(db)
|
||||
|
||||
# Load cookies to classify
|
||||
query = select(Cookie).where(Cookie.site_id == site_id)
|
||||
if only_pending:
|
||||
query = query.where(
|
||||
Cookie.review_status == "pending",
|
||||
Cookie.category_id.is_(None),
|
||||
)
|
||||
result = await db.execute(query)
|
||||
cookies = list(result.scalars().all())
|
||||
|
||||
results: list[ClassificationResult] = []
|
||||
for cookie in cookies:
|
||||
cr = classify_cookie(
|
||||
cookie.name,
|
||||
cookie.domain,
|
||||
allow_list,
|
||||
exact_known,
|
||||
regex_known,
|
||||
category_map,
|
||||
)
|
||||
results.append(cr)
|
||||
|
||||
# Update the cookie if matched
|
||||
if cr.matched and cr.category_id:
|
||||
cookie.category_id = cr.category_id
|
||||
if cr.vendor and not cookie.vendor:
|
||||
cookie.vendor = cr.vendor
|
||||
if cr.description and not cookie.description:
|
||||
cookie.description = cr.description
|
||||
|
||||
await db.flush()
|
||||
return results
|
||||
|
||||
|
||||
async def classify_single_cookie(
|
||||
db: AsyncSession,
|
||||
site_id: uuid.UUID,
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
) -> ClassificationResult:
|
||||
"""Classify a single cookie (e.g. for preview/testing)."""
|
||||
allow_list = await _load_allow_list(db, site_id)
|
||||
exact_known, regex_known = await _load_known_cookies(db)
|
||||
category_map = await _load_category_map(db)
|
||||
|
||||
return classify_cookie(
|
||||
cookie_name,
|
||||
cookie_domain,
|
||||
allow_list,
|
||||
exact_known,
|
||||
regex_known,
|
||||
category_map,
|
||||
)
|
||||
482
apps/api/src/services/compliance.py
Normal file
482
apps/api/src/services/compliance.py
Normal file
@@ -0,0 +1,482 @@
|
||||
"""Pluggable compliance rule engine.
|
||||
|
||||
Each regulatory framework (GDPR, CNIL, CCPA, ePrivacy, LGPD) is defined as a
|
||||
list of ComplianceRule objects. Rules evaluate site configuration, banner
|
||||
settings, cookie data, and consent parameters to produce issues with severity,
|
||||
message, and recommendation.
|
||||
|
||||
The engine aggregates individual rule results into per-framework reports with
|
||||
a compliance score, status, and actionable issues list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from src.schemas.compliance import (
|
||||
ComplianceIssue,
|
||||
Framework,
|
||||
FrameworkResult,
|
||||
Severity,
|
||||
)
|
||||
|
||||
# ── Rule context ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class SiteContext:
|
||||
"""All data needed to evaluate compliance rules for a site."""
|
||||
|
||||
# Site config fields
|
||||
blocking_mode: str = "opt_in"
|
||||
regional_modes: dict[str, str] | None = None
|
||||
tcf_enabled: bool = False
|
||||
gcm_enabled: bool = True
|
||||
consent_expiry_days: int = 365
|
||||
privacy_policy_url: str | None = None
|
||||
|
||||
# Banner config (JSONB — may have any keys)
|
||||
banner_config: dict[str, Any] | None = None
|
||||
|
||||
# Cookie statistics
|
||||
total_cookies: int = 0
|
||||
uncategorised_cookies: int = 0
|
||||
cookies_without_expiry: int = 0
|
||||
|
||||
# Consent settings
|
||||
has_reject_button: bool = True
|
||||
has_granular_choices: bool = True
|
||||
has_cookie_wall: bool = False
|
||||
pre_ticked_boxes: bool = False
|
||||
|
||||
|
||||
# ── Rule definition ───────────────────────────────────────────────────
|
||||
|
||||
# A check function receives a SiteContext and returns a list of issues.
|
||||
CheckFn = Callable[[SiteContext], list[ComplianceIssue]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ComplianceRule:
|
||||
"""A single compliance rule with an ID, description, and check function."""
|
||||
|
||||
rule_id: str
|
||||
description: str
|
||||
check: CheckFn
|
||||
|
||||
|
||||
# ── Helper factories ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _issue(
|
||||
rule_id: str,
|
||||
severity: Severity,
|
||||
message: str,
|
||||
recommendation: str,
|
||||
) -> ComplianceIssue:
|
||||
return ComplianceIssue(
|
||||
rule_id=rule_id,
|
||||
severity=severity,
|
||||
message=message,
|
||||
recommendation=recommendation,
|
||||
)
|
||||
|
||||
|
||||
# ── GDPR rules ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _gdpr_opt_in(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.blocking_mode != "opt_in":
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_opt_in",
|
||||
Severity.CRITICAL,
|
||||
"GDPR requires opt-in consent before setting non-essential cookies.",
|
||||
"Set blocking mode to 'opt_in'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_reject_button(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.has_reject_button:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_reject_button",
|
||||
Severity.CRITICAL,
|
||||
"The reject option must be as prominent as the accept option.",
|
||||
"Add a clearly visible 'Reject all' button to the first layer.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_granular_consent(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.has_granular_choices:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_granular",
|
||||
Severity.CRITICAL,
|
||||
"Users must be able to consent to individual cookie categories.",
|
||||
"Provide granular category toggles in the consent banner.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_no_cookie_wall(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.has_cookie_wall:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_cookie_wall",
|
||||
Severity.CRITICAL,
|
||||
"Cookie walls (blocking access unless consent is given) are not permitted.",
|
||||
"Remove the cookie wall and allow access without consent.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_no_pre_ticked(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.pre_ticked_boxes:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_pre_ticked",
|
||||
Severity.CRITICAL,
|
||||
"Pre-ticked consent boxes do not constitute valid consent.",
|
||||
"Ensure all non-essential category checkboxes default to unchecked.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.privacy_policy_url:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_privacy_policy",
|
||||
Severity.WARNING,
|
||||
"A link to the privacy policy should be accessible from the banner.",
|
||||
"Configure a privacy policy URL in the site settings.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _gdpr_uncategorised_cookies(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if ctx.uncategorised_cookies > 0:
|
||||
return [
|
||||
_issue(
|
||||
"gdpr_uncategorised",
|
||||
Severity.WARNING,
|
||||
f"{ctx.uncategorised_cookies} cookie(s) have not been categorised.",
|
||||
"Review and assign a category to all discovered cookies.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
GDPR_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("gdpr_opt_in", "Opt-in consent required", _gdpr_opt_in),
|
||||
ComplianceRule("gdpr_reject_button", "Reject as prominent as accept", _gdpr_reject_button),
|
||||
ComplianceRule("gdpr_granular", "Granular category consent", _gdpr_granular_consent),
|
||||
ComplianceRule("gdpr_cookie_wall", "No cookie walls", _gdpr_no_cookie_wall),
|
||||
ComplianceRule("gdpr_pre_ticked", "No pre-ticked boxes", _gdpr_no_pre_ticked),
|
||||
ComplianceRule("gdpr_privacy_policy", "Privacy policy link", _gdpr_privacy_policy),
|
||||
ComplianceRule("gdpr_uncategorised", "All cookies categorised", _gdpr_uncategorised_cookies),
|
||||
]
|
||||
|
||||
|
||||
# ── CNIL rules (French — stricter GDPR) ──────────────────────────────
|
||||
|
||||
|
||||
def _cnil_consent_expiry(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CNIL mandates re-consent every 6 months (≈ 182 days)."""
|
||||
if ctx.consent_expiry_days > 182:
|
||||
return [
|
||||
_issue(
|
||||
"cnil_reconsent",
|
||||
Severity.CRITICAL,
|
||||
"CNIL requires re-consent at least every 6 months.",
|
||||
"Set consent_expiry_days to 182 or fewer.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _cnil_cookie_lifetime(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CNIL limits cookie lifetime to 13 months (≈ 395 days)."""
|
||||
if ctx.consent_expiry_days > 395:
|
||||
return [
|
||||
_issue(
|
||||
"cnil_cookie_lifetime",
|
||||
Severity.CRITICAL,
|
||||
"CNIL limits consent cookie lifetime to 13 months.",
|
||||
"Set consent_expiry_days to 395 or fewer.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _cnil_reject_first_layer(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CNIL requires 'Tout refuser' on the first layer of the banner."""
|
||||
if not ctx.has_reject_button:
|
||||
return [
|
||||
_issue(
|
||||
"cnil_reject_first_layer",
|
||||
Severity.CRITICAL,
|
||||
"CNIL requires a 'Reject all' button on the first layer of the banner.",
|
||||
"Ensure the 'Reject all' button is visible on the first banner view.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
# CNIL rules include all GDPR rules plus CNIL-specific ones
|
||||
CNIL_RULES: list[ComplianceRule] = [
|
||||
*GDPR_RULES,
|
||||
ComplianceRule("cnil_reconsent", "Re-consent every 6 months", _cnil_consent_expiry),
|
||||
ComplianceRule("cnil_cookie_lifetime", "13-month cookie lifetime", _cnil_cookie_lifetime),
|
||||
ComplianceRule(
|
||||
"cnil_reject_first_layer",
|
||||
"Reject on first layer",
|
||||
_cnil_reject_first_layer,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── CCPA / CPRA rules ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _ccpa_opt_out(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CCPA uses an opt-out model — blocking mode should be opt_out."""
|
||||
if ctx.blocking_mode not in ("opt_out", "opt_in"):
|
||||
return [
|
||||
_issue(
|
||||
"ccpa_opt_out",
|
||||
Severity.CRITICAL,
|
||||
"CCPA requires at minimum an opt-out mechanism for data sale.",
|
||||
"Set blocking mode to 'opt_out' or 'opt_in'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _ccpa_do_not_sell(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""CCPA requires a 'Do Not Sell My Personal Information' link."""
|
||||
bc = ctx.banner_config or {}
|
||||
has_dns = bc.get("show_do_not_sell_link", False)
|
||||
if not has_dns:
|
||||
return [
|
||||
_issue(
|
||||
"ccpa_do_not_sell",
|
||||
Severity.CRITICAL,
|
||||
"CCPA requires a 'Do Not Sell My Personal Information' link.",
|
||||
"Enable 'show_do_not_sell_link' in the banner configuration.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _ccpa_privacy_policy(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.privacy_policy_url:
|
||||
return [
|
||||
_issue(
|
||||
"ccpa_privacy_policy",
|
||||
Severity.WARNING,
|
||||
"A privacy policy is required under CCPA.",
|
||||
"Configure a privacy policy URL in the site settings.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
CCPA_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("ccpa_opt_out", "Opt-out mechanism", _ccpa_opt_out),
|
||||
ComplianceRule("ccpa_do_not_sell", "Do Not Sell link", _ccpa_do_not_sell),
|
||||
ComplianceRule("ccpa_privacy_policy", "Privacy policy required", _ccpa_privacy_policy),
|
||||
]
|
||||
|
||||
|
||||
# ── ePrivacy rules ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _eprivacy_consent(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""ePrivacy requires consent for non-essential cookies."""
|
||||
if ctx.blocking_mode == "informational":
|
||||
return [
|
||||
_issue(
|
||||
"eprivacy_consent",
|
||||
Severity.CRITICAL,
|
||||
"ePrivacy Directive requires consent for non-essential cookies.",
|
||||
"Set blocking mode to 'opt_in' or 'opt_out'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _eprivacy_necessary_exempt(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""Strictly necessary cookies must be exempt from consent."""
|
||||
# This is a configuration guidance check — ensure opt-in mode
|
||||
# doesn't block necessary cookies (which the blocker handles by default).
|
||||
# We report an info if everything looks good.
|
||||
return []
|
||||
|
||||
|
||||
EPRIVACY_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("eprivacy_consent", "Consent for non-essential", _eprivacy_consent),
|
||||
ComplianceRule(
|
||||
"eprivacy_necessary_exempt",
|
||||
"Necessary cookies exempt",
|
||||
_eprivacy_necessary_exempt,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ── LGPD rules (Brazil) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def _lgpd_consent_basis(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""LGPD requires consent or legitimate interest as legal basis."""
|
||||
if ctx.blocking_mode == "informational":
|
||||
return [
|
||||
_issue(
|
||||
"lgpd_consent_basis",
|
||||
Severity.CRITICAL,
|
||||
"LGPD requires a legal basis (consent or legitimate interest) for data processing.",
|
||||
"Set blocking mode to 'opt_in' or 'opt_out'.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _lgpd_data_controller(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
"""LGPD requires identifying the data controller."""
|
||||
if not ctx.privacy_policy_url:
|
||||
return [
|
||||
_issue(
|
||||
"lgpd_data_controller",
|
||||
Severity.WARNING,
|
||||
"LGPD requires identification of the data controller.",
|
||||
"Link to a privacy policy that identifies the data controller.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
def _lgpd_granular(ctx: SiteContext) -> list[ComplianceIssue]:
|
||||
if not ctx.has_granular_choices:
|
||||
return [
|
||||
_issue(
|
||||
"lgpd_granular",
|
||||
Severity.WARNING,
|
||||
"LGPD recommends granular consent choices.",
|
||||
"Provide individual category toggles in the consent banner.",
|
||||
)
|
||||
]
|
||||
return []
|
||||
|
||||
|
||||
LGPD_RULES: list[ComplianceRule] = [
|
||||
ComplianceRule("lgpd_consent_basis", "Legal basis for processing", _lgpd_consent_basis),
|
||||
ComplianceRule("lgpd_data_controller", "Identify data controller", _lgpd_data_controller),
|
||||
ComplianceRule("lgpd_granular", "Granular consent choices", _lgpd_granular),
|
||||
]
|
||||
|
||||
|
||||
# ── Framework registry ────────────────────────────────────────────────
|
||||
|
||||
FRAMEWORK_RULES: dict[Framework, list[ComplianceRule]] = {
|
||||
Framework.GDPR: GDPR_RULES,
|
||||
Framework.CNIL: CNIL_RULES,
|
||||
Framework.CCPA: CCPA_RULES,
|
||||
Framework.EPRIVACY: EPRIVACY_RULES,
|
||||
Framework.LGPD: LGPD_RULES,
|
||||
}
|
||||
|
||||
|
||||
# ── Engine ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def run_framework_check(
|
||||
framework: Framework,
|
||||
ctx: SiteContext,
|
||||
) -> FrameworkResult:
|
||||
"""Run all rules for a single framework and produce a result."""
|
||||
rules = FRAMEWORK_RULES.get(framework, [])
|
||||
all_issues: list[ComplianceIssue] = []
|
||||
rules_passed = 0
|
||||
|
||||
for rule in rules:
|
||||
issues = rule.check(ctx)
|
||||
if issues:
|
||||
all_issues.extend(issues)
|
||||
else:
|
||||
rules_passed += 1
|
||||
|
||||
rules_checked = len(rules)
|
||||
score = _calculate_score(all_issues, rules_checked)
|
||||
status = _determine_status(score, all_issues)
|
||||
|
||||
return FrameworkResult(
|
||||
framework=framework,
|
||||
score=score,
|
||||
status=status,
|
||||
issues=all_issues,
|
||||
rules_checked=rules_checked,
|
||||
rules_passed=rules_passed,
|
||||
)
|
||||
|
||||
|
||||
def run_compliance_check(
|
||||
ctx: SiteContext,
|
||||
frameworks: list[Framework] | None = None,
|
||||
) -> list[FrameworkResult]:
|
||||
"""Run compliance checks for the specified (or all) frameworks."""
|
||||
targets = frameworks if frameworks else list(FRAMEWORK_RULES.keys())
|
||||
return [run_framework_check(fw, ctx) for fw in targets]
|
||||
|
||||
|
||||
def calculate_overall_score(results: list[FrameworkResult]) -> int:
|
||||
"""Calculate a weighted average score across framework results."""
|
||||
if not results:
|
||||
return 100
|
||||
total = sum(r.score for r in results)
|
||||
return round(total / len(results))
|
||||
|
||||
|
||||
# ── Scoring helpers ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _calculate_score(
|
||||
issues: list[ComplianceIssue],
|
||||
rules_checked: int,
|
||||
) -> int:
|
||||
"""Score from 0-100. Critical issues deduct 20 pts, warnings 5 pts."""
|
||||
if rules_checked == 0:
|
||||
return 100
|
||||
|
||||
deductions = 0
|
||||
for issue in issues:
|
||||
if issue.severity == Severity.CRITICAL:
|
||||
deductions += 20
|
||||
elif issue.severity == Severity.WARNING:
|
||||
deductions += 5
|
||||
# INFO issues don't affect the score
|
||||
|
||||
return max(0, 100 - deductions)
|
||||
|
||||
|
||||
def _determine_status(
|
||||
score: int,
|
||||
issues: list[ComplianceIssue],
|
||||
) -> str:
|
||||
"""Derive overall status string from score and issues."""
|
||||
has_critical = any(i.severity == Severity.CRITICAL for i in issues)
|
||||
if has_critical:
|
||||
return "non_compliant"
|
||||
if score >= 100:
|
||||
return "compliant"
|
||||
return "partial"
|
||||
156
apps/api/src/services/config_resolver.py
Normal file
156
apps/api/src/services/config_resolver.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Configuration hierarchy resolver.
|
||||
|
||||
Resolves site configuration by merging:
|
||||
System Defaults → Org Defaults → Site Group Defaults → Site Config → Regional Overrides
|
||||
|
||||
Produces a fully resolved public config suitable for the banner script.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
# System-level defaults (hard-coded, lowest priority)
|
||||
SYSTEM_DEFAULTS: dict[str, Any] = {
|
||||
"blocking_mode": "opt_in",
|
||||
"tcf_enabled": False,
|
||||
"gpp_enabled": True,
|
||||
"gpp_supported_apis": ["usnat"],
|
||||
"gpc_enabled": True,
|
||||
"gpc_jurisdictions": ["US-CA", "US-CO", "US-CT", "US-TX", "US-MT"],
|
||||
"gpc_global_honour": False,
|
||||
"gcm_enabled": True,
|
||||
"shopify_privacy_enabled": False,
|
||||
"gcm_default": {
|
||||
"ad_storage": "denied",
|
||||
"ad_user_data": "denied",
|
||||
"ad_personalization": "denied",
|
||||
"analytics_storage": "denied",
|
||||
"functionality_storage": "denied",
|
||||
"personalization_storage": "denied",
|
||||
"security_storage": "granted",
|
||||
},
|
||||
"banner_config": None,
|
||||
"privacy_policy_url": None,
|
||||
"terms_url": None,
|
||||
"consent_expiry_days": 365,
|
||||
}
|
||||
|
||||
|
||||
def resolve_config(
|
||||
site_config: dict[str, Any],
|
||||
org_defaults: dict[str, Any] | None = None,
|
||||
group_defaults: dict[str, Any] | None = None,
|
||||
region: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve the full configuration by merging layers.
|
||||
|
||||
Args:
|
||||
site_config: Site-specific configuration from the database.
|
||||
org_defaults: Organisation-level default overrides (optional).
|
||||
group_defaults: Site-group-level default overrides (optional).
|
||||
region: ISO region code for regional mode override (optional).
|
||||
|
||||
Returns:
|
||||
Fully resolved configuration dictionary.
|
||||
"""
|
||||
# Start with system defaults
|
||||
resolved = {**SYSTEM_DEFAULTS}
|
||||
|
||||
# Apply organisation defaults (if any)
|
||||
if org_defaults:
|
||||
_merge_non_none(resolved, org_defaults)
|
||||
|
||||
# Apply site group defaults (if any)
|
||||
if group_defaults:
|
||||
_merge_non_none(resolved, group_defaults)
|
||||
|
||||
# Apply site-specific config
|
||||
_merge_non_none(resolved, site_config)
|
||||
|
||||
# Apply regional blocking mode override
|
||||
if region and site_config.get("regional_modes"):
|
||||
regional_modes = site_config["regional_modes"]
|
||||
if isinstance(regional_modes, dict):
|
||||
# Try exact match first, then fall back to DEFAULT
|
||||
regional_mode = regional_modes.get(region) or regional_modes.get("DEFAULT")
|
||||
if regional_mode:
|
||||
resolved["blocking_mode"] = regional_mode
|
||||
|
||||
return resolved
|
||||
|
||||
|
||||
def build_public_config(
|
||||
site_id: str,
|
||||
resolved: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build a public configuration JSON for the banner script.
|
||||
|
||||
Strips internal fields and adds the site_id for identification.
|
||||
"""
|
||||
return {
|
||||
"id": resolved.get("id", ""),
|
||||
"site_id": site_id,
|
||||
"blocking_mode": resolved["blocking_mode"],
|
||||
"regional_modes": resolved.get("regional_modes"),
|
||||
"tcf_enabled": resolved["tcf_enabled"],
|
||||
"gpp_enabled": resolved["gpp_enabled"],
|
||||
"gpp_supported_apis": resolved.get("gpp_supported_apis"),
|
||||
"gpc_enabled": resolved["gpc_enabled"],
|
||||
"gpc_jurisdictions": resolved.get("gpc_jurisdictions"),
|
||||
"gpc_global_honour": resolved["gpc_global_honour"],
|
||||
"gcm_enabled": resolved["gcm_enabled"],
|
||||
"gcm_default": resolved.get("gcm_default"),
|
||||
"shopify_privacy_enabled": resolved["shopify_privacy_enabled"],
|
||||
"banner_config": resolved.get("banner_config"),
|
||||
"privacy_policy_url": resolved.get("privacy_policy_url"),
|
||||
"terms_url": resolved.get("terms_url"),
|
||||
"consent_expiry_days": resolved["consent_expiry_days"],
|
||||
"consent_group_id": resolved.get("consent_group_id"),
|
||||
"ab_test": resolved.get("ab_test"),
|
||||
}
|
||||
|
||||
|
||||
CONFIG_FIELDS = (
|
||||
"blocking_mode",
|
||||
"regional_modes",
|
||||
"tcf_enabled",
|
||||
"tcf_publisher_cc",
|
||||
"gpp_enabled",
|
||||
"gpp_supported_apis",
|
||||
"gpc_enabled",
|
||||
"gpc_jurisdictions",
|
||||
"gpc_global_honour",
|
||||
"gcm_enabled",
|
||||
"gcm_default",
|
||||
"shopify_privacy_enabled",
|
||||
"banner_config",
|
||||
"privacy_policy_url",
|
||||
"terms_url",
|
||||
"consent_expiry_days",
|
||||
)
|
||||
|
||||
|
||||
def orm_to_config_dict(obj: Any, *, include_id: bool = False) -> dict[str, Any]:
|
||||
"""Convert a SiteConfig or OrgConfig ORM object to a dict of config fields.
|
||||
|
||||
Only includes fields that are explicitly set (not NULL). This allows the
|
||||
hierarchy to work correctly: unset fields at higher-priority layers don't
|
||||
block inheritance from lower-priority layers.
|
||||
"""
|
||||
d: dict[str, Any] = {}
|
||||
if include_id and hasattr(obj, "id"):
|
||||
d["id"] = str(obj.id)
|
||||
for field in CONFIG_FIELDS:
|
||||
if hasattr(obj, field):
|
||||
value = getattr(obj, field)
|
||||
if value is not None:
|
||||
d[field] = value
|
||||
return d
|
||||
|
||||
|
||||
def _merge_non_none(target: dict[str, Any], source: dict[str, Any]) -> None:
|
||||
"""Merge source into target, skipping None values in source."""
|
||||
for key, value in source.items():
|
||||
if value is not None:
|
||||
target[key] = value
|
||||
77
apps/api/src/services/cors.py
Normal file
77
apps/api/src/services/cors.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Dynamic CORS origin validation.
|
||||
|
||||
Provides an origin validator that checks incoming origins against
|
||||
registered site domains (primary + additional) in addition to the
|
||||
statically configured allowed_origins list.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.site import Site
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_domain_from_origin(origin: str) -> str | None:
|
||||
"""Extract the hostname from an origin URL.
|
||||
|
||||
e.g. 'https://example.com:443' → 'example.com'
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(origin)
|
||||
return parsed.hostname
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
async def get_allowed_domains(db: AsyncSession) -> set[str]:
|
||||
"""Fetch all registered domains (primary + additional) from active sites."""
|
||||
result = await db.execute(
|
||||
select(Site.domain, Site.additional_domains).where(
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
|
||||
domains: set[str] = set()
|
||||
for row in result.all():
|
||||
domains.add(row.domain.lower())
|
||||
if row.additional_domains:
|
||||
for d in row.additional_domains:
|
||||
domains.add(d.lower())
|
||||
|
||||
return domains
|
||||
|
||||
|
||||
def is_origin_allowed(
|
||||
origin: str,
|
||||
static_origins: list[str],
|
||||
registered_domains: set[str],
|
||||
) -> bool:
|
||||
"""Check if an origin is allowed by either the static list or registered domains.
|
||||
|
||||
Args:
|
||||
origin: The Origin header value (e.g. 'https://example.com').
|
||||
static_origins: Statically configured allowed origins from settings.
|
||||
registered_domains: Set of registered site domains from the database.
|
||||
|
||||
Returns:
|
||||
True if the origin is allowed.
|
||||
"""
|
||||
# Check static origins first (exact match)
|
||||
if origin in static_origins:
|
||||
return True
|
||||
|
||||
# Wildcard — allow everything
|
||||
if "*" in static_origins:
|
||||
return True
|
||||
|
||||
# Extract domain from origin and check against registered domains
|
||||
domain = extract_domain_from_origin(origin)
|
||||
return bool(domain and domain.lower() in registered_domains)
|
||||
54
apps/api/src/services/dependencies.py
Normal file
54
apps/api/src/services/dependencies.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from jose import JWTError
|
||||
|
||||
from src.schemas.auth import CurrentUser
|
||||
from src.services.auth import decode_token
|
||||
|
||||
bearer_scheme = HTTPBearer()
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials = Depends(bearer_scheme),
|
||||
) -> CurrentUser:
|
||||
"""Extract and validate the current user from the JWT bearer token."""
|
||||
try:
|
||||
payload = decode_token(credentials.credentials)
|
||||
except JWTError as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
) from exc
|
||||
|
||||
if payload.get("type") != "access":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token type",
|
||||
)
|
||||
|
||||
return CurrentUser(
|
||||
id=uuid.UUID(payload["sub"]),
|
||||
organisation_id=uuid.UUID(payload["org_id"]),
|
||||
email=payload.get("email", ""),
|
||||
role=payload.get("role", "viewer"),
|
||||
)
|
||||
|
||||
|
||||
def require_role(*allowed_roles: str) -> Callable:
|
||||
"""Dependency factory that restricts access to users with specific roles."""
|
||||
|
||||
async def _check_role(
|
||||
current_user: CurrentUser = Depends(get_current_user),
|
||||
) -> CurrentUser:
|
||||
if not current_user.has_role(*allowed_roles):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Role '{current_user.role}' is not permitted for this action",
|
||||
)
|
||||
return current_user
|
||||
|
||||
return _check_role
|
||||
339
apps/api/src/services/geoip.py
Normal file
339
apps/api/src/services/geoip.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""GeoIP service — resolve an IP address to a country/region code.
|
||||
|
||||
Resolution order (see :func:`detect_region`):
|
||||
|
||||
1. **CDN / proxy headers.** Operators configure ``GEOIP_COUNTRY_HEADER``
|
||||
(and optionally ``GEOIP_REGION_HEADER``) to match whatever their edge
|
||||
uses — e.g. ``cf-ipcountry`` + ``cf-region-code`` on Cloudflare
|
||||
Enterprise, or ``x-gclb-country`` + ``x-gclb-region`` on GCP. A short
|
||||
built-in country list (``cf-ipcountry``, ``x-vercel-ip-country``,
|
||||
``x-appengine-country``, ``x-country-code``) covers the common case
|
||||
where only country-level granularity is needed.
|
||||
2. **Local MaxMind GeoLite2-City database.** Set
|
||||
``GEOIP_MAXMIND_DB_PATH`` to a mounted ``.mmdb`` file. Gives both
|
||||
country and ISO 3166-2 subdivision without any external calls.
|
||||
3. **External ip-api.com lookup** (rate-limited, no API key). Last-ditch
|
||||
fallback; fine for development, not recommended for production.
|
||||
4. Unresolved — the caller should fall back to the default region.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
|
||||
import geoip2.database
|
||||
import httpx
|
||||
from fastapi import Request
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazily-initialised MaxMind reader. ``geoip2.database.Reader`` opens
|
||||
# the file once and then every lookup is a memory-mapped read, so we
|
||||
# cache it for the lifetime of the process. ``None`` means either no
|
||||
# path is configured, initialisation failed, or we haven't tried yet.
|
||||
_maxmind_reader: geoip2.database.Reader | None = None
|
||||
_maxmind_initialised = False
|
||||
|
||||
# Standard headers set by CDN / reverse proxy providers. Operators
|
||||
# running behind a CDN that uses a non-standard header (e.g. Google
|
||||
# Cloud Load Balancer's ``x-gclb-country``) can add one more via the
|
||||
# ``GEOIP_COUNTRY_HEADER`` env var — see ``detect_region_from_headers``.
|
||||
_GEO_HEADERS = [
|
||||
"cf-ipcountry", # Cloudflare
|
||||
"x-vercel-ip-country", # Vercel
|
||||
"x-appengine-country", # Google App Engine
|
||||
"x-country-code", # Generic / custom
|
||||
]
|
||||
|
||||
# Mapping from two-letter country code to region codes used in regional_modes
|
||||
# EU member states → "EU", US states handled separately, etc.
|
||||
_EU_COUNTRIES = frozenset(
|
||||
{
|
||||
"AT",
|
||||
"BE",
|
||||
"BG",
|
||||
"HR",
|
||||
"CY",
|
||||
"CZ",
|
||||
"DK",
|
||||
"EE",
|
||||
"FI",
|
||||
"FR",
|
||||
"DE",
|
||||
"GR",
|
||||
"HU",
|
||||
"IE",
|
||||
"IT",
|
||||
"LV",
|
||||
"LT",
|
||||
"LU",
|
||||
"MT",
|
||||
"NL",
|
||||
"PL",
|
||||
"PT",
|
||||
"RO",
|
||||
"SK",
|
||||
"SI",
|
||||
"ES",
|
||||
"SE",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GeoResult:
|
||||
"""Result of a GeoIP lookup."""
|
||||
|
||||
country_code: str | None
|
||||
region: str | None
|
||||
|
||||
@property
|
||||
def is_resolved(self) -> bool:
|
||||
return self.country_code is not None
|
||||
|
||||
|
||||
def country_to_region(country_code: str, state_code: str | None = None) -> str:
|
||||
"""Map a country code (+ optional subdivision) to a regional_modes key.
|
||||
|
||||
Resolution order:
|
||||
- EU member states collapse to ``"EU"`` regardless of subdivision;
|
||||
regional_modes treats the bloc as a single unit.
|
||||
- Any other country with a subdivision produces ``"{CC}-{SUB}"``
|
||||
(e.g. ``"US-CA"``, ``"GB-SCT"``, ``"BR-SP"``). The operator
|
||||
opts in to subdivision-level resolution by configuring a key
|
||||
of that form in ``regional_modes``; if they don't, the
|
||||
fallback resolver still matches on the plain country code.
|
||||
- Country with no subdivision is returned as-is (``"GB"``,
|
||||
``"BR"``, …).
|
||||
"""
|
||||
upper = country_code.upper()
|
||||
|
||||
if upper in _EU_COUNTRIES:
|
||||
return "EU"
|
||||
|
||||
if state_code:
|
||||
return f"{upper}-{state_code.upper()}"
|
||||
|
||||
return upper
|
||||
|
||||
|
||||
def detect_region_from_headers(request: Request) -> GeoResult:
|
||||
"""Attempt to detect the visitor's region from proxy/CDN headers.
|
||||
|
||||
This is the fastest path — no external calls needed. A custom
|
||||
country header configured via ``GEOIP_COUNTRY_HEADER`` takes
|
||||
priority over the built-in list so operators can plumb in
|
||||
non-standard CDN/load-balancer headers (e.g. ``x-gclb-country``)
|
||||
without code changes.
|
||||
|
||||
When ``GEOIP_REGION_HEADER`` is also set and the custom country
|
||||
header resolved, the subdivision code from that header is paired
|
||||
with the country to build region keys like ``US-CA``. The built-in
|
||||
country list is country-only — operators who need subdivision
|
||||
granularity must configure the explicit pair.
|
||||
|
||||
Header lookups are case-insensitive.
|
||||
"""
|
||||
settings = get_settings()
|
||||
custom_country = settings.geoip_country_header
|
||||
custom_region = settings.geoip_region_header
|
||||
|
||||
if custom_country:
|
||||
value = request.headers.get(custom_country)
|
||||
if value and value.upper() != "XX":
|
||||
country = value.upper().strip()
|
||||
state: str | None = None
|
||||
if custom_region:
|
||||
raw_state = request.headers.get(custom_region)
|
||||
if raw_state and raw_state.upper() != "XX":
|
||||
# ISO 3166-2 subdivision codes may be prefixed
|
||||
# with the country (e.g. ``US-CA``) or bare (e.g.
|
||||
# ``CA``). Strip the prefix so ``country_to_region``
|
||||
# sees just the subdivision.
|
||||
stripped = raw_state.strip().upper()
|
||||
state = stripped.split("-", 1)[-1] if "-" in stripped else stripped
|
||||
return GeoResult(
|
||||
country_code=country,
|
||||
region=country_to_region(country, state),
|
||||
)
|
||||
|
||||
for header in _GEO_HEADERS:
|
||||
value = request.headers.get(header)
|
||||
if value and value.upper() != "XX":
|
||||
country = value.upper().strip()
|
||||
return GeoResult(
|
||||
country_code=country,
|
||||
region=country_to_region(country),
|
||||
)
|
||||
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
|
||||
def get_client_ip(request: Request) -> str | None:
|
||||
"""Extract the real client IP from the request.
|
||||
|
||||
Checks X-Forwarded-For and X-Real-IP before falling back to the
|
||||
direct connection address.
|
||||
"""
|
||||
# X-Forwarded-For: client, proxy1, proxy2
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",")[0].strip()
|
||||
|
||||
real_ip = request.headers.get("x-real-ip")
|
||||
if real_ip:
|
||||
return real_ip.strip()
|
||||
|
||||
if request.client:
|
||||
return request.client.host
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def lookup_ip_region(ip: str) -> GeoResult:
|
||||
"""Look up the region for an IP address via an external API.
|
||||
|
||||
Uses ip-api.com (free tier, no key required, 45 req/min).
|
||||
In production this should be replaced with a local MaxMind database.
|
||||
"""
|
||||
if _is_private_ip(ip):
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3.0) as client:
|
||||
resp = await client.get(
|
||||
f"http://ip-api.com/json/{ip}",
|
||||
params={"fields": "status,countryCode,region"},
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
data = resp.json()
|
||||
if data.get("status") != "success":
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
country = data.get("countryCode")
|
||||
state = data.get("region") # State/province code
|
||||
if not country:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
region = country_to_region(country, state)
|
||||
return GeoResult(country_code=country, region=region)
|
||||
|
||||
except Exception:
|
||||
logger.debug("GeoIP lookup failed for %s", ip, exc_info=True)
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
|
||||
def _get_maxmind_reader() -> geoip2.database.Reader | None:
|
||||
"""Return the cached MaxMind reader, opening the DB on first use.
|
||||
|
||||
Caches both successful opens and failures (via
|
||||
``_maxmind_initialised``) so we don't retry a bad path on every
|
||||
request. Returns ``None`` if no path is configured or the DB
|
||||
couldn't be opened.
|
||||
"""
|
||||
global _maxmind_reader, _maxmind_initialised
|
||||
if _maxmind_initialised:
|
||||
return _maxmind_reader
|
||||
|
||||
_maxmind_initialised = True
|
||||
db_path = get_settings().geoip_maxmind_db_path
|
||||
if not db_path:
|
||||
return None
|
||||
|
||||
try:
|
||||
_maxmind_reader = geoip2.database.Reader(db_path)
|
||||
logger.info("GeoIP: opened MaxMind database at %s", db_path)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"GeoIP: failed to open MaxMind database at %s — falling back to "
|
||||
"external lookups. Check GEOIP_MAXMIND_DB_PATH and that the file "
|
||||
"is readable inside the container.",
|
||||
db_path,
|
||||
exc_info=True,
|
||||
)
|
||||
_maxmind_reader = None
|
||||
|
||||
return _maxmind_reader
|
||||
|
||||
|
||||
def lookup_ip_maxmind(ip: str) -> GeoResult:
|
||||
"""Resolve an IP via the local MaxMind database.
|
||||
|
||||
Memory-mapped read, no network I/O — cheap enough to call
|
||||
synchronously from the async path. Returns an unresolved
|
||||
``GeoResult`` when the DB isn't configured, the IP is private, or
|
||||
the record can't be found.
|
||||
"""
|
||||
if _is_private_ip(ip):
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
reader = _get_maxmind_reader()
|
||||
if reader is None:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
try:
|
||||
response = reader.city(ip)
|
||||
except Exception:
|
||||
logger.debug("MaxMind lookup failed for %s", ip, exc_info=True)
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
country = response.country.iso_code
|
||||
if not country:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
# ``subdivisions`` is ordered most-specific first; the first entry
|
||||
# is the ISO 3166-2 code (without the country prefix).
|
||||
state = response.subdivisions.most_specific.iso_code if response.subdivisions else None
|
||||
return GeoResult(
|
||||
country_code=country.upper(),
|
||||
region=country_to_region(country, state),
|
||||
)
|
||||
|
||||
|
||||
async def detect_region(request: Request) -> GeoResult:
|
||||
"""Detect the visitor's region.
|
||||
|
||||
Resolution order:
|
||||
|
||||
1. CDN/proxy headers (see :func:`detect_region_from_headers`).
|
||||
2. Local MaxMind database, if ``GEOIP_MAXMIND_DB_PATH`` is set.
|
||||
3. External ``ip-api.com`` lookup — last-ditch fallback.
|
||||
|
||||
Returns an unresolved :class:`GeoResult` if every tier fails.
|
||||
"""
|
||||
result = detect_region_from_headers(request)
|
||||
if result.is_resolved:
|
||||
return result
|
||||
|
||||
ip = get_client_ip(request)
|
||||
if not ip:
|
||||
return GeoResult(country_code=None, region=None)
|
||||
|
||||
if get_settings().geoip_maxmind_db_path:
|
||||
result = lookup_ip_maxmind(ip)
|
||||
if result.is_resolved:
|
||||
return result
|
||||
|
||||
return await lookup_ip_region(ip)
|
||||
|
||||
|
||||
def _is_private_ip(ip: str) -> bool:
|
||||
"""Check if an IP address is a private/loopback address."""
|
||||
return (
|
||||
ip.startswith("127.")
|
||||
or ip.startswith("10.")
|
||||
or ip.startswith("192.168.")
|
||||
or ip.startswith("172.16.")
|
||||
or ip.startswith("172.17.")
|
||||
or ip.startswith("172.18.")
|
||||
or ip.startswith("172.19.")
|
||||
or ip.startswith("172.2")
|
||||
or ip.startswith("172.3")
|
||||
or ip == "::1"
|
||||
or ip == "localhost"
|
||||
)
|
||||
41
apps/api/src/services/pseudonymisation.py
Normal file
41
apps/api/src/services/pseudonymisation.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Pseudonymisation helpers for consent records.
|
||||
|
||||
Consent records capture a hash of the visitor's IP address and
|
||||
user-agent string for abuse protection and audit trail purposes.
|
||||
|
||||
Previously this used an unsalted truncated SHA-256, which is trivially
|
||||
reversible for IPv4 addresses (only ~4 billion inputs). We now use
|
||||
HMAC-SHA256 keyed with a server-side secret so the hash cannot be
|
||||
recovered without access to the secret.
|
||||
|
||||
Public API: :func:`pseudonymise`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hmac
|
||||
from hashlib import sha256
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
# Length of the hex-encoded digest stored in the database. 32 hex chars
|
||||
# = 128 bits, which is more than enough entropy while keeping the
|
||||
# column compact. (Previous code used 16 hex chars = 64 bits.)
|
||||
_DIGEST_HEX_LEN = 32
|
||||
|
||||
|
||||
def pseudonymise(value: str) -> str:
|
||||
"""Return a keyed hash of *value* safe to store in an audit record.
|
||||
|
||||
Uses HMAC-SHA256 with the configured ``pseudonymisation_secret``
|
||||
(falling back to ``jwt_secret_key`` if not explicitly set). The
|
||||
resulting hex digest is truncated to 32 characters (128 bits).
|
||||
|
||||
An empty input always returns an empty string so callers don't
|
||||
have to branch on missing data.
|
||||
"""
|
||||
if not value:
|
||||
return ""
|
||||
key = get_settings().pseudonymisation_key
|
||||
digest = hmac.new(key, value.encode("utf-8"), sha256).hexdigest()
|
||||
return digest[:_DIGEST_HEX_LEN]
|
||||
89
apps/api/src/services/publisher.py
Normal file
89
apps/api/src/services/publisher.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""CDN publishing pipeline.
|
||||
|
||||
Publishes resolved site configurations as static JSON files for the
|
||||
banner script to fetch. Supports local filesystem (development) and
|
||||
can be extended for S3/GCS/CloudFront.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from src.config.settings import get_settings
|
||||
|
||||
from .config_resolver import build_public_config, resolve_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PublishResult:
|
||||
"""Result of a publish operation."""
|
||||
|
||||
def __init__(self, success: bool, path: str, error: str | None = None) -> None:
|
||||
self.success = success
|
||||
self.path = path
|
||||
self.error = error
|
||||
self.published_at = datetime.now(UTC).isoformat() if success else None
|
||||
|
||||
|
||||
async def publish_site_config(
|
||||
site_id: str,
|
||||
site_config: dict[str, Any],
|
||||
org_defaults: dict[str, Any] | None = None,
|
||||
) -> PublishResult:
|
||||
"""Resolve and publish a site configuration to CDN.
|
||||
|
||||
Args:
|
||||
site_id: The site UUID as a string.
|
||||
site_config: Raw site configuration from the database.
|
||||
org_defaults: Organisation-level defaults (optional).
|
||||
|
||||
Returns:
|
||||
PublishResult with success status and path.
|
||||
"""
|
||||
try:
|
||||
# Resolve the full config hierarchy
|
||||
resolved = resolve_config(site_config, org_defaults)
|
||||
|
||||
# Build the public-facing config
|
||||
public_config = build_public_config(site_id, resolved)
|
||||
|
||||
# Publish to the configured backend
|
||||
settings = get_settings()
|
||||
path = await _publish_local(site_id, public_config, settings.cdn_base_url)
|
||||
|
||||
logger.info("Published config for site %s to %s", site_id, path)
|
||||
return PublishResult(success=True, path=path)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Failed to publish config for site %s", site_id)
|
||||
return PublishResult(success=False, path="", error=str(exc))
|
||||
|
||||
|
||||
async def _publish_local(
|
||||
site_id: str,
|
||||
config: dict[str, Any],
|
||||
cdn_base: str,
|
||||
) -> str:
|
||||
"""Publish config to local filesystem (for development/Docker Compose).
|
||||
|
||||
Writes to the CDN proxy's HTML directory so nginx can serve it.
|
||||
"""
|
||||
# Default local publish directory
|
||||
publish_dir = Path("/app/cdn-publish") if Path("/app").exists() else Path("cdn-publish")
|
||||
publish_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Write the config JSON
|
||||
config_path = publish_dir / f"site-config-{site_id}.json"
|
||||
config_path.write_text(json.dumps(config, indent=2, default=str))
|
||||
|
||||
# Also write a versioned copy for cache-busting
|
||||
version = datetime.now(UTC).strftime("%Y%m%d%H%M%S")
|
||||
versioned_path = publish_dir / f"site-config-{site_id}-{version}.json"
|
||||
versioned_path.write_text(json.dumps(config, indent=2, default=str))
|
||||
|
||||
return str(config_path)
|
||||
322
apps/api/src/services/scanner.py
Normal file
322
apps/api/src/services/scanner.py
Normal file
@@ -0,0 +1,322 @@
|
||||
"""Scan orchestration and diff engine.
|
||||
|
||||
Provides scan job lifecycle management, result diffing between scans,
|
||||
and cookie inventory synchronisation from scan results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.models.cookie import Cookie
|
||||
from src.models.scan import ScanJob, ScanResult
|
||||
from src.models.site import Site
|
||||
from src.schemas.scanner import (
|
||||
CookieDiffItem,
|
||||
DiffStatus,
|
||||
ScanDiffResponse,
|
||||
)
|
||||
|
||||
|
||||
async def create_scan_job(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
site_id: uuid.UUID,
|
||||
trigger: str = "manual",
|
||||
max_pages: int = 50,
|
||||
) -> ScanJob:
|
||||
"""Create a new scan job in 'pending' state."""
|
||||
job = ScanJob(
|
||||
site_id=site_id,
|
||||
status="pending",
|
||||
trigger=trigger,
|
||||
pages_total=max_pages,
|
||||
)
|
||||
db.add(job)
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
|
||||
async def start_scan_job(db: AsyncSession, job: ScanJob) -> ScanJob:
|
||||
"""Transition a scan job to 'running'.
|
||||
|
||||
Idempotent: if the job is already running (e.g. Celery re-delivered the
|
||||
task after a worker crash), this is a no-op. Also handles re-delivery
|
||||
after a transient failure that left the job in 'failed' state mid-retry.
|
||||
"""
|
||||
if job.status == "running":
|
||||
return job
|
||||
job.status = "running"
|
||||
job.started_at = datetime.now(UTC)
|
||||
# Reset any previous error so the retry starts clean
|
||||
job.error_message = None
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
|
||||
async def complete_scan_job(
|
||||
db: AsyncSession,
|
||||
job: ScanJob,
|
||||
*,
|
||||
pages_scanned: int = 0,
|
||||
cookies_found: int = 0,
|
||||
error_message: str | None = None,
|
||||
) -> ScanJob:
|
||||
"""Mark a scan job as completed or failed."""
|
||||
job.status = "failed" if error_message else "completed"
|
||||
job.completed_at = datetime.now(UTC)
|
||||
job.pages_scanned = pages_scanned
|
||||
job.cookies_found = cookies_found
|
||||
job.error_message = error_message
|
||||
await db.flush()
|
||||
return job
|
||||
|
||||
|
||||
async def add_scan_result(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
scan_job_id: uuid.UUID,
|
||||
page_url: str,
|
||||
cookie_name: str,
|
||||
cookie_domain: str,
|
||||
storage_type: str = "cookie",
|
||||
attributes: dict | None = None,
|
||||
script_source: str | None = None,
|
||||
auto_category: str | None = None,
|
||||
initiator_chain: list[str] | None = None,
|
||||
) -> ScanResult:
|
||||
"""Record a single cookie discovery from a scan."""
|
||||
result = ScanResult(
|
||||
scan_job_id=scan_job_id,
|
||||
page_url=page_url,
|
||||
cookie_name=cookie_name,
|
||||
cookie_domain=cookie_domain,
|
||||
storage_type=storage_type,
|
||||
attributes=attributes,
|
||||
script_source=script_source,
|
||||
auto_category=auto_category,
|
||||
initiator_chain=initiator_chain,
|
||||
)
|
||||
db.add(result)
|
||||
await db.flush()
|
||||
return result
|
||||
|
||||
|
||||
async def get_previous_completed_scan(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
site_id: uuid.UUID,
|
||||
before_scan_id: uuid.UUID,
|
||||
) -> ScanJob | None:
|
||||
"""Find the most recent completed scan before the given one."""
|
||||
# First get the creation time of the reference scan
|
||||
ref_result = await db.execute(select(ScanJob.created_at).where(ScanJob.id == before_scan_id))
|
||||
ref_time = ref_result.scalar_one_or_none()
|
||||
if ref_time is None:
|
||||
return None
|
||||
|
||||
result = await db.execute(
|
||||
select(ScanJob)
|
||||
.where(
|
||||
ScanJob.site_id == site_id,
|
||||
ScanJob.status == "completed",
|
||||
ScanJob.id != before_scan_id,
|
||||
ScanJob.created_at < ref_time,
|
||||
)
|
||||
.order_by(ScanJob.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def _result_key(r: ScanResult) -> tuple[str, str, str]:
|
||||
"""Unique key for a scan result (cookie identity)."""
|
||||
return (r.cookie_name, r.cookie_domain, r.storage_type)
|
||||
|
||||
|
||||
async def compute_scan_diff(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
current_scan_id: uuid.UUID,
|
||||
site_id: uuid.UUID,
|
||||
) -> ScanDiffResponse:
|
||||
"""Compute the diff between the current scan and the previous one.
|
||||
|
||||
Returns new, removed, and changed cookies. If no previous scan exists,
|
||||
all cookies in the current scan are marked as 'new'.
|
||||
"""
|
||||
previous_scan = await get_previous_completed_scan(
|
||||
db, site_id=site_id, before_scan_id=current_scan_id
|
||||
)
|
||||
|
||||
# Load current scan results
|
||||
current_results = await db.execute(
|
||||
select(ScanResult).where(ScanResult.scan_job_id == current_scan_id)
|
||||
)
|
||||
current_items = list(current_results.scalars().all())
|
||||
current_keys = {_result_key(r): r for r in current_items}
|
||||
|
||||
if previous_scan is None:
|
||||
# No previous scan — everything is new
|
||||
new_cookies = [
|
||||
CookieDiffItem(
|
||||
name=r.cookie_name,
|
||||
domain=r.cookie_domain,
|
||||
storage_type=r.storage_type,
|
||||
diff_status=DiffStatus.NEW,
|
||||
details="First scan — no previous data",
|
||||
)
|
||||
for r in current_items
|
||||
]
|
||||
return ScanDiffResponse(
|
||||
current_scan_id=current_scan_id,
|
||||
previous_scan_id=None,
|
||||
new_cookies=new_cookies,
|
||||
total_new=len(new_cookies),
|
||||
)
|
||||
|
||||
# Load previous scan results
|
||||
prev_results = await db.execute(
|
||||
select(ScanResult).where(ScanResult.scan_job_id == previous_scan.id)
|
||||
)
|
||||
prev_items = list(prev_results.scalars().all())
|
||||
prev_keys = {_result_key(r): r for r in prev_items}
|
||||
|
||||
new_cookies: list[CookieDiffItem] = []
|
||||
removed_cookies: list[CookieDiffItem] = []
|
||||
changed_cookies: list[CookieDiffItem] = []
|
||||
|
||||
# New cookies: in current but not in previous
|
||||
for key, r in current_keys.items():
|
||||
if key not in prev_keys:
|
||||
new_cookies.append(
|
||||
CookieDiffItem(
|
||||
name=r.cookie_name,
|
||||
domain=r.cookie_domain,
|
||||
storage_type=r.storage_type,
|
||||
diff_status=DiffStatus.NEW,
|
||||
)
|
||||
)
|
||||
|
||||
# Removed cookies: in previous but not in current
|
||||
for key, r in prev_keys.items():
|
||||
if key not in current_keys:
|
||||
removed_cookies.append(
|
||||
CookieDiffItem(
|
||||
name=r.cookie_name,
|
||||
domain=r.cookie_domain,
|
||||
storage_type=r.storage_type,
|
||||
diff_status=DiffStatus.REMOVED,
|
||||
)
|
||||
)
|
||||
|
||||
# Changed cookies: in both but with different attributes
|
||||
for key in current_keys:
|
||||
if key in prev_keys:
|
||||
curr = current_keys[key]
|
||||
prev = prev_keys[key]
|
||||
changes: list[str] = []
|
||||
|
||||
if curr.script_source != prev.script_source:
|
||||
changes.append("script_source changed")
|
||||
if curr.auto_category != prev.auto_category:
|
||||
changes.append("auto_category changed")
|
||||
# Compare cookie attributes (e.g. secure, httpOnly)
|
||||
if (curr.attributes or {}) != (prev.attributes or {}):
|
||||
changes.append("attributes changed")
|
||||
|
||||
if changes:
|
||||
changed_cookies.append(
|
||||
CookieDiffItem(
|
||||
name=curr.cookie_name,
|
||||
domain=curr.cookie_domain,
|
||||
storage_type=curr.storage_type,
|
||||
diff_status=DiffStatus.CHANGED,
|
||||
details="; ".join(changes),
|
||||
)
|
||||
)
|
||||
|
||||
return ScanDiffResponse(
|
||||
current_scan_id=current_scan_id,
|
||||
previous_scan_id=previous_scan.id,
|
||||
new_cookies=new_cookies,
|
||||
removed_cookies=removed_cookies,
|
||||
changed_cookies=changed_cookies,
|
||||
total_new=len(new_cookies),
|
||||
total_removed=len(removed_cookies),
|
||||
total_changed=len(changed_cookies),
|
||||
)
|
||||
|
||||
|
||||
async def sync_scan_results_to_cookies(
|
||||
db: AsyncSession,
|
||||
*,
|
||||
scan_job_id: uuid.UUID,
|
||||
site_id: uuid.UUID,
|
||||
) -> int:
|
||||
"""Upsert scan results into the site's cookie inventory.
|
||||
|
||||
Creates new Cookie records for newly discovered items or updates
|
||||
last_seen_at for existing ones. Returns the number of new cookies.
|
||||
"""
|
||||
results = await db.execute(select(ScanResult).where(ScanResult.scan_job_id == scan_job_id))
|
||||
items = list(results.scalars().all())
|
||||
|
||||
now_iso = datetime.now(UTC).isoformat()
|
||||
new_count = 0
|
||||
|
||||
for item in items:
|
||||
existing = await db.execute(
|
||||
select(Cookie).where(
|
||||
Cookie.site_id == site_id,
|
||||
Cookie.name == item.cookie_name,
|
||||
Cookie.domain == item.cookie_domain,
|
||||
Cookie.storage_type == item.storage_type,
|
||||
)
|
||||
)
|
||||
cookie = existing.scalar_one_or_none()
|
||||
|
||||
if cookie:
|
||||
cookie.last_seen_at = now_iso
|
||||
else:
|
||||
cookie = Cookie(
|
||||
site_id=site_id,
|
||||
name=item.cookie_name,
|
||||
domain=item.cookie_domain,
|
||||
storage_type=item.storage_type,
|
||||
review_status="pending",
|
||||
first_seen_at=now_iso,
|
||||
last_seen_at=now_iso,
|
||||
)
|
||||
db.add(cookie)
|
||||
new_count += 1
|
||||
|
||||
await db.flush()
|
||||
return new_count
|
||||
|
||||
|
||||
async def get_sites_due_for_scan(db: AsyncSession) -> list[Site]:
|
||||
"""Find sites with a scan schedule that are due for scanning.
|
||||
|
||||
A site is due when it has a scan_schedule_cron set and either has
|
||||
never been scanned or the last scan completed before the schedule
|
||||
interval. For simplicity, this checks the most recent scan's
|
||||
completed_at against the current time minus a derived interval.
|
||||
"""
|
||||
from src.models.site_config import SiteConfig
|
||||
|
||||
# Find sites with a cron schedule
|
||||
result = await db.execute(
|
||||
select(Site)
|
||||
.join(SiteConfig, SiteConfig.site_id == Site.id)
|
||||
.where(
|
||||
Site.deleted_at.is_(None),
|
||||
Site.is_active.is_(True),
|
||||
SiteConfig.scan_schedule_cron.isnot(None),
|
||||
)
|
||||
)
|
||||
return list(result.scalars().all())
|
||||
0
apps/api/src/tasks/__init__.py
Normal file
0
apps/api/src/tasks/__init__.py
Normal file
87
apps/api/src/tasks/retention.py
Normal file
87
apps/api/src/tasks/retention.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Consent record retention purge.
|
||||
|
||||
Deletes consent records older than each site's configured
|
||||
``consent_retention_days``. Sites with no retention configured are
|
||||
skipped — operators must explicitly opt in per site (or set it at the
|
||||
org/system level and let the cascade resolve it).
|
||||
|
||||
Scheduled by ``celery beat`` daily at 01:00 UTC via the entry in
|
||||
``src.celery_app.beat_schedule``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from src.celery_app import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _purge() -> dict[str, int]:
|
||||
"""Delete expired consent records across all sites with retention set.
|
||||
|
||||
Returns a summary ``{"sites_processed": N, "records_deleted": M}``.
|
||||
"""
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.models.consent import ConsentRecord
|
||||
from src.models.site_config import SiteConfig
|
||||
|
||||
settings = get_settings()
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
|
||||
sites_processed = 0
|
||||
records_deleted = 0
|
||||
|
||||
async with AsyncSession(engine, expire_on_commit=False) as session:
|
||||
configs = (
|
||||
(
|
||||
await session.execute(
|
||||
select(SiteConfig).where(SiteConfig.consent_retention_days.isnot(None)),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
for cfg in configs:
|
||||
retention_days = cfg.consent_retention_days
|
||||
if not retention_days or retention_days <= 0:
|
||||
continue
|
||||
cutoff = now - timedelta(days=retention_days)
|
||||
result = await session.execute(
|
||||
delete(ConsentRecord).where(
|
||||
ConsentRecord.site_id == cfg.site_id,
|
||||
ConsentRecord.consented_at < cutoff,
|
||||
),
|
||||
)
|
||||
deleted = result.rowcount or 0
|
||||
records_deleted += deleted
|
||||
sites_processed += 1
|
||||
if deleted:
|
||||
logger.info(
|
||||
"retention.purged",
|
||||
extra={
|
||||
"site_id": str(cfg.site_id),
|
||||
"retention_days": retention_days,
|
||||
"deleted": deleted,
|
||||
"cutoff": cutoff.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
await engine.dispose()
|
||||
return {"sites_processed": sites_processed, "records_deleted": records_deleted}
|
||||
|
||||
|
||||
@app.task(name="src.tasks.retention.purge_expired_consent_records")
|
||||
def purge_expired_consent_records() -> dict[str, int]:
|
||||
"""Celery entrypoint for the retention purge."""
|
||||
return asyncio.run(_purge())
|
||||
308
apps/api/src/tasks/scanner.py
Normal file
308
apps/api/src/tasks/scanner.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Celery tasks for scan job execution and scheduling.
|
||||
|
||||
The run_scan task calls the scanner HTTP service to execute a Playwright
|
||||
crawl, then processes the results: stores scan results, runs auto-
|
||||
classification, syncs discovered cookies to the site inventory, and
|
||||
computes diffs against the previous scan.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
|
||||
from src.celery_app import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@app.task(name="src.tasks.scanner.run_scan", bind=True, max_retries=2)
|
||||
def run_scan(self, scan_job_id: str, site_id: str) -> dict:
|
||||
"""Execute a scan job by calling the scanner service.
|
||||
|
||||
1. Transition job to 'running'
|
||||
2. Look up site domain
|
||||
3. Call scanner HTTP service with the domain
|
||||
4. Store scan results and run auto-classification
|
||||
5. Sync discovered cookies to the site inventory
|
||||
6. Mark job as completed
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.models.scan import ScanJob
|
||||
from src.models.site import Site
|
||||
from src.services.classification import classify_single_cookie
|
||||
from src.services.scanner import (
|
||||
add_scan_result,
|
||||
complete_scan_job,
|
||||
start_scan_job,
|
||||
sync_scan_results_to_cookies,
|
||||
)
|
||||
|
||||
settings = get_settings()
|
||||
job_uuid = uuid.UUID(scan_job_id)
|
||||
site_uuid = uuid.UUID(site_id)
|
||||
|
||||
async def _execute() -> dict:
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async with AsyncSession(engine, expire_on_commit=False) as db:
|
||||
try:
|
||||
# Load the job
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
|
||||
job = result.scalar_one_or_none()
|
||||
if job is None:
|
||||
return {"error": "Scan job not found"}
|
||||
|
||||
# Load the site to get the domain
|
||||
site_result = await db.execute(select(Site).where(Site.id == site_uuid))
|
||||
site = site_result.scalar_one_or_none()
|
||||
if site is None:
|
||||
return {"error": "Site not found"}
|
||||
|
||||
# Transition to running
|
||||
await start_scan_job(db, job)
|
||||
await db.commit()
|
||||
|
||||
# Call the scanner service
|
||||
scanner_url = f"{settings.scanner_service_url}/scan"
|
||||
max_pages = job.pages_total or 50
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(settings.scanner_timeout_seconds)
|
||||
) as client:
|
||||
resp = await client.post(
|
||||
scanner_url,
|
||||
json={
|
||||
"domain": site.domain,
|
||||
"max_pages": max_pages,
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
scan_data = resp.json()
|
||||
|
||||
# Store scan results
|
||||
cookies = scan_data.get("cookies", [])
|
||||
pages_crawled = scan_data.get("pages_crawled", 0)
|
||||
|
||||
for cookie in cookies:
|
||||
# Auto-classify the cookie
|
||||
category = await classify_single_cookie(
|
||||
db,
|
||||
site_id=site_uuid,
|
||||
cookie_name=cookie["name"],
|
||||
cookie_domain=cookie["domain"],
|
||||
)
|
||||
|
||||
await add_scan_result(
|
||||
db,
|
||||
scan_job_id=job_uuid,
|
||||
page_url=cookie.get("page_url", ""),
|
||||
cookie_name=cookie["name"],
|
||||
cookie_domain=cookie["domain"],
|
||||
storage_type=cookie.get("storage_type", "cookie"),
|
||||
attributes={
|
||||
"path": cookie.get("path"),
|
||||
"http_only": cookie.get("http_only"),
|
||||
"secure": cookie.get("secure"),
|
||||
"same_site": cookie.get("same_site"),
|
||||
"value_length": cookie.get("value_length", 0),
|
||||
},
|
||||
script_source=cookie.get("script_source"),
|
||||
auto_category=category.category_slug if category else None,
|
||||
initiator_chain=cookie.get("initiator_chain") or None,
|
||||
)
|
||||
|
||||
await db.commit()
|
||||
|
||||
# Mark job as completed
|
||||
await complete_scan_job(
|
||||
db,
|
||||
job,
|
||||
pages_scanned=pages_crawled,
|
||||
cookies_found=len(cookies),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Sync results to cookie inventory
|
||||
new_cookies = await sync_scan_results_to_cookies(
|
||||
db,
|
||||
scan_job_id=job_uuid,
|
||||
site_id=site_uuid,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
logger.info(
|
||||
"Scan %s completed: %d pages, %d cookies, %d new",
|
||||
scan_job_id,
|
||||
pages_crawled,
|
||||
len(cookies),
|
||||
new_cookies,
|
||||
)
|
||||
|
||||
return {
|
||||
"scan_job_id": scan_job_id,
|
||||
"status": "completed",
|
||||
"pages_scanned": pages_crawled,
|
||||
"cookies_found": len(cookies),
|
||||
"new_cookies_synced": new_cookies,
|
||||
}
|
||||
|
||||
except httpx.HTTPError as exc:
|
||||
logger.error("Scanner service error for job %s: %s", scan_job_id, exc)
|
||||
await db.rollback()
|
||||
# Only mark failed on the final retry; otherwise let the
|
||||
# retry set status back to "running" cleanly.
|
||||
if self.request.retries >= self.max_retries:
|
||||
await _mark_failed(db, job_uuid, f"Scanner service error: {exc}")
|
||||
raise self.retry(exc=exc, countdown=30) from exc
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception("Scan task failed for job %s", scan_job_id)
|
||||
await db.rollback()
|
||||
await _mark_failed(db, job_uuid, str(exc))
|
||||
return {"error": str(exc)}
|
||||
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
return asyncio.run(_execute())
|
||||
|
||||
|
||||
async def _mark_failed(db, job_uuid: uuid.UUID, message: str) -> None:
|
||||
"""Mark a scan job as failed."""
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.models.scan import ScanJob
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
try:
|
||||
result = await db.execute(select(ScanJob).where(ScanJob.id == job_uuid))
|
||||
job = result.scalar_one_or_none()
|
||||
if job:
|
||||
await complete_scan_job(db, job, error_message=message)
|
||||
await db.commit()
|
||||
except Exception:
|
||||
logger.exception("Failed to mark scan job %s as failed", job_uuid)
|
||||
|
||||
|
||||
@app.task(name="src.tasks.scanner.check_scheduled_scans")
|
||||
def check_scheduled_scans() -> dict:
|
||||
"""Periodic task: check which sites are due for a scheduled scan.
|
||||
|
||||
Runs every 15 minutes via Celery Beat. For each site with a
|
||||
scan_schedule_cron, checks if a scan is overdue and triggers one.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.services.scanner import create_scan_job, get_sites_due_for_scan
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
async def _check() -> dict:
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async with AsyncSession(engine, expire_on_commit=False) as db:
|
||||
try:
|
||||
sites = await get_sites_due_for_scan(db)
|
||||
triggered = 0
|
||||
|
||||
for site in sites:
|
||||
job = await create_scan_job(db, site_id=site.id, trigger="scheduled")
|
||||
await db.commit()
|
||||
# Dispatch the scan task
|
||||
run_scan.delay(str(job.id), str(site.id))
|
||||
triggered += 1
|
||||
|
||||
return {"sites_checked": len(sites), "scans_triggered": triggered}
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
raise
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
return asyncio.run(_check())
|
||||
|
||||
|
||||
@app.task(name="src.tasks.scanner.recover_stale_scans")
|
||||
def recover_stale_scans() -> dict:
|
||||
"""Periodic task: detect and recover scan jobs stuck in pending/running.
|
||||
|
||||
- Jobs stuck in 'pending' for >5 minutes are re-dispatched to Celery.
|
||||
- Jobs stuck in 'running' for >10 minutes are marked as failed.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from sqlalchemy import or_, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.config.settings import get_settings
|
||||
from src.models.scan import ScanJob
|
||||
from src.services.scanner import complete_scan_job
|
||||
|
||||
settings = get_settings()
|
||||
|
||||
async def _recover() -> dict:
|
||||
engine = create_async_engine(settings.database_url, echo=False)
|
||||
async with AsyncSession(engine, expire_on_commit=False) as db:
|
||||
try:
|
||||
now = datetime.now(UTC)
|
||||
stale_pending_cutoff = now - timedelta(minutes=5)
|
||||
stale_running_cutoff = now - timedelta(minutes=10)
|
||||
|
||||
result = await db.execute(
|
||||
select(ScanJob).where(
|
||||
or_(
|
||||
# Pending too long — likely never picked up
|
||||
(ScanJob.status == "pending")
|
||||
& (ScanJob.created_at < stale_pending_cutoff),
|
||||
# Running too long — likely worker died
|
||||
(ScanJob.status == "running")
|
||||
& (ScanJob.started_at < stale_running_cutoff),
|
||||
)
|
||||
)
|
||||
)
|
||||
stale_jobs = list(result.scalars().all())
|
||||
|
||||
redispatched = 0
|
||||
failed = 0
|
||||
|
||||
for job in stale_jobs:
|
||||
if job.status == "pending":
|
||||
# Re-dispatch to Celery
|
||||
logger.warning("Re-dispatching stale pending scan job %s", job.id)
|
||||
run_scan.delay(str(job.id), str(job.site_id))
|
||||
redispatched += 1
|
||||
elif job.status == "running":
|
||||
# Mark as failed — the worker likely died
|
||||
logger.warning("Failing stale running scan job %s", job.id)
|
||||
await complete_scan_job(
|
||||
db,
|
||||
job,
|
||||
error_message=(
|
||||
"Job timed out (running too long, worker may have crashed)"
|
||||
),
|
||||
)
|
||||
failed += 1
|
||||
|
||||
await db.commit()
|
||||
return {
|
||||
"stale_jobs_found": len(stale_jobs),
|
||||
"redispatched": redispatched,
|
||||
"failed": failed,
|
||||
}
|
||||
except Exception:
|
||||
await db.rollback()
|
||||
raise
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
return asyncio.run(_recover())
|
||||
Reference in New Issue
Block a user