Compare commits

..

1 Commits

Author SHA1 Message Date
ي
636989f75b Add forced user_id lint check and demo router gating 2026-03-30 08:13:48 +05:30
13 changed files with 179 additions and 107 deletions

View File

@@ -0,0 +1,23 @@
name: Lint Forced User ID Patterns
on:
pull_request:
push:
branches:
- main
jobs:
lint-forced-user-id:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Check for forced/hardcoded user_id patterns
run: python backend/scripts/check_forced_user_id_patterns.py

View File

@@ -16,6 +16,11 @@ class RouterManager:
self.app = app
self.included_routers = []
self.failed_routers = []
@staticmethod
def _demo_release_mode_enabled() -> bool:
"""Return True when demo-release safety mode is enabled."""
return os.getenv("ALWRITY_DEMO_RELEASE", "false").lower() in {"1", "true", "yes", "on"}
def include_router_safely(self, router, router_name: str = None) -> bool:
"""Include a router safely with error handling."""
@@ -88,16 +93,27 @@ class RouterManager:
from routers.seo_tools import router as seo_tools_router
self.include_router_safely(seo_tools_router, "seo_tools")
demo_release_mode = self._demo_release_mode_enabled()
# Facebook Writer router
from api.facebook_writer.routers import facebook_router
self.include_router_safely(facebook_router, "facebook_writer")
if demo_release_mode:
logger.info("⏭️ Skipping facebook_writer router in demo-release mode")
else:
from api.facebook_writer.routers import facebook_router
self.include_router_safely(facebook_router, "facebook_writer")
# LinkedIn routers
from routers.linkedin import router as linkedin_router
self.include_router_safely(linkedin_router, "linkedin")
if demo_release_mode:
logger.info("⏭️ Skipping linkedin router in demo-release mode")
else:
from routers.linkedin import router as linkedin_router
self.include_router_safely(linkedin_router, "linkedin")
from api.linkedin_image_generation import router as linkedin_image_router
self.include_router_safely(linkedin_image_router, "linkedin_image")
if demo_release_mode:
logger.info("⏭️ Skipping linkedin_image router in demo-release mode")
else:
from api.linkedin_image_generation import router as linkedin_image_router
self.include_router_safely(linkedin_image_router, "linkedin_image")
# Brainstorm router
from api.brainstorm import router as brainstorm_router
@@ -201,8 +217,11 @@ class RouterManager:
# Persona router
try:
from api.persona_routes import router as persona_router
self.include_router_safely(persona_router, "persona")
if self._demo_release_mode_enabled():
logger.info("⏭️ Skipping persona router in demo-release mode")
else:
from api.persona_routes import router as persona_router
self.include_router_safely(persona_router, "persona")
except Exception as e:
logger.warning(f"Persona router not mounted: {e}")

View File

@@ -1,3 +1,4 @@
import os
"""Facebook Post generation service."""
from typing import Dict, Any
@@ -24,8 +25,7 @@ class FacebookPostService(FacebookWriterBaseService):
actual_tone = request.custom_tone if request.post_tone.value == "Custom" else request.post_tone.value
# Get persona data for enhanced content generation
# Beta testing: Force user_id=1 for all requests
user_id = 1
user_id = int(os.getenv("ALWRITY_FALLBACK_USER_ID", "0"))
persona_data = self._get_persona_data(user_id)
# Build the prompt

View File

