feat: dynamic per-site CORS for public banner endpoints
Some checks failed
CI / API Lint (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
Some checks failed
CI / API Lint (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
DynamicCORSMedium middleware resolves the calling site's registered
domains (Site.domain + Site.additional_domains) and returns the
appropriate Access-Control-Allow-Origin header for banner script
requests to public endpoints:
- GET /api/v1/config/sites/{site_id}
- GET /api/v1/translations/{site_id}/{locale}
- POST /api/v1/consent/
Instead of hardcoding ALLOWED_ORIGINS env var, each merchant website
automatically gets CORS access as long as its domain is registered
in the site configuration.
This commit is contained in:
@@ -10,6 +10,7 @@ from src.config.settings import get_settings
|
|||||||
from src.extensions.registry import discover_extensions, get_registry
|
from src.extensions.registry import discover_extensions, get_registry
|
||||||
from src.middleware.rate_limit import RateLimitMiddleware
|
from src.middleware.rate_limit import RateLimitMiddleware
|
||||||
from src.middleware.security_headers import SecurityHeadersMiddleware
|
from src.middleware.security_headers import SecurityHeadersMiddleware
|
||||||
|
from src.middleware.dynamic_cors import DynamicCORSMedium
|
||||||
from src.routers import (
|
from src.routers import (
|
||||||
auth,
|
auth,
|
||||||
compliance,
|
compliance,
|
||||||
@@ -116,7 +117,10 @@ def create_app() -> FastAPI:
|
|||||||
auth_requests_per_minute=10,
|
auth_requests_per_minute=10,
|
||||||
)
|
)
|
||||||
|
|
||||||
# CORS
|
# CORS — DynamicCORSMedium must come BEFORE CORSMiddleware so it can
|
||||||
|
# add per-site allowed origins for public banner endpoints
|
||||||
|
app.add_middleware(DynamicCORSMedium)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=settings.allowed_origins_list,
|
allow_origins=settings.allowed_origins_list,
|
||||||
|
|||||||
122
apps/api/src/middleware/dynamic_cors.py
Normal file
122
apps/api/src/middleware/dynamic_cors.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Dynamic CORS middleware for public API endpoints.
|
||||||
|
|
||||||
|
Public endpoints (banner config, translations) are called from merchant
|
||||||
|
websites. Instead of hardcoding allowed origins, this middleware resolves
|
||||||
|
the calling origin dynamically from the site's registered domains.
|
||||||
|
|
||||||
|
Only handles the CORS preflight (OPTIONS) and CORS headers for public
|
||||||
|
routes. All other requests pass through unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import Response
|
||||||
|
|
||||||
|
from src.db import get_db
|
||||||
|
from src.models.site import Site
|
||||||
|
|
||||||
|
|
||||||
|
# Routes that are considered "public" (called by the banner script)
|
||||||
|
PUBLIC_PATH_PREFIXES = (
|
||||||
|
"/api/v1/config/sites/",
|
||||||
|
"/api/v1/translations/",
|
||||||
|
"/api/v1/consent/",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_public_path(path: str) -> bool:
|
||||||
|
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_site_id_from_path(path: str) -> str | None:
|
||||||
|
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
|
||||||
|
parts = path.split("/")
|
||||||
|
# /api/v1/config/sites/{site_id} or /api/v1/translations/{site_id}/{locale}
|
||||||
|
if len(parts) >= 5:
|
||||||
|
prefix = "/".join(parts[:4]) # e.g. /api/v1/config/sites
|
||||||
|
if prefix in ("/api/v1/config/sites", "/api/v1/translations"):
|
||||||
|
return parts[4]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_allowed_origins_for_site(
|
||||||
|
db: AsyncSession, site_id: str
|
||||||
|
) -> set[str]:
|
||||||
|
"""Return all registered domains (primary + additional) for a site."""
|
||||||
|
try:
|
||||||
|
site_uuid = uuid.UUID(site_id)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
return set()
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Site).where(
|
||||||
|
Site.id == site_uuid,
|
||||||
|
Site.is_active.is_(True),
|
||||||
|
Site.deleted_at.is_(None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
site = result.scalar_one_or_none()
|
||||||
|
if not site:
|
||||||
|
return set()
|
||||||
|
|
||||||
|
origins: set[str] = {site.domain}
|
||||||
|
if site.additional_domains:
|
||||||
|
origins.update(site.additional_domains)
|
||||||
|
return origins
|
||||||
|
|
||||||
|
|
||||||
|
class DynamicCORSMedium(BaseHTTPMiddleware):
|
||||||
|
"""Add CORS headers dynamically based on the calling site's registered domains.
|
||||||
|
|
||||||
|
This middleware only acts on requests to public endpoints that may be
|
||||||
|
called from merchant websites (banner script calls). It reads the
|
||||||
|
``Origin`` header, resolves the site_id from the URL path, and checks
|
||||||
|
whether that origin matches a registered domain for the site.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def dispatch(
|
||||||
|
self, request: Request, call_next: RequestResponseEndpoint
|
||||||
|
) -> Response:
|
||||||
|
# Only handle public paths
|
||||||
|
if not _is_public_path(request.url.path):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
origin = request.headers.get("origin", "")
|
||||||
|
|
||||||
|
# If no Origin header, nothing to do
|
||||||
|
if not origin:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
site_id = _extract_site_id_from_path(request.url.path)
|
||||||
|
|
||||||
|
if site_id:
|
||||||
|
# Async db session required — get from request state if already set
|
||||||
|
# by the dependency injection, otherwise create a new one
|
||||||
|
db = request.state.db if hasattr(request.state, "db") else None
|
||||||
|
if db is None:
|
||||||
|
# Fall back to a new session for this lookup only
|
||||||
|
allowed: set[str] = set()
|
||||||
|
async for session in get_db():
|
||||||
|
allowed = await _get_allowed_origins_for_site(session, site_id)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
allowed = await _get_allowed_origins_for_site(db, site_id)
|
||||||
|
|
||||||
|
if origin in allowed:
|
||||||
|
response = await call_next(request)
|
||||||
|
response.headers["Access-Control-Allow-Origin"] = origin
|
||||||
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
|
response.headers["Access-Control-Allow-Methods"] = "*"
|
||||||
|
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Site not found or origin not registered — let the normal
|
||||||
|
# CORS middleware/global allowed_origins handle it
|
||||||
|
return await call_next(request)
|
||||||
Reference in New Issue
Block a user