Files
ALwrity/backend/tests/test_no_import_time_mkdir.py

73 lines
2.4 KiB
Python

"""Guardrail: startup-loaded modules must not write filesystem at import time."""
from __future__ import annotations
import ast
from pathlib import Path
STARTUP_MODULES = [
"backend/app.py",
"backend/alwrity_utils/router_manager.py",
"backend/routers/seo_tools.py",
"backend/middleware/logging_middleware.py",
"backend/api/youtube/router.py",
"backend/api/youtube/handlers/avatar.py",
"backend/api/youtube/handlers/images.py",
"backend/api/youtube/handlers/audio.py",
"backend/utils/media_utils.py",
]
def _is_forbidden_call(node: ast.Call) -> bool:
func = node.func
if isinstance(func, ast.Attribute):
if isinstance(func.value, ast.Name) and func.value.id == "os" and func.attr == "makedirs":
return True
if func.attr == "mkdir":
return True
return False
def _top_level_forbidden_calls(tree: ast.AST) -> list[tuple[int, str]]:
"""Return forbidden calls that execute during module import.
We intentionally do not flag calls inside function/class bodies, because those
are runtime operations and safe for this policy.
"""
violations: list[tuple[int, str]] = []
for statement in getattr(tree, "body", []):
# Skip function/class definitions: their bodies are not executed at import time.
if isinstance(statement, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
continue
nodes_to_visit = [statement]
while nodes_to_visit:
node = nodes_to_visit.pop()
# Do not descend into nested function/class definitions.
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
continue
if isinstance(node, ast.Call) and _is_forbidden_call(node):
violations.append((node.lineno, ast.unparse(node.func)))
nodes_to_visit.extend(ast.iter_child_nodes(node))
return violations
def test_no_import_time_mkdir_calls_in_startup_modules() -> None:
all_violations: list[str] = []
for rel in STARTUP_MODULES:
module_file = Path(rel)
source = module_file.read_text(encoding="utf-8")
tree = ast.parse(source)
violations = _top_level_forbidden_calls(tree)
for line, call in violations:
all_violations.append(f"{rel}:{line} -> {call}")
assert not all_violations, "Top-level filesystem writes found:\n" + "\n".join(all_violations)