Compare commits

..

1 Commits

Author SHA1 Message Date
ي
b54c2978c3 Restrict podcast task status access by owner 2026-03-30 08:05:44 +05:30
6 changed files with 45 additions and 72 deletions

View File

@@ -203,7 +203,10 @@ async def create_audio_dubbing_task(
""" """
user_id = require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
task_id = task_manager.create_task("audio_dubbing") task_id = task_manager.create_task(
"audio_dubbing",
metadata={"owner_user_id": user_id},
)
background_tasks.add_task( background_tasks.add_task(
_execute_dubbing_task, _execute_dubbing_task,
@@ -240,7 +243,7 @@ async def get_dubbing_result(
""" """
user_id = require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id) task_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
if not task_status: if not task_status:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")
@@ -403,7 +406,10 @@ async def create_voice_clone_task(
""" """
user_id = require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
task_id = task_manager.create_task("voice_clone") task_id = task_manager.create_task(
"voice_clone",
metadata={"owner_user_id": user_id},
)
background_tasks.add_task( background_tasks.add_task(
_execute_voice_clone_task, _execute_voice_clone_task,
@@ -434,7 +440,7 @@ async def get_voice_clone_result(
""" """
user_id = require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id) task_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
if not task_status: if not task_status:
raise HTTPException(status_code=404, detail="Task not found") raise HTTPException(status_code=404, detail="Task not found")

View File

@@ -222,7 +222,7 @@ def _execute_podcast_video_task(
) )
# Verify the task status was updated correctly # Verify the task status was updated correctly
updated_status = task_manager.get_task_status(task_id) updated_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
logger.info( logger.info(
f"[Podcast] Task status after update: task_id={task_id}, status={updated_status.get('status') if updated_status else 'None'}, has_result={bool(updated_status.get('result') if updated_status else False)}, video_url={updated_status.get('result', {}).get('video_url') if updated_status else 'N/A'}" f"[Podcast] Task status after update: task_id={task_id}, status={updated_status.get('status') if updated_status else 'None'}, has_result={bool(updated_status.get('result') if updated_status else False)}, video_url={updated_status.get('result', {}).get('video_url') if updated_status else 'N/A'}"
) )
@@ -358,7 +358,10 @@ async def generate_podcast_video(
logger.warning(f"[Podcast] Failed to extract auth token from headers: {e}") logger.warning(f"[Podcast] Failed to extract auth token from headers: {e}")
# Create async task # Create async task
task_id = task_manager.create_task("podcast_video_generation") task_id = task_manager.create_task(
"podcast_video_generation",
metadata={"owner_user_id": user_id},
)
background_tasks.add_task( background_tasks.add_task(
_execute_podcast_video_task, _execute_podcast_video_task,
task_id=task_id, task_id=task_id,
@@ -488,7 +491,10 @@ async def combine_podcast_videos(
raise HTTPException(status_code=400, detail="No scene videos provided") raise HTTPException(status_code=400, detail="No scene videos provided")
# Create async task # Create async task
task_id = task_manager.create_task("podcast_combine_videos") task_id = task_manager.create_task(
"podcast_combine_videos",
metadata={"owner_user_id": user_id},
)
# Extract token for authenticated URL building # Extract token for authenticated URL building
auth_token = None auth_token = None

View File

@@ -4,7 +4,7 @@ Podcast Maker API Router
Main router that imports and registers all handler modules. Main router that imports and registers all handler modules.
""" """
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any from typing import Dict, Any
from middleware.auth_middleware import get_current_user from middleware.auth_middleware import get_current_user
@@ -32,5 +32,8 @@ router.include_router(dubbing.router)
@router.get("/task/{task_id}/status") @router.get("/task/{task_id}/status")
async def podcast_task_status(task_id: str, current_user: Dict[str, Any] = Depends(get_current_user)): async def podcast_task_status(task_id: str, current_user: Dict[str, Any] = Depends(get_current_user)):
"""Expose task status under podcast namespace (reuses shared task manager).""" """Expose task status under podcast namespace (reuses shared task manager)."""
require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
return task_manager.get_task_status(task_id) task_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
return task_status

View File

@@ -34,9 +34,14 @@ class TaskManager:
del self.task_storage[task_id] del self.task_storage[task_id]
logger.debug(f"[StoryWriter] Cleaned up old task: {task_id}") logger.debug(f"[StoryWriter] Cleaned up old task: {task_id}")
def create_task(self, task_type: str = "story_generation") -> str: def create_task(
self,
task_type: str = "story_generation",
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Create a new task and return its ID.""" """Create a new task and return its ID."""
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
task_metadata = metadata or {}
self.task_storage[task_id] = { self.task_storage[task_id] = {
"status": "pending", "status": "pending",
@@ -45,13 +50,14 @@ class TaskManager:
"error": None, "error": None,
"progress_messages": [], "progress_messages": [],
"task_type": task_type, "task_type": task_type,
"progress": 0.0 "progress": 0.0,
"metadata": task_metadata,
} }
logger.info(f"[StoryWriter] Created task: {task_id} (type: {task_type})") logger.info(f"[StoryWriter] Created task: {task_id} (type: {task_type})")
return task_id return task_id
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]: def get_task_status(self, task_id: str, requester_user_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Get the status of a task.""" """Get the status of a task."""
self.cleanup_old_tasks() self.cleanup_old_tasks()
@@ -62,6 +68,15 @@ class TaskManager:
return None return None
task = self.task_storage[task_id] task = self.task_storage[task_id]
metadata = task.get("metadata", {}) or {}
owner_user_id = metadata.get("owner_user_id")
if requester_user_id is not None and owner_user_id is not None and requester_user_id != owner_user_id:
logger.warning(
f"[StoryWriter] Task access denied for task {task_id}: requester does not match owner"
)
return None
response = { response = {
"task_id": task_id, "task_id": task_id,
"status": task["status"], "status": task["status"],

View File

@@ -462,7 +462,7 @@ async def serve_frontend():
async def startup_event(): async def startup_event():
"""Initialize services on startup.""" """Initialize services on startup."""
try: try:
startup_report = run_startup_health_routine(app) startup_report = run_startup_health_routine()
if startup_report.get("status") != "healthy": if startup_report.get("status") != "healthy":
logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}") logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}")

View File

@@ -3,8 +3,6 @@ from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from fastapi import FastAPI
from fastapi.routing import APIRoute
from loguru import logger from loguru import logger
from sqlalchemy import inspect, text from sqlalchemy import inspect, text
@@ -51,60 +49,6 @@ def _record_check(checks: List[Dict[str, Any]], name: str, ok: bool, detail: str
checks.append({"name": name, "ok": ok, "detail": detail}) checks.append({"name": name, "ok": ok, "detail": detail})
def _is_demo_mode() -> bool:
app_env = os.getenv("APP_ENV", os.getenv("ENV", os.getenv("DEPLOY_ENV", ""))).strip().lower()
if app_env == "demo":
return True
return _env_true("ALWRITY_DEMO_MODE", default=False)
def _check_required_demo_routes(
app: Optional[FastAPI],
checks: List[Dict[str, Any]],
errors: List[str],
) -> None:
if not _is_demo_mode():
_record_check(
checks,
"demo_required_routes",
True,
"Skipped (not in demo mode). Set APP_ENV=demo or ALWRITY_DEMO_MODE=true to enforce.",
)
return
if app is None:
errors.append(
"Demo startup route check could not run because FastAPI app context was not provided to startup health routine."
)
_record_check(checks, "demo_required_routes_context", False, "missing app context")
return
required_routes = {
"/api/subscription/plans": "GET",
"/api/podcast/projects": "GET",
}
available_routes = {
(route.path, method)
for route in app.router.routes
if isinstance(route, APIRoute)
for method in route.methods
}
missing: List[str] = []
for path, method in required_routes.items():
if (path, method) in available_routes:
_record_check(checks, f"demo_route_{path}_{method}", True, "route registered")
else:
missing.append(f"{method} {path}")
_record_check(checks, f"demo_route_{path}_{method}", False, "route missing")
if missing:
errors.append(
"Demo mode startup check failed. Missing required API endpoints: "
f"{', '.join(missing)}. Ensure subscription and podcast routers are imported and included during app setup."
)
def _check_workspace_root(checks: List[Dict[str, Any]], errors: List[str]) -> None: def _check_workspace_root(checks: List[Dict[str, Any]], errors: List[str]) -> None:
workspace = Path(WORKSPACE_DIR) workspace = Path(WORKSPACE_DIR)
if not workspace.exists(): if not workspace.exists():
@@ -200,7 +144,7 @@ def _check_db_access(checks: List[Dict[str, Any]], errors: List[str], warnings:
return candidate_user return candidate_user
def run_startup_health_routine(app: Optional[FastAPI] = None) -> Dict[str, Any]: def run_startup_health_routine() -> Dict[str, Any]:
checks: List[Dict[str, Any]] = [] checks: List[Dict[str, Any]] = []
errors: List[str] = [] errors: List[str] = []
warnings: List[str] = [] warnings: List[str] = []
@@ -208,7 +152,6 @@ def run_startup_health_routine(app: Optional[FastAPI] = None) -> Dict[str, Any]:
_check_workspace_root(checks, errors) _check_workspace_root(checks, errors)
if not errors: if not errors:
_check_db_access(checks, errors, warnings) _check_db_access(checks, errors, warnings)
_check_required_demo_routes(app, checks, errors)
status = "healthy" if not errors else "failed" status = "healthy" if not errors else "failed"
report = { report = {