Compare commits

..

1 Commits

Author SHA1 Message Date
ي
6a182aecaf Support multi-source content asset filtering end-to-end 2026-05-18 14:36:16 +05:30
7 changed files with 158 additions and 196 deletions

View File

@@ -5,7 +5,7 @@ API endpoints for managing unified content assets across all modules.
from fastapi import APIRouter, Depends, HTTPException, Query, Body
from sqlalchemy.orm import Session
from typing import List, Optional, Dict, Any
from typing import List, Optional, Dict, Any, Set
from pydantic import BaseModel, Field
from datetime import datetime
@@ -47,6 +47,33 @@ class AssetResponse(BaseModel):
from_attributes = True
def _parse_source_modules(source_module: Optional[List[str]]) -> Optional[List[AssetSource]]:
"""Parse source_module query values from repeated params and/or comma-separated values."""
if not source_module:
return None
parsed_values: List[AssetSource] = []
seen: Set[AssetSource] = set()
for raw_value in source_module:
for value in raw_value.split(","):
normalized = value.strip().lower()
if not normalized:
continue
try:
module = AssetSource(normalized)
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid source module: {value.strip()}")
if module not in seen:
seen.add(module)
parsed_values.append(module)
return parsed_values or None
class AssetListResponse(BaseModel):
"""Response model for asset list."""
assets: List[AssetResponse]
@@ -58,7 +85,7 @@ class AssetListResponse(BaseModel):
@router.get("/", response_model=AssetListResponse)
async def get_assets(
asset_type: Optional[str] = Query(None, description="Filter by asset type"),
source_module: Optional[str] = Query(None, description="Filter by source module"),
source_module: Optional[List[str]] = Query(None, description="Filter by source module(s); supports repeated params and comma-separated values"),
search: Optional[str] = Query(None, description="Search query"),
tags: Optional[str] = Query(None, description="Comma-separated tags"),
favorites_only: bool = Query(False, description="Only favorites"),
@@ -89,12 +116,7 @@ async def get_assets(
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid asset type: {asset_type}")
source_module_enum = None
if source_module:
try:
source_module_enum = AssetSource(source_module.lower())
except ValueError:
raise HTTPException(status_code=400, detail=f"Invalid source module: {source_module}")
source_modules_enum = _parse_source_modules(source_module)
tags_list = None
if tags:
@@ -126,7 +148,7 @@ async def get_assets(
assets, total = service.get_user_assets(
user_id=user_id,
asset_type=asset_type_enum,
source_module=source_module_enum,
source_modules=source_modules_enum,
search_query=search,
tags=tags_list,
favorites_only=favorites_only,
@@ -200,7 +222,7 @@ async def create_asset(
asset = service.create_asset(
user_id=user_id,
asset_type=asset_type_enum,
source_module=source_module_enum,
source_modules=source_modules_enum,
filename=asset_data.filename,
file_url=asset_data.file_url,
file_path=asset_data.file_path,

View File

@@ -1,182 +0,0 @@
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Optional
from urllib.parse import urlencode
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import RedirectResponse
from loguru import logger
from sqlalchemy import text
from sqlalchemy.orm import Session
from services.database import get_db
router = APIRouter(prefix="/v1/social-proxy", tags=["social-proxy"])
def _utc_now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _ensure_tables(db: Session) -> None:
# Keep this router backward-compatible on tenant DBs without migrations.
db.execute(text("""
CREATE TABLE IF NOT EXISTS oauth_nonce_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
state TEXT NOT NULL UNIQUE,
nonce TEXT NOT NULL,
user_id TEXT NOT NULL,
platform TEXT NOT NULL,
channel_id INTEGER,
consumed_at TEXT,
expires_at TEXT,
created_at TEXT NOT NULL
)
"""))
db.execute(text("""
CREATE TABLE IF NOT EXISTS social_channels (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id TEXT NOT NULL,
platform TEXT NOT NULL,
platform_account_id TEXT NOT NULL,
token_bundle TEXT NOT NULL,
token_version INTEGER NOT NULL DEFAULT 1,
publication_linkage TEXT,
is_connected INTEGER NOT NULL DEFAULT 1,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
UNIQUE(platform, platform_account_id)
)
"""))
def _build_redirect(base_url: str, code: str, message: str, channel_id: Optional[int] = None) -> RedirectResponse:
params = {"code": code, "message": message}
if channel_id is not None:
params["channel_id"] = str(channel_id)
return RedirectResponse(url=f"{base_url}?{urlencode(params)}", status_code=303)
@router.get("/oauth/callback")
def oauth_callback(
state: str = Query(...),
platform: str = Query(...),
account_id: str = Query(...),
token_bundle: str = Query(..., description="Serialized token payload"),
ui_redirect: str = Query("/dashboard/connections"),
db: Session = Depends(get_db),
):
"""Consume OAuth callback, bind to user/platform, and upsert social channel connection."""
_ensure_tables(db)
record = db.execute(
text("""
SELECT id, nonce, user_id, platform, channel_id, consumed_at, expires_at
FROM oauth_nonce_sessions WHERE state = :state
"""),
{"state": state},
).mappings().first()
if not record:
return _build_redirect(ui_redirect, "invalid_state", "Missing OAuth session")
if record["consumed_at"] is not None:
return _build_redirect(ui_redirect, "state_reused", "OAuth state already consumed")
if record["platform"] != platform:
return _build_redirect(ui_redirect, "platform_mismatch", "Platform mismatch")
if record["expires_at"] and record["expires_at"] < _utc_now_iso():
return _build_redirect(ui_redirect, "state_expired", "OAuth session expired")
user_id = record["user_id"]
# Validate token payload is JSON.
try:
parsed_bundle = json.loads(token_bundle)
except json.JSONDecodeError as exc:
raise HTTPException(status_code=400, detail="Invalid token_bundle JSON") from exc
now = _utc_now_iso()
existing = db.execute(
text("""
SELECT id, publication_linkage, token_version
FROM social_channels
WHERE platform = :platform AND platform_account_id = :account_id
"""),
{"platform": platform, "account_id": account_id},
).mappings().first()
if existing:
# Reconnect path: preserve publication linkage and bump token version.
db.execute(
text("""
UPDATE social_channels
SET user_id = :user_id,
token_bundle = :token_bundle,
token_version = :token_version,
is_connected = 1,
updated_at = :updated_at
WHERE id = :id
"""),
{
"id": existing["id"],
"user_id": user_id,
"token_bundle": json.dumps(parsed_bundle),
"token_version": int(existing["token_version"] or 0) + 1,
"updated_at": now,
},
)
channel_id = existing["id"]
result_code = "reconnected"
result_message = "Channel reconnected"
else:
db.execute(
text("""
INSERT INTO social_channels (
user_id, platform, platform_account_id, token_bundle,
token_version, publication_linkage, is_connected, created_at, updated_at
) VALUES (
:user_id, :platform, :account_id, :token_bundle,
1, :publication_linkage, 1, :created_at, :updated_at
)
"""),
{
"user_id": user_id,
"platform": platform,
"account_id": account_id,
"token_bundle": json.dumps(parsed_bundle),
"publication_linkage": None,
"created_at": now,
"updated_at": now,
},
)
channel_id = db.execute(text("SELECT last_insert_rowid()")).scalar_one()
result_code = "connected"
result_message = "Channel connected"
# Bind callback session to concrete channel/user/platform and mark consumed.
db.execute(
text("""
UPDATE oauth_nonce_sessions
SET consumed_at = :consumed_at,
channel_id = :channel_id,
user_id = :user_id,
platform = :platform
WHERE id = :id
"""),
{
"id": record["id"],
"consumed_at": now,
"channel_id": channel_id,
"user_id": user_id,
"platform": platform,
},
)
db.commit()
logger.info(f"OAuth callback complete user={user_id} platform={platform} channel_id={channel_id}")
return _build_redirect(ui_redirect, result_code, result_message, channel_id)

View File

@@ -107,6 +107,7 @@ class ContentAssetService:
user_id: str,
asset_type: Optional[AssetType] = None,
source_module: Optional[AssetSource] = None,
source_modules: Optional[List[AssetSource]] = None,
search_query: Optional[str] = None,
tags: Optional[List[str]] = None,
favorites_only: bool = False,
@@ -125,6 +126,7 @@ class ContentAssetService:
user_id: Clerk user ID
asset_type: Filter by asset type (optional)
source_module: Filter by source module (optional)
source_modules: Filter by multiple source modules (optional)
search_query: Search in title, description, prompt (optional)
tags: Filter by tags (optional)
favorites_only: Only return favorites (optional)
@@ -142,7 +144,9 @@ class ContentAssetService:
if asset_type:
query = query.filter(ContentAsset.asset_type == asset_type)
if source_module:
if source_modules:
query = query.filter(ContentAsset.source_module.in_(source_modules))
elif source_module:
query = query.filter(ContentAsset.source_module == source_module)
if favorites_only:

View File

@@ -0,0 +1,31 @@
import importlib.util
from pathlib import Path
from fastapi import HTTPException
ROOT = Path(__file__).resolve().parents[3]
ROUTER_PATH = ROOT / 'backend' / 'api' / 'content_assets' / 'router.py'
MODELS_PATH = ROOT / 'backend' / 'models' / 'content_asset_models.py'
models_spec = importlib.util.spec_from_file_location('content_asset_models', MODELS_PATH)
models = importlib.util.module_from_spec(models_spec)
models_spec.loader.exec_module(models)
AssetSource = models.AssetSource
router_spec = importlib.util.spec_from_file_location('content_assets_router', ROUTER_PATH)
router = importlib.util.module_from_spec(router_spec)
router_spec.loader.exec_module(router)
def test_parse_source_modules_supports_repeated_and_csv_values():
parsed = router._parse_source_modules(["blog_writer", "youtube,podcast"])
assert parsed == [AssetSource.BLOG_WRITER, AssetSource.YOUTUBE, AssetSource.PODCAST]
def test_parse_source_modules_raises_for_invalid_values():
try:
router._parse_source_modules(["blog_writer,unknown"])
except HTTPException as exc:
assert exc.status_code == 400
assert "Invalid source module" in exc.detail
else:
raise AssertionError("Expected HTTPException for invalid source module")

View File

@@ -0,0 +1,50 @@
import importlib.util
from pathlib import Path
ROOT = Path(__file__).resolve().parents[3]
SERVICE_PATH = ROOT / 'backend' / 'services' / 'content_asset_service.py'
MODELS_PATH = ROOT / 'backend' / 'models' / 'content_asset_models.py'
models_spec = importlib.util.spec_from_file_location('content_asset_models', MODELS_PATH)
models = importlib.util.module_from_spec(models_spec)
models_spec.loader.exec_module(models)
AssetSource = models.AssetSource
service_spec = importlib.util.spec_from_file_location('content_asset_service', SERVICE_PATH)
service_module = importlib.util.module_from_spec(service_spec)
service_spec.loader.exec_module(service_module)
ContentAssetService = service_module.ContentAssetService
class DummyQuery:
def __init__(self):
self.filters = []
def filter(self, expr):
self.filters.append(expr)
return self
def count(self): return 0
def order_by(self, *_args, **_kwargs): return self
def limit(self, *_args, **_kwargs): return self
def offset(self, *_args, **_kwargs): return self
def all(self): return []
class DummyDB:
def __init__(self): self.query_obj = DummyQuery()
def query(self, *_args, **_kwargs): return self.query_obj
def test_get_user_assets_accepts_multiple_source_modules_filter():
db = DummyDB()
service = ContentAssetService(db)
assets, total = service.get_user_assets(
user_id="user-1",
source_modules=[AssetSource.BLOG_WRITER, AssetSource.YOUTUBE],
)
assert assets == []
assert total == 0
assert len(db.query_obj.filters) >= 2

View File

@@ -0,0 +1,35 @@
import { renderHook, waitFor } from '@testing-library/react';
import { useContentAssets } from '../useContentAssets';
const getTokenMock = jest.fn();
jest.mock('@clerk/clerk-react', () => ({
useAuth: () => ({ getToken: getTokenMock }),
}));
describe('useContentAssets', () => {
beforeEach(() => {
getTokenMock.mockResolvedValue('test-token');
global.fetch = jest.fn().mockResolvedValue({
ok: true,
json: async () => ({ assets: [], total: 0, limit: 100, offset: 0 }),
} as Response);
});
afterEach(() => {
jest.clearAllMocks();
});
it('sends all source_module values as repeated query params', async () => {
renderHook(() =>
useContentAssets({ source_module: ['blog_writer', 'youtube'], limit: 50, offset: 0 })
);
await waitFor(() => expect(global.fetch).toHaveBeenCalled());
const calledUrl = (global.fetch as jest.Mock).mock.calls[0][0] as string;
const params = new URL(calledUrl).searchParams;
expect(params.getAll('source_module')).toEqual(['blog_writer', 'youtube']);
});
});

View File

@@ -29,7 +29,7 @@ export interface ContentAsset {
export interface AssetFilters {
asset_type?: 'text' | 'image' | 'video' | 'audio';
source_module?: string | string[]; // Support single or multiple source modules
source_module?: string | string[]; // Supports single or multiple source modules
search?: string;
tags?: string[];
favorites_only?: boolean;
@@ -146,8 +146,10 @@ export const useContentAssets = (filters: AssetFilters = {}) => {
if (currentFilters.source_module) {
// Handle both string and array cases
if (Array.isArray(currentFilters.source_module)) {
// For arrays, use the first value (backend doesn't support multiple yet)
params.append('source_module', currentFilters.source_module[0]);
// Send every selected source module as repeated query params
currentFilters.source_module.forEach((module) => {
params.append('source_module', module);
});
} else {
params.append('source_module', currentFilters.source_module);
}