"""Tests for the extension registry and edition detection.""" import pytest from fastapi import APIRouter, FastAPI from src.config.edition import edition_name, is_ee from src.extensions.registry import ( ExtensionRegistry, OpenAPITag, discover_extensions, get_registry, ) # -- Edition detection ------------------------------------------------------- class TestEditionDetection: """The ``is_ee()`` / ``edition_name()`` helpers should return a consistent pair regardless of which edition is installed. Core tests don't assume a specific edition — that's checked in each repo's own integration tests.""" def test_edition_name_matches_is_ee(self): assert edition_name() == ("ee" if is_ee() else "ce") def test_edition_name_is_valid(self): assert edition_name() in ("ce", "ee") # -- Extension registry (unit) ---------------------------------------------- class TestExtensionRegistry: def _make_registry(self) -> ExtensionRegistry: return ExtensionRegistry() def test_empty_registry(self): reg = self._make_registry() assert reg.routers == [] assert reg.model_modules == [] assert reg.startup_hooks == [] def test_add_router(self): reg = self._make_registry() router = APIRouter() reg.add_router(router, prefix="/api/v1") assert len(reg.routers) == 1 assert reg.routers[0].router is router assert reg.routers[0].prefix == "/api/v1" def test_add_router_with_tags(self): reg = self._make_registry() router = APIRouter() tag = OpenAPITag(name="billing", description="Billing endpoints") reg.add_router(router, tags=[tag]) assert reg.routers[0].tags == [tag] def test_add_model_module(self): reg = self._make_registry() reg.add_model_module("ee.api.src.models.billing") assert reg.model_modules == ["ee.api.src.models.billing"] def test_add_startup_hook(self): reg = self._make_registry() async def hook(app: FastAPI) -> None: pass reg.add_startup_hook(hook) assert len(reg.startup_hooks) == 1 def test_apply_mounts_routers(self): reg = self._make_registry() router = APIRouter() @router.get("/test") async def _test() -> dict[str, str]: return {"ok": True} reg.add_router(router, prefix="/ext") app = FastAPI() reg.apply(app) # The router should be included in the app routes paths = [r.path for r in app.routes] assert "/ext/test" in paths def test_apply_adds_openapi_tags(self): reg = self._make_registry() router = APIRouter() tag = OpenAPITag(name="billing", description="Billing endpoints") reg.add_router(router, tags=[tag]) app = FastAPI() app.openapi_tags = [] reg.apply(app) assert any(t["name"] == "billing" for t in app.openapi_tags) def test_apply_skips_duplicate_tags(self): reg = self._make_registry() router = APIRouter() tag = OpenAPITag(name="billing", description="Billing endpoints") reg.add_router(router, tags=[tag]) app = FastAPI() app.openapi_tags = [{"name": "billing", "description": "Existing"}] reg.apply(app) billing_tags = [t for t in app.openapi_tags if t["name"] == "billing"] assert len(billing_tags) == 1 assert billing_tags[0]["description"] == "Existing" # -- discover_extensions ----------------------------------------------------- class TestDiscoverExtensions: def test_discover_extensions_does_not_raise(self): """discover_extensions should not raise regardless of edition.""" discover_extensions() # -- Global registry --------------------------------------------------------- class TestGlobalRegistry: def test_get_registry_returns_singleton(self): assert get_registry() is get_registry() # -- Health endpoint with edition field -------------------------------------- @pytest.mark.asyncio async def test_health_reports_edition(client): response = await client.get("/health") assert response.status_code == 200 data = response.json() assert data["edition"] in ("ce", "ee")