diff --git a/backend/api/subscription/routes/alerts.py b/backend/api/subscription/routes/alerts.py index e214fc42..6eec18c6 100644 --- a/backend/api/subscription/routes/alerts.py +++ b/backend/api/subscription/routes/alerts.py @@ -10,6 +10,8 @@ from loguru import logger from services.database import get_db from models.subscription_models import UsageAlert +from middleware.auth_middleware import get_current_user +from ..dependencies import verify_user_access router = APIRouter() @@ -19,9 +21,12 @@ async def get_usage_alerts( user_id: str, unread_only: bool = Query(False, description="Only return unread alerts"), limit: int = Query(50, ge=1, le=100, description="Maximum number of alerts"), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + current_user: Dict[str, Any] = Depends(get_current_user) ) -> Dict[str, Any]: """Get usage alerts for a user.""" + + verify_user_access(user_id, current_user) try: query = db.query(UsageAlert).filter( @@ -79,16 +84,20 @@ async def get_usage_alerts( @router.post("/alerts/{alert_id}/mark-read") async def mark_alert_read( alert_id: int, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + current_user: Dict[str, Any] = Depends(get_current_user) ) -> Dict[str, Any]: """Mark an alert as read.""" try: alert = db.query(UsageAlert).filter(UsageAlert.id == alert_id).first() - + if not alert: raise HTTPException(status_code=404, detail="Alert not found") - + + if str(alert.user_id) != str(current_user.get("id")): + raise HTTPException(status_code=403, detail="Access denied") + alert.is_read = True alert.read_at = datetime.utcnow() db.commit() @@ -98,6 +107,8 @@ async def mark_alert_read( "message": "Alert marked as read" } + except HTTPException: + raise except Exception as e: logger.error(f"Error marking alert as read: {e}") raise HTTPException(status_code=500, detail=str(e)) diff --git a/backend/api/subscription/routes/dashboard.py b/backend/api/subscription/routes/dashboard.py index f9cae8b3..d5e33448 100644 --- a/backend/api/subscription/routes/dashboard.py +++ b/backend/api/subscription/routes/dashboard.py @@ -13,6 +13,8 @@ from services.database import get_db from services.subscription import UsageTrackingService, PricingService from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns from models.subscription_models import UsageAlert +from middleware.auth_middleware import get_current_user +from ..dependencies import verify_user_access from ..cache import get_cached_dashboard, set_cached_dashboard router = APIRouter() @@ -22,9 +24,12 @@ router = APIRouter() async def get_dashboard_data( user_id: str, billing_period: str = None, - db: Session = Depends(get_db) + db: Session = Depends(get_db), + current_user: Dict[str, Any] = Depends(get_current_user) ) -> Dict[str, Any]: """Get comprehensive dashboard data for usage monitoring.""" + + verify_user_access(user_id, current_user) try: ensure_subscription_plan_columns(db) diff --git a/backend/api/subscription/routes/route_access_audit.py b/backend/api/subscription/routes/route_access_audit.py new file mode 100644 index 00000000..8754d80a --- /dev/null +++ b/backend/api/subscription/routes/route_access_audit.py @@ -0,0 +1,82 @@ +"""Quick route-level audit to enforce user-scoped access checks.""" + +from __future__ import annotations + +import ast +from pathlib import Path +from typing import List, Tuple + +ROUTES_DIR = Path(__file__).resolve().parent + + +def _decorator_path(decorator: ast.AST) -> str | None: + """Extract route path from decorators like @router.get("/usage/{user_id}").""" + if not isinstance(decorator, ast.Call): + return None + if not isinstance(decorator.func, ast.Attribute): + return None + if not decorator.args: + return None + + first_arg = decorator.args[0] + if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str): + return first_arg.value + return None + + +def _has_current_user_dependency(fn_node: ast.AsyncFunctionDef) -> bool: + for arg in fn_node.args.args: + if arg.arg != "current_user": + continue + default_index = fn_node.args.args.index(arg) - (len(fn_node.args.args) - len(fn_node.args.defaults)) + if default_index < 0: + continue + default_node = fn_node.args.defaults[default_index] + if not isinstance(default_node, ast.Call): + continue + if isinstance(default_node.func, ast.Name) and default_node.func.id == "Depends": + if default_node.args and isinstance(default_node.args[0], ast.Name): + return default_node.args[0].id == "get_current_user" + return False + + +def _has_verify_user_access_call(fn_node: ast.AsyncFunctionDef) -> bool: + for node in ast.walk(fn_node): + if not isinstance(node, ast.Call): + continue + if isinstance(node.func, ast.Name) and node.func.id == "verify_user_access": + return True + return False + + +def run_access_audit() -> List[Tuple[str, str]]: + """Return (file, function) pairs for user-scoped routes missing auth checks.""" + failures: List[Tuple[str, str]] = [] + + for route_file in ROUTES_DIR.glob("*.py"): + if route_file.name in {"__init__.py", Path(__file__).name}: + continue + + tree = ast.parse(route_file.read_text(), filename=str(route_file)) + for node in tree.body: + if not isinstance(node, ast.AsyncFunctionDef): + continue + + route_paths = [p for d in node.decorator_list if (p := _decorator_path(d))] + if not any("{user_id}" in p for p in route_paths): + continue + + if not _has_current_user_dependency(node) or not _has_verify_user_access_call(node): + failures.append((route_file.name, node.name)) + + return failures + + +if __name__ == "__main__": + issues = run_access_audit() + if issues: + for file_name, fn_name in issues: + print(f"FAIL: {file_name}:{fn_name} missing get_current_user/verify_user_access pattern") + raise SystemExit(1) + + print("PASS: all user-scoped routes include get_current_user and verify_user_access") diff --git a/backend/api/subscription/routes/usage.py b/backend/api/subscription/routes/usage.py index 636858f1..41efda0f 100644 --- a/backend/api/subscription/routes/usage.py +++ b/backend/api/subscription/routes/usage.py @@ -26,7 +26,7 @@ async def get_user_usage( # Verify user can only access their own data verify_user_access(user_id, current_user) - + try: usage_service = UsageTrackingService(db) stats = usage_service.get_user_usage_stats(user_id, billing_period) @@ -44,9 +44,12 @@ async def get_user_usage( async def get_usage_trends( user_id: str, months: int = Query(6, ge=1, le=24, description="Number of months to include"), - db: Session = Depends(get_db) + db: Session = Depends(get_db), + current_user: Dict[str, Any] = Depends(get_current_user) ) -> Dict[str, Any]: """Get usage trends over time.""" + + verify_user_access(user_id, current_user) try: usage_service = UsageTrackingService(db)