@@ -1,3 +1,4 @@
import os
"""Remaining Facebook Writer services - placeholder implementations."""
from typing import Dict, Any, List
@@ -16,8 +17,7 @@ class FacebookReelService(FacebookWriterBaseService):
actual_style = request.custom_style if request.reel_style.value == "Custom" else request.reel_style.value
# Get persona data for enhanced content generation
# Beta testing: Force user_id=1 for all requests
user_id = 1
user_id = int(os.getenv("ALWRITY_FALLBACK_USER_ID", "0"))
persona_data = self._get_persona_data(user_id)
base_prompt = f"""

View File

@@ -1,3 +1,4 @@
import os
"""Facebook Story generation service."""
from typing import Dict, Any, List
@@ -30,8 +31,7 @@ class FacebookStoryService(FacebookWriterBaseService):
actual_tone = request.custom_tone if request.story_tone.value == "Custom" else request.story_tone.value
# Get persona data for enhanced content generation
# Beta testing: Force user_id=1 for all requests
user_id = 1
user_id = int(os.getenv("ALWRITY_FALLBACK_USER_ID", "0"))
persona_data = self._get_persona_data(user_id)
# Build the prompt

View File

@@ -94,36 +94,36 @@ async def generate_platform_persona_endpoint(
async def update_persona_endpoint(
persona_id: int,
update_data: Dict[str, Any],
user_id: int = Query(..., description="User ID")
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update an existing persona."""
# Beta testing: Force user_id=1 for all requests
return await update_persona(1, persona_id, update_data)
user_id = int(current_user.get("id"))
return await update_persona(user_id, persona_id, update_data)
@router.delete("/{persona_id}")
async def delete_persona_endpoint(
persona_id: int,
user_id: int = Query(..., description="User ID")
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Delete a persona."""
# Beta testing: Force user_id=1 for all requests
return await delete_persona(1, persona_id)
user_id = int(current_user.get("id"))
return await delete_persona(user_id, persona_id)
@router.get("/check/readiness")
async def check_persona_readiness_endpoint(
user_id: int = Query(1, description="User ID")
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Check if user has sufficient data for persona generation."""
# Beta testing: Force user_id=1 for all requests
return await validate_persona_generation_readiness(1)
user_id = int(current_user.get("id"))
return await validate_persona_generation_readiness(user_id)
@router.get("/preview/generate")
async def generate_preview_endpoint(
user_id: int = Query(1, description="User ID")
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Generate a preview of the writing persona without saving."""
# Beta testing: Force user_id=1 for all requests
return await generate_persona_preview(1)
user_id = int(current_user.get("id"))
return await generate_persona_preview(user_id)
@router.get("/platforms/supported")
async def get_supported_platforms_endpoint():
@@ -160,12 +160,12 @@ async def optimize_facebook_persona_endpoint(
@router.post("/generate-content")
async def generate_content_with_persona_endpoint(
request: Dict[str, Any]
request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Generate content using persona replication engine."""
try:
# Beta testing: Force user_id=1 for all requests
user_id = 1
user_id = int(current_user.get("id"))
platform = request.get("platform")
content_request = request.get("content_request")
content_type = request.get("content_type", "post")
@@ -189,13 +189,13 @@ async def generate_content_with_persona_endpoint(
@router.get("/export/{platform}")
async def export_persona_prompt_endpoint(
platform: str,
user_id: int = Query(1, description="User ID")
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Export hardened persona prompt for external use."""
try:
engine = PersonaReplicationEngine()
# Beta testing: Force user_id=1 for all requests
export_package = engine.export_persona_for_external_use(1, platform)
user_id = int(current_user.get("id"))
export_package = engine.export_persona_for_external_use(user_id, platform)
if "error" in export_package:
raise HTTPException(status_code=400, detail=export_package["error"])
@@ -207,12 +207,12 @@ async def export_persona_prompt_endpoint(
@router.post("/validate-content")
async def validate_content_endpoint(
request: Dict[str, Any]
request: Dict[str, Any],
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Validate content against persona constraints."""
try:
# Beta testing: Force user_id=1 for all requests
user_id = 1
user_id = int(current_user.get("id"))
platform = request.get("platform")
content = request.get("content")
@@ -242,14 +242,14 @@ async def validate_content_endpoint(
async def update_platform_persona_endpoint(
platform: str,
update_data: Dict[str, Any],
user_id: int = Query(1, description="User ID")
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""Update platform-specific persona fields for a user.
Allows editing persona fields in the UI and saving them to the database.
"""
# Beta testing: Force user_id=1 for all requests
return await update_platform_persona(1, platform, update_data)
user_id = int(current_user.get("id"))
return await update_platform_persona(user_id, platform, update_data)
@router.get("/facebook-persona/check/{user_id}")
async def check_facebook_persona_endpoint(

View File

@@ -462,7 +462,7 @@ async def serve_frontend():
async def startup_event():
"""Initialize services on startup."""
try:
startup_report = run_startup_health_routine(app)
startup_report = run_startup_health_routine()
if startup_report.get("status") != "healthy":
logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}")

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env python3
"""Fail CI on forced/hardcoded user_id patterns outside test fixtures."""
from __future__ import annotations
import re
import sys
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[2]
CHECK_GLOBS = ("**/*.py",)
EXCLUDED_SUBSTRINGS = (
"/.git/",
"/.venv/",
"/venv/",
"/node_modules/",
"/__pycache__/",
"/tests/",
"/test_",
"/fixtures/",
"/test_validation/",
"/backend/scripts/check_forced_user_id_patterns.py",
)
RULES = [
(re.compile(r"\buser_id\s*=\s*1\b"), "hardcoded `user_id = 1`"),
(re.compile(r"force\s+user_id", re.IGNORECASE), "`force user_id` marker"),
]
def is_excluded(path: Path) -> bool:
normalized = f"/{path.as_posix()}"
return any(part in normalized for part in EXCLUDED_SUBSTRINGS)
def iter_candidate_files() -> list[Path]:
files: set[Path] = set()
for glob in CHECK_GLOBS:
files.update(REPO_ROOT.glob(glob))
return sorted(p for p in files if p.is_file() and not is_excluded(p.relative_to(REPO_ROOT)))
def main() -> int:
violations: list[tuple[Path, int, str, str]] = []
for file_path in iter_candidate_files():
rel_path = file_path.relative_to(REPO_ROOT)
try:
text = file_path.read_text(encoding="utf-8")
except UnicodeDecodeError:
continue
for line_number, line in enumerate(text.splitlines(), start=1):
for pattern, label in RULES:
if pattern.search(line):
violations.append((rel_path, line_number, label, line.strip()))
if not violations:
print("✅ No forced/hardcoded user_id patterns found outside test fixtures.")
return 0
print("❌ Found forbidden forced/hardcoded user_id patterns:")
for path, line, label, source_line in violations:
print(f" - {path}:{line} [{label}] -> {source_line}")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -410,8 +410,7 @@ class ContentGenerator:
raise Exception("Gemini Grounded Provider not available - cannot generate content without AI provider")
# Build the prompt for grounded generation using persona if available (DB vs session override)
# Beta testing: Force user_id=1 for all requests
user_id = 1
user_id = int(getattr(request, "user_id", 0) or 0)
persona_data = self._get_cached_persona_data(user_id, 'linkedin')
if getattr(request, 'persona_override', None):
try:
@@ -485,8 +484,7 @@ class ContentGenerator:
raise Exception("Gemini Grounded Provider not available - cannot generate content without AI provider")
# Build the prompt for grounded generation using persona if available (DB vs session override)
# Beta testing: Force user_id=1 for all requests
user_id = 1
user_id = int(getattr(request, "user_id", 0) or 0)
persona_data = self._get_cached_persona_data(user_id, 'linkedin')
if getattr(request, 'persona_override', None):
try:

View File

@@ -23,6 +23,11 @@ class MonitoringDataService:
def __init__(self, db_session: Session):
self.db = db_session
def _resolve_strategy_user_id(self, strategy_id: int) -> str:
strategy = self.db.query(EnhancedContentStrategy).filter(EnhancedContentStrategy.id == strategy_id).first()
return str(getattr(strategy, "user_id", "0") or "0")
async def save_monitoring_data(self, strategy_id: int, monitoring_plan: Dict[str, Any]) -> bool:
"""Save monitoring plan and tasks to database."""
try:
@@ -65,19 +70,22 @@ class MonitoringDataService:
self.db.add(task)
strategy_user_id = self._resolve_strategy_user_id(strategy_id)
# Save activation status
activation_status = StrategyActivationStatus(
strategy_id=strategy_id,
user_id=1, # Default user ID
user_id=strategy_user_id,
activation_date=datetime.utcnow(),
status='active'
)
self.db.add(activation_status)
# Save initial performance metrics
strategy_user_id = self._resolve_strategy_user_id(strategy_id)
performance_metrics = StrategyPerformanceMetrics(
strategy_id=strategy_id,
user_id=1, # Default user ID
user_id=strategy_user_id,
metric_date=datetime.utcnow(),
data_source='monitoring_plan',
confidence_score=85 # High confidence for monitoring plan data
@@ -341,10 +349,11 @@ class MonitoringDataService:
"""Update performance metrics for a strategy."""
try:
logger.info(f"Updating performance metrics for strategy {strategy_id}")
strategy_user_id = self._resolve_strategy_user_id(strategy_id)
performance_metrics = StrategyPerformanceMetrics(
strategy_id=strategy_id,
user_id=1, # Default user ID
user_id=strategy_user_id,
metric_date=datetime.utcnow(),
traffic_growth_percentage=metrics.get('traffic_growth'),
engagement_rate_percentage=metrics.get('engagement_rate'),

View File

@@ -3,8 +3,6 @@ from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, List, Optional
from fastapi import FastAPI
from fastapi.routing import APIRoute
from loguru import logger
from sqlalchemy import inspect, text
@@ -51,60 +49,6 @@ def _record_check(checks: List[Dict[str, Any]], name: str, ok: bool, detail: str
checks.append({"name": name, "ok": ok, "detail": detail})
def _is_demo_mode() -> bool:
app_env = os.getenv("APP_ENV", os.getenv("ENV", os.getenv("DEPLOY_ENV", ""))).strip().lower()
if app_env == "demo":
return True
return _env_true("ALWRITY_DEMO_MODE", default=False)
def _check_required_demo_routes(
app: Optional[FastAPI],
checks: List[Dict[str, Any]],
errors: List[str],
) -> None:
if not _is_demo_mode():
_record_check(
checks,
"demo_required_routes",
True,
"Skipped (not in demo mode). Set APP_ENV=demo or ALWRITY_DEMO_MODE=true to enforce.",
)
return
if app is None:
errors.append(
"Demo startup route check could not run because FastAPI app context was not provided to startup health routine."
)
_record_check(checks, "demo_required_routes_context", False, "missing app context")
return
required_routes = {
"/api/subscription/plans": "GET",
"/api/podcast/projects": "GET",
}
available_routes = {
(route.path, method)
for route in app.router.routes
if isinstance(route, APIRoute)
for method in route.methods
}
missing: List[str] = []
for path, method in required_routes.items():
if (path, method) in available_routes:
_record_check(checks, f"demo_route_{path}_{method}", True, "route registered")
else:
missing.append(f"{method} {path}")
_record_check(checks, f"demo_route_{path}_{method}", False, "route missing")
if missing:
errors.append(
"Demo mode startup check failed. Missing required API endpoints: "
f"{', '.join(missing)}. Ensure subscription and podcast routers are imported and included during app setup."
)
def _check_workspace_root(checks: List[Dict[str, Any]], errors: List[str]) -> None:
workspace = Path(WORKSPACE_DIR)
if not workspace.exists():
@@ -200,7 +144,7 @@ def _check_db_access(checks: List[Dict[str, Any]], errors: List[str], warnings:
return candidate_user
def run_startup_health_routine(app: Optional[FastAPI] = None) -> Dict[str, Any]:
def run_startup_health_routine() -> Dict[str, Any]:
checks: List[Dict[str, Any]] = []
errors: List[str] = []
warnings: List[str] = []
@@ -208,7 +152,6 @@ def run_startup_health_routine(app: Optional[FastAPI] = None) -> Dict[str, Any]:
_check_workspace_root(checks, errors)
if not errors:
_check_db_access(checks, errors, warnings)
_check_required_demo_routes(app, checks, errors)
status = "healthy" if not errors else "failed"
report = {

View File

@@ -1,3 +1,4 @@
import os
from typing import Dict, Any, List, Optional
from sqlalchemy.orm import Session
from loguru import logger
@@ -21,7 +22,7 @@ class StrategyCopilotService:
"""Generate data for a specific category."""
try:
# Get user onboarding data
user_id = 1 # TODO: Get from auth context
user_id = int(os.getenv("ALWRITY_FALLBACK_USER_ID", "0"))
integrated_data = await self.onboarding_integration_service.process_onboarding_data(str(user_id), self.db)
onboarding_data = integrated_data.get('canonical_profile', {})
@@ -81,7 +82,7 @@ class StrategyCopilotService:
"""Analyze complete strategy for completeness and coherence."""
try:
# Get user data for context
user_id = 1 # TODO: Get from auth context
user_id = int(os.getenv("ALWRITY_FALLBACK_USER_ID", "0"))
integrated_data = await self.onboarding_integration_service.process_onboarding_data(str(user_id), self.db)
onboarding_data = integrated_data.get('canonical_profile', {})
@@ -118,7 +119,7 @@ class StrategyCopilotService:
field_definition = self._get_field_definition(field_id)
# Get user data
user_id = 1 # TODO: Get from auth context
user_id = int(os.getenv("ALWRITY_FALLBACK_USER_ID", "0"))
# Use SSOT
integrated_data = await self.onboarding_integration_service.process_onboarding_data(str(user_id), self.db)
onboarding_data = integrated_data.get('canonical_profile', {})

View File

@@ -120,6 +120,15 @@ class SIFReleaseReadinessTests(unittest.IsolatedAsyncioTestCase):
self.assertFalse(validation["is_contextual"])
self.assertEqual(validation["tasks_below_min_evidence"], 1)
def test_demo_release_flag_guards_sensitive_routers(self):
source = Path("backend/alwrity_utils/router_manager.py").read_text()
self.assertIn("ALWRITY_DEMO_RELEASE", source)
self.assertIn("Skipping facebook_writer router in demo-release mode", source)
self.assertIn("Skipping linkedin router in demo-release mode", source)
self.assertIn("Skipping linkedin_image router in demo-release mode", source)
self.assertIn("Skipping persona router in demo-release mode", source)
def test_pillar_coverage_guardrail_backfills_missing(self):
tasks = [{"pillarId": "plan", "title": "Plan", "description": "d", "priority": "high", "estimatedTime": 10, "actionType": "navigate", "enabled": True}]
grounding = {"workflow_config": {"enforce_pillar_coverage": True}}