fix: strip www. prefix when matching origin to registered domains
Some checks failed
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
Some checks failed
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
This commit is contained in:
@@ -35,6 +35,11 @@ def _is_public_path(path: str) -> bool:
|
|||||||
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
|
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_domain(domain: str) -> str:
|
||||||
|
"""Strip leading www. for comparison purposes."""
|
||||||
|
return domain.removeprefix("www.")
|
||||||
|
|
||||||
|
|
||||||
def _extract_site_id_from_path(path: str) -> str | None:
|
def _extract_site_id_from_path(path: str) -> str | None:
|
||||||
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
|
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
|
||||||
parts = path.split("/")
|
parts = path.split("/")
|
||||||
@@ -49,7 +54,11 @@ def _extract_site_id_from_path(path: str) -> str | None:
|
|||||||
async def _get_allowed_origins_for_site(
|
async def _get_allowed_origins_for_site(
|
||||||
db: AsyncSession, site_id: str
|
db: AsyncSession, site_id: str
|
||||||
) -> set[str]:
|
) -> set[str]:
|
||||||
"""Return all registered domains (primary + additional) for a site."""
|
"""Return all registered domains (primary + additional) for a site.
|
||||||
|
|
||||||
|
Domains are stored without www. prefix; the origin is similarly
|
||||||
|
stripped before comparison so both variants match.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
site_uuid = uuid.UUID(site_id)
|
site_uuid = uuid.UUID(site_id)
|
||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
@@ -66,9 +75,9 @@ async def _get_allowed_origins_for_site(
|
|||||||
if not site:
|
if not site:
|
||||||
return set()
|
return set()
|
||||||
|
|
||||||
origins: set[str] = {site.domain}
|
origins: set[str] = {_normalize_domain(site.domain)}
|
||||||
if site.additional_domains:
|
if site.additional_domains:
|
||||||
origins.update(site.additional_domains)
|
origins.update(_normalize_domain(d) for d in site.additional_domains)
|
||||||
return origins
|
return origins
|
||||||
|
|
||||||
|
|
||||||
@@ -96,7 +105,9 @@ class DynamicCORSMedium(BaseHTTPMiddleware):
|
|||||||
|
|
||||||
site_id = _extract_site_id_from_path(request.url.path)
|
site_id = _extract_site_id_from_path(request.url.path)
|
||||||
|
|
||||||
if site_id:
|
if not site_id:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
# Async db session required — get from request state if already set
|
# Async db session required — get from request state if already set
|
||||||
# by the dependency injection, otherwise create a new one
|
# by the dependency injection, otherwise create a new one
|
||||||
db = request.state.db if hasattr(request.state, "db") else None
|
db = request.state.db if hasattr(request.state, "db") else None
|
||||||
@@ -109,7 +120,8 @@ class DynamicCORSMedium(BaseHTTPMiddleware):
|
|||||||
else:
|
else:
|
||||||
allowed = await _get_allowed_origins_for_site(db, site_id)
|
allowed = await _get_allowed_origins_for_site(db, site_id)
|
||||||
|
|
||||||
if origin in allowed:
|
# Normalize origin (strip www.) before comparing
|
||||||
|
if _normalize_domain(origin) in allowed:
|
||||||
response = await call_next(request)
|
response = await call_next(request)
|
||||||
response.headers["Access-Control-Allow-Origin"] = origin
|
response.headers["Access-Control-Allow-Origin"] = origin
|
||||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||||
|
|||||||
Reference in New Issue
Block a user