From 27a3e777ae829e2158a15959c733f84d35f23b1f Mon Sep 17 00:00:00 2001 From: Kunthawat Greethong Date: Mon, 15 Jun 2026 21:12:59 +0700 Subject: [PATCH] fix: wildcard CORS for public banner API endpoints Replace the fragile per-site dynamic CORS middleware with a public banner CORS middleware that allows non-credentialed wildcard CORS only for banner endpoints: - /api/v1/config/sites/* - /api/v1/translations/* - /api/v1/consent/ Admin/auth endpoints remain governed by the normal ALLOWED_ORIGINS based CORSMiddleware. Add regression tests for public GET/preflight behavior and for avoiding wildcard CORS on non-public endpoints. --- apps/api/src/main.py | 12 +- apps/api/src/middleware/dynamic_cors.py | 134 ------------------ apps/api/src/middleware/public_banner_cors.py | 73 ++++++++++ apps/api/tests/test_public_banner_cors.py | 84 +++++++++++ 4 files changed, 164 insertions(+), 139 deletions(-) delete mode 100644 apps/api/src/middleware/dynamic_cors.py create mode 100644 apps/api/src/middleware/public_banner_cors.py create mode 100644 apps/api/tests/test_public_banner_cors.py diff --git a/apps/api/src/main.py b/apps/api/src/main.py index b1ebdc6..b32a90e 100644 --- a/apps/api/src/main.py +++ b/apps/api/src/main.py @@ -8,9 +8,9 @@ from src.config.edition import edition_name from src.config.logging import setup_logging from src.config.settings import get_settings from src.extensions.registry import discover_extensions, get_registry +from src.middleware.public_banner_cors import PublicBannerCORSMiddleware from src.middleware.rate_limit import RateLimitMiddleware from src.middleware.security_headers import SecurityHeadersMiddleware -from src.middleware.dynamic_cors import DynamicCORSMedium from src.routers import ( auth, compliance, @@ -117,10 +117,8 @@ def create_app() -> FastAPI: auth_requests_per_minute=10, ) - # CORS — DynamicCORSMedium must come BEFORE CORSMiddleware so it can - # add per-site allowed origins for public banner endpoints - app.add_middleware(DynamicCORSMedium) - + # CORS for admin/auth endpoints. Public banner endpoints get wildcard, + # non-credentialed CORS from PublicBannerCORSMiddleware below. app.add_middleware( CORSMiddleware, allow_origins=settings.allowed_origins_list, @@ -129,6 +127,10 @@ def create_app() -> FastAPI: allow_headers=["*"], ) + # Add this AFTER CORSMiddleware so it becomes the outermost middleware and + # can override/remove credentialed CORS headers for public banner endpoints. + app.add_middleware(PublicBannerCORSMiddleware) + # Core routers api_prefix = "/api/v1" app.include_router(auth.router, prefix=api_prefix) diff --git a/apps/api/src/middleware/dynamic_cors.py b/apps/api/src/middleware/dynamic_cors.py deleted file mode 100644 index 2a40cc1..0000000 --- a/apps/api/src/middleware/dynamic_cors.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Dynamic CORS middleware for public API endpoints. - -Public endpoints (banner config, translations) are called from merchant -websites. Instead of hardcoding allowed origins, this middleware resolves -the calling origin dynamically from the site's registered domains. - -Only handles the CORS preflight (OPTIONS) and CORS headers for public -routes. All other requests pass through unchanged. -""" - -from __future__ import annotations - -import uuid -from typing import Callable - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint -from starlette.requests import Request -from starlette.responses import Response - -from src.db import get_db -from src.models.site import Site - - -# Routes that are considered "public" (called by the banner script) -PUBLIC_PATH_PREFIXES = ( - "/api/v1/config/sites/", - "/api/v1/translations/", - "/api/v1/consent/", -) - - -def _is_public_path(path: str) -> bool: - return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES) - - -def _normalize_domain(domain: str) -> str: - """Strip leading www. for comparison purposes.""" - return domain.removeprefix("www.") - - -def _extract_site_id_from_path(path: str) -> str | None: - """Extract site_id UUID from public config path like /api/v1/config/sites/{id}.""" - parts = path.split("/") - # /api/v1/config/sites/{site_id} or /api/v1/translations/{site_id}/{locale} - if len(parts) >= 5: - prefix = "/".join(parts[:4]) # e.g. /api/v1/config/sites - if prefix in ("/api/v1/config/sites", "/api/v1/translations"): - return parts[4] - return None - - -async def _get_allowed_origins_for_site( - db: AsyncSession, site_id: str -) -> set[str]: - """Return all registered domains (primary + additional) for a site. - - Domains are stored without www. prefix; the origin is similarly - stripped before comparison so both variants match. - """ - try: - site_uuid = uuid.UUID(site_id) - except (ValueError, TypeError): - return set() - - result = await db.execute( - select(Site).where( - Site.id == site_uuid, - Site.is_active.is_(True), - Site.deleted_at.is_(None), - ) - ) - site = result.scalar_one_or_none() - if not site: - return set() - - origins: set[str] = {_normalize_domain(site.domain)} - if site.additional_domains: - origins.update(_normalize_domain(d) for d in site.additional_domains) - return origins - - -class DynamicCORSMedium(BaseHTTPMiddleware): - """Add CORS headers dynamically based on the calling site's registered domains. - - This middleware only acts on requests to public endpoints that may be - called from merchant websites (banner script calls). It reads the - ``Origin`` header, resolves the site_id from the URL path, and checks - whether that origin matches a registered domain for the site. - """ - - async def dispatch( - self, request: Request, call_next: RequestResponseEndpoint - ) -> Response: - # Only handle public paths - if not _is_public_path(request.url.path): - return await call_next(request) - - origin = request.headers.get("origin", "") - - # If no Origin header, nothing to do - if not origin: - return await call_next(request) - - site_id = _extract_site_id_from_path(request.url.path) - - if not site_id: - return await call_next(request) - - # Async db session required — get from request state if already set - # by the dependency injection, otherwise create a new one - db = request.state.db if hasattr(request.state, "db") else None - if db is None: - # Fall back to a new session for this lookup only - allowed: set[str] = set() - async for session in get_db(): - allowed = await _get_allowed_origins_for_site(session, site_id) - break - else: - allowed = await _get_allowed_origins_for_site(db, site_id) - - # Normalize origin (strip www.) before comparing - if _normalize_domain(origin) in allowed: - response = await call_next(request) - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Credentials"] = "true" - response.headers["Access-Control-Allow-Methods"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" - return response - - # Site not found or origin not registered — let the normal - # CORS middleware/global allowed_origins handle it - return await call_next(request) diff --git a/apps/api/src/middleware/public_banner_cors.py b/apps/api/src/middleware/public_banner_cors.py new file mode 100644 index 0000000..4a390f5 --- /dev/null +++ b/apps/api/src/middleware/public_banner_cors.py @@ -0,0 +1,73 @@ +"""CORS middleware for public banner API endpoints. + +The banner script is embedded on merchant websites, so public banner +endpoints must be readable from arbitrary origins. Admin/auth endpoints +remain protected by the normal FastAPI CORSMiddleware configured with +ALLOWED_ORIGINS. +""" + +from __future__ import annotations + +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from starlette.requests import Request +from starlette.responses import Response + +PUBLIC_BANNER_PATH_PREFIXES = ( + "/api/v1/config/sites/", + "/api/v1/translations/", + "/api/v1/consent/", +) + +DEFAULT_ALLOWED_HEADERS = "Content-Type, Authorization, X-Requested-With" +ALLOWED_METHODS = "GET, POST, OPTIONS" +MAX_AGE_SECONDS = "86400" + + +def is_public_banner_path(path: str) -> bool: + """Return True for public endpoints called by the banner script.""" + return any(path.startswith(prefix) for prefix in PUBLIC_BANNER_PATH_PREFIXES) + + +def _public_cors_headers(request: Request) -> dict[str, str]: + requested_headers = request.headers.get( + "access-control-request-headers", + DEFAULT_ALLOWED_HEADERS, + ) + return { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": ALLOWED_METHODS, + "Access-Control-Allow-Headers": requested_headers, + "Access-Control-Max-Age": MAX_AGE_SECONDS, + } + + +def _remove_credentials_header(response: Response) -> None: + # Wildcard origins must not be paired with credentials. The normal global + # CORSMiddleware may add this header because admin endpoints use credentialed + # auth; strip it for public banner endpoints. + if "access-control-allow-credentials" in response.headers: + del response.headers["access-control-allow-credentials"] + + +class PublicBannerCORSMiddleware(BaseHTTPMiddleware): + """Allow non-credentialed CORS for public banner endpoints only.""" + + async def dispatch( + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + if not is_public_banner_path(request.url.path): + return await call_next(request) + + # No Origin header means this is not a browser CORS request. + if "origin" not in request.headers: + return await call_next(request) + + if request.method == "OPTIONS": + return Response(status_code=204, headers=_public_cors_headers(request)) + + response = await call_next(request) + response.headers.update(_public_cors_headers(request)) + _remove_credentials_header(response) + return response diff --git a/apps/api/tests/test_public_banner_cors.py b/apps/api/tests/test_public_banner_cors.py new file mode 100644 index 0000000..5a3440b --- /dev/null +++ b/apps/api/tests/test_public_banner_cors.py @@ -0,0 +1,84 @@ +"""CORS behavior for public banner endpoints. + +The banner API is embedded on merchant websites, so these public endpoints +must be readable cross-origin without relying on the admin ALLOWED_ORIGINS +setting. +""" + +from __future__ import annotations + +import uuid +from unittest.mock import AsyncMock, MagicMock + +import pytest +from httpx import ASGITransport, AsyncClient + +from src.db import get_db +from src.main import create_app + + +def _db_returning_not_found() -> AsyncMock: + session = AsyncMock() + result = MagicMock() + result.scalar_one_or_none.return_value = None + session.execute = AsyncMock(return_value=result) + return session + + +@pytest.mark.asyncio +async def test_public_config_get_allows_any_merchant_origin(): + app = create_app() + db = _db_returning_not_found() + + async def _override_get_db(): + yield db + + app.dependency_overrides[get_db] = _override_get_db + + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + f"/api/v1/config/sites/{uuid.uuid4()}", + headers={"Origin": "https://www.dealplustech.co.th"}, + ) + + assert resp.status_code == 404 + assert resp.headers["access-control-allow-origin"] == "*" + assert "access-control-allow-credentials" not in resp.headers + + +@pytest.mark.asyncio +async def test_public_consent_preflight_allows_json_post_from_any_merchant_origin(): + app = create_app() + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.options( + "/api/v1/consent/", + headers={ + "Origin": "https://www.dealplustech.co.th", + "Access-Control-Request-Method": "POST", + "Access-Control-Request-Headers": "content-type", + }, + ) + + assert resp.status_code == 204 + assert resp.headers["access-control-allow-origin"] == "*" + assert "POST" in resp.headers["access-control-allow-methods"] + assert "content-type" in resp.headers["access-control-allow-headers"].lower() + assert "access-control-allow-credentials" not in resp.headers + + +@pytest.mark.asyncio +async def test_non_public_endpoint_does_not_get_public_wildcard_cors(): + app = create_app() + transport = ASGITransport(app=app) + + async with AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.get( + "/health", + headers={"Origin": "https://www.dealplustech.co.th"}, + ) + + assert resp.status_code == 200 + assert resp.headers.get("access-control-allow-origin") != "*"