fix: strip www. prefix when matching origin to registered domains
Some checks failed
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
Some checks failed
CI / Banner Lint & Typecheck (push) Has been cancelled
CI / Detect changes (push) Has been cancelled
CI / API Lint (push) Has been cancelled
CI / API Tests (push) Has been cancelled
CI / Scanner Lint (push) Has been cancelled
CI / Scanner Tests (push) Has been cancelled
CI / Banner Tests (push) Has been cancelled
CI / Banner Build (push) Has been cancelled
CI / Admin UI Typecheck (push) Has been cancelled
CI / Admin UI Tests (push) Has been cancelled
CI / Admin UI Build (push) Has been cancelled
This commit is contained in:
@@ -35,6 +35,11 @@ def _is_public_path(path: str) -> bool:
|
||||
return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
def _normalize_domain(domain: str) -> str:
|
||||
"""Strip leading www. for comparison purposes."""
|
||||
return domain.removeprefix("www.")
|
||||
|
||||
|
||||
def _extract_site_id_from_path(path: str) -> str | None:
|
||||
"""Extract site_id UUID from public config path like /api/v1/config/sites/{id}."""
|
||||
parts = path.split("/")
|
||||
@@ -49,7 +54,11 @@ def _extract_site_id_from_path(path: str) -> str | None:
|
||||
async def _get_allowed_origins_for_site(
|
||||
db: AsyncSession, site_id: str
|
||||
) -> set[str]:
|
||||
"""Return all registered domains (primary + additional) for a site."""
|
||||
"""Return all registered domains (primary + additional) for a site.
|
||||
|
||||
Domains are stored without www. prefix; the origin is similarly
|
||||
stripped before comparison so both variants match.
|
||||
"""
|
||||
try:
|
||||
site_uuid = uuid.UUID(site_id)
|
||||
except (ValueError, TypeError):
|
||||
@@ -66,9 +75,9 @@ async def _get_allowed_origins_for_site(
|
||||
if not site:
|
||||
return set()
|
||||
|
||||
origins: set[str] = {site.domain}
|
||||
origins: set[str] = {_normalize_domain(site.domain)}
|
||||
if site.additional_domains:
|
||||
origins.update(site.additional_domains)
|
||||
origins.update(_normalize_domain(d) for d in site.additional_domains)
|
||||
return origins
|
||||
|
||||
|
||||
@@ -96,26 +105,29 @@ class DynamicCORSMedium(BaseHTTPMiddleware):
|
||||
|
||||
site_id = _extract_site_id_from_path(request.url.path)
|
||||
|
||||
if site_id:
|
||||
# Async db session required — get from request state if already set
|
||||
# by the dependency injection, otherwise create a new one
|
||||
db = request.state.db if hasattr(request.state, "db") else None
|
||||
if db is None:
|
||||
# Fall back to a new session for this lookup only
|
||||
allowed: set[str] = set()
|
||||
async for session in get_db():
|
||||
allowed = await _get_allowed_origins_for_site(session, site_id)
|
||||
break
|
||||
else:
|
||||
allowed = await _get_allowed_origins_for_site(db, site_id)
|
||||
if not site_id:
|
||||
return await call_next(request)
|
||||
|
||||
if origin in allowed:
|
||||
response = await call_next(request)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Allow-Methods"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
return response
|
||||
# Async db session required — get from request state if already set
|
||||
# by the dependency injection, otherwise create a new one
|
||||
db = request.state.db if hasattr(request.state, "db") else None
|
||||
if db is None:
|
||||
# Fall back to a new session for this lookup only
|
||||
allowed: set[str] = set()
|
||||
async for session in get_db():
|
||||
allowed = await _get_allowed_origins_for_site(session, site_id)
|
||||
break
|
||||
else:
|
||||
allowed = await _get_allowed_origins_for_site(db, site_id)
|
||||
|
||||
# Normalize origin (strip www.) before comparing
|
||||
if _normalize_domain(origin) in allowed:
|
||||
response = await call_next(request)
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Allow-Methods"] = "*"
|
||||
response.headers["Access-Control-Allow-Headers"] = "*"
|
||||
return response
|
||||
|
||||
# Site not found or origin not registered — let the normal
|
||||
# CORS middleware/global allowed_origins handle it
|
||||
|
||||
Reference in New Issue
Block a user