feat: dynamic per-site CORS for public banner endpoints
Some checks failed
CI / API Lint (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled

DynamicCORSMedium middleware resolves the calling site's registered
domains (Site.domain + Site.additional_domains) and returns the
appropriate Access-Control-Allow-Origin header for banner script
requests to public endpoints:

- GET /api/v1/config/sites/{site_id}
- GET /api/v1/translations/{site_id}/{locale}
- POST /api/v1/consent/

Instead of hardcoding ALLOWED_ORIGINS env var, each merchant website
automatically gets CORS access as long as its domain is registered
in the site configuration.
This commit is contained in:
Kunthawat Greethong
2026-06-15 18:36:23 +07:00
parent 683aa2379d
commit 7973cb321e
2 changed files with 127 additions and 1 deletions

View File

@@ -10,6 +10,7 @@ from src.config.settings import get_settings
from src.extensions.registry import discover_extensions, get_registry from src.extensions.registry import discover_extensions, get_registry
from src.middleware.rate_limit import RateLimitMiddleware from src.middleware.rate_limit import RateLimitMiddleware
from src.middleware.security_headers import SecurityHeadersMiddleware from src.middleware.security_headers import SecurityHeadersMiddleware
from src.middleware.dynamic_cors import DynamicCORSMedium
from src.routers import ( from src.routers import (
auth, auth,
compliance, compliance,
@@ -116,7 +117,10 @@ def create_app() -> FastAPI:
auth_requests_per_minute=10, auth_requests_per_minute=10,
) )
# CORS # CORS — DynamicCORSMedium must come BEFORE CORSMiddleware so it can
# add per-site allowed origins for public banner endpoints
app.add_middleware(DynamicCORSMedium)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.allowed_origins_list, allow_origins=settings.allowed_origins_list,

View File

@@ -0,0 +1,122 @@
"""Dynamic CORS middleware for public API endpoints.
Public endpoints (banner config, translations) are called from merchant
websites. Instead of hardcoding allowed origins, this middleware resolves
the calling origin dynamically from the site's registered domains.
Only handles the CORS preflight (OPTIONS) and CORS headers for public
routes. All other requests pass through unchanged.
"""
from __future__ import annotations
import uuid
from typing import Callable
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from src.db import get_db
from src.models.site import Site
# Routes that are considered "public" (called by the banner script)
PUBLIC_PATH_PREFIXES = (
"/api/v1/config/sites/",
"/api/v1/translations/",
"/api/v1/consent/",
)
def _is_public_path(path: str) -> bool:
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
def _extract_site_id_from_path(path: str) -> str | None:
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
parts = path.split("/")
# /api/v1/config/sites/{site_id} or /api/v1/translations/{site_id}/{locale}
if len(parts) >= 5:
prefix = "/".join(parts[:4]) # e.g. /api/v1/config/sites
if prefix in ("/api/v1/config/sites", "/api/v1/translations"):
return parts[4]
return None
async def _get_allowed_origins_for_site(
db: AsyncSession, site_id: str
) -> set[str]:
"""Return all registered domains (primary + additional) for a site."""
try:
site_uuid = uuid.UUID(site_id)
except (ValueError, TypeError):
return set()
result = await db.execute(
select(Site).where(
Site.id == site_uuid,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if not site:
return set()
origins: set[str] = {site.domain}
if site.additional_domains:
origins.update(site.additional_domains)
return origins
class DynamicCORSMedium(BaseHTTPMiddleware):
"""Add CORS headers dynamically based on the calling site's registered domains.
This middleware only acts on requests to public endpoints that may be
called from merchant websites (banner script calls). It reads the
``Origin`` header, resolves the site_id from the URL path, and checks
whether that origin matches a registered domain for the site.
"""
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
# Only handle public paths
if not _is_public_path(request.url.path):
return await call_next(request)
origin = request.headers.get("origin", "")
# If no Origin header, nothing to do
if not origin:
return await call_next(request)
site_id = _extract_site_id_from_path(request.url.path)
if site_id:
# Async db session required — get from request state if already set
# by the dependency injection, otherwise create a new one
db = request.state.db if hasattr(request.state, "db") else None
if db is None:
# Fall back to a new session for this lookup only
allowed: set[str] = set()
async for session in get_db():
allowed = await _get_allowed_origins_for_site(session, site_id)
break
else:
allowed = await _get_allowed_origins_for_site(db, site_id)
if origin in allowed:
response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Allow-Methods"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
return response
# Site not found or origin not registered — let the normal
# CORS middleware/global allowed_origins handle it
return await call_next(request)