From 7973cb321ee702133681d359d3a73111609c826a Mon Sep 17 00:00:00 2001 From: Kunthawat Greethong Date: Mon, 15 Jun 2026 18:36:23 +0700 Subject: [PATCH] feat: dynamic per-site CORS for public banner endpoints DynamicCORSMedium middleware resolves the calling site's registered domains (Site.domain + Site.additional_domains) and returns the appropriate Access-Control-Allow-Origin header for banner script requests to public endpoints: - GET /api/v1/config/sites/{site_id} - GET /api/v1/translations/{site_id}/{locale} - POST /api/v1/consent/ Instead of hardcoding ALLOWED_ORIGINS env var, each merchant website automatically gets CORS access as long as its domain is registered in the site configuration. --- apps/api/src/main.py | 6 +- apps/api/src/middleware/dynamic_cors.py | 122 ++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) create mode 100644 apps/api/src/middleware/dynamic_cors.py diff --git a/apps/api/src/main.py b/apps/api/src/main.py index 808ea98..b1ebdc6 100644 --- a/apps/api/src/main.py +++ b/apps/api/src/main.py @@ -10,6 +10,7 @@ from src.config.settings import get_settings from src.extensions.registry import discover_extensions, get_registry from src.middleware.rate_limit import RateLimitMiddleware from src.middleware.security_headers import SecurityHeadersMiddleware +from src.middleware.dynamic_cors import DynamicCORSMedium from src.routers import ( auth, compliance, @@ -116,7 +117,10 @@ def create_app() -> FastAPI: auth_requests_per_minute=10, ) - # CORS + # CORS — DynamicCORSMedium must come BEFORE CORSMiddleware so it can + # add per-site allowed origins for public banner endpoints + app.add_middleware(DynamicCORSMedium) + app.add_middleware( CORSMiddleware, allow_origins=settings.allowed_origins_list, diff --git a/apps/api/src/middleware/dynamic_cors.py b/apps/api/src/middleware/dynamic_cors.py new file mode 100644 index 0000000..dbc7d5a --- /dev/null +++ b/apps/api/src/middleware/dynamic_cors.py @@ -0,0 +1,122 @@ +"""Dynamic CORS middleware for public API endpoints. + +Public endpoints (banner config, translations) are called from merchant +websites. Instead of hardcoding allowed origins, this middleware resolves +the calling origin dynamically from the site's registered domains. + +Only handles the CORS preflight (OPTIONS) and CORS headers for public +routes. All other requests pass through unchanged. +""" + +from __future__ import annotations + +import uuid +from typing import Callable + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response + +from src.db import get_db +from src.models.site import Site + + +# Routes that are considered "public" (called by the banner script) +PUBLIC_PATH_PREFIXES = ( + "/api/v1/config/sites/", + "/api/v1/translations/", + "/api/v1/consent/", +) + + +def _is_public_path(path: str) -> bool: + return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES) + + +def _extract_site_id_from_path(path: str) -> str | None: + """Extract site_id UUID from public config path like /api/v1/config/sites/{id}.""" + parts = path.split("/") + # /api/v1/config/sites/{site_id} or /api/v1/translations/{site_id}/{locale} + if len(parts) >= 5: + prefix = "/".join(parts[:4]) # e.g. /api/v1/config/sites + if prefix in ("/api/v1/config/sites", "/api/v1/translations"): + return parts[4] + return None + + +async def _get_allowed_origins_for_site( + db: AsyncSession, site_id: str +) -> set[str]: + """Return all registered domains (primary + additional) for a site.""" + try: + site_uuid = uuid.UUID(site_id) + except (ValueError, TypeError): + return set() + + result = await db.execute( + select(Site).where( + Site.id == site_uuid, + Site.is_active.is_(True), + Site.deleted_at.is_(None), + ) + ) + site = result.scalar_one_or_none() + if not site: + return set() + + origins: set[str] = {site.domain} + if site.additional_domains: + origins.update(site.additional_domains) + return origins + + +class DynamicCORSMedium(BaseHTTPMiddleware): + """Add CORS headers dynamically based on the calling site's registered domains. + + This middleware only acts on requests to public endpoints that may be + called from merchant websites (banner script calls). It reads the + ``Origin`` header, resolves the site_id from the URL path, and checks + whether that origin matches a registered domain for the site. + """ + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + # Only handle public paths + if not _is_public_path(request.url.path): + return await call_next(request) + + origin = request.headers.get("origin", "") + + # If no Origin header, nothing to do + if not origin: + return await call_next(request) + + site_id = _extract_site_id_from_path(request.url.path) + + if site_id: + # Async db session required — get from request state if already set + # by the dependency injection, otherwise create a new one + db = request.state.db if hasattr(request.state, "db") else None + if db is None: + # Fall back to a new session for this lookup only + allowed: set[str] = set() + async for session in get_db(): + allowed = await _get_allowed_origins_for_site(session, site_id) + break + else: + allowed = await _get_allowed_origins_for_site(db, site_id) + + if origin in allowed: + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = origin + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["Access-Control-Allow-Methods"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + return response + + # Site not found or origin not registered — let the normal + # CORS middleware/global allowed_origins handle it + return await call_next(request)