Harden user-scoped subscription route access checks
This commit is contained in:
@@ -10,6 +10,8 @@ from loguru import logger
|
|||||||
|
|
||||||
from services.database import get_db
|
from services.database import get_db
|
||||||
from models.subscription_models import UsageAlert
|
from models.subscription_models import UsageAlert
|
||||||
|
from middleware.auth_middleware import get_current_user
|
||||||
|
from ..dependencies import verify_user_access
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
@@ -19,9 +21,12 @@ async def get_usage_alerts(
|
|||||||
user_id: str,
|
user_id: str,
|
||||||
unread_only: bool = Query(False, description="Only return unread alerts"),
|
unread_only: bool = Query(False, description="Only return unread alerts"),
|
||||||
limit: int = Query(50, ge=1, le=100, description="Maximum number of alerts"),
|
limit: int = Query(50, ge=1, le=100, description="Maximum number of alerts"),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Get usage alerts for a user."""
|
"""Get usage alerts for a user."""
|
||||||
|
|
||||||
|
verify_user_access(user_id, current_user)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
query = db.query(UsageAlert).filter(
|
query = db.query(UsageAlert).filter(
|
||||||
@@ -79,16 +84,20 @@ async def get_usage_alerts(
|
|||||||
@router.post("/alerts/{alert_id}/mark-read")
|
@router.post("/alerts/{alert_id}/mark-read")
|
||||||
async def mark_alert_read(
|
async def mark_alert_read(
|
||||||
alert_id: int,
|
alert_id: int,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Mark an alert as read."""
|
"""Mark an alert as read."""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
alert = db.query(UsageAlert).filter(UsageAlert.id == alert_id).first()
|
alert = db.query(UsageAlert).filter(UsageAlert.id == alert_id).first()
|
||||||
|
|
||||||
if not alert:
|
if not alert:
|
||||||
raise HTTPException(status_code=404, detail="Alert not found")
|
raise HTTPException(status_code=404, detail="Alert not found")
|
||||||
|
|
||||||
|
if str(alert.user_id) != str(current_user.get("id")):
|
||||||
|
raise HTTPException(status_code=403, detail="Access denied")
|
||||||
|
|
||||||
alert.is_read = True
|
alert.is_read = True
|
||||||
alert.read_at = datetime.utcnow()
|
alert.read_at = datetime.utcnow()
|
||||||
db.commit()
|
db.commit()
|
||||||
@@ -98,6 +107,8 @@ async def mark_alert_read(
|
|||||||
"message": "Alert marked as read"
|
"message": "Alert marked as read"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error marking alert as read: {e}")
|
logger.error(f"Error marking alert as read: {e}")
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ from services.database import get_db
|
|||||||
from services.subscription import UsageTrackingService, PricingService
|
from services.subscription import UsageTrackingService, PricingService
|
||||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||||
from models.subscription_models import UsageAlert
|
from models.subscription_models import UsageAlert
|
||||||
|
from middleware.auth_middleware import get_current_user
|
||||||
|
from ..dependencies import verify_user_access
|
||||||
from ..cache import get_cached_dashboard, set_cached_dashboard
|
from ..cache import get_cached_dashboard, set_cached_dashboard
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
@@ -22,9 +24,12 @@ router = APIRouter()
|
|||||||
async def get_dashboard_data(
|
async def get_dashboard_data(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
billing_period: str = None,
|
billing_period: str = None,
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Get comprehensive dashboard data for usage monitoring."""
|
"""Get comprehensive dashboard data for usage monitoring."""
|
||||||
|
|
||||||
|
verify_user_access(user_id, current_user)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ensure_subscription_plan_columns(db)
|
ensure_subscription_plan_columns(db)
|
||||||
|
|||||||
82
backend/api/subscription/routes/route_access_audit.py
Normal file
82
backend/api/subscription/routes/route_access_audit.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Quick route-level audit to enforce user-scoped access checks."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
ROUTES_DIR = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
|
||||||
|
def _decorator_path(decorator: ast.AST) -> str | None:
|
||||||
|
"""Extract route path from decorators like @router.get("/usage/{user_id}")."""
|
||||||
|
if not isinstance(decorator, ast.Call):
|
||||||
|
return None
|
||||||
|
if not isinstance(decorator.func, ast.Attribute):
|
||||||
|
return None
|
||||||
|
if not decorator.args:
|
||||||
|
return None
|
||||||
|
|
||||||
|
first_arg = decorator.args[0]
|
||||||
|
if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
|
||||||
|
return first_arg.value
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _has_current_user_dependency(fn_node: ast.AsyncFunctionDef) -> bool:
|
||||||
|
for arg in fn_node.args.args:
|
||||||
|
if arg.arg != "current_user":
|
||||||
|
continue
|
||||||
|
default_index = fn_node.args.args.index(arg) - (len(fn_node.args.args) - len(fn_node.args.defaults))
|
||||||
|
if default_index < 0:
|
||||||
|
continue
|
||||||
|
default_node = fn_node.args.defaults[default_index]
|
||||||
|
if not isinstance(default_node, ast.Call):
|
||||||
|
continue
|
||||||
|
if isinstance(default_node.func, ast.Name) and default_node.func.id == "Depends":
|
||||||
|
if default_node.args and isinstance(default_node.args[0], ast.Name):
|
||||||
|
return default_node.args[0].id == "get_current_user"
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _has_verify_user_access_call(fn_node: ast.AsyncFunctionDef) -> bool:
|
||||||
|
for node in ast.walk(fn_node):
|
||||||
|
if not isinstance(node, ast.Call):
|
||||||
|
continue
|
||||||
|
if isinstance(node.func, ast.Name) and node.func.id == "verify_user_access":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def run_access_audit() -> List[Tuple[str, str]]:
|
||||||
|
"""Return (file, function) pairs for user-scoped routes missing auth checks."""
|
||||||
|
failures: List[Tuple[str, str]] = []
|
||||||
|
|
||||||
|
for route_file in ROUTES_DIR.glob("*.py"):
|
||||||
|
if route_file.name in {"__init__.py", Path(__file__).name}:
|
||||||
|
continue
|
||||||
|
|
||||||
|
tree = ast.parse(route_file.read_text(), filename=str(route_file))
|
||||||
|
for node in tree.body:
|
||||||
|
if not isinstance(node, ast.AsyncFunctionDef):
|
||||||
|
continue
|
||||||
|
|
||||||
|
route_paths = [p for d in node.decorator_list if (p := _decorator_path(d))]
|
||||||
|
if not any("{user_id}" in p for p in route_paths):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not _has_current_user_dependency(node) or not _has_verify_user_access_call(node):
|
||||||
|
failures.append((route_file.name, node.name))
|
||||||
|
|
||||||
|
return failures
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
issues = run_access_audit()
|
||||||
|
if issues:
|
||||||
|
for file_name, fn_name in issues:
|
||||||
|
print(f"FAIL: {file_name}:{fn_name} missing get_current_user/verify_user_access pattern")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
print("PASS: all user-scoped routes include get_current_user and verify_user_access")
|
||||||
@@ -26,7 +26,7 @@ async def get_user_usage(
|
|||||||
|
|
||||||
# Verify user can only access their own data
|
# Verify user can only access their own data
|
||||||
verify_user_access(user_id, current_user)
|
verify_user_access(user_id, current_user)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
usage_service = UsageTrackingService(db)
|
usage_service = UsageTrackingService(db)
|
||||||
stats = usage_service.get_user_usage_stats(user_id, billing_period)
|
stats = usage_service.get_user_usage_stats(user_id, billing_period)
|
||||||
@@ -44,9 +44,12 @@ async def get_user_usage(
|
|||||||
async def get_usage_trends(
|
async def get_usage_trends(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
months: int = Query(6, ge=1, le=24, description="Number of months to include"),
|
months: int = Query(6, ge=1, le=24, description="Number of months to include"),
|
||||||
db: Session = Depends(get_db)
|
db: Session = Depends(get_db),
|
||||||
|
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Get usage trends over time."""
|
"""Get usage trends over time."""
|
||||||
|
|
||||||
|
verify_user_access(user_id, current_user)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
usage_service = UsageTrackingService(db)
|
usage_service = UsageTrackingService(db)
|
||||||
|
|||||||
Reference in New Issue
Block a user