fix: wildcard CORS for public banner API endpoints
Some checks failed
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled

Replace the fragile per-site dynamic CORS middleware with a public banner
CORS middleware that allows non-credentialed wildcard CORS only for banner
endpoints:

- /api/v1/config/sites/*
- /api/v1/translations/*
- /api/v1/consent/

Admin/auth endpoints remain governed by the normal ALLOWED_ORIGINS based
CORSMiddleware. Add regression tests for public GET/preflight behavior and
for avoiding wildcard CORS on non-public endpoints.
This commit is contained in:
Kunthawat Greethong
2026-06-15 21:12:59 +07:00
parent 0bba7ef21a
commit 27a3e777ae
4 changed files with 164 additions and 139 deletions

View File

@@ -8,9 +8,9 @@ from src.config.edition import edition_name
from src.config.logging import setup_logging from src.config.logging import setup_logging
from src.config.settings import get_settings from src.config.settings import get_settings
from src.extensions.registry import discover_extensions, get_registry from src.extensions.registry import discover_extensions, get_registry
from src.middleware.public_banner_cors import PublicBannerCORSMiddleware
from src.middleware.rate_limit import RateLimitMiddleware from src.middleware.rate_limit import RateLimitMiddleware
from src.middleware.security_headers import SecurityHeadersMiddleware from src.middleware.security_headers import SecurityHeadersMiddleware
from src.middleware.dynamic_cors import DynamicCORSMedium
from src.routers import ( from src.routers import (
auth, auth,
compliance, compliance,
@@ -117,10 +117,8 @@ def create_app() -> FastAPI:
auth_requests_per_minute=10, auth_requests_per_minute=10,
) )
# CORS — DynamicCORSMedium must come BEFORE CORSMiddleware so it can # CORS for admin/auth endpoints. Public banner endpoints get wildcard,
# add per-site allowed origins for public banner endpoints # non-credentialed CORS from PublicBannerCORSMiddleware below.
app.add_middleware(DynamicCORSMedium)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.allowed_origins_list, allow_origins=settings.allowed_origins_list,
@@ -129,6 +127,10 @@ def create_app() -> FastAPI:
allow_headers=["*"], allow_headers=["*"],
) )
# Add this AFTER CORSMiddleware so it becomes the outermost middleware and
# can override/remove credentialed CORS headers for public banner endpoints.
app.add_middleware(PublicBannerCORSMiddleware)
# Core routers # Core routers
api_prefix = "/api/v1" api_prefix = "/api/v1"
app.include_router(auth.router, prefix=api_prefix) app.include_router(auth.router, prefix=api_prefix)

View File

@@ -1,134 +0,0 @@
"""Dynamic CORS middleware for public API endpoints.
Public endpoints (banner config, translations) are called from merchant
websites. Instead of hardcoding allowed origins, this middleware resolves
the calling origin dynamically from the site's registered domains.
Only handles the CORS preflight (OPTIONS) and CORS headers for public
routes. All other requests pass through unchanged.
"""
from __future__ import annotations
import uuid
from typing import Callable
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from src.db import get_db
from src.models.site import Site
# Routes that are considered "public" (called by the banner script)
PUBLIC_PATH_PREFIXES = (
"/api/v1/config/sites/",
"/api/v1/translations/",
"/api/v1/consent/",
)
def _is_public_path(path: str) -> bool:
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
def _normalize_domain(domain: str) -> str:
"""Strip leading www. for comparison purposes."""
return domain.removeprefix("www.")
def _extract_site_id_from_path(path: str) -> str | None:
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
parts = path.split("/")
# /api/v1/config/sites/{site_id} or /api/v1/translations/{site_id}/{locale}
if len(parts) >= 5:
prefix = "/".join(parts[:4]) # e.g. /api/v1/config/sites
if prefix in ("/api/v1/config/sites", "/api/v1/translations"):
return parts[4]
return None
async def _get_allowed_origins_for_site(
db: AsyncSession, site_id: str
) -> set[str]:
"""Return all registered domains (primary + additional) for a site.
Domains are stored without www. prefix; the origin is similarly
stripped before comparison so both variants match.
"""
try:
site_uuid = uuid.UUID(site_id)
except (ValueError, TypeError):
return set()
result = await db.execute(
select(Site).where(
Site.id == site_uuid,
Site.is_active.is_(True),
Site.deleted_at.is_(None),
)
)
site = result.scalar_one_or_none()
if not site:
return set()
origins: set[str] = {_normalize_domain(site.domain)}
if site.additional_domains:
origins.update(_normalize_domain(d) for d in site.additional_domains)
return origins
class DynamicCORSMedium(BaseHTTPMiddleware):
"""Add CORS headers dynamically based on the calling site's registered domains.
This middleware only acts on requests to public endpoints that may be
called from merchant websites (banner script calls). It reads the
``Origin`` header, resolves the site_id from the URL path, and checks
whether that origin matches a registered domain for the site.
"""
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
# Only handle public paths
if not _is_public_path(request.url.path):
return await call_next(request)
origin = request.headers.get("origin", "")
# If no Origin header, nothing to do
if not origin:
return await call_next(request)
site_id = _extract_site_id_from_path(request.url.path)
if not site_id:
return await call_next(request)
# Async db session required — get from request state if already set
# by the dependency injection, otherwise create a new one
db = request.state.db if hasattr(request.state, "db") else None
if db is None:
# Fall back to a new session for this lookup only
allowed: set[str] = set()
async for session in get_db():
allowed = await _get_allowed_origins_for_site(session, site_id)
break
else:
allowed = await _get_allowed_origins_for_site(db, site_id)
# Normalize origin (strip www.) before comparing
if _normalize_domain(origin) in allowed:
response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true"
response.headers["Access-Control-Allow-Methods"] = "*"
response.headers["Access-Control-Allow-Headers"] = "*"
return response
# Site not found or origin not registered — let the normal
# CORS middleware/global allowed_origins handle it
return await call_next(request)

View File

@@ -0,0 +1,73 @@
"""CORS middleware for public banner API endpoints.
The banner script is embedded on merchant websites, so public banner
endpoints must be readable from arbitrary origins. Admin/auth endpoints
remain protected by the normal FastAPI CORSMiddleware configured with
ALLOWED_ORIGINS.
"""
from __future__ import annotations
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
PUBLIC_BANNER_PATH_PREFIXES = (
"/api/v1/config/sites/",
"/api/v1/translations/",
"/api/v1/consent/",
)
DEFAULT_ALLOWED_HEADERS = "Content-Type, Authorization, X-Requested-With"
ALLOWED_METHODS = "GET, POST, OPTIONS"
MAX_AGE_SECONDS = "86400"
def is_public_banner_path(path: str) -> bool:
"""Return True for public endpoints called by the banner script."""
return any(path.startswith(prefix) for prefix in PUBLIC_BANNER_PATH_PREFIXES)
def _public_cors_headers(request: Request) -> dict[str, str]:
requested_headers = request.headers.get(
"access-control-request-headers",
DEFAULT_ALLOWED_HEADERS,
)
return {
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Methods": ALLOWED_METHODS,
"Access-Control-Allow-Headers": requested_headers,
"Access-Control-Max-Age": MAX_AGE_SECONDS,
}
def _remove_credentials_header(response: Response) -> None:
# Wildcard origins must not be paired with credentials. The normal global
# CORSMiddleware may add this header because admin endpoints use credentialed
# auth; strip it for public banner endpoints.
if "access-control-allow-credentials" in response.headers:
del response.headers["access-control-allow-credentials"]
class PublicBannerCORSMiddleware(BaseHTTPMiddleware):
"""Allow non-credentialed CORS for public banner endpoints only."""
async def dispatch(
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
if not is_public_banner_path(request.url.path):
return await call_next(request)
# No Origin header means this is not a browser CORS request.
if "origin" not in request.headers:
return await call_next(request)
if request.method == "OPTIONS":
return Response(status_code=204, headers=_public_cors_headers(request))
response = await call_next(request)
response.headers.update(_public_cors_headers(request))
_remove_credentials_header(response)
return response

View File

@@ -0,0 +1,84 @@
"""CORS behavior for public banner endpoints.
The banner API is embedded on merchant websites, so these public endpoints
must be readable cross-origin without relying on the admin ALLOWED_ORIGINS
setting.
"""
from __future__ import annotations
import uuid
from unittest.mock import AsyncMock, MagicMock
import pytest
from httpx import ASGITransport, AsyncClient
from src.db import get_db
from src.main import create_app
def _db_returning_not_found() -> AsyncMock:
session = AsyncMock()
result = MagicMock()
result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=result)
return session
@pytest.mark.asyncio
async def test_public_config_get_allows_any_merchant_origin():
app = create_app()
db = _db_returning_not_found()
async def _override_get_db():
yield db
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get(
f"/api/v1/config/sites/{uuid.uuid4()}",
headers={"Origin": "https://www.dealplustech.co.th"},
)
assert resp.status_code == 404
assert resp.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in resp.headers
@pytest.mark.asyncio
async def test_public_consent_preflight_allows_json_post_from_any_merchant_origin():
app = create_app()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.options(
"/api/v1/consent/",
headers={
"Origin": "https://www.dealplustech.co.th",
"Access-Control-Request-Method": "POST",
"Access-Control-Request-Headers": "content-type",
},
)
assert resp.status_code == 204
assert resp.headers["access-control-allow-origin"] == "*"
assert "POST" in resp.headers["access-control-allow-methods"]
assert "content-type" in resp.headers["access-control-allow-headers"].lower()
assert "access-control-allow-credentials" not in resp.headers
@pytest.mark.asyncio
async def test_non_public_endpoint_does_not_get_public_wildcard_cors():
app = create_app()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
resp = await client.get(
"/health",
headers={"Origin": "https://www.dealplustech.co.th"},
)
assert resp.status_code == 200
assert resp.headers.get("access-control-allow-origin") != "*"