diff --git a/apps/api/src/middleware/dynamic_cors.py b/apps/api/src/middleware/dynamic_cors.py index dbc7d5a..2a40cc1 100644 --- a/apps/api/src/middleware/dynamic_cors.py +++ b/apps/api/src/middleware/dynamic_cors.py @@ -35,6 +35,11 @@ def _is_public_path(path: str) -> bool: return any(path.startswith(prefix) for prefix in PUBLIC_PATH_PREFIXES) +def _normalize_domain(domain: str) -> str: + """Strip leading www. for comparison purposes.""" + return domain.removeprefix("www.") + + def _extract_site_id_from_path(path: str) -> str | None: """Extract site_id UUID from public config path like /api/v1/config/sites/{id}.""" parts = path.split("/") @@ -49,7 +54,11 @@ def _extract_site_id_from_path(path: str) -> str | None: async def _get_allowed_origins_for_site( db: AsyncSession, site_id: str ) -> set[str]: - """Return all registered domains (primary + additional) for a site.""" + """Return all registered domains (primary + additional) for a site. + + Domains are stored without www. prefix; the origin is similarly + stripped before comparison so both variants match. + """ try: site_uuid = uuid.UUID(site_id) except (ValueError, TypeError): @@ -66,9 +75,9 @@ async def _get_allowed_origins_for_site( if not site: return set() - origins: set[str] = {site.domain} + origins: set[str] = {_normalize_domain(site.domain)} if site.additional_domains: - origins.update(site.additional_domains) + origins.update(_normalize_domain(d) for d in site.additional_domains) return origins @@ -96,26 +105,29 @@ class DynamicCORSMedium(BaseHTTPMiddleware): site_id = _extract_site_id_from_path(request.url.path) - if site_id: - # Async db session required — get from request state if already set - # by the dependency injection, otherwise create a new one - db = request.state.db if hasattr(request.state, "db") else None - if db is None: - # Fall back to a new session for this lookup only - allowed: set[str] = set() - async for session in get_db(): - allowed = await _get_allowed_origins_for_site(session, site_id) - break - else: - allowed = await _get_allowed_origins_for_site(db, site_id) + if not site_id: + return await call_next(request) - if origin in allowed: - response = await call_next(request) - response.headers["Access-Control-Allow-Origin"] = origin - response.headers["Access-Control-Allow-Credentials"] = "true" - response.headers["Access-Control-Allow-Methods"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" - return response + # Async db session required — get from request state if already set + # by the dependency injection, otherwise create a new one + db = request.state.db if hasattr(request.state, "db") else None + if db is None: + # Fall back to a new session for this lookup only + allowed: set[str] = set() + async for session in get_db(): + allowed = await _get_allowed_origins_for_site(session, site_id) + break + else: + allowed = await _get_allowed_origins_for_site(db, site_id) + + # Normalize origin (strip www.) before comparing + if _normalize_domain(origin) in allowed: + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = origin + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["Access-Control-Allow-Methods"] = "*" + response.headers["Access-Control-Allow-Headers"] = "*" + return response # Site not found or origin not registered — let the normal # CORS middleware/global allowed_origins handle it