feat: dynamic per-site CORS for public banner endpoints
Some checks failed
CI / API Lint (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
Some checks failed
CI / API Lint (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
DynamicCORSMedium middleware resolves the calling site's registered
domains (Site.domain + Site.additional_domains) and returns the
appropriate Access-Control-Allow-Origin header for banner script
requests to public endpoints:
- GET /api/v1/config/sites/{site_id}
- GET /api/v1/translations/{site_id}/{locale}
- POST /api/v1/consent/
Instead of hardcoding ALLOWED_ORIGINS env var, each merchant website
automatically gets CORS access as long as its domain is registered
in the site configuration.
This commit is contained in:
@@ -10,6 +10,7 @@ from src.config.settings import get_settings
|
||||
from src.extensions.registry import discover_extensions, get_registry
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from src.middleware.dynamic_cors import DynamicCORSMedium
|
||||
from src.routers import (
|
||||
auth,
|
||||
compliance,
|
||||
@@ -116,7 +117,10 @@ def create_app() -> FastAPI:
|
||||
auth_requests_per_minute=10,
|
||||
)
|
||||
|
||||
# CORS
|
||||
# CORS — DynamicCORSMedium must come BEFORE CORSMiddleware so it can
|
||||
# add per-site allowed origins for public banner endpoints
|
||||
app.add_middleware(DynamicCORSMedium)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.allowed_origins_list,
|
||||
|
||||
122
apps/api/src/middleware/dynamic_cors.py
Normal file
122
apps/api/src/middleware/dynamic_cors.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Dynamic CORS middleware for public API endpoints.
|
||||
|
||||
Public endpoints (banner config, translations) are called from merchant
|
||||
websites. Instead of hardcoding allowed origins, this middleware resolves
|
||||
the calling origin dynamically from the site's registered domains.
|
||||
|
||||
Only handles the CORS preflight (OPTIONS) and CORS headers for public
|
||||
routes. All other requests pass through unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
|
||||
|
||||
# Routes that are considered "public" (called by the banner script)
|
||||
PUBLIC_PATH_PREFIXES = (
|
||||
"/api/v1/config/sites/",
|
||||
"/api/v1/translations/",
|
||||
"/api/v1/consent/",
|
||||
)
|
||||
|
||||
|
||||
def _is_public_path(path: str) -> bool:
|
||||
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
def _extract_site_id_from_path(path: str) -> str | None:
|
||||
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
|
||||
parts = path.split("/")
|
||||
# /api/v1/config/sites/{site_id} or /api/v1/translations/{site_id}/{locale}
|
||||
if len(parts) >= 5:
|
||||
prefix = "/".join(parts[:4]) # e.g. /api/v1/config/sites
|
||||
if prefix in ("/api/v1/config/sites", "/api/v1/translations"):
|
||||
return parts[4]
|
||||
return None
|
||||
|
||||
|
||||
async def _get_allowed_origins_for_site(
|
||||
db: AsyncSession, site_id: str
|
||||
) -> set[str]:
|
||||
"""Return all registered domains (primary + additional) for a site."""
|
||||
try:
|
||||
site_uuid = uuid.UUID(site_id)
|
||||
except (ValueError, TypeError):
|
||||
return set()
|
||||
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_uuid,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if not site:
|
||||
return set()
|
||||
|
||||
origins: set[str] = {site.domain}
|
||||
if site.additional_domains:
|
||||
origins.update(site.additional_domains)
|
||||
return origins
|
||||
|
||||
|
||||
class DynamicCORSMedium(BaseHTTPMiddleware):
|
||||
"""Add CORS headers dynamically based on the calling site's registered domains.
|
||||
|
||||
This middleware only acts on requests to public endpoints that may be
|
||||
called from merchant websites (banner script calls). It reads the
|
||||
``Origin`` header, resolves the site_id from the URL path, and checks
|
||||
whether that origin matches a registered domain for the site.
|
||||
"""
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
# Only handle public paths
|
||||
if not _is_public_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
origin = request.headers.get("origin", "")
|
||||
|
||||
# If no Origin header, nothing to do
|
||||
if not origin:
|
||||
return await call_next(request)
|
||||
|
||||
site_id = _extract_site_id_from_path(request.url.path)
|
||||
|
||||
if site_id:
|
||||
# Async db session required — get from request state if already set
|
||||
# by the dependency injection, otherwise create a new one
|
||||
db = request.state.db if hasattr(request.state, "db") else None
|
||||
if db is None:
|
||||
# Fall back to a new session for this lookup only
|
||||
allowed: set[str] = set()
|
||||
async for session in get_db():
|
||||
allowed = await _get_allowed_origins_for_site(session, site_id)
|
||||
break
|
||||
else:
|
||||
allowed = await _get_allowed_origins_for_site(db, site_id)
|
||||
|
||||
if origin in allowed:
|
||||
response = await call_next(request)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Allow-Methods"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
return response
|
||||
|
||||
# Site not found or origin not registered — let the normal
|
||||
# CORS middleware/global allowed_origins handle it
|
||||
return await call_next(request)
|
||||
Reference in New Issue
Block a user