Compare commits
162 Commits
dependabot
...
codex/add-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
87925c8fdc | ||
|
|
928c2f20aa | ||
|
|
7385100017 | ||
|
|
93a1985d9f | ||
|
|
4fdc7d3ea0 | ||
|
|
85d6cc1d20 | ||
|
|
0d20dcb801 | ||
|
|
463cfdc5cf | ||
|
|
19a5af9682 | ||
|
|
ca725b77e7 | ||
|
|
bc311cfdf6 | ||
|
|
6c740ee63f | ||
|
|
05e84d6089 | ||
|
|
f46465cd97 | ||
|
|
ebdd1edfa0 | ||
|
|
45bd1eada9 | ||
|
|
ef7b3d2b49 | ||
|
|
98cfb03cf7 | ||
|
|
993000a540 | ||
|
|
b3e2f4382c | ||
|
|
638e785ad4 | ||
|
|
98a1cc91a2 | ||
|
|
ab827e9ab9 | ||
|
|
8ee042bd2c | ||
|
|
4df1adfbe2 | ||
|
|
3f984e8d0c | ||
|
|
a7d2ef1c09 | ||
|
|
fc47445181 | ||
|
|
d518365c87 | ||
|
|
ba94ee30bc | ||
|
|
8b79099b15 | ||
|
|
fbbfe81ed7 | ||
|
|
d7319c981e | ||
|
|
3c4965462a | ||
|
|
26ccb2f609 | ||
|
|
cbd68fa43f | ||
|
|
641143a7d6 | ||
|
|
dd7f8515a4 | ||
|
|
5e205d52cd | ||
|
|
b9f2123ce9 | ||
|
|
00f46ecbed | ||
|
|
973dd501fe | ||
|
|
efff72f4bd | ||
|
|
913e59a0a8 | ||
|
|
02d13716f3 | ||
|
|
c5d625945f | ||
|
|
6e9c11744c | ||
|
|
b1ca29f7f7 | ||
|
|
91b2f996fd | ||
|
|
7637babd7d | ||
|
|
1deed48484 | ||
|
|
afdbc78779 | ||
|
|
294c64877d | ||
|
|
4a4b8c5a24 | ||
|
|
625dd550d3 | ||
|
|
7f7279f903 | ||
|
|
e68c289901 | ||
|
|
f748c081c2 | ||
|
|
cf70261658 | ||
|
|
7241874545 | ||
|
|
35ebf8c077 | ||
|
|
7aead3ae7d | ||
|
|
80cdd7ff29 | ||
|
|
a9dd9afba1 | ||
|
|
eaea1ee793 | ||
|
|
6db378beff | ||
|
|
7c2a185a29 | ||
|
|
17c046c51e | ||
|
|
ba9ddbf368 | ||
|
|
bfa1b028b3 | ||
|
|
0cac25751f | ||
|
|
a486f4c4fa | ||
|
|
34f82c43dd | ||
|
|
95edd7d470 | ||
|
|
280159669b | ||
|
|
5f13ee5f7b | ||
|
|
e71cf65802 | ||
|
|
196ea65af9 | ||
|
|
bcf62017aa | ||
|
|
0732887c09 | ||
|
|
e704aa7d87 | ||
|
|
79f26c815b | ||
|
|
e2726805f3 | ||
|
|
ff61708e29 | ||
|
|
63767d72b3 | ||
|
|
d85a1ee561 | ||
|
|
18bed36e2b | ||
|
|
24d932d2b5 | ||
|
|
cd53680523 | ||
|
|
edf3f32b3c | ||
|
|
e59c77b221 | ||
|
|
1a456b21b7 | ||
|
|
813f9acc34 | ||
|
|
60b6b0904b | ||
|
|
80838ed028 | ||
|
|
e66311ea44 | ||
|
|
cf2d3a51e8 | ||
|
|
8dd1c13f85 | ||
|
|
ad97dc0d3b | ||
|
|
45231625fd | ||
|
|
23bf709c10 | ||
|
|
3f1d5cbb09 | ||
|
|
12960a22ea | ||
|
|
45d2b0b693 | ||
|
|
348839be36 | ||
|
|
b5ab46a749 | ||
|
|
d12fe6348e | ||
|
|
0e3a611e57 | ||
|
|
b24d39349d | ||
|
|
0d0d964605 | ||
|
|
03d43fb54b | ||
|
|
c361bd127d | ||
|
|
6ac880e61e | ||
|
|
92a27270aa | ||
|
|
cc03567d2f | ||
|
|
3c79073a10 | ||
|
|
71c0e2ed46 | ||
|
|
11663b0142 | ||
|
|
4ca58084fd | ||
|
|
6c99b26140 | ||
|
|
13e25cec3b | ||
|
|
724832c688 | ||
|
|
917be873df | ||
|
|
429689bdcb | ||
|
|
6cf5d0396d | ||
|
|
27147d50a5 | ||
|
|
2b025673d6 | ||
|
|
3f3575cc18 | ||
|
|
c0a5f5fdeb | ||
|
|
1f139e3167 | ||
|
|
1bdf0d4b93 | ||
|
|
f1e8cdb0d8 | ||
|
|
0680bf98a2 | ||
|
|
cc2443cf5b | ||
|
|
6cef24289f | ||
|
|
f6795100ac | ||
|
|
aa2317c359 | ||
|
|
bba56a1940 | ||
|
|
0f34048c6a | ||
|
|
1cf3ae96ce | ||
|
|
a697b869ab | ||
|
|
9e3867ca61 | ||
|
|
b567a32136 | ||
|
|
88deabb9fc | ||
|
|
f30f6c5346 | ||
|
|
2ab4471632 | ||
|
|
a43c229809 | ||
|
|
0e8953b538 | ||
|
|
6579f60d7d | ||
|
|
08f08a1a52 | ||
|
|
ab78a6a158 | ||
|
|
22c31e6c77 | ||
|
|
249a1962d4 | ||
|
|
dcb7d28e03 | ||
|
|
26e1f08ebb | ||
|
|
fcf00cd20d | ||
|
|
b8ffda1cbb | ||
|
|
6d5ae8d2fa | ||
|
|
c5e2fc3514 | ||
|
|
a3e4f5231a | ||
|
|
a8c80c5b75 | ||
|
|
027638dfb9 |
14
.gitignore
vendored
14
.gitignore
vendored
@@ -4,15 +4,27 @@ __pycache__/
|
||||
*.db
|
||||
*.sqlite*
|
||||
|
||||
nul
|
||||
LICENSE
|
||||
CHANGELOG.md
|
||||
|
||||
.planning
|
||||
.planning/
|
||||
|
||||
|
||||
.trae/
|
||||
.trae
|
||||
|
||||
workspace/
|
||||
workspace/*
|
||||
|
||||
.windsurf
|
||||
artifacts
|
||||
|
||||
.opencode
|
||||
|
||||
data/
|
||||
data/*
|
||||
|
||||
.trae/
|
||||
/backend/database/migrations/*
|
||||
@@ -21,7 +33,7 @@ backend/*.db
|
||||
backend\youtube_audio
|
||||
youtube_avatars
|
||||
backend\youtube_images
|
||||
|
||||
data/media/podcast_videos/AI_Videos
|
||||
backend/.trae_*
|
||||
|
||||
# Onboarding progress files
|
||||
|
||||
88
.planning/ROADMAP.md
Normal file
88
.planning/ROADMAP.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# Roadmap: Alwrity - ALwrity Frontend Optimization
|
||||
|
||||
## Overview
|
||||
|
||||
Optimize the frontend build to reduce build time from 5 minutes to under 30 seconds and shrink bundle size from 8.42MB to under 1MB. First, implement code splitting with React.lazy and feature-gated loading using ALWRITY_ENABLED_FEATURES. Then migrate from Create React App to Vite for faster builds. Finally, optimize dependencies for maximum performance.
|
||||
|
||||
## Phases
|
||||
|
||||
**Phase Numbering:**
|
||||
- Integer phases (1, 2, 3, 4): Planned work
|
||||
- All phases planned and ready for execution
|
||||
|
||||
---
|
||||
|
||||
### Phase 1: Code Splitting & Feature-Based Lazy Loading ✅ Complete
|
||||
**Goal**: Replace all static imports with React.lazy dynamic imports and add feature-gated loading using ALWRITY_ENABLED_FEATURES. Also convert MUI icon barrel imports to individual imports (moved here from Phase 3 for Vite readiness).
|
||||
**Depends on**: Nothing (first phase)
|
||||
**Requirements**: VITE-04 (code splitting), VITE-06 (dependency optimization)
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. ✅ All 31+ route components loaded via React.lazy (not static imports)
|
||||
2. ✅ Initial bundle size reduced from 8.42MB to 2.50MB (70% reduction)
|
||||
3. ✅ Disabled features (via ALWRITY_ENABLED_FEATURES) don't load their bundles
|
||||
4. ✅ All existing routes still work correctly
|
||||
5. ✅ No build warnings or errors with CRA
|
||||
6. ✅ All MUI icon imports changed from barrel to individual (111 files)
|
||||
|
||||
**Plans**: 3 plans (all complete)
|
||||
|
||||
Plans:
|
||||
- [x] 01-01: Convert 31 static imports to React.lazy with Suspense
|
||||
- [x] 01-02: Add feature-gated route loading using ALWRITY_ENABLED_FEATURES
|
||||
- [x] 01-03: Convert MUI icon barrel imports to individual imports (111 files)
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Migrate from CRA to Vite (Next)
|
||||
**Goal**: Migrate frontend from Create React App to Vite for fast builds
|
||||
**Depends on**: Phase 1 ✅
|
||||
**Requirements**: VITE-01, VITE-02, VITE-03
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. `npm run dev` starts Vite dev server with HMR
|
||||
2. `npm run build` completes in under 30 seconds (down from 5 minutes)
|
||||
3. All environment variables work with `VITE_*` prefix
|
||||
4. TypeScript compiles without errors
|
||||
5. Material UI theme renders correctly
|
||||
|
||||
**Plans**: 3 plans
|
||||
|
||||
Plans:
|
||||
- [ ] 02-01: Install Vite dependencies and create configuration
|
||||
- [ ] 02-02: Migrate index.html and entry point
|
||||
- [ ] 02-03: Update environment variables and scripts
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: Dependency Cleanup & Production Validation
|
||||
**Goal**: Remove unused dependencies and deploy Vite build to production
|
||||
**Depends on**: Phase 2
|
||||
**Requirements**: VITE-07, VITE-08, VITE-09
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. Unused dependencies identified and removed
|
||||
2. Production build serves correctly (preview mode)
|
||||
3. All features tested and working (Clerk auth, Stripe, CopilotKit)
|
||||
4. Vercel deployment config updated for Vite
|
||||
5. Build time consistently under 30 seconds
|
||||
6. Total bundle size under 2MB
|
||||
|
||||
**Plans**: 2 plans (consolidated from former Phase 3 & 4)
|
||||
|
||||
Plans:
|
||||
- [ ] 03-01: Audit and remove unused dependencies, update Vercel config
|
||||
- [ ] 03-02: Full feature testing and performance validation
|
||||
|
||||
---
|
||||
|
||||
## Execution Order
|
||||
|
||||
Phases execute in numeric order: 1 → 2 → 3
|
||||
|
||||
**Key insight:** Phase 1 (code splitting) works with CRA, so we immediately reduce bundle size. Phase 2 (Vite) gives build speed bonus on already-split bundles. Phase 3 is cleanup and deployment.
|
||||
|
||||
## Progress
|
||||
|
||||
| Phase | Plans Complete | Status | Completed |
|
||||
|-------|----------------|--------|-----------|
|
||||
| 1. Code Splitting & MUI Optimization | 3/3 | ✅ Complete | 2026-05-08 |
|
||||
| 2. Migrate CRA to Vite | 0/3 | ⏳ Ready | - |
|
||||
| 3. Cleanup & Production | 0/2 | ⏳ Planned | - |
|
||||
73
.planning/STATE.md
Normal file
73
.planning/STATE.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# Project State: Alwrity
|
||||
|
||||
## Current Position
|
||||
|
||||
**Active Phase:** Phase 1 - Code Splitting & Feature-Based Lazy Loading
|
||||
**Phase Status:** ✅ Complete — Ready for Phase 2
|
||||
**Milestone:** v1.0 - Frontend Optimization
|
||||
|
||||
## Phase Progress
|
||||
|
||||
### Phase 1: Code Splitting & Feature-Based Lazy Loading
|
||||
- **Status:** ✅ Complete
|
||||
- **Plans:** 3 plans executed (01-01, 01-02, 01-03)
|
||||
|
||||
**Plans:**
|
||||
- [x] 01-01: Convert 31 static imports to React.lazy with Suspense
|
||||
- [x] 01-02: Add feature-gated route loading using ALWRITY_ENABLED_FEATURES
|
||||
- [x] 01-03: Convert MUI icon barrel imports to individual imports (111 files)
|
||||
|
||||
**Results:**
|
||||
- Main bundle: 8.42MB → 2.50MB (70% reduction via React.lazy)
|
||||
- 190+ chunk files for route-level code splitting
|
||||
- 47 routes feature-gated with ALWRITY_ENABLED_FEATURES
|
||||
- 16 feature keys in FEATURE_KEYS constant
|
||||
- 111 files converted from barrel to individual MUI icon imports
|
||||
- Zero barrel imports from @mui/icons-material remain
|
||||
|
||||
### Phase 2: Migrate CRA to Vite
|
||||
- **Status:** Ready to start (Phase 1 complete)
|
||||
- **Plans:** 3 plans created (02-01, 02-02, 02-03)
|
||||
- **Dependencies:** Phase 1 complete
|
||||
|
||||
**Plans:**
|
||||
- [ ] 02-01: Install Vite dependencies and create configuration
|
||||
- [ ] 02-02: Migrate index.html and entry point
|
||||
- [ ] 02-03: Update environment variables and scripts
|
||||
|
||||
### Phase 3: Production Validation (Planned)
|
||||
- Depends on: Phase 2
|
||||
- Focus: Vercel deploy, full feature testing
|
||||
|
||||
### Phase 4: (Removed — MUI icon optimization folded into Phase 1-03)
|
||||
|
||||
## Decisions Made
|
||||
|
||||
### Locked Decisions
|
||||
- **Code splitting first**, then Vite migration (not the other way around) ✅ Done
|
||||
- Use React.lazy for ALL route components (this is a React feature, NOT bundler-specific) ✅ Done
|
||||
- Use ALWRITY_ENABLED_FEATURES for feature-gated route loading ✅ Done
|
||||
- **MUI icon imports before Vite migration** — barrel imports converted to individual per-file default imports ✅ Done
|
||||
- Use Vite 5.x with @vitejs/plugin-react
|
||||
- Disable sourcemaps in production build for speed
|
||||
- Migrate env vars from `REACT_APP_*` to `VITE_*`
|
||||
|
||||
### Patterns Established
|
||||
- **MUI icon imports**: Always `import IconName from '@mui/icons-material/IconName'` — never barrel destructuring
|
||||
- **Route splitting**: All route components use React.lazy with Suspense
|
||||
- **Feature gating**: FeatureRoute wraps inside ProtectedRoute (auth → then feature check)
|
||||
|
||||
## Key Insight
|
||||
|
||||
**React.lazy is a React feature (not CRA or Vite specific).** Doing code splitting first with CRA:
|
||||
1. Immediately reduces main bundle from 8.42MB → ~1-2MB
|
||||
2. Adds no risk (React.lazy is stable since React 16.6)
|
||||
3. Makes Vite migration smoother (bundles are already split)
|
||||
4. ALWRITY_ENABLED_FEATURES can prevent disabled feature bundles from loading at all
|
||||
|
||||
**MUI icon barrel imports eliminated** — 111 files converted to individual per-file imports. This ensures reliable tree-shaking during Vite migration and beyond.
|
||||
|
||||
---
|
||||
|
||||
*Last updated: 2026-05-08*
|
||||
*Updated by: gsd-executor*
|
||||
129
.planning/phases/01-code-splitting/01-03-SUMMARY.md
Normal file
129
.planning/phases/01-code-splitting/01-03-SUMMARY.md
Normal file
@@ -0,0 +1,129 @@
|
||||
---
|
||||
phase: 01-code-splitting
|
||||
plan: 03
|
||||
type: execute
|
||||
subsystem: frontend
|
||||
tags: [performance, MUI, icons, tree-shaking, barrel-imports]
|
||||
requires:
|
||||
- phase: 01-code-splitting-02
|
||||
provides: feature gating structure for route protection
|
||||
provides:
|
||||
- All MUI icon imports converted from barrel (destructured) to individual per-file default imports
|
||||
- Zero barrel imports from @mui/icons-material remain in the codebase
|
||||
affects: [02-vite-migration, build performance]
|
||||
tech-stack:
|
||||
added: []
|
||||
patterns: [individual MUI icon imports, per-file default imports for tree-shaking]
|
||||
key-files:
|
||||
created: []
|
||||
modified:
|
||||
- frontend/src/components/shared/ErrorBoundary.tsx
|
||||
- frontend/src/components/SubscriptionGuard.tsx
|
||||
- frontend/src/components/SubscriptionExpiredModal.tsx
|
||||
- frontend/src/pages/SchedulerDashboard.tsx
|
||||
- frontend/src/pages/BillingPage.tsx
|
||||
- +106 additional frontend component files
|
||||
key-decisions:
|
||||
- "All MUI icon barrel imports converted BEFORE Vite migration to eliminate Webpack 4 tree-shaking uncertainty"
|
||||
- "Used per-file default imports (import X from '@mui/icons-material/X') instead of destructured barrel imports"
|
||||
- "Aliased icons (e.g., ErrorOutline as ErrorIcon) converted to named default imports matching the alias (import ErrorIcon from '@mui/icons-material/ErrorOutline')"
|
||||
- "JSX variable names preserved — only import statements changed"
|
||||
patterns-established:
|
||||
- "MUI icon imports: always use import X from '@mui/icons-material/X' pattern, never import { X } from '@mui/icons-material'"
|
||||
duration: 45min
|
||||
completed: 2026-05-08
|
||||
---
|
||||
|
||||
# Phase 1 Plan 01-03: MUI Icon Import Optimization Summary
|
||||
|
||||
**Converted all 300+ MUI icon barrel imports to individual per-file default imports across 111 frontend files — eliminating Webpack 4 tree-shaking uncertainty before Vite migration**
|
||||
|
||||
## Performance
|
||||
|
||||
- **Duration:** ~35 min
|
||||
- **Completed:** 2026-05-08
|
||||
- **Tasks:** 10 commits across 111 files
|
||||
- **Files modified:** 111
|
||||
|
||||
## Accomplishments
|
||||
|
||||
- Converted **all barrel** `import { X } from '@mui/icons-material'` to individual `import X from '@mui/icons-material/X'` — **zero barrel imports remaining**
|
||||
- Modified **111 files** across every area: PodcastMaker, YouTubeCreator, OnboardingWizard, billing, SEO, shared components, and more
|
||||
- Handled aliased imports (`IconName as Alias`) correctly — JSX variable names preserved unchanged
|
||||
- Build verified — `npm run build:nomap` succeeds with zero new errors
|
||||
- Enables reliable tree-shaking during Phase 2 (Vite migration) — each file imports only the icons it uses
|
||||
|
||||
## Task Commits
|
||||
|
||||
Each batch was committed atomically:
|
||||
|
||||
1. **ErrorBoundary** (`components/shared/`) - `46781a0` — 5 icons
|
||||
2. **SubscriptionGuard** - `bda75cb` — 2 icons
|
||||
3. **SubscriptionExpiredModal** - `80f76b1` — 3 icons
|
||||
4. **SchedulerDashboard** - `7ffd972` — 7 icons
|
||||
5. **BillingPage** - `a76671c` — 1 icon
|
||||
6. **Billing, Blog, ContentPlanning, ErrorBoundary, Pricing, Alerts** - `a009cbb` — 8 files, 36 insertions
|
||||
7. **ImageStudio, Landing, LinkedIn, MainDashboard, OnboardingWizard** - `205e098` — 14 files, 65 insertions
|
||||
8. **PodcastMaker AnalysisPanel** - `25ce5b9` — 18 files, 58 insertions
|
||||
9. **PodcastMaker, ProductMarketing, Research, Scheduler, SEO, Shared** - `986a7e5` — 44 files, 149 insertions
|
||||
10. **StoryWriter, YouTubeCreator** - `6361255` — 22 files, 67 insertions
|
||||
|
||||
## Files Modified
|
||||
|
||||
**111 files total** across the frontend source tree:
|
||||
|
||||
- `components/billing/` — 2 files (ComprehensiveAPIBreakdown, CostOptimizationRecommendations)
|
||||
- `components/BlogWriter/` — 1 file (BlogWriterPhasesSection)
|
||||
- `components/ContentPlanningDashboard/` — 2 files (CardExpansionWrapper, StrategyErrorBoundary)
|
||||
- `components/ErrorBoundary.tsx` — 1 file (3 icons)
|
||||
- `components/ImageStudio/` — 2 files (AssetFilters, CreateStudioCostAlerts)
|
||||
- `components/Landing/` — 2 files (EnterpriseCTA, FeatureShowcase)
|
||||
- `components/LinkedInWriter/` — 1 file (FactCheckResults)
|
||||
- `components/MainDashboard/` — 1 file (MainDashboard)
|
||||
- `components/OnboardingWizard/` — 7 files (incl. VoiceAvatarPlaceholder with 22 icons)
|
||||
- `components/PodcastMaker/` — 40 files (AnalysisPanel, CreateStep, ScriptEditor, etc.)
|
||||
- `components/Pricing/` — 1 file (PricingPage)
|
||||
- `components/ProductMarketing/` — 5 files (CampaignWizard, ProductPhotoshootStudio, etc.)
|
||||
- `components/Research/` — 2 files (PersonalizationIndicator, ResearchInputContainer)
|
||||
- `components/SchedulerDashboard/` — 1 file (SchedulerCharts)
|
||||
- `components/SEODashboard/` — 3 files (AIInsightsPanel, HealthScore, MetricCard)
|
||||
- `components/shared/` — 12 files (ErrorBoundary, AlertsBadge, ProtectedRoute, etc.)
|
||||
- `components/StoryWriter/` — 3 files (AIStorySetupModal, FormFieldWithTooltip, SelectFieldWithTooltip)
|
||||
- `components/SubscriptionGuard.tsx` — 1 file
|
||||
- `components/SubscriptionExpiredModal.tsx` — 1 file
|
||||
- `components/YouTubeCreator/` — 19 files (SceneCard, RenderStep, PlanStep, etc.)
|
||||
- `pages/` — 2 files (BillingPage, ResearchDashboard/PresetsCard)
|
||||
|
||||
## Decisions Made
|
||||
|
||||
- **Convert all barrel imports now, before Vite migration** — CRA's Webpack 4 cannot reliably tree-shake barrel imports. Converting before the bundler swap reduces migration risk and ensures Vite's native ESM tree-shaking works optimally.
|
||||
- **Per-file default import pattern** — Every icon gets its own import line: `import IconName from '@mui/icons-material/IconName'`. This is the most predictable pattern and works identically in both Webpack and Vite.
|
||||
- **Alias handling** — For icons imported as `{ X as Y }`, the alias `Y` becomes the import name: `import Y from '@mui/icons-material/X'`. JSX usage unchanged.
|
||||
- **Multiple import lines preserved** — Files with separate barrel imports from `@mui/icons-material` were converted to multiple individual import blocks, preserving the original organizational structure.
|
||||
|
||||
## Deviations from Plan
|
||||
|
||||
None - this was ad-hoc work not covered by an existing PLAN.md.
|
||||
|
||||
## Issues Encountered
|
||||
|
||||
- **Task agent timeout**: First attempt at parallel conversion agents failed silently for batches 1-2 (73 files). Re-launched with explicit edit instructions - succeeded on second attempt.
|
||||
- **No naming conflicts found**: Despite converting 300+ icon imports across 111 files, no variable naming collisions occurred. Each icon only appears once per file.
|
||||
|
||||
## Build Verification
|
||||
|
||||
- `npm run build:nomap` — **PASSED** with zero errors
|
||||
- Only pre-existing CRA bundle size warning remains (expected — Vite migration will resolve it in Phase 2)
|
||||
- No new build warnings introduced
|
||||
|
||||
## Next Phase Readiness
|
||||
|
||||
- Frontend is ready for **Phase 2: Vite Migration**
|
||||
- All MUI icon imports use individual default imports — tree-shaking will work correctly with Vite's rollup
|
||||
- User should perform manual testing of Podcast Maker with `REACT_APP_ENABLED_FEATURES=podcast` before Vite migration begins
|
||||
- After manual verification, proceed with [Phase 2-01: Install Vite dependencies and create configuration]
|
||||
|
||||
---
|
||||
|
||||
*Phase: 01-code-splitting*
|
||||
*Completed: 2026-05-08*
|
||||
1
Procfile
Normal file
1
Procfile
Normal file
@@ -0,0 +1 @@
|
||||
web: cd backend && python start_alwrity_backend.py --production
|
||||
@@ -1,43 +0,0 @@
|
||||
{
|
||||
"preflight": {
|
||||
"success": true,
|
||||
"can_proceed": true,
|
||||
"estimated_cost": 0.3
|
||||
},
|
||||
"operations": {
|
||||
"analysis_title_suggestions": [
|
||||
"AI Agents in 2026",
|
||||
"Ship Faster with AI",
|
||||
"Startup AI Playbook"
|
||||
],
|
||||
"research_provider": "exa",
|
||||
"research_cost": 0.015,
|
||||
"video_task_status": "completed"
|
||||
},
|
||||
"dashboard_deltas": {
|
||||
"total_calls_before": 1,
|
||||
"total_calls_after": 5,
|
||||
"delta_calls": 4,
|
||||
"total_cost_before": 0.09,
|
||||
"total_cost_after": 0.488,
|
||||
"delta_cost": 0.398,
|
||||
"projected_monthly_cost_before": 0.09,
|
||||
"projected_monthly_cost_after": 0.49,
|
||||
"delta_projected_monthly_cost": 0.4
|
||||
},
|
||||
"provider_cost_deltas": {
|
||||
"exa": 0.005,
|
||||
"huggingface": 0.003,
|
||||
"wavespeed": 0.39
|
||||
},
|
||||
"acceptance": {
|
||||
"passed": true,
|
||||
"criteria": {
|
||||
"preflight_success": true,
|
||||
"usage_cost_incremented": true,
|
||||
"usage_call_incremented": true,
|
||||
"projection_incremented": true,
|
||||
"provider_delta_present": true
|
||||
}
|
||||
}
|
||||
}
|
||||
2
backend/Procfile
Normal file
2
backend/Procfile
Normal file
@@ -0,0 +1,2 @@
|
||||
# Use start_alwrity_backend.py for deployment
|
||||
web: python start_alwrity_backend.py --production
|
||||
157
backend/add_method.py
Normal file
157
backend/add_method.py
Normal file
@@ -0,0 +1,157 @@
|
||||
#!/usr/bin/env python
|
||||
# Add _get_all_historical_usage method to usage_tracking_service.py
|
||||
|
||||
with open('services/subscription/usage_tracking_service.py', 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find where to insert (before get_usage_trends)
|
||||
insert_idx = None
|
||||
for i, line in enumerate(lines):
|
||||
if ' def get_usage_trends(' in line:
|
||||
insert_idx = i
|
||||
break
|
||||
|
||||
if insert_idx is None:
|
||||
print("Error: Could not find insertion point")
|
||||
exit(1)
|
||||
|
||||
print(f"Inserting at line {insert_idx + 1}")
|
||||
|
||||
# Method to insert
|
||||
new_method = ''' def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
|
||||
# Get all usage summaries for the user
|
||||
all_summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
if not all_summaries:
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'limits': self.pricing_service.get_user_limits(user_id),
|
||||
'provider_breakdown': {},
|
||||
'usage_percentages': {},
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Aggregate all data from UsageSummary
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
# Calculate weighted average response time
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
# Calculate overall error rate
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Map database columns to frontend keys
|
||||
provider_mapping = {
|
||||
'gemini_calls': 'gemini',
|
||||
'openai_calls': 'openai',
|
||||
'anthropic_calls': 'anthropic',
|
||||
'mistral_calls': 'huggingface',
|
||||
'wavespeed_calls': 'wavespeed',
|
||||
'exa_calls': 'exa',
|
||||
'video_calls': 'video',
|
||||
'image_edit_calls': 'image_edit',
|
||||
'audio_calls': 'audio',
|
||||
}
|
||||
|
||||
# Build provider_breakdown for frontend
|
||||
provider_breakdown = {}
|
||||
for db_col, frontend_key in provider_mapping.items():
|
||||
total_provider_calls = sum(getattr(s, db_col, 0) or 0 for s in all_summaries)
|
||||
provider_breakdown[frontend_key] = {
|
||||
'calls': total_provider_calls,
|
||||
'cost': 0,
|
||||
'tokens': 0
|
||||
}
|
||||
|
||||
# Calculate usage_percentages based on limits
|
||||
usage_percentages = {}
|
||||
if limits and limits.get('limits'):
|
||||
# Gemini calls percentage
|
||||
gemini_calls = provider_breakdown.get('gemini', {}).get('calls', 0)
|
||||
gemini_limit = limits.get('limits', {}).get('gemini_calls', 0) or 0
|
||||
if gemini_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (gemini_calls / gemini_limit) * 100
|
||||
|
||||
# HuggingFace calls percentage (from mistral_calls)
|
||||
huggingface_calls = provider_breakdown.get('huggingface', {}).get('calls', 0)
|
||||
huggingface_limit = limits.get('limits', {}).get('mistral_calls', 0) or 0
|
||||
if huggingface_limit > 0:
|
||||
usage_percentages['huggingface_calls'] = (huggingface_calls / huggingface_limit) * 100
|
||||
|
||||
# Cost percentage
|
||||
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (total_cost / cost_limit) * 100
|
||||
|
||||
# Build historical breakdown
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
'billing_period': s.billing_period,
|
||||
'total_calls': s.total_calls or 0,
|
||||
'total_tokens': s.total_tokens or 0,
|
||||
'total_cost': float(s.total_cost or 0),
|
||||
'usage_status': status_val,
|
||||
'updated_at': s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
# Determine overall status
|
||||
usage_status = 'active'
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status = s.usage_status.value
|
||||
except:
|
||||
status = str(s.usage_status)
|
||||
if status == 'limit_reached':
|
||||
usage_status = 'limit_reached'
|
||||
break
|
||||
elif status == 'warning' and usage_status != 'limit_reached':
|
||||
usage_status = 'warning'
|
||||
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': usage_status,
|
||||
'total_calls': total_calls,
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': round(total_cost, 2),
|
||||
'avg_response_time': round(avg_response_time, 2),
|
||||
'error_rate': round(error_rate, 2),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': historical_breakdown,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
'''
|
||||
|
||||
# Insert the new method
|
||||
new_lines = lines[:insert_idx] + [new_method] + lines[insert_idx:]
|
||||
|
||||
# Write back
|
||||
with open('services/subscription/usage_tracking_service.py', 'w', encoding='utf-8') as f:
|
||||
f.writelines(new_lines)
|
||||
|
||||
print("Successfully added _get_all_historical_usage method")
|
||||
@@ -3,6 +3,11 @@ ALwrity Utilities Package
|
||||
Modular utilities for ALwrity backend startup and configuration.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
# Check feature mode early to skip heavy imports
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
|
||||
from .dependency_manager import DependencyManager
|
||||
from .environment_setup import EnvironmentSetup
|
||||
from .database_setup import DatabaseSetup
|
||||
@@ -11,7 +16,6 @@ from .health_checker import HealthChecker
|
||||
from .rate_limiter import RateLimiter
|
||||
from .frontend_serving import FrontendServing
|
||||
from .router_manager import RouterManager
|
||||
from .onboarding_manager import OnboardingManager
|
||||
from .feature_runtime import (
|
||||
get_active_profiles,
|
||||
get_enabled_groups,
|
||||
@@ -21,6 +25,12 @@ from .feature_runtime import (
|
||||
is_enabled,
|
||||
)
|
||||
|
||||
# Lazy load OnboardingManager - it triggers heavy imports (aiohttp, etc.)
|
||||
if _is_full_mode:
|
||||
from .onboarding_manager import OnboardingManager
|
||||
else:
|
||||
OnboardingManager = None
|
||||
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
|
||||
@@ -55,22 +55,28 @@ class EnvironmentSetup:
|
||||
print("🔧 Setting up environment variables...")
|
||||
|
||||
# Production environment variables
|
||||
# IMPORTANT: Don't override PORT if already set by Render cloud
|
||||
render_port = os.getenv("PORT")
|
||||
|
||||
if self.production_mode:
|
||||
env_vars = {
|
||||
"HOST": "0.0.0.0",
|
||||
"PORT": "8000",
|
||||
"RELOAD": "false",
|
||||
"LOG_LEVEL": "INFO",
|
||||
"DEBUG": "false"
|
||||
}
|
||||
# Only set PORT if not already provided by cloud (Render sets PORT)
|
||||
if not render_port:
|
||||
env_vars["PORT"] = "8000"
|
||||
else:
|
||||
env_vars = {
|
||||
"HOST": "0.0.0.0",
|
||||
"PORT": "8000",
|
||||
"RELOAD": "true",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"DEBUG": "true"
|
||||
}
|
||||
if not render_port:
|
||||
env_vars["PORT"] = "8000"
|
||||
|
||||
for key, value in env_vars.items():
|
||||
os.environ.setdefault(key, value)
|
||||
|
||||
@@ -51,6 +51,13 @@ FEATURE_GROUPS: Dict[str, FeatureGroup] = {
|
||||
"api.content_planning.strategy_copilot:router",
|
||||
),
|
||||
),
|
||||
"blog_writer": FeatureGroup(
|
||||
features=("blog_writer",),
|
||||
routers=(
|
||||
"api.blog_writer.router:router",
|
||||
"api.blog_writer.seo_analysis:router",
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -59,5 +66,6 @@ PROFILE_GROUP_MAP: Dict[str, Tuple[str, ...]] = {
|
||||
"core": ("core",),
|
||||
"podcast": ("core", "podcast"),
|
||||
"youtube": ("core", "youtube"),
|
||||
"blog_writer": ("core", "blog_writer"),
|
||||
"planning": ("core", "content_planning"),
|
||||
}
|
||||
|
||||
@@ -39,9 +39,10 @@ class ProductionOptimizer:
|
||||
def _set_production_env_vars(self) -> None:
|
||||
"""Set production-specific environment variables."""
|
||||
production_vars = {
|
||||
# Note: PORT is NOT set here - it's provided by the deployment platform (e.g., Render)
|
||||
# Don't override PORT as it must come from the environment
|
||||
# Note: HOST is not set here - it's auto-detected by start_backend()
|
||||
# Based on deployment environment (cloud vs local)
|
||||
'PORT': '8000',
|
||||
'RELOAD': 'false',
|
||||
'LOG_LEVEL': 'INFO',
|
||||
'DEBUG': 'false',
|
||||
|
||||
@@ -14,9 +14,9 @@ from loguru import logger
|
||||
|
||||
CORE_ROUTER_REGISTRY = [
|
||||
{"name": "component_logic", "module": "api.component_logic", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog-writer", "youtube"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog_writer", "youtube"}},
|
||||
{"name": "step3_research", "module": "api.onboarding_utils.step3_routes", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "step4_assets", "module": "api.onboarding_utils.step4_asset_routes", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "step4_assets", "module": "api.onboarding_utils.step4_asset_routes", "attr": "router", "features": {"all", "core", "podcast"}},
|
||||
{"name": "step4_persona", "module": "api.onboarding_utils.step4_persona_routes_optimized", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "gsc_auth", "module": "routers.gsc_auth", "attr": "router", "features": {"all", "core", "seo"}},
|
||||
{"name": "wordpress_oauth", "module": "routers.wordpress_oauth", "attr": "router", "features": {"all", "core"}},
|
||||
@@ -29,31 +29,31 @@ CORE_ROUTER_REGISTRY = [
|
||||
{"name": "linkedin_image", "module": "api.linkedin_image_generation", "attr": "router", "features": {"all", "core", "linkedin"}},
|
||||
{"name": "brainstorm", "module": "api.brainstorm", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "hallucination_detector", "module": "api.hallucination_detector", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "platform_analytics", "module": "routers.platform_analytics", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "bing_insights", "module": "routers.bing_insights", "attr": "router", "features": {"all", "core", "seo"}},
|
||||
{"name": "background_jobs", "module": "routers.background_jobs", "attr": "router", "features": {"all", "core"}},
|
||||
]
|
||||
|
||||
OPTIONAL_ROUTER_REGISTRY = [
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story-writer"}},
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story_writer"}},
|
||||
{"name": "wix", "module": "api.wix_routes", "attr": "router", "features": {"all"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "persona", "module": "api.persona_routes", "attr": "router", "features": {"all", "persona"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video-studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product-marketing"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video_studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product_marketing"}},
|
||||
{"name": "campaign_creator", "module": "routers.campaign_creator", "attr": "router", "features": {"all"}},
|
||||
{"name": "content_assets", "module": "api.content_assets.router", "attr": "router", "features": {"all"}},
|
||||
{"name": "podcast", "module": "api.podcast.router", "attr": "router", "features": {"all", "podcast"}},
|
||||
@@ -116,10 +116,6 @@ class RouterManager:
|
||||
if "all" in enabled_features:
|
||||
return True
|
||||
|
||||
# Skip core routers in podcast-only mode (they require non-podcast features)
|
||||
if enabled_features == {"podcast"}:
|
||||
return False
|
||||
|
||||
# If no required features specified, include by default
|
||||
if not required_features:
|
||||
return True
|
||||
|
||||
@@ -5,50 +5,59 @@ The onboarding endpoints are re-exported from a stable module
|
||||
`onboarding.py`.
|
||||
"""
|
||||
|
||||
from .onboarding_endpoints import (
|
||||
health_check,
|
||||
get_onboarding_status,
|
||||
get_onboarding_progress_full,
|
||||
get_step_data,
|
||||
complete_step,
|
||||
skip_step,
|
||||
validate_step_access,
|
||||
get_api_keys,
|
||||
save_api_key,
|
||||
validate_api_keys,
|
||||
start_onboarding,
|
||||
complete_onboarding,
|
||||
reset_onboarding,
|
||||
get_resume_info,
|
||||
get_onboarding_config,
|
||||
generate_writing_personas,
|
||||
generate_writing_personas_async,
|
||||
get_persona_task_status,
|
||||
assess_persona_quality,
|
||||
regenerate_persona,
|
||||
get_persona_generation_options
|
||||
)
|
||||
import os
|
||||
|
||||
__all__ = [
|
||||
'health_check',
|
||||
'get_onboarding_status',
|
||||
'get_onboarding_progress_full',
|
||||
'get_step_data',
|
||||
'complete_step',
|
||||
'skip_step',
|
||||
'validate_step_access',
|
||||
'get_api_keys',
|
||||
'save_api_key',
|
||||
'validate_api_keys',
|
||||
'start_onboarding',
|
||||
'complete_onboarding',
|
||||
'reset_onboarding',
|
||||
'get_resume_info',
|
||||
'get_onboarding_config',
|
||||
'generate_writing_personas',
|
||||
'generate_writing_personas_async',
|
||||
'get_persona_task_status',
|
||||
'assess_persona_quality',
|
||||
'regenerate_persona',
|
||||
'get_persona_generation_options'
|
||||
]
|
||||
# In feature-only modes, don't import heavy onboarding endpoints
|
||||
# They trigger heavy dependencies (exa_py, etc.)
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
|
||||
if not _is_full_mode:
|
||||
__all__ = []
|
||||
else:
|
||||
from .onboarding_endpoints import (
|
||||
health_check,
|
||||
get_onboarding_status,
|
||||
get_onboarding_progress_full,
|
||||
get_step_data,
|
||||
complete_step,
|
||||
skip_step,
|
||||
validate_step_access,
|
||||
get_api_keys,
|
||||
save_api_key,
|
||||
validate_api_keys,
|
||||
start_onboarding,
|
||||
complete_onboarding,
|
||||
reset_onboarding,
|
||||
get_resume_info,
|
||||
get_onboarding_config,
|
||||
generate_writing_personas,
|
||||
generate_writing_personas_async,
|
||||
get_persona_task_status,
|
||||
assess_persona_quality,
|
||||
regenerate_persona,
|
||||
get_persona_generation_options
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'health_check',
|
||||
'get_onboarding_status',
|
||||
'get_onboarding_progress_full',
|
||||
'get_step_data',
|
||||
'complete_step',
|
||||
'skip_step',
|
||||
'validate_step_access',
|
||||
'get_api_keys',
|
||||
'save_api_key',
|
||||
'validate_api_keys',
|
||||
'start_onboarding',
|
||||
'complete_onboarding',
|
||||
'reset_onboarding',
|
||||
'get_resume_info',
|
||||
'get_onboarding_config',
|
||||
'generate_writing_personas',
|
||||
'generate_writing_personas_async',
|
||||
'get_persona_task_status',
|
||||
'assess_persona_quality',
|
||||
'regenerate_persona',
|
||||
'get_persona_generation_options'
|
||||
]
|
||||
@@ -1,52 +1,104 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
"""
|
||||
Assets Serving Router
|
||||
|
||||
Serves user-uploaded assets (avatars, voice samples) from workspace storage.
|
||||
Uses authenticated or query-token access for security.
|
||||
Audio MIME types are set correctly based on file extension so browsers
|
||||
can play voice clone previews without NotSupportedError.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from services.database import WORKSPACE_DIR, get_user_db_path
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from typing import Dict, Any
|
||||
|
||||
from middleware.auth_middleware import get_current_user_with_query_token
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from utils.storage_paths import get_repo_root, sanitize_user_id
|
||||
|
||||
router = APIRouter(prefix="/api/assets", tags=["Assets Serving"])
|
||||
|
||||
MIME_MAP = {
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".ogg": "audio/ogg",
|
||||
".opus": "audio/opus",
|
||||
".webm": "audio/webm",
|
||||
".m4a": "audio/mp4",
|
||||
".aac": "audio/aac",
|
||||
".flac": "audio/flac",
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".svg": "image/svg+xml",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_asset_path(user_id: str, category: str, filename: str) -> Path:
|
||||
"""Resolve asset path in user workspace with path-traversal protection."""
|
||||
safe_user_id = sanitize_user_id(user_id)
|
||||
repo_root = get_repo_root()
|
||||
|
||||
file_path = (repo_root / "workspace" / f"workspace_{safe_user_id}" / "assets" / category / filename).resolve()
|
||||
|
||||
workspace_dir = (repo_root / "workspace" / f"workspace_{safe_user_id}").resolve()
|
||||
if not str(file_path).startswith(str(workspace_dir)):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def _get_media_type(filename: str) -> str:
|
||||
"""Determine MIME type from file extension, with fallback."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
return MIME_MAP.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
@router.get("/{user_id}/avatars/{filename}")
|
||||
async def serve_avatar(user_id: str, filename: str):
|
||||
"""
|
||||
Serve avatar images directly.
|
||||
Public endpoint relying on unguessable filenames.
|
||||
"""
|
||||
# Sanitize user_id (simple check to prevent directory traversal)
|
||||
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
|
||||
if safe_user_id != user_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
# Sanitize filename
|
||||
async def serve_avatar(
|
||||
user_id: str,
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve avatar images. Supports auth via Authorization header or ?token= query param."""
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
safe_filename = os.path.basename(filename)
|
||||
|
||||
# Construct path
|
||||
# workspace/workspace_{user_id}/assets/avatars/{filename}
|
||||
file_path = Path(WORKSPACE_DIR) / f"workspace_{safe_user_id}" / "assets" / "avatars" / safe_filename
|
||||
|
||||
file_path = _resolve_asset_path(user_id, "avatars", safe_filename)
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
media_type = _get_media_type(safe_filename)
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
|
||||
|
||||
@router.get("/{user_id}/voice_samples/{filename}")
|
||||
async def serve_voice_sample(user_id: str, filename: str):
|
||||
async def serve_voice_sample(
|
||||
user_id: str,
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve voice sample audio files.
|
||||
|
||||
Supports auth via Authorization header or ?token= query param.
|
||||
The ?token= param is essential for <audio> elements and new Audio()
|
||||
which cannot send Authorization headers.
|
||||
"""
|
||||
Serve voice sample audio files directly.
|
||||
"""
|
||||
# Sanitize user_id
|
||||
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
|
||||
if safe_user_id != user_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
# Sanitize filename
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
safe_filename = os.path.basename(filename)
|
||||
|
||||
# Construct path
|
||||
# workspace/workspace_{user_id}/assets/voice_samples/{filename}
|
||||
file_path = Path(WORKSPACE_DIR) / f"workspace_{safe_user_id}" / "assets" / "voice_samples" / safe_filename
|
||||
|
||||
file_path = _resolve_asset_path(user_id, "voice_samples", safe_filename)
|
||||
|
||||
if not file_path.exists():
|
||||
logger.info(f"[Assets] Voice sample not found: {file_path}")
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
media_type = _get_media_type(safe_filename)
|
||||
file_size = file_path.stat().st_size
|
||||
logger.warning(f"[Assets] Serving voice sample: {safe_filename} ({media_type}, {file_size} bytes)")
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
@@ -1195,3 +1195,68 @@ async def generate_introductions(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate introductions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Save Complete Blog Asset
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class SaveCompleteBlogAssetRequest(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
seo_title: Optional[str] = None
|
||||
meta_description: Optional[str] = None
|
||||
focus_keyword: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
categories: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.post("/save-complete-asset")
|
||||
async def save_complete_blog_asset(
|
||||
request: SaveCompleteBlogAssetRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
"""Save the complete blog content as a single asset in the asset library."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
full_content = f"# {request.title}\n\n{request.content}"
|
||||
|
||||
asset_id = save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=full_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Published Blog: {request.title[:60]}",
|
||||
description=request.meta_description or f"Complete published blog post: {request.title}",
|
||||
prompt=f"SEO Title: {request.seo_title or request.title}\nFocus Keyword: {request.focus_keyword or ''}",
|
||||
tags=["blog", "published"] + [t for t in (request.tags or []) if t],
|
||||
asset_metadata={
|
||||
"status": "published",
|
||||
"focus_keyword": request.focus_keyword,
|
||||
"categories": request.categories,
|
||||
"word_count": len(full_content.split()),
|
||||
},
|
||||
subdirectory="published",
|
||||
file_extension=".md"
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"✅ Complete blog asset saved to library: ID={asset_id}")
|
||||
return {"success": True, "asset_id": asset_id}
|
||||
else:
|
||||
logger.warning("save_and_track_text_content returned None for published blog")
|
||||
return {"success": False, "error": "Failed to save blog asset"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save complete blog asset: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from services.database import SessionLocal, get_session_for_user
|
||||
from services.database import get_session_for_user
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -264,7 +264,7 @@ class TaskManager:
|
||||
raise ValueError("Global target words exceed 1000; medium generation not allowed")
|
||||
|
||||
# Create a sync session for asset saving
|
||||
db_session = SessionLocal()
|
||||
db_session = get_session_for_user(user_id)
|
||||
try:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
@@ -326,6 +326,7 @@ class TaskManager:
|
||||
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
|
||||
self.task_storage[task_id]["status"] = "failed"
|
||||
self.task_storage[task_id]["error"] = str(e)
|
||||
self.task_storage[task_id]["error_data"] = {"error_message": str(e), "error_type": type(e).__name__}
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
|
||||
@@ -9,13 +9,27 @@ from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from .step4_persona_routes import _extract_user_id
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
|
||||
def _extract_user_id(user: Dict[str, Any]) -> str:
|
||||
"""Extract a stable user ID from Clerk-authenticated user payloads.
|
||||
Prefers 'clerk_user_id' or 'id', falls back to 'user_id', else 'unknown'.
|
||||
"""
|
||||
if not isinstance(user, dict):
|
||||
return 'unknown'
|
||||
return (
|
||||
user.get('clerk_user_id')
|
||||
or user.get('id')
|
||||
or user.get('user_id')
|
||||
or 'unknown'
|
||||
)
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from utils.file_storage import save_file_safely, generate_unique_filename
|
||||
from services.database import get_db, WORKSPACE_DIR
|
||||
from services.database import get_db
|
||||
from utils.storage_paths import get_user_workspace, sanitize_user_id
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from models.content_asset_models import ContentAsset, AssetType, AssetSource
|
||||
from sqlalchemy import desc
|
||||
@@ -73,6 +87,8 @@ async def get_latest_avatar(
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
|
||||
logger.warning(f"[latest-avatar] Looking for avatar for user_id: {user_id}")
|
||||
|
||||
# Search for assets that are either:
|
||||
# 1. Saved with source_module=BRAND_AVATAR_GENERATOR (new)
|
||||
# 2. Saved with source_module=STORY_WRITER but have metadata category='brand_avatar' (legacy)
|
||||
@@ -87,6 +103,8 @@ async def get_latest_avatar(
|
||||
])
|
||||
).order_by(desc(ContentAsset.created_at)).limit(50).all()
|
||||
|
||||
logger.warning(f"[latest-avatar] Found {len(candidates)} candidate(s)")
|
||||
|
||||
asset = None
|
||||
for candidate in candidates:
|
||||
# Check for direct match (new assets)
|
||||
@@ -167,7 +185,7 @@ async def generate_avatar(
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
|
||||
logger.info(f"Generating avatar for user {user_id} with prompt: {request.prompt}")
|
||||
logger.warning(f"Generating avatar for user {user_id} with prompt: {request.prompt}")
|
||||
|
||||
# 1. Generate Image
|
||||
result = await generate_image_with_provider(
|
||||
@@ -217,7 +235,7 @@ async def generate_avatar(
|
||||
content_to_save = base64.b64decode(image_data) if isinstance(image_data, str) else image_data
|
||||
|
||||
# Construct user assets directory
|
||||
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
|
||||
user_assets_dir = get_user_workspace(user_id) / "assets" / "avatars"
|
||||
|
||||
saved_path, error = save_file_safely(
|
||||
content_to_save,
|
||||
@@ -270,7 +288,7 @@ async def enhance_prompt_route(
|
||||
"""Enhance a simple prompt into a detailed midjourney-style prompt."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.info(f"Enhancing prompt for user {user_id}: {request.prompt}")
|
||||
logger.warning(f"Enhancing prompt for user {user_id}: {request.prompt}")
|
||||
|
||||
enhanced_prompt = await enhance_image_prompt(request.prompt, user_id=user_id)
|
||||
|
||||
@@ -294,7 +312,7 @@ async def create_variation_route(
|
||||
"""Generate a variation of an existing avatar."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.info(f"Creating variation for user {user_id} with prompt: {prompt}")
|
||||
logger.warning(f"Creating variation for user {user_id} with prompt: {prompt}")
|
||||
|
||||
# Read file
|
||||
file_content = await file.read()
|
||||
@@ -315,7 +333,7 @@ async def create_variation_route(
|
||||
content_to_save = base64.b64decode(image_data)
|
||||
|
||||
# Construct user assets directory
|
||||
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
|
||||
user_assets_dir = get_user_workspace(user_id) / "assets" / "avatars"
|
||||
|
||||
saved_path, error = save_file_safely(
|
||||
content_to_save,
|
||||
@@ -369,7 +387,7 @@ async def enhance_avatar_route(
|
||||
"""Enhance/Upscale an existing avatar."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.info(f"Enhancing avatar for user {user_id}")
|
||||
logger.warning(f"Enhancing avatar for user {user_id}")
|
||||
|
||||
# Read file
|
||||
file_content = await file.read()
|
||||
@@ -389,7 +407,7 @@ async def enhance_avatar_route(
|
||||
content_to_save = base64.b64decode(image_data)
|
||||
|
||||
# Construct user assets directory
|
||||
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
|
||||
user_assets_dir = get_user_workspace(user_id) / "assets" / "avatars"
|
||||
|
||||
saved_path, error = save_file_safely(
|
||||
content_to_save,
|
||||
@@ -446,13 +464,13 @@ async def create_voice_clone(
|
||||
"""Create a voice clone from an audio file."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.info(f"Creating voice clone '{voice_name}' (engine={engine}) for user {user_id}")
|
||||
logger.warning(f"[VoiceClone] Creating voice clone '{voice_name}' (engine={engine}) for user {user_id}")
|
||||
|
||||
# 1. Save uploaded audio file
|
||||
file_content = await file.read()
|
||||
filename = generate_unique_filename("voice_sample", Path(file.filename).suffix.lstrip("."))
|
||||
|
||||
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
|
||||
user_voice_dir = get_user_workspace(user_id) / "assets" / "voice_samples"
|
||||
saved_path, error = save_file_safely(file_content, user_voice_dir, filename)
|
||||
|
||||
if error or not saved_path:
|
||||
@@ -474,7 +492,7 @@ async def create_voice_clone(
|
||||
random_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
|
||||
custom_voice_id = f"vc_{random_suffix}"
|
||||
|
||||
logger.info(f"Cloning voice with Minimax, ID: {custom_voice_id}")
|
||||
logger.warning(f"Cloning voice with Minimax, ID: {custom_voice_id}")
|
||||
|
||||
# Run blocking call in executor
|
||||
result = await loop.run_in_executor(
|
||||
@@ -489,7 +507,7 @@ async def create_voice_clone(
|
||||
preview_audio_bytes = result.preview_audio_bytes
|
||||
|
||||
elif engine.lower() == "cosyvoice":
|
||||
logger.info("Cloning voice with CosyVoice")
|
||||
logger.warning("Cloning voice with CosyVoice")
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: cosyvoice_voice_clone(
|
||||
@@ -504,7 +522,7 @@ async def create_voice_clone(
|
||||
custom_voice_id = f"vc_cosy_{asset_uuid}"
|
||||
|
||||
else: # qwen3 (default)
|
||||
logger.info("Cloning voice with Qwen3")
|
||||
logger.warning("Cloning voice with Qwen3")
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: qwen3_voice_clone(
|
||||
@@ -520,27 +538,48 @@ async def create_voice_clone(
|
||||
|
||||
# 3. Save Preview Audio (if generated)
|
||||
preview_url = None
|
||||
if preview_audio_bytes:
|
||||
preview_filename = f"preview_{filename}"
|
||||
# Ensure it ends with .wav
|
||||
if not preview_filename.endswith(".wav"):
|
||||
preview_filename = str(Path(preview_filename).with_suffix('.wav'))
|
||||
preview_mime_type = "audio/wav"
|
||||
actual_filename = None # Default if preview save fails
|
||||
|
||||
if preview_audio_bytes and len(preview_audio_bytes) > 0:
|
||||
from utils.media_utils import detect_audio_format, ensure_audio_extension
|
||||
|
||||
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
|
||||
detected_fmt, preview_mime_type = detect_audio_format(preview_audio_bytes)
|
||||
logger.warning(f"[VoiceClone] Detected preview audio format: {detected_fmt} ({preview_mime_type}), {len(preview_audio_bytes)} bytes")
|
||||
|
||||
# Build filename with correct extension based on actual content format
|
||||
original_stem = Path(filename).stem
|
||||
preview_filename = f"preview_{original_stem}"
|
||||
preview_filename = ensure_audio_extension(preview_filename, preview_audio_bytes)
|
||||
|
||||
user_voice_dir = get_user_workspace(user_id) / "assets" / "voice_samples"
|
||||
saved_preview_path, error = save_file_safely(preview_audio_bytes, user_voice_dir, preview_filename)
|
||||
|
||||
if not error and saved_preview_path:
|
||||
preview_url = f"/api/assets/{user_id}/voice_samples/{preview_filename}"
|
||||
# Use actual saved filename (may have UUID suffix added by save_file_safely)
|
||||
actual_filename = saved_preview_path.name
|
||||
preview_url = f"/api/assets/{user_id}/voice_samples/{actual_filename}"
|
||||
logger.warning(f"[VoiceClone] Saved preview: {actual_filename} ({saved_preview_path.stat().st_size} bytes, {preview_mime_type})")
|
||||
|
||||
# Verify file exists
|
||||
if not saved_preview_path.exists():
|
||||
logger.warning(f"[VoiceClone] Preview file does not exist after save: {saved_preview_path}")
|
||||
preview_url = None
|
||||
else:
|
||||
logger.warning(f"[VoiceClone] Failed to save preview audio: {error}")
|
||||
|
||||
# 4. Save to Asset Library
|
||||
# Use the preview file (with corrected .wav extension) as the main asset file
|
||||
has_valid_preview = preview_audio_bytes and len(preview_audio_bytes) > 0 and saved_preview_path
|
||||
stored_filename = actual_filename if has_valid_preview else filename
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
file_path=file_path,
|
||||
asset_type="audio",
|
||||
source_module="voice_cloner",
|
||||
filename=filename,
|
||||
file_url=f"/api/assets/{user_id}/voice_samples/{filename}",
|
||||
filename=stored_filename,
|
||||
file_url=f"/api/assets/{user_id}/voice_samples/{stored_filename}",
|
||||
asset_metadata={
|
||||
"voice_name": voice_name,
|
||||
"engine": engine,
|
||||
@@ -555,7 +594,7 @@ async def create_voice_clone(
|
||||
return {
|
||||
"success": True,
|
||||
"custom_voice_id": custom_voice_id,
|
||||
"preview_audio_url": preview_url or f"/api/assets/{user_id}/voice_samples/{filename}",
|
||||
"preview_audio_url": preview_url or f"/api/assets/{user_id}/voice_samples/{stored_filename}",
|
||||
"asset_id": asset_id,
|
||||
"message": "Voice clone created successfully"
|
||||
}
|
||||
@@ -574,7 +613,7 @@ async def create_voice_design(
|
||||
"""Create a voice from text description (Voice Design)."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.info(f"Designing voice for user {user_id}")
|
||||
logger.warning(f"Designing voice for user {user_id}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -588,9 +627,15 @@ async def create_voice_design(
|
||||
)
|
||||
)
|
||||
|
||||
# Save the result to a temporary file
|
||||
filename = generate_unique_filename("voice_design_preview", "wav")
|
||||
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
|
||||
# Save the result to a file with correct extension based on content
|
||||
from utils.media_utils import detect_audio_format, ensure_audio_extension
|
||||
detected_fmt, mime_type = detect_audio_format(result.preview_audio_bytes)
|
||||
logger.warning(f"[VoiceDesign] Detected audio format: {detected_fmt} ({mime_type})")
|
||||
|
||||
filename = generate_unique_filename("voice_design_preview", detected_fmt)
|
||||
filename = ensure_audio_extension(filename, result.preview_audio_bytes)
|
||||
|
||||
user_voice_dir = get_user_workspace(user_id) / "assets" / "voice_samples"
|
||||
saved_path, error = save_file_safely(result.preview_audio_bytes, user_voice_dir, filename)
|
||||
|
||||
if error or not saved_path:
|
||||
|
||||
@@ -2,34 +2,26 @@
|
||||
Podcast API Constants
|
||||
|
||||
Centralized constants and directory configuration for podcast module.
|
||||
All workspace paths use utils.storage_paths for root resolution.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from loguru import logger
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from utils.storage_paths import get_repo_root, sanitize_user_id as _sanitize_user_id
|
||||
|
||||
# Directory paths
|
||||
# router.py is at: backend/api/podcast/router.py
|
||||
# parents[0] = backend/api/podcast/
|
||||
# parents[1] = backend/api/
|
||||
# parents[2] = backend/
|
||||
# parents[3] = root/
|
||||
ROOT_DIR = Path(__file__).resolve().parents[3] # root/
|
||||
DATA_MEDIA_DIR = ROOT_DIR / "data" / "media"
|
||||
ROOT_DIR = get_repo_root()
|
||||
|
||||
PODCAST_AUDIO_DIR = (DATA_MEDIA_DIR / "podcast_audio").resolve()
|
||||
PODCAST_IMAGES_DIR = (DATA_MEDIA_DIR / "podcast_images").resolve()
|
||||
PODCAST_VIDEOS_DIR = (DATA_MEDIA_DIR / "podcast_videos").resolve()
|
||||
|
||||
# Video subdirectory
|
||||
# Video subdirectory (relative to workspace media dir)
|
||||
AI_VIDEO_SUBDIR = Path("AI_Videos")
|
||||
|
||||
MediaType = Literal["audio", "image", "video"]
|
||||
# Legacy constants - DEPRECATED, use get_podcast_media_dir() instead
|
||||
# Kept for backward compatibility with some handlers
|
||||
PODCAST_AVATARS_SUBDIR = Path("avatars")
|
||||
|
||||
|
||||
def _sanitize_user_id(user_id: str) -> str:
|
||||
return "".join(c for c in user_id if c.isalnum() or c in ("-", "_"))
|
||||
MediaType = Literal["audio", "image", "video", "chart"]
|
||||
|
||||
|
||||
def get_podcast_media_dir(
|
||||
@@ -38,21 +30,30 @@ def get_podcast_media_dir(
|
||||
*,
|
||||
ensure_exists: bool = False,
|
||||
) -> Path:
|
||||
"""Resolve podcast media directory (tenant workspace first, legacy global fallback)."""
|
||||
"""
|
||||
Resolve podcast media directory (workspace-only for multi-tenant isolation).
|
||||
|
||||
Requires user_id for tenant isolation. Falls back to default workspace
|
||||
only if no user_id provided (for backward compat in development).
|
||||
Logs a warning in production when user_id is missing.
|
||||
"""
|
||||
media_subdir = {
|
||||
"audio": "podcast_audio",
|
||||
"image": "podcast_images",
|
||||
"video": "podcast_videos",
|
||||
"chart": "podcast_charts",
|
||||
}[media_type]
|
||||
|
||||
if user_id:
|
||||
sanitized = _sanitize_user_id(user_id)
|
||||
tenant_media_dir = ROOT_DIR / "workspace" / f"workspace_{sanitized}" / "media" / media_subdir
|
||||
resolved_dir = tenant_media_dir.resolve()
|
||||
resolved_dir = (
|
||||
ROOT_DIR / "workspace" / f"workspace_{sanitized}" / "media" / media_subdir
|
||||
).resolve()
|
||||
else:
|
||||
resolved_dir = (DATA_MEDIA_DIR / media_subdir).resolve()
|
||||
|
||||
logger.debug(f"[Podcast] get_podcast_media_dir: type={media_type}, user_id={user_id}, sanitized={user_id and _sanitize_user_id(user_id)}, resolved={resolved_dir}")
|
||||
logger.warning(f"[Podcast] get_podcast_media_dir called without user_id for {media_type} — using default workspace. This should not happen in production.")
|
||||
resolved_dir = (
|
||||
ROOT_DIR / "workspace" / "workspace_alwrity" / "media" / media_subdir
|
||||
).resolve()
|
||||
|
||||
if ensure_exists:
|
||||
resolved_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -61,14 +62,11 @@ def get_podcast_media_dir(
|
||||
|
||||
|
||||
def get_podcast_media_read_dirs(media_type: MediaType, user_id: str | None = None) -> list[Path]:
|
||||
"""Return ordered directories to search (tenant path first, then legacy global path)."""
|
||||
dirs: list[Path] = []
|
||||
if user_id:
|
||||
dirs.append(get_podcast_media_dir(media_type, user_id))
|
||||
logger.debug(f"[Podcast] get_podcast_media_read_dirs: added user dir for {user_id}")
|
||||
dirs.append(get_podcast_media_dir(media_type, None))
|
||||
logger.debug(f"[Podcast] get_podcast_media_read_dirs: dirs={dirs}")
|
||||
return dirs
|
||||
"""
|
||||
Return directories to search for podcast media.
|
||||
Now workspace-only (no legacy fallback).
|
||||
"""
|
||||
return [get_podcast_media_dir(media_type, user_id)]
|
||||
|
||||
|
||||
def get_podcast_audio_service(user_id: str | None = None) -> StoryAudioGenerationService:
|
||||
|
||||
216
backend/api/podcast/cost_estimator.py
Normal file
216
backend/api/podcast/cost_estimator.py
Normal file
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
Podcast cost estimation helpers.
|
||||
|
||||
Builds user-facing podcast estimates from the subscription pricing catalog
|
||||
instead of hard-coded frontend heuristics.
|
||||
|
||||
Supports multiple models for each component:
|
||||
- Audio TTS: minimax/speech-02-hd (default), qwen3-tts, cosyvoice-tts
|
||||
- Voice Clone: qwen3, cosyvoice, minimax
|
||||
- Image: qwen-image (default), ideogram-v3-turbo
|
||||
- Video: wan-2.5 (default), kling-v2.5, infinitetalk
|
||||
- LLM: gemini-2.5-flash (default)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.subscription_models import APIProvider
|
||||
from services.subscription.pricing_service import PricingService
|
||||
|
||||
|
||||
def _round_money(value: float) -> float:
|
||||
return round(float(value), 4)
|
||||
|
||||
|
||||
def _load_pricing(
|
||||
pricing_service: PricingService,
|
||||
provider: APIProvider,
|
||||
preferred_model: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pricing for a provider and model, with fallback to default."""
|
||||
pricing = pricing_service.get_pricing_for_provider_model(provider, preferred_model)
|
||||
if pricing:
|
||||
return pricing
|
||||
# Fallback to provider default model row (if configured).
|
||||
return pricing_service.get_pricing_for_provider_model(provider, "default")
|
||||
|
||||
|
||||
# Default models used in podcast generation
|
||||
DEFAULT_MODELS = {
|
||||
"gemini": "gemini-2.5-flash",
|
||||
"exa": "exa-search",
|
||||
"audio_tts": "minimax/speech-02-hd",
|
||||
"voice_clone": "wavespeed-ai/qwen3-tts/voice-clone",
|
||||
"image": "qwen-image",
|
||||
"video": "wan-2.5",
|
||||
}
|
||||
|
||||
|
||||
def estimate_podcast_cost(
|
||||
*,
|
||||
db: Session,
|
||||
duration_minutes: int,
|
||||
speakers: int,
|
||||
query_count: int,
|
||||
include_avatar_phase: bool = True,
|
||||
# Optional model overrides
|
||||
gemini_model: str = "gemini-2.5-flash",
|
||||
audio_tts_model: str = "minimax/speech-02-hd",
|
||||
voice_clone_engine: str = "qwen3",
|
||||
image_model: str = "qwen-image",
|
||||
video_model: str = "wan-2.5",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Compute a backend estimate for podcast creation.
|
||||
|
||||
Supports customizable models for each component.
|
||||
Uses pricing_catalog for accurate cost calculation.
|
||||
"""
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Load pricing for each component and model
|
||||
gemini_pricing = _load_pricing(pricing_service, APIProvider.GEMINI, gemini_model)
|
||||
exa_pricing = _load_pricing(pricing_service, APIProvider.EXA, "exa-search")
|
||||
|
||||
# Audio TTS pricing (minimax/speech-02-hd)
|
||||
audio_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, audio_tts_model)
|
||||
|
||||
# Voice clone pricing (different engines)
|
||||
voice_clone_model = f"wavespeed-ai/{voice_clone_engine}-tts/voice-clone"
|
||||
voice_clone_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, voice_clone_model)
|
||||
if not voice_clone_pricing:
|
||||
# Try alternate model names
|
||||
voice_clone_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, f"{voice_clone_engine}/voice-clone")
|
||||
|
||||
# Image pricing (qwen-image or ideogram)
|
||||
image_pricing = _load_pricing(pricing_service, APIProvider.STABILITY, image_model)
|
||||
|
||||
# Video pricing (wan-2.5, kling, or infinitetalk)
|
||||
video_pricing = _load_pricing(pricing_service, APIProvider.VIDEO, video_model)
|
||||
|
||||
# Return None if critical pricing unavailable (fail fast)
|
||||
if not gemini_pricing:
|
||||
return None
|
||||
|
||||
# Configuration
|
||||
minutes = max(1, int(duration_minutes or 1))
|
||||
speaker_count = max(1, int(speakers or 1))
|
||||
research_queries = max(1, int(query_count or 1))
|
||||
|
||||
# Token usage assumptions per phase
|
||||
analysis_input_tokens = 1800
|
||||
analysis_output_tokens = 1000
|
||||
research_synthesis_input_tokens = 2200
|
||||
research_synthesis_output_tokens = 900
|
||||
script_input_tokens = max(1800, minutes * 300)
|
||||
script_output_tokens = max(2200, minutes * 700)
|
||||
|
||||
# TTS: ~900 chars per minute per speaker
|
||||
estimated_tts_tokens = max(900, minutes * 900 * speaker_count)
|
||||
|
||||
# Voice clone: 1 clone operation per speaker
|
||||
voice_clone_count = speaker_count
|
||||
|
||||
# ===== COST CALCULATIONS =====
|
||||
|
||||
# 1. Analysis phase (LLM)
|
||||
analysis_cost = (
|
||||
analysis_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ analysis_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
)
|
||||
|
||||
# 2. Research phase
|
||||
# 2a. LLM for research synthesis
|
||||
research_llm_cost = (
|
||||
research_synthesis_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ research_synthesis_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
)
|
||||
# 2b. Search API (Exa)
|
||||
research_search_cost = 0.0
|
||||
if exa_pricing:
|
||||
research_search_cost = research_queries * float(exa_pricing.get("cost_per_request") or 0.0)
|
||||
research_cost = research_search_cost + research_llm_cost
|
||||
|
||||
# 3. Script generation (LLM)
|
||||
script_cost = (
|
||||
script_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ script_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
)
|
||||
|
||||
# 4. Audio TTS
|
||||
tts_cost = 0.0
|
||||
if audio_pricing:
|
||||
tts_cost = estimated_tts_tokens * float(audio_pricing.get("cost_per_input_token") or 0.0)
|
||||
|
||||
# 5. Voice cloning (if needed)
|
||||
voice_clone_cost = 0.0
|
||||
if voice_clone_pricing:
|
||||
voice_clone_cost = voice_clone_count * (
|
||||
float(voice_clone_pricing.get("cost_per_request") or 0.0)
|
||||
+ estimated_tts_tokens * float(voice_clone_pricing.get("cost_per_input_token") or 0.0)
|
||||
)
|
||||
|
||||
# 6. Avatar image generation
|
||||
avatar_cost = 0.0
|
||||
if include_avatar_phase and image_pricing:
|
||||
image_unit = float(image_pricing.get("cost_per_image") or image_pricing.get("cost_per_request") or 0.0)
|
||||
avatar_cost = speaker_count * image_unit
|
||||
|
||||
# 7. Video rendering
|
||||
video_cost = 0.0
|
||||
if video_pricing:
|
||||
# Assume 1 video render per minute (upper bound)
|
||||
video_cost = minutes * float(video_pricing.get("cost_per_request") or 0.0)
|
||||
|
||||
# ===== TOTALS =====
|
||||
llm_total = analysis_cost + research_llm_cost + script_cost
|
||||
audio_total = tts_cost + voice_clone_cost
|
||||
media_total = avatar_cost + video_cost
|
||||
total = llm_total + research_search_cost + audio_total + media_total
|
||||
|
||||
return {
|
||||
# Cost breakdown
|
||||
"analysisCost": _round_money(analysis_cost),
|
||||
"researchCost": _round_money(research_cost),
|
||||
"researchSearchCost": _round_money(research_search_cost),
|
||||
"researchLlmCost": _round_money(research_llm_cost),
|
||||
"scriptCost": _round_money(script_cost),
|
||||
"ttsCost": _round_money(tts_cost),
|
||||
"voiceCloneCost": _round_money(voice_clone_cost),
|
||||
"avatarCost": _round_money(avatar_cost),
|
||||
"videoCost": _round_money(video_cost),
|
||||
"total": _round_money(total),
|
||||
# Totals by category
|
||||
"llmCost": _round_money(llm_total),
|
||||
"audioCost": _round_money(audio_total),
|
||||
"mediaCost": _round_money(media_total),
|
||||
# Currency
|
||||
"currency": "USD",
|
||||
"source": "pricing_catalog",
|
||||
# Models used for this estimate
|
||||
"models": {
|
||||
"llm": gemini_model,
|
||||
"research": "exa-search",
|
||||
"audio_tts": audio_tts_model,
|
||||
"voice_clone": voice_clone_model,
|
||||
"image": image_model,
|
||||
"video": video_model,
|
||||
},
|
||||
# Assumptions used
|
||||
"assumptions": {
|
||||
"analysis_input_tokens": analysis_input_tokens,
|
||||
"analysis_output_tokens": analysis_output_tokens,
|
||||
"research_synthesis_input_tokens": research_synthesis_input_tokens,
|
||||
"research_synthesis_output_tokens": research_synthesis_output_tokens,
|
||||
"script_input_tokens": script_input_tokens,
|
||||
"script_output_tokens": script_output_tokens,
|
||||
"estimated_tts_tokens": estimated_tts_tokens,
|
||||
"research_queries": research_queries,
|
||||
"voice_clone_count": voice_clone_count,
|
||||
"video_requests": minutes,
|
||||
"avatar_requests": speaker_count if include_avatar_phase else 0,
|
||||
},
|
||||
}
|
||||
@@ -4,8 +4,9 @@ Podcast Analysis Handlers
|
||||
Analysis endpoint for podcast ideas.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -19,17 +20,99 @@ from services.llm_providers.main_image_generation import generate_image
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
import os
|
||||
from ..constants import get_podcast_media_dir
|
||||
from ..prompts import get_enhance_topic_prompt, format_website_context
|
||||
from ..models import (
|
||||
PodcastAnalyzeRequest,
|
||||
PodcastAnalyzeResponse,
|
||||
PodcastEnhanceIdeaRequest,
|
||||
PodcastEnhanceIdeaResponse
|
||||
PodcastEnhanceIdeaResponse,
|
||||
ExtractUrlRequest,
|
||||
ExtractUrlResponse,
|
||||
WebsiteAnalysisRequest,
|
||||
WebsiteAnalysisResponse,
|
||||
PodcastPreEstimateRequest,
|
||||
PodcastPreEstimateResponse,
|
||||
)
|
||||
from ..cost_estimator import estimate_podcast_cost
|
||||
|
||||
# Check if running in podcast-only demo mode
|
||||
def _is_podcast_only_mode() -> bool:
|
||||
"""Check if podcast-only demo mode is enabled."""
|
||||
return os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/pre-estimate", response_model=PodcastPreEstimateResponse)
|
||||
async def pre_estimate_cost(
|
||||
request: PodcastPreEstimateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Lightweight endpoint to estimate podcast creation cost before analysis.
|
||||
|
||||
Takes user configuration (duration, speakers, query_count, podcast_mode) and returns
|
||||
a cost estimate WITHOUT running full analysis.
|
||||
|
||||
Optional model overrides can be specified to estimate with different models.
|
||||
"""
|
||||
try:
|
||||
include_avatar_phase = request.podcast_mode != "audio_only"
|
||||
|
||||
estimate = estimate_podcast_cost(
|
||||
db=db,
|
||||
duration_minutes=request.duration,
|
||||
speakers=request.speakers,
|
||||
query_count=request.query_count,
|
||||
include_avatar_phase=include_avatar_phase,
|
||||
# Model overrides if provided
|
||||
gemini_model=request.gemini_model or "gemini-2.5-flash",
|
||||
audio_tts_model=request.audio_tts_model or "minimax/speech-02-hd",
|
||||
voice_clone_engine=request.voice_clone_engine or "qwen3",
|
||||
image_model=request.image_model or "qwen-image",
|
||||
video_model=request.video_model or "wan-2.5",
|
||||
)
|
||||
|
||||
# Debug: get pricing row count and providers
|
||||
from models.subscription_models import APIProviderPricing
|
||||
pricing_count = db.query(APIProviderPricing).count()
|
||||
providers = db.query(APIProviderPricing.provider).distinct().all()
|
||||
provider_list = sorted([p[0].value for p in providers]) if providers else []
|
||||
|
||||
debug_info = {
|
||||
"pricing_rows": pricing_count,
|
||||
"providers": provider_list,
|
||||
}
|
||||
|
||||
# Log pricing debug info at warning level
|
||||
logger.warning(f"[PRE-ESTIMATE] Pricing debug: rows={pricing_count}, providers={provider_list}")
|
||||
logger.warning(f"[PRE-ESTIMATE] Models: llm={request.gemini_model}, tts={request.audio_tts_model}, video={request.video_model}")
|
||||
|
||||
if estimate is None:
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=None,
|
||||
error="Pricing data unavailable. Please try again later.",
|
||||
pricing_available=False,
|
||||
debug=debug_info,
|
||||
)
|
||||
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=estimate,
|
||||
error=None,
|
||||
pricing_available=True,
|
||||
debug=debug_info,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pre-estimate error: {e}")
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=None,
|
||||
error=str(e),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/idea/enhance", response_model=PodcastEnhanceIdeaResponse)
|
||||
async def enhance_podcast_idea(
|
||||
request: PodcastEnhanceIdeaRequest,
|
||||
@@ -42,39 +125,55 @@ async def enhance_podcast_idea(
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Serialize Bible context if provided or generate from onboarding
|
||||
# In podcast-only mode, skip bible generation since onboarding is disabled
|
||||
bible_context = ""
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if not _is_podcast_only_mode():
|
||||
logger.warning(f"[Podcast Enhance] Podcast mode=full — attempting Bible generation for user {user_id}")
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_enhance")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Enhance] Failed to parse or generate bible context: {exc}")
|
||||
else:
|
||||
# In podcast mode, use the provided bible directly if available
|
||||
logger.warning(f"[Podcast Enhance] Podcast mode=podcast_only — skipping Bible generation for user {user_id}")
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_enhance")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Enhance] Failed to parse or generate bible context: {exc}")
|
||||
try:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_service = PodcastBibleService()
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
except Exception as exc:
|
||||
logger.debug(f"[Podcast Enhance] Bible parsing skipped in podcast mode: {exc}")
|
||||
|
||||
prompt = f"""
|
||||
You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
# Log what's being used for context
|
||||
context_used = []
|
||||
if bible_context:
|
||||
context_used.append("Podcast Bible")
|
||||
if request.website_data:
|
||||
context_used.append("Website Extraction")
|
||||
if request.topic_context:
|
||||
category = request.topic_context.get("category", "unknown")
|
||||
context_used.append(f"Category Research ({category})")
|
||||
|
||||
logger.warning(f"[Podcast Enhance] Generating with context: {', '.join(context_used) if context_used else 'basic idea only'}")
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{request.idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 enhanced episode pitches (in order: Professional, Storytelling, Trendy)
|
||||
- rationales: array of 3 rationales explaining the approach for each version
|
||||
"""
|
||||
# Use new context builder for prompt generation
|
||||
from services.podcast_context_builder import context_builder
|
||||
context_result = context_builder.build_enhance_context(
|
||||
idea=request.idea,
|
||||
bible_context=bible_context,
|
||||
website_data=request.website_data,
|
||||
topic_context=request.topic_context,
|
||||
)
|
||||
prompt = context_result["prompt"]
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
@@ -95,6 +194,19 @@ Return JSON with:
|
||||
enhanced_ideas = data.get("enhanced_ideas", [])
|
||||
rationales = data.get("rationales", [])
|
||||
|
||||
# Handle case where LLM returns objects instead of strings
|
||||
normalized_ideas = []
|
||||
for idea in enhanced_ideas:
|
||||
if isinstance(idea, dict):
|
||||
# Extract title and description from object
|
||||
title = idea.get("title", "")
|
||||
description = idea.get("description", "") or idea.get("content", "")
|
||||
normalized_ideas.append(f"{title}: {description}" if description else title)
|
||||
elif isinstance(idea, str):
|
||||
normalized_ideas.append(idea)
|
||||
|
||||
enhanced_ideas = normalized_ideas
|
||||
|
||||
# Ensure we have exactly 3 ideas, fallback to original if needed
|
||||
if not isinstance(enhanced_ideas, list) or len(enhanced_ideas) != 3:
|
||||
# Fallback: create 3 variations of the original idea
|
||||
@@ -164,7 +276,11 @@ async def analyze_podcast_idea(
|
||||
final_avatar_url = request.avatar_url
|
||||
final_avatar_prompt = None
|
||||
|
||||
if not final_avatar_url:
|
||||
# Skip avatar generation for audio_only mode
|
||||
podcast_mode = getattr(request, 'podcast_mode', None) or 'video_only'
|
||||
should_generate_avatar = not final_avatar_url and podcast_mode != 'audio_only'
|
||||
|
||||
if should_generate_avatar:
|
||||
logger.info(f"[Podcast Analyze] No avatar_url provided, generating one for user {user_id}")
|
||||
try:
|
||||
# 1. PRE-FLIGHT VALIDATION: Check subscription limits for image generation
|
||||
@@ -195,8 +311,10 @@ async def analyze_podcast_idea(
|
||||
if image_result and image_result.image_bytes:
|
||||
img_id = str(uuid.uuid4())[:8]
|
||||
filename = f"presenter_podcast_{user_id}_{img_id}.png"
|
||||
output_path = PODCAST_IMAGES_DIR / filename
|
||||
PODCAST_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
images_dir = get_podcast_media_dir("image", user_id, ensure_exists=True)
|
||||
avatars_dir = images_dir / "avatars"
|
||||
avatars_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = avatars_dir / filename
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_result.image_bytes)
|
||||
@@ -208,13 +326,14 @@ async def analyze_podcast_idea(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
file_url=final_avatar_url,
|
||||
source_module="podcast_analysis",
|
||||
filename=filename,
|
||||
file_url=final_avatar_url,
|
||||
title=f"Presenter Avatar - {request.idea[:40]}",
|
||||
description=f"AI-generated podcast presenter for: {request.idea}",
|
||||
provider=image_result.provider,
|
||||
model=image_result.model,
|
||||
cost=image_result.cost
|
||||
cost=0.0 # Cost tracked in generate_image
|
||||
)
|
||||
logger.info(f"[Podcast Analyze] ✅ Generated and saved avatar to {final_avatar_url}")
|
||||
except Exception as e:
|
||||
@@ -319,6 +438,13 @@ Requirements:
|
||||
listener_cta = data.get("listener_cta") or ""
|
||||
research_queries = data.get("research_queries") or []
|
||||
exa_suggested_config = data.get("exa_suggested_config") or None
|
||||
estimate = estimate_podcast_cost(
|
||||
db=db,
|
||||
duration_minutes=request.duration,
|
||||
speakers=request.speakers,
|
||||
query_count=len(research_queries) if isinstance(research_queries, list) else 0,
|
||||
include_avatar_phase=podcast_mode != "audio_only",
|
||||
)
|
||||
|
||||
return PodcastAnalyzeResponse(
|
||||
audience=audience,
|
||||
@@ -335,6 +461,7 @@ Requirements:
|
||||
bible=bible_obj.model_dump() if bible_obj else None,
|
||||
avatar_url=final_avatar_url,
|
||||
avatar_prompt=final_avatar_prompt,
|
||||
estimate=estimate,
|
||||
)
|
||||
|
||||
|
||||
@@ -440,3 +567,315 @@ Requirements:
|
||||
logger.error(f"[Regenerate Queries] Failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Regenerate queries failed: {exc}")
|
||||
|
||||
|
||||
@router.post("/extract-url", response_model=ExtractUrlResponse)
|
||||
async def extract_url_content(
|
||||
request: ExtractUrlRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Extract content from a URL using Exa's get_contents API.
|
||||
|
||||
This allows users to paste a blog post or article URL as their podcast topic,
|
||||
and we'll extract the content to use as the podcast idea.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
from exa_py import Exa
|
||||
import os
|
||||
|
||||
api_key = os.getenv("EXA_API_KEY")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="EXA_API_KEY not configured")
|
||||
|
||||
exa = Exa(api_key)
|
||||
|
||||
logger.warning(f"[ExtractUrl] Extracting content from: {request.url} for user {user_id}")
|
||||
|
||||
try:
|
||||
result = exa.get_contents(
|
||||
urls=[request.url],
|
||||
text=True,
|
||||
highlights=True,
|
||||
summary=True,
|
||||
subpages=2,
|
||||
)
|
||||
except Exception as exa_error:
|
||||
logger.error(f"[ExtractUrl] Exa call error: {exa_error}")
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error=f"Exa API error: {str(exa_error)}"
|
||||
)
|
||||
|
||||
# Check for errors using the correct attribute (statuses is array of status objects)
|
||||
if hasattr(result, 'statuses') and result.statuses:
|
||||
for status in result.statuses:
|
||||
if status.status == "error":
|
||||
logger.error(f"[ExtractUrl] Failed to extract {status.id}: {status.error.tag if hasattr(status.error, 'tag') else 'unknown'}")
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error=f"Failed to extract content: {status.error.tag if hasattr(status.error, 'tag') else 'unknown error'}"
|
||||
)
|
||||
|
||||
if not result.results:
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error="No content found at the provided URL"
|
||||
)
|
||||
|
||||
# Extract content - safe to access result now
|
||||
content = result.results[0]
|
||||
|
||||
# Extract all available fields from Exa response
|
||||
extracted_text = content.text or ""
|
||||
extracted_summary = getattr(content, 'summary', "") or ""
|
||||
extracted_title = content.title or ""
|
||||
|
||||
# Highlights - extract from content.highlights array if available
|
||||
highlights = []
|
||||
if hasattr(content, 'highlights') and content.highlights:
|
||||
highlights = [h for h in content.highlights if h]
|
||||
|
||||
# Additional fields from Exa response
|
||||
image = getattr(content, 'image', None)
|
||||
favicon = getattr(content, 'favicon', None)
|
||||
|
||||
# Subpages - extract with their own content
|
||||
subpages = []
|
||||
if hasattr(content, 'subpages') and content.subpages:
|
||||
for sp in content.subpages:
|
||||
subpages.append({
|
||||
'id': sp.get('id', ''),
|
||||
'title': sp.get('title', ''),
|
||||
'url': sp.get('url', ''),
|
||||
'summary': sp.get('summary', ''),
|
||||
'text': sp.get('text', '')[:500] if sp.get('text') else '', # First 500 chars
|
||||
})
|
||||
|
||||
logger.warning(f"[ExtractUrl] Successfully extracted {len(extracted_text)} chars from {request.url}")
|
||||
logger.warning(f"[ExtractUrl] title={extracted_title[:50]}, summary={extracted_summary[:50]}, highlights={len(highlights)}, subpages={len(subpages)}")
|
||||
|
||||
return ExtractUrlResponse(
|
||||
success=True,
|
||||
title=extracted_title,
|
||||
text=extracted_text,
|
||||
summary=extracted_summary,
|
||||
author=getattr(content, 'author', None),
|
||||
highlights=highlights,
|
||||
url=request.url,
|
||||
image=image,
|
||||
favicon=favicon,
|
||||
subpages=subpages,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/website-analysis", response_model=WebsiteAnalysisResponse)
|
||||
async def save_website_analysis(
|
||||
request: WebsiteAnalysisRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save the user's website analysis for reuse in future podcasts."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.user_data_service import user_data_service
|
||||
|
||||
website_data = {
|
||||
"website_url": request.website_url,
|
||||
"extracted_at": datetime.now().isoformat(),
|
||||
"exa_content": request.exa_content,
|
||||
"full_analysis": None,
|
||||
"analysis_status": "pending",
|
||||
}
|
||||
|
||||
success = user_data_service.save_user_data(
|
||||
user_id=user_id,
|
||||
data_key="website_analysis",
|
||||
data_value=website_data,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.warning(f"[WebsiteAnalysis] Saved analysis for user {user_id}: {request.website_url}")
|
||||
return WebsiteAnalysisResponse(
|
||||
success=True,
|
||||
website_url=request.website_url,
|
||||
message="Website analysis saved successfully",
|
||||
)
|
||||
else:
|
||||
return WebsiteAnalysisResponse(
|
||||
success=False,
|
||||
error="Failed to save website analysis",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteAnalysis] Failed to save for user {user_id}: {exc}")
|
||||
return WebsiteAnalysisResponse(
|
||||
success=False,
|
||||
error=f"Failed to save: {str(exc)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/website-extraction")
|
||||
async def get_saved_website_extraction(request: Request = None):
|
||||
"""Get previously saved website extraction data for this user."""
|
||||
try:
|
||||
# Safely get current_user from Depends
|
||||
if request is None or not hasattr(request, 'state'):
|
||||
logger.warning("[WebsiteExtraction] No request or state - user not authenticated")
|
||||
return {"success": False, "data": None, "error": "Not authenticated"}
|
||||
|
||||
current_user = getattr(request.state, 'user', None)
|
||||
if not current_user:
|
||||
logger.warning("[WebsiteExtraction] No user in request state")
|
||||
return {"success": False, "data": None, "error": "Not authenticated"}
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
from services.user_data_service import UserDataService
|
||||
from services.database import get_db
|
||||
db = next(get_db())
|
||||
|
||||
user_service = UserDataService(db)
|
||||
extraction = user_service.get_website_extraction(user_id)
|
||||
|
||||
if extraction:
|
||||
logger.info(f"[WebsiteExtraction] Found saved data for user {user_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": extraction
|
||||
}
|
||||
else:
|
||||
logger.info(f"[WebsiteExtraction] No saved data for user {user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
"data": None
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteExtraction] Failed for user: {exc}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/website-extraction")
|
||||
async def save_website_extraction(
|
||||
extraction: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save website extraction data for future use."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.user_data_service import UserDataService
|
||||
from services.database import get_db
|
||||
db = next(get_db())
|
||||
|
||||
user_service = UserDataService(db)
|
||||
success = user_service.save_website_extraction(user_id, extraction)
|
||||
|
||||
if success:
|
||||
logger.info(f"[WebsiteExtraction] Saved for user {user_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Website extraction saved"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Failed to save"
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteExtraction] Save failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/project/{project_id}/topic-context")
|
||||
async def save_topic_context(
|
||||
project_id: str,
|
||||
topic_context: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save topic context (category research) to a podcast project."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from models.podcast_models import PodcastProject
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
# Find the project
|
||||
project = db.query(PodcastProject).filter(
|
||||
PodcastProject.project_id == project_id,
|
||||
PodcastProject.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not project:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Project not found"
|
||||
}
|
||||
|
||||
# Update topic context
|
||||
project.topic_context = topic_context
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[TopicContext] Saved for project {project_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Topic context saved"
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[TopicContext] Save failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/project/{project_id}/topic-context")
|
||||
async def get_topic_context(
|
||||
project_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get topic context from a podcast project."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from models.podcast_models import PodcastProject
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
project = db.query(PodcastProject).filter(
|
||||
PodcastProject.project_id == project_id,
|
||||
PodcastProject.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not project:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Project not found"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": project.topic_context
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[TopicContext] Get failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
@@ -12,7 +12,15 @@ from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import tempfile
|
||||
import uuid
|
||||
import hashlib
|
||||
import time
|
||||
import shutil
|
||||
import requests
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
@@ -31,6 +39,124 @@ from ..models import (
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Thread pool for CPU/IO-intensive voice clone operations
|
||||
_audio_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="podcast_audio")
|
||||
|
||||
# In-memory LRU cache for voice samples (per user) to avoid re-downloading
|
||||
_voice_sample_cache: dict[str, tuple[float, bytes]] = {}
|
||||
_VOICE_SAMPLE_CACHE_TTL = 1800 # 30 minutes
|
||||
|
||||
|
||||
def _get_cached_voice_sample(cache_key: str) -> Optional[bytes]:
|
||||
"""Get voice sample bytes from in-memory cache if fresh."""
|
||||
if cache_key in _voice_sample_cache:
|
||||
ts, data = _voice_sample_cache[cache_key]
|
||||
if time.time() - ts < _VOICE_SAMPLE_CACHE_TTL:
|
||||
logger.debug(f"[Podcast] Voice sample cache hit for {cache_key[:16]}...")
|
||||
return data
|
||||
del _voice_sample_cache[cache_key]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_voice_sample(cache_key: str, data: bytes) -> None:
|
||||
"""Store voice sample bytes in in-memory cache."""
|
||||
# Evict oldest entries if cache grows too large
|
||||
if len(_voice_sample_cache) > 50:
|
||||
oldest_key = min(_voice_sample_cache, key=lambda k: _voice_sample_cache[k][0])
|
||||
del _voice_sample_cache[oldest_key]
|
||||
_voice_sample_cache[cache_key] = (time.time(), data)
|
||||
|
||||
|
||||
def _get_latest_voice_sample_url(user_id: str, db) -> Optional[str]:
|
||||
"""Get the latest voice sample URL for a user from their voice clone assets."""
|
||||
try:
|
||||
from models.content_asset_models import ContentAsset, AssetType, AssetSource
|
||||
from sqlalchemy import desc
|
||||
|
||||
asset = db.query(ContentAsset).filter(
|
||||
ContentAsset.user_id == user_id,
|
||||
ContentAsset.asset_type == AssetType.AUDIO,
|
||||
ContentAsset.source_module == AssetSource.VOICE_CLONER,
|
||||
).order_by(desc(ContentAsset.created_at)).first()
|
||||
|
||||
if asset and asset.file_url:
|
||||
logger.info(f"[Podcast] Found voice sample for user {user_id}: {asset.file_url}")
|
||||
return asset.file_url
|
||||
|
||||
logger.warning(f"[Podcast] No voice sample asset found for user {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Error fetching voice sample URL: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_voice_sample(voice_sample_url: str, user_id: str) -> Optional[bytes]:
|
||||
"""Fetch voice sample audio bytes from URL, with caching."""
|
||||
cache_key = hashlib.md5(f"{user_id}:{voice_sample_url}".encode()).hexdigest()
|
||||
|
||||
# Check in-memory cache first
|
||||
cached = _get_cached_voice_sample(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
from utils.media_utils import resolve_media_path
|
||||
|
||||
# Try resolving as a local workspace path first (fastest)
|
||||
if "/api/assets/" in voice_sample_url:
|
||||
# Resolve user workspace path directly
|
||||
sanitized_uid = "".join(c for c in user_id if c.isalnum() or c in ("-", "_"))
|
||||
from api.podcast.constants import ROOT_DIR
|
||||
parts = voice_sample_url.split("/")
|
||||
# Expected: /api/assets/{user_id}/voice_samples/{filename}
|
||||
try:
|
||||
idx = parts.index("voice_samples")
|
||||
filename = parts[idx + 1].split("?")[0]
|
||||
local_path = ROOT_DIR / "workspace" / f"workspace_{sanitized_uid}" / "assets" / "voice_samples" / filename
|
||||
if local_path.exists():
|
||||
data = local_path.read_bytes()
|
||||
_cache_voice_sample(cache_key, data)
|
||||
logger.info(f"[Podcast] Voice sample loaded from workspace: {local_path}")
|
||||
return data
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Fall back to media utils resolver
|
||||
local_path = resolve_media_path(voice_sample_url)
|
||||
if local_path and local_path.exists():
|
||||
data = local_path.read_bytes()
|
||||
_cache_voice_sample(cache_key, data)
|
||||
return data
|
||||
|
||||
# Try resolving as a podcast audio file
|
||||
if "/api/podcast/audio/" in voice_sample_url:
|
||||
filename = voice_sample_url.split("/api/podcast/audio/")[-1].split("?")[0]
|
||||
try:
|
||||
audio_dir = get_podcast_media_dir("audio", user_id)
|
||||
local_path = audio_dir / filename
|
||||
if local_path.exists():
|
||||
data = local_path.read_bytes()
|
||||
_cache_voice_sample(cache_key, data)
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try direct HTTP fetch as fallback
|
||||
if voice_sample_url.startswith("http"):
|
||||
logger.info(f"[Podcast] Fetching voice sample via HTTP: {voice_sample_url[:80]}...")
|
||||
resp = requests.get(voice_sample_url, timeout=30)
|
||||
if resp.status_code == 200:
|
||||
data = resp.content
|
||||
_cache_voice_sample(cache_key, data)
|
||||
logger.info(f"[Podcast] Voice sample fetched via HTTP ({len(data)} bytes)")
|
||||
return data
|
||||
|
||||
logger.warning(f"[Podcast] Could not fetch voice sample from: {voice_sample_url}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Error fetching voice sample: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/audio/upload")
|
||||
async def upload_podcast_audio(
|
||||
@@ -125,36 +251,190 @@ async def generate_podcast_audio(
|
||||
raise HTTPException(status_code=400, detail="Text is required")
|
||||
|
||||
try:
|
||||
audio_service = get_podcast_audio_service(user_id)
|
||||
logger.warning(f"[Podcast] Generating audio with service dir: {audio_service.output_dir}")
|
||||
result: StoryAudioResult = audio_service.generate_ai_audio(
|
||||
scene_number=0,
|
||||
scene_title=request.scene_title,
|
||||
text=request.text.strip(),
|
||||
user_id=user_id,
|
||||
voice_id=request.voice_id or "Wise_Woman",
|
||||
custom_voice_id=request.custom_voice_id,
|
||||
speed=request.speed or 1.0, # Normal speed (was 0.9, but too slow - causing duration issues)
|
||||
volume=request.volume or 1.0,
|
||||
pitch=request.pitch or 0.0, # Normal pitch (0.0 = neutral)
|
||||
emotion=request.emotion or "neutral",
|
||||
english_normalization=request.english_normalization or False,
|
||||
sample_rate=request.sample_rate,
|
||||
bitrate=request.bitrate,
|
||||
channel=request.channel,
|
||||
format=request.format,
|
||||
language_boost=request.language_boost,
|
||||
enable_sync_mode=request.enable_sync_mode,
|
||||
# Determine if we should use voice clone path
|
||||
# Voice clone is used when: explicitly requested, OR when voice_id/custom_voice_id indicates a clone
|
||||
# (cloned voice IDs start with "vc_" or match the placeholder "MY_VOICE_CLONE")
|
||||
_vid = request.voice_id or ""
|
||||
_cvid = request.custom_voice_id or ""
|
||||
is_voice_clone = request.use_voice_clone or (
|
||||
_cvid.startswith("vc_") or _cvid == "MY_VOICE_CLONE"
|
||||
) or (
|
||||
_vid.startswith("vc_") or _vid == "MY_VOICE_CLONE"
|
||||
)
|
||||
|
||||
# Override URL to use podcast endpoint instead of story endpoint
|
||||
if result.get("audio_url") and "/api/story/audio/" in result.get("audio_url", ""):
|
||||
audio_filename = result.get("audio_filename", "")
|
||||
result["audio_url"] = f"/api/podcast/audio/{audio_filename}"
|
||||
|
||||
logger.warning(f"[Podcast] Audio generated - path: {result.get('audio_path')}, url: {result.get('audio_url')}")
|
||||
# If voice_id is a clone ID, normalize it to use Wise_Woman for TTS fallback
|
||||
effective_voice_id = _vid if not (_vid.startswith("vc_") or _vid == "MY_VOICE_CLONE") else "Wise_Woman"
|
||||
|
||||
logger.warning(f"[Podcast] Audio request: use_voice_clone={request.use_voice_clone}, voice_id={request.voice_id}, custom_voice_id={request.custom_voice_id}, is_voice_clone={is_voice_clone}, voice_sample_url={request.voice_sample_url}, voice_clone_engine={request.voice_clone_engine}")
|
||||
|
||||
# Voice clone path: use user's voice sample with scene text as reference
|
||||
if is_voice_clone:
|
||||
# If no voice_sample_url provided, try to fetch it from the user's latest voice clone
|
||||
voice_sample_url = request.voice_sample_url
|
||||
if not voice_sample_url:
|
||||
try:
|
||||
voice_sample_url = _get_latest_voice_sample_url(user_id, db)
|
||||
logger.warning(f"[Podcast] DB fallback voice sample URL for user {user_id}: {voice_sample_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] Could not fetch voice sample URL: {e}")
|
||||
|
||||
if voice_sample_url:
|
||||
from services.llm_providers.main_audio_generation import qwen3_voice_clone, cosyvoice_voice_clone
|
||||
from utils.media_utils import detect_audio_format
|
||||
|
||||
engine = (request.voice_clone_engine or "qwen3").lower()
|
||||
logger.warning(f"[Podcast] 🔊 Voice clone path: engine={engine}, scene='{request.scene_title}', voice_sample_url={voice_sample_url[:80]}...")
|
||||
|
||||
# Download voice sample from URL (with caching)
|
||||
logger.warning(f"[Podcast] Fetching voice sample from: {voice_sample_url}")
|
||||
try:
|
||||
voice_sample_bytes = _fetch_voice_sample(voice_sample_url, user_id)
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"[Podcast] ❌ Failed to fetch voice sample: {fetch_err}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=f"Could not fetch voice sample: {str(fetch_err)}")
|
||||
logger.warning(f"[Podcast] Voice sample fetch result: {len(voice_sample_bytes) if voice_sample_bytes else 0} bytes")
|
||||
if not voice_sample_bytes:
|
||||
raise HTTPException(status_code=400, detail=f"Could not fetch voice sample from {voice_sample_url}")
|
||||
|
||||
# Detect actual audio format from bytes (may differ from file extension)
|
||||
detected_fmt, detected_mime = detect_audio_format(voice_sample_bytes)
|
||||
logger.warning(f"[Podcast] 🔊 Detected voice sample format: {detected_fmt} ({detected_mime}), {len(voice_sample_bytes)} bytes")
|
||||
voice_mime_type = detected_mime or "audio/wav"
|
||||
|
||||
scene_text = request.text.strip()
|
||||
if len(scene_text) > 4000:
|
||||
scene_text = scene_text[:4000]
|
||||
|
||||
# Run voice clone in thread pool to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
if engine == "minimax":
|
||||
from services.llm_providers.main_audio_generation import clone_voice
|
||||
import random
|
||||
import string
|
||||
random_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
|
||||
custom_vid = request.custom_voice_id or f"vc_{random_suffix}"
|
||||
|
||||
result_obj = await loop.run_in_executor(
|
||||
_audio_executor,
|
||||
lambda cv=custom_vid: clone_voice(
|
||||
audio_bytes=voice_sample_bytes,
|
||||
custom_voice_id=cv,
|
||||
text=scene_text,
|
||||
user_id=user_id,
|
||||
),
|
||||
)
|
||||
audio_bytes = result_obj.preview_audio_bytes
|
||||
provider = "minimax"
|
||||
model = "minimax/voice-clone"
|
||||
elif engine == "cosyvoice":
|
||||
result_obj = await loop.run_in_executor(
|
||||
_audio_executor,
|
||||
lambda: cosyvoice_voice_clone(
|
||||
audio_bytes=voice_sample_bytes,
|
||||
text=scene_text,
|
||||
user_id=user_id,
|
||||
audio_mime_type=voice_mime_type,
|
||||
),
|
||||
)
|
||||
audio_bytes = result_obj.preview_audio_bytes
|
||||
provider = "wavespeed-ai"
|
||||
model = "wavespeed-ai/cosyvoice-tts/voice-clone"
|
||||
else:
|
||||
result_obj = await loop.run_in_executor(
|
||||
_audio_executor,
|
||||
lambda: qwen3_voice_clone(
|
||||
audio_bytes=voice_sample_bytes,
|
||||
text=scene_text,
|
||||
user_id=user_id,
|
||||
audio_mime_type=voice_mime_type,
|
||||
),
|
||||
)
|
||||
audio_bytes = result_obj.preview_audio_bytes
|
||||
provider = "wavespeed-ai"
|
||||
model = "wavespeed-ai/qwen3-tts/voice-clone"
|
||||
|
||||
logger.warning(f"[Podcast] 🔊 Voice clone result: {len(audio_bytes) if audio_bytes else 0} bytes, provider={provider}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as clone_err:
|
||||
logger.error(f"[Podcast] ❌ Voice clone failed: {clone_err}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Voice clone generation failed: {str(clone_err)}")
|
||||
|
||||
# Save audio bytes to file
|
||||
audio_service = get_podcast_audio_service(user_id)
|
||||
audio_filename = f"scene_{request.scene_id}_{uuid.uuid4().hex[:8]}.mp3"
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
|
||||
with open(audio_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
file_size = len(audio_bytes)
|
||||
audio_url = f"/api/podcast/audio/{audio_filename}"
|
||||
cost = max(0.005, 0.005 * (len(scene_text) / 100.0))
|
||||
|
||||
result = {
|
||||
"audio_path": str(audio_path),
|
||||
"audio_filename": audio_filename,
|
||||
"audio_url": audio_url,
|
||||
"file_size": file_size,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"cost": cost,
|
||||
"scene_number": 0,
|
||||
"scene_title": request.scene_title,
|
||||
}
|
||||
|
||||
else:
|
||||
# Standard TTS path - but NOT if custom_voice_id is a clone ID
|
||||
# Clone IDs (vc_*, MY_VOICE_CLONE) are not valid for minimax TTS
|
||||
if is_voice_clone:
|
||||
logger.warning(f"[Podcast] ⚠️ Voice clone detected but no voice sample available - falling back to standard TTS with voice_id={effective_voice_id}")
|
||||
effective_custom_voice_id = request.custom_voice_id
|
||||
if effective_custom_voice_id and (
|
||||
effective_custom_voice_id.startswith("vc_") or
|
||||
effective_custom_voice_id == "MY_VOICE_CLONE"
|
||||
):
|
||||
logger.warning(f"[Podcast] Ignoring clone ID '{effective_custom_voice_id}' in standard TTS path - no voice sample URL available")
|
||||
effective_custom_voice_id = None
|
||||
|
||||
audio_service = get_podcast_audio_service(user_id)
|
||||
logger.warning(f"[Podcast] Standard TTS path: voice_id={effective_voice_id}, custom_voice_id={effective_custom_voice_id}")
|
||||
result: StoryAudioResult = audio_service.generate_ai_audio(
|
||||
scene_number=0,
|
||||
scene_title=request.scene_title,
|
||||
text=request.text.strip(),
|
||||
user_id=user_id,
|
||||
voice_id=effective_voice_id,
|
||||
custom_voice_id=effective_custom_voice_id,
|
||||
speed=request.speed or 1.0, # Normal speed (was 0.9, but too slow - causing duration issues)
|
||||
volume=request.volume or 1.0,
|
||||
pitch=request.pitch or 0.0, # Normal pitch (0.0 = neutral)
|
||||
emotion=request.emotion or "neutral",
|
||||
english_normalization=request.english_normalization or False,
|
||||
sample_rate=request.sample_rate,
|
||||
bitrate=request.bitrate,
|
||||
channel=request.channel,
|
||||
format=request.format,
|
||||
language_boost=request.language_boost,
|
||||
enable_sync_mode=request.enable_sync_mode,
|
||||
)
|
||||
|
||||
# Override URL to use podcast endpoint instead of story endpoint
|
||||
if result.get("audio_url") and "/api/story/audio/" in result.get("audio_url", ""):
|
||||
audio_filename = result.get("audio_filename", "")
|
||||
result["audio_url"] = f"/api/podcast/audio/{audio_filename}"
|
||||
|
||||
logger.warning(f"[Podcast] Audio generated - path: {result.get('audio_path')}, url: {result.get('audio_url')}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Audio generation failed: {exc}")
|
||||
exc_type = type(exc).__name__
|
||||
exc_msg = str(exc)[:500]
|
||||
logger.error(f"[Podcast] Audio generation failed ({exc_type}): {exc_msg}")
|
||||
logger.error(f"[Podcast] Audio generation traceback:", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Audio generation failed ({exc_type}): {exc_msg}")
|
||||
|
||||
# Save to asset library (podcast module)
|
||||
try:
|
||||
@@ -391,9 +671,12 @@ async def serve_podcast_audio(
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.warning(f"[Podcast] serve_podcast_audio called: user_id={user_id}, filename={filename}")
|
||||
logger.info(f"[Podcast] serve_podcast_audio: filename={filename}, user_id={user_id}")
|
||||
|
||||
audio_path = _resolve_podcast_media_file(filename, "audio", user_id)
|
||||
logger.warning(f"[Podcast] Resolved audio path: {audio_path}")
|
||||
logger.info(f"[Podcast] Audio resolved path: {audio_path}, exists={audio_path.exists()}")
|
||||
audio_path = _resolve_podcast_media_file(filename, "audio", user_id)
|
||||
logger.debug(f"[Podcast] Resolved audio path: {audio_path}")
|
||||
|
||||
return FileResponse(audio_path, media_type="audio/mpeg")
|
||||
|
||||
|
||||
@@ -12,22 +12,39 @@ from pathlib import Path
|
||||
import uuid
|
||||
import hashlib
|
||||
|
||||
from services.database import get_db
|
||||
from services.database import get_db, get_session_for_user
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_image_editing import edit_image
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..constants import get_podcast_media_dir, PODCAST_AVATARS_SUBDIR
|
||||
from ..presenter_personas import choose_persona_id, get_persona
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Avatar subdirectory
|
||||
AVATAR_SUBDIR = "avatars"
|
||||
PODCAST_AVATARS_DIR = PODCAST_IMAGES_DIR / AVATAR_SUBDIR
|
||||
PODCAST_AVATARS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
AVATAR_SUBDIR = PODCAST_AVATARS_SUBDIR
|
||||
|
||||
|
||||
async def _get_db_or_none(current_user: Dict[str, Any]):
|
||||
"""Try to get a database session, returning None on failure (non-fatal for uploads)."""
|
||||
try:
|
||||
user_id = current_user.get('id') or current_user.get('clerk_user_id')
|
||||
if not user_id:
|
||||
return None
|
||||
return get_session_for_user(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] DB session unavailable (non-fatal): {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_podcast_avatars_dir(user_id: str) -> Path:
|
||||
"""Get podcast avatars directory for a user (workspace-aware)."""
|
||||
avatars_dir = get_podcast_media_dir("image", user_id, ensure_exists=True) / AVATAR_SUBDIR
|
||||
avatars_dir.mkdir(parents=True, exist_ok=True)
|
||||
return avatars_dir
|
||||
|
||||
|
||||
@router.post("/avatar/upload")
|
||||
@@ -41,8 +58,16 @@ async def upload_podcast_avatar(
|
||||
Upload a presenter avatar image for a podcast project.
|
||||
Returns the avatar URL for use in scene image generation.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Avatar upload auth failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
logger.info(f"[Podcast] Avatar upload request - user_id={user_id}, project_id={project_id}, content_type={file.content_type}")
|
||||
|
||||
# Validate file type
|
||||
if not file.content_type or not file.content_type.startswith('image/'):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
@@ -57,19 +82,21 @@ async def upload_podcast_avatar(
|
||||
file_ext = Path(file.filename).suffix or '.png'
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
avatar_filename = f"avatar_{project_id or 'temp'}_{unique_id}{file_ext}"
|
||||
avatar_path = PODCAST_AVATARS_DIR / avatar_filename
|
||||
avatars_dir = _get_podcast_avatars_dir(user_id)
|
||||
logger.info(f"[Podcast] Saving avatar to: {avatars_dir / avatar_filename}")
|
||||
avatar_path = avatars_dir / avatar_filename
|
||||
|
||||
# Save file
|
||||
with open(avatar_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
logger.info(f"[Podcast] Avatar uploaded: {avatar_path}")
|
||||
logger.info(f"[Podcast] Avatar uploaded successfully: {avatar_path}")
|
||||
|
||||
# Create avatar URL
|
||||
avatar_url = f"/api/podcast/images/{AVATAR_SUBDIR}/{avatar_filename}"
|
||||
|
||||
# Save to asset library if project_id provided
|
||||
if project_id:
|
||||
# Save to asset library if project_id provided and DB session available
|
||||
if project_id and db:
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
@@ -91,13 +118,17 @@ async def upload_podcast_avatar(
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] Failed to save avatar asset: {e}")
|
||||
logger.warning(f"[Podcast] Failed to save avatar asset (non-fatal): {e}")
|
||||
elif project_id and not db:
|
||||
logger.warning(f"[Podcast] DB session unavailable, skipping asset library save for avatar")
|
||||
|
||||
return {
|
||||
"avatar_url": avatar_url,
|
||||
"avatar_filename": avatar_filename,
|
||||
"message": "Avatar uploaded successfully"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast] Avatar upload failed: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Avatar upload failed: {str(exc)}")
|
||||
@@ -114,12 +145,18 @@ async def make_avatar_presentable(
|
||||
Transform an uploaded avatar image into a podcast-appropriate presenter.
|
||||
Uses AI image editing to convert the uploaded photo into a professional podcast presenter.
|
||||
"""
|
||||
# CRITICAL: Log at the very start before any logic
|
||||
logger.info(f"[Podcast] ===== MAKE PRESENTABLE ENDPOINT START =====")
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.info(f"[Podcast] Make presentable request received - user_id={user_id}, avatar_url={avatar_url}, project_id={project_id}")
|
||||
|
||||
try:
|
||||
# Load the uploaded avatar image
|
||||
from ..utils import load_podcast_image_bytes
|
||||
avatar_bytes = load_podcast_image_bytes(avatar_url)
|
||||
logger.info(f"[Podcast] Loading avatar image from {avatar_url}")
|
||||
avatar_bytes = load_podcast_image_bytes(avatar_url, user_id=user_id)
|
||||
logger.info(f"[Podcast] Avatar loaded successfully - size={len(avatar_bytes)} bytes")
|
||||
|
||||
logger.info(f"[Podcast] Transforming avatar to podcast presenter for project {project_id}")
|
||||
|
||||
@@ -141,17 +178,24 @@ async def make_avatar_presentable(
|
||||
"model": None, # Use default model
|
||||
}
|
||||
|
||||
result = edit_image(
|
||||
input_image_bytes=avatar_bytes,
|
||||
prompt=transformation_prompt,
|
||||
options=image_options,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Podcast] Calling edit_image with user_id={user_id}")
|
||||
try:
|
||||
result = edit_image(
|
||||
input_image_bytes=avatar_bytes,
|
||||
prompt=transformation_prompt,
|
||||
options=image_options,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Podcast] edit_image completed successfully - provider={result.provider}, model={result.model}")
|
||||
except Exception as edit_err:
|
||||
logger.error(f"[Podcast] edit_image failed: {edit_err}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image editing failed: {str(edit_err)}")
|
||||
|
||||
# Save transformed avatar
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
transformed_filename = f"presenter_transformed_{project_id or 'temp'}_{unique_id}.png"
|
||||
transformed_path = PODCAST_AVATARS_DIR / transformed_filename
|
||||
avatars_dir = _get_podcast_avatars_dir(user_id)
|
||||
transformed_path = avatars_dir / transformed_filename
|
||||
|
||||
with open(transformed_path, "wb") as f:
|
||||
f.write(result.image_bytes)
|
||||
@@ -194,6 +238,16 @@ async def make_avatar_presentable(
|
||||
"avatar_filename": transformed_filename,
|
||||
"message": "Avatar transformed into podcast presenter successfully"
|
||||
}
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions as-is
|
||||
raise
|
||||
except RuntimeError as rt_err:
|
||||
# Handle missing API keys or configuration errors
|
||||
logger.error(f"[Podcast] Avatar transformation configuration error: {rt_err}")
|
||||
raise HTTPException(
|
||||
status_code=503, # Service Unavailable
|
||||
detail=f"Image editing service not configured: {str(rt_err)}. Please contact support."
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast] Avatar transformation failed: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Avatar transformation failed: {str(exc)}")
|
||||
@@ -323,7 +377,8 @@ async def generate_podcast_presenters(
|
||||
# Save avatar
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
avatar_filename = f"presenter_{project_id or 'temp'}_{i+1}_{unique_id}.png"
|
||||
avatar_path = PODCAST_AVATARS_DIR / avatar_filename
|
||||
avatars_dir = _get_podcast_avatars_dir(user_id)
|
||||
avatar_path = avatars_dir / avatar_filename
|
||||
|
||||
with open(avatar_path, "wb") as f:
|
||||
f.write(result.image_bytes)
|
||||
|
||||
398
backend/api/podcast/handlers/broll.py
Normal file
398
backend/api/podcast/handlers/broll.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
B-Roll Handlers
|
||||
|
||||
API endpoints for B-roll chart preview and video generation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from api.story_writer.task_manager import task_manager
|
||||
from api.podcast.utils import _resolve_podcast_media_file
|
||||
from services.podcast.broll_service import get_broll_service
|
||||
from utils.media_utils import resolve_media_path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/broll", tags=["B-Roll"])
|
||||
|
||||
|
||||
def _resolve_broll_background_image_path(background_image_url: str) -> str:
|
||||
"""Resolve background image URL/path to a local file path."""
|
||||
resolved = resolve_media_path(background_image_url)
|
||||
if not resolved:
|
||||
raise HTTPException(status_code=404, detail=f"Background image not found: {background_image_url}")
|
||||
return str(resolved)
|
||||
|
||||
|
||||
def _resolve_broll_avatar_video_path(avatar_video_url: Optional[str], user_id: str) -> Optional[str]:
|
||||
"""Resolve optional avatar video URL/path to a local file path."""
|
||||
if not avatar_video_url:
|
||||
return None
|
||||
|
||||
parsed = urlparse(avatar_video_url)
|
||||
path = parsed.path if parsed.scheme else avatar_video_url
|
||||
|
||||
if "/api/podcast/videos/" in path:
|
||||
filename = path.split("/api/podcast/videos/", 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid avatar video URL")
|
||||
return str(_resolve_podcast_media_file(filename, "video", user_id))
|
||||
|
||||
local_path = Path(path).expanduser().resolve()
|
||||
if local_path.exists() and local_path.is_file():
|
||||
return str(local_path)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Unsupported avatar video URL format. "
|
||||
"Use /api/podcast/videos/{filename} or a valid local file path."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _execute_broll_scene_task(
|
||||
task_id: str,
|
||||
*,
|
||||
scene_id: str,
|
||||
key_insight: str,
|
||||
supporting_stat: str,
|
||||
chart_data: Optional[Dict[str, Any]],
|
||||
visual_cue: str,
|
||||
duration: float,
|
||||
background_img_path: str,
|
||||
avatar_video_path: Optional[str],
|
||||
):
|
||||
"""Background task for rendering a B-roll scene."""
|
||||
try:
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=10.0,
|
||||
message="Starting B-roll scene render...",
|
||||
)
|
||||
|
||||
broll_service = get_broll_service()
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=35.0,
|
||||
message="Composing scene layers and overlays...",
|
||||
)
|
||||
|
||||
video_path = broll_service.generate_scene_broll(
|
||||
scene_id=scene_id,
|
||||
key_insight=key_insight,
|
||||
supporting_stat=supporting_stat,
|
||||
chart_data=chart_data,
|
||||
visual_cue=visual_cue,
|
||||
duration=duration,
|
||||
background_img_path=background_img_path,
|
||||
avatar_video_path=avatar_video_path,
|
||||
)
|
||||
|
||||
filename = Path(video_path).name
|
||||
video_url = f"/api/podcast/broll/final/{filename}"
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="B-roll scene render completed.",
|
||||
result={
|
||||
"scene_id": scene_id,
|
||||
"broll_video_path": video_path,
|
||||
"broll_video_url": video_url,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Broll] Task {task_id} failed: {exc}")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=f"B-roll scene render failed: {str(exc)}",
|
||||
error_status=500,
|
||||
)
|
||||
|
||||
|
||||
class ChartPreviewRequest(BaseModel):
|
||||
"""Request model for chart preview generation."""
|
||||
chart_data: Dict[str, Any] = Field(..., description="Chart data (labels, before/after, etc.)")
|
||||
chart_type: str = Field(
|
||||
default="bar_comparison",
|
||||
description="bar_comparison | bar_horizontal | line_trend | pie | stacked_bar | bullet"
|
||||
)
|
||||
title: str = Field(default="", description="Chart title")
|
||||
subtitle: Optional[str] = Field(default="", description="Optional subtitle at bottom")
|
||||
|
||||
|
||||
class ChartPreviewResponse(BaseModel):
|
||||
"""Response for chart preview."""
|
||||
preview_url: str
|
||||
chart_id: str
|
||||
|
||||
|
||||
class BrollSceneRequest(BaseModel):
|
||||
"""Request for generating B-roll video for a scene."""
|
||||
scene_id: str
|
||||
key_insight: str
|
||||
supporting_stat: str
|
||||
chart_data: Optional[Dict[str, Any]] = None
|
||||
visual_cue: str = Field(default="bar_comparison", description="bar_comparison | bar_horizontal | line_trend | pie | stacked_bar | bullet_points | full_avatar")
|
||||
duration: float = Field(default=10.0, ge=3.0, le=60.0)
|
||||
background_image_url: str
|
||||
avatar_video_url: Optional[str] = None
|
||||
|
||||
|
||||
class BrollSceneResponse(BaseModel):
|
||||
"""Response for B-roll scene generation."""
|
||||
scene_id: str
|
||||
broll_video_url: str = ""
|
||||
broll_video_path: str = ""
|
||||
task_id: Optional[str] = None
|
||||
status: str = "completed"
|
||||
message: Optional[str] = None
|
||||
|
||||
|
||||
class BrollComposeRequest(BaseModel):
|
||||
"""Request for composing multiple B-roll videos."""
|
||||
scene_video_paths: List[str]
|
||||
output_filename: str = "final_broll.mp4"
|
||||
fade_dur: float = Field(default=0.5, ge=0.0, le=2.0)
|
||||
fps: int = Field(default=24, ge=12, le=60)
|
||||
|
||||
|
||||
class BrollComposeResponse(BaseModel):
|
||||
"""Response for B-roll composition."""
|
||||
final_video_url: str
|
||||
final_video_path: str
|
||||
|
||||
|
||||
@router.post("/preview/chart", response_model=ChartPreviewResponse)
|
||||
async def generate_chart_preview(
|
||||
request: ChartPreviewRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Generate a chart PNG preview (static image for Write phase).
|
||||
|
||||
This endpoint is called from the Write phase to show users chart previews
|
||||
before they commit to B-roll video generation.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Debug logging
|
||||
logger.warning(f"[Broll] Chart preview request: type={request.chart_type}, title={request.title}, chart_data keys={list(request.chart_data.keys())}, user_id={user_id}")
|
||||
|
||||
try:
|
||||
broll_service = get_broll_service(user_id=user_id)
|
||||
chart_id = uuid.uuid4().hex[:8]
|
||||
|
||||
preview_path = broll_service.generate_chart_preview(
|
||||
chart_data=request.chart_data,
|
||||
chart_type=request.chart_type,
|
||||
title=request.title,
|
||||
subtitle=request.subtitle or "",
|
||||
chart_id=chart_id,
|
||||
)
|
||||
|
||||
# If chart generation failed (empty path), return a placeholder instead of 500
|
||||
if not preview_path:
|
||||
# Return a fallback response so frontend doesn't crash
|
||||
logger.warning(f"[Broll] Chart preview skipped - invalid data for type: {request.chart_type}")
|
||||
return ChartPreviewResponse(
|
||||
preview_url="",
|
||||
chart_id=chart_id,
|
||||
)
|
||||
|
||||
preview_filename = Path(preview_path).name
|
||||
preview_url = f"/api/podcast/broll/preview/{chart_id}/{preview_filename}"
|
||||
|
||||
logger.warning(f"[Broll] Chart preview generated: chart_id={chart_id}, path={preview_path}, url={preview_url}")
|
||||
|
||||
return ChartPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
chart_id=chart_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Broll] Chart preview generation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Chart preview failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/render/broll-scene", response_model=BrollSceneResponse)
|
||||
async def generate_broll_scene(
|
||||
request: BrollSceneRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Generate a B-roll video for a single scene.
|
||||
|
||||
This creates a programmatic video with:
|
||||
- Background image with Ken Burns effect
|
||||
- Chart overlay (if chart_data provided)
|
||||
- Avatar circle in corner (if avatar_video_url provided)
|
||||
- Insight card at bottom
|
||||
|
||||
Returns a task_id for polling since video generation can take time.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
# Validate visual_cue
|
||||
valid_cues = ["bar_comparison", "bar_chart_comparison", "bar_horizontal", "line_trend", "pie", "stacked_bar", "bullet_points", "full_avatar"]
|
||||
if request.visual_cue not in valid_cues:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid visual_cue. Must be one of: {valid_cues}"
|
||||
)
|
||||
|
||||
background_img_path = _resolve_broll_background_image_path(request.background_image_url)
|
||||
avatar_video_path = _resolve_broll_avatar_video_path(request.avatar_video_url, user_id)
|
||||
|
||||
logger.info(f"[Broll] B-roll scene request for scene: {request.scene_id}")
|
||||
|
||||
# Scene rendering can be expensive, so use task manager/background execution.
|
||||
task_id = task_manager.create_task(
|
||||
"podcast_broll_scene_generation",
|
||||
metadata={"owner_user_id": user_id, "scene_id": request.scene_id},
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
_execute_broll_scene_task,
|
||||
task_id=task_id,
|
||||
scene_id=request.scene_id,
|
||||
key_insight=request.key_insight,
|
||||
supporting_stat=request.supporting_stat,
|
||||
chart_data=request.chart_data,
|
||||
visual_cue=request.visual_cue,
|
||||
duration=request.duration,
|
||||
background_img_path=background_img_path,
|
||||
avatar_video_path=avatar_video_path,
|
||||
)
|
||||
|
||||
return BrollSceneResponse(
|
||||
scene_id=request.scene_id,
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
message="B-roll scene render started. Poll /api/podcast/task/{task_id}/status for progress.",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Broll] B-roll scene generation failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"B-roll generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/render/broll-compose", response_model=BrollComposeResponse)
|
||||
async def compose_broll_videos(
|
||||
request: BrollComposeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Compose multiple B-roll scene videos into a final video.
|
||||
|
||||
Applies crossfade transitions between scenes.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
broll_service = get_broll_service()
|
||||
|
||||
final_path = broll_service.compose_final_video(
|
||||
video_paths=request.scene_video_paths,
|
||||
output_filename=request.output_filename,
|
||||
fade_dur=request.fade_dur,
|
||||
fps=request.fps,
|
||||
)
|
||||
|
||||
final_filename = final_path.split('/')[-1]
|
||||
final_url = f"/api/podcast/broll/final/{final_filename}"
|
||||
|
||||
return BrollComposeResponse(
|
||||
final_video_url=final_url,
|
||||
final_video_path=final_path,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Broll] Video composition failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Video composition failed: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/preview/{chart_id}/{filename}")
|
||||
async def serve_chart_preview(
|
||||
chart_id: str,
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""
|
||||
Serve chart preview PNG files.
|
||||
|
||||
Uses authentication via Authorization header or token query parameter,
|
||||
matching the pattern used by /api/podcast/images/ for browser <img> tags.
|
||||
"""
|
||||
from api.podcast.constants import get_podcast_media_dir
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Validate filename to prevent directory traversal
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
|
||||
logger.warning(f"[Broll] serve_chart_preview: chart_id={chart_id}, filename={filename}, user_id={user_id}")
|
||||
|
||||
charts_dir = get_podcast_media_dir("chart", user_id)
|
||||
file_path = charts_dir / filename
|
||||
|
||||
logger.warning(f"[Broll] serve_chart_preview: resolved path={file_path}, exists={file_path.exists()}")
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Chart preview not found")
|
||||
|
||||
# Security: ensure resolved path is within charts_dir
|
||||
if not str(file_path.resolve()).startswith(str(charts_dir.resolve())):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type="image/png",
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/final/{filename}")
|
||||
async def serve_final_broll(
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Serve final composed B-roll video files."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
broll_service = get_broll_service()
|
||||
file_path = broll_service.output_dir / filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Video not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type="video/mp4",
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def broll_health():
|
||||
"""Health check for B-roll service."""
|
||||
return {"status": "ok", "service": "broll"}
|
||||
@@ -17,7 +17,7 @@ from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_image_generation import generate_image, generate_character_image
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..constants import get_podcast_media_dir
|
||||
from ..models import PodcastImageRequest, PodcastImageResponse
|
||||
|
||||
router = APIRouter()
|
||||
@@ -69,7 +69,7 @@ async def generate_podcast_scene_image(
|
||||
from ..utils import load_podcast_image_bytes
|
||||
try:
|
||||
logger.info(f"[Podcast] Attempting to load base avatar from: {request.base_avatar_url}")
|
||||
base_avatar_bytes = load_podcast_image_bytes(request.base_avatar_url)
|
||||
base_avatar_bytes = load_podcast_image_bytes(request.base_avatar_url, user_id=user_id)
|
||||
logger.info(f"[Podcast] ✅ Successfully loaded base avatar ({len(base_avatar_bytes)} bytes) for scene {request.scene_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] ❌ Failed to load base avatar from {request.base_avatar_url}: {e}", exc_info=True)
|
||||
@@ -377,14 +377,14 @@ async def generate_podcast_scene_image(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Save image to podcast images directory
|
||||
PODCAST_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
# Save image to podcast images directory (workspace-aware)
|
||||
images_dir = get_podcast_media_dir("image", user_id, ensure_exists=True)
|
||||
|
||||
# Generate filename
|
||||
clean_title = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in request.scene_title[:30])
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
image_filename = f"scene_{request.scene_id}_{clean_title}_{unique_id}.png"
|
||||
image_path = PODCAST_IMAGES_DIR / image_filename
|
||||
image_path = images_dir / image_filename
|
||||
|
||||
# Save image
|
||||
with open(image_path, "wb") as f:
|
||||
@@ -470,16 +470,17 @@ async def serve_podcast_image(
|
||||
Query parameter is useful for HTML elements like <img> that cannot send custom headers.
|
||||
Supports subdirectories like avatars/
|
||||
"""
|
||||
require_authenticated_user(current_user)
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Security check: ensure path doesn't contain path traversal or absolute paths
|
||||
if ".." in path or path.startswith("/"):
|
||||
raise HTTPException(status_code=400, detail="Invalid path")
|
||||
|
||||
image_path = (PODCAST_IMAGES_DIR / path).resolve()
|
||||
images_dir = get_podcast_media_dir("image", user_id)
|
||||
image_path = (images_dir / path).resolve()
|
||||
|
||||
# Security check: ensure resolved path is within PODCAST_IMAGES_DIR
|
||||
if not str(image_path).startswith(str(PODCAST_IMAGES_DIR)):
|
||||
# Security check: ensure resolved path is within images_dir
|
||||
if not str(image_path).startswith(str(images_dir)):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if not image_path.exists():
|
||||
|
||||
@@ -11,6 +11,7 @@ from typing import Optional, Dict, Any
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.podcast_service import PodcastService
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastProjectResponse,
|
||||
CreateProjectRequest,
|
||||
@@ -106,25 +107,57 @@ async def update_project(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update a podcast project state."""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
logger.error(f"[Podcast] update_project: No user_id found in current_user: {current_user}")
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
# Get only field names being updated (not full data to avoid console flooding)
|
||||
request_dict = request.model_dump(exclude_none=True)
|
||||
updated_fields = list(request_dict.keys())
|
||||
|
||||
logger.warning(f"[Podcast] ===== UPDATE_PROJECT_START =====")
|
||||
logger.warning(f"[Podcast] project_id={project_id}, user_id={user_id}, fields={updated_fields}")
|
||||
|
||||
service = PodcastService(db)
|
||||
|
||||
# Convert request to dict, excluding None values
|
||||
updates = request.model_dump(exclude_unset=True)
|
||||
# Check if project exists; if not, create it (upsert behavior for resilience)
|
||||
existing = service.get_project(user_id, project_id)
|
||||
if not existing:
|
||||
logger.warning(f"[Podcast] Project {project_id} not found for user {user_id}, creating new project with default values")
|
||||
# Try to create the project - this handles cases where create succeeded but wasn't found later
|
||||
# (can happen with user_id mismatch or after session refresh)
|
||||
try:
|
||||
project = service.create_project(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
idea="Untitled Podcast",
|
||||
status="scripting",
|
||||
duration=10,
|
||||
speakers=1,
|
||||
budget_cap=0.0,
|
||||
)
|
||||
except Exception as create_err:
|
||||
logger.error(f"[Podcast] Failed to create project {project_id}: {create_err}")
|
||||
raise HTTPException(status_code=404, detail=f"Project {project_id} not found and could not create: {create_err}")
|
||||
else:
|
||||
# Convert request to dict, excluding None values
|
||||
updates = request.model_dump(exclude_unset=True)
|
||||
project = service.update_project(user_id, project_id, **updates)
|
||||
|
||||
project = service.update_project(user_id, project_id, **updates)
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[Podcast] ===== UPDATE_PROJECT_END (took {duration_ms}ms) =====")
|
||||
|
||||
return PodcastProjectResponse.model_validate(project)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"[Podcast] ===== UPDATE_PROJECT_ERROR (took {duration_ms}ms): {str(e)} =====")
|
||||
raise HTTPException(status_code=500, detail=f"Error updating project: {str(e)}")
|
||||
|
||||
|
||||
|
||||
@@ -9,37 +9,142 @@ from typing import Dict, Any, List
|
||||
from types import SimpleNamespace
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.database import get_db
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
from loguru import logger
|
||||
from ..cost_estimator import estimate_podcast_cost
|
||||
from ..models import (
|
||||
PodcastExaResearchRequest,
|
||||
PodcastExaResearchResponse,
|
||||
PodcastExaSource,
|
||||
PodcastExaConfig,
|
||||
PodcastResearchInsight,
|
||||
PodcastResearchOutput,
|
||||
PodcastCostEst,
|
||||
PodcastCostBreakdownItem,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
def _get_price_from_catalog(
|
||||
pricing_service: PricingService,
|
||||
provider: APIProvider,
|
||||
model_name: str,
|
||||
key: str,
|
||||
fallback: float = 0.0,
|
||||
) -> float:
|
||||
try:
|
||||
pricing = pricing_service.get_pricing_for_provider_model(provider, model_name) or {}
|
||||
value = pricing.get(key)
|
||||
return float(value or fallback)
|
||||
except Exception:
|
||||
return fallback
|
||||
|
||||
|
||||
def _build_research_cost_estimate(
|
||||
request: PodcastExaResearchRequest,
|
||||
raw_content: str,
|
||||
sources_count: int,
|
||||
provider_result: Dict[str, Any],
|
||||
user_id: str = "default",
|
||||
) -> PodcastCostEst:
|
||||
# Fallback defaults mirror current catalog defaults.
|
||||
exa_per_request = 0.005
|
||||
gemini_in_token = 0.00000015
|
||||
gemini_out_token = 0.0000006
|
||||
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
exa_per_request = _get_price_from_catalog(
|
||||
pricing_service, APIProvider.EXA, "exa-search", "cost_per_request", exa_per_request
|
||||
)
|
||||
gemini_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.GEMINI, "gemini-2.5-flash") or {}
|
||||
gemini_in_token = float(gemini_pricing.get("cost_per_input_token") or gemini_in_token)
|
||||
gemini_out_token = float(gemini_pricing.get("cost_per_output_token") or gemini_out_token)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as pricing_err:
|
||||
logger.warning(f"[Podcast Research] Failed loading pricing catalog; using defaults: {pricing_err}")
|
||||
|
||||
query_count = max(1, len(request.queries or []))
|
||||
source_count = max(1, sources_count)
|
||||
|
||||
analyze_tokens = _estimate_tokens(request.topic) + sum(_estimate_tokens(q) for q in request.queries or [])
|
||||
gather_search_calls = max(1, query_count)
|
||||
gather_cost = gather_search_calls * exa_per_request
|
||||
|
||||
write_input_tokens = _estimate_tokens(raw_content) + _estimate_tokens(request.topic) + (query_count * 40)
|
||||
write_output_tokens = max(500, int(write_input_tokens * 0.22))
|
||||
write_cost = (write_input_tokens * gemini_in_token) + (write_output_tokens * gemini_out_token)
|
||||
|
||||
# "Produce" is shaping the final API payload and mapped artifacts.
|
||||
produce_tokens = max(120, source_count * 30)
|
||||
produce_cost = (produce_tokens * gemini_in_token) + (produce_tokens * 0.5 * gemini_out_token)
|
||||
|
||||
analyze_cost = analyze_tokens * gemini_in_token
|
||||
|
||||
provider_total = 0.0
|
||||
if isinstance(provider_result, dict):
|
||||
provider_total = float((provider_result.get("cost") or {}).get("total") or 0.0)
|
||||
|
||||
# Prefer transparent estimate built from catalog + usage. If provider reports a higher measured value, keep it.
|
||||
estimated_total = analyze_cost + gather_cost + write_cost + produce_cost
|
||||
scale = (provider_total / estimated_total) if estimated_total > 0 and provider_total > estimated_total else 1.0
|
||||
|
||||
breakdown = [
|
||||
PodcastCostBreakdownItem(phase="Analyze", cost=round(analyze_cost * scale, 6)),
|
||||
PodcastCostBreakdownItem(phase="Gather", cost=round(gather_cost * scale, 6)),
|
||||
PodcastCostBreakdownItem(phase="Write", cost=round(write_cost * scale, 6)),
|
||||
PodcastCostBreakdownItem(phase="Produce", cost=round(produce_cost * scale, 6)),
|
||||
]
|
||||
total = round(sum(item.cost for item in breakdown), 6)
|
||||
|
||||
return PodcastCostEst(
|
||||
total=total,
|
||||
breakdown=breakdown,
|
||||
currency="USD",
|
||||
last_updated=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
@router.post("/research/exa", response_model=PodcastExaResearchResponse)
|
||||
async def podcast_research_exa(
|
||||
request: PodcastExaResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Run podcast research via Exa and then use LLM to extract deep insights.
|
||||
Uses Podcast Bible and Analysis context for hyper-personalization.
|
||||
"""
|
||||
start_time = time.time()
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.warning(f"[Podcast Research] ========== REQUEST START ==========")
|
||||
logger.warning(f"[Podcast Research] User: {user_id}, Topic: {request.topic[:80]}...")
|
||||
logger.warning(f"[Podcast Research] Queries count: {len(request.queries) if request.queries else 0}")
|
||||
|
||||
# Log only essential info, not full request data
|
||||
logger.warning(f"[Podcast Research] ===== RESEARCH_START =====")
|
||||
logger.warning(f"[Podcast Research] user={user_id}, topic='{request.topic[:50]}...', queries={len(request.queries) if request.queries else 0}")
|
||||
|
||||
|
||||
queries = [q.strip() for q in request.queries if q and q.strip()]
|
||||
@@ -97,6 +202,26 @@ Listener CTA: {request.analysis.get('listener_cta', 'N/A')}
|
||||
interests = ", ".join(audience_dna.get("interests", []))
|
||||
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
|
||||
|
||||
# Preflight subscription check for Exa
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'exa', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[Podcast Research] Preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast Research] Preflight check failed: {e}")
|
||||
|
||||
try:
|
||||
# 1. RUN EXA SEARCH
|
||||
logger.warning(f"[Podcast Research] Calling Exa search with topic: {request.topic[:100]}...")
|
||||
@@ -119,6 +244,9 @@ Listener CTA: {request.analysis.get('listener_cta', 'N/A')}
|
||||
|
||||
summary = ""
|
||||
key_insights = []
|
||||
expert_quotes = []
|
||||
listener_cta_suggestions = []
|
||||
mapped_angles = []
|
||||
|
||||
if raw_content and sources:
|
||||
logger.warning(f"[Podcast Research] Extracting insights from {len(sources)} sources for user {user_id}")
|
||||
@@ -159,43 +287,50 @@ As a podcast research expert, analyze this data and create content that will:
|
||||
4. Include a compelling call-to-action for listeners
|
||||
|
||||
REQUIRED OUTPUT (JSON):
|
||||
=======================
|
||||
======================
|
||||
{{
|
||||
"summary": "2-3 paragraph comprehensive summary in Markdown. Start with a hook that matches the episode intro. Include specific data points, expert quotes, and trends.",
|
||||
"summary": "2-3 paragraph comprehensive summary in Markdown. Start with a hook that matches the episode intro.",
|
||||
"key_insights": [
|
||||
{{
|
||||
"title": "Catchy, engaging title for this insight",
|
||||
"content": "3-4 sentences with specific facts, quotes, or data. Write in a conversational tone suitable for a podcast host to discuss.",
|
||||
"source_indices": [1, 2, 3],
|
||||
"podcast_talking_points": ["Point 1 host can expand on", "Counter-point or follow-up", "Question to ask guest"]
|
||||
"title": "Insight title",
|
||||
"content": "3-4 sentences with specific facts, quotes, or data for podcast host.",
|
||||
"source_indices": [1, 2],
|
||||
"podcast_talking_points": ["Point host can expand on", "Counter-point"]
|
||||
}}
|
||||
],
|
||||
"expert_quotes": [
|
||||
{{
|
||||
"quote": "Direct quote from source",
|
||||
"quote": "Direct quote from source text",
|
||||
"source_index": 1,
|
||||
"context": "Why this quote matters for the podcast"
|
||||
}}
|
||||
],
|
||||
"listener_cta_suggestions": ["Specific action listener can take", "Resource to share", "Next episode preview"]
|
||||
"listener_cta_suggestions": ["Action listener can take", "Resource to share", "Next episode preview"],
|
||||
"mapped_angles": [
|
||||
{{
|
||||
"title": "Content angle title",
|
||||
"why": "Why compelling for audience",
|
||||
"mapped_fact_ids": [1, 2]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
IMPORTANT: You must include ALL fields above with valid data. expert_quotes, listener_cta_suggestions, and mapped_angles must have content - do NOT leave them empty!
|
||||
|
||||
QUALITY STANDARDS:
|
||||
==================
|
||||
- INSIGHTS MUST BE DEEP, not superficial - avoid generic statements
|
||||
- Include SPECIFIC DATA POINTS, percentages, statistics when available
|
||||
- Extract EXPERT QUOTES that hosts can reference
|
||||
- Identify GAPS in the research where more depth is needed
|
||||
- Make content naturally flow into the planned episode hook and CTA
|
||||
- Write in a CONVERSATIONAL tone - how a host would actually speak
|
||||
- Flag any CONTROVERSIAL or debatable claims for host to address
|
||||
=================
|
||||
- Include at least 2 expert_quotes with source_index
|
||||
- Include at least 2 listener_cta_suggestions
|
||||
- Include at least 2 mapped_angles
|
||||
- Include specific data points, percentages, statistics
|
||||
- Write in conversational tone
|
||||
"""
|
||||
try:
|
||||
logger.warning(f"[Podcast Research] Calling LLM for insight extraction...")
|
||||
logger.warning(f"[Podcast Research] Calling LLM with json_struct...")
|
||||
llm_response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
json_struct=PodcastResearchOutput.model_json_schema(),
|
||||
preferred_provider=None,
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
@@ -231,13 +366,22 @@ QUALITY STANDARDS:
|
||||
try:
|
||||
summary = data.get("summary", "")
|
||||
key_insights = [PodcastResearchInsight(**insight) for insight in data.get("key_insights", [])]
|
||||
expert_quotes = data.get("expert_quotes", [])
|
||||
listener_cta_suggestions = data.get("listener_cta_suggestions", [])
|
||||
mapped_angles = data.get("mapped_angles", [])
|
||||
except Exception as insight_err:
|
||||
logger.warning(f"[Podcast Research] Failed to parse insights: {insight_err}. Data keys: {list(data.keys()) if isinstance(data, dict) else 'not a dict'}")
|
||||
summary = data.get("summary", "") if isinstance(data, dict) else ""
|
||||
key_insights = []
|
||||
expert_quotes = data.get("expert_quotes", []) if isinstance(data, dict) else []
|
||||
listener_cta_suggestions = data.get("listener_cta_suggestions", []) if isinstance(data, dict) else []
|
||||
mapped_angles = data.get("mapped_angles", []) if isinstance(data, dict) else []
|
||||
else:
|
||||
summary = ""
|
||||
key_insights = []
|
||||
expert_quotes = []
|
||||
listener_cta_suggestions = []
|
||||
mapped_angles = []
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -289,14 +433,41 @@ QUALITY STANDARDS:
|
||||
"credibility_score": src.get("credibility_score"),
|
||||
}))
|
||||
|
||||
duration_minutes = 10
|
||||
speakers = 1
|
||||
if request.analysis:
|
||||
duration_minutes = int(request.analysis.get("duration", 10) or 10)
|
||||
speakers = int(request.analysis.get("speakers", 1) or 1)
|
||||
|
||||
estimate = estimate_podcast_cost(
|
||||
db=db,
|
||||
duration_minutes=duration_minutes,
|
||||
speakers=speakers,
|
||||
query_count=len(queries),
|
||||
include_avatar_phase=True,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[Podcast Research] ===== RESEARCH_END (took {duration_ms}ms) =====")
|
||||
logger.warning(f"[Podcast Research] sources={len(sources_payload)}, insights={len(key_insights)}, summary_len={len(summary)}")
|
||||
|
||||
return PodcastExaResearchResponse(
|
||||
sources=sources_payload,
|
||||
search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries,
|
||||
summary=summary,
|
||||
key_insights=key_insights,
|
||||
cost=result.get("cost") if isinstance(result, dict) else None,
|
||||
cost_est=_build_research_cost_estimate(
|
||||
request=request,
|
||||
raw_content=raw_content,
|
||||
sources_count=len(sources_payload),
|
||||
provider_result=result if isinstance(result, dict) else {},
|
||||
user_id=user_id,
|
||||
),
|
||||
search_type=result.get("search_type") if isinstance(result, dict) else None,
|
||||
provider=result.get("provider", "exa") if isinstance(result, dict) else "exa",
|
||||
content=raw_content,
|
||||
mapped_angles=mapped_angles,
|
||||
expert_quotes=expert_quotes,
|
||||
listener_cta_suggestions=listener_cta_suggestions,
|
||||
estimate=estimate,
|
||||
)
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
@@ -23,6 +25,8 @@ from ..models import (
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
MAX_TTS_CHARS_PER_REQUEST = 10_000
|
||||
TARGET_TTS_CHARS_PER_SCENE = 8_500
|
||||
|
||||
|
||||
class SceneApprovalRequest(BaseModel):
|
||||
@@ -57,31 +61,46 @@ async def generate_podcast_script(
|
||||
Generate a podcast script outline (scenes + lines) using podcast-oriented prompting.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.warning(f"[ScriptGen] ========== SCRIPT GENERATION START ==========")
|
||||
logger.warning(f"[ScriptGen] Topic: {request.idea[:60]}...")
|
||||
logger.warning(f"[ScriptGen] Duration: {request.duration_minutes} min, Speakers: {request.speakers}")
|
||||
logger.warning(f"[ScriptGen] Has research: {bool(request.research)}, Has bible: {bool(request.bible)}, Has analysis: {bool(request.analysis)}")
|
||||
start_time = time.time()
|
||||
logger.warning(f"[ScriptGen] ===== SCRIPT_GEN_START =====")
|
||||
logger.warning(f"[ScriptGen] user={user_id}, topic='{request.idea[:50]}...', duration={request.duration_minutes}min, speakers={request.speakers}")
|
||||
podcast_mode = (request.podcast_mode or "video_only").strip().lower()
|
||||
logger.warning(f"[ScriptGen] research={bool(request.research)}, bible={bool(request.bible)}, analysis={bool(request.analysis)}, mode={podcast_mode}")
|
||||
research_fact_cards = request.research.get("factCards", []) if request.research else []
|
||||
|
||||
# Build comprehensive research context for higher-quality scripts
|
||||
research_context = ""
|
||||
if request.research:
|
||||
try:
|
||||
key_insights = request.research.get("keyword_analysis", {}).get("key_insights") or []
|
||||
fact_cards = request.research.get("factCards", []) or []
|
||||
fact_cards = research_fact_cards or []
|
||||
mapped_angles = request.research.get("mappedAngles", []) or []
|
||||
sources = request.research.get("sources", []) or []
|
||||
|
||||
top_facts = [f.get("quote", "") for f in fact_cards[:5] if f.get("quote")]
|
||||
top_facts = [
|
||||
f"[{f.get('id') or f'fact_{idx + 1}'}] {f.get('quote', '')}"
|
||||
for idx, f in enumerate(fact_cards[:10])
|
||||
if f.get("quote")
|
||||
]
|
||||
angles_summary = [
|
||||
f"{a.get('title', '')}: {a.get('why', '')}" for a in mapped_angles[:3] if a.get("title") or a.get("why")
|
||||
]
|
||||
top_sources = [s.get("url") for s in sources[:3] if s.get("url")]
|
||||
numeric_signals = []
|
||||
for f in fact_cards[:12]:
|
||||
quote = (f.get("quote") or "").strip()
|
||||
if any(ch.isdigit() for ch in quote):
|
||||
numeric_signals.append(quote[:180])
|
||||
if len(numeric_signals) >= 5:
|
||||
break
|
||||
|
||||
research_parts = []
|
||||
if key_insights:
|
||||
research_parts.append(f"Key Insights: {', '.join(key_insights[:5])}")
|
||||
if top_facts:
|
||||
research_parts.append(f"Key Facts: {', '.join(top_facts)}")
|
||||
if numeric_signals:
|
||||
research_parts.append(f"Numeric Signals (prefer for chart scenes): {' | '.join(numeric_signals)}")
|
||||
if angles_summary:
|
||||
research_parts.append(f"Research Angles: {' | '.join(angles_summary)}")
|
||||
if top_sources:
|
||||
@@ -92,6 +111,53 @@ async def generate_podcast_script(
|
||||
logger.warning(f"Failed to parse research context: {exc}")
|
||||
research_context = ""
|
||||
|
||||
def _normalize_fact_ids(value: Any) -> Optional[list[str]]:
|
||||
if not value:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
cleaned = [str(v).strip() for v in value if str(v).strip()]
|
||||
return cleaned or None
|
||||
if isinstance(value, str) and value.strip():
|
||||
return [value.strip()]
|
||||
return None
|
||||
|
||||
def _default_chart_data(scene_title: str) -> Dict[str, Any]:
|
||||
numeric_pairs: list[tuple[str, float]] = []
|
||||
for fact in research_fact_cards[:12]:
|
||||
quote = (fact.get("quote") or "").strip()
|
||||
if not quote:
|
||||
continue
|
||||
nums = re.findall(r"\d+(?:\.\d+)?", quote.replace(",", ""))
|
||||
if not nums:
|
||||
continue
|
||||
label = quote[:48] + ("…" if len(quote) > 48 else "")
|
||||
try:
|
||||
numeric_pairs.append((label, float(nums[0])))
|
||||
except ValueError:
|
||||
continue
|
||||
if len(numeric_pairs) >= 5:
|
||||
break
|
||||
|
||||
if numeric_pairs:
|
||||
labels = [p[0] for p in numeric_pairs]
|
||||
values = [p[1] for p in numeric_pairs]
|
||||
sources = [f.get("url", f.get("source", "")) for f in research_fact_cards[:12] if f.get("url") or f.get("source")]
|
||||
return {
|
||||
"type": "bar_comparison",
|
||||
"title": scene_title,
|
||||
"labels": labels,
|
||||
"values": values,
|
||||
"takeaway": "Data points sourced from research facts used in this scene.",
|
||||
"source": sources[0] if sources else "",
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "bullet_points",
|
||||
"title": scene_title,
|
||||
"bullet_points": ["Key point 1", "Key point 2", "Key point 3"],
|
||||
"takeaway": "Narration summary for this scene.",
|
||||
}
|
||||
|
||||
# Extract Podcast Bible context for hyper-personalization
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
@@ -122,25 +188,62 @@ async def generate_podcast_script(
|
||||
except:
|
||||
pass
|
||||
|
||||
mode_instructions = ""
|
||||
if podcast_mode == "audio_only":
|
||||
mode_instructions = f"""
|
||||
AUDIO-ONLY MODE RULES (CRITICAL):
|
||||
- This is an audio-only episode. Do NOT include avatar/image/camera instructions.
|
||||
- Keep each scene's total dialogue under {TARGET_TTS_CHARS_PER_SCENE} chars to stay below TTS max request size ({MAX_TTS_CHARS_PER_REQUEST}).
|
||||
- For every scene include chart_data so B-roll charts can be generated while narration plays.
|
||||
- Build script STRICTLY from RESEARCH context and cite fact linkage via usedFactIds.
|
||||
- If evidence is weak, say uncertainty explicitly rather than inventing facts.
|
||||
- Add natural TTS pacing in dialogue with markers like [pause:300ms], [pause:700ms], [emote:curious], [emote:serious].
|
||||
"""
|
||||
elif podcast_mode == "audio_video":
|
||||
mode_instructions = """
|
||||
AUDIO+VIDEO MODE:
|
||||
- Include rich narration that works for both listening and visual storytelling.
|
||||
- Use a balanced pace suitable for TTS and scene visuals.
|
||||
"""
|
||||
else:
|
||||
mode_instructions = """
|
||||
VIDEO-ONLY MODE:
|
||||
- Prioritize visual rhythm and concise narration per scene.
|
||||
"""
|
||||
|
||||
prompt = f"""Create a podcast script with scenes and dialogue.
|
||||
|
||||
{f"BIBLE: {bible_context[:1500]}" if bible_context else ""}
|
||||
{f"{analysis_context}" if analysis_context else ""}
|
||||
{f"{outline_context}" if outline_context else ""}
|
||||
{f"RESEARCH: {research_context[:1200]}" if research_context else ""}
|
||||
{f"RESEARCH: {research_context[:2500]}" if research_context else ""}
|
||||
{mode_instructions}
|
||||
|
||||
Topic: "{request.idea}"
|
||||
Duration: {request.duration_minutes} min | Speakers: {request.speakers}
|
||||
Podcast mode: {podcast_mode}
|
||||
|
||||
Return JSON with scenes array. Each scene:
|
||||
- id: string
|
||||
- title: short title (<=50 chars)
|
||||
- duration: seconds (total/5)
|
||||
- emotion: neutral|happy|excited|serious|curious|confident
|
||||
- lines: array of {{speaker, text, emphasis}}
|
||||
- lines: array of {{speaker, text, emphasis, usedFactIds, ttsHints}}
|
||||
- Use 2-4 LINES PER SCENE (shorter script = lower TTS costs)
|
||||
- Each line: 1-3 sentences, conversational
|
||||
- usedFactIds: include related fact ids when research facts are available (example: ["fact_1", "fact_3"])
|
||||
- ttsHints: optional list from [pause_300ms, pause_700ms, smile, serious_tone, emphasize_data]
|
||||
- Plain text only, no markdown
|
||||
- chart_data: object for B-roll mapping (required in audio_only)
|
||||
- type: bar_comparison|bar_horizontal|line_trend|pie|stacked_bar|bullet_points
|
||||
- title: short chart title
|
||||
- labels: list
|
||||
- values: list (same length as labels, required for bar/line/pie)
|
||||
- before/after: parallel lists of numbers (for bar_comparison only)
|
||||
- segments: list of {{name, values}} (for stacked_bar only)
|
||||
- bullet_points: list of strings (for bullet_points only)
|
||||
- takeaway: one sentence tying chart to narration
|
||||
- source: URL or citation for the data (e.g. "Research fact #3" or a URL from the research context)
|
||||
|
||||
COST OPTIMIZATION:
|
||||
- 5-6 scenes max for {request.duration_minutes} min episode
|
||||
@@ -178,25 +281,112 @@ COST OPTIMIZATION:
|
||||
scenes_data = data.get("scenes") or []
|
||||
if not isinstance(scenes_data, list):
|
||||
raise HTTPException(status_code=500, detail="LLM response missing scenes array")
|
||||
|
||||
if len(scenes_data) == 0:
|
||||
logger.warning("[ScriptGen] LLM returned empty scenes array")
|
||||
raise HTTPException(status_code=500, detail="LLM returned no scenes - please try again")
|
||||
|
||||
logger.warning(f"[ScriptGen] Processing {len(scenes_data)} scenes from LLM response")
|
||||
|
||||
valid_emotions = {"neutral", "happy", "excited", "serious", "curious", "confident"}
|
||||
|
||||
# Normalize scenes
|
||||
scenes: list[PodcastScene] = []
|
||||
total_lines_input = 0
|
||||
total_lines_output = 0
|
||||
dropped_empty_lines = 0
|
||||
|
||||
for idx, scene in enumerate(scenes_data):
|
||||
if not isinstance(scene, dict):
|
||||
logger.warning(f"[ScriptGen] Scene {idx} is not a dict, skipping")
|
||||
continue
|
||||
|
||||
title = scene.get("title") or f"Scene {idx + 1}"
|
||||
duration = int(scene.get("duration") or max(30, (request.duration_minutes * 60) // max(1, len(scenes_data))))
|
||||
emotion = scene.get("emotion") or "neutral"
|
||||
if emotion not in valid_emotions:
|
||||
logger.warning(f"[ScriptGen] Invalid emotion '{emotion}' in scene {idx}, defaulting to 'neutral'")
|
||||
emotion = "neutral"
|
||||
lines_raw = scene.get("lines") or []
|
||||
total_lines_input += len(lines_raw)
|
||||
lines: list[PodcastSceneLine] = []
|
||||
for line in lines_raw:
|
||||
|
||||
for line_idx, line in enumerate(lines_raw):
|
||||
if not isinstance(line, dict):
|
||||
logger.warning(f"[ScriptGen] Line {line_idx} in scene {idx} is not a dict, skipping")
|
||||
continue
|
||||
|
||||
speaker = line.get("speaker") or ("Host" if len(lines) % request.speakers == 0 else "Guest")
|
||||
text = line.get("text") or ""
|
||||
emphasis = line.get("emphasis", False)
|
||||
|
||||
# Handle emphasis - convert various values to boolean
|
||||
emphasis_raw = line.get("emphasis", False)
|
||||
if isinstance(emphasis_raw, bool):
|
||||
emphasis = emphasis_raw
|
||||
elif isinstance(emphasis_raw, str):
|
||||
emphasis = emphasis_raw.lower() in ("true", "yes", "1")
|
||||
if emphasis_raw.lower() not in ("true", "false", "yes", "no", "1", "0"):
|
||||
logger.debug(f"[ScriptGen] Unusual emphasis value '{emphasis_raw}' converted to {emphasis}")
|
||||
else:
|
||||
emphasis = bool(emphasis_raw)
|
||||
|
||||
# Generate line ID if not provided
|
||||
line_id = line.get("id") or f"line-{idx + 1}-{line_idx + 1}"
|
||||
|
||||
# Get used fact IDs if provided
|
||||
used_fact_ids = _normalize_fact_ids(line.get("usedFactIds") or line.get("used_fact_ids"))
|
||||
tts_hints = line.get("ttsHints") or line.get("tts_hints") or None
|
||||
|
||||
if text:
|
||||
lines.append(PodcastSceneLine(speaker=speaker, text=text, emphasis=emphasis))
|
||||
lines.append(PodcastSceneLine(
|
||||
speaker=speaker,
|
||||
text=text,
|
||||
emphasis=emphasis,
|
||||
id=line_id,
|
||||
usedFactIds=used_fact_ids,
|
||||
ttsHints=tts_hints if isinstance(tts_hints, list) else None,
|
||||
))
|
||||
total_lines_output += 1
|
||||
else:
|
||||
dropped_empty_lines += 1
|
||||
logger.debug(f"[ScriptGen] Dropped empty line {line_idx} in scene {idx}")
|
||||
|
||||
# Log scene status
|
||||
if scenes_data and isinstance(scene, dict):
|
||||
image_url_raw = scene.get("imageUrl") or scene.get("image_url")
|
||||
audio_url_raw = scene.get("audioUrl") or scene.get("audio_url")
|
||||
if image_url_raw:
|
||||
logger.warning(f"[ScriptGen] Scene {idx} has imageUrl - will be reset to None")
|
||||
if audio_url_raw:
|
||||
logger.warning(f"[ScriptGen] Scene {idx} has audioUrl - will be reset to None")
|
||||
|
||||
# Keep each scene under TTS request size to prevent failures
|
||||
scene_char_count = sum(len((l.text or "").strip()) for l in lines)
|
||||
if scene_char_count > TARGET_TTS_CHARS_PER_SCENE and lines:
|
||||
logger.warning(
|
||||
f"[ScriptGen] Scene {idx} text too long ({scene_char_count} chars). "
|
||||
f"Trimming to {TARGET_TTS_CHARS_PER_SCENE} target."
|
||||
)
|
||||
trimmed_lines: list[PodcastSceneLine] = []
|
||||
remaining = TARGET_TTS_CHARS_PER_SCENE
|
||||
for l in lines:
|
||||
if remaining <= 0:
|
||||
break
|
||||
line_text = (l.text or "").strip()
|
||||
if len(line_text) <= remaining:
|
||||
trimmed_lines.append(l)
|
||||
remaining -= len(line_text)
|
||||
continue
|
||||
l.text = f"{line_text[:max(0, remaining - 1)].rstrip()}…"
|
||||
trimmed_lines.append(l)
|
||||
remaining = 0
|
||||
lines = trimmed_lines
|
||||
|
||||
chart_data = scene.get("chart_data") or scene.get("chartData") or None
|
||||
if podcast_mode == "audio_only" and not chart_data:
|
||||
# Ensure audio-only always has a B-roll mapping fallback
|
||||
chart_data = _default_chart_data(title)
|
||||
|
||||
scenes.append(
|
||||
PodcastScene(
|
||||
id=scene.get("id") or f"scene-{idx + 1}",
|
||||
@@ -205,8 +395,19 @@ COST OPTIMIZATION:
|
||||
lines=lines,
|
||||
approved=False,
|
||||
emotion=emotion,
|
||||
imageUrl=None, # Will be generated later
|
||||
audioUrl=None, # Will be generated later
|
||||
imagePrompt=None, # Will be generated during image generation
|
||||
chart_data=chart_data if isinstance(chart_data, dict) else None,
|
||||
)
|
||||
)
|
||||
|
||||
# Summary logging
|
||||
logger.warning(f"[ScriptGen] Script generated: {len(scenes)} scenes, {total_lines_output}/{total_lines_input} lines")
|
||||
if dropped_empty_lines > 0:
|
||||
logger.warning(f"[ScriptGen] Dropped {dropped_empty_lines} empty lines")
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[ScriptGen] ===== SCRIPT_GEN_END (took {duration_ms}ms) =====")
|
||||
|
||||
return PodcastScriptResponse(scenes=scenes)
|
||||
|
||||
|
||||
338
backend/api/podcast/handlers/tavily_category_research.py
Normal file
338
backend/api/podcast/handlers/tavily_category_research.py
Normal file
@@ -0,0 +1,338 @@
|
||||
"""
|
||||
Category Research Handlers
|
||||
|
||||
Research endpoints using Tavily or Exa for category-based topic discovery.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from types import SimpleNamespace
|
||||
from sqlalchemy import text
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.research.tavily_service import TavilyService
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
router = APIRouter(prefix="/research", tags=["Podcast Category Research"])
|
||||
|
||||
CATEGORY_PROVIDER_MAP = {
|
||||
"news": "tavily",
|
||||
"finance": "tavily",
|
||||
"research-paper": "exa",
|
||||
"personal-site": "exa",
|
||||
}
|
||||
|
||||
EXA_CATEGORY_MAP = {
|
||||
"research-paper": "research paper",
|
||||
"personal-site": "personal site",
|
||||
}
|
||||
|
||||
|
||||
def _preflight_check(user_id: str, provider: APIProvider, provider_name: str):
|
||||
"""Check subscription limits before making a research API call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=0,
|
||||
actual_provider_name=provider_name,
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': provider_name, 'usage_info': usage_info or {}
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[CategoryResearch] Preflight check failed for {provider_name}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_research_usage(user_id: str, provider_name: str, cost: float, calls_column: str, cost_column: str):
|
||||
"""Track research API usage after successful call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[CategoryResearch] Could not get DB session for user {user_id}")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
update_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {calls_column} = COALESCE({calls_column}, 0) + 1,
|
||||
{cost_column} = COALESCE({cost_column}, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[CategoryResearch] Tracked {provider_name} usage: user={user_id}, cost=${cost}")
|
||||
|
||||
# Clear dashboard cache so header stats update immediately
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[CategoryResearch] Failed to clear dashboard cache: {cache_err}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Failed to track {provider_name} usage: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class CategoryResearchRequest(BaseModel):
|
||||
category: str
|
||||
keyword: Optional[str] = None
|
||||
max_results: Optional[int] = 8
|
||||
website_url: Optional[str] = None
|
||||
|
||||
|
||||
class CategoryTopic(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
score: float
|
||||
favicon: Optional[str] = None
|
||||
|
||||
|
||||
class CategoryResearchResponse(BaseModel):
|
||||
success: bool
|
||||
category: str
|
||||
provider: str
|
||||
topics: List[CategoryTopic]
|
||||
query: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def _normalize_tavily_results(results: List[Dict]) -> List[CategoryTopic]:
|
||||
topics = []
|
||||
for item in results:
|
||||
topics.append(CategoryTopic(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("content", ""),
|
||||
score=item.get("score", 0.0),
|
||||
favicon=item.get("favicon"),
|
||||
))
|
||||
return topics
|
||||
|
||||
|
||||
def _normalize_exa_results(results: List[Dict], query: str) -> List[CategoryTopic]:
|
||||
topics = []
|
||||
for idx, item in enumerate(results):
|
||||
score = 1.0 - (idx * 0.1)
|
||||
topics.append(CategoryTopic(
|
||||
title=item.get("title", "") or f"Result {idx + 1}",
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("summary", "") or item.get("text", "") or "",
|
||||
score=max(0.5, score),
|
||||
favicon=None,
|
||||
))
|
||||
return topics
|
||||
|
||||
|
||||
async def _search_tavily(category: str, keyword: str, max_results: int, user_id: str) -> CategoryResearchResponse:
|
||||
logger.info(f"[CategoryResearch] Using Tavily for category={category}, keyword={keyword}")
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.TAVILY, "tavily")
|
||||
|
||||
try:
|
||||
tavily = TavilyService()
|
||||
result = await tavily.search(
|
||||
query=keyword,
|
||||
topic=category,
|
||||
search_depth="basic",
|
||||
max_results=max_results,
|
||||
include_favicon=True,
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=result.get("error", "Tavily search failed")
|
||||
)
|
||||
|
||||
topics = _normalize_tavily_results(result.get("results", []))
|
||||
logger.info(f"[CategoryResearch] Tavily found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.001 # basic search = 1 credit
|
||||
_track_research_usage(user_id, "tavily", cost, "tavily_calls", "tavily_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
provider="tavily",
|
||||
topics=topics,
|
||||
query=keyword,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Tavily error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _search_exa(category: str, keyword: str, max_results: int, user_id: str, website_url: Optional[str] = None) -> CategoryResearchResponse:
|
||||
exa_category = EXA_CATEGORY_MAP.get(category, category)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa: category={category}, exa_category={exa_category}, keyword={keyword}, website_url={website_url}")
|
||||
|
||||
try:
|
||||
# Import exa directly for more control
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
exa_api_key = os.getenv("EXA_API_KEY")
|
||||
if not exa_api_key:
|
||||
raise HTTPException(status_code=500, detail="EXA_API_KEY not configured")
|
||||
|
||||
from exa_py import Exa
|
||||
exa = Exa(exa_api_key)
|
||||
logger.info(f"[CategoryResearch] Exa client initialized")
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.EXA, "exa")
|
||||
|
||||
# Build search parameters
|
||||
search_params = {
|
||||
"num_results": max_results,
|
||||
"category": exa_category,
|
||||
}
|
||||
|
||||
# For personal-site, extract domain from URL if provided
|
||||
include_domains = None
|
||||
if category == "personal-site" and website_url:
|
||||
try:
|
||||
parsed = urlparse(website_url)
|
||||
if parsed.netloc:
|
||||
include_domains = [parsed.netloc]
|
||||
logger.info(f"[CategoryResearch] Personal site - limiting to domain: {parsed.netloc}")
|
||||
elif parsed.path and "." in parsed.path:
|
||||
# Could be domain without protocol
|
||||
include_domains = [parsed.path]
|
||||
logger.info(f"[CategoryResearch] Personal site - using as domain: {parsed.path}")
|
||||
except Exception as url_err:
|
||||
logger.warning(f"[CategoryResearch] Failed to parse website_url: {url_err}")
|
||||
|
||||
logger.info(f"[CategoryResearch] Calling Exa with params: {search_params}, include_domains={include_domains}")
|
||||
|
||||
# Make the search call
|
||||
results = exa.search_and_contents(
|
||||
query=keyword,
|
||||
type="auto" if category != "personal-site" else "neural",
|
||||
num_results=max_results,
|
||||
category=exa_category,
|
||||
text=True,
|
||||
summary=True,
|
||||
include_domains=include_domains,
|
||||
)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa search completed, got results")
|
||||
|
||||
# Transform results to our format
|
||||
topics = []
|
||||
if results and hasattr(results, 'results'):
|
||||
for item in results.results:
|
||||
title = getattr(item, 'title', 'Untitled')
|
||||
url = getattr(item, 'url', '')
|
||||
snippet = getattr(item, 'summary', '') or getattr(item, 'text', '') or ''
|
||||
score = 0.8 # Default score for Exa results
|
||||
|
||||
topics.append(CategoryTopic(
|
||||
title=title,
|
||||
url=url,
|
||||
snippet=snippet[:300] if snippet else '',
|
||||
score=score,
|
||||
favicon=None,
|
||||
))
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.005 # Default Exa cost for 1-25 results
|
||||
_track_research_usage(user_id, "exa", cost, "exa_calls", "exa_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
provider="exa",
|
||||
topics=topics,
|
||||
query=keyword,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"[CategoryResearch] Exa error: {type(e).__name__}: {e}")
|
||||
logger.error(f"[CategoryResearch] Stack: {traceback.format_exc()}")
|
||||
raise HTTPException(status_code=500, detail=f"Exa search failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tavily-category", response_model=CategoryResearchResponse)
|
||||
async def research_by_category(
|
||||
request: CategoryResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Research topics by category using Tavily or Exa.
|
||||
|
||||
Categories:
|
||||
- news, finance: Uses Tavily
|
||||
- research-paper, personal-site: Uses Exa
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
category = request.category.lower()
|
||||
valid_categories = list(CATEGORY_PROVIDER_MAP.keys())
|
||||
|
||||
logger.info(f"[CategoryResearch] Full request payload: category={request.category}, keyword={request.keyword}, website_url={request.website_url}")
|
||||
|
||||
if category not in valid_categories:
|
||||
logger.error(f"[CategoryResearch] Invalid category: {category}, valid: {valid_categories}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Category must be one of: {', '.join(valid_categories)}"
|
||||
)
|
||||
|
||||
keyword = request.keyword or category
|
||||
max_results = min(max(request.max_results or 8, 5), 10)
|
||||
website_url = request.website_url
|
||||
|
||||
logger.info(f"[CategoryResearch] Processing: category={category}, keyword={keyword}, max_results={max_results}, website_url={website_url}")
|
||||
|
||||
provider = CATEGORY_PROVIDER_MAP.get(category, "tavily")
|
||||
logger.info(f"[CategoryResearch] Selected provider: {provider} for category: {category}")
|
||||
|
||||
try:
|
||||
if provider == "tavily":
|
||||
return await _search_tavily(category, keyword, max_results, user_id)
|
||||
elif provider == "exa":
|
||||
return await _search_exa(category, keyword, max_results, user_id, website_url)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unknown provider")
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Outer error: {type(e).__name__}: {e}", exc_info=True)
|
||||
raise
|
||||
119
backend/api/podcast/handlers/trends.py
Normal file
119
backend/api/podcast/handlers/trends.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""
|
||||
Podcast Trends Handler
|
||||
|
||||
Endpoints for fetching Google Trends data relevant to podcast topics.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/trends", tags=["Podcast Trends"])
|
||||
|
||||
# Module-level shared instance (singleton pattern)
|
||||
_trends_service_instance = None
|
||||
_trends_service_lock = None
|
||||
|
||||
|
||||
def get_trends_service():
|
||||
"""Get or create shared GoogleTrendsService instance."""
|
||||
global _trends_service_instance, _trends_service_lock
|
||||
if _trends_service_instance is None:
|
||||
try:
|
||||
from services.research.trends import GoogleTrendsService
|
||||
_trends_service_instance = GoogleTrendsService()
|
||||
_trends_service_lock = asyncio.Lock()
|
||||
logger.info("[Podcast Trends] Created shared GoogleTrendsService instance")
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] Failed to create GoogleTrendsService: {e}")
|
||||
raise
|
||||
return _trends_service_instance
|
||||
|
||||
|
||||
class PodcastTrendsRequest(BaseModel):
|
||||
keywords: List[str] = Field(..., min_length=1, max_length=5, description="1-5 keywords to analyze")
|
||||
timeframe: str = Field(default="today 12-m", description="Timeframe: 'today 3-m', 'today 12-m', 'today 5-y', 'all'")
|
||||
geo: str = Field(default="US", description="Country code: 'US', 'GB', 'IN', etc.")
|
||||
source: str = Field(default="web", description="Data source: 'web' (Google), 'podcast' (YouTube)")
|
||||
|
||||
|
||||
class PodcastTrendsResponse(BaseModel):
|
||||
success: bool
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("", response_model=PodcastTrendsResponse)
|
||||
async def get_podcast_trends(
|
||||
request: PodcastTrendsRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Fetch Google Trends data for podcast topic keywords."""
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
try:
|
||||
service = get_trends_service()
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] GoogleTrendsService unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Google Trends service is currently unavailable. Please try again later."
|
||||
)
|
||||
|
||||
try:
|
||||
# Map 'source' to 'gprop' - 'podcast' uses YouTube for video/podcast relevance
|
||||
gprop_map = {"": "", "web": "", "podcast": "youtube", "news": "news", "images": "images", "shopping": "froogle"}
|
||||
gprop = gprop_map.get(request.source, "")
|
||||
|
||||
result = await service.analyze_trends(
|
||||
keywords=request.keywords,
|
||||
timeframe=request.timeframe,
|
||||
geo=request.geo,
|
||||
gprop=gprop,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
has_error = result.get("error")
|
||||
has_data = (
|
||||
len(result.get("interest_over_time", [])) > 0
|
||||
or len(result.get("interest_by_region", [])) > 0
|
||||
or len(result.get("related_topics", {}).get("top", [])) > 0
|
||||
or len(result.get("related_topics", {}).get("rising", [])) > 0
|
||||
or len(result.get("related_queries", {}).get("top", [])) > 0
|
||||
or len(result.get("related_queries", {}).get("rising", [])) > 0
|
||||
)
|
||||
|
||||
# Return error if: has error OR no data (meaning blocked/empty)
|
||||
if has_error and not has_data:
|
||||
error_msg = result.get("error", "")
|
||||
cooldown_active = result.get("cooldown_active", False)
|
||||
logger.warning(f"[Trends] No data or error: {error_msg[:100]}")
|
||||
# Provide helpful message during cooldown
|
||||
if cooldown_active:
|
||||
return PodcastTrendsResponse(
|
||||
success=False,
|
||||
data=result,
|
||||
error="Google is rate limiting requests. Try using 'Get Trending Topics' instead, or wait 30 minutes."
|
||||
)
|
||||
return PodcastTrendsResponse(success=False, data=result, error=error_msg or "No trends data available. Google may be blocking requests.")
|
||||
|
||||
# Even if no error but empty data - return error
|
||||
if not has_data:
|
||||
logger.warning("[Trends] Empty data returned")
|
||||
return PodcastTrendsResponse(success=False, data=result, error="No trends data available. Please try different keywords.")
|
||||
|
||||
return PodcastTrendsResponse(success=True, data=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast Trends] Error fetching trends for {request.keywords}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch trends data: {str(e)}"
|
||||
)
|
||||
@@ -321,7 +321,7 @@ async def generate_podcast_video(
|
||||
|
||||
# Load image bytes (scene image is required for video generation)
|
||||
if body.avatar_image_url:
|
||||
image_bytes = load_podcast_image_bytes(body.avatar_image_url)
|
||||
image_bytes = load_podcast_image_bytes(body.avatar_image_url, user_id=user_id)
|
||||
else:
|
||||
# Scene-specific image should be generated before video generation
|
||||
raise HTTPException(
|
||||
@@ -332,7 +332,7 @@ async def generate_podcast_video(
|
||||
mask_image_bytes = None
|
||||
if body.mask_image_url:
|
||||
try:
|
||||
mask_image_bytes = load_podcast_image_bytes(body.mask_image_url)
|
||||
mask_image_bytes = load_podcast_image_bytes(body.mask_image_url, user_id=user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Failed to load mask image: {e}")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -5,7 +5,7 @@ All Pydantic request/response models for podcast endpoints.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import List, Optional, Dict, Any
|
||||
from typing import List, Optional, Dict, Any, Literal
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
@@ -54,6 +54,7 @@ class PodcastAnalyzeRequest(BaseModel):
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
avatar_url: Optional[str] = Field(None, description="Current avatar URL if selected")
|
||||
feedback: Optional[str] = Field(None, description="User feedback for regeneration")
|
||||
podcast_mode: Optional[str] = Field(None, description="Podcast mode: audio_only, video_only, or audio_video")
|
||||
|
||||
|
||||
class PodcastAnalyzeResponse(BaseModel):
|
||||
@@ -72,12 +73,21 @@ class PodcastAnalyzeResponse(BaseModel):
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
avatar_url: Optional[str] = None
|
||||
avatar_prompt: Optional[str] = None
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaRequest(BaseModel):
|
||||
"""Request model for enhancing a podcast idea with AI."""
|
||||
idea: str = Field(..., description="The raw podcast idea or keywords")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
website_data: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional website extraction data for enriched context (title, summary, highlights, subpages, url)"
|
||||
)
|
||||
topic_context: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional category research context (category, topics, selected_topic)"
|
||||
)
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaResponse(BaseModel):
|
||||
@@ -95,12 +105,16 @@ class PodcastScriptRequest(BaseModel):
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
outline: Optional[Dict[str, Any]] = Field(None, description="The refined episode outline to follow")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="The full analysis context (audience, keywords, etc.)")
|
||||
podcast_mode: Optional[str] = Field(default="video_only", description="Podcast mode: audio_only, video_only, or audio_video")
|
||||
|
||||
|
||||
class PodcastSceneLine(BaseModel):
|
||||
speaker: str
|
||||
text: str
|
||||
emphasis: Optional[bool] = False
|
||||
id: Optional[str] = None # Optional line ID for frontend tracking
|
||||
usedFactIds: Optional[List[str]] = None # Facts referenced in this line
|
||||
ttsHints: Optional[List[str]] = None # Optional TTS hints, e.g. pause_300ms, smile, emphasize_data
|
||||
|
||||
|
||||
class PodcastScene(BaseModel):
|
||||
@@ -111,6 +125,9 @@ class PodcastScene(BaseModel):
|
||||
approved: bool = False
|
||||
emotion: Optional[str] = None
|
||||
imageUrl: Optional[str] = None # Generated image URL for video generation
|
||||
audioUrl: Optional[str] = None # Generated audio URL for this scene
|
||||
imagePrompt: Optional[str] = None # Original image generation prompt for video context
|
||||
chart_data: Optional[Dict[str, Any]] = None # Optional chart mapping for B-roll scenes
|
||||
|
||||
|
||||
class PodcastExaConfig(BaseModel):
|
||||
@@ -167,15 +184,40 @@ class PodcastResearchInsight(BaseModel):
|
||||
listener_cta_suggestions: Optional[List[str]] = [] # CTA suggestions
|
||||
|
||||
|
||||
class PodcastResearchOutput(BaseModel):
|
||||
"""Structured JSON output for LLM research extraction using json_struct."""
|
||||
summary: str = ""
|
||||
key_insights: List[PodcastResearchInsight] = []
|
||||
expert_quotes: List[Dict[str, Any]] = [] # [{"quote": str, "source_index": int, "context": str}]
|
||||
listener_cta_suggestions: List[str] = [] # List of CTA suggestions
|
||||
mapped_angles: List[Dict[str, Any]] = [] # [{"title": str, "why": str, "mapped_fact_ids": []}]
|
||||
|
||||
|
||||
class PodcastCostBreakdownItem(BaseModel):
|
||||
phase: Literal["Analyze", "Gather", "Write", "Produce"]
|
||||
cost: float
|
||||
|
||||
|
||||
class PodcastCostEst(BaseModel):
|
||||
total: float
|
||||
breakdown: List[PodcastCostBreakdownItem]
|
||||
currency: Literal["USD"] = "USD"
|
||||
last_updated: datetime
|
||||
|
||||
|
||||
class PodcastExaResearchResponse(BaseModel):
|
||||
sources: List[PodcastExaSource]
|
||||
search_queries: List[str] = []
|
||||
summary: str = ""
|
||||
key_insights: List[PodcastResearchInsight] = []
|
||||
cost: Optional[Dict[str, Any]] = None
|
||||
cost_est: PodcastCostEst
|
||||
search_type: Optional[str] = None
|
||||
provider: str = "exa"
|
||||
content: Optional[str] = None # Raw aggregated content (deprecated)
|
||||
mapped_angles: List[Dict[str, Any]] = [] # Content angles for the episode
|
||||
expert_quotes: List[Dict[str, Any]] = [] # Expert quotes from research
|
||||
listener_cta_suggestions: List[str] = [] # CTA suggestions
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class PodcastScriptResponse(BaseModel):
|
||||
@@ -189,6 +231,9 @@ class PodcastAudioRequest(BaseModel):
|
||||
text: str
|
||||
voice_id: Optional[str] = "Wise_Woman"
|
||||
custom_voice_id: Optional[str] = None # Voice clone ID for custom voice
|
||||
use_voice_clone: Optional[bool] = False # If True, use voice clone with voice_sample_url
|
||||
voice_sample_url: Optional[str] = None # URL to user's voice sample for cloning
|
||||
voice_clone_engine: Optional[str] = None # Engine: "qwen3", "minimax", "cosyvoice"
|
||||
speed: Optional[float] = 1.0
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[float] = 0.0
|
||||
@@ -434,3 +479,58 @@ class VoiceCloneResult(BaseModel):
|
||||
task_id: str
|
||||
status: str = "completed"
|
||||
|
||||
|
||||
class ExtractUrlRequest(BaseModel):
|
||||
"""Request to extract content from a URL using Exa."""
|
||||
url: str = Field(..., description="URL to extract content from")
|
||||
|
||||
|
||||
class ExtractUrlResponse(BaseModel):
|
||||
"""Response with extracted content from URL."""
|
||||
success: bool
|
||||
title: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
highlights: Optional[List[str]] = Field(default_factory=list, description="Key highlights from the content")
|
||||
url: str
|
||||
image: Optional[str] = None
|
||||
favicon: Optional[str] = None
|
||||
subpages: Optional[List[Dict[str, Any]]] = Field(default_factory=list, description="Subpages with their own content")
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class WebsiteAnalysisRequest(BaseModel):
|
||||
"""Request to save user's website analysis."""
|
||||
website_url: str = Field(..., description="The website URL")
|
||||
exa_content: Dict[str, Any] = Field(default_factory=dict, description="Exa extracted content")
|
||||
|
||||
|
||||
class WebsiteAnalysisResponse(BaseModel):
|
||||
"""Response for website analysis."""
|
||||
success: bool
|
||||
website_url: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastPreEstimateRequest(BaseModel):
|
||||
"""Request model for pre-analysis cost estimate."""
|
||||
duration: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
query_count: int = Field(default=3, description="Number of research queries")
|
||||
podcast_mode: str = Field(default="audio_video", description="Podcast mode: audio_only, video_only, or audio_video")
|
||||
# Optional model overrides for cost estimation
|
||||
gemini_model: Optional[str] = Field(default=None, description="LLM model: gemini-2.5-flash, gemini-1.5-flash, etc.")
|
||||
audio_tts_model: Optional[str] = Field(default=None, description="Audio TTS model: minimax/speech-02-hd")
|
||||
voice_clone_engine: Optional[str] = Field(default=None, description="Voice clone engine: qwen3, cosyvoice, minimax")
|
||||
image_model: Optional[str] = Field(default=None, description="Image model: qwen-image, ideogram-v3-turbo")
|
||||
video_model: Optional[str] = Field(default=None, description="Video model: wan-2.5, kling-v2.5-turbo-std-5s, wavespeed-ai/infinitetalk")
|
||||
|
||||
|
||||
class PodcastPreEstimateResponse(BaseModel):
|
||||
"""Response model for pre-analysis cost estimate."""
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
pricing_available: bool = Field(default=False, description="Whether pricing data is available in DB")
|
||||
debug: Optional[Dict[str, Any]] = Field(default=None, description="Debug info: pricing rows count, providers")
|
||||
|
||||
24
backend/api/podcast/prompts/__init__.py
Normal file
24
backend/api/podcast/prompts/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
Prompts module for podcast topic enhancement.
|
||||
"""
|
||||
|
||||
from .website_enhance_prompts import (
|
||||
get_enhance_topic_prompt,
|
||||
format_website_context,
|
||||
STANDARD_ENHANCE_PROMPT,
|
||||
WEBSITE_AWARE_ENHANCE_PROMPT,
|
||||
)
|
||||
|
||||
from services.podcast_context_builder import (
|
||||
PodcastContextBuilder,
|
||||
context_builder,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_enhance_topic_prompt",
|
||||
"format_website_context",
|
||||
"STANDARD_ENHANCE_PROMPT",
|
||||
"WEBSITE_AWARE_ENHANCE_PROMPT",
|
||||
"PodcastContextBuilder",
|
||||
"context_builder",
|
||||
]
|
||||
187
backend/api/podcast/prompts/website_enhance_prompts.py
Normal file
187
backend/api/podcast/prompts/website_enhance_prompts.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
Website-aware prompts for podcast topic enhancement.
|
||||
|
||||
This module provides prompts for enhancing podcast topics with optional
|
||||
website extraction data for richer context.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from string import Template
|
||||
|
||||
|
||||
# Standard prompt for when no website data is available
|
||||
STANDARD_ENHANCE_PROMPT = Template("""">You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
|
||||
${bible_context}
|
||||
|
||||
RAW IDEA/KEYWORDS: "$idea"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
# Website-aware prompt for when website data is available
|
||||
WEBSITE_AWARE_ENHANCE_PROMPT = Template("""">You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea, enriched with website content analysis.
|
||||
|
||||
${bible_context}
|
||||
|
||||
WEBSITE CONTENT ANALYSIS:
|
||||
${website_context}
|
||||
|
||||
RAW IDEA/KEYWORDS: "$idea"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle, that INCORPORATE the website content context:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise from the website)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections tied to the brand)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance leveraging the site's focus areas)
|
||||
|
||||
Each version should:
|
||||
- Be 2-3 sentences
|
||||
- Reference specific elements from the website content when relevant
|
||||
- Be audience-focused and align with host persona if provided
|
||||
- NOT just repeat the website summary - create fresh podcast angles
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
def get_enhance_topic_prompt(
|
||||
idea: str,
|
||||
bible_context: str = "",
|
||||
website_data: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Returns the appropriate prompt based on available context.
|
||||
|
||||
Args:
|
||||
idea: The raw podcast idea or keywords
|
||||
bible_context: Optional Podcast Bible context string
|
||||
website_data: Optional website extraction data
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with appropriate context
|
||||
"""
|
||||
# Build bible context section
|
||||
bible_section = f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""
|
||||
|
||||
if website_data:
|
||||
# Build website context section
|
||||
website_context_parts = []
|
||||
if website_data.get('url'):
|
||||
website_context_parts.append(f"Source: {website_data.get('url')}")
|
||||
if website_data.get('title'):
|
||||
website_context_parts.append(f"Company/Organization: {website_data.get('title')}")
|
||||
if website_data.get('summary'):
|
||||
website_context_parts.append(f"About: {website_data.get('summary')}")
|
||||
if website_data.get('highlights'):
|
||||
highlights_str = ', '.join(website_data.get('highlights', [])[:3])
|
||||
website_context_parts.append(f"Key Highlights: {highlights_str}")
|
||||
if website_data.get('subpages'):
|
||||
subpages_str = ', '.join([
|
||||
sp.get('title', sp.get('url', ''))
|
||||
for sp in website_data.get('subpages', [])[:3]
|
||||
])
|
||||
website_context_parts.append(f"Subpages: {subpages_str}")
|
||||
|
||||
website_context_str = "\n".join(website_context_parts)
|
||||
|
||||
return WEBSITE_AWARE_ENHANCE_PROMPT.substitute(
|
||||
idea=idea,
|
||||
bible_context=bible_section,
|
||||
website_context=website_context_str
|
||||
)
|
||||
else:
|
||||
return STANDARD_ENHANCE_PROMPT.substitute(
|
||||
idea=idea,
|
||||
bible_context=bible_section
|
||||
)
|
||||
|
||||
|
||||
def format_website_context(website_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format website data for inclusion in progress messages.
|
||||
|
||||
Args:
|
||||
website_data: Website extraction data
|
||||
|
||||
Returns:
|
||||
Formatted string describing what's being used
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if website_data.get('title'):
|
||||
parts.append(f"• {website_data['title']}")
|
||||
|
||||
if website_data.get('summary'):
|
||||
summary_preview = website_data['summary'][:100]
|
||||
parts.append(f"• Summary: {summary_preview}...")
|
||||
|
||||
if website_data.get('highlights'):
|
||||
parts.append(f"• {len(website_data['highlights'])} key highlights")
|
||||
|
||||
if website_data.get('subpages'):
|
||||
parts.append(f"• {len(website_data['subpages'])} subpages analyzed")
|
||||
|
||||
if website_data.get('url'):
|
||||
parts.append(f"• Source: {website_data['url']}")
|
||||
|
||||
return "\n".join(parts) if parts else "Basic website analysis"
|
||||
|
||||
if website_data.get('title'):
|
||||
parts.append(f"• {website_data['title']}")
|
||||
|
||||
if website_data.get('summary'):
|
||||
summary_preview = website_data['summary'][:100]
|
||||
parts.append(f"• Summary: {summary_preview}...")
|
||||
|
||||
if website_data.get('highlights'):
|
||||
parts.append(f"• {len(website_data['highlights'])} key highlights")
|
||||
|
||||
if website_data.get('subpages'):
|
||||
parts.append(f"• {len(website_data['subpages'])} subpages analyzed")
|
||||
|
||||
if website_data.get('url'):
|
||||
parts.append(f"• Source: {website_data['url']}")
|
||||
|
||||
return "\n".join(parts) if parts else "Basic website analysis"
|
||||
@@ -12,7 +12,7 @@ from api.story_writer.utils.auth import require_authenticated_user
|
||||
from api.story_writer.task_manager import task_manager
|
||||
|
||||
# Import all handler routers
|
||||
from .handlers import projects, analysis, research, script, audio, images, video, avatar, dubbing
|
||||
from .handlers import projects, analysis, research, script, audio, images, video, avatar, dubbing, broll, trends, tavily_category_research
|
||||
|
||||
# Create main router
|
||||
router = APIRouter(prefix="/api/podcast", tags=["Podcast Maker"])
|
||||
@@ -27,6 +27,9 @@ router.include_router(images.router)
|
||||
router.include_router(video.router)
|
||||
router.include_router(avatar.router)
|
||||
router.include_router(dubbing.router)
|
||||
router.include_router(broll.router)
|
||||
router.include_router(trends.router)
|
||||
router.include_router(tavily_category_research.router)
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/status")
|
||||
|
||||
@@ -67,15 +67,32 @@ def load_podcast_audio_bytes(audio_url: str, user_id: str | None = None) -> byte
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load audio: {str(exc)}")
|
||||
|
||||
|
||||
def load_podcast_image_bytes(image_url: str) -> bytes:
|
||||
"""Load podcast image bytes from URL. Uses centralized media loader."""
|
||||
def load_podcast_image_bytes(image_url: str, user_id: str | None = None) -> bytes:
|
||||
"""Load podcast image bytes from URL. Resolves from workspace first."""
|
||||
if not image_url:
|
||||
raise HTTPException(status_code=400, detail="Image URL is required")
|
||||
|
||||
logger.info(f"[Podcast] Loading image from URL: {image_url}")
|
||||
|
||||
try:
|
||||
# REUSE: Use centralized media loader which handles cross-module lookups
|
||||
# Extract filename from URL path
|
||||
prefix = "/api/podcast/images/"
|
||||
if prefix in image_url:
|
||||
filename = image_url.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
# Handle subdirectories like avatars/
|
||||
subdir = None
|
||||
if "/" in filename:
|
||||
subdir_part = filename.rsplit("/", 1)[0]
|
||||
subdir = Path(subdir_part)
|
||||
filename = filename.rsplit("/", 1)[1]
|
||||
|
||||
try:
|
||||
image_path = _resolve_podcast_media_file(filename, "image", user_id, subdir=subdir)
|
||||
return image_path.read_bytes()
|
||||
except HTTPException:
|
||||
pass # Fall through to centralized loader
|
||||
|
||||
# Fall back to centralized media loader
|
||||
image_bytes = load_media_bytes(image_url)
|
||||
|
||||
if not image_bytes:
|
||||
|
||||
@@ -8,9 +8,14 @@ def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
|
||||
Validates the current user dictionary provided by Clerk middleware and
|
||||
returns the normalized user_id. Raises HTTP 401 if authentication fails.
|
||||
"""
|
||||
if not current_user or not isinstance(current_user, dict):
|
||||
# Guard against dependency injection issues where Depends object might be passed
|
||||
if current_user is None or not isinstance(current_user, dict):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
|
||||
# Additional check: ensure it's actually a dict and not a Depends object or other type
|
||||
if not hasattr(current_user, 'get') or not callable(getattr(current_user, 'get')):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid authentication context")
|
||||
|
||||
user_id = str(current_user.get("id", "")).strip()
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -12,7 +12,7 @@ import sqlite3
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
from models.subscription_models import UsageAlert
|
||||
from models.subscription_models import UsageAlert, UserSubscription
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from ..dependencies import verify_user_access
|
||||
from ..cache import get_cached_dashboard, set_cached_dashboard
|
||||
@@ -27,7 +27,9 @@ async def get_dashboard_data(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive dashboard data for usage monitoring."""
|
||||
"""Get comprehensive dashboard data for usage monitoring.
|
||||
Returns all-time total + current period usage by default.
|
||||
When billing_period is specified, returns that period's data only."""
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
@@ -35,17 +37,23 @@ async def get_dashboard_data(
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
|
||||
# Check cache first (skip if billing_period is specified)
|
||||
if not billing_period:
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
# Check cache first (only for default view, skip when a specific period is requested)
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data and not billing_period:
|
||||
return cached_data
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Get current usage stats (for the requested period)
|
||||
current_usage = usage_service.get_user_usage_stats(user_id, billing_period)
|
||||
# When a specific billing_period is requested, show only that period's data
|
||||
# Otherwise show all-time total + current period usage
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
|
||||
# Get usage trends (last 6 months)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
@@ -76,13 +84,44 @@ async def get_dashboard_data(
|
||||
]
|
||||
|
||||
# Calculate cost projections (only relevant for current month)
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
|
||||
# Only project costs if viewing current month
|
||||
is_current_month = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
if is_current_month:
|
||||
# Determine if viewing current period based on subscription, not calendar
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
# Use subscription's billing period or fallback to calendar
|
||||
if subscription and subscription.current_period_start:
|
||||
sub_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Check if we have data for subscription period or calendar period
|
||||
from models.subscription_models import UsageSummary
|
||||
sub_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == sub_period
|
||||
).first()
|
||||
|
||||
# Determine which period to use for "current"
|
||||
if sub_data_exists:
|
||||
effective_period = sub_period
|
||||
else:
|
||||
# Check calendar period for backward compatibility
|
||||
cal_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
effective_period = calendar_period if cal_data_exists else sub_period
|
||||
|
||||
is_current_period = not billing_period or billing_period == effective_period
|
||||
else:
|
||||
is_current_period = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
|
||||
if is_current_period:
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
else:
|
||||
projected_cost = current_cost # For past months, projected is actual
|
||||
@@ -90,7 +129,8 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -100,9 +140,9 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
@@ -131,7 +171,13 @@ async def get_dashboard_data(
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
@@ -152,7 +198,7 @@ async def get_dashboard_data(
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
@@ -160,7 +206,8 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"current_usage": current_usage,
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -170,16 +217,17 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Cache the response after successful retry
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
# Cache the response after successful retry (only for default view)
|
||||
if not billing_period:
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
return response_payload
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
@@ -187,7 +235,8 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(retry_err),
|
||||
"data": {
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
@@ -201,7 +250,8 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"data": {
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
Pre-flight check endpoints for operation validation and cost estimation.
|
||||
"""
|
||||
|
||||
import time
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any
|
||||
@@ -34,6 +35,7 @@ async def preflight_check(
|
||||
|
||||
Uses caching to minimize DB load (< 100ms with cache hit).
|
||||
"""
|
||||
start_time = time.time()
|
||||
try:
|
||||
user_id = get_user_id_from_token(current_user)
|
||||
|
||||
@@ -229,13 +231,19 @@ async def preflight_check(
|
||||
'remaining': max(0, video_limit - video_current) if video_limit > 0 else float('inf')
|
||||
}
|
||||
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.warning(f"[PreflightCheck] Completed in {elapsed_ms:.0f}ms for user {user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": response_data
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.warning(f"[PreflightCheck] HTTP error after {elapsed_ms:.0f}ms")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error in pre-flight check: {e}", exc_info=True)
|
||||
elapsed_ms = (time.time() - start_time) * 1000
|
||||
logger.error(f"[PreflightCheck] Error after {elapsed_ms:.0f}ms: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Pre-flight check failed: {str(e)}")
|
||||
|
||||
@@ -14,13 +14,21 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""
|
||||
Format subscription plan limits for API response.
|
||||
|
||||
Includes _zero_means metadata per field to disambiguate:
|
||||
- 'disabled': 0 means the feature is not available (Free tier)
|
||||
- 'unlimited': 0 means unlimited usage (Enterprise tier)
|
||||
- 'limited': >0 means numerical limit applies
|
||||
|
||||
Args:
|
||||
plan: SubscriptionPlan model instance
|
||||
|
||||
Returns:
|
||||
Dictionary with formatted limits
|
||||
Dictionary with formatted limits and _zero_means metadata
|
||||
"""
|
||||
return {
|
||||
tier = plan.tier.value if hasattr(plan.tier, 'value') else str(plan.tier)
|
||||
is_enterprise = tier == 'enterprise'
|
||||
|
||||
limit_fields = {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
@@ -35,11 +43,43 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0) or 0,
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0) or 0,
|
||||
"exa_calls": getattr(plan, 'exa_calls_limit', 0) or 0,
|
||||
"wavespeed_calls": getattr(plan, 'wavespeed_calls_limit', 0) or 0,
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
"mistral_tokens": plan.mistral_tokens_limit,
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
"monthly_cost": plan.monthly_cost_limit,
|
||||
}
|
||||
|
||||
# Build _zero_means metadata: indicates whether 0 means 'disabled' or 'unlimited'
|
||||
zero_means = {}
|
||||
for field, value in limit_fields.items():
|
||||
if field == "monthly_cost":
|
||||
zero_means[field] = "disabled"
|
||||
elif is_enterprise:
|
||||
# Enterprise: 0 means unlimited for all call/token fields
|
||||
zero_means[field] = "unlimited"
|
||||
else:
|
||||
# Free/Basic/Pro: determine per-field
|
||||
# Fields that are 0=disabled on Free tier but 0=unlimited on Basic/Pro
|
||||
call_and_token_fields = {
|
||||
"gemini_calls", "openai_calls", "anthropic_calls", "mistral_calls",
|
||||
"tavily_calls", "serper_calls", "metaphor_calls", "firecrawl_calls",
|
||||
"stability_calls", "video_calls", "image_edit_calls", "audio_calls",
|
||||
"exa_calls", "wavespeed_calls", "ai_text_generation_calls",
|
||||
"gemini_tokens", "openai_tokens", "anthropic_tokens", "mistral_tokens",
|
||||
}
|
||||
if field in call_and_token_fields:
|
||||
if value == 0:
|
||||
zero_means[field] = "disabled" if tier == "free" else "unlimited"
|
||||
else:
|
||||
zero_means[field] = "limited"
|
||||
else:
|
||||
zero_means[field] = "limited" if value > 0 else "disabled"
|
||||
|
||||
return {
|
||||
**limit_fields,
|
||||
"_zero_means": zero_means,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from services.writing_assistant import WritingAssistantService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
@@ -11,7 +12,6 @@ router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
|
||||
class SuggestRequest(BaseModel):
|
||||
text: str
|
||||
max_results: int | None = 1
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
@@ -38,9 +38,10 @@ assistant_service = WritingAssistantService()
|
||||
|
||||
|
||||
@router.post("/suggest", response_model=SuggestResponse)
|
||||
async def suggest_endpoint(req: SuggestRequest) -> SuggestResponse:
|
||||
async def suggest_endpoint(req: SuggestRequest, current_user: Dict[str, Any] = Depends(get_current_user)) -> SuggestResponse:
|
||||
try:
|
||||
suggestions = await assistant_service.suggest(req.text, req.max_results or 1)
|
||||
user_id = current_user.get("id")
|
||||
suggestions = await assistant_service.suggest(req.text, user_id=user_id)
|
||||
return SuggestResponse(
|
||||
success=True,
|
||||
suggestions=[
|
||||
|
||||
714
backend/app.py
714
backend/app.py
@@ -1,6 +1,12 @@
|
||||
# Ensure typing constructs and models are available globally for FastAPI type annotation evaluation
|
||||
import os
|
||||
|
||||
# Print env vars immediately - BEFORE any imports
|
||||
print(f"[app.py] EARLY - PORT={os.getenv('PORT')}, HOST={os.getenv('HOST')}", flush=True)
|
||||
|
||||
import typing
|
||||
import builtins
|
||||
import builtins
|
||||
|
||||
# Make common typing constructs available globally
|
||||
builtins.Optional = typing.Optional
|
||||
@@ -14,15 +20,20 @@ from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
backend_dir = Path(__file__).parent
|
||||
project_root = backend_dir.parent
|
||||
load_dotenv(backend_dir / '.env')
|
||||
load_dotenv(project_root / '.env')
|
||||
load_dotenv()
|
||||
|
||||
# Set LOG_LEVEL early to WARNING to suppress DEBUG persona logs in podcast mode
|
||||
# Load .env but DON'T override existing environment variables (especially PORT from Render)
|
||||
# Use override=False to preserve Render-provided PORT
|
||||
load_dotenv(backend_dir / '.env', override=False)
|
||||
load_dotenv(project_root / '.env', override=False)
|
||||
load_dotenv(override=False)
|
||||
|
||||
# Set LOG_LEVEL early to WARNING in feature-only modes to suppress DEBUG persona logs
|
||||
import os
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast":
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
os.environ["LOG_LEVEL"] = "WARNING"
|
||||
|
||||
print(f"[app.py] Starting... ALWRITY_ENABLED_FEATURES={os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
|
||||
def get_enabled_features() -> set:
|
||||
"""Get enabled features from ALWRITY_ENABLED_FEATURES env var."""
|
||||
@@ -32,13 +43,23 @@ def get_enabled_features() -> set:
|
||||
return {f.strip() for f in env_value.split(",") if f.strip()}
|
||||
|
||||
|
||||
def is_podcast_only_demo_mode() -> bool:
|
||||
"""Check if podcast-only mode is enabled."""
|
||||
def _is_full_mode() -> bool:
|
||||
"""Check if running in full mode (all features enabled)."""
|
||||
enabled = get_enabled_features()
|
||||
return "podcast" in enabled and "all" not in enabled
|
||||
return "all" in enabled
|
||||
|
||||
|
||||
# Import onboarding models (after env is loaded)
|
||||
def _is_feature_enabled(feature: str) -> bool:
|
||||
"""Check if a specific feature is enabled (including in 'all' mode)."""
|
||||
enabled = get_enabled_features()
|
||||
return feature in enabled or "all" in enabled
|
||||
|
||||
|
||||
# Print env var IMMEDIATELY at module start
|
||||
print(f"[app.py] ALWRITY_ENABLED_FEATURES at start: {os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
|
||||
# Import onboarding models (after env is loaded, before heavy imports)
|
||||
from models.onboarding import APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData, CompetitorAnalysis
|
||||
|
||||
|
||||
@@ -54,24 +75,30 @@ import asyncio
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
|
||||
def _log_memory_usage():
|
||||
try:
|
||||
import psutil
|
||||
mem_mb = psutil.Process().memory_info().rss // (1024 * 1024)
|
||||
logger.info(f"Memory usage (MB): {mem_mb}")
|
||||
except Exception:
|
||||
# psutil not available or failed; skip silently
|
||||
pass
|
||||
|
||||
# Import modular utilities (skip OnboardingManager import in podcast-only mode)
|
||||
# Log memory early in app.py startup
|
||||
_log_memory_usage()
|
||||
logger.info("app.py: Early memory checkpoint after env load")
|
||||
|
||||
|
||||
# Import modular utilities (skip OnboardingManager import in feature-only modes)
|
||||
from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager
|
||||
if not is_podcast_only_demo_mode():
|
||||
if _is_full_mode():
|
||||
from alwrity_utils import OnboardingManager
|
||||
|
||||
# Import monitoring middleware
|
||||
from services.subscription import monitoring_middleware
|
||||
|
||||
|
||||
def should_include_non_podcast_features() -> bool:
|
||||
"""Check if non-podcast features should be included."""
|
||||
enabled = get_enabled_features()
|
||||
return "all" in enabled or "core" in enabled
|
||||
|
||||
|
||||
# Legacy constant for backwards compatibility
|
||||
PODCAST_ONLY_DEMO_MODE = is_podcast_only_demo_mode()
|
||||
# Skip monitoring middleware in feature-only modes to save memory
|
||||
if _is_full_mode():
|
||||
from services.subscription import monitoring_middleware
|
||||
else:
|
||||
monitoring_middleware = None
|
||||
|
||||
|
||||
# Set up clean logging for end users
|
||||
@@ -81,49 +108,73 @@ setup_clean_logging()
|
||||
# Import middleware
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
# Import component logic endpoints (needs OnboardingSession, so import after models)
|
||||
from api.component_logic import router as component_logic_router
|
||||
# Import component logic endpoints (skip in feature-only modes - uses seo_analyzer)
|
||||
component_logic_router = None
|
||||
if _is_full_mode():
|
||||
from api.component_logic import router as component_logic_router
|
||||
|
||||
# Import subscription API endpoints
|
||||
from api.subscription import router as subscription_router
|
||||
|
||||
# Import Step 3 onboarding routes (skip in podcast-only mode)
|
||||
# Import Step 3 onboarding routes (skip in feature-only modes)
|
||||
step3_routes = None
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
from api.onboarding_utils.step3_routes import router as step3_routes
|
||||
|
||||
# Import SEO tools router
|
||||
from routers.seo_tools import router as seo_tools_router
|
||||
# Import Facebook Writer endpoints
|
||||
from api.facebook_writer.routers import facebook_router
|
||||
# Import LinkedIn content generation router
|
||||
from routers.linkedin import router as linkedin_router
|
||||
# Import LinkedIn image generation router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
from api.brainstorm import router as brainstorm_router
|
||||
from api.images import router as images_router
|
||||
from api.assets_serving import router as assets_serving_router
|
||||
from routers.image_studio import router as image_studio_router
|
||||
from routers.product_marketing import router as product_marketing_router
|
||||
from routers.campaign_creator import router as campaign_creator_router
|
||||
# Import SEO tools router (skip in feature-only modes - uses seo_analyzer)
|
||||
seo_tools_router = None
|
||||
if _is_full_mode():
|
||||
from routers.seo_tools import router as seo_tools_router
|
||||
|
||||
# Import hallucination detector router
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
# Skip Facebook Writer, LinkedIn, and other non-essential routes in feature-only modes
|
||||
# Also skip other heavy services that trigger PersonaAnalysisService initialization
|
||||
if _is_full_mode():
|
||||
from api.facebook_writer.routers import facebook_router
|
||||
from routers.linkedin import router as linkedin_router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
from api.brainstorm import router as brainstorm_router
|
||||
from api.images import router as images_router
|
||||
from api.assets_serving import router as assets_serving_router
|
||||
from routers.image_studio import router as image_studio_router
|
||||
from routers.product_marketing import router as product_marketing_router
|
||||
from routers.campaign_creator import router as campaign_creator_router
|
||||
else:
|
||||
# In feature-only modes, only load essential assets router
|
||||
from api.assets_serving import router as assets_serving_router
|
||||
brainstorm_router = None
|
||||
images_router = None
|
||||
image_studio_router = None
|
||||
product_marketing_router = None
|
||||
campaign_creator_router = None
|
||||
|
||||
# Import research configuration router
|
||||
from api.research_config import router as research_config_router
|
||||
# Import hallucination detector router (skip in feature-only modes - triggers heavy ML)
|
||||
if _is_full_mode():
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
else:
|
||||
hallucination_detector_router = None
|
||||
writing_assistant_router = None
|
||||
|
||||
# Import research configuration router (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
from api.research_config import router as research_config_router
|
||||
else:
|
||||
research_config_router = None
|
||||
|
||||
# Import user data endpoints
|
||||
# Import content planning endpoints
|
||||
from api.content_planning.api.router import router as content_planning_router
|
||||
from api.user_data import router as user_data_router
|
||||
# Import content planning endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
from api.content_planning.api.router import router as content_planning_router
|
||||
from api.content_planning.strategy_copilot import router as strategy_copilot_router
|
||||
else:
|
||||
content_planning_router = None
|
||||
strategy_copilot_router = None
|
||||
|
||||
# Import user environment endpoints
|
||||
from api.user_environment import router as user_environment_router
|
||||
|
||||
# Import strategy copilot endpoints
|
||||
from api.content_planning.strategy_copilot import router as strategy_copilot_router
|
||||
# Import user data endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
from api.user_data import router as user_data_router
|
||||
else:
|
||||
user_data_router = None
|
||||
|
||||
# Import database service
|
||||
from services.database import close_database
|
||||
@@ -135,39 +186,71 @@ from services.startup_health import (
|
||||
|
||||
# Trigger reload for monitoring fix
|
||||
|
||||
# Import OAuth token monitoring routes
|
||||
from api.oauth_token_monitoring_routes import router as oauth_token_monitoring_router
|
||||
# Import OAuth token monitoring routes (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
from api.oauth_token_monitoring_routes import router as oauth_token_monitoring_router
|
||||
else:
|
||||
oauth_token_monitoring_router = None
|
||||
|
||||
# Import SEO Dashboard endpoints
|
||||
from api.seo_dashboard import (
|
||||
get_seo_dashboard_data,
|
||||
get_seo_health_score,
|
||||
get_seo_metrics,
|
||||
get_platform_status,
|
||||
get_ai_insights,
|
||||
seo_dashboard_health_check,
|
||||
analyze_seo_comprehensive,
|
||||
analyze_seo_full,
|
||||
get_seo_metrics_detailed,
|
||||
get_analysis_summary,
|
||||
batch_analyze_urls,
|
||||
SEOAnalysisRequest,
|
||||
get_seo_dashboard_overview,
|
||||
get_gsc_raw_data,
|
||||
get_bing_raw_data,
|
||||
get_competitive_insights,
|
||||
get_deep_competitor_analysis,
|
||||
run_strategic_insights,
|
||||
get_strategic_insights_history,
|
||||
refresh_analytics_data,
|
||||
analyze_urls_ai,
|
||||
AnalyzeURLsRequest,
|
||||
get_analyzed_pages,
|
||||
get_semantic_health,
|
||||
get_semantic_cache_stats,
|
||||
get_sif_indexing_health,
|
||||
get_onboarding_task_health,
|
||||
)
|
||||
# Import SEO Dashboard endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
from api.seo_dashboard import (
|
||||
get_seo_dashboard_data,
|
||||
get_seo_health_score,
|
||||
get_seo_metrics,
|
||||
get_platform_status,
|
||||
get_ai_insights,
|
||||
seo_dashboard_health_check,
|
||||
analyze_seo_comprehensive,
|
||||
analyze_seo_full,
|
||||
get_seo_metrics_detailed,
|
||||
get_analysis_summary,
|
||||
batch_analyze_urls,
|
||||
SEOAnalysisRequest,
|
||||
get_seo_dashboard_overview,
|
||||
get_gsc_raw_data,
|
||||
get_bing_raw_data,
|
||||
get_competitive_insights,
|
||||
get_deep_competitor_analysis,
|
||||
run_strategic_insights,
|
||||
get_strategic_insights_history,
|
||||
refresh_analytics_data,
|
||||
analyze_urls_ai,
|
||||
AnalyzeURLsRequest,
|
||||
get_analyzed_pages,
|
||||
get_semantic_health,
|
||||
get_semantic_cache_stats,
|
||||
get_sif_indexing_health,
|
||||
get_onboarding_task_health,
|
||||
)
|
||||
else:
|
||||
get_seo_dashboard_data = None
|
||||
get_seo_health_score = None
|
||||
get_seo_metrics = None
|
||||
get_platform_status = None
|
||||
get_ai_insights = None
|
||||
seo_dashboard_health_check = None
|
||||
analyze_seo_comprehensive = None
|
||||
analyze_seo_full = None
|
||||
get_seo_metrics_detailed = None
|
||||
get_analysis_summary = None
|
||||
batch_analyze_urls = None
|
||||
SEOAnalysisRequest = None
|
||||
get_seo_dashboard_overview = None
|
||||
get_gsc_raw_data = None
|
||||
get_bing_raw_data = None
|
||||
get_competitive_insights = None
|
||||
get_deep_competitor_analysis = None
|
||||
run_strategic_insights = None
|
||||
get_strategic_insights_history = None
|
||||
refresh_analytics_data = None
|
||||
analyze_urls_ai = None
|
||||
AnalyzeURLsRequest = None
|
||||
get_analyzed_pages = None
|
||||
get_semantic_health = None
|
||||
get_semantic_cache_stats = None
|
||||
get_sif_indexing_health = None
|
||||
get_onboarding_task_health = None
|
||||
|
||||
|
||||
# Initialize FastAPI app
|
||||
@@ -184,12 +267,23 @@ default_allowed_origins = [
|
||||
"http://localhost:8000", # Backend dev server
|
||||
"http://localhost:3001", # Alternative React port
|
||||
"https://alwrity-ai.vercel.app", # Vercel frontend
|
||||
"https://alwrity-5vac2n9su-ajsis-projects.vercel.app", # Current Vercel deployment
|
||||
"https://alwrity.vercel.app", # Vercel app
|
||||
]
|
||||
|
||||
# Optional dynamic origins from environment (comma-separated)
|
||||
env_origins = os.getenv("ALWRITY_ALLOWED_ORIGINS", "").split(",") if os.getenv("ALWRITY_ALLOWED_ORIGINS") else []
|
||||
env_origins = [o.strip() for o in env_origins if o.strip()]
|
||||
|
||||
# Convenience: NGROK_URL env var (single origin)
|
||||
ngrok_origin = os.getenv("NGROK_URL")
|
||||
if ngrok_origin:
|
||||
env_origins.append(ngrok_origin.strip())
|
||||
|
||||
# Optional dynamic origins from environment (comma-separated)
|
||||
env_origins = os.getenv("ALWRITY_ALLOWED_ORIGINS", "").split(",") if os.getenv("ALWRITY_ALLOWED_ORIGINS") else []
|
||||
env_origins = [o.strip() for o in env_origins if o.strip()]
|
||||
|
||||
# Convenience: NGROK_URL env var (single origin)
|
||||
ngrok_origin = os.getenv("NGROK_URL")
|
||||
if ngrok_origin:
|
||||
@@ -213,8 +307,8 @@ router_manager = RouterManager(app)
|
||||
router_group_status: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
onboarding_manager = None
|
||||
# Only create OnboardingManager if NOT in podcast-only mode
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
# Only create OnboardingManager in full mode
|
||||
if _is_full_mode():
|
||||
from alwrity_utils import OnboardingManager
|
||||
onboarding_manager = OnboardingManager(app)
|
||||
|
||||
@@ -222,8 +316,9 @@ if not PODCAST_ONLY_DEMO_MODE:
|
||||
# Registration order: 1. Monitoring 2. Rate Limit 3. API Key Injection
|
||||
# Execution order: 1. API Key Injection (sets user_id) 2. Rate Limit 3. Monitoring (uses user_id)
|
||||
|
||||
# 1. FIRST REGISTERED (runs LAST) - Monitoring middleware
|
||||
app.middleware("http")(monitoring_middleware)
|
||||
# 1. FIRST REGISTERED (runs LAST) - Monitoring middleware (skip in podcast-only mode)
|
||||
if monitoring_middleware:
|
||||
app.middleware("http")(monitoring_middleware)
|
||||
|
||||
# 2. SECOND REGISTERED (runs SECOND) - Rate limiting
|
||||
@app.middleware("http")
|
||||
@@ -240,7 +335,8 @@ app.middleware("http")(api_key_injection_middleware)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
health_data = health_checker.basic_health_check()
|
||||
health_data["podcast_only_demo_mode"] = PODCAST_ONLY_DEMO_MODE
|
||||
health_data["feature_mode"] = "single" if not _is_full_mode() else "full"
|
||||
health_data["enabled_features"] = list(get_enabled_features())
|
||||
return health_data
|
||||
|
||||
@app.get("/health/database")
|
||||
@@ -257,7 +353,8 @@ async def comprehensive_health():
|
||||
async def readiness(current_user: dict = Depends(get_current_user)):
|
||||
"""Readiness check that validates tenant DB resolution/session under auth context."""
|
||||
return {
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"startup": get_startup_status(),
|
||||
"tenant": readiness_under_auth_context(current_user),
|
||||
}
|
||||
@@ -289,7 +386,8 @@ async def router_status():
|
||||
status = router_manager.get_router_status()
|
||||
status.update(
|
||||
{
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"router_groups": router_group_status,
|
||||
}
|
||||
)
|
||||
@@ -304,26 +402,19 @@ async def feature_profile_status():
|
||||
@app.get("/api/onboarding/status")
|
||||
async def onboarding_status():
|
||||
"""Get onboarding manager status (or demo-mode disabled state)."""
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
if not _is_full_mode():
|
||||
return {
|
||||
"enabled": False,
|
||||
"status": "disabled",
|
||||
"message": "Onboarding is disabled for podcast-only demo mode.",
|
||||
"demo_mode": "podcast_only",
|
||||
"message": f"Onboarding is disabled in feature-only mode. Enabled features: {list(get_enabled_features())}",
|
||||
"feature_mode": "single",
|
||||
}
|
||||
return onboarding_manager.get_onboarding_status()
|
||||
|
||||
# Include routers using modular utilities
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
else:
|
||||
enabled_features = get_enabled_features()
|
||||
if "all" in enabled_features:
|
||||
# Full mode: load all core and optional routers
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": router_manager.include_core_routers(),
|
||||
"reason": "Full mode",
|
||||
@@ -332,6 +423,72 @@ else:
|
||||
"mounted": router_manager.include_optional_routers(),
|
||||
"reason": "Full mode",
|
||||
}
|
||||
else:
|
||||
# Feature-only mode: load only routers matching enabled features
|
||||
from alwrity_utils.router_manager import CORE_ROUTER_REGISTRY
|
||||
|
||||
# Filter core routers that match any enabled feature
|
||||
matching_core = [
|
||||
r for r in CORE_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
logger.info(
|
||||
f"[FEATURE-MODE] Enabled features: {enabled_features}, "
|
||||
f"matching {len(matching_core)} core routers: {[r['name'] for r in matching_core]}"
|
||||
)
|
||||
|
||||
# Try to include step4_assets for voice cloning (may fail if nltk not installed)
|
||||
step4_entry = next((r for r in matching_core if r.get("name") == "step4_assets"), None)
|
||||
if step4_entry:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Attempting to load step4_assets")
|
||||
router = router_manager._load_router_from_registry(step4_entry)
|
||||
router_manager.include_router_safely(router, step4_entry["name"], step4_entry.get("include_kwargs"))
|
||||
except ImportError as e:
|
||||
logger.warning(f"[FEATURE-MODE] Skipping step4_assets (missing optional dependency): {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount step4_assets: {e}")
|
||||
|
||||
# Load other matching core routers
|
||||
for entry in matching_core:
|
||||
if entry.get("name") == "step4_assets":
|
||||
continue # Already loaded above
|
||||
if entry.get("name") == "subscription":
|
||||
continue # Loaded separately below
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Load optional routers matching enabled features
|
||||
from alwrity_utils.router_manager import OPTIONAL_ROUTER_REGISTRY
|
||||
matching_optional = [
|
||||
r for r in OPTIONAL_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
for entry in matching_optional:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading optional router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount optional {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Safety net: explicitly include hallucination detector (router_manager may skip silently)
|
||||
if hallucination_detector_router:
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
|
||||
# Log startup summary
|
||||
router_manager.log_startup_summary()
|
||||
@@ -347,157 +504,159 @@ router_group_status["assets_serving"] = {
|
||||
"reason": "Required for podcast media assets",
|
||||
}
|
||||
|
||||
# SEO Dashboard endpoints
|
||||
@app.get("/api/seo-dashboard/data")
|
||||
async def seo_dashboard_data():
|
||||
"""Get complete SEO dashboard data."""
|
||||
return await get_seo_dashboard_data()
|
||||
# SEO Dashboard endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
@app.get("/api/seo-dashboard/data")
|
||||
async def seo_dashboard_data():
|
||||
"""Get complete SEO dashboard data."""
|
||||
return await get_seo_dashboard_data()
|
||||
|
||||
@app.get("/api/seo-dashboard/health-score")
|
||||
async def seo_health_score():
|
||||
"""Get SEO health score."""
|
||||
return await get_seo_health_score()
|
||||
@app.get("/api/seo-dashboard/health-score")
|
||||
async def seo_health_score():
|
||||
"""Get SEO health score."""
|
||||
return await get_seo_health_score()
|
||||
|
||||
@app.get("/api/seo-dashboard/metrics")
|
||||
async def seo_metrics():
|
||||
"""Get SEO metrics."""
|
||||
return await get_seo_metrics()
|
||||
@app.get("/api/seo-dashboard/metrics")
|
||||
async def seo_metrics():
|
||||
"""Get SEO metrics."""
|
||||
return await get_seo_metrics()
|
||||
|
||||
@app.get("/api/seo-dashboard/platforms")
|
||||
async def seo_platforms(current_user: dict = Depends(get_current_user)):
|
||||
"""Get platform status."""
|
||||
return await get_platform_status(current_user)
|
||||
@app.get("/api/seo-dashboard/platforms")
|
||||
async def seo_platforms(current_user: dict = Depends(get_current_user)):
|
||||
"""Get platform status."""
|
||||
return await get_platform_status(current_user)
|
||||
|
||||
@app.get("/api/seo-dashboard/insights")
|
||||
async def seo_insights():
|
||||
"""Get AI insights."""
|
||||
return await get_ai_insights()
|
||||
@app.get("/api/seo-dashboard/insights")
|
||||
async def seo_insights():
|
||||
"""Get AI insights."""
|
||||
return await get_ai_insights()
|
||||
|
||||
# New SEO Dashboard endpoints with real data
|
||||
@app.get("/api/seo-dashboard/overview")
|
||||
async def seo_dashboard_overview_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get comprehensive SEO dashboard overview with real GSC/Bing data."""
|
||||
return await get_seo_dashboard_overview(current_user, site_url)
|
||||
@app.get("/api/seo-dashboard/overview")
|
||||
async def seo_dashboard_overview_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get comprehensive SEO dashboard overview with real GSC/Bing data."""
|
||||
return await get_seo_dashboard_overview(current_user, site_url)
|
||||
|
||||
@app.get("/api/seo-dashboard/gsc/raw")
|
||||
async def gsc_raw_data_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get raw GSC data for the specified site."""
|
||||
return await get_gsc_raw_data(current_user, site_url)
|
||||
@app.get("/api/seo-dashboard/gsc/raw")
|
||||
async def gsc_raw_data_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get raw GSC data for the specified site."""
|
||||
return await get_gsc_raw_data(current_user, site_url)
|
||||
|
||||
@app.get("/api/seo-dashboard/bing/raw")
|
||||
async def bing_raw_data_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get raw Bing data for the specified site."""
|
||||
return await get_bing_raw_data(current_user, site_url)
|
||||
@app.get("/api/seo-dashboard/bing/raw")
|
||||
async def bing_raw_data_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get raw Bing data for the specified site."""
|
||||
return await get_bing_raw_data(current_user, site_url)
|
||||
|
||||
@app.get("/api/seo-dashboard/competitive-insights")
|
||||
async def competitive_insights_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get competitive insights from onboarding step 3 data."""
|
||||
return await get_competitive_insights(current_user, site_url)
|
||||
@app.get("/api/seo-dashboard/competitive-insights")
|
||||
async def competitive_insights_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get competitive insights from onboarding step 3 data."""
|
||||
return await get_competitive_insights(current_user, site_url)
|
||||
|
||||
@app.get("/api/seo-dashboard/deep-competitor-analysis")
|
||||
async def deep_competitor_analysis_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get deep competitor analysis results (auto-scheduled post-onboarding)."""
|
||||
return await get_deep_competitor_analysis(current_user, site_url)
|
||||
@app.get("/api/seo-dashboard/deep-competitor-analysis")
|
||||
async def deep_competitor_analysis_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get deep competitor analysis results (auto-scheduled post-onboarding)."""
|
||||
return await get_deep_competitor_analysis(current_user, site_url)
|
||||
|
||||
@app.post("/api/seo-dashboard/strategic-insights/run")
|
||||
async def run_strategic_insights_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""Run AI-powered strategic insights analysis manually."""
|
||||
return await run_strategic_insights(current_user)
|
||||
@app.post("/api/seo-dashboard/strategic-insights/run")
|
||||
async def run_strategic_insights_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""Run AI-powered strategic insights analysis manually."""
|
||||
return await run_strategic_insights(current_user)
|
||||
|
||||
@app.get("/api/seo-dashboard/strategic-insights/history")
|
||||
async def get_strategic_insights_history_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""Fetch the history of strategic insights for the user."""
|
||||
return await get_strategic_insights_history(current_user)
|
||||
@app.get("/api/seo-dashboard/strategic-insights/history")
|
||||
async def get_strategic_insights_history_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""Fetch the history of strategic insights for the user."""
|
||||
return await get_strategic_insights_history(current_user)
|
||||
|
||||
@app.post("/api/seo-dashboard/refresh")
|
||||
async def refresh_analytics_data_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Refresh analytics data by invalidating cache and fetching fresh data."""
|
||||
return await refresh_analytics_data(current_user, site_url)
|
||||
@app.post("/api/seo-dashboard/refresh")
|
||||
async def refresh_analytics_data_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Refresh analytics data by invalidating cache and fetching fresh data."""
|
||||
return await refresh_analytics_data(current_user, site_url)
|
||||
|
||||
|
||||
@app.get("/api/seo-dashboard/onboarding-task-health")
|
||||
async def onboarding_task_health_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get consolidated health for onboarding-scheduled SEO tasks."""
|
||||
return await get_onboarding_task_health(current_user, site_url)
|
||||
|
||||
@app.get("/api/seo-dashboard/onboarding-task-health")
|
||||
async def onboarding_task_health_endpoint(current_user: dict = Depends(get_current_user), site_url: str = None):
|
||||
"""Get consolidated health for onboarding-scheduled SEO tasks."""
|
||||
return await get_onboarding_task_health(current_user, site_url)
|
||||
@app.get("/api/seo-dashboard/health")
|
||||
async def seo_dashboard_health():
|
||||
"""Health check for SEO dashboard."""
|
||||
return await seo_dashboard_health_check()
|
||||
|
||||
@app.get("/api/seo-dashboard/health")
|
||||
async def seo_dashboard_health():
|
||||
"""Health check for SEO dashboard."""
|
||||
return await seo_dashboard_health_check()
|
||||
|
||||
# Phase 2B: Semantic health monitoring endpoint (24-hour polling)
|
||||
@app.get("/api/seo-dashboard/semantic-health")
|
||||
async def semantic_health_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get real-time semantic health metrics for content and competitors.
|
||||
This endpoint provides Phase 2B semantic intelligence monitoring data.
|
||||
|
||||
Returns semantic health score, status, and recommendations.
|
||||
Data is cached and updated every 24 hours via scheduler.
|
||||
"""
|
||||
return await get_semantic_health(current_user)
|
||||
@app.get("/api/seo-dashboard/semantic-health")
|
||||
async def semantic_health_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get real-time semantic health metrics for content and competitors.
|
||||
This endpoint provides Phase 2B semantic intelligence monitoring data.
|
||||
|
||||
Returns semantic health score, status, and recommendations.
|
||||
Data is cached and updated every 24 hours via scheduler.
|
||||
"""
|
||||
return await get_semantic_health(current_user)
|
||||
|
||||
|
||||
@app.get("/api/seo-dashboard/cache-stats")
|
||||
async def semantic_cache_stats_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get semantic cache performance statistics.
|
||||
Returns hit rate, memory usage, and eviction counts.
|
||||
"""
|
||||
return await get_semantic_cache_stats(current_user)
|
||||
@app.get("/api/seo-dashboard/cache-stats")
|
||||
async def semantic_cache_stats_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get semantic cache performance statistics.
|
||||
Returns hit rate, memory usage, and eviction counts.
|
||||
"""
|
||||
return await get_semantic_cache_stats(current_user)
|
||||
|
||||
|
||||
@app.get("/api/seo-dashboard/sif-health")
|
||||
async def sif_indexing_health_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get SIF indexing health summary for the current user.
|
||||
Used by the Semantic Indexing Status widget on the dashboard.
|
||||
"""
|
||||
return await get_sif_indexing_health(current_user)
|
||||
@app.get("/api/seo-dashboard/sif-health")
|
||||
async def sif_indexing_health_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get SIF indexing health summary for the current user.
|
||||
Used by the Semantic Indexing Status widget on the dashboard.
|
||||
"""
|
||||
return await get_sif_indexing_health(current_user)
|
||||
|
||||
# Comprehensive SEO Analysis endpoints
|
||||
@app.post("/api/seo-dashboard/analyze-comprehensive")
|
||||
async def analyze_seo_comprehensive_endpoint(request: SEOAnalysisRequest):
|
||||
"""Analyze a URL for comprehensive SEO performance."""
|
||||
return await analyze_seo_comprehensive(request)
|
||||
# Comprehensive SEO Analysis endpoints
|
||||
@app.post("/api/seo-dashboard/analyze-comprehensive")
|
||||
async def analyze_seo_comprehensive_endpoint(request: SEOAnalysisRequest):
|
||||
"""Analyze a URL for comprehensive SEO performance."""
|
||||
return await analyze_seo_comprehensive(request)
|
||||
|
||||
@app.post("/api/seo-dashboard/analyze-full")
|
||||
async def analyze_seo_full_endpoint(request: SEOAnalysisRequest):
|
||||
"""Analyze a URL for comprehensive SEO performance."""
|
||||
return await analyze_seo_full(request)
|
||||
@app.post("/api/seo-dashboard/analyze-full")
|
||||
async def analyze_seo_full_endpoint(request: SEOAnalysisRequest):
|
||||
"""Analyze a URL for comprehensive SEO performance."""
|
||||
return await analyze_seo_full(request)
|
||||
|
||||
@app.get("/api/seo-dashboard/metrics-detailed")
|
||||
async def seo_metrics_detailed(url: str):
|
||||
"""Get detailed SEO metrics for a URL."""
|
||||
return await get_seo_metrics_detailed(url)
|
||||
@app.get("/api/seo-dashboard/metrics-detailed")
|
||||
async def seo_metrics_detailed(url: str):
|
||||
"""Get detailed SEO metrics for a URL."""
|
||||
return await get_seo_metrics_detailed(url)
|
||||
|
||||
@app.get("/api/seo-dashboard/analysis-summary")
|
||||
async def seo_analysis_summary(url: str):
|
||||
"""Get a quick summary of SEO analysis for a URL."""
|
||||
return await get_analysis_summary(url)
|
||||
@app.get("/api/seo-dashboard/analysis-summary")
|
||||
async def seo_analysis_summary(url: str):
|
||||
"""Get a quick summary of SEO analysis for a URL."""
|
||||
return await get_analysis_summary(url)
|
||||
|
||||
@app.post("/api/seo-dashboard/batch-analyze")
|
||||
async def batch_analyze_urls_endpoint(urls: list[str]):
|
||||
"""Analyze multiple URLs in batch."""
|
||||
return await batch_analyze_urls(urls)
|
||||
@app.post("/api/seo-dashboard/batch-analyze")
|
||||
async def batch_analyze_urls_endpoint(urls: list[str]):
|
||||
"""Analyze multiple URLs in batch."""
|
||||
return await batch_analyze_urls(urls)
|
||||
|
||||
@app.post("/api/seo-dashboard/analyze-urls-ai")
|
||||
async def analyze_urls_ai_endpoint(request: AnalyzeURLsRequest, current_user: dict = Depends(get_current_user)):
|
||||
"""Run AI-powered SEO analysis on selected URLs."""
|
||||
return await analyze_urls_ai(request, current_user)
|
||||
@app.post("/api/seo-dashboard/analyze-urls-ai")
|
||||
async def analyze_urls_ai_endpoint(request: AnalyzeURLsRequest, current_user: dict = Depends(get_current_user)):
|
||||
"""Run AI-powered SEO analysis on selected URLs."""
|
||||
return await analyze_urls_ai(request, current_user)
|
||||
|
||||
# Include platform analytics router
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
# Include Bing Analytics Storage router to expose storage-backed endpoints
|
||||
from routers.bing_analytics_storage import router as bing_analytics_storage_router
|
||||
app.include_router(bing_analytics_storage_router)
|
||||
app.include_router(images_router)
|
||||
app.include_router(image_studio_router)
|
||||
app.include_router(product_marketing_router)
|
||||
app.include_router(campaign_creator_router)
|
||||
if images_router:
|
||||
app.include_router(images_router)
|
||||
if image_studio_router:
|
||||
app.include_router(image_studio_router)
|
||||
if product_marketing_router:
|
||||
app.include_router(product_marketing_router)
|
||||
if campaign_creator_router:
|
||||
app.include_router(campaign_creator_router)
|
||||
|
||||
# Include content assets router
|
||||
from api.content_assets.router import router as content_assets_router
|
||||
@@ -509,24 +668,38 @@ if not PODCAST_ONLY_DEMO_MODE:
|
||||
else:
|
||||
router_group_status["platform_extensions"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
"reason": "Skipped in feature-only mode",
|
||||
}
|
||||
|
||||
# Include Podcast Maker router
|
||||
from api.podcast.router import router as podcast_router
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Always mounted",
|
||||
}
|
||||
# Include Podcast Maker router (only when podcast feature is enabled)
|
||||
if _is_feature_enabled("podcast") and "all" not in get_enabled_features():
|
||||
from api.podcast.router import router as podcast_router
|
||||
logger.info(f"[ROUTER] Including podcast_router")
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Podcast feature enabled",
|
||||
}
|
||||
elif "all" in get_enabled_features():
|
||||
# In full mode, podcast is loaded via optional router registry
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Full mode (loaded via registry)",
|
||||
}
|
||||
else:
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": False,
|
||||
"reason": "Podcast feature not enabled",
|
||||
}
|
||||
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
if _is_full_mode():
|
||||
# Include YouTube Creator Studio router
|
||||
from api.youtube.router import router as youtube_router
|
||||
app.include_router(youtube_router, prefix="/api")
|
||||
|
||||
# Include research configuration router
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
if research_config_router:
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
|
||||
# Include Research Engine router (standalone AI research module)
|
||||
from api.research.router import router as research_engine_router
|
||||
@@ -535,7 +708,8 @@ if not PODCAST_ONLY_DEMO_MODE:
|
||||
# Scheduler dashboard routes
|
||||
from api.scheduler_dashboard import router as scheduler_dashboard_router
|
||||
app.include_router(scheduler_dashboard_router)
|
||||
app.include_router(oauth_token_monitoring_router)
|
||||
if oauth_token_monitoring_router:
|
||||
app.include_router(oauth_token_monitoring_router)
|
||||
|
||||
# Autonomous Agents API routes (Phase 3A)
|
||||
from api.agents_api import router as agents_router
|
||||
@@ -551,7 +725,7 @@ if not PODCAST_ONLY_DEMO_MODE:
|
||||
else:
|
||||
router_group_status["advanced_workflows"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
"reason": "Skipped in feature-only mode",
|
||||
}
|
||||
|
||||
# Setup frontend serving using modular utilities
|
||||
@@ -563,21 +737,38 @@ async def serve_frontend():
|
||||
"""Serve the React frontend."""
|
||||
return frontend_serving.serve_frontend()
|
||||
|
||||
# Startup event
|
||||
# Startup event - fires AFTER port is bound
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
"""Initialize services on startup."""
|
||||
import time
|
||||
startup_start = time.time()
|
||||
|
||||
logger.info("[STARTUP] Server port bound, beginning background initialization...")
|
||||
|
||||
try:
|
||||
startup_report = run_startup_health_routine(app)
|
||||
if startup_report.get("status") != "healthy":
|
||||
logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}")
|
||||
_log_memory_usage()
|
||||
|
||||
# Note: Pricing is initialized per-user in services/database.py:init_user_database()
|
||||
# which runs on first database access for each user. No global seeding needed at startup.
|
||||
|
||||
enabled_features = get_enabled_features()
|
||||
is_single_mode = "all" not in enabled_features
|
||||
|
||||
# Skip startup health checks in feature-only modes to avoid unnecessary DB errors
|
||||
if _is_full_mode():
|
||||
startup_report = run_startup_health_routine(app)
|
||||
if startup_report.get("status") != "healthy":
|
||||
logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}")
|
||||
else:
|
||||
logger.info(f"[FEATURE-MODE] Skipping startup health routine (features: {enabled_features})")
|
||||
|
||||
# Start task scheduler only if NOT in podcast-only mode
|
||||
if not is_podcast_only_demo_mode():
|
||||
# Start task scheduler only in full mode
|
||||
if _is_full_mode():
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().start()
|
||||
else:
|
||||
logger.info("[Podcast] Skipping scheduler startup (podcast-only mode)")
|
||||
logger.info(f"[FEATURE-MODE] Skipping scheduler startup (features: {enabled_features})")
|
||||
|
||||
# Check Wix API key configuration
|
||||
wix_api_key = os.getenv('WIX_API_KEY')
|
||||
@@ -586,14 +777,18 @@ async def startup_event():
|
||||
else:
|
||||
logger.warning("⚠️ WIX_API_KEY not found in environment - Wix publishing may fail")
|
||||
|
||||
logger.info("ALwrity backend started successfully")
|
||||
elapsed = time.time() - startup_start
|
||||
logger.info(f"ALwrity backend started successfully in {elapsed:.1f}s")
|
||||
|
||||
# Critical router mount assertions for podcast-only demo mode
|
||||
# Critical router mount assertions for feature-only modes
|
||||
_assert_router_mounted("subscription")
|
||||
_assert_router_mounted("podcast")
|
||||
if _is_feature_enabled("podcast"):
|
||||
_assert_router_mounted("podcast")
|
||||
if _is_feature_enabled("blog_writer"):
|
||||
_assert_router_mounted("blog_writer")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during startup: {e}")
|
||||
raise
|
||||
# Don't raise - let the server start anyway
|
||||
|
||||
|
||||
def _assert_router_mounted(router_name: str) -> None:
|
||||
@@ -605,6 +800,7 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
router_path_indicators = {
|
||||
"subscription": ["/api/subscription/plans", "/api/subscription/preflight"],
|
||||
"podcast": ["/api/podcast/projects", "/api/podcast/"],
|
||||
"blog_writer": ["/api/blog/health", "/api/blog/research/start"],
|
||||
}
|
||||
|
||||
expected_paths = router_path_indicators.get(router_name, [])
|
||||
@@ -615,10 +811,9 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
else:
|
||||
error_msg = f"❌ CRITICAL: Router '{router_name}' is NOT mounted! Expected paths: {expected_paths}"
|
||||
logger.error(error_msg)
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
# In demo mode, podcast router MUST be mounted
|
||||
if router_name == "podcast":
|
||||
raise RuntimeError(error_msg)
|
||||
# In feature-only mode, only fail if the feature is expected
|
||||
if not _is_full_mode() and _is_feature_enabled(router_name):
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Shutdown event
|
||||
@app.on_event("shutdown")
|
||||
@@ -633,4 +828,19 @@ async def shutdown_event():
|
||||
close_database()
|
||||
logger.info("ALwrity backend shutdown successfully")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
logger.error(f"Error during shutdown: {e}")
|
||||
|
||||
|
||||
# Add main block to allow running directly with: python app.py
|
||||
# This also helps Gunicorn work correctly
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
port = int(os.environ.get("PORT", "10000"))
|
||||
host = os.environ.get("HOST", "0.0.0.0")
|
||||
|
||||
print(f"[app.py] ====================", flush=True)
|
||||
print(f"[app.py] DIRECT STARTUP", flush=True)
|
||||
print(f"[app.py] PORT={port}, HOST={host}", flush=True)
|
||||
print(f"[app.py] ====================", flush=True)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
197
backend/docs/AGENT_FLAT_CONTEXT_REVIEW.md
Normal file
197
backend/docs/AGENT_FLAT_CONTEXT_REVIEW.md
Normal file
@@ -0,0 +1,197 @@
|
||||
# Agent Flat-File Context System Review
|
||||
|
||||
## Scope
|
||||
This review documents the **current implementation** of ALwrity's onboarding flat-file context system and compares it to the proposed **Direct-to-File Virtual Shell (VFS)** model.
|
||||
|
||||
---
|
||||
|
||||
## 1) Present Implementation (What Exists Today)
|
||||
|
||||
### 1.1 Storage model
|
||||
- Context is stored per user under:
|
||||
- `backend/workspace/workspace_<safe_user_id>/agent_context/`
|
||||
- Files are JSON documents, one per onboarding domain:
|
||||
- `step2_website_analysis.json`
|
||||
- `step3_research_preferences.json`
|
||||
- `step4_persona_data.json`
|
||||
- `step5_integrations.json`
|
||||
- `context_manifest.json`
|
||||
|
||||
### 1.2 Writer and reader
|
||||
- `AgentFlatContextStore` is the core component that:
|
||||
- sanitizes user IDs for path safety,
|
||||
- writes documents atomically (`tempfile` + `os.replace`),
|
||||
- sets restrictive file permissions (`0600` best effort),
|
||||
- generates structured `agent_summary` objects,
|
||||
- updates a manifest index of available documents.
|
||||
- Data is loaded by direct file reads from the same class (`load_stepX_context_document`).
|
||||
|
||||
### 1.3 Read-path fallback chain
|
||||
`SIFIntegrationService` uses a strict fallback sequence for onboarding context retrieval:
|
||||
1. **flat file** (`AgentFlatContextStore`)
|
||||
2. **database** (`WebsiteAnalysis`, `ResearchPreferences`, `PersonaData`, etc.)
|
||||
3. **SIF semantic index** (`TxtaiIntelligenceService.search`)
|
||||
|
||||
Step 5 uses `flat_file -> sif_semantic`.
|
||||
|
||||
### 1.4 Producer flow (onboarding persistence)
|
||||
`StepManagementService` persists canonical snapshots to flat context when onboarding steps are saved:
|
||||
- Step 2 website analysis
|
||||
- Step 3 research preferences (and later competitor-enriched refresh)
|
||||
- Step 4 persona data
|
||||
- Step 5 integrations
|
||||
|
||||
### 1.5 Context optimization currently implemented
|
||||
- Sensitive-key redaction in nested payloads (`api_key`, `token`, `secret`, etc.).
|
||||
- Size budgeting with trimming (`DEFAULT_MAX_BYTES = 300_000`) and trim metadata.
|
||||
- Generated summaries include:
|
||||
- quick facts,
|
||||
- retrieval hints (high-signal terms and suggested agent queries),
|
||||
- domain-specific focus blocks.
|
||||
- Document context includes audience, retrieval contract, journey stage, related documents, and context-window guidance.
|
||||
|
||||
---
|
||||
|
||||
## 2) Comparison vs Proposed Direct-to-File VFS
|
||||
|
||||
## Strong alignment
|
||||
The current system already matches the proposal in important ways:
|
||||
- **Direct-to-file persistence** instead of DB-backed retrieval for fast reads.
|
||||
- **Manifest/index concept** (`context_manifest.json`) that can act like a precomputed path map.
|
||||
- **Agent-first retrieval semantics** (summary-first contract and fallback policy).
|
||||
- **Operational safety controls** (atomic writes, redaction, path sanitization).
|
||||
|
||||
## Gaps vs full virtual shell abstraction
|
||||
The following pieces are not fully implemented as described in your proposed architecture:
|
||||
- No explicit **virtual shell provider** (`IFileSystem`) exposing `ls/cat/grep/find` commands.
|
||||
- No always-live, process-level **in-memory `Map<virtualPath, absolutePath>`** for path lookups.
|
||||
- No native glob/query command layer for agent shell UX.
|
||||
- Not currently **read-only enforced at API surface** (writes are intentionally allowed by onboarding services to refresh context).
|
||||
|
||||
---
|
||||
|
||||
## 3) Practical Recommendation: Incremental VFS Evolution
|
||||
|
||||
1. **Introduce a read-only VFS facade for agents**
|
||||
- Keep `AgentFlatContextStore` as the write path for trusted onboarding services.
|
||||
- Add `AgentContextVFS` read adapter exposing:
|
||||
- `ls(path)` from manifest,
|
||||
- `cat(path)` mapped to underlying JSON,
|
||||
- `find(glob)` on virtual keys,
|
||||
- `grep(query)` with path prefilter + stream scan.
|
||||
|
||||
2. **Promote manifest to a first-class path map**
|
||||
- Build and cache an in-memory map on service startup or first access.
|
||||
- Refresh map when manifest `updated_at` changes.
|
||||
|
||||
3. **Add explicit write policy boundaries**
|
||||
- Agent-facing interface: hard read-only (`EROFS`).
|
||||
- Internal system service interface: allow writes for onboarding synchronization.
|
||||
|
||||
4. **Metadata strategy for grep ranking**
|
||||
- Prioritize in order:
|
||||
1) `agent_summary.quick_facts`
|
||||
2) `agent_summary.retrieval_hints.high_signal_terms`
|
||||
3) `document_context.context_type` and `journey.stage`
|
||||
4) full `data` body
|
||||
|
||||
---
|
||||
|
||||
## 4) Response to the Metadata Header Question
|
||||
|
||||
> "Does your current `.txt` optimization include specific metadata headers (like YAML frontmatter) that the grep tool should prioritize?"
|
||||
|
||||
For this implementation, context is currently persisted as structured JSON (not `.txt` with YAML frontmatter). Equivalent high-value metadata already exists and should be prioritized for search/ranking:
|
||||
- `context_type`
|
||||
- `updated_at`
|
||||
- `agent_summary.quick_facts`
|
||||
- `agent_summary.retrieval_hints.high_signal_terms`
|
||||
- `document_context.journey.stage`
|
||||
- `document_context.related_documents`
|
||||
|
||||
If you later move to `.txt` transport files, mirror these as frontmatter fields to preserve retrieval quality.
|
||||
|
||||
---
|
||||
|
||||
## 5) Bottom line
|
||||
Your current onboarding flat-file context implementation is already a strong "shim" architecture and close to the proposed model. The biggest missing piece is a dedicated virtual-shell read interface (`ls/cat/grep/find`) backed by a persistent path-map cache and a clear read-only contract for agent execution contexts.
|
||||
|
||||
---
|
||||
|
||||
## 6) Implemented Follow-up (VFS Adapter + Workspace Guide)
|
||||
|
||||
The following enhancements are now implemented:
|
||||
|
||||
1. **Auto-generated workspace map**
|
||||
- The system now generates `workspace_<user>/README.md` whenever `context_manifest.json` is updated.
|
||||
- The README includes:
|
||||
- available context files,
|
||||
- key signal hints from `agent_summary.retrieval_hints.high_signal_terms`,
|
||||
- journey-stage hints,
|
||||
- virtual path mappings and retrieval strategy guidance.
|
||||
|
||||
2. **Read-only VFS facade**
|
||||
- Added `AgentContextVFS` with:
|
||||
- `list_context()` (`ls` equivalent),
|
||||
- `search_context()` (`grep` equivalent; prioritizes `high_signal_terms` and `quick_facts`),
|
||||
- `read_context_file()` (`cat` equivalent; large-file summary mode + subkey drilldown),
|
||||
- explicit write rejection (`EROFS`).
|
||||
|
||||
3. **Virtual path support**
|
||||
- `/env/summary` maps to `AgentFlatContextStore.generate_total_summary()`.
|
||||
- `/steps/website`, `/steps/research`, `/steps/persona`, `/steps/integrations` map to step documents.
|
||||
|
||||
4. **System-prompt helper**
|
||||
- Added `build_filesystem_header(user_id)` to inject a compact file availability + priority hint block into agent startup prompts.
|
||||
|
||||
5. **Merged context helper in SIF integration**
|
||||
- `SIFIntegrationService.get_merged_flat_context()` now provides a unified view across all available flat files while preserving existing per-step retrieval methods.
|
||||
|
||||
6. **Basic file-level security hardening**
|
||||
- Workspace and context directories are now explicitly forced to `0700`.
|
||||
- Context and workspace files are written with strict `0600`.
|
||||
- Added path sandboxing to ensure requested paths cannot escape user workspace roots.
|
||||
- Restricted context-file loading to an allowlist of known onboarding context documents.
|
||||
- Added deterministic per-user secret derivation from `.env` (`FILE_ENCRYPTION_SALT` + `safe_user_id`) with non-sensitive fingerprints for audit/debug and future encryption-at-rest rollout.
|
||||
|
||||
7. **Tool-logic enhancement (coarse-to-fine search)**
|
||||
- `search_context` now performs a two-pass retrieval:
|
||||
1) high-relevance summary match pass (`high_signal_terms`, `quick_facts`),
|
||||
2) parallelized stream scan pass over sandboxed allowlisted files for supporting details.
|
||||
- Results include relevance labels, snippets, and line numbers for body matches.
|
||||
- Large-result behavior now reports truncation guidance (show top 10 and suggest narrower keywords).
|
||||
- `inspect_file` now provides token-saving behavior: full return for small files, or `agent_summary` + top-level keys for larger files, with key-level zoom-in support.
|
||||
|
||||
8. **Retrieval robustness roadmap (next hardening phase)**
|
||||
- **Query normalization:** Add synonym expansion and typo-tolerant matching (e.g., `tone` ≈ `brand voice`) before coarse/fine passes.
|
||||
- **Confidence scoring:** Return confidence tiers that blend source freshness (`updated_at`), summary-match strength, and match density.
|
||||
- **Field-aware boosting:** Weight matches by field priority (`high_signal_terms` > `quick_facts` > `data`) and document recency.
|
||||
- **Deduplicated evidence:** Collapse repeated hits from the same file/key into one clustered result with a single best snippet and hit count.
|
||||
- **Fallback query reformulation:** If zero hits, automatically retry with narrow/expanded variants and return attempted queries.
|
||||
- **Answerability contract:** Add a lightweight `can_answer` signal in search responses so orchestrators can decide whether to ask follow-up questions or fetch more context.
|
||||
- **Evaluation harness:** Track retrieval metrics over golden queries (`precision@k`, `MRR`, zero-hit rate, stale-hit rate) in CI to prevent relevance regressions.
|
||||
|
||||
9. **Collaborative VFS namespace (shared memory mode)**
|
||||
- Added optional `project_id` support to `AgentContextVFS` with isolated root: `workspace/project_<project_id>/`.
|
||||
- Introduced `scratchpad/` for collaborative writes while keeping onboarding `agent_context` read-first.
|
||||
- Added `write_shared_note(...)` with advisory locking (`flock`) and strict filename/path validation.
|
||||
- Added append-only `activity_log.jsonl` via `append_activity_log(...)` for watchdog/event-driven coordination.
|
||||
- Maintains owner-only permissions (`0700` scratchpad dir, `0600` files) and audit trails for shared writes.
|
||||
|
||||
10. **Testing readiness upgrades**
|
||||
- Added automated tests for:
|
||||
- query reformulation + `can_answer` behavior in `search_context`,
|
||||
- large-file progressive disclosure behavior in `inspect_file`,
|
||||
- collaborative write path (`write_shared_note`) and append-only activity logging.
|
||||
- Test module: `backend/tests/test_agent_context_vfs.py`.
|
||||
- These tests provide a baseline regression harness for VFS retrieval quality and shared-memory safety.
|
||||
|
||||
11. **Static + Structural retrieval hardening**
|
||||
- Added a **static triage layer** in `search_context`:
|
||||
- keyword-density scoring,
|
||||
- `low_probability` flags for likely-noisy hits,
|
||||
- `triage_top5` shortlist for router-style pre-filtering.
|
||||
- Added `read_struct(filename, path_query)`:
|
||||
- resolves dot/bracket JSON paths to return node-level data only,
|
||||
- includes lightweight dependency injection (e.g., Step 4 persona reads include Step 2 brand voice context when available),
|
||||
- keeps output token-efficient for downstream agents.
|
||||
1
backend/emojis.txt
Normal file
1
backend/emojis.txt
Normal file
@@ -0,0 +1 @@
|
||||
{'🎙', '🛑', '🚀', '📖', '💳', '📈', '🌐', '📊', '📦', '🔧', '🔍'}
|
||||
46
backend/gunicorn_config.py
Normal file
46
backend/gunicorn_config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Gunicorn configuration for Render deployment."""
|
||||
import os
|
||||
import multiprocessing
|
||||
|
||||
# Bind to the port Render provides
|
||||
bind = f"0.0.0.0:{os.getenv('PORT', '10000')}"
|
||||
|
||||
# Use uvicorn workers
|
||||
worker_class = "uvicorn.workers.UvicornWorker"
|
||||
|
||||
# Single worker for memory efficiency on free tier
|
||||
workers = 1
|
||||
|
||||
# Timeout for slow startup (10 minutes to allow for model loading)
|
||||
timeout = 600
|
||||
|
||||
# Graceful timeout
|
||||
graceful_timeout = 30
|
||||
|
||||
# Keepalive
|
||||
keepalive = 5
|
||||
|
||||
# Logging
|
||||
accesslog = "-"
|
||||
errorlog = "-"
|
||||
loglevel = os.getenv("LOG_LEVEL", "info").lower()
|
||||
|
||||
# Don't preload - bind to port FIRST, then load worker
|
||||
preload_app = False
|
||||
|
||||
# Use the startup script that handles all the logic
|
||||
factory = False # app:app is not a factory, it's the app object
|
||||
|
||||
def on_starting(server):
|
||||
"""Called just before the master process is initialized."""
|
||||
print(f"[GUNICORN] Starting on {bind}", flush=True)
|
||||
|
||||
|
||||
def on_reload(server):
|
||||
"""Called when worker is reloaded."""
|
||||
print(f"[GUNICORN] Reloading workers", flush=True)
|
||||
|
||||
|
||||
def when_ready(server):
|
||||
"""Called just after the server is started."""
|
||||
print(f"[GUNICORN] Server is ready. Accepting connections.", flush=True)
|
||||
@@ -252,6 +252,8 @@ router_manager.include_core_routers()
|
||||
# Safety net: keep subscription routes available even if core inclusion flow changes
|
||||
# in special modes (e.g., demo mode). De-dup is handled by RouterManager.
|
||||
router_manager.include_router_safely(subscription_router, "subscription")
|
||||
# Include hallucination detector explicitly (router_manager may skip silently on import failure)
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
router_manager.include_optional_routers()
|
||||
|
||||
# SEO Dashboard endpoints
|
||||
|
||||
@@ -45,6 +45,9 @@ class PodcastProject(Base):
|
||||
knobs = Column(JSON, nullable=True) # Knobs settings
|
||||
research_provider = Column(String(50), nullable=True, default="google") # Research provider
|
||||
|
||||
# Project-specific topic context (category research, selected topics)
|
||||
topic_context = Column(JSON, nullable=True) # { category: "news"|"finance", topics: [...], selected_topic: {...} }
|
||||
|
||||
# UI state
|
||||
show_script_editor = Column(Boolean, default=False)
|
||||
show_render_queue = Column(Boolean, default=False)
|
||||
|
||||
@@ -80,6 +80,7 @@ class SubscriptionPlan(Base):
|
||||
video_calls_limit = Column(Integer, default=0) # AI video generation
|
||||
image_edit_calls_limit = Column(Integer, default=0) # AI image editing
|
||||
audio_calls_limit = Column(Integer, default=0) # AI audio generation (text-to-speech)
|
||||
wavespeed_calls_limit = Column(Integer, default=0) # WaveSpeed API calls (LLM + TTS + video + image)
|
||||
|
||||
# Token Limits (for LLM providers)
|
||||
gemini_tokens_limit = Column(Integer, default=0)
|
||||
|
||||
@@ -1,9 +1,43 @@
|
||||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
python -m pip install --retries 10 --timeout 120 -r requirements.txt
|
||||
echo "🚀 Starting ALwrity Build Process..."
|
||||
|
||||
# Download required NLTK and spaCy models during build phase
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
# 1. Update pip and essential build tools
|
||||
python -m pip install --upgrade pip setuptools wheel
|
||||
|
||||
# 2. Install requirements based on mode
|
||||
echo "📦 Checking ALWRITY_ENABLED_FEATURES..."
|
||||
ENABLED_FEATURES="${ALWRITY_ENABLED_FEATURES:-all}"
|
||||
echo "DEBUG: ENABLED_FEATURES='$ENABLED_FEATURES'"
|
||||
|
||||
case "$ENABLED_FEATURES" in
|
||||
all)
|
||||
echo "📦 Full mode: Installing all requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
# Download spaCy/NLTK models for full mode
|
||||
echo "🧠 Installing spaCy and NLTK models..."
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
;;
|
||||
podcast)
|
||||
echo "🔊 Podcast-only mode: Installing lean requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements-podcast.txt --only-binary :all: --retries 10 --timeout 120
|
||||
;;
|
||||
*)
|
||||
echo "🎯 Feature-limited mode ($ENABLED_FEATURES): Installing requirements..."
|
||||
req_file="requirements-${ENABLED_FEATURES}.txt"
|
||||
if [[ -f "$req_file" ]]; then
|
||||
python -m pip install --no-cache-dir -r "$req_file" --only-binary :all: --retries 10 --timeout 120
|
||||
else
|
||||
echo "⚠️ No feature-specific requirements file found ($req_file), installing full requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
|
||||
# 3. Clean up unnecessary build artifacts
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
rm -rf /root/.cache/pip 2>/dev/null || true
|
||||
|
||||
echo "✅ Build Complete!"
|
||||
|
||||
82
backend/requirements-podcast.txt
Normal file
82
backend/requirements-podcast.txt
Normal file
@@ -0,0 +1,82 @@
|
||||
# =====================================================
|
||||
# ALwrity Podcast-Only Requirements
|
||||
# Lean subset for podcast-only demo mode
|
||||
# =====================================================
|
||||
|
||||
# Core Web Server
|
||||
fastapi>=0.115.14
|
||||
starlette>=0.40.0,<0.47.0
|
||||
sse-starlette<3.0.0
|
||||
uvicorn>=0.24.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
gunicorn>=21.0.0
|
||||
|
||||
# Server utilities
|
||||
python-multipart>=0.0.6
|
||||
python-dotenv>=1.0.0
|
||||
loguru>=0.7.2
|
||||
tenacity>=8.2.3
|
||||
pydantic>=2.5.2,<3.0.0
|
||||
typing-extensions>=4.8.0
|
||||
setuptools>=65.0.0
|
||||
|
||||
# Auth & Database
|
||||
fastapi-clerk-auth>=0.0.7
|
||||
sqlalchemy>=2.0.25
|
||||
|
||||
# Payment
|
||||
stripe>=8.0.0
|
||||
|
||||
# HTTP clients
|
||||
httpx>=0.28.1
|
||||
aiohttp>=3.9.0
|
||||
requests>=2.31.0
|
||||
|
||||
# AI - needed for podcast
|
||||
openai>=1.3.0
|
||||
google-genai>=1.0.0
|
||||
exa-py==1.9.1
|
||||
|
||||
# Text processing (minimal)
|
||||
markdown>=3.5.0
|
||||
beautifulsoup4>=4.12.0
|
||||
|
||||
# Data processing (numpy needed for moviepy, pandas for usage tracking)
|
||||
numpy>=1.24.0
|
||||
pandas>=2.0.0
|
||||
|
||||
# Image/media for podcast
|
||||
Pillow>=10.0.0
|
||||
matplotlib>=3.7.0
|
||||
huggingface_hub>=1.1.4
|
||||
|
||||
# TTS for podcast
|
||||
gtts>=2.4.0
|
||||
pyttsx3>=2.90
|
||||
|
||||
# Video composition
|
||||
moviepy==2.1.2
|
||||
imageio>=2.31.0
|
||||
imageio-ffmpeg>=0.4.9
|
||||
|
||||
# Testing
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
|
||||
# Task scheduling
|
||||
apscheduler>=3.10.0
|
||||
|
||||
# Utilities
|
||||
redis>=5.0.0
|
||||
schedule>=1.2.0
|
||||
aiofiles>=23.2.0
|
||||
psutil>=5.9.0
|
||||
|
||||
# Google APIs
|
||||
google-api-python-client>=2.100.0
|
||||
google-auth>=2.23.0
|
||||
google-auth-oauthlib>=1.0.0
|
||||
|
||||
# Other utilities
|
||||
python-dateutil>=2.8.0
|
||||
jinja2>=3.1.0
|
||||
@@ -1,93 +1,81 @@
|
||||
# Core dependencies
|
||||
# Core dependencies - needed for all modes
|
||||
fastapi>=0.115.14
|
||||
starlette>=0.40.0,<0.47.0
|
||||
sse-starlette<3.0.0
|
||||
uvicorn>=0.24.0
|
||||
uvicorn[standard]>=0.24.0
|
||||
gunicorn>=21.0.0
|
||||
python-multipart>=0.0.6
|
||||
python-dotenv>=1.0.0
|
||||
loguru>=0.7.2
|
||||
tenacity>=8.2.3
|
||||
pydantic>=2.5.2,<3.0.0
|
||||
typing-extensions>=4.8.0
|
||||
|
||||
# Authentication and security
|
||||
# Auth
|
||||
PyJWT>=2.8.0
|
||||
cryptography>=41.0.0
|
||||
fastapi-clerk-auth>=0.0.7
|
||||
|
||||
# Database dependencies
|
||||
# Database
|
||||
sqlalchemy>=2.0.25
|
||||
|
||||
# Payment processing
|
||||
# Payment
|
||||
stripe>=8.0.0
|
||||
|
||||
# CopilotKit and Research
|
||||
copilotkit
|
||||
exa-py==1.9.1
|
||||
httpx>=0.27.2,<0.28.0
|
||||
# HTTP clients
|
||||
httpx>=0.28.1
|
||||
aiohttp>=3.9.0
|
||||
requests>=2.31.0
|
||||
|
||||
# AI/ML dependencies - Windows-compatible versions
|
||||
# AI - needed for podcast
|
||||
openai>=1.3.0
|
||||
google-genai>=1.0.0
|
||||
sentence-transformers>=2.2.2
|
||||
exa-py==1.9.1
|
||||
|
||||
# txtai with Windows-compatible dependencies
|
||||
txtai[agent]>=7.0.0
|
||||
|
||||
|
||||
google-api-python-client>=2.100.0
|
||||
google-auth>=2.23.0
|
||||
google-auth-oauthlib>=1.0.0
|
||||
|
||||
# Web scraping and content processing
|
||||
# Text processing
|
||||
markdown>=3.5.0
|
||||
beautifulsoup4>=4.12.0
|
||||
requests>=2.31.0
|
||||
urllib3<2.0.0
|
||||
chardet>=5.0.0
|
||||
charset-normalizer<3.0.0
|
||||
lxml>=4.9.0
|
||||
html5lib>=1.1
|
||||
aiohttp>=3.9.0
|
||||
advertools>=0.14.0
|
||||
|
||||
# Data processing
|
||||
pandas>=2.0.0
|
||||
numpy>=1.24.0
|
||||
markdown>=3.5.0
|
||||
|
||||
# SEO Analysis dependencies
|
||||
advertools>=0.14.0
|
||||
textstat>=0.7.3
|
||||
pyspellchecker>=0.7.2
|
||||
aiofiles>=23.2.0
|
||||
crawl4ai>=0.2.0
|
||||
|
||||
# Linguistic Analysis dependencies (Required for persona generation)
|
||||
spacy>=3.7.0
|
||||
nltk>=3.8.0
|
||||
|
||||
# Image and audio processing for Stability AI
|
||||
# Image/media for podcast
|
||||
Pillow>=10.0.0
|
||||
matplotlib>=3.8.0
|
||||
huggingface_hub>=1.1.4
|
||||
|
||||
# Text-to-Speech (TTS) dependencies
|
||||
# TTS for podcast
|
||||
gtts>=2.4.0
|
||||
pyttsx3>=2.90
|
||||
|
||||
# Video composition dependencies
|
||||
# Video composition
|
||||
moviepy==2.1.2
|
||||
imageio>=2.31.0
|
||||
imageio-ffmpeg>=0.4.9
|
||||
|
||||
# Testing dependencies
|
||||
# Testing
|
||||
pytest>=7.4.0
|
||||
pytest-asyncio>=0.21.0
|
||||
|
||||
# Utilities
|
||||
pydantic>=2.5.2,<3.0.0
|
||||
typing-extensions>=4.8.0
|
||||
|
||||
# Task scheduling
|
||||
apscheduler>=3.10.0
|
||||
|
||||
# Optional dependencies (for enhanced features)
|
||||
# Utilities
|
||||
redis>=5.0.0
|
||||
schedule>=1.2.0
|
||||
pytrends>=4.9.0
|
||||
schedule>=1.2.0
|
||||
aiofiles>=23.2.0
|
||||
psutil>=5.9.0
|
||||
|
||||
# Google APIs
|
||||
google-api-python-client>=2.100.0
|
||||
google-auth>=2.23.0
|
||||
google-auth-oauthlib>=1.0.0
|
||||
|
||||
# Other utilities
|
||||
python-dateutil>=2.8.0
|
||||
jinja2>=3.1.0
|
||||
pydantic-settings>=2.0.0
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
34
backend/routers/image_studio/__init__.py
Normal file
34
backend/routers/image_studio/__init__.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Image Studio API router package.
|
||||
|
||||
Composed from modular sub-routers. Same prefix and tags as the original monolithic file.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .health import router as health_router
|
||||
from .upscale import router as upscale_router
|
||||
from .control import router as control_router
|
||||
from .social import router as social_router
|
||||
from .edit import router as edit_router
|
||||
from .face_swap import router as face_swap_router
|
||||
from .create import router as create_router
|
||||
from .transform import router as transform_router
|
||||
from .compress import router as compress_router
|
||||
from .convert import router as convert_router
|
||||
from .save import router as save_router
|
||||
|
||||
router = APIRouter(prefix="/api/image-studio", tags=["image-studio"])
|
||||
|
||||
router.include_router(health_router)
|
||||
router.include_router(upscale_router)
|
||||
router.include_router(control_router)
|
||||
router.include_router(social_router)
|
||||
router.include_router(edit_router)
|
||||
router.include_router(face_swap_router)
|
||||
router.include_router(create_router)
|
||||
router.include_router(transform_router)
|
||||
router.include_router(compress_router)
|
||||
router.include_router(convert_router)
|
||||
router.include_router(save_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
158
backend/routers/image_studio/compress.py
Normal file
158
backend/routers/image_studio/compress.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Compression Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
CompressImageRequest, CompressImageResponse,
|
||||
CompressBatchRequest, CompressBatchResponse,
|
||||
CompressionEstimateRequest, CompressionEstimateResponse,
|
||||
CompressionFormatsResponse, CompressionPresetsResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/compress", response_model=CompressImageResponse, summary="Compress an image")
|
||||
async def compress_image(
|
||||
request: CompressImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Compress an image with specified quality and format settings."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image compression")
|
||||
logger.info(f"[Compression] Request from user {user_id}: format={request.format}, quality={request.quality}")
|
||||
|
||||
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
|
||||
|
||||
compression_request = ServiceRequest(
|
||||
image_base64=request.image_base64,
|
||||
quality=request.quality,
|
||||
format=request.format,
|
||||
target_size_kb=request.target_size_kb,
|
||||
strip_metadata=request.strip_metadata,
|
||||
progressive=request.progressive,
|
||||
optimize=request.optimize,
|
||||
)
|
||||
|
||||
result = await studio_manager.compress_image(compression_request, user_id=user_id)
|
||||
|
||||
return CompressImageResponse(
|
||||
success=result.success,
|
||||
image_base64=result.image_base64,
|
||||
original_size_kb=result.original_size_kb,
|
||||
compressed_size_kb=result.compressed_size_kb,
|
||||
compression_ratio=result.compression_ratio,
|
||||
format=result.format,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
quality_used=result.quality_used,
|
||||
metadata_stripped=result.metadata_stripped,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image compression failed: {e}")
|
||||
|
||||
|
||||
@router.post("/compress/batch", response_model=CompressBatchResponse, summary="Compress multiple images")
|
||||
async def compress_batch(
|
||||
request: CompressBatchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Compress multiple images with the same or individual settings."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "batch compression")
|
||||
logger.info(f"[Compression] Batch request from user {user_id}: {len(request.images)} images")
|
||||
|
||||
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
|
||||
|
||||
compression_requests = [
|
||||
ServiceRequest(
|
||||
image_base64=img.image_base64,
|
||||
quality=img.quality,
|
||||
format=img.format,
|
||||
target_size_kb=img.target_size_kb,
|
||||
strip_metadata=img.strip_metadata,
|
||||
progressive=img.progressive,
|
||||
optimize=img.optimize,
|
||||
)
|
||||
for img in request.images
|
||||
]
|
||||
|
||||
results = await studio_manager.compress_batch(compression_requests, user_id=user_id)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
|
||||
return CompressBatchResponse(
|
||||
success=failed == 0,
|
||||
results=[
|
||||
CompressImageResponse(
|
||||
success=r.success,
|
||||
image_base64=r.image_base64,
|
||||
original_size_kb=r.original_size_kb,
|
||||
compressed_size_kb=r.compressed_size_kb,
|
||||
compression_ratio=r.compression_ratio,
|
||||
format=r.format,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
quality_used=r.quality_used,
|
||||
metadata_stripped=r.metadata_stripped,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
total_images=len(results),
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Batch error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Batch compression failed: {e}")
|
||||
|
||||
|
||||
@router.post("/compress/estimate", response_model=CompressionEstimateResponse, summary="Estimate compression results")
|
||||
async def estimate_compression(
|
||||
request: CompressionEstimateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Estimate compression results without actually compressing the image."""
|
||||
try:
|
||||
result = await studio_manager.estimate_compression(
|
||||
request.image_base64,
|
||||
request.format,
|
||||
request.quality,
|
||||
)
|
||||
return CompressionEstimateResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Estimate error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Compression estimation failed: {e}")
|
||||
|
||||
|
||||
@router.get("/compress/formats", response_model=CompressionFormatsResponse, summary="Get supported compression formats")
|
||||
async def get_compression_formats(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get list of supported compression formats with their capabilities."""
|
||||
formats = studio_manager.get_compression_formats()
|
||||
return CompressionFormatsResponse(formats=formats)
|
||||
|
||||
|
||||
@router.get("/compress/presets", response_model=CompressionPresetsResponse, summary="Get compression presets")
|
||||
async def get_compression_presets(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get predefined compression presets for common use cases."""
|
||||
presets = studio_manager.get_compression_presets()
|
||||
return CompressionPresetsResponse(presets=presets)
|
||||
64
backend/routers/image_studio/control.py
Normal file
64
backend/routers/image_studio/control.py
Normal file
@@ -0,0 +1,64 @@
|
||||
"""Control Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import ControlImageRequest, ControlImageResponse, ControlOperationsResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, ControlStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/control/process", response_model=ControlImageResponse, summary="Process Control Studio request")
|
||||
async def process_control_image(
|
||||
request: ControlImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Control Studio operations such as sketch-to-image, structure control, style control, and style transfer."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image control")
|
||||
logger.info(f"[Control Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
control_request = ControlStudioRequest(
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
control_image_base64=request.control_image_base64,
|
||||
style_image_base64=request.style_image_base64,
|
||||
negative_prompt=request.negative_prompt,
|
||||
control_strength=request.control_strength,
|
||||
fidelity=request.fidelity,
|
||||
style_strength=request.style_strength,
|
||||
composition_fidelity=request.composition_fidelity,
|
||||
change_strength=request.change_strength,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
|
||||
result = await studio_manager.control_image(control_request, user_id=user_id)
|
||||
return ControlImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image control failed: {e}")
|
||||
|
||||
|
||||
@router.get("/control/operations", response_model=ControlOperationsResponse, summary="List Control Studio operations")
|
||||
async def get_control_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Control Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_control_operations()
|
||||
return ControlOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load control operations")
|
||||
143
backend/routers/image_studio/convert.py
Normal file
143
backend/routers/image_studio/convert.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Format Converter endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from .models import (
|
||||
ConvertFormatRequest, ConvertFormatResponse,
|
||||
ConvertFormatBatchRequest, ConvertFormatBatchResponse,
|
||||
SupportedFormatsResponse, FormatRecommendationsResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/convert-format", response_model=ConvertFormatResponse, summary="Convert image format")
|
||||
async def convert_format(
|
||||
request: ConvertFormatRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Convert an image to a different format."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "format conversion")
|
||||
logger.info(f"[Format Converter] Request from user {user_id}: {request.target_format}")
|
||||
|
||||
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
|
||||
|
||||
conversion_request = ServiceRequest(
|
||||
image_base64=request.image_base64,
|
||||
target_format=request.target_format,
|
||||
preserve_transparency=request.preserve_transparency,
|
||||
quality=request.quality,
|
||||
color_space=request.color_space,
|
||||
strip_metadata=request.strip_metadata,
|
||||
optimize=request.optimize,
|
||||
progressive=request.progressive,
|
||||
)
|
||||
|
||||
result = await studio_manager.convert_format(conversion_request, user_id=user_id)
|
||||
|
||||
return ConvertFormatResponse(
|
||||
success=result.success,
|
||||
image_base64=result.image_base64,
|
||||
original_format=result.original_format,
|
||||
target_format=result.target_format,
|
||||
original_size_kb=result.original_size_kb,
|
||||
converted_size_kb=result.converted_size_kb,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
transparency_preserved=result.transparency_preserved,
|
||||
metadata_preserved=result.metadata_preserved,
|
||||
color_space=result.color_space,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Format conversion failed: {e}")
|
||||
|
||||
|
||||
@router.post("/convert-format/batch", response_model=ConvertFormatBatchResponse, summary="Convert multiple images")
|
||||
async def convert_format_batch(
|
||||
request: ConvertFormatBatchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Convert multiple images to different formats."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "batch format conversion")
|
||||
logger.info(f"[Format Converter] Batch request from user {user_id}: {len(request.images)} images")
|
||||
|
||||
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
|
||||
|
||||
conversion_requests = [
|
||||
ServiceRequest(
|
||||
image_base64=img.image_base64,
|
||||
target_format=img.target_format,
|
||||
preserve_transparency=img.preserve_transparency,
|
||||
quality=img.quality,
|
||||
color_space=img.color_space,
|
||||
strip_metadata=img.strip_metadata,
|
||||
optimize=img.optimize,
|
||||
progressive=img.progressive,
|
||||
)
|
||||
for img in request.images
|
||||
]
|
||||
|
||||
results = await studio_manager.convert_format_batch(conversion_requests, user_id=user_id)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
|
||||
return ConvertFormatBatchResponse(
|
||||
success=failed == 0,
|
||||
results=[
|
||||
ConvertFormatResponse(
|
||||
success=r.success,
|
||||
image_base64=r.image_base64,
|
||||
original_format=r.original_format,
|
||||
target_format=r.target_format,
|
||||
original_size_kb=r.original_size_kb,
|
||||
converted_size_kb=r.converted_size_kb,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
transparency_preserved=r.transparency_preserved,
|
||||
metadata_preserved=r.metadata_preserved,
|
||||
color_space=r.color_space,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
total_images=len(results),
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] ❌ Batch error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Batch format conversion failed: {e}")
|
||||
|
||||
|
||||
@router.get("/convert-format/supported", response_model=SupportedFormatsResponse, summary="Get supported formats")
|
||||
async def get_supported_formats(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get list of supported conversion formats with their capabilities."""
|
||||
formats = studio_manager.get_supported_formats()
|
||||
return SupportedFormatsResponse(formats=formats)
|
||||
|
||||
|
||||
@router.get("/convert-format/recommendations", response_model=FormatRecommendationsResponse, summary="Get format recommendations")
|
||||
async def get_format_recommendations(
|
||||
source_format: str = Query(..., description="Source format"),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get format recommendations based on source format."""
|
||||
recommendations = studio_manager.get_format_recommendations(source_format)
|
||||
return FormatRecommendationsResponse(recommendations=recommendations)
|
||||
231
backend/routers/image_studio/create.py
Normal file
231
backend/routers/image_studio/create.py
Normal file
@@ -0,0 +1,231 @@
|
||||
"""Create Studio, Templates, Providers, Cost Estimation, and Platform Specs endpoints."""
|
||||
|
||||
import base64
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import CreateImageRequest, CostEstimationRequest
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, CreateStudioRequest
|
||||
from services.image_studio.templates import Platform, TemplateCategory
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/create", summary="Generate Image")
|
||||
async def create_image(
|
||||
request: CreateImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Generate image(s) using Create Studio."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image generation")
|
||||
logger.info(f"[Create Image] Request from user {user_id}: {request.prompt[:100]}")
|
||||
|
||||
studio_request = CreateStudioRequest(
|
||||
prompt=request.prompt,
|
||||
template_id=request.template_id,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
quality=request.quality,
|
||||
negative_prompt=request.negative_prompt,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
num_variations=request.num_variations,
|
||||
enhance_prompt=request.enhance_prompt,
|
||||
use_persona=request.use_persona,
|
||||
persona_id=request.persona_id,
|
||||
)
|
||||
|
||||
result = await studio_manager.create_image(studio_request, user_id=user_id)
|
||||
|
||||
for idx, img_result in enumerate(result["results"]):
|
||||
if "image_bytes" in img_result:
|
||||
img_result["image_base64"] = base64.b64encode(img_result["image_bytes"]).decode("utf-8")
|
||||
del img_result["image_bytes"]
|
||||
|
||||
logger.info(f"[Create Image] ✅ Success: {result['total_generated']} images generated")
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Create Image] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"[Create Image] ❌ Generation error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Create Image] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/templates", summary="Get Templates")
|
||||
async def get_templates(
|
||||
platform: Optional[Platform] = None,
|
||||
category: Optional[TemplateCategory] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available image templates."""
|
||||
try:
|
||||
templates = studio_manager.get_templates(platform=platform, category=category)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict)}
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/search", summary="Search Templates")
|
||||
async def search_templates(
|
||||
query: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Search templates by query."""
|
||||
try:
|
||||
templates = studio_manager.search_templates(query)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "query": query}
|
||||
except Exception as e:
|
||||
logger.error(f"[Search Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/recommend", summary="Recommend Templates")
|
||||
async def recommend_templates(
|
||||
use_case: str,
|
||||
platform: Optional[Platform] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Recommend templates based on use case."""
|
||||
try:
|
||||
templates = studio_manager.recommend_templates(use_case, platform=platform)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "use_case": use_case}
|
||||
except Exception as e:
|
||||
logger.error(f"[Recommend Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/providers", summary="Get Providers")
|
||||
async def get_providers(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available AI providers and their capabilities."""
|
||||
try:
|
||||
providers = studio_manager.get_providers()
|
||||
return {"providers": providers}
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Providers] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/estimate-cost", summary="Estimate Cost")
|
||||
async def estimate_cost(
|
||||
request: CostEstimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Estimate cost for image generation operations."""
|
||||
try:
|
||||
resolution = None
|
||||
if request.width and request.height:
|
||||
resolution = (request.width, request.height)
|
||||
estimate = studio_manager.estimate_cost(
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
operation=request.operation,
|
||||
num_images=request.num_images,
|
||||
resolution=resolution
|
||||
)
|
||||
return estimate
|
||||
except Exception as e:
|
||||
logger.error(f"[Estimate Cost] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/platform-specs/{platform}", summary="Get Platform Specifications")
|
||||
async def get_platform_specs(
|
||||
platform: Platform,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get specifications and requirements for a specific platform."""
|
||||
try:
|
||||
specs = studio_manager.get_platform_specs(platform)
|
||||
if not specs:
|
||||
raise HTTPException(status_code=404, detail=f"Specifications not found for platform: {platform}")
|
||||
return specs
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Platform Specs] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
35
backend/routers/image_studio/deps.py
Normal file
35
backend/routers/image_studio/deps.py
Normal file
@@ -0,0 +1,35 @@
|
||||
"""Shared dependencies for Image Studio API endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
|
||||
|
||||
def get_studio_manager() -> ImageStudioManager:
|
||||
"""Get Image Studio Manager instance."""
|
||||
return ImageStudioManager()
|
||||
|
||||
|
||||
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
|
||||
"""Ensure user_id is available for protected operations."""
|
||||
user_id = (
|
||||
current_user.get("sub")
|
||||
or current_user.get("user_id")
|
||||
or current_user.get("id")
|
||||
or current_user.get("clerk_user_id")
|
||||
)
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"[Image Studio] ❌ Missing user_id for %s operation - blocking request",
|
||||
operation,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authenticated user required for image operations.",
|
||||
)
|
||||
return user_id
|
||||
122
backend/routers/image_studio/edit.py
Normal file
122
backend/routers/image_studio/edit.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Edit Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
EditImageRequest, EditImageResponse, EditOperationsResponse,
|
||||
EditModelsResponse, EditModelRecommendationRequest, EditModelRecommendationResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, EditStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/edit/process", response_model=EditImageResponse, summary="Process Edit Studio request")
|
||||
async def process_edit_image(
|
||||
request: EditImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Edit Studio operations such as remove background, inpaint, or recolor."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image editing")
|
||||
logger.info(f"[Edit Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
edit_request = EditStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
mask_base64=request.mask_base64,
|
||||
search_prompt=request.search_prompt,
|
||||
select_prompt=request.select_prompt,
|
||||
background_image_base64=request.background_image_base64,
|
||||
lighting_image_base64=request.lighting_image_base64,
|
||||
expand_left=request.expand_left,
|
||||
expand_right=request.expand_right,
|
||||
expand_up=request.expand_up,
|
||||
expand_down=request.expand_down,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
style_preset=request.style_preset,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
options=request.options or {},
|
||||
)
|
||||
|
||||
result = await studio_manager.edit_image(edit_request, user_id=user_id)
|
||||
return EditImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image editing failed: {e}")
|
||||
|
||||
|
||||
@router.get("/edit/operations", response_model=EditOperationsResponse, summary="List Edit Studio operations")
|
||||
async def get_edit_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Edit Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_edit_operations()
|
||||
return EditOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load edit operations")
|
||||
|
||||
|
||||
@router.get("/edit/models", response_model=EditModelsResponse, summary="List available editing models")
|
||||
async def get_edit_models(
|
||||
operation: Optional[str] = None,
|
||||
tier: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available WaveSpeed editing models with metadata.
|
||||
|
||||
Query Parameters:
|
||||
- operation: Filter by operation type (e.g., "general_edit")
|
||||
- tier: Filter by tier ("budget", "mid", "premium")
|
||||
"""
|
||||
try:
|
||||
result = studio_manager.get_edit_models(operation=operation, tier=tier)
|
||||
return EditModelsResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Models] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load editing models")
|
||||
|
||||
|
||||
@router.post("/edit/recommend", response_model=EditModelRecommendationResponse, summary="Get model recommendation")
|
||||
async def recommend_edit_model(
|
||||
request: EditModelRecommendationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get recommended editing model based on operation, image resolution, and user preferences.
|
||||
|
||||
Auto-detects best model when user doesn't specify one.
|
||||
"""
|
||||
try:
|
||||
user_tier = request.user_tier
|
||||
if not user_tier and current_user:
|
||||
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
|
||||
|
||||
result = studio_manager.recommend_edit_model(
|
||||
operation=request.operation,
|
||||
image_resolution=request.image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=request.preferences,
|
||||
)
|
||||
return EditModelRecommendationResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Recommend] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
|
||||
89
backend/routers/image_studio/face_swap.py
Normal file
89
backend/routers/image_studio/face_swap.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Face Swap Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
FaceSwapRequest, FaceSwapResponse, FaceSwapModelsResponse,
|
||||
FaceSwapModelRecommendationRequest, FaceSwapModelRecommendationResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from services.image_studio.face_swap_service import FaceSwapStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/face-swap/process", response_model=FaceSwapResponse, summary="Process Face Swap")
|
||||
async def process_face_swap(
|
||||
request: FaceSwapRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Process face swap request with auto-detection and model selection."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "face swap")
|
||||
face_swap_request = FaceSwapStudioRequest(
|
||||
base_image_base64=request.base_image_base64,
|
||||
face_image_base64=request.face_image_base64,
|
||||
model=request.model,
|
||||
target_face_index=request.target_face_index,
|
||||
target_gender=request.target_gender,
|
||||
options=request.options,
|
||||
)
|
||||
result = await studio_manager.face_swap(face_swap_request, user_id=user_id)
|
||||
return FaceSwapResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap] ❌ Error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Face swap failed: {e}")
|
||||
|
||||
|
||||
@router.get("/face-swap/models", response_model=FaceSwapModelsResponse, summary="List available face swap models")
|
||||
async def get_face_swap_models(
|
||||
tier: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available WaveSpeed face swap models with metadata.
|
||||
|
||||
Query Parameters:
|
||||
- tier: Filter by tier ("budget", "mid", "premium")
|
||||
"""
|
||||
try:
|
||||
result = studio_manager.get_face_swap_models(tier=tier)
|
||||
return FaceSwapModelsResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap Models] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load face swap models")
|
||||
|
||||
|
||||
@router.post("/face-swap/recommend", response_model=FaceSwapModelRecommendationResponse, summary="Get face swap model recommendation")
|
||||
async def recommend_face_swap_model(
|
||||
request: FaceSwapModelRecommendationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get recommended face swap model based on image resolutions and user preferences.
|
||||
|
||||
Auto-detects best model when user doesn't specify one.
|
||||
"""
|
||||
try:
|
||||
user_tier = request.user_tier
|
||||
if not user_tier and current_user:
|
||||
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
|
||||
|
||||
result = studio_manager.recommend_face_swap_model(
|
||||
base_image_resolution=request.base_image_resolution,
|
||||
face_image_resolution=request.face_image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=request.preferences,
|
||||
)
|
||||
return FaceSwapModelRecommendationResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap Recommend] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
|
||||
21
backend/routers/image_studio/health.py
Normal file
21
backend/routers/image_studio/health.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Health check endpoint."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.get("/health", summary="Health Check")
|
||||
async def health_check():
|
||||
"""Health check endpoint for Image Studio."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "image_studio",
|
||||
"version": "1.0.0",
|
||||
"modules": {
|
||||
"create_studio": "available",
|
||||
"templates": "available",
|
||||
"providers": "available",
|
||||
"compression": "available",
|
||||
}
|
||||
}
|
||||
372
backend/routers/image_studio/models.py
Normal file
372
backend/routers/image_studio/models.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Pydantic request/response models for Image Studio API."""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ==================== Create Studio ====================
|
||||
|
||||
class CreateImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description="Image generation prompt")
|
||||
template_id: Optional[str] = Field(None, description="Template ID to use")
|
||||
provider: Optional[str] = Field("auto", description="Provider: auto, stability, wavespeed, huggingface, gemini")
|
||||
model: Optional[str] = Field(None, description="Specific model to use")
|
||||
width: Optional[int] = Field(None, description="Image width in pixels")
|
||||
height: Optional[int] = Field(None, description="Image height in pixels")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (e.g., '1:1', '16:9')")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
quality: str = Field("standard", description="Quality: draft, standard, premium")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale")
|
||||
steps: Optional[int] = Field(None, description="Number of inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
num_variations: int = Field(1, ge=1, le=10, description="Number of variations (1-10)")
|
||||
enhance_prompt: bool = Field(True, description="Enhance prompt with AI")
|
||||
use_persona: bool = Field(False, description="Use persona for brand consistency")
|
||||
persona_id: Optional[str] = Field(None, description="Persona ID")
|
||||
|
||||
|
||||
class CostEstimationRequest(BaseModel):
|
||||
provider: str = Field(..., description="Provider name")
|
||||
model: Optional[str] = Field(None, description="Model name")
|
||||
operation: str = Field("generate", description="Operation type")
|
||||
num_images: int = Field(1, ge=1, description="Number of images")
|
||||
width: Optional[int] = Field(None, description="Image width")
|
||||
height: Optional[int] = Field(None, description="Image height")
|
||||
|
||||
|
||||
# ==================== Edit Studio ====================
|
||||
|
||||
class EditImageRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Primary image payload (base64 or data URL)")
|
||||
operation: Literal[
|
||||
"remove_background",
|
||||
"inpaint",
|
||||
"outpaint",
|
||||
"search_replace",
|
||||
"search_recolor",
|
||||
"general_edit",
|
||||
] = Field(..., description="Edit operation to perform")
|
||||
prompt: Optional[str] = Field(None, description="Primary prompt/instruction")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt for providers that support it")
|
||||
mask_base64: Optional[str] = Field(None, description="Optional mask image in base64")
|
||||
search_prompt: Optional[str] = Field(None, description="Search prompt for replace operations")
|
||||
select_prompt: Optional[str] = Field(None, description="Select prompt for recolor operations")
|
||||
background_image_base64: Optional[str] = Field(None, description="Reference background image")
|
||||
lighting_image_base64: Optional[str] = Field(None, description="Reference lighting image")
|
||||
expand_left: Optional[int] = Field(0, description="Outpaint expansion in pixels (left)")
|
||||
expand_right: Optional[int] = Field(0, description="Outpaint expansion in pixels (right)")
|
||||
expand_up: Optional[int] = Field(0, description="Outpaint expansion in pixels (up)")
|
||||
expand_down: Optional[int] = Field(0, description="Outpaint expansion in pixels (down)")
|
||||
provider: Optional[str] = Field(None, description="Explicit provider override")
|
||||
model: Optional[str] = Field(None, description="Explicit model override")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset for Stability helpers")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale for general edits")
|
||||
steps: Optional[int] = Field(None, description="Inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
output_format: str = Field("png", description="Output format for edited image")
|
||||
options: Optional[Dict[str, Any]] = Field(None, description="Advanced provider-specific options (e.g., grow_mask)")
|
||||
|
||||
|
||||
class EditImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class EditOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
class EditModelsResponse(BaseModel):
|
||||
models: List[Dict[str, Any]]
|
||||
total: int
|
||||
|
||||
|
||||
class EditModelRecommendationRequest(BaseModel):
|
||||
operation: str
|
||||
image_resolution: Optional[Dict[str, int]] = None
|
||||
user_tier: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class EditModelRecommendationResponse(BaseModel):
|
||||
recommended_model: str
|
||||
reason: str
|
||||
alternatives: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Face Swap Studio ====================
|
||||
|
||||
class FaceSwapRequest(BaseModel):
|
||||
base_image_base64: str
|
||||
face_image_base64: str
|
||||
model: Optional[str] = None
|
||||
target_face_index: Optional[int] = None
|
||||
target_gender: Optional[str] = None
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FaceSwapResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class FaceSwapModelsResponse(BaseModel):
|
||||
models: List[Dict[str, Any]]
|
||||
total: int
|
||||
|
||||
|
||||
class FaceSwapModelRecommendationRequest(BaseModel):
|
||||
base_image_resolution: Optional[Dict[str, int]] = None
|
||||
face_image_resolution: Optional[Dict[str, int]] = None
|
||||
user_tier: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FaceSwapModelRecommendationResponse(BaseModel):
|
||||
recommended_model: str
|
||||
reason: str
|
||||
alternatives: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Upscale Studio ====================
|
||||
|
||||
class UpscaleImageRequest(BaseModel):
|
||||
image_base64: str
|
||||
mode: Literal["fast", "conservative", "creative", "auto"] = "auto"
|
||||
target_width: Optional[int] = Field(None, description="Target width in pixels")
|
||||
target_height: Optional[int] = Field(None, description="Target height in pixels")
|
||||
preset: Optional[str] = Field(None, description="Named preset (web, print, social)")
|
||||
prompt: Optional[str] = Field(None, description="Prompt for conservative/creative modes")
|
||||
|
||||
|
||||
class UpscaleImageResponse(BaseModel):
|
||||
success: bool
|
||||
mode: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
# ==================== Control Studio ====================
|
||||
|
||||
class ControlImageRequest(BaseModel):
|
||||
control_image_base64: str = Field(..., description="Control image (sketch/structure/style) in base64")
|
||||
operation: Literal["sketch", "structure", "style", "style_transfer"] = Field(..., description="Control operation")
|
||||
prompt: str = Field(..., description="Text prompt for generation")
|
||||
style_image_base64: Optional[str] = Field(None, description="Style reference image (for style_transfer only)")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
control_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Control strength (sketch/structure)")
|
||||
fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style fidelity (style operation)")
|
||||
style_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style strength (style_transfer)")
|
||||
composition_fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Composition fidelity (style_transfer)")
|
||||
change_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Change strength (style_transfer)")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (style operation)")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
output_format: str = Field("png", description="Output format")
|
||||
|
||||
|
||||
class ControlImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class ControlOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Social Optimizer ====================
|
||||
|
||||
class SocialOptimizeRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Source image in base64 or data URL")
|
||||
platforms: List[str] = Field(..., description="List of platforms to optimize for")
|
||||
format_names: Optional[Dict[str, str]] = Field(None, description="Specific format per platform")
|
||||
show_safe_zones: bool = Field(False, description="Include safe zone overlay in output")
|
||||
crop_mode: str = Field("smart", description="Crop mode: smart, center, or fit")
|
||||
focal_point: Optional[Dict[str, float]] = Field(None, description="Focal point for smart crop (x, y as 0-1)")
|
||||
output_format: str = Field("png", description="Output format (png or jpg)")
|
||||
|
||||
|
||||
class SocialOptimizeResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[Dict[str, Any]]
|
||||
total_optimized: int
|
||||
|
||||
|
||||
class PlatformFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Transform Studio ====================
|
||||
|
||||
class TransformImageToVideoRequestModel(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
prompt: str = Field(..., description="Text prompt describing the video")
|
||||
audio_base64: Optional[str] = Field(None, description="Optional audio file (wav/mp3, 3-30s, ≤15MB)")
|
||||
resolution: Literal["480p", "720p", "1080p"] = Field("720p", description="Output resolution")
|
||||
duration: Literal[5, 10] = Field(5, description="Video duration in seconds")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
enable_prompt_expansion: bool = Field(True, description="Enable prompt optimizer")
|
||||
|
||||
|
||||
class TalkingAvatarRequestModel(BaseModel):
|
||||
image_base64: str = Field(..., description="Person image in base64 or data URL")
|
||||
audio_base64: str = Field(..., description="Audio file in base64 or data URL (wav/mp3, max 10 minutes)")
|
||||
resolution: Literal["480p", "720p"] = Field("720p", description="Output resolution")
|
||||
prompt: Optional[str] = Field(None, description="Optional prompt for expression/style")
|
||||
mask_image_base64: Optional[str] = Field(None, description="Optional mask for animatable regions")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
|
||||
|
||||
class TransformVideoResponse(BaseModel):
|
||||
success: bool
|
||||
video_url: Optional[str] = None
|
||||
video_base64: Optional[str] = None
|
||||
duration: float
|
||||
resolution: str
|
||||
width: int
|
||||
height: int
|
||||
file_size: int
|
||||
cost: float
|
||||
provider: str
|
||||
model: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TransformCostEstimateRequest(BaseModel):
|
||||
operation: Literal["image-to-video", "talking-avatar"] = Field(..., description="Operation type")
|
||||
resolution: str = Field(..., description="Output resolution")
|
||||
duration: Optional[int] = Field(None, description="Video duration in seconds (for image-to-video)")
|
||||
|
||||
|
||||
class TransformCostEstimateResponse(BaseModel):
|
||||
estimated_cost: float
|
||||
breakdown: Dict[str, Any]
|
||||
currency: str
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
# ==================== Compression ====================
|
||||
|
||||
class CompressImageRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
quality: int = Field(85, ge=1, le=100, description="Compression quality (1-100)")
|
||||
format: str = Field("jpeg", description="Output format: jpeg, png, webp")
|
||||
target_size_kb: Optional[int] = Field(None, ge=10, description="Target file size in KB")
|
||||
strip_metadata: bool = Field(True, description="Remove EXIF metadata")
|
||||
progressive: bool = Field(True, description="Progressive JPEG encoding")
|
||||
optimize: bool = Field(True, description="Optimize encoding")
|
||||
|
||||
|
||||
class CompressImageResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_size_kb: float
|
||||
compressed_size_kb: float
|
||||
compression_ratio: float
|
||||
format: str
|
||||
width: int
|
||||
height: int
|
||||
quality_used: int
|
||||
metadata_stripped: bool
|
||||
|
||||
|
||||
class CompressBatchRequest(BaseModel):
|
||||
images: List[CompressImageRequest] = Field(..., description="List of images to compress")
|
||||
|
||||
|
||||
class CompressBatchResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[CompressImageResponse]
|
||||
total_images: int
|
||||
successful: int
|
||||
failed: int
|
||||
|
||||
|
||||
class CompressionEstimateRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
format: str = Field("jpeg", description="Output format")
|
||||
quality: int = Field(85, ge=1, le=100, description="Quality level")
|
||||
|
||||
|
||||
class CompressionEstimateResponse(BaseModel):
|
||||
original_size_kb: float
|
||||
estimated_size_kb: float
|
||||
estimated_reduction_percent: float
|
||||
width: int
|
||||
height: int
|
||||
format: str
|
||||
|
||||
|
||||
class CompressionFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class CompressionPresetsResponse(BaseModel):
|
||||
presets: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Format Converter ====================
|
||||
|
||||
class ConvertFormatRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
target_format: str = Field(..., description="Target format: png, jpeg, jpg, webp, gif, bmp, tiff")
|
||||
preserve_transparency: bool = Field(True, description="Preserve transparency when possible")
|
||||
quality: Optional[int] = Field(None, ge=1, le=100, description="Quality for lossy formats (1-100)")
|
||||
color_space: Optional[str] = Field(None, description="Color space: sRGB, Adobe RGB")
|
||||
strip_metadata: bool = Field(False, description="Remove EXIF metadata")
|
||||
optimize: bool = Field(True, description="Optimize encoding")
|
||||
progressive: bool = Field(True, description="Progressive JPEG encoding")
|
||||
|
||||
|
||||
class ConvertFormatResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_format: str
|
||||
target_format: str
|
||||
original_size_kb: float
|
||||
converted_size_kb: float
|
||||
width: int
|
||||
height: int
|
||||
transparency_preserved: bool
|
||||
metadata_preserved: bool
|
||||
color_space: Optional[str] = None
|
||||
|
||||
|
||||
class ConvertFormatBatchRequest(BaseModel):
|
||||
images: List[ConvertFormatRequest] = Field(..., description="List of images to convert")
|
||||
|
||||
|
||||
class ConvertFormatBatchResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[ConvertFormatResponse]
|
||||
total_images: int
|
||||
successful: int
|
||||
failed: int
|
||||
|
||||
|
||||
class SupportedFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class FormatRecommendationsResponse(BaseModel):
|
||||
recommendations: List[Dict[str, Any]]
|
||||
100
backend/routers/image_studio/save.py
Normal file
100
backend/routers/image_studio/save.py
Normal file
@@ -0,0 +1,100 @@
|
||||
"""Save generated images to the unified asset library."""
|
||||
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .deps import _require_user_id
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.storage_paths import get_repo_root, sanitize_user_id
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
class SaveToLibraryRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Base64-encoded image (or data URL)")
|
||||
prompt: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: Optional[float] = None
|
||||
operation: str = Field("image-generation", description="Operation type for labelling")
|
||||
output_format: str = Field("png", description="Output image format")
|
||||
|
||||
|
||||
@router.post("/save-to-library")
|
||||
async def save_to_library(
|
||||
req: SaveToLibraryRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Save a generated image to the asset library.
|
||||
|
||||
Decodes base64 image data, saves to workspace disk storage,
|
||||
and creates a record in the ContentAsset database table.
|
||||
"""
|
||||
user_id = _require_user_id(current_user, "save-to-library")
|
||||
|
||||
# Decode base64 payload
|
||||
try:
|
||||
b64data = req.image_base64
|
||||
if "base64," in b64data:
|
||||
b64data = b64data.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(b64data)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid base64 image data")
|
||||
|
||||
# Generate file path under workspace
|
||||
safe_user = sanitize_user_id(user_id)
|
||||
repo_root = get_repo_root()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"generated_{timestamp}.{req.output_format or 'png'}"
|
||||
|
||||
assets_dir = repo_root / "workspace" / f"workspace_{safe_user}" / "assets" / "images"
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = assets_dir / filename
|
||||
file_path.write_bytes(image_bytes)
|
||||
|
||||
# Build serving URL (assets_serving.py serves /{user_id}/avatars/{filename})
|
||||
file_url = f"/api/assets/{safe_user}/avatars/{filename}"
|
||||
|
||||
# Save to unified asset library via existing utility
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="image_studio",
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=str(file_path),
|
||||
file_size=len(image_bytes),
|
||||
mime_type=f"image/{req.output_format or 'png'}",
|
||||
title=f"Generated Image - {timestamp}",
|
||||
prompt=req.prompt,
|
||||
provider=req.provider,
|
||||
model=req.model,
|
||||
cost=req.cost,
|
||||
)
|
||||
|
||||
if not asset_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to save to asset library")
|
||||
|
||||
logger.info(f"[Save to Library] ✅ Image saved: asset_id={asset_id}, user={user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_id": asset_id,
|
||||
"file_url": file_url,
|
||||
"filename": filename,
|
||||
"file_size": len(image_bytes),
|
||||
}
|
||||
88
backend/routers/image_studio/social.py
Normal file
88
backend/routers/image_studio/social.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""Social Optimizer endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import SocialOptimizeRequest, SocialOptimizeResponse, PlatformFormatsResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, SocialOptimizerRequest
|
||||
from services.image_studio.templates import Platform
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/social/optimize", response_model=SocialOptimizeResponse, summary="Optimize image for social platforms")
|
||||
async def optimize_for_social(
|
||||
request: SocialOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Optimize an image for multiple social media platforms with smart cropping and safe zones."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "social optimization")
|
||||
logger.info(f"[Social Optimizer] Request from user {user_id}: platforms={request.platforms}")
|
||||
|
||||
platforms = []
|
||||
for platform_str in request.platforms:
|
||||
try:
|
||||
platforms.append(Platform(platform_str.lower()))
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform: {platform_str}")
|
||||
continue
|
||||
|
||||
if not platforms:
|
||||
raise HTTPException(status_code=400, detail="No valid platforms provided")
|
||||
|
||||
format_names = None
|
||||
if request.format_names:
|
||||
format_names = {}
|
||||
for platform_str, format_name in request.format_names.items():
|
||||
try:
|
||||
platform = Platform(platform_str.lower())
|
||||
format_names[platform] = format_name
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform in format_names: {platform_str}")
|
||||
|
||||
social_request = SocialOptimizerRequest(
|
||||
image_base64=request.image_base64,
|
||||
platforms=platforms,
|
||||
format_names=format_names,
|
||||
show_safe_zones=request.show_safe_zones,
|
||||
crop_mode=request.crop_mode,
|
||||
focal_point=request.focal_point,
|
||||
output_format=request.output_format,
|
||||
options={},
|
||||
)
|
||||
|
||||
result = await studio_manager.optimize_for_social(social_request, user_id=user_id)
|
||||
return SocialOptimizeResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Social Optimizer] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Social optimization failed: {e}")
|
||||
|
||||
|
||||
@router.get("/social/platforms/{platform}/formats", response_model=PlatformFormatsResponse, summary="Get platform formats")
|
||||
async def get_platform_formats(
|
||||
platform: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available formats for a social media platform."""
|
||||
try:
|
||||
try:
|
||||
platform_enum = Platform(platform.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid platform: {platform}")
|
||||
|
||||
formats = studio_manager.get_social_platform_formats(platform_enum)
|
||||
return PlatformFormatsResponse(formats=formats)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Platform Formats] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load platform formats: {e}")
|
||||
158
backend/routers/image_studio/transform.py
Normal file
158
backend/routers/image_studio/transform.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Transform Studio endpoints — image-to-video, talking avatar, and video serving."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .models import (
|
||||
TransformImageToVideoRequestModel, TalkingAvatarRequestModel,
|
||||
TransformVideoResponse, TransformCostEstimateRequest, TransformCostEstimateResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, TransformImageToVideoRequest, TalkingAvatarRequest
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/transform/image-to-video", response_model=TransformVideoResponse, summary="Transform Image to Video")
|
||||
async def transform_image_to_video(
|
||||
request: TransformImageToVideoRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Transform an image into a video using WAN 2.5."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image-to-video transformation")
|
||||
logger.info(f"[Transform Studio] Image-to-video request from user {user_id}: resolution={request.resolution}, duration={request.duration}s")
|
||||
|
||||
transform_request = TransformImageToVideoRequest(
|
||||
image_base64=request.image_base64,
|
||||
prompt=request.prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=request.enable_prompt_expansion,
|
||||
)
|
||||
|
||||
result = await studio_manager.transform_image_to_video(transform_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Image-to-video completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/talking-avatar", response_model=TransformVideoResponse, summary="Create Talking Avatar")
|
||||
async def create_talking_avatar(
|
||||
request: TalkingAvatarRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Create a talking avatar video using InfiniteTalk."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "talking avatar generation")
|
||||
logger.info(f"[Transform Studio] Talking avatar request from user {user_id}: resolution={request.resolution}")
|
||||
|
||||
avatar_request = TalkingAvatarRequest(
|
||||
image_base64=request.image_base64,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=request.prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
result = await studio_manager.create_talking_avatar(avatar_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Talking avatar completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Talking avatar generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/estimate-cost", response_model=TransformCostEstimateResponse, summary="Estimate Transform Cost")
|
||||
async def estimate_transform_cost(
|
||||
request: TransformCostEstimateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Estimate cost for transform operations."""
|
||||
try:
|
||||
estimate = studio_manager.estimate_transform_cost(
|
||||
operation=request.operation,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
)
|
||||
return TransformCostEstimateResponse(**estimate)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Cost estimation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/videos/{user_id}/{video_filename:path}", summary="Serve Transform Studio Video")
|
||||
async def serve_transform_video(
|
||||
user_id: str,
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve a generated Transform Studio video file."""
|
||||
try:
|
||||
authenticated_user_id = _require_user_id(current_user, "video access")
|
||||
if authenticated_user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: You can only access your own videos"
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
transform_videos_dir = base_dir / "transform_videos"
|
||||
video_path = transform_videos_dir / user_id / video_filename
|
||||
|
||||
try:
|
||||
resolved_video_path = video_path.resolve()
|
||||
resolved_base = transform_videos_dir.resolve()
|
||||
resolved_video_path.relative_to(resolved_base)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid video path: path traversal detected"
|
||||
)
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Video not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] Failed to serve video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
40
backend/routers/image_studio/upscale.py
Normal file
40
backend/routers/image_studio/upscale.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Upscale Studio endpoint."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import UpscaleImageRequest, UpscaleImageResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from services.image_studio.upscale_service import UpscaleStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/upscale", response_model=UpscaleImageResponse, summary="Upscale Image")
|
||||
async def upscale_image(
|
||||
request: UpscaleImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Upscale an image using Stability AI pipelines."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image upscaling")
|
||||
upscale_request = UpscaleStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
mode=request.mode,
|
||||
target_width=request.target_width,
|
||||
target_height=request.target_height,
|
||||
preset=request.preset,
|
||||
prompt=request.prompt,
|
||||
)
|
||||
result = await studio_manager.upscale_image(upscale_request, user_id=user_id)
|
||||
return UpscaleImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Upscale Image] ❌ Error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image upscaling failed: {e}")
|
||||
@@ -2,6 +2,10 @@
|
||||
"""
|
||||
Initialize Alpha Tester Subscription Tiers
|
||||
Creates subscription plans for alpha testing with appropriate limits.
|
||||
|
||||
NOTE: Pricing is seeded via PricingService.initialize_default_pricing()
|
||||
which runs in services/database.py:init_user_database()
|
||||
NOT via this script.
|
||||
"""
|
||||
|
||||
import sys
|
||||
@@ -10,7 +14,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from models.subscription_models import (
|
||||
SubscriptionPlan, SubscriptionTier, APIProviderPricing, APIProvider
|
||||
SubscriptionPlan, SubscriptionTier
|
||||
)
|
||||
from services.database import get_db_session
|
||||
from datetime import datetime
|
||||
@@ -24,7 +28,7 @@ def create_alpha_subscription_tiers():
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error("❌ Could not get database session")
|
||||
logger.error("Could not get database session")
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -38,12 +42,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Free tier for alpha testing - Limited usage",
|
||||
"features": ["blog_writer", "basic_seo", "content_planning"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 50, # 50 calls per day
|
||||
"gemini_tokens_limit": 10000, # 10k tokens per day
|
||||
"tavily_calls_limit": 20, # 20 searches per day
|
||||
"serper_calls_limit": 10, # 10 SEO searches per day
|
||||
"stability_calls_limit": 5, # 5 images per day
|
||||
"monthly_cost_limit": 5.0 # $5 monthly limit
|
||||
"gemini_calls_limit": 50,
|
||||
"gemini_tokens_limit": 10000,
|
||||
"tavily_calls_limit": 20,
|
||||
"serper_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"monthly_cost_limit": 5.0
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -54,12 +58,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Basic alpha tier - Moderate usage for testing",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 200, # 200 calls per day
|
||||
"gemini_tokens_limit": 50000, # 50k tokens per day
|
||||
"tavily_calls_limit": 100, # 100 searches per day
|
||||
"serper_calls_limit": 50, # 50 SEO searches per day
|
||||
"stability_calls_limit": 25, # 25 images per day
|
||||
"monthly_cost_limit": 25.0 # $25 monthly limit
|
||||
"gemini_calls_limit": 200,
|
||||
"gemini_tokens_limit": 50000,
|
||||
"tavily_calls_limit": 100,
|
||||
"serper_calls_limit": 50,
|
||||
"stability_calls_limit": 25,
|
||||
"monthly_cost_limit": 25.0
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -70,12 +74,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Pro alpha tier - High usage for power users",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot", "advanced_analytics"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 500, # 500 calls per day
|
||||
"gemini_tokens_limit": 150000, # 150k tokens per day
|
||||
"tavily_calls_limit": 300, # 300 searches per day
|
||||
"serper_calls_limit": 150, # 150 SEO searches per day
|
||||
"stability_calls_limit": 100, # 100 images per day
|
||||
"monthly_cost_limit": 100.0 # $100 monthly limit
|
||||
"gemini_calls_limit": 500,
|
||||
"gemini_tokens_limit": 150000,
|
||||
"tavily_calls_limit": 300,
|
||||
"serper_calls_limit": 150,
|
||||
"stability_calls_limit": 100,
|
||||
"monthly_cost_limit": 100.0
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -86,34 +90,31 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Enterprise alpha tier - Unlimited usage for enterprise testing",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot", "advanced_analytics", "custom_integrations"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 0, # Unlimited calls
|
||||
"gemini_tokens_limit": 0, # Unlimited tokens
|
||||
"tavily_calls_limit": 0, # Unlimited searches
|
||||
"serper_calls_limit": 0, # Unlimited SEO searches
|
||||
"stability_calls_limit": 0, # Unlimited images
|
||||
"monthly_cost_limit": 500.0 # $500 monthly limit
|
||||
"gemini_calls_limit": 0,
|
||||
"gemini_tokens_limit": 0,
|
||||
"tavily_calls_limit": 0,
|
||||
"serper_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"monthly_cost_limit": 500.0
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Create subscription plans
|
||||
for tier_data in alpha_tiers:
|
||||
# Check if plan already exists
|
||||
existing_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == tier_data["name"]
|
||||
).first()
|
||||
|
||||
if existing_plan:
|
||||
logger.info(f"✅ Plan '{tier_data['name']}' already exists, updating...")
|
||||
# Update existing plan
|
||||
logger.info(f"Plan '{tier_data['name']}' already exists, updating...")
|
||||
for key, value in tier_data["limits"].items():
|
||||
setattr(existing_plan, key, value)
|
||||
existing_plan.description = tier_data["description"]
|
||||
existing_plan.features = tier_data["features"]
|
||||
existing_plan.updated_at = datetime.utcnow()
|
||||
else:
|
||||
logger.info(f"🆕 Creating new plan: {tier_data['name']}")
|
||||
# Create new plan
|
||||
logger.info(f"Creating new plan: {tier_data['name']}")
|
||||
plan = SubscriptionPlan(
|
||||
name=tier_data["name"],
|
||||
tier=tier_data["tier"],
|
||||
@@ -126,106 +127,17 @@ def create_alpha_subscription_tiers():
|
||||
db.add(plan)
|
||||
|
||||
db.commit()
|
||||
logger.info("✅ Alpha subscription tiers created/updated successfully!")
|
||||
|
||||
# Create API provider pricing
|
||||
create_api_pricing(db)
|
||||
logger.info("Alpha subscription tiers created/updated successfully!")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating alpha subscription tiers: {e}")
|
||||
logger.error(f"Error creating alpha subscription tiers: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def create_api_pricing(db: Session):
|
||||
"""Create API provider pricing configuration."""
|
||||
|
||||
try:
|
||||
# Gemini pricing (based on current Google AI pricing)
|
||||
gemini_pricing = [
|
||||
{
|
||||
"model_name": "gemini-2.0-flash-exp",
|
||||
"cost_per_input_token": 0.00000075, # $0.75 per 1M tokens
|
||||
"cost_per_output_token": 0.000003, # $3 per 1M tokens
|
||||
"description": "Gemini 2.0 Flash Experimental"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-1.5-flash",
|
||||
"cost_per_input_token": 0.00000075, # $0.75 per 1M tokens
|
||||
"cost_per_output_token": 0.000003, # $3 per 1M tokens
|
||||
"description": "Gemini 1.5 Flash"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-1.5-pro",
|
||||
"cost_per_input_token": 0.00000125, # $1.25 per 1M tokens
|
||||
"cost_per_output_token": 0.000005, # $5 per 1M tokens
|
||||
"description": "Gemini 1.5 Pro"
|
||||
}
|
||||
]
|
||||
|
||||
# Tavily pricing
|
||||
tavily_pricing = [
|
||||
{
|
||||
"model_name": "search",
|
||||
"cost_per_search": 0.001, # $0.001 per search
|
||||
"description": "Tavily Search API"
|
||||
}
|
||||
]
|
||||
|
||||
# Serper pricing
|
||||
serper_pricing = [
|
||||
{
|
||||
"model_name": "search",
|
||||
"cost_per_search": 0.001, # $0.001 per search
|
||||
"description": "Serper Google Search API"
|
||||
}
|
||||
]
|
||||
|
||||
# Stability AI pricing
|
||||
stability_pricing = [
|
||||
{
|
||||
"model_name": "stable-diffusion-xl",
|
||||
"cost_per_image": 0.01, # $0.01 per image
|
||||
"description": "Stable Diffusion XL"
|
||||
}
|
||||
]
|
||||
|
||||
# Create pricing records
|
||||
pricing_configs = [
|
||||
(APIProvider.GEMINI, gemini_pricing),
|
||||
(APIProvider.TAVILY, tavily_pricing),
|
||||
(APIProvider.SERPER, serper_pricing),
|
||||
(APIProvider.STABILITY, stability_pricing)
|
||||
]
|
||||
|
||||
for provider, pricing_list in pricing_configs:
|
||||
for pricing_data in pricing_list:
|
||||
# Check if pricing already exists
|
||||
existing_pricing = db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == pricing_data["model_name"]
|
||||
).first()
|
||||
|
||||
if existing_pricing:
|
||||
logger.info(f"✅ Pricing for {provider.value}/{pricing_data['model_name']} already exists")
|
||||
else:
|
||||
logger.info(f"🆕 Creating pricing for {provider.value}/{pricing_data['model_name']}")
|
||||
pricing = APIProviderPricing(
|
||||
provider=provider,
|
||||
**pricing_data
|
||||
)
|
||||
db.add(pricing)
|
||||
|
||||
db.commit()
|
||||
logger.info("✅ API provider pricing created successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating API pricing: {e}")
|
||||
db.rollback()
|
||||
|
||||
def assign_default_plan_to_users():
|
||||
"""Assign Free Alpha plan to all existing users."""
|
||||
if os.getenv('ENABLE_ALPHA', 'false').lower() not in {'1','true','yes','on'}:
|
||||
@@ -234,32 +146,28 @@ def assign_default_plan_to_users():
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error("❌ Could not get database session")
|
||||
logger.error("Could not get database session")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get Free Alpha plan
|
||||
free_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == "Free Alpha"
|
||||
).first()
|
||||
|
||||
if not free_plan:
|
||||
logger.error("❌ Free Alpha plan not found")
|
||||
logger.error("Free Alpha plan not found")
|
||||
return False
|
||||
|
||||
# For now, we'll create a default user subscription
|
||||
# In a real system, you'd query actual users
|
||||
|
||||
from models.subscription_models import UserSubscription, BillingCycle, UsageStatus
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import timedelta
|
||||
|
||||
# Create default user subscription for testing
|
||||
default_user_id = "default_user"
|
||||
existing_subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == default_user_id
|
||||
).first()
|
||||
|
||||
if not existing_subscription:
|
||||
logger.info(f"🆕 Creating default subscription for {default_user_id}")
|
||||
logger.info(f"Creating default subscription for {default_user_id}")
|
||||
subscription = UserSubscription(
|
||||
user_id=default_user_id,
|
||||
plan_id=free_plan.id,
|
||||
@@ -272,33 +180,32 @@ def assign_default_plan_to_users():
|
||||
)
|
||||
db.add(subscription)
|
||||
db.commit()
|
||||
logger.info(f"✅ Default subscription created for {default_user_id}")
|
||||
logger.info(f"Default subscription created for {default_user_id}")
|
||||
else:
|
||||
logger.info(f"✅ Default subscription already exists for {default_user_id}")
|
||||
logger.info(f"Default subscription already exists for {default_user_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error assigning default plan: {e}")
|
||||
logger.error(f"Error assigning default plan: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Initializing Alpha Subscription Tiers...")
|
||||
logger.info("Initializing Alpha Subscription Tiers...")
|
||||
|
||||
success = create_alpha_subscription_tiers()
|
||||
if success:
|
||||
logger.info("✅ Subscription tiers created successfully!")
|
||||
logger.info("Subscription tiers created successfully!")
|
||||
|
||||
# Assign default plan
|
||||
assign_success = assign_default_plan_to_users()
|
||||
if assign_success:
|
||||
logger.info("✅ Default plan assigned successfully!")
|
||||
logger.info("Default plan assigned successfully!")
|
||||
else:
|
||||
logger.error("❌ Failed to assign default plan")
|
||||
logger.error("Failed to assign default plan")
|
||||
else:
|
||||
logger.error("❌ Failed to create subscription tiers")
|
||||
logger.error("Failed to create subscription tiers")
|
||||
|
||||
logger.info("🎉 Alpha subscription system initialization complete!")
|
||||
logger.info("Alpha subscription system initialization complete!")
|
||||
@@ -9,6 +9,7 @@ import json
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.blog_models import (
|
||||
MediumBlogGenerateRequest,
|
||||
@@ -26,7 +27,7 @@ class MediumBlogGenerator:
|
||||
def __init__(self):
|
||||
self.cache = persistent_content_cache
|
||||
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str, db: Session = None) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -499,7 +499,7 @@ class DatabaseTaskManager:
|
||||
)
|
||||
blog_writer_logger.log_error(e, "outline_generation_task", context={"task_id": task_id})
|
||||
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest):
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
|
||||
"""Background task to generate a medium blog using a single structured JSON call."""
|
||||
try:
|
||||
await self.update_progress(task_id, "📦 Packaging outline and metadata...", 0)
|
||||
@@ -512,7 +512,7 @@ class DatabaseTaskManager:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
task_id,
|
||||
user_id=request.user_id if hasattr(request, 'user_id') else (await self.get_task_status(task_id))['user_id'],
|
||||
user_id,
|
||||
db=self.db
|
||||
)
|
||||
|
||||
|
||||
@@ -70,22 +70,22 @@ STRATEGIC REQUIREMENTS:
|
||||
- Ensure engaging, actionable content throughout
|
||||
|
||||
Return JSON format:
|
||||
{
|
||||
{{
|
||||
"title_options": [
|
||||
"Title option 1",
|
||||
"Title option 2",
|
||||
"Title option 3"
|
||||
],
|
||||
"outline": [
|
||||
{
|
||||
{{
|
||||
"heading": "Section heading with primary keyword",
|
||||
"subheadings": ["Subheading 1", "Subheading 2", "Subheading 3"],
|
||||
"key_points": ["Key point 1", "Key point 2", "Key point 3"],
|
||||
"target_words": 300,
|
||||
"keywords": ["primary keyword", "secondary keyword"]
|
||||
}
|
||||
}}
|
||||
]
|
||||
}"""
|
||||
}}"""
|
||||
|
||||
def get_outline_schema(self) -> Dict[str, Any]:
|
||||
"""Get the structured JSON schema for outline generation."""
|
||||
|
||||
@@ -5,8 +5,8 @@ Enhances individual outline sections for better engagement and value.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import BlogOutlineSection
|
||||
import json
|
||||
|
||||
|
||||
class SectionEnhancer:
|
||||
@@ -73,14 +73,45 @@ class SectionEnhancer:
|
||||
"required": ["heading", "subheadings", "key_points", "target_words", "keywords"]
|
||||
}
|
||||
|
||||
enhanced_data = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=enhancement_prompt,
|
||||
json_struct=enhancement_schema,
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
enhanced_data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
enhanced_data = json.loads(json_match.group(0))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Section enhancement returned invalid JSON: {e}")
|
||||
return section
|
||||
else:
|
||||
logger.warning(f"Section enhancement returned non-JSON string: {cleaned[:200]}")
|
||||
return section
|
||||
elif isinstance(raw, dict):
|
||||
enhanced_data = raw
|
||||
else:
|
||||
logger.warning(f"Unexpected LLM response type: {type(raw)}")
|
||||
return section
|
||||
|
||||
if 'error' in enhanced_data:
|
||||
logger.warning(f"AI section enhancement failed: {enhanced_data.get('error', 'Unknown error')}")
|
||||
else:
|
||||
return BlogOutlineSection(
|
||||
id=section.id,
|
||||
heading=enhanced_data.get('heading', section.heading),
|
||||
|
||||
@@ -6,6 +6,7 @@ Extracts competitor insights and market intelligence from research content.
|
||||
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CompetitorAnalyzer:
|
||||
@@ -22,7 +23,7 @@ class CompetitorAnalyzer:
|
||||
Extract and analyze:
|
||||
1. Top competitors mentioned (companies, brands, platforms)
|
||||
2. Content gaps (what competitors are missing)
|
||||
3. Market opportunities (untapped areas)
|
||||
3. Opportunities (untapped areas)
|
||||
4. Competitive advantages (what makes content unique)
|
||||
5. Market positioning insights
|
||||
6. Industry leaders and their strategies
|
||||
@@ -55,18 +56,38 @@ class CompetitorAnalyzer:
|
||||
"required": ["top_competitors", "content_gaps", "opportunities", "competitive_advantages", "market_positioning", "industry_leaders", "analysis_notes"]
|
||||
}
|
||||
|
||||
competitor_analysis = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
competitor_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
competitor_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Competitor analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
competitor_analysis = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = competitor_analysis.get('error', 'Unknown error') if isinstance(competitor_analysis, dict) else str(competitor_analysis)
|
||||
logger.error(f"AI competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in competitor_analysis:
|
||||
raise ValueError(f"Competitor analysis failed: {competitor_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
|
||||
|
||||
@@ -63,18 +63,41 @@ class ContentAngleGenerator:
|
||||
"required": ["content_angles"]
|
||||
}
|
||||
|
||||
angles_result = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import json, re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
angles_result = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
angles_result = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Content angles returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
angles_result = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = angles_result.get('error', 'Unknown error') if isinstance(angles_result, dict) else str(angles_result)
|
||||
logger.error(f"AI content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in angles_result:
|
||||
raise ValueError(f"Content angles generation failed: {angles_result.get('error', 'Unknown error')}")
|
||||
|
||||
if 'content_angles' not in angles_result:
|
||||
raise ValueError(f"Content angles missing from response")
|
||||
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
|
||||
|
||||
@@ -314,11 +314,14 @@ class ExaResearchProvider(BaseProvider):
|
||||
|
||||
def track_exa_usage(self, user_id: str, cost: float):
|
||||
"""Track Exa API usage after successful call."""
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[track_exa_usage] Could not get DB session for user {user_id}")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
@@ -6,6 +6,7 @@ Extracts and analyzes keywords from research content using structured AI respons
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class KeywordAnalyzer:
|
||||
@@ -62,18 +63,38 @@ class KeywordAnalyzer:
|
||||
"required": ["primary", "secondary", "long_tail", "search_intent", "difficulty", "content_gaps", "semantic_keywords", "trending_terms", "analysis_insights"]
|
||||
}
|
||||
|
||||
keyword_analysis = llm_text_gen(
|
||||
raw = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
keyword_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
keyword_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Keyword analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
keyword_analysis = raw
|
||||
else:
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = keyword_analysis.get('error', 'Unknown error') if isinstance(keyword_analysis, dict) else str(keyword_analysis)
|
||||
logger.error(f"AI keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in keyword_analysis:
|
||||
raise ValueError(f"Keyword analysis failed: {keyword_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
|
||||
|
||||
@@ -111,19 +111,22 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
api_start_time = time.time()
|
||||
@@ -162,13 +165,15 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (similar to Exa)
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -429,14 +434,16 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
@@ -446,7 +453,8 @@ class ResearchService:
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
|
||||
@@ -485,14 +493,16 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Tavily AI search...")
|
||||
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -529,7 +539,8 @@ class ResearchService:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Tavily limits: {e}")
|
||||
finally:
|
||||
db_val.close()
|
||||
if db_val:
|
||||
db_val.close()
|
||||
|
||||
# Execute Tavily search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Tavily AI search...")
|
||||
|
||||
@@ -135,11 +135,14 @@ class TavilyResearchProvider(BaseProvider):
|
||||
|
||||
def track_tavily_usage(self, user_id: str, cost: float, search_depth: str):
|
||||
"""Track Tavily API usage after successful call."""
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[Tavily] Could not get DB session for user {user_id}, skipping usage tracking")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
@@ -92,6 +92,7 @@ class BlogSEORecommendationApplier:
|
||||
None,
|
||||
schema,
|
||||
user_id, # Pass user_id for subscription checking
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from typing import Optional, List
|
||||
|
||||
@@ -351,16 +352,15 @@ def init_database():
|
||||
|
||||
try:
|
||||
# Create all tables for all models using default engine
|
||||
OnboardingBase.metadata.create_all(bind=default_engine)
|
||||
SEOAnalysisBase.metadata.create_all(bind=default_engine)
|
||||
ContentPlanningBase.metadata.create_all(bind=default_engine)
|
||||
EnhancedStrategyBase.metadata.create_all(bind=default_engine)
|
||||
MonitoringBase.metadata.create_all(bind=default_engine)
|
||||
APIMonitoringBase.metadata.create_all(bind=default_engine)
|
||||
PersonaBase.metadata.create_all(bind=default_engine)
|
||||
SubscriptionBase.metadata.create_all(bind=default_engine)
|
||||
UserBusinessInfoBase.metadata.create_all(bind=default_engine)
|
||||
ContentAssetBase.metadata.create_all(bind=default_engine)
|
||||
# Use checkfirst=True (default) to avoid errors for existing tables
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
# Create tables with checkfirst=True explicitly to handle existing objects
|
||||
for base in [OnboardingBase, SEOAnalysisBase, ContentPlanningBase,
|
||||
EnhancedStrategyBase, MonitoringBase, APIMonitoringBase,
|
||||
PersonaBase, SubscriptionBase, UserBusinessInfoBase, ContentAssetBase]:
|
||||
base.metadata.create_all(bind=default_engine, checkfirst=True)
|
||||
logger.info("Global database initialized successfully")
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"Error initializing global database: {str(e)}")
|
||||
@@ -387,12 +387,15 @@ def get_db(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
user_id = current_user.get('id') or current_user.get('clerk_user_id')
|
||||
if not user_id:
|
||||
# Fallback or error? For now log error
|
||||
logger.error("No user ID found in context for DB connection")
|
||||
# Could raise exception, but let's try to be safe
|
||||
raise Exception("User ID required for database access")
|
||||
raise HTTPException(status_code=401, detail="User ID required for database access")
|
||||
|
||||
engine = get_engine_for_user(user_id)
|
||||
try:
|
||||
engine = get_engine_for_user(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[DB] Failed to create engine for user {user_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
@@ -237,6 +237,21 @@ class ControlStudioService:
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||||
|
||||
# Track usage
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"control-{operation}",
|
||||
operation_type="image-control",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/control/process",
|
||||
log_prefix="[Control Studio]"
|
||||
)
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"operation": operation,
|
||||
|
||||
@@ -514,6 +514,19 @@ class EditStudioService:
|
||||
background_bytes=background_bytes,
|
||||
lighting_bytes=lighting_bytes,
|
||||
)
|
||||
# Track usage for Stability operations
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"edit-{operation}",
|
||||
operation_type="image-edit",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/edit/process",
|
||||
log_prefix="[Edit Studio]"
|
||||
)
|
||||
else:
|
||||
image_bytes = await self._handle_general_edit(
|
||||
request=request,
|
||||
|
||||
@@ -88,6 +88,20 @@ class UpscaleStudioService:
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_metadata(image_bytes)
|
||||
|
||||
# Track usage
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"upscale-{mode}",
|
||||
operation_type="image-upscale",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/upscale",
|
||||
log_prefix="[Upscale Studio]"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"mode": mode,
|
||||
|
||||
@@ -233,7 +233,7 @@ def create_blog_post(
|
||||
|
||||
# BACK TO BASICS MODE: Try simplest possible structure FIRST
|
||||
# Since posting worked before Ricos/SEO, let's test with absolute minimum
|
||||
BACK_TO_BASICS_MODE = True # Set to True to test with simplest structure
|
||||
BACK_TO_BASICS_MODE = False # Disabled: full Ricos conversion now produces valid output
|
||||
|
||||
wix_logger.reset()
|
||||
wix_logger.log_operation_start("Blog Post Creation", title=title[:50] if title else None, member_id=member_id[:20] if member_id else None)
|
||||
@@ -257,8 +257,7 @@ def create_blog_post(
|
||||
'text': (content[:500] if content else "This is a post from ALwrity.").strip(),
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}]
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
@@ -256,17 +256,16 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
quote_content = ' '.join(quote_lines)
|
||||
text_nodes = parse_markdown_inline(quote_content)
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
blockquote_node = {
|
||||
'id': node_id,
|
||||
'type': 'BLOCKQUOTE',
|
||||
'nodes': [paragraph_node],
|
||||
'blockquoteData': {}
|
||||
}
|
||||
nodes.append(blockquote_node)
|
||||
|
||||
@@ -332,7 +331,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -345,7 +343,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'BULLETED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'bulletedListData': {}
|
||||
}
|
||||
nodes.append(bulleted_list_node)
|
||||
|
||||
@@ -373,7 +370,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -386,7 +382,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'ORDERED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'orderedListData': {}
|
||||
}
|
||||
nodes.append(ordered_list_node)
|
||||
|
||||
@@ -442,7 +437,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(paragraph_node)
|
||||
|
||||
@@ -461,7 +455,6 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(fallback_paragraph)
|
||||
|
||||
|
||||
745
backend/services/intelligence/agent_context_vfs.py
Normal file
745
backend/services/intelligence/agent_context_vfs.py
Normal file
@@ -0,0 +1,745 @@
|
||||
"""Read-only virtual filesystem facade for agent flat context documents.
|
||||
|
||||
This adapter provides shell-like primitives (`list_context`, `search_context`,
|
||||
`read_context_file`) over the JSON documents managed by AgentFlatContextStore.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import os
|
||||
import fcntl
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from collections import deque
|
||||
from fnmatch import fnmatch
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from services.intelligence.agent_flat_context import AgentFlatContextStore
|
||||
|
||||
|
||||
class SmartGrepEngine:
|
||||
"""Streaming grep engine with regex fallback and contextual snippets."""
|
||||
|
||||
def __init__(self, context_window: int = 1):
|
||||
self.context_window = max(0, int(context_window))
|
||||
|
||||
@staticmethod
|
||||
def _compile_pattern(pattern: str) -> re.Pattern:
|
||||
try:
|
||||
return re.compile(pattern, re.IGNORECASE)
|
||||
except re.error:
|
||||
return re.compile(re.escape(pattern), re.IGNORECASE)
|
||||
|
||||
@staticmethod
|
||||
def _truncate(text: str, limit: int = 180) -> str:
|
||||
text = " ".join(text.split())
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[:limit] + "..."
|
||||
|
||||
def stream_file(self, file_path: Path, pattern: str, *, path_label: str) -> List[Dict[str, Any]]:
|
||||
regex = self._compile_pattern(pattern)
|
||||
matches: List[Dict[str, Any]] = []
|
||||
prev = deque(maxlen=self.context_window)
|
||||
active: List[Dict[str, Any]] = []
|
||||
|
||||
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
|
||||
for line_no, line in enumerate(f, start=1):
|
||||
# Fill trailing context for active matches.
|
||||
for item in active:
|
||||
if item["remaining_after"] > 0:
|
||||
item["after"].append(line.rstrip("\n"))
|
||||
item["remaining_after"] -= 1
|
||||
|
||||
# Detect a new match on current line.
|
||||
if regex.search(line):
|
||||
current = line.rstrip("\n")
|
||||
record = {
|
||||
"path": path_label,
|
||||
"line": line_no,
|
||||
"before": list(prev),
|
||||
"match_line": current,
|
||||
"after": [],
|
||||
"remaining_after": self.context_window,
|
||||
}
|
||||
active.append(record)
|
||||
matches.append(record)
|
||||
|
||||
prev.append(line.rstrip("\n"))
|
||||
|
||||
formatted: List[Dict[str, Any]] = []
|
||||
for m in matches:
|
||||
snippet_parts = [*m["before"], m["match_line"], *m["after"]]
|
||||
snippet = self._truncate(" | ".join([p for p in snippet_parts if p is not None]))
|
||||
line_l = m["match_line"].lower()
|
||||
is_high_signal = any(k in line_l for k in ("agent_summary", "high_signal_terms", "quick_facts"))
|
||||
formatted.append(
|
||||
{
|
||||
"path": m["path"],
|
||||
"line": m["line"],
|
||||
"snippet": snippet,
|
||||
"relevance": "High Relevance" if is_high_signal else "Supporting Detail",
|
||||
"reason": "matched summary field in stream" if is_high_signal else "matched streamed body line",
|
||||
"score": 70 if is_high_signal else 50,
|
||||
}
|
||||
)
|
||||
return formatted
|
||||
|
||||
|
||||
class AgentContextVFS:
|
||||
"""Read-only adapter that maps virtual paths to flat context documents."""
|
||||
|
||||
VIRTUAL_MAP = {
|
||||
"/steps/website": AgentFlatContextStore.STEP2_FILENAME,
|
||||
"/steps/research": AgentFlatContextStore.STEP3_FILENAME,
|
||||
"/steps/persona": AgentFlatContextStore.STEP4_FILENAME,
|
||||
"/steps/integrations": AgentFlatContextStore.STEP5_FILENAME,
|
||||
}
|
||||
HIGH_SIGNAL_MARKERS = ("agent_summary", "high_signal_terms", "quick_facts", "context_type")
|
||||
|
||||
def __init__(self, user_id: str, project_id: Optional[str] = None):
|
||||
self.user_id = user_id
|
||||
self.project_id = project_id
|
||||
self.store = AgentFlatContextStore(user_id)
|
||||
self.grep_engine = SmartGrepEngine(context_window=1)
|
||||
|
||||
@staticmethod
|
||||
def _safe_slug(value: Optional[str], fallback: str) -> str:
|
||||
raw = str(value or "").strip()
|
||||
safe = "".join(c for c in raw if c.isalnum() or c in ("-", "_"))
|
||||
return safe or fallback
|
||||
|
||||
def _manifest_docs(self) -> List[Dict[str, Any]]:
|
||||
manifest = self.store.load_context_manifest() or {"documents": []}
|
||||
docs = manifest.get("documents")
|
||||
return docs if isinstance(docs, list) else []
|
||||
|
||||
def _workspace_root(self) -> Path:
|
||||
if self.project_id:
|
||||
root_dir = Path(__file__).resolve().parents[3]
|
||||
safe_project = self._safe_slug(self.project_id, "default_project")
|
||||
project_root = root_dir / "workspace" / f"project_{safe_project}"
|
||||
project_root.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(project_root, 0o700)
|
||||
return project_root
|
||||
return self.store._workspace_dir()
|
||||
|
||||
def _scratchpad_dir(self) -> Path:
|
||||
scratch = self._workspace_root() / "scratchpad"
|
||||
scratch.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(scratch, 0o700)
|
||||
return scratch
|
||||
|
||||
def _allowlisted_workspace_files(self) -> List[Path]:
|
||||
"""Return sandboxed files eligible for streaming search."""
|
||||
files: List[Path] = []
|
||||
workspace = self._workspace_root()
|
||||
context_dir = self.store._context_dir()
|
||||
|
||||
# 1) manifest-backed onboarding context files
|
||||
for item in self._manifest_docs():
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
rel = str(item.get("path") or "")
|
||||
if not rel:
|
||||
continue
|
||||
try:
|
||||
candidate = self.store._safe_resolve_under(context_dir, rel)
|
||||
if candidate.exists() and candidate.is_file():
|
||||
files.append(candidate)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 2) workspace text artifacts (README, operator notes, etc.)
|
||||
for candidate in workspace.glob("*.txt"):
|
||||
if candidate.is_file():
|
||||
files.append(candidate.resolve())
|
||||
readme = workspace / "README.md"
|
||||
if readme.exists() and readme.is_file():
|
||||
files.append(readme.resolve())
|
||||
|
||||
# dedupe
|
||||
seen = set()
|
||||
unique: List[Path] = []
|
||||
for p in files:
|
||||
rp = str(p)
|
||||
if rp in seen:
|
||||
continue
|
||||
seen.add(rp)
|
||||
unique.append(p)
|
||||
return unique
|
||||
|
||||
@staticmethod
|
||||
def _query_variants(query: str) -> List[str]:
|
||||
"""Generate normalized and synonym-expanded query variants."""
|
||||
base = (query or "").strip().lower()
|
||||
if not base:
|
||||
return []
|
||||
synonyms = {
|
||||
"tone": ["brand voice", "writing tone"],
|
||||
"voice": ["brand voice", "writing style"],
|
||||
"competitor": ["competition", "rival"],
|
||||
"seo": ["search", "metadata"],
|
||||
"persona": ["audience profile", "target audience"],
|
||||
}
|
||||
variants = [base]
|
||||
tokens = base.split()
|
||||
for idx, tok in enumerate(tokens):
|
||||
if tok in synonyms:
|
||||
for repl in synonyms[tok]:
|
||||
new_tokens = tokens.copy()
|
||||
new_tokens[idx] = repl
|
||||
variants.append(" ".join(new_tokens))
|
||||
variants.extend([base.replace("-", " "), base.replace("_", " ")])
|
||||
# dedupe, preserve order
|
||||
seen = set()
|
||||
out: List[str] = []
|
||||
for v in variants:
|
||||
vv = v.strip()
|
||||
if not vv or vv in seen:
|
||||
continue
|
||||
seen.add(vv)
|
||||
out.append(vv)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def _freshness_score(updated_at: Optional[str]) -> float:
|
||||
if not updated_at:
|
||||
return 0.3
|
||||
try:
|
||||
from datetime import datetime, timezone
|
||||
|
||||
ts = datetime.fromisoformat(str(updated_at).replace("Z", "+00:00"))
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
days = max(0.0, (datetime.now(timezone.utc) - ts).total_seconds() / 86400.0)
|
||||
if days <= 1:
|
||||
return 1.0
|
||||
if days <= 7:
|
||||
return 0.9
|
||||
if days <= 30:
|
||||
return 0.75
|
||||
if days <= 90:
|
||||
return 0.6
|
||||
return 0.4
|
||||
except Exception:
|
||||
return 0.3
|
||||
|
||||
def _cluster_results(self, results: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
||||
"""Deduplicate repeated hits by file + reason and keep strongest evidence."""
|
||||
buckets: Dict[Tuple[str, str], Dict[str, Any]] = {}
|
||||
for r in results:
|
||||
path = str(r.get("path") or "")
|
||||
reason = str(r.get("reason") or "")
|
||||
key = (path, reason)
|
||||
existing = buckets.get(key)
|
||||
if not existing:
|
||||
buckets[key] = {**r, "hit_count": 1}
|
||||
continue
|
||||
existing["hit_count"] = int(existing.get("hit_count", 1)) + 1
|
||||
if int(r.get("score", 0)) > int(existing.get("score", 0)):
|
||||
existing.update({k: v for k, v in r.items() if k != "hit_count"})
|
||||
existing["hit_count"] = int(existing.get("hit_count", 1))
|
||||
clustered = list(buckets.values())
|
||||
clustered.sort(key=lambda r: (-int(r.get("score", 0)), str(r.get("path") or "")))
|
||||
return clustered
|
||||
|
||||
def _keyword_density(self, snippet: str, query: str) -> float:
|
||||
if not snippet or not query:
|
||||
return 0.0
|
||||
query_tokens = [t for t in query.lower().split() if t]
|
||||
if not query_tokens:
|
||||
return 0.0
|
||||
text = snippet.lower()
|
||||
hits = sum(text.count(tok) for tok in query_tokens)
|
||||
words = max(1, len(text.split()))
|
||||
return hits / words
|
||||
|
||||
def _static_triage(self, results: List[Dict[str, Any]], query: str) -> List[Dict[str, Any]]:
|
||||
"""Semgrep-style static heuristic triage before main agent consumption."""
|
||||
triaged: List[Dict[str, Any]] = []
|
||||
for r in results:
|
||||
snippet = str(r.get("snippet") or "")
|
||||
density = self._keyword_density(snippet, query)
|
||||
marker_hit = any(marker in snippet.lower() for marker in self.HIGH_SIGNAL_MARKERS)
|
||||
low_probability = bool(density < 0.01 and not marker_hit)
|
||||
item = dict(r)
|
||||
item["keyword_density"] = round(density, 4)
|
||||
item["low_probability"] = low_probability
|
||||
triaged.append(item)
|
||||
triaged.sort(
|
||||
key=lambda x: (
|
||||
bool(x.get("low_probability")),
|
||||
-float(x.get("confidence", 0)),
|
||||
-int(x.get("score", 0)),
|
||||
)
|
||||
)
|
||||
return triaged
|
||||
|
||||
@staticmethod
|
||||
def _llm_router_stub(results: List[Dict[str, Any]], top_k: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Fast local triage stub (drop low-probability first; keep strongest candidates)."""
|
||||
ranked = sorted(
|
||||
results,
|
||||
key=lambda x: (
|
||||
bool(x.get("low_probability")),
|
||||
-float(x.get("confidence", 0)),
|
||||
-int(x.get("score", 0)),
|
||||
),
|
||||
)
|
||||
return ranked[: max(1, top_k)]
|
||||
|
||||
@staticmethod
|
||||
def _resolve_json_path(data: Any, path_query: str) -> Any:
|
||||
"""Resolve dot/bracket JSON path such as 'data.seo_audit.recommendations[0]'."""
|
||||
if not path_query:
|
||||
return data
|
||||
|
||||
current = data
|
||||
query = path_query.strip()
|
||||
parts: List[str] = []
|
||||
buf = ""
|
||||
in_brackets = False
|
||||
for ch in query:
|
||||
if ch == "." and not in_brackets:
|
||||
if buf:
|
||||
parts.append(buf)
|
||||
buf = ""
|
||||
continue
|
||||
if ch == "[":
|
||||
in_brackets = True
|
||||
elif ch == "]":
|
||||
in_brackets = False
|
||||
buf += ch
|
||||
if buf:
|
||||
parts.append(buf)
|
||||
|
||||
for part in parts:
|
||||
if "[" in part and part.endswith("]"):
|
||||
key, idx_raw = part.split("[", 1)
|
||||
idx = int(idx_raw[:-1])
|
||||
if key:
|
||||
if not isinstance(current, dict):
|
||||
raise KeyError(key)
|
||||
current = current[key]
|
||||
if not isinstance(current, list):
|
||||
raise IndexError(idx)
|
||||
current = current[idx]
|
||||
else:
|
||||
if not isinstance(current, dict):
|
||||
raise KeyError(part)
|
||||
current = current[part]
|
||||
return current
|
||||
|
||||
def _resolve_path(self, path: str) -> Tuple[str, Optional[str]]:
|
||||
normalized = (path or "").strip()
|
||||
if not normalized:
|
||||
return "", None
|
||||
if normalized == "/env/summary":
|
||||
return "virtual_summary", None
|
||||
if normalized in self.VIRTUAL_MAP:
|
||||
return "file", self.VIRTUAL_MAP[normalized]
|
||||
if ".." in normalized or "\\" in normalized:
|
||||
return "", None
|
||||
if normalized.startswith("/"):
|
||||
candidate = normalized.rsplit("/", 1)[-1]
|
||||
else:
|
||||
candidate = normalized
|
||||
if "/" in candidate:
|
||||
return "", None
|
||||
allowed = AgentFlatContextStore.ALLOWED_CONTEXT_FILES - {AgentFlatContextStore.MANIFEST_FILENAME}
|
||||
if candidate not in allowed:
|
||||
return "", None
|
||||
return "file", candidate
|
||||
|
||||
def list_context(self) -> Dict[str, Any]:
|
||||
"""List available context files (ls-equivalent)."""
|
||||
docs = self._manifest_docs()
|
||||
items = []
|
||||
for d in docs:
|
||||
if not isinstance(d, dict):
|
||||
continue
|
||||
items.append(
|
||||
{
|
||||
"path": d.get("path"),
|
||||
"type": d.get("type"),
|
||||
"updated_at": d.get("updated_at"),
|
||||
"size_bytes": d.get("size_bytes", 0),
|
||||
}
|
||||
)
|
||||
items.sort(key=lambda x: str(x.get("path") or ""))
|
||||
result = {
|
||||
"workspace_hint": "Use this list to see which onboarding steps are complete.",
|
||||
"tip": "Use `search_context` to find specific keywords across all steps.",
|
||||
"virtual_paths": ["/env/summary", *sorted(self.VIRTUAL_MAP.keys())],
|
||||
"files": items,
|
||||
"collaboration": {
|
||||
"scratchpad_dir": str(self._scratchpad_dir()),
|
||||
"activity_log": "scratchpad/activity_log.jsonl",
|
||||
},
|
||||
}
|
||||
logger.info(f"[vfs_audit] user={self.store.safe_user_id} action=list_context files={len(items)}")
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _flatten_strings(data: Any, limit: int = 2000) -> str:
|
||||
pieces: List[str] = []
|
||||
|
||||
def walk(v: Any) -> None:
|
||||
if len(pieces) >= limit:
|
||||
return
|
||||
if isinstance(v, dict):
|
||||
for key, value in v.items():
|
||||
pieces.append(str(key))
|
||||
walk(value)
|
||||
elif isinstance(v, list):
|
||||
for item in v:
|
||||
walk(item)
|
||||
elif isinstance(v, (str, int, float, bool)):
|
||||
pieces.append(str(v))
|
||||
|
||||
walk(data)
|
||||
return " ".join(pieces)
|
||||
|
||||
@staticmethod
|
||||
def _extract_search_fields(doc: Dict[str, Any]) -> Tuple[List[str], Dict[str, Any], str]:
|
||||
summary = doc.get("agent_summary") if isinstance(doc.get("agent_summary"), dict) else {}
|
||||
hints = summary.get("retrieval_hints") if isinstance(summary.get("retrieval_hints"), dict) else {}
|
||||
quick_facts = summary.get("quick_facts") if isinstance(summary.get("quick_facts"), dict) else {}
|
||||
high_terms = hints.get("high_signal_terms") if isinstance(hints.get("high_signal_terms"), list) else []
|
||||
body = AgentContextVFS._flatten_strings(doc.get("data") if isinstance(doc.get("data"), dict) else {})
|
||||
return [str(t).lower() for t in high_terms], quick_facts, body.lower()
|
||||
|
||||
def search_context(self, query: str, *, limit: int = 10, path_glob: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Smart grep with coarse-to-fine ranking and parallel stream scans."""
|
||||
normalized = (query or "").strip()
|
||||
if not normalized:
|
||||
return {"query": query, "results": []}
|
||||
self.store._audit_event("vfs_search", normalized, "started")
|
||||
try:
|
||||
variants = self._query_variants(normalized)
|
||||
attempted_queries: List[str] = []
|
||||
scored: List[Dict[str, Any]] = []
|
||||
|
||||
for candidate_query in variants:
|
||||
attempted_queries.append(candidate_query)
|
||||
needle = candidate_query.lower()
|
||||
|
||||
# Pass 1: summary-first ranking (high relevance)
|
||||
docs = self._manifest_docs()
|
||||
variant_scored: List[Dict[str, Any]] = []
|
||||
for item in docs:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
path = str(item.get("path") or "")
|
||||
if not path:
|
||||
continue
|
||||
if path_glob and not fnmatch(path, path_glob):
|
||||
continue
|
||||
doc = self.store.load_context_document(path) or {}
|
||||
high_terms, quick_facts, _ = self._extract_search_fields(doc)
|
||||
|
||||
high_match = any(needle in term for term in high_terms)
|
||||
quick_match = any(needle in str(v).lower() for v in quick_facts.values()) if isinstance(quick_facts, dict) else False
|
||||
if not (high_match or quick_match):
|
||||
continue
|
||||
|
||||
score = 100 if high_match else 80
|
||||
reason = "matched high_signal_terms" if high_match else "matched quick_facts"
|
||||
variant_scored.append(
|
||||
{
|
||||
"path": path,
|
||||
"line": None,
|
||||
"snippet": f"{reason}: {candidate_query}"[:100],
|
||||
"type": item.get("type"),
|
||||
"updated_at": item.get("updated_at"),
|
||||
"relevance": "High Relevance",
|
||||
"reason": reason,
|
||||
"score": score,
|
||||
}
|
||||
)
|
||||
|
||||
# Pass 2: parallelized stream scan over allowlisted workspace files.
|
||||
allowlisted = self._allowlisted_workspace_files()
|
||||
body_matches: List[Dict[str, Any]] = []
|
||||
if allowlisted:
|
||||
with ThreadPoolExecutor(max_workers=min(8, max(1, len(allowlisted)))) as pool:
|
||||
future_map = {}
|
||||
for p in allowlisted:
|
||||
path_label = p.name
|
||||
if path_glob and not fnmatch(path_label, path_glob):
|
||||
continue
|
||||
future = pool.submit(self.grep_engine.stream_file, p, candidate_query, path_label=path_label)
|
||||
future_map[future] = path_label
|
||||
|
||||
for future in as_completed(future_map):
|
||||
try:
|
||||
body_matches.extend(future.result() or [])
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
variant_scored.extend(body_matches)
|
||||
if variant_scored:
|
||||
scored = variant_scored
|
||||
break
|
||||
|
||||
scored = self._cluster_results(scored)
|
||||
|
||||
# Add confidence based on score + freshness + hit density.
|
||||
for r in scored:
|
||||
base = min(1.0, max(0.0, float(r.get("score", 0)) / 100.0))
|
||||
freshness = self._freshness_score(r.get("updated_at"))
|
||||
density = min(1.0, 0.2 + (int(r.get("hit_count", 1)) * 0.1))
|
||||
confidence = round((base * 0.6) + (freshness * 0.25) + (density * 0.15), 3)
|
||||
r["confidence"] = confidence
|
||||
|
||||
scored.sort(key=lambda r: (-int(r.get("score", 0)), str(r.get("path") or "")))
|
||||
matched_files = sorted({str(r.get("path") or "") for r in scored if r.get("path")})
|
||||
capped_results = scored[: max(1, limit)]
|
||||
notice = None
|
||||
if len(matched_files) > 10:
|
||||
notice = f"Found {len(matched_files)} matches. Showing top 10. Use a more specific keyword to narrow down."
|
||||
capped_results = scored[:10]
|
||||
|
||||
# Token/length budgeting (~2000 tokens ~= ~8000 chars).
|
||||
budget_chars = 8000
|
||||
bounded_results = []
|
||||
used = 0
|
||||
for r in capped_results:
|
||||
snippet = str(r.get("snippet") or "")
|
||||
cost = len(snippet) + 120 # account for metadata fields
|
||||
if bounded_results and used + cost > budget_chars:
|
||||
break
|
||||
bounded_results.append(r)
|
||||
used += cost
|
||||
|
||||
result = {
|
||||
"query": normalized,
|
||||
"attempted_queries": attempted_queries,
|
||||
"matched_files_count": len(matched_files),
|
||||
"results": self._static_triage(bounded_results, normalized),
|
||||
"notice": notice,
|
||||
"char_budget_used": used,
|
||||
"can_answer": bool(bounded_results),
|
||||
}
|
||||
result["triage_top5"] = self._llm_router_stub(result["results"], top_k=5)
|
||||
logger.info(
|
||||
f"[vfs_audit] user={self.store.safe_user_id} action=search_context query={normalized!r} results={len(result['results'])}"
|
||||
)
|
||||
self.store._audit_event("vfs_search", normalized, f"success_{len(result['results'])}_hits")
|
||||
return result
|
||||
except Exception as exc:
|
||||
self.store._audit_event("vfs_search", normalized, f"failed_{exc.__class__.__name__}")
|
||||
return {"query": normalized, "matched_files_count": 0, "results": [], "notice": "Search failed.", "can_answer": False}
|
||||
|
||||
@staticmethod
|
||||
def _strip_technical_metadata(doc: Dict[str, Any]) -> Dict[str, Any]:
|
||||
sanitized = {
|
||||
"context_type": doc.get("context_type"),
|
||||
"updated_at": doc.get("updated_at"),
|
||||
"journey": ((doc.get("document_context") or {}).get("journey") or {}) if isinstance(doc.get("document_context"), dict) else {},
|
||||
"agent_summary": doc.get("agent_summary") if isinstance(doc.get("agent_summary"), dict) else {},
|
||||
"data": doc.get("data") if isinstance(doc.get("data"), dict) else {},
|
||||
}
|
||||
return sanitized
|
||||
|
||||
def inspect_file(self, path: str, *, key: Optional[str] = None, small_file_bytes: int = 5 * 1024) -> Dict[str, Any]:
|
||||
"""Smart reader (cat/head equivalent) with summary-first behavior."""
|
||||
kind, resolved = self._resolve_path(path)
|
||||
if kind == "virtual_summary":
|
||||
result = {
|
||||
"path": "/env/summary",
|
||||
"mode": "summary",
|
||||
"data": self.store.generate_total_summary(),
|
||||
}
|
||||
logger.info(f"[vfs_audit] user={self.store.safe_user_id} action=read_context_file path=/env/summary mode=summary")
|
||||
return result
|
||||
|
||||
if not resolved:
|
||||
logger.info(f"[vfs_audit] user={self.store.safe_user_id} action=read_context_file path={path!r} status=rejected")
|
||||
return {"error": "File not found", "path": path}
|
||||
|
||||
# JSON context doc path
|
||||
doc = self.store.load_context_document(resolved)
|
||||
if doc:
|
||||
view = self._strip_technical_metadata(doc)
|
||||
data = view.get("data") if isinstance(view.get("data"), dict) else {}
|
||||
raw_size = self.store.estimate_size_bytes(view)
|
||||
|
||||
if key:
|
||||
if key in data:
|
||||
result = {
|
||||
"path": resolved,
|
||||
"mode": "key",
|
||||
"key": key,
|
||||
"agent_summary": view.get("agent_summary"),
|
||||
"data": data.get(key),
|
||||
}
|
||||
logger.info(f"[vfs_audit] user={self.store.safe_user_id} action=inspect_file path={resolved} mode=key")
|
||||
return result
|
||||
logger.info(
|
||||
f"[vfs_audit] user={self.store.safe_user_id} action=inspect_file path={resolved} mode=key_missing key={key}"
|
||||
)
|
||||
return {
|
||||
"path": resolved,
|
||||
"mode": "key_missing",
|
||||
"key": key,
|
||||
"available_keys": sorted(list(data.keys())),
|
||||
"message": "Requested key not found. Choose one of available_keys.",
|
||||
}
|
||||
|
||||
if raw_size <= small_file_bytes:
|
||||
result = {
|
||||
"path": resolved,
|
||||
"mode": "full",
|
||||
"data": view,
|
||||
}
|
||||
logger.info(f"[vfs_audit] user={self.store.safe_user_id} action=inspect_file path={resolved} mode=full")
|
||||
return result
|
||||
|
||||
result = {
|
||||
"path": resolved,
|
||||
"mode": "summary_plus_keys",
|
||||
"size_bytes": raw_size,
|
||||
"agent_summary": view.get("agent_summary"),
|
||||
"keys": sorted(list(data.keys())),
|
||||
"message": "File is large. Re-run with key to inspect a specific section.",
|
||||
}
|
||||
logger.info(f"[vfs_audit] user={self.store.safe_user_id} action=inspect_file path={resolved} mode=summary_plus_keys")
|
||||
return result
|
||||
|
||||
logger.info(f"[vfs_audit] user={self.store.safe_user_id} action=inspect_file path={resolved} status=not_found")
|
||||
return {"error": "File not found", "path": path, "resolved": resolved}
|
||||
|
||||
def read_context_file(self, path: str, *, subkey: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Backward-compatible alias for inspect_file."""
|
||||
return self.inspect_file(path, key=subkey)
|
||||
|
||||
def write_context_file(self, *_args: Any, **_kwargs: Any) -> None:
|
||||
"""Disallow writes from the agent-facing VFS."""
|
||||
raise OSError("EROFS: read-only file system")
|
||||
|
||||
# Backward-compat function name requested in design docs.
|
||||
inspect = inspect_file
|
||||
|
||||
def write_shared_note(self, note: str, *, agent_id: str = "agent", filename: str = "collaboration.md") -> Dict[str, Any]:
|
||||
"""Append a shared project note with advisory locking in scratchpad."""
|
||||
safe_name = Path(filename).name
|
||||
if safe_name != filename or ".." in filename or "/" in filename or "\\" in filename:
|
||||
self.store._audit_event("write_shared_note", filename, "rejected_filename")
|
||||
return {"ok": False, "error": "Invalid filename"}
|
||||
|
||||
scratch = self._scratchpad_dir()
|
||||
target = (scratch / safe_name).resolve()
|
||||
if scratch.resolve() not in target.parents:
|
||||
self.store._audit_event("write_shared_note", filename, "rejected_path")
|
||||
return {"ok": False, "error": "Unsafe path"}
|
||||
|
||||
lock_path = scratch / f".{safe_name}.lock"
|
||||
ts = datetime.now(timezone.utc).isoformat()
|
||||
header = f"\n## {ts} | {self._safe_slug(agent_id, 'agent')}\n"
|
||||
payload = header + str(note).rstrip() + "\n"
|
||||
|
||||
try:
|
||||
with open(lock_path, "w", encoding="utf-8") as lf:
|
||||
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
|
||||
with open(target, "a", encoding="utf-8") as tf:
|
||||
tf.write(payload)
|
||||
tf.flush()
|
||||
os.fsync(tf.fileno())
|
||||
os.chmod(target, 0o600)
|
||||
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
|
||||
self.store._audit_event("write_shared_note", safe_name, "success")
|
||||
self.append_activity_log(
|
||||
event_type="shared_note_written",
|
||||
actor=agent_id,
|
||||
details={"file": safe_name, "bytes": len(payload)},
|
||||
)
|
||||
return {"ok": True, "file": safe_name, "bytes_written": len(payload)}
|
||||
except Exception as exc:
|
||||
self.store._audit_event("write_shared_note", safe_name, f"failed_{exc.__class__.__name__}")
|
||||
return {"ok": False, "error": str(exc)}
|
||||
|
||||
def append_activity_log(self, *, event_type: str, actor: str, details: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Write append-only project activity log entry in JSONL format."""
|
||||
scratch = self._scratchpad_dir()
|
||||
target = (scratch / "activity_log.jsonl").resolve()
|
||||
lock_path = scratch / ".activity_log.jsonl.lock"
|
||||
entry = {
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"event_type": str(event_type),
|
||||
"actor": self._safe_slug(actor, "agent"),
|
||||
"project_id": self._safe_slug(self.project_id, "none") if self.project_id else None,
|
||||
"details": details or {},
|
||||
}
|
||||
line = json.dumps(entry, ensure_ascii=False) + "\n"
|
||||
try:
|
||||
with open(lock_path, "w", encoding="utf-8") as lf:
|
||||
fcntl.flock(lf.fileno(), fcntl.LOCK_EX)
|
||||
with open(target, "a", encoding="utf-8") as tf:
|
||||
tf.write(line)
|
||||
tf.flush()
|
||||
os.fsync(tf.fileno())
|
||||
os.chmod(target, 0o600)
|
||||
fcntl.flock(lf.fileno(), fcntl.LOCK_UN)
|
||||
return {"ok": True}
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to append activity log: {exc}")
|
||||
return {"ok": False, "error": str(exc)}
|
||||
|
||||
def read_struct(self, filename: str, path_query: str) -> Dict[str, Any]:
|
||||
"""AST-style structural reader for JSON context files with dependency context injection."""
|
||||
resolved_kind, resolved = self._resolve_path(filename)
|
||||
if resolved_kind == "virtual_summary" or not resolved:
|
||||
return {"ok": False, "error": "Invalid file"}
|
||||
|
||||
doc = self.store.load_context_document(resolved)
|
||||
if not isinstance(doc, dict):
|
||||
return {"ok": False, "error": "File not found"}
|
||||
|
||||
try:
|
||||
extracted = self._resolve_json_path(doc, path_query)
|
||||
except Exception as exc:
|
||||
return {"ok": False, "error": f"path_query resolution failed: {exc}"}
|
||||
|
||||
# Lightweight dependency context: inject brand voice from step2 when reading persona structures.
|
||||
dependency_context: Dict[str, Any] = {}
|
||||
if "persona" in path_query.lower() or resolved == AgentFlatContextStore.STEP4_FILENAME:
|
||||
step2 = self.store.load_step2_context_document() or {}
|
||||
step2_data = step2.get("data") if isinstance(step2.get("data"), dict) else {}
|
||||
brand = step2_data.get("brand_analysis") if isinstance(step2_data.get("brand_analysis"), dict) else {}
|
||||
dependency_context["brand_voice"] = brand.get("brand_voice")
|
||||
|
||||
return {
|
||||
"ok": True,
|
||||
"file": resolved,
|
||||
"path_query": path_query,
|
||||
"data": extracted,
|
||||
"dependency_context": dependency_context,
|
||||
"context": "Extracted via structural parse to save tokens.",
|
||||
}
|
||||
|
||||
|
||||
|
||||
def build_filesystem_header(user_id: str) -> str:
|
||||
"""Generate compact prompt header with available files and priority hints."""
|
||||
try:
|
||||
store = AgentFlatContextStore(user_id)
|
||||
manifest = store.load_context_manifest() or {"documents": []}
|
||||
docs = manifest.get("documents") if isinstance(manifest.get("documents"), list) else []
|
||||
available = [str(d.get("path")) for d in docs if isinstance(d, dict) and d.get("path")]
|
||||
files = ", ".join(sorted(available)) if available else "none"
|
||||
return (
|
||||
"Workspace Context: You have access to a local flat-file store. "
|
||||
f"Available Files: {files}. "
|
||||
"Instructions: For style guidelines, prioritize step4_persona_data.json. "
|
||||
"For technical site data, prioritize step2_website_analysis.json."
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to build filesystem header for user {user_id}: {exc}")
|
||||
return "Workspace Context: local flat-file store unavailable."
|
||||
@@ -9,6 +9,8 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import hmac
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
@@ -25,6 +27,14 @@ class AgentFlatContextStore:
|
||||
STEP4_FILENAME = "step4_persona_data.json"
|
||||
STEP5_FILENAME = "step5_integrations.json"
|
||||
MANIFEST_FILENAME = "context_manifest.json"
|
||||
WORKSPACE_README = "README.md"
|
||||
ALLOWED_CONTEXT_FILES = {
|
||||
STEP2_FILENAME,
|
||||
STEP3_FILENAME,
|
||||
STEP4_FILENAME,
|
||||
STEP5_FILENAME,
|
||||
MANIFEST_FILENAME,
|
||||
}
|
||||
|
||||
SCHEMA_VERSION = "1.3"
|
||||
DEFAULT_MAX_BYTES = 300_000
|
||||
@@ -33,12 +43,53 @@ class AgentFlatContextStore:
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
self.safe_user_id = self._sanitize_user_id(user_id)
|
||||
self._ensure_workspace_permissions()
|
||||
|
||||
def _ensure_workspace_permissions(self) -> None:
|
||||
"""Ensure workspace and context directories exist with owner-only permissions."""
|
||||
workspace_dir = self._workspace_dir()
|
||||
context_dir = workspace_dir / self.CONTEXT_DIRNAME
|
||||
workspace_dir.mkdir(parents=True, exist_ok=True)
|
||||
context_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(workspace_dir, 0o700)
|
||||
os.chmod(context_dir, 0o700)
|
||||
|
||||
@staticmethod
|
||||
def _safe_resolve_under(base_dir: Path, requested_path: str) -> Path:
|
||||
"""Resolve path and ensure it remains inside base_dir (path sandboxing)."""
|
||||
base_real = base_dir.resolve()
|
||||
candidate = (base_dir / requested_path).resolve()
|
||||
if candidate == base_real or base_real in candidate.parents:
|
||||
return candidate
|
||||
raise ValueError("Unsafe path access attempt outside sandbox")
|
||||
|
||||
@staticmethod
|
||||
def _sanitize_user_id(user_id: str) -> str:
|
||||
safe = "".join(c for c in str(user_id) if c.isalnum() or c in ("-", "_"))
|
||||
return safe or "unknown_user"
|
||||
|
||||
def _master_salt(self) -> str:
|
||||
return os.getenv("FILE_ENCRYPTION_SALT", "")
|
||||
|
||||
def derive_user_secret(self) -> bytes:
|
||||
"""Derive deterministic per-user secret from env salt + safe user id."""
|
||||
salt = self._master_salt()
|
||||
if not salt:
|
||||
return b""
|
||||
return hmac.new(salt.encode("utf-8"), self.safe_user_id.encode("utf-8"), hashlib.sha256).digest()
|
||||
|
||||
def user_secret_fingerprint(self) -> str:
|
||||
"""Short fingerprint used for diagnostics/audit only (not a key)."""
|
||||
secret = self.derive_user_secret()
|
||||
if not secret:
|
||||
return "salt_not_configured"
|
||||
return hashlib.sha256(secret).hexdigest()[:16]
|
||||
|
||||
def _audit_event(self, action: str, target: str, status: str) -> None:
|
||||
logger.info(
|
||||
f"[flat_context_audit] user={self.safe_user_id} action={action} target={target} status={status}"
|
||||
)
|
||||
|
||||
def _workspace_dir(self) -> Path:
|
||||
root_dir = Path(__file__).resolve().parents[3]
|
||||
return root_dir / "workspace" / f"workspace_{self.safe_user_id}"
|
||||
@@ -47,7 +98,10 @@ class AgentFlatContextStore:
|
||||
return self._workspace_dir() / self.CONTEXT_DIRNAME
|
||||
|
||||
def _context_file(self, filename: str) -> Path:
|
||||
return self._context_dir() / filename
|
||||
return self._safe_resolve_under(self._context_dir(), str(filename))
|
||||
|
||||
def _workspace_file(self, filename: str) -> Path:
|
||||
return self._safe_resolve_under(self._workspace_dir(), str(filename))
|
||||
|
||||
@staticmethod
|
||||
def _estimate_size_bytes(value: Any) -> int:
|
||||
@@ -56,6 +110,10 @@ class AgentFlatContextStore:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def estimate_size_bytes(self, value: Any) -> int:
|
||||
"""Public size estimate helper for adapter layers."""
|
||||
return self._estimate_size_bytes(value)
|
||||
|
||||
@staticmethod
|
||||
def _to_context_list(value: Any) -> Any:
|
||||
if value is None:
|
||||
@@ -143,6 +201,12 @@ class AgentFlatContextStore:
|
||||
"preferred": "flat_file",
|
||||
"fallback_order": fallback_order,
|
||||
},
|
||||
"security": {
|
||||
"path_sandboxing": True,
|
||||
"file_permissions": "0600",
|
||||
"directory_permissions": "0700",
|
||||
"user_secret_fingerprint": self.user_secret_fingerprint(),
|
||||
},
|
||||
"context_window_guidance": {
|
||||
"max_raw_bytes": self.DEFAULT_MAX_BYTES,
|
||||
"total_bytes": total_size,
|
||||
@@ -343,6 +407,7 @@ class AgentFlatContextStore:
|
||||
|
||||
def _atomic_write_json(self, target_file: Path, data: Dict[str, Any]) -> None:
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(target_file.parent, 0o700)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(target_file.parent), prefix=f".{target_file.name}.", suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
@@ -361,6 +426,108 @@ class AgentFlatContextStore:
|
||||
pass
|
||||
raise
|
||||
|
||||
def _atomic_write_text(self, target_file: Path, content: str) -> None:
|
||||
target_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
os.chmod(target_file.parent, 0o700)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=str(target_file.parent), prefix=f".{target_file.name}.", suffix=".tmp")
|
||||
try:
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as f:
|
||||
f.write(content)
|
||||
f.flush()
|
||||
os.fsync(f.fileno())
|
||||
os.replace(tmp_path, target_file)
|
||||
try:
|
||||
os.chmod(target_file, 0o600)
|
||||
except Exception:
|
||||
pass
|
||||
except Exception:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def _collect_signal_terms(doc: Dict[str, Any], limit: int = 6) -> list:
|
||||
summary = doc.get("agent_summary") if isinstance(doc, dict) else {}
|
||||
hints = summary.get("retrieval_hints") if isinstance(summary, dict) else {}
|
||||
terms = hints.get("high_signal_terms") if isinstance(hints, dict) else []
|
||||
if not isinstance(terms, list):
|
||||
return []
|
||||
normalized = [str(t).strip() for t in terms if str(t).strip()]
|
||||
return normalized[:limit]
|
||||
|
||||
@staticmethod
|
||||
def _extract_journey_stage(doc: Dict[str, Any]) -> str:
|
||||
dctx = doc.get("document_context") if isinstance(doc, dict) else {}
|
||||
journey = dctx.get("journey") if isinstance(dctx, dict) else {}
|
||||
stage = journey.get("stage") if isinstance(journey, dict) else ""
|
||||
return str(stage or "").strip()
|
||||
|
||||
@staticmethod
|
||||
def _context_description(filename: str) -> str:
|
||||
descriptions = {
|
||||
AgentFlatContextStore.STEP2_FILENAME: "Primary SEO and site structure context",
|
||||
AgentFlatContextStore.STEP3_FILENAME: "Research depth, competitors, and content preferences",
|
||||
AgentFlatContextStore.STEP4_FILENAME: "Persona profiles, voice adaptation, and platform strategy",
|
||||
AgentFlatContextStore.STEP5_FILENAME: "Connected integrations and provider readiness",
|
||||
}
|
||||
return descriptions.get(filename, "Context document")
|
||||
|
||||
def _generate_workspace_readme(self, manifest: Dict[str, Any]) -> str:
|
||||
docs = manifest.get("documents") if isinstance(manifest, dict) and isinstance(manifest.get("documents"), list) else []
|
||||
|
||||
lines = [
|
||||
"# Agent Workspace Map",
|
||||
"",
|
||||
"You are in a restricted read-only VFS. Use `list_context`, `read_context_file`, and `search_context` to navigate.",
|
||||
"",
|
||||
"## Core Context Files",
|
||||
]
|
||||
|
||||
for item in sorted(docs, key=lambda d: str((d or {}).get("path", ""))):
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
path = item.get("path") or ""
|
||||
if not path:
|
||||
continue
|
||||
doc = self._load_context_document(path) or {}
|
||||
signals = self._collect_signal_terms(doc)
|
||||
journey_stage = self._extract_journey_stage(doc)
|
||||
updated_at = str(item.get("updated_at") or "")
|
||||
lines.append(f"- `{path}`: {self._context_description(path)}.")
|
||||
if signals:
|
||||
lines.append(f" - **Key Signals:** {', '.join(signals)}")
|
||||
if journey_stage:
|
||||
lines.append(f" - **Journey Stage:** {journey_stage}")
|
||||
if updated_at:
|
||||
lines.append(f" - **Updated:** {updated_at}")
|
||||
|
||||
lines.extend(
|
||||
[
|
||||
"",
|
||||
"## Retrieval Strategy",
|
||||
"1. Run `list_context` to check which onboarding steps are available.",
|
||||
"2. Run `search_context` for targeted terms (for example: \"competitor\", \"tone\", \"integrations\").",
|
||||
"3. Run `read_context_file` and ingest `agent_summary` before expanding full `data`.",
|
||||
"",
|
||||
"## Virtual Paths",
|
||||
"- `/env/summary` -> consolidated summary generated from all available context docs",
|
||||
f"- `/steps/website` -> `{self.STEP2_FILENAME}`",
|
||||
f"- `/steps/research` -> `{self.STEP3_FILENAME}`",
|
||||
f"- `/steps/persona` -> `{self.STEP4_FILENAME}`",
|
||||
f"- `/steps/integrations` -> `{self.STEP5_FILENAME}`",
|
||||
]
|
||||
)
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
def _update_workspace_readme(self, manifest: Dict[str, Any]) -> None:
|
||||
try:
|
||||
content = self._generate_workspace_readme(manifest)
|
||||
self._atomic_write_text(self._workspace_file(self.WORKSPACE_README), content)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to update workspace README for user {self.user_id}: {exc}")
|
||||
|
||||
def _update_manifest(self, context_type: str, filename: str, doc: Dict[str, Any]) -> None:
|
||||
manifest_file = self._context_file(self.MANIFEST_FILENAME)
|
||||
existing = {}
|
||||
@@ -390,6 +557,7 @@ class AgentFlatContextStore:
|
||||
"documents": items,
|
||||
}
|
||||
self._atomic_write_json(manifest_file, manifest)
|
||||
self._update_workspace_readme(manifest)
|
||||
|
||||
def _save_context_document(
|
||||
self,
|
||||
@@ -436,9 +604,11 @@ class AgentFlatContextStore:
|
||||
|
||||
self._atomic_write_json(target_file, context_doc)
|
||||
self._update_manifest(context_type, filename, context_doc)
|
||||
self._audit_event("write_context", filename, "success")
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error(f"Failed to save context for user {self.user_id} ({context_type}): {exc}")
|
||||
self._audit_event("write_context", filename, "error")
|
||||
return False
|
||||
|
||||
def save_step2_website_analysis(self, payload: Dict[str, Any], *, source: str = "onboarding_step2") -> bool:
|
||||
@@ -483,19 +653,31 @@ class AgentFlatContextStore:
|
||||
|
||||
def _load_context_document(self, filename: str) -> Optional[Dict[str, Any]]:
|
||||
try:
|
||||
if str(filename) not in self.ALLOWED_CONTEXT_FILES:
|
||||
logger.warning(f"Rejected non-allowed context filename for user {self.user_id}: {filename}")
|
||||
self._audit_event("read_context", str(filename), "rejected_filename")
|
||||
return None
|
||||
target_file = self._context_file(filename)
|
||||
if not target_file.exists():
|
||||
self._audit_event("read_context", str(filename), "not_found")
|
||||
return None
|
||||
with open(target_file, "r", encoding="utf-8") as f:
|
||||
doc = json.load(f)
|
||||
if isinstance(doc, dict) and str(doc.get("user_id")) != str(self.user_id):
|
||||
logger.warning(f"Context user mismatch for {filename} (expected {self.user_id})")
|
||||
self._audit_event("read_context", str(filename), "user_mismatch")
|
||||
return None
|
||||
self._audit_event("read_context", str(filename), "success")
|
||||
return doc if isinstance(doc, dict) else None
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to load context document for user {self.user_id} ({filename}): {exc}")
|
||||
self._audit_event("read_context", str(filename), "error")
|
||||
return None
|
||||
|
||||
def load_context_document(self, filename: str) -> Optional[Dict[str, Any]]:
|
||||
"""Public loader for a named context document file."""
|
||||
return self._load_context_document(filename)
|
||||
|
||||
def load_context_manifest(self) -> Optional[Dict[str, Any]]:
|
||||
return self._load_context_document(self.MANIFEST_FILENAME)
|
||||
|
||||
@@ -526,3 +708,35 @@ class AgentFlatContextStore:
|
||||
def load_step5_integrations(self) -> Optional[Dict[str, Any]]:
|
||||
doc = self.load_step5_context_document()
|
||||
return doc.get("data") if isinstance(doc, dict) and isinstance(doc.get("data"), dict) else None
|
||||
|
||||
def generate_total_summary(self) -> Dict[str, Any]:
|
||||
"""Build a lightweight consolidated summary across available context documents."""
|
||||
manifest = self.load_context_manifest() or {"documents": []}
|
||||
docs = manifest.get("documents") if isinstance(manifest.get("documents"), list) else []
|
||||
overview = []
|
||||
for item in docs:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
path = str(item.get("path") or "")
|
||||
if not path:
|
||||
continue
|
||||
doc = self._load_context_document(path) or {}
|
||||
summary = doc.get("agent_summary") if isinstance(doc.get("agent_summary"), dict) else {}
|
||||
quick_facts = summary.get("quick_facts") if isinstance(summary.get("quick_facts"), dict) else {}
|
||||
hints = summary.get("retrieval_hints") if isinstance(summary.get("retrieval_hints"), dict) else {}
|
||||
overview.append(
|
||||
{
|
||||
"path": path,
|
||||
"context_type": doc.get("context_type"),
|
||||
"updated_at": doc.get("updated_at") or item.get("updated_at"),
|
||||
"journey_stage": self._extract_journey_stage(doc),
|
||||
"high_signal_terms": hints.get("high_signal_terms") if isinstance(hints.get("high_signal_terms"), list) else [],
|
||||
"quick_facts": quick_facts,
|
||||
}
|
||||
)
|
||||
return {
|
||||
"user_id": str(self.user_id),
|
||||
"generated_at": datetime.utcnow().isoformat(),
|
||||
"document_count": len(overview),
|
||||
"documents": overview,
|
||||
}
|
||||
|
||||
@@ -99,6 +99,17 @@ class OptimizationRecommendation:
|
||||
expires = datetime.utcnow().timestamp() + (7 * 24 * 60 * 60)
|
||||
self.expires_at = datetime.fromtimestamp(expires).isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class TierPolicyConfig:
|
||||
"""Structured policy for anomaly tiers and remediation controls"""
|
||||
tier: int
|
||||
trigger_metrics: List[str]
|
||||
thresholds: Dict[str, float]
|
||||
max_iterations: int
|
||||
lock_criteria: Dict[str, Any]
|
||||
|
||||
|
||||
class AgentPerformanceMonitor:
|
||||
"""Main performance monitoring system for agents"""
|
||||
|
||||
@@ -108,6 +119,32 @@ class AgentPerformanceMonitor:
|
||||
self.agent_snapshots: Dict[str, AgentPerformanceSnapshot] = {}
|
||||
self.recommendations: List[OptimizationRecommendation] = []
|
||||
self.performance_history: deque = deque(maxlen=1000) # Keep last 1000 data points
|
||||
self.systemic_alerts: List[Dict[str, Any]] = []
|
||||
|
||||
# Structured tier policy config
|
||||
self.tier_policy_config: Dict[int, TierPolicyConfig] = {
|
||||
1: TierPolicyConfig(
|
||||
tier=1,
|
||||
trigger_metrics=["success_rate", "efficiency_score", "response_time"],
|
||||
thresholds={"success_rate": 0.80, "efficiency_score": 0.65, "response_time": 45.0},
|
||||
max_iterations=3,
|
||||
lock_criteria={"min_confidence": 0.85, "consecutive_failures": 6}
|
||||
),
|
||||
2: TierPolicyConfig(
|
||||
tier=2,
|
||||
trigger_metrics=["success_rate", "efficiency_score", "response_time", "market_impact"],
|
||||
thresholds={"success_rate": 0.70, "efficiency_score": 0.50, "response_time": 60.0, "market_impact": 0.35},
|
||||
max_iterations=2,
|
||||
lock_criteria={"min_confidence": 0.75, "consecutive_failures": 4}
|
||||
),
|
||||
3: TierPolicyConfig(
|
||||
tier=3,
|
||||
trigger_metrics=["success_rate", "efficiency_score", "response_time", "market_impact"],
|
||||
thresholds={"success_rate": 0.55, "efficiency_score": 0.35, "response_time": 90.0, "market_impact": 0.25},
|
||||
max_iterations=1,
|
||||
lock_criteria={"min_confidence": 0.65, "consecutive_failures": 3}
|
||||
)
|
||||
}
|
||||
|
||||
# Performance thresholds and targets
|
||||
self.performance_targets = {
|
||||
@@ -513,6 +550,54 @@ class AgentPerformanceMonitor:
|
||||
}
|
||||
return priority_weights.get(priority, 0)
|
||||
|
||||
def _build_recommended_action_payload(self, agent_id: str, snapshot: AgentPerformanceSnapshot) -> Dict[str, Any]:
|
||||
"""Build recommended action payload including tier and confidence."""
|
||||
tier = 1
|
||||
if (snapshot.success_rate <= self.tier_policy_config[3].thresholds["success_rate"] or
|
||||
snapshot.efficiency_score <= self.tier_policy_config[3].thresholds["efficiency_score"] or
|
||||
snapshot.average_response_time >= self.tier_policy_config[3].thresholds["response_time"] or
|
||||
snapshot.market_impact_score <= self.tier_policy_config[3].thresholds["market_impact"]):
|
||||
tier = 3
|
||||
elif (snapshot.success_rate <= self.tier_policy_config[2].thresholds["success_rate"] or
|
||||
snapshot.efficiency_score <= self.tier_policy_config[2].thresholds["efficiency_score"] or
|
||||
snapshot.average_response_time >= self.tier_policy_config[2].thresholds["response_time"] or
|
||||
snapshot.market_impact_score <= self.tier_policy_config[2].thresholds["market_impact"]):
|
||||
tier = 2
|
||||
|
||||
confidence = round(max(0.0, min(1.0, 1.0 - abs(0.75 - self._calculate_health_score(snapshot)))) , 2)
|
||||
policy = self.tier_policy_config[tier]
|
||||
|
||||
return {
|
||||
"agent_id": agent_id,
|
||||
"tier": tier,
|
||||
"confidence": confidence,
|
||||
"max_iterations": policy.max_iterations,
|
||||
"lock_criteria": policy.lock_criteria,
|
||||
"trigger_metrics": policy.trigger_metrics
|
||||
}
|
||||
|
||||
def _route_tier3_systemic_alert(self, action_payload: Dict[str, Any], alerts: List[Dict[str, Any]]) -> None:
|
||||
"""Route Tier 3 systemic anomalies to alerting subsystem with diagnostic brief."""
|
||||
diagnostic_brief = {
|
||||
"type": "systemic_anomaly",
|
||||
"severity": "critical",
|
||||
"tier": 3,
|
||||
"confidence": action_payload.get("confidence", 0.0),
|
||||
"agent_id": action_payload.get("agent_id"),
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"diagnostic_brief": {
|
||||
"trigger_metrics": action_payload.get("trigger_metrics", []),
|
||||
"alerts": alerts,
|
||||
"max_iterations": action_payload.get("max_iterations"),
|
||||
"lock_criteria": action_payload.get("lock_criteria", {})
|
||||
}
|
||||
}
|
||||
self.systemic_alerts.append(diagnostic_brief)
|
||||
if len(self.systemic_alerts) > 200:
|
||||
self.systemic_alerts = self.systemic_alerts[-200:]
|
||||
logger.critical(f"[ALERTING_SUBSYSTEM] Tier 3 systemic anomaly routed: {json.dumps(diagnostic_brief)}")
|
||||
|
||||
|
||||
async def get_performance_alerts(self, agent_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get performance alerts for an agent"""
|
||||
alerts = []
|
||||
@@ -574,6 +659,13 @@ class AgentPerformanceMonitor:
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
})
|
||||
|
||||
action_payload = self._build_recommended_action_payload(agent_id, snapshot)
|
||||
if action_payload["tier"] == 3:
|
||||
self._route_tier3_systemic_alert(action_payload, alerts)
|
||||
|
||||
for alert in alerts:
|
||||
alert["recommended_action"] = action_payload
|
||||
|
||||
return alerts
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -84,6 +84,17 @@ class SafetyValidation:
|
||||
if self.validation_timestamp is None:
|
||||
self.validation_timestamp = datetime.utcnow().isoformat()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SafetyArbitrationDecision:
|
||||
"""Explicit allow/deny/lock decision with reasons."""
|
||||
decision: str
|
||||
reasons: List[str]
|
||||
tier: int
|
||||
confidence: float
|
||||
lock_state_active: bool
|
||||
|
||||
|
||||
class SafetyConstraintManager:
|
||||
"""Manages safety constraints for agent actions"""
|
||||
|
||||
@@ -92,6 +103,8 @@ class SafetyConstraintManager:
|
||||
self.constraints: Dict[str, SafetyConstraint] = {}
|
||||
self.action_history: List[Dict[str, Any]] = []
|
||||
self.violation_history: List[Dict[str, Any]] = []
|
||||
self.lock_state_active: bool = False
|
||||
self.lock_state_reason: Optional[str] = None
|
||||
|
||||
# Initialize default constraints
|
||||
self._initialize_default_constraints()
|
||||
@@ -163,6 +176,17 @@ class SafetyConstraintManager:
|
||||
"""Validate an action against safety constraints"""
|
||||
try:
|
||||
logger.info(f"Validating action for user {self.user_id}: {action_data.get('action_type', 'unknown')}")
|
||||
|
||||
if self.lock_state_active and action_data.get("autonomous_modification", True):
|
||||
reason = self.lock_state_reason or "Safety lock is active due to Tier 3 systemic anomaly"
|
||||
return SafetyValidation(
|
||||
is_valid=False,
|
||||
risk_level=RiskLevel.CRITICAL,
|
||||
violations=["Autonomous modifications blocked while lock state is active"],
|
||||
recommendations=[reason],
|
||||
requires_approval=True,
|
||||
confidence_score=1.0
|
||||
)
|
||||
|
||||
violations = []
|
||||
recommendations = []
|
||||
@@ -207,19 +231,29 @@ class SafetyConstraintManager:
|
||||
|
||||
# Final validation
|
||||
is_valid = len(violations) == 0 and not requires_approval
|
||||
|
||||
logger.info(f"Action validation completed for user {self.user_id}. Valid: {is_valid}, Risk: {risk_level.value}, Violations: {len(violations)}")
|
||||
|
||||
confidence_score = max(0.0, min(1.0, confidence_score))
|
||||
arbitration = self._arbitrate_decision(action_data, risk_level, violations, requires_approval, confidence_score)
|
||||
|
||||
if arbitration.decision == "lock":
|
||||
self.lock_state_active = True
|
||||
self.lock_state_reason = "; ".join(arbitration.reasons)
|
||||
is_valid = False
|
||||
requires_approval = True
|
||||
|
||||
recommendations.extend([f"Arbitration decision: {arbitration.decision}", *arbitration.reasons])
|
||||
|
||||
logger.info(f"Action validation completed for user {self.user_id}. Decision: {arbitration.decision}, Valid: {is_valid}, Risk: {risk_level.value}, Violations: {len(violations)}")
|
||||
|
||||
# Record in history
|
||||
await self._record_validation_history(action_data, is_valid, violations)
|
||||
|
||||
|
||||
return SafetyValidation(
|
||||
is_valid=is_valid,
|
||||
risk_level=risk_level,
|
||||
violations=violations,
|
||||
recommendations=recommendations,
|
||||
requires_approval=requires_approval,
|
||||
confidence_score=max(0.0, min(1.0, confidence_score))
|
||||
confidence_score=confidence_score
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -235,6 +269,30 @@ class SafetyConstraintManager:
|
||||
confidence_score=0.0
|
||||
)
|
||||
|
||||
def _arbitrate_decision(self, action_data: Dict[str, Any], risk_level: RiskLevel, violations: List[str], requires_approval: bool, confidence_score: float) -> SafetyArbitrationDecision:
|
||||
"""Arbitrate allow/deny/lock with explicit reasons."""
|
||||
reasons: List[str] = []
|
||||
tier = int(action_data.get("recommended_tier", 1))
|
||||
|
||||
if self.lock_state_active:
|
||||
reasons.append("Existing lock state is active")
|
||||
return SafetyArbitrationDecision("lock", reasons, tier, confidence_score, True)
|
||||
|
||||
if tier >= 3 or risk_level == RiskLevel.CRITICAL:
|
||||
reasons.append("Tier 3 systemic anomaly or critical risk detected")
|
||||
if violations:
|
||||
reasons.extend(violations)
|
||||
return SafetyArbitrationDecision("lock", reasons, 3, confidence_score, True)
|
||||
|
||||
if violations or requires_approval:
|
||||
reasons.append("Safety policy violation or approval requirement triggered")
|
||||
reasons.extend(violations)
|
||||
return SafetyArbitrationDecision("deny", reasons, tier, confidence_score, False)
|
||||
|
||||
reasons.append("No policy violations detected")
|
||||
return SafetyArbitrationDecision("allow", reasons, tier, confidence_score, False)
|
||||
|
||||
|
||||
def _determine_action_category(self, action_type: str) -> ActionCategory:
|
||||
"""Determine the category of an action"""
|
||||
action_type_lower = action_type.lower()
|
||||
|
||||
@@ -20,13 +20,14 @@ class SemanticHarvesterService:
|
||||
"last_harvest_time": None
|
||||
}
|
||||
|
||||
async def harvest_website(self, website_url: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
async def harvest_website(self, website_url: str, limit: int = 100, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deep crawl a website using Exa AI.
|
||||
|
||||
Args:
|
||||
website_url: The root URL to crawl.
|
||||
limit: Maximum number of pages to retrieve.
|
||||
user_id: Optional user ID for usage tracking and preflight checks.
|
||||
|
||||
Returns:
|
||||
List of pages with content and metadata.
|
||||
@@ -59,6 +60,30 @@ class SemanticHarvesterService:
|
||||
logger.warning("[SemanticHarvester] Exa service disabled. Returning placeholder data.")
|
||||
return self._get_placeholder_data(website_url)
|
||||
|
||||
# Preflight subscription check if user_id provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
logger.warning(f"[SemanticHarvester] Exa blocked for user {user_id}: {message}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[SemanticHarvester] Preflight check failed: {e}")
|
||||
|
||||
# Use Exa to search for all pages in this domain
|
||||
search_response = self.exa_service.exa.search_and_contents(
|
||||
query=f"site:{website_url}",
|
||||
@@ -82,6 +107,38 @@ class SemanticHarvesterService:
|
||||
})
|
||||
|
||||
logger.info(f"[SemanticHarvester] Successfully harvested {len(results)} pages from {website_url}")
|
||||
|
||||
# Track Exa usage if user_id provided
|
||||
if user_id and results:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Exa search cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET exa_calls = COALESCE(exa_calls, 0) + 1,
|
||||
exa_cost = COALESCE(exa_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[SemanticHarvester] Tracked Exa usage: user={user_id}, cost=${cost}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[SemanticHarvester] Failed to track Exa usage: {track_err}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -340,6 +340,46 @@ class SIFIntegrationService:
|
||||
logger.warning(f"Failed to load flat context manifest for user {self.user_id}: {e}")
|
||||
return {"source": "none", "data": {"documents": []}}
|
||||
|
||||
async def get_merged_flat_context(self) -> Dict[str, Any]:
|
||||
"""Return merged onboarding context from all available flat context documents.
|
||||
|
||||
This is an aggregation helper; step-specific APIs still return one-by-one files.
|
||||
"""
|
||||
store = AgentFlatContextStore(self.user_id)
|
||||
manifest = store.load_context_manifest() or {"documents": []}
|
||||
docs = manifest.get("documents") if isinstance(manifest.get("documents"), list) else []
|
||||
|
||||
merged: Dict[str, Any] = {
|
||||
"source": "flat_file",
|
||||
"user_id": self.user_id,
|
||||
"manifest_updated_at": manifest.get("updated_at"),
|
||||
"steps": {},
|
||||
"agent_summaries": {},
|
||||
"documents": [],
|
||||
}
|
||||
|
||||
for item in docs:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
path = item.get("path")
|
||||
if not path:
|
||||
continue
|
||||
doc = store.load_context_document(str(path)) or {}
|
||||
context_type = str(doc.get("context_type") or item.get("type") or path)
|
||||
merged["documents"].append(
|
||||
{
|
||||
"path": path,
|
||||
"context_type": context_type,
|
||||
"updated_at": doc.get("updated_at") or item.get("updated_at"),
|
||||
"size_bytes": item.get("size_bytes"),
|
||||
}
|
||||
)
|
||||
merged["steps"][context_type] = doc.get("data") if isinstance(doc.get("data"), dict) else {}
|
||||
merged["agent_summaries"][context_type] = doc.get("agent_summary") if isinstance(doc.get("agent_summary"), dict) else {}
|
||||
|
||||
merged["document_count"] = len(merged["documents"])
|
||||
return merged
|
||||
|
||||
async def index_market_trends_run(self, trends_result: Dict[str, Any], run_id: str) -> bool:
|
||||
try:
|
||||
latest_id = f"market_trends_latest:{self.user_id}"
|
||||
|
||||
@@ -67,10 +67,11 @@ import sys
|
||||
from pathlib import Path
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
from services.api_key_manager import APIKeyManager
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("gemini_audio_text")
|
||||
|
||||
@@ -250,10 +250,6 @@ def huggingface_text_response(
|
||||
|
||||
logger.info("🚀 Making Hugging Face API call (chat completion)...")
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model):
|
||||
@@ -403,10 +399,6 @@ def huggingface_structured_json_response(
|
||||
json_schema_str = json.dumps(schema, indent=2)
|
||||
messages[-1]["content"] += f"\n\nJSON Schema:\n{json_schema_str}"
|
||||
|
||||
# Add rate limiting to prevent expensive API calls
|
||||
import time
|
||||
time.sleep(1) # 1 second delay between API calls
|
||||
|
||||
try:
|
||||
response = None
|
||||
last_error = None
|
||||
|
||||
120
backend/services/llm_providers/image_generation/edit.py
Normal file
120
backend/services/llm_providers/image_generation/edit.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Image editing operations — generate_image_edit and related helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import ImageEditOptions, ImageGenerationResult, ImageEditProvider
|
||||
from .wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.edit")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Generate edited image with pre-flight validation and usage tracking.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Image editing failed", "message": str(e)}
|
||||
)
|
||||
|
||||
# 6. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
estimated_cost = 0.02 if provider_name == "wavespeed" else 0.05
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=result.model or model or "unknown",
|
||||
operation_type="image-edit",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation/edit",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||
|
||||
# 7. Return result
|
||||
return result
|
||||
105
backend/services/llm_providers/image_generation/face_swap.py
Normal file
105
backend/services/llm_providers/image_generation/face_swap.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Face swap operations — generate_face_swap and related helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import FaceSwapOptions, FaceSwapProvider, ImageGenerationResult
|
||||
from .wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.face_swap")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Generate face swap with pre-flight validation and usage tracking.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
num_operations=1,
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025)
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None,
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap failed", "message": str(api_error)}
|
||||
)
|
||||
200
backend/services/llm_providers/image_generation/helpers.py
Normal file
200
backend/services/llm_providers/image_generation/helpers.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Shared helpers for image generation operations — validation and usage tracking."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.helpers")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""Reusable pre-flight validation helper for all image operations."""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id}")
|
||||
except HTTPException:
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Reusable usage tracking helper for all image operations."""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Map provider to DB column names
|
||||
provider_column_map = {
|
||||
"stability": ("stability_calls", "stability_cost"),
|
||||
"wavespeed": ("wavespeed_calls", "wavespeed_cost"),
|
||||
"gemini": ("gemini_calls", "gemini_cost"),
|
||||
"openai": ("openai_calls", "openai_cost"),
|
||||
"huggingface": ("total_calls", "total_cost"), # no dedicated columns
|
||||
}
|
||||
calls_col, cost_col = provider_column_map.get(provider, ("total_calls", "total_cost"))
|
||||
|
||||
current_calls_before = getattr(summary, calls_col, 0) or 0
|
||||
current_cost_before = getattr(summary, cost_col, 0.0) or 0.0
|
||||
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {calls_col} = :new_calls,
|
||||
{cost_col} = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Map provider to APIProvider enum
|
||||
provider_api_map = {
|
||||
"stability": APIProvider.STABILITY,
|
||||
"wavespeed": APIProvider.WAVESPEED,
|
||||
"gemini": APIProvider.GEMINI,
|
||||
"openai": APIProvider.OPENAI,
|
||||
"image_edit": APIProvider.IMAGE_EDIT,
|
||||
"video": APIProvider.VIDEO,
|
||||
"audio": APIProvider.AUDIO,
|
||||
}
|
||||
api_provider = provider_api_map.get(provider, APIProvider.STABILITY)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
provider_limit = limits['limits'].get(calls_col, 0) if limits else 0
|
||||
provider_limit_display = provider_limit if (provider_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {provider_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {"current_calls": new_calls, "cost": cost, "total_cost": new_cost}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
@@ -55,6 +55,9 @@ def _select_provider(explicit: Optional[str]) -> str:
|
||||
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get the client for the specified provider."""
|
||||
if provider_name == "wavespeed":
|
||||
api_key = api_key or os.getenv("WAVESPEED_API_KEY")
|
||||
if not api_key:
|
||||
raise RuntimeError("WAVESPEED_API_KEY is required for WaveSpeed image editing. Set it in your .env file.")
|
||||
return WaveSpeedEditProvider(api_key=api_key)
|
||||
|
||||
if not HF_HUB_AVAILABLE:
|
||||
@@ -63,7 +66,7 @@ def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
if provider_name == "huggingface":
|
||||
api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing. Set it in your .env file.")
|
||||
# Use fal-ai provider for fast inference via HF Inference API
|
||||
return InferenceClient(provider="fal-ai", api_key=api_key)
|
||||
|
||||
@@ -99,35 +102,53 @@ def edit_image(
|
||||
"""
|
||||
# PRE-FLIGHT VALIDATION: Validate image editing before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
if user_id:
|
||||
from services.database import get_db
|
||||
# Skip validation in podcast-only demo mode or if explicitly disabled
|
||||
skip_validation = os.getenv("ALWRITY_SKIP_IMAGE_EDITING_VALIDATION", "false").lower() in ("true", "1", "yes")
|
||||
|
||||
if user_id and not skip_validation:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_editing_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
# Note: get_db() is a generator, so we need to use next() to get the session
|
||||
# and ensure we close it in the finally block
|
||||
db = next(get_db())
|
||||
|
||||
db = None
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_editing_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Image Editing] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image editing")
|
||||
# Use get_session_for_user instead of get_db() since we're outside FastAPI DI
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[Image Editing] ⚠️ Could not get DB session for user {user_id} - skipping validation")
|
||||
else:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_editing_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
logger.info(f"[Image Editing] ✅ Pre-flight validation passed for user_id={user_id} - proceeding with image editing")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
# In feature-limited mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in feature-limited mode - allowing operation to continue")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
if db:
|
||||
try:
|
||||
db.close()
|
||||
except Exception as close_err:
|
||||
logger.warning(f"[Image Editing] Error closing DB session: {close_err}")
|
||||
else:
|
||||
logger.warning(f"[Image Editing] ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
if skip_validation:
|
||||
logger.info(f"[Image Editing] ⚡ Skipping pre-flight validation (ALWRITY_SKIP_IMAGE_EDITING_VALIDATION=true)")
|
||||
else:
|
||||
logger.warning(f"[Image Editing] ⚠️ No user_id provided - skipping pre-flight validation")
|
||||
|
||||
# Validate input
|
||||
if not input_image_bytes:
|
||||
|
||||
@@ -18,9 +18,9 @@ from .image_generation import (
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from .image_generation.helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from .image_generation.edit import generate_image_edit
|
||||
from .image_generation.face_swap import generate_face_swap
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
@@ -53,259 +53,6 @@ def _get_provider(provider_name: str, user_id: Optional[str] = None):
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance.
|
||||
|
||||
Args:
|
||||
provider_name: Provider name ("wavespeed", "stability", etc.)
|
||||
|
||||
Returns:
|
||||
ImageEditProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
# TODO: Add Stability edit provider if needed
|
||||
# elif provider_name == "stability":
|
||||
# return StabilityEditProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""
|
||||
Reusable pre-flight validation helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for subscription checking
|
||||
operation_type: Type of operation (for logging)
|
||||
num_operations: Number of operations to validate (default: 1)
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails (subscription limits exceeded, etc.)
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name (e.g., "wavespeed", "stability")
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated/processed image bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text (for request size calculation)
|
||||
endpoint: API endpoint path (for logging)
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information (current_calls, cost, etc.)
|
||||
"""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Stability, HuggingFace, etc.)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, Stability, etc.)
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
@@ -500,165 +247,7 @@ def generate_character_image(
|
||||
)
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate edited image - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
# If model is specified and starts with "wavespeed", use wavespeed provider
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider (REUSES provider pattern)
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Image editing failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate face swap - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
image_base64=base_image_base64, # Use base image for validation
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
# Get model cost
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None, # Face swap doesn't use prompts
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -6,6 +6,7 @@ migrated from the legacy lib/gpt_providers/text_generation/main_text_generation.
|
||||
|
||||
import os
|
||||
import json
|
||||
import time
|
||||
from typing import Optional, Dict, Any, List
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
@@ -44,6 +45,7 @@ def llm_text_gen(
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
preferred_provider: Optional[str] = None,
|
||||
flow_type: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
@@ -74,7 +76,8 @@ def llm_text_gen(
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
temperature = 0.7
|
||||
max_tokens = 4000
|
||||
if max_tokens is None:
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
n = 1
|
||||
fp = 16
|
||||
@@ -211,7 +214,7 @@ def llm_text_gen(
|
||||
provider_enum = APIProvider.MISTRAL # HuggingFace maps to Mistral enum for usage tracking
|
||||
actual_provider_name = "huggingface" # Keep actual provider name for logs
|
||||
elif gpt_provider == "wavespeed":
|
||||
provider_enum = APIProvider.OPENAI # Map to OpenAI for tracking purposes
|
||||
provider_enum = APIProvider.WAVESPEED
|
||||
actual_provider_name = "wavespeed"
|
||||
elif gpt_provider == "openai":
|
||||
provider_enum = APIProvider.OPENAI
|
||||
@@ -225,6 +228,8 @@ def llm_text_gen(
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
sub_check_start = time.time()
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] Subscription check START for user {user_id}")
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
@@ -286,6 +291,8 @@ def llm_text_gen(
|
||||
logger.info(f"[llm_text_gen] Subscription check passed for user {user_id}: provider={actual_provider_name or gpt_provider}, tokens_requested={estimated_total_tokens}, new_user_no_usage_record")
|
||||
|
||||
finally:
|
||||
sub_check_ms = (time.time() - sub_check_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] Subscription check took {sub_check_ms:.0f}ms for user {user_id}")
|
||||
db.close()
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
|
||||
@@ -295,7 +302,8 @@ def llm_text_gen(
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
# STRICT: Fail on subscription check errors
|
||||
logger.error(f"[llm_text_gen] Subscription check failed for user {user_id}: {sub_error}")
|
||||
sub_check_ms = (time.time() - sub_check_start) * 1000
|
||||
logger.error(f"[llm_text_gen][{flow_tag}] Subscription check FAILED after {sub_check_ms:.0f}ms for user {user_id}: {sub_error}")
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
# Construct the system prompt if not provided
|
||||
@@ -365,15 +373,29 @@ def llm_text_gen(
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "wavespeed":
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
llm_start = time.time()
|
||||
if json_struct:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_structured_json_response
|
||||
response_text = wavespeed_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
llm_ms = (time.time() - llm_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] LLM API call took {llm_ms:.0f}ms for user {user_id} (wavespeed)")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] Unknown provider: {gpt_provider}")
|
||||
raise RuntimeError(f"Unknown LLM provider: {gpt_provider}. Supported providers: google, huggingface, wavespeed")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user