fix: wildcard CORS for public banner API endpoints
Some checks failed
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
Some checks failed
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
Replace the fragile per-site dynamic CORS middleware with a public banner CORS middleware that allows non-credentialed wildcard CORS only for banner endpoints: - /api/v1/config/sites/* - /api/v1/translations/* - /api/v1/consent/ Admin/auth endpoints remain governed by the normal ALLOWED_ORIGINS based CORSMiddleware. Add regression tests for public GET/preflight behavior and for avoiding wildcard CORS on non-public endpoints.
This commit is contained in:
@@ -8,9 +8,9 @@ from src.config.edition import edition_name
|
||||
from src.config.logging import setup_logging
|
||||
from src.config.settings import get_settings
|
||||
from src.extensions.registry import discover_extensions, get_registry
|
||||
from src.middleware.public_banner_cors import PublicBannerCORSMiddleware
|
||||
from src.middleware.rate_limit import RateLimitMiddleware
|
||||
from src.middleware.security_headers import SecurityHeadersMiddleware
|
||||
from src.middleware.dynamic_cors import DynamicCORSMedium
|
||||
from src.routers import (
|
||||
auth,
|
||||
compliance,
|
||||
@@ -117,10 +117,8 @@ def create_app() -> FastAPI:
|
||||
auth_requests_per_minute=10,
|
||||
)
|
||||
|
||||
# CORS — DynamicCORSMedium must come BEFORE CORSMiddleware so it can
|
||||
# add per-site allowed origins for public banner endpoints
|
||||
app.add_middleware(DynamicCORSMedium)
|
||||
|
||||
# CORS for admin/auth endpoints. Public banner endpoints get wildcard,
|
||||
# non-credentialed CORS from PublicBannerCORSMiddleware below.
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.allowed_origins_list,
|
||||
@@ -129,6 +127,10 @@ def create_app() -> FastAPI:
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add this AFTER CORSMiddleware so it becomes the outermost middleware and
|
||||
# can override/remove credentialed CORS headers for public banner endpoints.
|
||||
app.add_middleware(PublicBannerCORSMiddleware)
|
||||
|
||||
# Core routers
|
||||
api_prefix = "/api/v1"
|
||||
app.include_router(auth.router, prefix=api_prefix)
|
||||
|
||||
@@ -1,134 +0,0 @@
|
||||
"""Dynamic CORS middleware for public API endpoints.
|
||||
|
||||
Public endpoints (banner config, translations) are called from merchant
|
||||
websites. Instead of hardcoding allowed origins, this middleware resolves
|
||||
the calling origin dynamically from the site's registered domains.
|
||||
|
||||
Only handles the CORS preflight (OPTIONS) and CORS headers for public
|
||||
routes. All other requests pass through unchanged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
from src.db import get_db
|
||||
from src.models.site import Site
|
||||
|
||||
|
||||
# Routes that are considered "public" (called by the banner script)
|
||||
PUBLIC_PATH_PREFIXES = (
|
||||
"/api/v1/config/sites/",
|
||||
"/api/v1/translations/",
|
||||
"/api/v1/consent/",
|
||||
)
|
||||
|
||||
|
||||
def _is_public_path(path: str) -> bool:
|
||||
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
def _normalize_domain(domain: str) -> str:
|
||||
"""Strip leading www. for comparison purposes."""
|
||||
return domain.removeprefix("www.")
|
||||
|
||||
|
||||
def _extract_site_id_from_path(path: str) -> str | None:
|
||||
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
|
||||
parts = path.split("/")
|
||||
# /api/v1/config/sites/{site_id} or /api/v1/translations/{site_id}/{locale}
|
||||
if len(parts) >= 5:
|
||||
prefix = "/".join(parts[:4]) # e.g. /api/v1/config/sites
|
||||
if prefix in ("/api/v1/config/sites", "/api/v1/translations"):
|
||||
return parts[4]
|
||||
return None
|
||||
|
||||
|
||||
async def _get_allowed_origins_for_site(
|
||||
db: AsyncSession, site_id: str
|
||||
) -> set[str]:
|
||||
"""Return all registered domains (primary + additional) for a site.
|
||||
|
||||
Domains are stored without www. prefix; the origin is similarly
|
||||
stripped before comparison so both variants match.
|
||||
"""
|
||||
try:
|
||||
site_uuid = uuid.UUID(site_id)
|
||||
except (ValueError, TypeError):
|
||||
return set()
|
||||
|
||||
result = await db.execute(
|
||||
select(Site).where(
|
||||
Site.id == site_uuid,
|
||||
Site.is_active.is_(True),
|
||||
Site.deleted_at.is_(None),
|
||||
)
|
||||
)
|
||||
site = result.scalar_one_or_none()
|
||||
if not site:
|
||||
return set()
|
||||
|
||||
origins: set[str] = {_normalize_domain(site.domain)}
|
||||
if site.additional_domains:
|
||||
origins.update(_normalize_domain(d) for d in site.additional_domains)
|
||||
return origins
|
||||
|
||||
|
||||
class DynamicCORSMedium(BaseHTTPMiddleware):
|
||||
"""Add CORS headers dynamically based on the calling site's registered domains.
|
||||
|
||||
This middleware only acts on requests to public endpoints that may be
|
||||
called from merchant websites (banner script calls). It reads the
|
||||
``Origin`` header, resolves the site_id from the URL path, and checks
|
||||
whether that origin matches a registered domain for the site.
|
||||
"""
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: RequestResponseEndpoint
|
||||
) -> Response:
|
||||
# Only handle public paths
|
||||
if not _is_public_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
origin = request.headers.get("origin", "")
|
||||
|
||||
# If no Origin header, nothing to do
|
||||
if not origin:
|
||||
return await call_next(request)
|
||||
|
||||
site_id = _extract_site_id_from_path(request.url.path)
|
||||
|
||||
if not site_id:
|
||||
return await call_next(request)
|
||||
|
||||
# Async db session required — get from request state if already set
|
||||
# by the dependency injection, otherwise create a new one
|
||||
db = request.state.db if hasattr(request.state, "db") else None
|
||||
if db is None:
|
||||
# Fall back to a new session for this lookup only
|
||||
allowed: set[str] = set()
|
||||
async for session in get_db():
|
||||
allowed = await _get_allowed_origins_for_site(session, site_id)
|
||||
break
|
||||
else:
|
||||
allowed = await _get_allowed_origins_for_site(db, site_id)
|
||||
|
||||
# Normalize origin (strip www.) before comparing
|
||||
if _normalize_domain(origin) in allowed:
|
||||
response = await call_next(request)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Allow-Methods"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
return response
|
||||
|
||||
# Site not found or origin not registered — let the normal
|
||||
# CORS middleware/global allowed_origins handle it
|
||||
return await call_next(request)
|
||||
73
apps/api/src/middleware/public_banner_cors.py
Normal file
73
apps/api/src/middleware/public_banner_cors.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""CORS middleware for public banner API endpoints.
|
||||
|
||||
The banner script is embedded on merchant websites, so public banner
|
||||
endpoints must be readable from arbitrary origins. Admin/auth endpoints
|
||||
remain protected by the normal FastAPI CORSMiddleware configured with
|
||||
ALLOWED_ORIGINS.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import Response
|
||||
|
||||
PUBLIC_BANNER_PATH_PREFIXES = (
|
||||
"/api/v1/config/sites/",
|
||||
"/api/v1/translations/",
|
||||
"/api/v1/consent/",
|
||||
)
|
||||
|
||||
DEFAULT_ALLOWED_HEADERS = "Content-Type, Authorization, X-Requested-With"
|
||||
ALLOWED_METHODS = "GET, POST, OPTIONS"
|
||||
MAX_AGE_SECONDS = "86400"
|
||||
|
||||
|
||||
def is_public_banner_path(path: str) -> bool:
|
||||
"""Return True for public endpoints called by the banner script."""
|
||||
return any(path.startswith(prefix) for prefix in PUBLIC_BANNER_PATH_PREFIXES)
|
||||
|
||||
|
||||
def _public_cors_headers(request: Request) -> dict[str, str]:
|
||||
requested_headers = request.headers.get(
|
||||
"access-control-request-headers",
|
||||
DEFAULT_ALLOWED_HEADERS,
|
||||
)
|
||||
return {
|
||||
"Access-Control-Allow-Origin": "*",
|
||||
"Access-Control-Allow-Methods": ALLOWED_METHODS,
|
||||
"Access-Control-Allow-Headers": requested_headers,
|
||||
"Access-Control-Max-Age": MAX_AGE_SECONDS,
|
||||
}
|
||||
|
||||
|
||||
def _remove_credentials_header(response: Response) -> None:
|
||||
# Wildcard origins must not be paired with credentials. The normal global
|
||||
# CORSMiddleware may add this header because admin endpoints use credentialed
|
||||
# auth; strip it for public banner endpoints.
|
||||
if "access-control-allow-credentials" in response.headers:
|
||||
del response.headers["access-control-allow-credentials"]
|
||||
|
||||
|
||||
class PublicBannerCORSMiddleware(BaseHTTPMiddleware):
|
||||
"""Allow non-credentialed CORS for public banner endpoints only."""
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: RequestResponseEndpoint,
|
||||
) -> Response:
|
||||
if not is_public_banner_path(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# No Origin header means this is not a browser CORS request.
|
||||
if "origin" not in request.headers:
|
||||
return await call_next(request)
|
||||
|
||||
if request.method == "OPTIONS":
|
||||
return Response(status_code=204, headers=_public_cors_headers(request))
|
||||
|
||||
response = await call_next(request)
|
||||
response.headers.update(_public_cors_headers(request))
|
||||
_remove_credentials_header(response)
|
||||
return response
|
||||
84
apps/api/tests/test_public_banner_cors.py
Normal file
84
apps/api/tests/test_public_banner_cors.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""CORS behavior for public banner endpoints.
|
||||
|
||||
The banner API is embedded on merchant websites, so these public endpoints
|
||||
must be readable cross-origin without relying on the admin ALLOWED_ORIGINS
|
||||
setting.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.db import get_db
|
||||
from src.main import create_app
|
||||
|
||||
|
||||
def _db_returning_not_found() -> AsyncMock:
|
||||
session = AsyncMock()
|
||||
result = MagicMock()
|
||||
result.scalar_one_or_none.return_value = None
|
||||
session.execute = AsyncMock(return_value=result)
|
||||
return session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_config_get_allows_any_merchant_origin():
|
||||
app = create_app()
|
||||
db = _db_returning_not_found()
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get(
|
||||
f"/api/v1/config/sites/{uuid.uuid4()}",
|
||||
headers={"Origin": "https://www.dealplustech.co.th"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 404
|
||||
assert resp.headers["access-control-allow-origin"] == "*"
|
||||
assert "access-control-allow-credentials" not in resp.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_public_consent_preflight_allows_json_post_from_any_merchant_origin():
|
||||
app = create_app()
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.options(
|
||||
"/api/v1/consent/",
|
||||
headers={
|
||||
"Origin": "https://www.dealplustech.co.th",
|
||||
"Access-Control-Request-Method": "POST",
|
||||
"Access-Control-Request-Headers": "content-type",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.status_code == 204
|
||||
assert resp.headers["access-control-allow-origin"] == "*"
|
||||
assert "POST" in resp.headers["access-control-allow-methods"]
|
||||
assert "content-type" in resp.headers["access-control-allow-headers"].lower()
|
||||
assert "access-control-allow-credentials" not in resp.headers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_public_endpoint_does_not_get_public_wildcard_cors():
|
||||
app = create_app()
|
||||
transport = ASGITransport(app=app)
|
||||
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get(
|
||||
"/health",
|
||||
headers={"Origin": "https://www.dealplustech.co.th"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.headers.get("access-control-allow-origin") != "*"
|
||||
Reference in New Issue
Block a user