fix: strip www. prefix when matching origin to registered domains
Some checks failed
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled

This commit is contained in:
Kunthawat Greethong
2026-06-15 18:39:29 +07:00
parent 7973cb321e
commit 0bba7ef21a

View File

@@ -35,6 +35,11 @@ def _is_public_path(path: str) -> bool:
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES) return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
def _normalize_domain(domain: str) -> str:
"""Strip leading www. for comparison purposes."""
return domain.removeprefix("www.")
def _extract_site_id_from_path(path: str) -> str | None: def _extract_site_id_from_path(path: str) -> str | None:
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}.""" """Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
parts = path.split("/") parts = path.split("/")
@@ -49,7 +54,11 @@ def _extract_site_id_from_path(path: str) -> str | None:
async def _get_allowed_origins_for_site( async def _get_allowed_origins_for_site(
db: AsyncSession, site_id: str db: AsyncSession, site_id: str
) -> set[str]: ) -> set[str]:
"""Return all registered domains (primary + additional) for a site.""" """Return all registered domains (primary + additional) for a site.
Domains are stored without www. prefix; the origin is similarly
stripped before comparison so both variants match.
"""
try: try:
site_uuid = uuid.UUID(site_id) site_uuid = uuid.UUID(site_id)
except (ValueError, TypeError): except (ValueError, TypeError):
@@ -66,9 +75,9 @@ async def _get_allowed_origins_for_site(
if not site: if not site:
return set() return set()
origins: set[str] = {site.domain} origins: set[str] = {_normalize_domain(site.domain)}
if site.additional_domains: if site.additional_domains:
origins.update(site.additional_domains) origins.update(_normalize_domain(d) for d in site.additional_domains)
return origins return origins
@@ -96,7 +105,9 @@ class DynamicCORSMedium(BaseHTTPMiddleware):
site_id = _extract_site_id_from_path(request.url.path) site_id = _extract_site_id_from_path(request.url.path)
if site_id: if not site_id:
return await call_next(request)
# Async db session required — get from request state if already set # Async db session required — get from request state if already set
# by the dependency injection, otherwise create a new one # by the dependency injection, otherwise create a new one
db = request.state.db if hasattr(request.state, "db") else None db = request.state.db if hasattr(request.state, "db") else None
@@ -109,7 +120,8 @@ class DynamicCORSMedium(BaseHTTPMiddleware):
else: else:
allowed = await _get_allowed_origins_for_site(db, site_id) allowed = await _get_allowed_origins_for_site(db, site_id)
if origin in allowed: # Normalize origin (strip www.) before comparing
if _normalize_domain(origin) in allowed:
response = await call_next(request) response = await call_next(request)
response.headers["Access-Control-Allow-Origin"] = origin response.headers["Access-Control-Allow-Origin"] = origin
response.headers["Access-Control-Allow-Credentials"] = "true" response.headers["Access-Control-Allow-Credentials"] = "true"