Compare commits

..

1 Commits

Author SHA1 Message Date
ي
27c167ebe8 Use tenant-scoped dubbed audio paths with safe file resolution 2026-03-30 08:07:01 +05:30
2 changed files with 47 additions and 77 deletions

View File

@@ -29,16 +29,45 @@ from ..models import (
VoiceCloneResult, VoiceCloneResult,
) )
from services.dubbing import AudioDubbingService from services.dubbing import AudioDubbingService
from ..constants import get_podcast_media_read_dirs, get_podcast_media_dir
router = APIRouter() router = APIRouter()
_dubbing_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="podcast_dubbing") _dubbing_executor = ThreadPoolExecutor(max_workers=4, thread_name_prefix="podcast_dubbing")
DUBBED_AUDIO_DIR = Path(__file__).resolve().parents[3] / "data" / "media" / "dubbed_audio" _DUBBED_AUDIO_SUBDIR = Path("dubbed_audio")
_LEGACY_DUBBED_AUDIO_DIR = Path(__file__).resolve().parents[3] / "data" / "media" / "dubbed_audio"
def _ensure_dubbed_audio_dir(): def _get_dubbed_audio_dir(user_id: str, *, ensure_exists: bool = False) -> Path:
DUBBED_AUDIO_DIR.mkdir(parents=True, exist_ok=True) """Resolve tenant-scoped dubbed audio directory under podcast audio media."""
base_dir = get_podcast_media_dir("audio", user_id, ensure_exists=ensure_exists)
dubbed_dir = (base_dir / _DUBBED_AUDIO_SUBDIR).resolve()
if ensure_exists:
dubbed_dir.mkdir(parents=True, exist_ok=True)
return dubbed_dir
def _resolve_dubbed_audio_file(filename: str, user_id: str) -> Path:
"""Resolve dubbed audio with traversal-safe checks (tenant first, then legacy fallback)."""
clean_filename = filename.split("?", 1)[0].strip()
if not clean_filename:
raise HTTPException(status_code=400, detail="Invalid filename")
candidate_dirs: list[Path] = []
for base_dir in get_podcast_media_read_dirs("audio", user_id):
candidate_dirs.append((base_dir / _DUBBED_AUDIO_SUBDIR).resolve())
candidate_dirs.append(_LEGACY_DUBBED_AUDIO_DIR.resolve())
for target_dir in candidate_dirs:
candidate = (target_dir / clean_filename).resolve()
if not str(candidate).startswith(str(target_dir)):
logger.error(f"[Podcast][Dubbing] Attempted path traversal: {filename}")
raise HTTPException(status_code=403, detail="Invalid audio path")
if candidate.exists():
return candidate
raise HTTPException(status_code=404, detail="Audio file not found")
def _execute_dubbing_task( def _execute_dubbing_task(
@@ -62,9 +91,8 @@ def _execute_dubbing_task(
message="Starting audio dubbing..." message="Starting audio dubbing..."
) )
_ensure_dubbed_audio_dir() dubbed_audio_dir = _get_dubbed_audio_dir(user_id, ensure_exists=True)
service = AudioDubbingService(output_dir=dubbed_audio_dir)
service = AudioDubbingService(output_dir=DUBBED_AUDIO_DIR)
def progress_callback(progress: float, message: str): def progress_callback(progress: float, message: str):
task_manager.update_task_status( task_manager.update_task_status(
@@ -136,9 +164,8 @@ def _execute_voice_clone_task(
message="Starting voice cloning..." message="Starting voice cloning..."
) )
_ensure_dubbed_audio_dir() dubbed_audio_dir = _get_dubbed_audio_dir(user_id, ensure_exists=True)
service = AudioDubbingService(output_dir=dubbed_audio_dir)
service = AudioDubbingService(output_dir=DUBBED_AUDIO_DIR)
task_manager.update_task_status( task_manager.update_task_status(
task_id, "processing", progress=30.0, task_id, "processing", progress=30.0,
@@ -301,12 +328,7 @@ async def serve_dubbed_audio(
""" """
user_id = require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
_ensure_dubbed_audio_dir() audio_path = _resolve_dubbed_audio_file(filename, user_id)
audio_path = DUBBED_AUDIO_DIR / filename
if not audio_path.exists():
raise HTTPException(status_code=404, detail="Audio file not found")
return FileResponse( return FileResponse(
path=audio_path, path=audio_path,
@@ -327,7 +349,8 @@ async def estimate_dubbing_cost(
""" """
user_id = require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
service = AudioDubbingService(output_dir=DUBBED_AUDIO_DIR) dubbed_audio_dir = _get_dubbed_audio_dir(user_id, ensure_exists=True)
service = AudioDubbingService(output_dir=dubbed_audio_dir)
cost_estimate = service.estimate_cost( cost_estimate = service.estimate_cost(
audio_duration_seconds=request.audio_duration_seconds, audio_duration_seconds=request.audio_duration_seconds,
@@ -479,12 +502,12 @@ async def serve_voice_audio(
""" """
user_id = require_authenticated_user(current_user) user_id = require_authenticated_user(current_user)
_ensure_dubbed_audio_dir() try:
audio_path = _resolve_dubbed_audio_file(filename, user_id)
audio_path = DUBBED_AUDIO_DIR / filename except HTTPException as exc:
if exc.status_code == 404:
if not audio_path.exists(): raise HTTPException(status_code=404, detail="Voice audio file not found") from exc
raise HTTPException(status_code=404, detail="Voice audio file not found") raise
return FileResponse( return FileResponse(
path=audio_path, path=audio_path,

View File

@@ -8,7 +8,6 @@ IMPORTANT: This is a compatibility layer. For new code, use UserAPIKeyContext di
""" """
import os import os
import time
from fastapi import Request from fastapi import Request
from loguru import logger from loguru import logger
from typing import Callable from typing import Callable
@@ -21,61 +20,8 @@ class APIKeyInjectionMiddleware:
for the duration of each request. for the duration of each request.
""" """
# Shared across middleware instances (module currently instantiates per request)
_missing_keys_log_timestamps = {}
def __init__(self): def __init__(self):
self.original_keys = {} self.original_keys = {}
@staticmethod
def _should_skip_missing_key_warning(request: Request) -> bool:
"""
Optionally suppress missing-key warnings for non-AI/internal routes.
Controlled by API_KEY_INJECTION_SKIP_NON_AI_WARNINGS (default: true).
"""
skip_non_ai_warnings = os.getenv('API_KEY_INJECTION_SKIP_NON_AI_WARNINGS', 'true').lower() in ('1', 'true', 'yes')
if not skip_non_ai_warnings:
return False
path_lower = (request.url.path or '').lower()
return (
path_lower.startswith('/api/subscription/')
or path_lower.startswith('/api/onboarding/')
or path_lower.endswith('/status')
or path_lower.endswith('/health')
or path_lower == '/health'
or path_lower == '/status'
)
def _log_missing_keys_non_blocking(self, request: Request, user_id: str) -> None:
"""
Log missing API keys without interrupting request flow.
- Defaults to debug-level logging.
- Optional warn once-per-user-per-interval via env:
API_KEY_INJECTION_MISSING_KEYS_LOG_MODE=warn_once
API_KEY_INJECTION_MISSING_KEYS_LOG_INTERVAL_SECONDS=900
"""
try:
if self._should_skip_missing_key_warning(request):
logger.debug(f"[API Key Injection] Missing keys for user {user_id} on non-AI route; skipping warning")
return
log_mode = os.getenv('API_KEY_INJECTION_MISSING_KEYS_LOG_MODE', 'debug').lower()
if log_mode != 'warn_once':
logger.debug(f"No API keys found for user {user_id}")
return
interval_seconds = int(os.getenv('API_KEY_INJECTION_MISSING_KEYS_LOG_INTERVAL_SECONDS', '900'))
now = time.time()
last_logged_at = self._missing_keys_log_timestamps.get(user_id, 0)
if (now - last_logged_at) >= max(interval_seconds, 1):
logger.warning(f"No API keys found for user {user_id}")
self._missing_keys_log_timestamps[user_id] = now
else:
logger.debug(f"No API keys found for user {user_id} (warning suppressed by interval)")
except Exception as log_error:
# Logging should never block request processing
logger.debug(f"[API Key Injection] Failed to log missing keys state for user {user_id}: {log_error}")
async def __call__(self, request: Request, call_next: Callable): async def __call__(self, request: Request, call_next: Callable):
""" """
@@ -122,7 +68,7 @@ class APIKeyInjectionMiddleware:
# Get user-specific API keys from database # Get user-specific API keys from database
with user_api_keys(user_id) as user_keys: with user_api_keys(user_id) as user_keys:
if not user_keys: if not user_keys:
self._log_missing_keys_non_blocking(request, user_id) logger.warning(f"No API keys found for user {user_id}")
return await call_next(request) return await call_next(request)
# Save original environment values # Save original environment values
@@ -174,3 +120,4 @@ async def api_key_injection_middleware(request: Request, call_next: Callable):
""" """
middleware = APIKeyInjectionMiddleware() middleware = APIKeyInjectionMiddleware()
return await middleware(request, call_next) return await middleware(request, call_next)