Compare commits

..

1 Commits

Author SHA1 Message Date
ي
b54c2978c3 Restrict podcast task status access by owner 2026-03-30 08:05:44 +05:30
6 changed files with 46 additions and 77 deletions

View File

@@ -203,7 +203,10 @@ async def create_audio_dubbing_task(
"""
user_id = require_authenticated_user(current_user)
task_id = task_manager.create_task("audio_dubbing")
task_id = task_manager.create_task(
"audio_dubbing",
metadata={"owner_user_id": user_id},
)
background_tasks.add_task(
_execute_dubbing_task,
@@ -240,7 +243,7 @@ async def get_dubbing_result(
"""
user_id = require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id)
task_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
@@ -403,7 +406,10 @@ async def create_voice_clone_task(
"""
user_id = require_authenticated_user(current_user)
task_id = task_manager.create_task("voice_clone")
task_id = task_manager.create_task(
"voice_clone",
metadata={"owner_user_id": user_id},
)
background_tasks.add_task(
_execute_voice_clone_task,
@@ -434,7 +440,7 @@ async def get_voice_clone_result(
"""
user_id = require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id)
task_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")

View File

@@ -222,7 +222,7 @@ def _execute_podcast_video_task(
)
# Verify the task status was updated correctly
updated_status = task_manager.get_task_status(task_id)
updated_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
logger.info(
f"[Podcast] Task status after update: task_id={task_id}, status={updated_status.get('status') if updated_status else 'None'}, has_result={bool(updated_status.get('result') if updated_status else False)}, video_url={updated_status.get('result', {}).get('video_url') if updated_status else 'N/A'}"
)
@@ -358,7 +358,10 @@ async def generate_podcast_video(
logger.warning(f"[Podcast] Failed to extract auth token from headers: {e}")
# Create async task
task_id = task_manager.create_task("podcast_video_generation")
task_id = task_manager.create_task(
"podcast_video_generation",
metadata={"owner_user_id": user_id},
)
background_tasks.add_task(
_execute_podcast_video_task,
task_id=task_id,
@@ -488,7 +491,10 @@ async def combine_podcast_videos(
raise HTTPException(status_code=400, detail="No scene videos provided")
# Create async task
task_id = task_manager.create_task("podcast_combine_videos")
task_id = task_manager.create_task(
"podcast_combine_videos",
metadata={"owner_user_id": user_id},
)
# Extract token for authenticated URL building
auth_token = None

View File

@@ -4,7 +4,7 @@ Podcast Maker API Router
Main router that imports and registers all handler modules.
"""
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any
from middleware.auth_middleware import get_current_user
@@ -32,5 +32,8 @@ router.include_router(dubbing.router)
@router.get("/task/{task_id}/status")
async def podcast_task_status(task_id: str, current_user: Dict[str, Any] = Depends(get_current_user)):
"""Expose task status under podcast namespace (reuses shared task manager)."""
require_authenticated_user(current_user)
return task_manager.get_task_status(task_id)
user_id = require_authenticated_user(current_user)
task_status = task_manager.get_task_status(task_id, requester_user_id=user_id)
if not task_status:
raise HTTPException(status_code=404, detail="Task not found")
return task_status

View File

@@ -34,9 +34,14 @@ class TaskManager:
del self.task_storage[task_id]
logger.debug(f"[StoryWriter] Cleaned up old task: {task_id}")
def create_task(self, task_type: str = "story_generation") -> str:
def create_task(
self,
task_type: str = "story_generation",
metadata: Optional[Dict[str, Any]] = None,
) -> str:
"""Create a new task and return its ID."""
task_id = str(uuid.uuid4())
task_metadata = metadata or {}
self.task_storage[task_id] = {
"status": "pending",
@@ -45,13 +50,14 @@ class TaskManager:
"error": None,
"progress_messages": [],
"task_type": task_type,
"progress": 0.0
"progress": 0.0,
"metadata": task_metadata,
}
logger.info(f"[StoryWriter] Created task: {task_id} (type: {task_type})")
return task_id
def get_task_status(self, task_id: str) -> Optional[Dict[str, Any]]:
def get_task_status(self, task_id: str, requester_user_id: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Get the status of a task."""
self.cleanup_old_tasks()
@@ -62,6 +68,15 @@ class TaskManager:
return None
task = self.task_storage[task_id]
metadata = task.get("metadata", {}) or {}
owner_user_id = metadata.get("owner_user_id")
if requester_user_id is not None and owner_user_id is not None and requester_user_id != owner_user_id:
logger.warning(
f"[StoryWriter] Task access denied for task {task_id}: requester does not match owner"
)
return None
response = {
"task_id": task_id,
"status": task["status"],

View File

@@ -15,7 +15,6 @@ from services.database import (
init_database,
default_engine,
)
from services.user_api_key_context import get_user_api_keys
_REQUIRED_SCHEMA: Dict[str, List[str]] = {
"onboarding_sessions": ["id", "user_id", "updated_at"],
@@ -145,62 +144,6 @@ def _check_db_access(checks: List[Dict[str, Any]], errors: List[str], warnings:
return candidate_user
def _check_production_api_key_loading(
checks: List[Dict[str, Any]],
errors: List[str],
warnings: List[str],
) -> None:
deploy_env = os.getenv("DEPLOY_ENV", "local").strip().lower()
if deploy_env == "local":
_record_check(checks, "production_api_key_loading", True, "skipped in local deploy mode")
return
test_tenant_id = os.getenv("ALWRITY_STARTUP_TEST_TENANT_ID", "").strip()
if not test_tenant_id:
message = (
"Missing ALWRITY_STARTUP_TEST_TENANT_ID for production API key startup check."
)
errors.append(message)
_record_check(checks, "production_api_key_loading", False, message)
return
try:
keys = get_user_api_keys(test_tenant_id)
except Exception as exc:
errors.append(
f"Failed to load API keys for startup test tenant '{test_tenant_id}': {exc}"
)
_record_check(checks, "production_api_key_loading", False, str(exc))
return
if not isinstance(keys, dict):
errors.append(
f"API key loader returned invalid payload type for startup test tenant '{test_tenant_id}'."
)
_record_check(checks, "production_api_key_loading", False, "invalid payload type")
return
non_empty_keys = [provider for provider, value in keys.items() if value]
if not non_empty_keys:
errors.append(
f"No API keys could be loaded for startup test tenant '{test_tenant_id}'."
)
_record_check(checks, "production_api_key_loading", False, "no non-empty keys loaded")
return
warning = None
if len(non_empty_keys) < len(keys):
warning = (
f"Startup test tenant '{test_tenant_id}' has {len(non_empty_keys)}/{len(keys)} non-empty API keys."
)
warnings.append(warning)
detail = f"loaded {len(non_empty_keys)} non-empty keys for tenant {test_tenant_id}"
if warning:
detail = f"{detail}; {warning}"
_record_check(checks, "production_api_key_loading", True, detail)
def run_startup_health_routine() -> Dict[str, Any]:
checks: List[Dict[str, Any]] = []
errors: List[str] = []
@@ -209,8 +152,6 @@ def run_startup_health_routine() -> Dict[str, Any]:
_check_workspace_root(checks, errors)
if not errors:
_check_db_access(checks, errors, warnings)
if not errors:
_check_production_api_key_loading(checks, errors, warnings)
status = "healthy" if not errors else "failed"
report = {

View File

@@ -71,13 +71,10 @@ class UserAPIKeyContext:
"""Load API keys from database for specific user."""
try:
from api.content_planning.services.content_strategy.onboarding import OnboardingDataIntegrationService
from services.database import get_session_for_user
from services.database import SessionLocal
integration_service = OnboardingDataIntegrationService()
db = get_session_for_user(user_id)
if not db:
logger.error(f"Failed to create DB session for user {user_id}")
return {}
db = SessionLocal()
try:
integrated_data = integration_service.get_integrated_data_sync(user_id, db)
keys = integrated_data.get('api_keys_data', {})
@@ -156,3 +153,4 @@ def get_tavily_key(user_id: Optional[str] = None) -> Optional[str]:
def get_copilotkit_key(user_id: Optional[str] = None) -> Optional[str]:
"""Get CopilotKit API key for user."""
return UserAPIKeyContext.get_user_key(user_id, 'copilotkit')