Compare commits
1 Commits
codex/impl
...
codex/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f210310177 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -8,10 +8,6 @@ nul
|
||||
LICENSE
|
||||
CHANGELOG.md
|
||||
|
||||
.planning
|
||||
.planning/
|
||||
|
||||
|
||||
.trae/
|
||||
.trae
|
||||
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
# Roadmap: Alwrity - ALwrity Frontend Optimization
|
||||
|
||||
## Overview
|
||||
|
||||
Optimize the frontend build to reduce build time from 5 minutes to under 30 seconds and shrink bundle size from 8.42MB to under 1MB. First, implement code splitting with React.lazy and feature-gated loading using ALWRITY_ENABLED_FEATURES. Then migrate from Create React App to Vite for faster builds. Finally, optimize dependencies for maximum performance.
|
||||
|
||||
## Phases
|
||||
|
||||
**Phase Numbering:**
|
||||
- Integer phases (1, 2, 3, 4): Planned work
|
||||
- All phases planned and ready for execution
|
||||
|
||||
---
|
||||
|
||||
### Phase 1: Code Splitting & Feature-Based Lazy Loading ✅ Complete
|
||||
**Goal**: Replace all static imports with React.lazy dynamic imports and add feature-gated loading using ALWRITY_ENABLED_FEATURES. Also convert MUI icon barrel imports to individual imports (moved here from Phase 3 for Vite readiness).
|
||||
**Depends on**: Nothing (first phase)
|
||||
**Requirements**: VITE-04 (code splitting), VITE-06 (dependency optimization)
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. ✅ All 31+ route components loaded via React.lazy (not static imports)
|
||||
2. ✅ Initial bundle size reduced from 8.42MB to 2.50MB (70% reduction)
|
||||
3. ✅ Disabled features (via ALWRITY_ENABLED_FEATURES) don't load their bundles
|
||||
4. ✅ All existing routes still work correctly
|
||||
5. ✅ No build warnings or errors with CRA
|
||||
6. ✅ All MUI icon imports changed from barrel to individual (111 files)
|
||||
|
||||
**Plans**: 3 plans (all complete)
|
||||
|
||||
Plans:
|
||||
- [x] 01-01: Convert 31 static imports to React.lazy with Suspense
|
||||
- [x] 01-02: Add feature-gated route loading using ALWRITY_ENABLED_FEATURES
|
||||
- [x] 01-03: Convert MUI icon barrel imports to individual imports (111 files)
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Migrate from CRA to Vite (Next)
|
||||
**Goal**: Migrate frontend from Create React App to Vite for fast builds
|
||||
**Depends on**: Phase 1 ✅
|
||||
**Requirements**: VITE-01, VITE-02, VITE-03
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. `npm run dev` starts Vite dev server with HMR
|
||||
2. `npm run build` completes in under 30 seconds (down from 5 minutes)
|
||||
3. All environment variables work with `VITE_*` prefix
|
||||
4. TypeScript compiles without errors
|
||||
5. Material UI theme renders correctly
|
||||
|
||||
**Plans**: 3 plans
|
||||
|
||||
Plans:
|
||||
- [ ] 02-01: Install Vite dependencies and create configuration
|
||||
- [ ] 02-02: Migrate index.html and entry point
|
||||
- [ ] 02-03: Update environment variables and scripts
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: Dependency Cleanup & Production Validation
|
||||
**Goal**: Remove unused dependencies and deploy Vite build to production
|
||||
**Depends on**: Phase 2
|
||||
**Requirements**: VITE-07, VITE-08, VITE-09
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. Unused dependencies identified and removed
|
||||
2. Production build serves correctly (preview mode)
|
||||
3. All features tested and working (Clerk auth, Stripe, CopilotKit)
|
||||
4. Vercel deployment config updated for Vite
|
||||
5. Build time consistently under 30 seconds
|
||||
6. Total bundle size under 2MB
|
||||
|
||||
**Plans**: 2 plans (consolidated from former Phase 3 & 4)
|
||||
|
||||
Plans:
|
||||
- [ ] 03-01: Audit and remove unused dependencies, update Vercel config
|
||||
- [ ] 03-02: Full feature testing and performance validation
|
||||
|
||||
---
|
||||
|
||||
## Execution Order
|
||||
|
||||
Phases execute in numeric order: 1 → 2 → 3
|
||||
|
||||
**Key insight:** Phase 1 (code splitting) works with CRA, so we immediately reduce bundle size. Phase 2 (Vite) gives build speed bonus on already-split bundles. Phase 3 is cleanup and deployment.
|
||||
|
||||
## Progress
|
||||
|
||||
| Phase | Plans Complete | Status | Completed |
|
||||
|-------|----------------|--------|-----------|
|
||||
| 1. Code Splitting & MUI Optimization | 3/3 | ✅ Complete | 2026-05-08 |
|
||||
| 2. Migrate CRA to Vite | 0/3 | ⏳ Ready | - |
|
||||
| 3. Cleanup & Production | 0/2 | ⏳ Planned | - |
|
||||
@@ -1,73 +0,0 @@
|
||||
# Project State: Alwrity
|
||||
|
||||
## Current Position
|
||||
|
||||
**Active Phase:** Phase 1 - Code Splitting & Feature-Based Lazy Loading
|
||||
**Phase Status:** ✅ Complete — Ready for Phase 2
|
||||
**Milestone:** v1.0 - Frontend Optimization
|
||||
|
||||
## Phase Progress
|
||||
|
||||
### Phase 1: Code Splitting & Feature-Based Lazy Loading
|
||||
- **Status:** ✅ Complete
|
||||
- **Plans:** 3 plans executed (01-01, 01-02, 01-03)
|
||||
|
||||
**Plans:**
|
||||
- [x] 01-01: Convert 31 static imports to React.lazy with Suspense
|
||||
- [x] 01-02: Add feature-gated route loading using ALWRITY_ENABLED_FEATURES
|
||||
- [x] 01-03: Convert MUI icon barrel imports to individual imports (111 files)
|
||||
|
||||
**Results:**
|
||||
- Main bundle: 8.42MB → 2.50MB (70% reduction via React.lazy)
|
||||
- 190+ chunk files for route-level code splitting
|
||||
- 47 routes feature-gated with ALWRITY_ENABLED_FEATURES
|
||||
- 16 feature keys in FEATURE_KEYS constant
|
||||
- 111 files converted from barrel to individual MUI icon imports
|
||||
- Zero barrel imports from @mui/icons-material remain
|
||||
|
||||
### Phase 2: Migrate CRA to Vite
|
||||
- **Status:** Ready to start (Phase 1 complete)
|
||||
- **Plans:** 3 plans created (02-01, 02-02, 02-03)
|
||||
- **Dependencies:** Phase 1 complete
|
||||
|
||||
**Plans:**
|
||||
- [ ] 02-01: Install Vite dependencies and create configuration
|
||||
- [ ] 02-02: Migrate index.html and entry point
|
||||
- [ ] 02-03: Update environment variables and scripts
|
||||
|
||||
### Phase 3: Production Validation (Planned)
|
||||
- Depends on: Phase 2
|
||||
- Focus: Vercel deploy, full feature testing
|
||||
|
||||
### Phase 4: (Removed — MUI icon optimization folded into Phase 1-03)
|
||||
|
||||
## Decisions Made
|
||||
|
||||
### Locked Decisions
|
||||
- **Code splitting first**, then Vite migration (not the other way around) ✅ Done
|
||||
- Use React.lazy for ALL route components (this is a React feature, NOT bundler-specific) ✅ Done
|
||||
- Use ALWRITY_ENABLED_FEATURES for feature-gated route loading ✅ Done
|
||||
- **MUI icon imports before Vite migration** — barrel imports converted to individual per-file default imports ✅ Done
|
||||
- Use Vite 5.x with @vitejs/plugin-react
|
||||
- Disable sourcemaps in production build for speed
|
||||
- Migrate env vars from `REACT_APP_*` to `VITE_*`
|
||||
|
||||
### Patterns Established
|
||||
- **MUI icon imports**: Always `import IconName from '@mui/icons-material/IconName'` — never barrel destructuring
|
||||
- **Route splitting**: All route components use React.lazy with Suspense
|
||||
- **Feature gating**: FeatureRoute wraps inside ProtectedRoute (auth → then feature check)
|
||||
|
||||
## Key Insight
|
||||
|
||||
**React.lazy is a React feature (not CRA or Vite specific).** Doing code splitting first with CRA:
|
||||
1. Immediately reduces main bundle from 8.42MB → ~1-2MB
|
||||
2. Adds no risk (React.lazy is stable since React 16.6)
|
||||
3. Makes Vite migration smoother (bundles are already split)
|
||||
4. ALWRITY_ENABLED_FEATURES can prevent disabled feature bundles from loading at all
|
||||
|
||||
**MUI icon barrel imports eliminated** — 111 files converted to individual per-file imports. This ensures reliable tree-shaking during Vite migration and beyond.
|
||||
|
||||
---
|
||||
|
||||
*Last updated: 2026-05-08*
|
||||
*Updated by: gsd-executor*
|
||||
@@ -1,129 +0,0 @@
|
||||
---
|
||||
phase: 01-code-splitting
|
||||
plan: 03
|
||||
type: execute
|
||||
subsystem: frontend
|
||||
tags: [performance, MUI, icons, tree-shaking, barrel-imports]
|
||||
requires:
|
||||
- phase: 01-code-splitting-02
|
||||
provides: feature gating structure for route protection
|
||||
provides:
|
||||
- All MUI icon imports converted from barrel (destructured) to individual per-file default imports
|
||||
- Zero barrel imports from @mui/icons-material remain in the codebase
|
||||
affects: [02-vite-migration, build performance]
|
||||
tech-stack:
|
||||
added: []
|
||||
patterns: [individual MUI icon imports, per-file default imports for tree-shaking]
|
||||
key-files:
|
||||
created: []
|
||||
modified:
|
||||
- frontend/src/components/shared/ErrorBoundary.tsx
|
||||
- frontend/src/components/SubscriptionGuard.tsx
|
||||
- frontend/src/components/SubscriptionExpiredModal.tsx
|
||||
- frontend/src/pages/SchedulerDashboard.tsx
|
||||
- frontend/src/pages/BillingPage.tsx
|
||||
- +106 additional frontend component files
|
||||
key-decisions:
|
||||
- "All MUI icon barrel imports converted BEFORE Vite migration to eliminate Webpack 4 tree-shaking uncertainty"
|
||||
- "Used per-file default imports (import X from '@mui/icons-material/X') instead of destructured barrel imports"
|
||||
- "Aliased icons (e.g., ErrorOutline as ErrorIcon) converted to named default imports matching the alias (import ErrorIcon from '@mui/icons-material/ErrorOutline')"
|
||||
- "JSX variable names preserved — only import statements changed"
|
||||
patterns-established:
|
||||
- "MUI icon imports: always use import X from '@mui/icons-material/X' pattern, never import { X } from '@mui/icons-material'"
|
||||
duration: 45min
|
||||
completed: 2026-05-08
|
||||
---
|
||||
|
||||
# Phase 1 Plan 01-03: MUI Icon Import Optimization Summary
|
||||
|
||||
**Converted all 300+ MUI icon barrel imports to individual per-file default imports across 111 frontend files — eliminating Webpack 4 tree-shaking uncertainty before Vite migration**
|
||||
|
||||
## Performance
|
||||
|
||||
- **Duration:** ~35 min
|
||||
- **Completed:** 2026-05-08
|
||||
- **Tasks:** 10 commits across 111 files
|
||||
- **Files modified:** 111
|
||||
|
||||
## Accomplishments
|
||||
|
||||
- Converted **all barrel** `import { X } from '@mui/icons-material'` to individual `import X from '@mui/icons-material/X'` — **zero barrel imports remaining**
|
||||
- Modified **111 files** across every area: PodcastMaker, YouTubeCreator, OnboardingWizard, billing, SEO, shared components, and more
|
||||
- Handled aliased imports (`IconName as Alias`) correctly — JSX variable names preserved unchanged
|
||||
- Build verified — `npm run build:nomap` succeeds with zero new errors
|
||||
- Enables reliable tree-shaking during Phase 2 (Vite migration) — each file imports only the icons it uses
|
||||
|
||||
## Task Commits
|
||||
|
||||
Each batch was committed atomically:
|
||||
|
||||
1. **ErrorBoundary** (`components/shared/`) - `46781a0` — 5 icons
|
||||
2. **SubscriptionGuard** - `bda75cb` — 2 icons
|
||||
3. **SubscriptionExpiredModal** - `80f76b1` — 3 icons
|
||||
4. **SchedulerDashboard** - `7ffd972` — 7 icons
|
||||
5. **BillingPage** - `a76671c` — 1 icon
|
||||
6. **Billing, Blog, ContentPlanning, ErrorBoundary, Pricing, Alerts** - `a009cbb` — 8 files, 36 insertions
|
||||
7. **ImageStudio, Landing, LinkedIn, MainDashboard, OnboardingWizard** - `205e098` — 14 files, 65 insertions
|
||||
8. **PodcastMaker AnalysisPanel** - `25ce5b9` — 18 files, 58 insertions
|
||||
9. **PodcastMaker, ProductMarketing, Research, Scheduler, SEO, Shared** - `986a7e5` — 44 files, 149 insertions
|
||||
10. **StoryWriter, YouTubeCreator** - `6361255` — 22 files, 67 insertions
|
||||
|
||||
## Files Modified
|
||||
|
||||
**111 files total** across the frontend source tree:
|
||||
|
||||
- `components/billing/` — 2 files (ComprehensiveAPIBreakdown, CostOptimizationRecommendations)
|
||||
- `components/BlogWriter/` — 1 file (BlogWriterPhasesSection)
|
||||
- `components/ContentPlanningDashboard/` — 2 files (CardExpansionWrapper, StrategyErrorBoundary)
|
||||
- `components/ErrorBoundary.tsx` — 1 file (3 icons)
|
||||
- `components/ImageStudio/` — 2 files (AssetFilters, CreateStudioCostAlerts)
|
||||
- `components/Landing/` — 2 files (EnterpriseCTA, FeatureShowcase)
|
||||
- `components/LinkedInWriter/` — 1 file (FactCheckResults)
|
||||
- `components/MainDashboard/` — 1 file (MainDashboard)
|
||||
- `components/OnboardingWizard/` — 7 files (incl. VoiceAvatarPlaceholder with 22 icons)
|
||||
- `components/PodcastMaker/` — 40 files (AnalysisPanel, CreateStep, ScriptEditor, etc.)
|
||||
- `components/Pricing/` — 1 file (PricingPage)
|
||||
- `components/ProductMarketing/` — 5 files (CampaignWizard, ProductPhotoshootStudio, etc.)
|
||||
- `components/Research/` — 2 files (PersonalizationIndicator, ResearchInputContainer)
|
||||
- `components/SchedulerDashboard/` — 1 file (SchedulerCharts)
|
||||
- `components/SEODashboard/` — 3 files (AIInsightsPanel, HealthScore, MetricCard)
|
||||
- `components/shared/` — 12 files (ErrorBoundary, AlertsBadge, ProtectedRoute, etc.)
|
||||
- `components/StoryWriter/` — 3 files (AIStorySetupModal, FormFieldWithTooltip, SelectFieldWithTooltip)
|
||||
- `components/SubscriptionGuard.tsx` — 1 file
|
||||
- `components/SubscriptionExpiredModal.tsx` — 1 file
|
||||
- `components/YouTubeCreator/` — 19 files (SceneCard, RenderStep, PlanStep, etc.)
|
||||
- `pages/` — 2 files (BillingPage, ResearchDashboard/PresetsCard)
|
||||
|
||||
## Decisions Made
|
||||
|
||||
- **Convert all barrel imports now, before Vite migration** — CRA's Webpack 4 cannot reliably tree-shake barrel imports. Converting before the bundler swap reduces migration risk and ensures Vite's native ESM tree-shaking works optimally.
|
||||
- **Per-file default import pattern** — Every icon gets its own import line: `import IconName from '@mui/icons-material/IconName'`. This is the most predictable pattern and works identically in both Webpack and Vite.
|
||||
- **Alias handling** — For icons imported as `{ X as Y }`, the alias `Y` becomes the import name: `import Y from '@mui/icons-material/X'`. JSX usage unchanged.
|
||||
- **Multiple import lines preserved** — Files with separate barrel imports from `@mui/icons-material` were converted to multiple individual import blocks, preserving the original organizational structure.
|
||||
|
||||
## Deviations from Plan
|
||||
|
||||
None - this was ad-hoc work not covered by an existing PLAN.md.
|
||||
|
||||
## Issues Encountered
|
||||
|
||||
- **Task agent timeout**: First attempt at parallel conversion agents failed silently for batches 1-2 (73 files). Re-launched with explicit edit instructions - succeeded on second attempt.
|
||||
- **No naming conflicts found**: Despite converting 300+ icon imports across 111 files, no variable naming collisions occurred. Each icon only appears once per file.
|
||||
|
||||
## Build Verification
|
||||
|
||||
- `npm run build:nomap` — **PASSED** with zero errors
|
||||
- Only pre-existing CRA bundle size warning remains (expected — Vite migration will resolve it in Phase 2)
|
||||
- No new build warnings introduced
|
||||
|
||||
## Next Phase Readiness
|
||||
|
||||
- Frontend is ready for **Phase 2: Vite Migration**
|
||||
- All MUI icon imports use individual default imports — tree-shaking will work correctly with Vite's rollup
|
||||
- User should perform manual testing of Podcast Maker with `REACT_APP_ENABLED_FEATURES=podcast` before Vite migration begins
|
||||
- After manual verification, proceed with [Phase 2-01: Install Vite dependencies and create configuration]
|
||||
|
||||
---
|
||||
|
||||
*Phase: 01-code-splitting*
|
||||
*Completed: 2026-05-08*
|
||||
14
Procfile
14
Procfile
@@ -1 +1,13 @@
|
||||
web: cd backend && python start_alwrity_backend.py --production
|
||||
web: cd backend && ALWRITY_ENABLED_FEATURES=podcast python -c "
|
||||
import os
|
||||
import sys
|
||||
# Ensure podcast mode
|
||||
os.environ.setdefault('ALWRITY_ENABLED_FEATURES', 'podcast')
|
||||
# Set HOST/PORT for Render
|
||||
port = os.getenv('PORT', '10000')
|
||||
host = os.getenv('HOST', '0.0.0.0')
|
||||
print(f'[STARTUP] Starting uvicorn on {host}:{port}', flush=True)
|
||||
sys.stdout.flush()
|
||||
import uvicorn
|
||||
uvicorn.run('app:app', host=host, port=int(port), reload=False)
|
||||
"
|
||||
|
||||
14
README.md
Normal file
14
README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Render CLI
|
||||
|
||||
## Installation
|
||||
|
||||
- [Homebrew](https://render.com/docs/cli#homebrew-macos-linux)
|
||||
- [Direct Download](https://render.com/docs/cli#direct-download)
|
||||
|
||||
## Documentation
|
||||
|
||||
Documentation is hosted at https://render.com/docs/cli.
|
||||
|
||||
## Contributing
|
||||
|
||||
To create a new command, use the `cmd/template.go` template file as a starting point. Reference the [CLI Style Guide](docs/STYLE.md) to learn more about command naming, flags, arguments, and help text conventions.
|
||||
672
_session_backup/App.tsx
Normal file
672
_session_backup/App.tsx
Normal file
@@ -0,0 +1,672 @@
|
||||
import React from 'react';
|
||||
import { BrowserRouter as Router, Routes, Route, Navigate, useLocation } from 'react-router-dom';
|
||||
import { Box, CircularProgress, Typography } from '@mui/material';
|
||||
import { CopilotKit } from "@copilotkit/react-core";
|
||||
import { ClerkProvider, useAuth } from '@clerk/clerk-react';
|
||||
import "@copilotkit/react-ui/styles.css";
|
||||
import Wizard from './components/OnboardingWizard/Wizard';
|
||||
import MainDashboard from './components/MainDashboard/MainDashboard';
|
||||
import SEODashboard from './components/SEODashboard/SEODashboard';
|
||||
import ContentPlanningDashboard from './components/ContentPlanningDashboard/ContentPlanningDashboard';
|
||||
import FacebookWriter from './components/FacebookWriter/FacebookWriter';
|
||||
import LinkedInWriter from './components/LinkedInWriter/LinkedInWriter';
|
||||
import BlogWriter from './components/BlogWriter/BlogWriter';
|
||||
import StoryWriter from './components/StoryWriter/StoryWriter';
|
||||
import { StoryProjectList } from './components/StoryWriter/StoryProjectList';
|
||||
import YouTubeCreator from './components/YouTubeCreator/YouTubeCreator';
|
||||
import { CreateStudio, EditStudio, UpscaleStudio, ControlStudio, SocialOptimizer, AssetLibrary, ImageStudioDashboard, FaceSwapStudio, CompressionStudio, ImageProcessingStudio } from './components/ImageStudio';
|
||||
import {
|
||||
VideoStudioDashboard,
|
||||
CreateVideo,
|
||||
AvatarVideo,
|
||||
EnhanceVideo,
|
||||
ExtendVideo,
|
||||
EditVideo,
|
||||
TransformVideo,
|
||||
SocialVideo,
|
||||
FaceSwap,
|
||||
VideoTranslate,
|
||||
VideoBackgroundRemover,
|
||||
AddAudioToVideo,
|
||||
LibraryVideo,
|
||||
} from './components/VideoStudio';
|
||||
import {
|
||||
ProductMarketingDashboard,
|
||||
ProductPhotoshootStudio,
|
||||
ProductAnimationStudio,
|
||||
ProductVideoStudio,
|
||||
ProductAvatarStudio,
|
||||
} from './components/ProductMarketing';
|
||||
import PodcastDashboard from './components/PodcastMaker/PodcastDashboard';
|
||||
import PricingPage from './components/Pricing/PricingPage';
|
||||
import WixTestPage from './components/WixTestPage/WixTestPage';
|
||||
import WixCallbackPage from './components/WixCallbackPage/WixCallbackPage';
|
||||
import WordPressCallbackPage from './components/WordPressCallbackPage/WordPressCallbackPage';
|
||||
import BingCallbackPage from './components/BingCallbackPage/BingCallbackPage';
|
||||
import BingAnalyticsStorage from './components/BingAnalyticsStorage/BingAnalyticsStorage';
|
||||
import ResearchDashboard from './pages/ResearchDashboard';
|
||||
import IntentResearchTest from './pages/IntentResearchTest';
|
||||
import SchedulerDashboard from './pages/SchedulerDashboard';
|
||||
import BillingPage from './pages/BillingPage';
|
||||
import ApprovalsPage from './pages/ApprovalsPage';
|
||||
import TeamActivityPage from './pages/TeamActivityPage';
|
||||
import StripeDisputesDashboard from './pages/StripeDisputesDashboard';
|
||||
import ProtectedRoute from './components/shared/ProtectedRoute';
|
||||
import GSCAuthCallback from './components/SEODashboard/components/GSCAuthCallback';
|
||||
import Landing from './components/Landing/Landing';
|
||||
import ErrorBoundary from './components/shared/ErrorBoundary';
|
||||
import ErrorBoundaryTest from './components/shared/ErrorBoundaryTest';
|
||||
import CopilotKitDegradedBanner from './components/shared/CopilotKitDegradedBanner';
|
||||
import { OnboardingProvider } from './contexts/OnboardingContext';
|
||||
import { SubscriptionProvider, useSubscription } from './contexts/SubscriptionContext';
|
||||
import { CopilotKitHealthProvider } from './contexts/CopilotKitHealthContext';
|
||||
import { useOAuthTokenAlerts } from './hooks/useOAuthTokenAlerts';
|
||||
|
||||
import { setAuthTokenGetter, setClerkSignOut } from './api/client';
|
||||
import { setMediaAuthTokenGetter } from './utils/fetchMediaBlobUrl';
|
||||
import { setBillingAuthTokenGetter } from './services/billingService';
|
||||
import { useOnboarding } from './contexts/OnboardingContext';
|
||||
import { useState, useEffect } from 'react';
|
||||
import ConnectionErrorPage from './components/shared/ConnectionErrorPage';
|
||||
import { isPodcastOnlyDemoMode } from './utils/demoMode';
|
||||
|
||||
// interface OnboardingStatus {
|
||||
// onboarding_required: boolean;
|
||||
// onboarding_complete: boolean;
|
||||
// current_step?: number;
|
||||
// total_steps?: number;
|
||||
// completion_percentage?: number;
|
||||
// }
|
||||
|
||||
// Conditional CopilotKit wrapper that only shows sidebar on content-planning route
|
||||
const ConditionalCopilotKit: React.FC<{ children: React.ReactNode }> = ({ children }) => {
|
||||
// Do not render CopilotSidebar here. Let specific pages/components control it.
|
||||
return <>{children}</>;
|
||||
};
|
||||
|
||||
// Wrapper to only enable CopilotKit checks/provider when user is authenticated
|
||||
// This prevents CopilotKit from running on the Landing page
|
||||
const AuthenticatedCopilotWrapper: React.FC<{
|
||||
children: React.ReactNode;
|
||||
apiKey: string;
|
||||
}> = ({ children, apiKey }) => {
|
||||
const { isSignedIn } = useAuth();
|
||||
const location = useLocation();
|
||||
|
||||
// Exclude CopilotKit from running on:
|
||||
// 1. Landing page (handled by !isSignedIn)
|
||||
// 2. Onboarding pages (to prevent health check timeouts)
|
||||
// 3. Podcast-only demo mode (CopilotKit not needed)
|
||||
const isPodcastOnly = isPodcastOnlyDemoMode();
|
||||
const shouldExcludeCopilot = !isSignedIn || location.pathname.startsWith('/onboarding') || isPodcastOnly;
|
||||
|
||||
if (shouldExcludeCopilot) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
const hasKey = apiKey && apiKey.trim();
|
||||
|
||||
if (hasKey) {
|
||||
// Enhanced error handler that updates health context
|
||||
const handleCopilotKitError = (e: any) => {
|
||||
console.error("CopilotKit Error:", e);
|
||||
|
||||
// Try to get health context if available
|
||||
// We'll use a custom event to notify health context since we can't access it directly here
|
||||
const errorMessage = e?.error?.message || e?.message || 'CopilotKit error occurred';
|
||||
const errorType = errorMessage.toLowerCase();
|
||||
|
||||
// Differentiate between fatal and transient errors
|
||||
const isFatalError =
|
||||
errorType.includes('cors') ||
|
||||
errorType.includes('ssl') ||
|
||||
errorType.includes('certificate') ||
|
||||
errorType.includes('403') ||
|
||||
errorType.includes('forbidden') ||
|
||||
errorType.includes('ERR_CERT_COMMON_NAME_INVALID');
|
||||
|
||||
// Dispatch event for health context to listen to
|
||||
window.dispatchEvent(new CustomEvent('copilotkit-error', {
|
||||
detail: {
|
||||
error: e,
|
||||
errorMessage,
|
||||
isFatal: isFatalError,
|
||||
}
|
||||
}));
|
||||
};
|
||||
|
||||
return (
|
||||
<CopilotKitHealthProvider initialHealthStatus={true}>
|
||||
<CopilotKitDegradedBanner />
|
||||
<ErrorBoundary
|
||||
context="CopilotKit"
|
||||
showDetails={process.env.NODE_ENV === 'development'}
|
||||
fallback={
|
||||
<Box sx={{ p: 3, textAlign: 'center' }}>
|
||||
<Typography variant="h6" color="warning" gutterBottom>
|
||||
Chat Unavailable
|
||||
</Typography>
|
||||
<Typography variant="body2" color="textSecondary">
|
||||
CopilotKit encountered an error. The app continues to work with manual controls.
|
||||
</Typography>
|
||||
</Box>
|
||||
}
|
||||
>
|
||||
<CopilotKit
|
||||
publicApiKey={apiKey}
|
||||
showDevConsole={false}
|
||||
onError={handleCopilotKitError}
|
||||
>
|
||||
{children}
|
||||
</CopilotKit>
|
||||
</ErrorBoundary>
|
||||
</CopilotKitHealthProvider>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<CopilotKitHealthProvider initialHealthStatus={false}>
|
||||
<CopilotKitDegradedBanner />
|
||||
{children}
|
||||
</CopilotKitHealthProvider>
|
||||
);
|
||||
};
|
||||
|
||||
// Component to handle initial routing based on subscription and onboarding status
|
||||
// Flow: Subscription → Onboarding → Dashboard
|
||||
const InitialRouteHandler: React.FC = () => {
|
||||
const { loading, error, isOnboardingComplete, initializeOnboarding, data } = useOnboarding();
|
||||
const { subscription, loading: subscriptionLoading, checkSubscription } = useSubscription();
|
||||
const [connectionError, setConnectionError] = useState<{
|
||||
hasError: boolean;
|
||||
error: Error | null;
|
||||
}>({
|
||||
hasError: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
// Poll for OAuth token alerts and show toast notifications
|
||||
// Only enabled when user is authenticated (has subscription)
|
||||
useOAuthTokenAlerts({
|
||||
enabled: subscription?.active === true,
|
||||
interval: 60000, // Poll every 1 minute
|
||||
});
|
||||
|
||||
// Check subscription on mount (non-blocking - don't wait for it to route)
|
||||
useEffect(() => {
|
||||
// Delay subscription check slightly to allow auth token getter to be installed first
|
||||
const timeoutId = setTimeout(async () => {
|
||||
// Retry logic for initial subscription check
|
||||
const maxRetries = 3;
|
||||
for (let attempt = 0; attempt < maxRetries; attempt++) {
|
||||
try {
|
||||
await checkSubscription();
|
||||
break; // Success
|
||||
} catch (err) {
|
||||
console.error(`App: Subscription check attempt ${attempt + 1} failed:`, err);
|
||||
|
||||
// If it's a connection error and we have retries left, wait and retry
|
||||
const isConnectionError = err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError');
|
||||
|
||||
if (isConnectionError && attempt < maxRetries - 1) {
|
||||
const delay = 1000 * Math.pow(2, attempt); // 1s, 2s
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
|
||||
// If final attempt or not a connection error, handle it
|
||||
if (attempt === maxRetries - 1 || !isConnectionError) {
|
||||
if (isConnectionError) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err as Error,
|
||||
});
|
||||
}
|
||||
// Don't block routing on other errors
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 100); // Small delay to ensure TokenInstaller has run
|
||||
|
||||
return () => clearTimeout(timeoutId);
|
||||
}, []); // Remove checkSubscription dependency to prevent loop
|
||||
|
||||
// Initialize onboarding only after subscription is confirmed
|
||||
useEffect(() => {
|
||||
if (subscription && !subscriptionLoading) {
|
||||
// Check if user is new (no subscription record at all)
|
||||
const isNewUser = !subscription || subscription.plan === 'none';
|
||||
|
||||
console.log('InitialRouteHandler: Subscription data received:', {
|
||||
plan: subscription.plan,
|
||||
active: subscription.active,
|
||||
isNewUser,
|
||||
subscriptionLoading
|
||||
});
|
||||
|
||||
if (subscription.active && !isNewUser) {
|
||||
console.log('InitialRouteHandler: Subscription confirmed, initializing onboarding...');
|
||||
initializeOnboarding();
|
||||
}
|
||||
}
|
||||
}, [subscription, subscriptionLoading, initializeOnboarding]);
|
||||
|
||||
// Handle connection error - show connection error page
|
||||
if (connectionError.hasError) {
|
||||
const handleRetry = () => {
|
||||
setConnectionError({
|
||||
hasError: false,
|
||||
error: null,
|
||||
});
|
||||
// Re-trigger the subscription check using context
|
||||
checkSubscription().catch((err) => {
|
||||
if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err,
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const handleGoHome = () => {
|
||||
window.location.href = '/';
|
||||
};
|
||||
|
||||
return (
|
||||
<ConnectionErrorPage
|
||||
onRetry={handleRetry}
|
||||
onGoHome={handleGoHome}
|
||||
message={connectionError.error?.message || "Backend service is not available. Please check if the server is running."}
|
||||
title="Connection Error"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Loading state - only wait for onboarding init, not subscription check
|
||||
// Subscription check is non-blocking and happens in background
|
||||
const waitingForOnboardingInit = loading || !data;
|
||||
if (loading || waitingForOnboardingInit) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
{subscriptionLoading ? 'Checking subscription...' : 'Preparing your workspace...'}
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Error state
|
||||
if (error) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
p={3}
|
||||
>
|
||||
<Typography variant="h5" color="error" gutterBottom>
|
||||
Error
|
||||
</Typography>
|
||||
<Typography variant="body1" color="textSecondary" textAlign="center">
|
||||
{error}
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Decision tree for SIGNED-IN users:
|
||||
// Priority: Subscription → Onboarding → Dashboard (as per user flow: Landing → Subscription → Onboarding → Dashboard)
|
||||
|
||||
// 1. If subscription is still loading, show loading state
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// 2. No subscription data yet - handle gracefully
|
||||
// If onboarding is complete, allow access to dashboard (user already went through flow)
|
||||
// If onboarding not complete, check if subscription check is still loading or failed
|
||||
if (!subscription) {
|
||||
if (isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Onboarding complete but no subscription data → Dashboard (allow access)');
|
||||
return <Navigate to="/dashboard" replace />;
|
||||
}
|
||||
|
||||
// Onboarding not complete and no subscription data
|
||||
// If subscription check is still loading, show loading state
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Subscription check completed but returned null/undefined
|
||||
// This likely means no subscription - redirect to pricing
|
||||
console.log('InitialRouteHandler: No subscription data after check → Pricing page');
|
||||
return <Navigate to="/pricing" replace />;
|
||||
}
|
||||
|
||||
// 3. Check subscription status first
|
||||
const isNewUser = !subscription || subscription.plan === 'none';
|
||||
|
||||
// No active subscription → Show modal (SubscriptionContext handles this)
|
||||
// Don't redirect immediately - let the modal show first
|
||||
// User can click "Renew Subscription" button in modal to go to pricing
|
||||
// Or click "Maybe Later" to dismiss (but they still can't use features)
|
||||
if (isNewUser || !subscription.active) {
|
||||
console.log('InitialRouteHandler: No active subscription - modal will be shown by SubscriptionContext');
|
||||
// Note: SubscriptionContext will show the modal automatically when subscription is inactive
|
||||
// We still redirect to pricing for new users, but allow existing users with expired subscriptions
|
||||
// to see the modal first. The modal has a "Renew Subscription" button that navigates to pricing.
|
||||
// For new users (no subscription at all), redirect to pricing immediately
|
||||
if (isNewUser) {
|
||||
console.log('InitialRouteHandler: New user (no subscription) → Pricing page');
|
||||
return <Navigate to="/pricing" replace />;
|
||||
}
|
||||
// For existing users with inactive subscription, show modal but don't redirect immediately
|
||||
// The modal will be shown by SubscriptionContext, and user can click "Renew Subscription"
|
||||
// Allow access to dashboard (modal will be shown and block functionality)
|
||||
console.log('InitialRouteHandler: Inactive subscription - allowing access to show modal');
|
||||
// Continue to onboarding/dashboard flow - modal will be shown by SubscriptionContext
|
||||
}
|
||||
|
||||
// 4. Has active subscription, check onboarding status
|
||||
if (!isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Subscription active but onboarding incomplete → Onboarding');
|
||||
return <Navigate to="/onboarding" replace />;
|
||||
}
|
||||
|
||||
// 5. Has subscription AND completed onboarding → Dashboard
|
||||
console.log('InitialRouteHandler: All set (subscription + onboarding) → Dashboard');
|
||||
return <Navigate to="/dashboard" replace />;
|
||||
};
|
||||
|
||||
// Root route that chooses Landing (signed out) or InitialRouteHandler (signed in)
|
||||
const RootRoute: React.FC = () => {
|
||||
const { isSignedIn } = useAuth();
|
||||
if (isSignedIn) {
|
||||
return <InitialRouteHandler />;
|
||||
}
|
||||
return <Landing />;
|
||||
};
|
||||
|
||||
// Installs Clerk auth token getter into axios clients and stores user_id
|
||||
// Must render under ClerkProvider
|
||||
const TokenInstaller: React.FC = () => {
|
||||
const { getToken, userId, isSignedIn, signOut } = useAuth();
|
||||
|
||||
// Store user_id in localStorage when user signs in
|
||||
useEffect(() => {
|
||||
if (isSignedIn && userId) {
|
||||
console.log('TokenInstaller: Storing user_id in localStorage:', userId);
|
||||
localStorage.setItem('user_id', userId);
|
||||
|
||||
// Trigger event to notify SubscriptionContext that user is authenticated
|
||||
window.dispatchEvent(new CustomEvent('user-authenticated', { detail: { userId } }));
|
||||
} else if (!isSignedIn) {
|
||||
// Clear user_id when signed out
|
||||
console.log('TokenInstaller: Clearing user_id from localStorage');
|
||||
localStorage.removeItem('user_id');
|
||||
}
|
||||
}, [isSignedIn, userId]);
|
||||
|
||||
// Install token getter for API calls
|
||||
useEffect(() => {
|
||||
const tokenGetter = async () => {
|
||||
try {
|
||||
const template = process.env.REACT_APP_CLERK_JWT_TEMPLATE;
|
||||
// If a template is provided and it's not a placeholder, request a template-specific JWT
|
||||
if (template && template !== 'your_jwt_template_name_here') {
|
||||
// @ts-ignore Clerk types allow options object
|
||||
return await getToken({ template });
|
||||
}
|
||||
return await getToken();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
// Set token getter for main API client
|
||||
setAuthTokenGetter(tokenGetter);
|
||||
|
||||
// Set token getter for billing API client (same function)
|
||||
setBillingAuthTokenGetter(tokenGetter);
|
||||
|
||||
// Set token getter for media blob URL fetcher (for authenticated image/video requests)
|
||||
setMediaAuthTokenGetter(tokenGetter);
|
||||
}, [getToken]);
|
||||
|
||||
// Install Clerk signOut function for handling expired tokens
|
||||
useEffect(() => {
|
||||
if (signOut) {
|
||||
setClerkSignOut(async () => {
|
||||
await signOut();
|
||||
});
|
||||
}
|
||||
}, [signOut]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const App: React.FC = () => {
|
||||
// React Hooks MUST be at the top before any conditionals
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
// Get CopilotKit key from localStorage or .env
|
||||
const [copilotApiKey, setCopilotApiKey] = useState(() => {
|
||||
const savedKey = localStorage.getItem('copilotkit_api_key');
|
||||
const envKey = process.env.REACT_APP_COPILOTKIT_API_KEY || '';
|
||||
const key = (savedKey || envKey).trim();
|
||||
|
||||
// Validate key format if present
|
||||
if (key && !key.startsWith('ck_pub_')) {
|
||||
console.warn('CopilotKit API key format invalid - must start with ck_pub_');
|
||||
}
|
||||
|
||||
return key;
|
||||
});
|
||||
|
||||
// Initialize app - loading state will be managed by InitialRouteHandler
|
||||
useEffect(() => {
|
||||
// Remove manual health check - connection errors are handled by ErrorBoundary
|
||||
setLoading(false);
|
||||
}, []);
|
||||
|
||||
// Listen for CopilotKit key updates
|
||||
useEffect(() => {
|
||||
const handleKeyUpdate = (event: CustomEvent) => {
|
||||
const newKey = event.detail?.apiKey;
|
||||
if (newKey) {
|
||||
console.log('App: CopilotKit key updated, reloading...');
|
||||
setCopilotApiKey(newKey);
|
||||
setTimeout(() => window.location.reload(), 500);
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('copilotkit-key-updated', handleKeyUpdate as EventListener);
|
||||
return () => window.removeEventListener('copilotkit-key-updated', handleKeyUpdate as EventListener);
|
||||
}, []);
|
||||
|
||||
// Token installer must be inside ClerkProvider; see TokenInstaller below
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Connecting to ALwrity...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Get environment variables with fallbacks
|
||||
const clerkPublishableKey = process.env.REACT_APP_CLERK_PUBLISHABLE_KEY || '';
|
||||
const clerkJSUrl = process.env.REACT_APP_CLERK_JS_URL;
|
||||
|
||||
// Show error if required keys are missing
|
||||
if (!clerkPublishableKey) {
|
||||
return (
|
||||
<Box sx={{ p: 3, textAlign: 'center' }}>
|
||||
<Typography color="error" variant="h6">
|
||||
Missing Clerk Publishable Key
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ mt: 1 }}>
|
||||
Please add REACT_APP_CLERK_PUBLISHABLE_KEY to your .env file
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Render app with or without CopilotKit based on whether we have a key
|
||||
const renderApp = () => {
|
||||
return (
|
||||
<Router>
|
||||
<AuthenticatedCopilotWrapper apiKey={copilotApiKey}>
|
||||
<ConditionalCopilotKit>
|
||||
<TokenInstaller />
|
||||
<Routes>
|
||||
<Route path="/" element={<RootRoute />} />
|
||||
<Route
|
||||
path="/onboarding"
|
||||
element={
|
||||
<ErrorBoundary context="Onboarding Wizard" showDetails>
|
||||
<Wizard />
|
||||
</ErrorBoundary>
|
||||
}
|
||||
/>
|
||||
{/* Error Boundary Testing - Development Only */}
|
||||
{process.env.NODE_ENV === 'development' && (
|
||||
<Route path="/error-test" element={<ErrorBoundaryTest />} />
|
||||
)}
|
||||
<Route path="/dashboard" element={<ProtectedRoute><MainDashboard /></ProtectedRoute>} />
|
||||
<Route path="/seo" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
|
||||
<Route path="/seo-dashboard" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
|
||||
<Route path="/content-planning" element={<ProtectedRoute><ContentPlanningDashboard /></ProtectedRoute>} />
|
||||
<Route path="/facebook-writer" element={<ProtectedRoute><FacebookWriter /></ProtectedRoute>} />
|
||||
<Route path="/linkedin-writer" element={<ProtectedRoute><LinkedInWriter /></ProtectedRoute>} />
|
||||
<Route path="/blog-writer" element={<ProtectedRoute><BlogWriter /></ProtectedRoute>} />
|
||||
<Route path="/story-writer" element={<ProtectedRoute><StoryWriter /></ProtectedRoute>} />
|
||||
<Route path="/story-projects" element={<ProtectedRoute><StoryProjectList /></ProtectedRoute>} />
|
||||
<Route path="/youtube-creator" element={<ProtectedRoute><YouTubeCreator /></ProtectedRoute>} />
|
||||
<Route path="/podcast-maker" element={<ProtectedRoute><PodcastDashboard /></ProtectedRoute>} />
|
||||
<Route path="/image-studio" element={<ProtectedRoute><ImageStudioDashboard /></ProtectedRoute>} />
|
||||
<Route path="/video-studio" element={<ProtectedRoute><VideoStudioDashboard /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/create" element={<ProtectedRoute><CreateVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/avatar" element={<ProtectedRoute><AvatarVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/enhance" element={<ProtectedRoute><EnhanceVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/extend" element={<ProtectedRoute><ExtendVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/edit" element={<ProtectedRoute><EditVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/transform" element={<ProtectedRoute><TransformVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/social" element={<ProtectedRoute><SocialVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/face-swap" element={<ProtectedRoute><FaceSwap /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-translate" element={<ProtectedRoute><VideoTranslate /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-background-remover" element={<ProtectedRoute><VideoBackgroundRemover /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/add-audio-to-video" element={<ProtectedRoute><AddAudioToVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/library" element={<ProtectedRoute><LibraryVideo /></ProtectedRoute>} />
|
||||
<Route path="/image-generator" element={<ProtectedRoute><CreateStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-editor" element={<ProtectedRoute><EditStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-upscale" element={<ProtectedRoute><UpscaleStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-control" element={<ProtectedRoute><ControlStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/face-swap" element={<ProtectedRoute><FaceSwapStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/compress" element={<ProtectedRoute><CompressionStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/processing" element={<ProtectedRoute><ImageProcessingStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/social-optimizer" element={<ProtectedRoute><SocialOptimizer /></ProtectedRoute>} />
|
||||
<Route path="/asset-library" element={<ProtectedRoute><AssetLibrary /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator" element={<ProtectedRoute><ProductMarketingDashboard /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/photoshoot" element={<ProtectedRoute><ProductPhotoshootStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/animation" element={<ProtectedRoute><ProductAnimationStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/video" element={<ProtectedRoute><ProductVideoStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/avatar" element={<ProtectedRoute><ProductAvatarStudio /></ProtectedRoute>} />
|
||||
<Route path="/product-marketing" element={<Navigate to="/campaign-creator" replace />} />
|
||||
<Route path="/scheduler-dashboard" element={<ProtectedRoute><SchedulerDashboard /></ProtectedRoute>} />
|
||||
<Route path="/billing" element={<ProtectedRoute><BillingPage /></ProtectedRoute>} />
|
||||
<Route path="/approvals" element={<ProtectedRoute><ApprovalsPage /></ProtectedRoute>} />
|
||||
<Route path="/team-activity" element={<ProtectedRoute><TeamActivityPage /></ProtectedRoute>} />
|
||||
<Route path="/stripe-disputes" element={<ProtectedRoute><StripeDisputesDashboard /></ProtectedRoute>} />
|
||||
<Route path="/pricing" element={<PricingPage />} />
|
||||
<Route path="/research-test" element={<ResearchDashboard />} />
|
||||
<Route path="/research-dashboard" element={<ResearchDashboard />} />
|
||||
<Route path="/alwrity-researcher" element={<ResearchDashboard />} />
|
||||
<Route path="/intent-research" element={<IntentResearchTest />} />
|
||||
<Route path="/wix-test" element={<WixTestPage />} />
|
||||
<Route path="/wix-test-direct" element={<WixTestPage />} />
|
||||
<Route path="/wix/callback" element={<WixCallbackPage />} />
|
||||
<Route path="/wp/callback" element={<WordPressCallbackPage />} />
|
||||
<Route path="/gsc/callback" element={<GSCAuthCallback />} />
|
||||
<Route path="/bing/callback" element={<BingCallbackPage />} />
|
||||
<Route path="/bing-analytics-storage" element={<ProtectedRoute><BingAnalyticsStorage /></ProtectedRoute>} />
|
||||
</Routes>
|
||||
</ConditionalCopilotKit>
|
||||
</AuthenticatedCopilotWrapper>
|
||||
</Router>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<ErrorBoundary
|
||||
context="Application Root"
|
||||
showDetails={process.env.NODE_ENV === 'development'}
|
||||
onError={(error, errorInfo) => {
|
||||
// Custom error handler - send to analytics/monitoring
|
||||
console.error('Global error caught:', { error, errorInfo });
|
||||
// TODO: Send to error tracking service (Sentry, LogRocket, etc.)
|
||||
}}
|
||||
>
|
||||
<ClerkProvider publishableKey={clerkPublishableKey} clerkJSUrl={clerkJSUrl}>
|
||||
<SubscriptionProvider>
|
||||
<OnboardingProvider>
|
||||
{renderApp()}
|
||||
</OnboardingProvider>
|
||||
</SubscriptionProvider>
|
||||
</ClerkProvider>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
||||
537
_session_backup/ResearchSummary.tsx
Normal file
537
_session_backup/ResearchSummary.tsx
Normal file
@@ -0,0 +1,537 @@
|
||||
import React, { useMemo, useCallback } from "react";
|
||||
import { Stack, Typography, Chip, Divider, Box, alpha, Paper, Tooltip } from "@mui/material";
|
||||
import {
|
||||
Insights as InsightsIcon,
|
||||
Search as SearchIcon,
|
||||
AttachMoney as AttachMoneyIcon,
|
||||
EditNote as EditNoteIcon,
|
||||
Article as ArticleIcon,
|
||||
AutoAwesome as AutoAwesomeIcon,
|
||||
FormatQuote as FormatQuoteIcon,
|
||||
Campaign as CampaignIcon,
|
||||
Explore as ExploreIcon,
|
||||
} from "@mui/icons-material";
|
||||
import { Research, ResearchInsight } from "../types";
|
||||
import { GlassyCard, glassyCardSx, PrimaryButton } from "../ui";
|
||||
import { FactCard } from "../FactCard";
|
||||
|
||||
interface ResearchSummaryProps {
|
||||
research: Research;
|
||||
canGenerateScript: boolean;
|
||||
onGenerateScript: () => void;
|
||||
}
|
||||
|
||||
export const ResearchSummary: React.FC<ResearchSummaryProps> = ({
|
||||
research,
|
||||
canGenerateScript,
|
||||
onGenerateScript,
|
||||
}) => {
|
||||
// Simple markdown-to-HTML converter
|
||||
const renderMarkdown = useCallback((text: string) => {
|
||||
if (!text) return null;
|
||||
return text
|
||||
.split('\n')
|
||||
.filter(line => line.trim() !== '') // Remove empty lines
|
||||
.map((line, i) => {
|
||||
// Handle bold
|
||||
let processedLine = line.replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>');
|
||||
// Handle lists
|
||||
if (processedLine.trim().startsWith('- ') || processedLine.trim().startsWith('* ')) {
|
||||
return <li key={i} dangerouslySetInnerHTML={{ __html: processedLine.trim().substring(2) }} style={{ marginBottom: '4px', fontSize: '0.9rem' }} />;
|
||||
}
|
||||
// Handle headers - make them smaller
|
||||
if (processedLine.startsWith('### ')) {
|
||||
return <Typography key={i} variant="subtitle2" fontWeight={700} sx={{ mt: 1.5, mb: 0.5, color: '#1e293b' }}>{processedLine.substring(4)}</Typography>;
|
||||
}
|
||||
if (processedLine.startsWith('## ')) {
|
||||
return <Typography key={i} variant="subtitle1" fontWeight={700} sx={{ mt: 1.5, mb: 0.5, color: '#0f172a' }}>{processedLine.substring(3)}</Typography>;
|
||||
}
|
||||
// Paragraphs - compact spacing
|
||||
return processedLine.trim() ? <p key={i} dangerouslySetInnerHTML={{ __html: processedLine }} style={{ margin: '4px 0', fontSize: '0.9rem' }} /> : null;
|
||||
});
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<GlassyCard sx={glassyCardSx}>
|
||||
<Stack spacing={3}>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center" flexWrap="wrap" gap={2}>
|
||||
<Stack direction="row" alignItems="center" spacing={2} sx={{ flex: 1 }}>
|
||||
<Typography variant="h6" sx={{ display: "flex", alignItems: "center", gap: 1, color: "#0f172a", fontWeight: 700 }}>
|
||||
<InsightsIcon />
|
||||
Research Summary
|
||||
</Typography>
|
||||
|
||||
{/* Research Metadata - Moved alongside title */}
|
||||
<Stack direction="row" spacing={1.5} flexWrap="wrap">
|
||||
{research.searchQueries && research.searchQueries.length > 0 && (
|
||||
<Chip
|
||||
icon={<SearchIcon sx={{ fontSize: "1rem !important" }} />}
|
||||
label={`${research.searchQueries.length} search${research.searchQueries.length > 1 ? "es" : ""}`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#667eea", 0.1),
|
||||
color: "#667eea",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(102, 126, 234, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{research.searchType && (
|
||||
<Chip
|
||||
label={`${research.searchType.charAt(0).toUpperCase() + research.searchType.slice(1)} search`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#10b981", 0.1),
|
||||
color: "#059669",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(16, 185, 129, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{research.sourceCount !== undefined && (
|
||||
<Chip
|
||||
label={`${research.sourceCount} source${research.sourceCount !== 1 ? "s" : ""}`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#6366f1", 0.1),
|
||||
color: "#4f46e5",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(99, 102, 241, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{research.cost !== undefined && (
|
||||
<Chip
|
||||
icon={<AttachMoneyIcon sx={{ fontSize: "0.875rem !important" }} />}
|
||||
label={`$${research.cost.toFixed(3)}`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#f59e0b", 0.1),
|
||||
color: "#d97706",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(245, 158, 11, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Stack>
|
||||
</Stack>
|
||||
|
||||
<PrimaryButton
|
||||
onClick={onGenerateScript}
|
||||
disabled={!canGenerateScript}
|
||||
startIcon={<EditNoteIcon />}
|
||||
tooltip={!canGenerateScript ? "Complete research to generate script" : "Generate AI-powered script from research"}
|
||||
>
|
||||
Generate Script
|
||||
</PrimaryButton>
|
||||
</Stack>
|
||||
|
||||
<Box sx={{ width: "100%" }}>
|
||||
{/* Main Summary */}
|
||||
{research.summary && (
|
||||
<Paper
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
mb: 3,
|
||||
background: "#f8fafc",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Typography variant="subtitle2" sx={{ mb: 1.5, color: "#64748b", fontWeight: 700, fontSize: "0.75rem", textTransform: "uppercase", letterSpacing: "0.05em", display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<AutoAwesomeIcon fontSize="small" sx={{ color: "#667eea", fontSize: "1rem" }} />
|
||||
Executive Summary
|
||||
</Typography>
|
||||
<Box sx={{
|
||||
lineHeight: 1.6,
|
||||
fontSize: "0.9rem",
|
||||
color: "#334155",
|
||||
"& p": { m: 0, mb: 1 },
|
||||
"& ul": { m: 0, mb: 1, pl: 2.5 },
|
||||
"& li": { mb: 0.5 },
|
||||
"& strong": { color: "#0f172a", fontWeight: 600 }
|
||||
}}>
|
||||
{renderMarkdown(research.summary)}
|
||||
</Box>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Deep Insights */}
|
||||
{(research.keyInsights && research.keyInsights.length > 0) ? (
|
||||
<Box sx={{ mb: 4 }}>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<ArticleIcon sx={{ color: "#667eea" }} />
|
||||
Deep Insights
|
||||
</Typography>
|
||||
<Stack spacing={2.5}>
|
||||
{research.keyInsights.map((insight: ResearchInsight, idx: number) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "#ffffff",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
boxShadow: "0 2px 12px rgba(0,0,0,0.03)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="flex-start" sx={{ mb: 1.5 }}>
|
||||
<Typography variant="subtitle1" sx={{ color: "#0f172a", fontWeight: 700 }}>
|
||||
{insight.title}
|
||||
</Typography>
|
||||
{insight.source_indices && insight.source_indices.length > 0 && (
|
||||
<Stack direction="row" spacing={0.5}>
|
||||
{insight.source_indices.map(sIdx => {
|
||||
const sourceIdx = sIdx - 1;
|
||||
const fact = research.factCards[sourceIdx];
|
||||
const sourceUrl = fact?.url;
|
||||
const hasUrl = !!sourceUrl;
|
||||
const hue = (sIdx * 47 + 220) % 360;
|
||||
const gradientFrom = `hsl(${hue}, 70%, 55%)`;
|
||||
const gradientTo = `hsl(${(hue + 30) % 360}, 80%, 65%)`;
|
||||
return (
|
||||
<Tooltip
|
||||
key={sIdx}
|
||||
title={hasUrl ? (
|
||||
<Box sx={{ maxWidth: 300, wordBreak: "break-all" }}>
|
||||
<Typography variant="caption" sx={{ color: "#fff", fontWeight: 600 }}>Source {sIdx}</Typography>
|
||||
<br />
|
||||
<Typography variant="caption" sx={{ color: "rgba(255,255,255,0.8)", fontSize: "0.65rem" }}>{sourceUrl}</Typography>
|
||||
</Box>
|
||||
) : `Source ${sIdx}`}
|
||||
arrow
|
||||
placement="top"
|
||||
>
|
||||
<Chip
|
||||
label={hasUrl ? `S${sIdx} ↗` : `S${sIdx}`}
|
||||
size="small"
|
||||
onClick={hasUrl ? () => window.open(sourceUrl, "_blank", "noopener,noreferrer") : undefined}
|
||||
sx={{
|
||||
height: 24,
|
||||
minWidth: 36,
|
||||
fontSize: '0.7rem',
|
||||
fontWeight: 800,
|
||||
fontFamily: "'Inter', 'Roboto', monospace",
|
||||
letterSpacing: "0.02em",
|
||||
border: "none",
|
||||
background: hasUrl
|
||||
? `linear-gradient(135deg, ${gradientFrom}, ${gradientTo})`
|
||||
: `linear-gradient(135deg, ${alpha(gradientFrom, 0.3)}, ${alpha(gradientTo, 0.3)})`,
|
||||
color: hasUrl ? "#fff" : alpha("#fff", 0.7),
|
||||
cursor: hasUrl ? "pointer" : "default",
|
||||
borderRadius: "8px",
|
||||
px: 0.5,
|
||||
boxShadow: hasUrl
|
||||
? `0 2px 8px ${alpha(gradientFrom, 0.35)}, inset 0 1px 0 ${alpha("#fff", 0.2)}`
|
||||
: "none",
|
||||
transition: "all 0.2s ease",
|
||||
"&:hover": hasUrl ? {
|
||||
background: `linear-gradient(135deg, ${gradientTo}, ${gradientFrom})`,
|
||||
boxShadow: `0 4px 14px ${alpha(gradientFrom, 0.5)}, inset 0 1px 0 ${alpha("#fff", 0.3)}`,
|
||||
transform: "translateY(-1px)",
|
||||
} : {},
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</Stack>
|
||||
)}
|
||||
</Stack>
|
||||
<Box sx={{
|
||||
color: "#475569",
|
||||
lineHeight: 1.7,
|
||||
fontSize: "0.9rem",
|
||||
"& p": { m: 0, mb: 1.5 },
|
||||
"& ul": { m: 0, mb: 1.5, pl: 2 }
|
||||
}}>
|
||||
{renderMarkdown(insight.content)}
|
||||
</Box>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
) : (
|
||||
/* Fallback if keyInsights is missing but we have summary paragraphs */
|
||||
research.summary && research.summary.length > 500 && !research.keyInsights && (
|
||||
<Box sx={{ mb: 4 }}>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<ArticleIcon sx={{ color: "#667eea" }} />
|
||||
Additional Insights
|
||||
</Typography>
|
||||
<Paper
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "#ffffff",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
boxShadow: "0 2px 12px rgba(0,0,0,0.03)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Box sx={{
|
||||
color: "#475569",
|
||||
lineHeight: 1.7,
|
||||
fontSize: "0.9rem",
|
||||
}}>
|
||||
{/* Render parts of summary that might contain insights if structured data is missing */}
|
||||
{renderMarkdown(research.summary.split('\n\n').slice(1).join('\n\n'))}
|
||||
</Box>
|
||||
</Paper>
|
||||
</Box>
|
||||
)
|
||||
)}
|
||||
|
||||
{/* Expert Quotes Section */}
|
||||
{research.expertQuotes && research.expertQuotes.length > 0 && (
|
||||
<Box sx={{ mt: 4, pt: 3, borderTop: "1px solid rgba(0,0,0,0.04)" }}>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<FormatQuoteIcon sx={{ color: "#8b5cf6" }} />
|
||||
Expert Quotes ({research.expertQuotes.length})
|
||||
</Typography>
|
||||
<Stack spacing={2}>
|
||||
{research.expertQuotes.map((eq, idx) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "linear-gradient(135deg, rgba(139, 92, 246, 0.04) 0%, rgba(99, 102, 241, 0.04) 100%)",
|
||||
border: "1px solid rgba(139, 92, 246, 0.15)",
|
||||
borderLeft: "4px solid #8b5cf6",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" spacing={1.5} alignItems="flex-start">
|
||||
<FormatQuoteIcon sx={{ color: "#8b5cf6", fontSize: "1.5rem", mt: -0.5, opacity: 0.7 }} />
|
||||
<Box sx={{ flex: 1 }}>
|
||||
<Typography variant="body2" sx={{ color: "#1e293b", fontStyle: "italic", lineHeight: 1.7, fontSize: "0.95rem" }}>
|
||||
“{eq.quote}”
|
||||
</Typography>
|
||||
{eq.source_index !== undefined && (() => {
|
||||
const fact = research.factCards[eq.source_index - 1];
|
||||
const sourceUrl = fact?.url;
|
||||
const hasUrl = !!sourceUrl;
|
||||
const hue = (eq.source_index * 47 + 270) % 360;
|
||||
const gradientFrom = `hsl(${hue}, 70%, 55%)`;
|
||||
const gradientTo = `hsl(${(hue + 30) % 360}, 80%, 65%)`;
|
||||
return (
|
||||
<Box sx={{ mt: 1 }}>
|
||||
<Tooltip title={hasUrl ? (
|
||||
<Box sx={{ maxWidth: 300, wordBreak: "break-all" }}>
|
||||
<Typography variant="caption" sx={{ color: "#fff", fontWeight: 600 }}>Source {eq.source_index}</Typography>
|
||||
<br />
|
||||
<Typography variant="caption" sx={{ color: "rgba(255,255,255,0.8)", fontSize: "0.65rem" }}>{sourceUrl}</Typography>
|
||||
</Box>
|
||||
) : `Source ${eq.source_index}`} arrow placement="top">
|
||||
<Chip
|
||||
label={hasUrl ? `Source ${eq.source_index} ↗` : `Source ${eq.source_index}`}
|
||||
size="small"
|
||||
onClick={hasUrl ? () => window.open(sourceUrl, "_blank", "noopener,noreferrer") : undefined}
|
||||
sx={{
|
||||
height: 24,
|
||||
fontSize: "0.7rem",
|
||||
fontWeight: 800,
|
||||
fontFamily: "'Inter', 'Roboto', monospace",
|
||||
border: "none",
|
||||
background: hasUrl
|
||||
? `linear-gradient(135deg, ${gradientFrom}, ${gradientTo})`
|
||||
: `linear-gradient(135deg, ${alpha(gradientFrom, 0.3)}, ${alpha(gradientTo, 0.3)})`,
|
||||
color: hasUrl ? "#fff" : alpha("#fff", 0.7),
|
||||
cursor: hasUrl ? "pointer" : "default",
|
||||
borderRadius: "8px",
|
||||
px: 1,
|
||||
boxShadow: hasUrl
|
||||
? `0 2px 8px ${alpha(gradientFrom, 0.35)}, inset 0 1px 0 ${alpha("#fff", 0.2)}`
|
||||
: "none",
|
||||
transition: "all 0.2s ease",
|
||||
"&:hover": hasUrl ? {
|
||||
background: `linear-gradient(135deg, ${gradientTo}, ${gradientFrom})`,
|
||||
boxShadow: `0 4px 14px ${alpha(gradientFrom, 0.5)}, inset 0 1px 0 ${alpha("#fff", 0.3)}`,
|
||||
transform: "translateY(-1px)",
|
||||
} : {},
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
);
|
||||
})()}
|
||||
</Box>
|
||||
</Stack>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Search Queries Used */}
|
||||
{research.searchQueries && research.searchQueries.length > 0 && (
|
||||
<Box sx={{ mt: 4, pt: 3, borderTop: "1px solid rgba(0,0,0,0.04)" }}>
|
||||
<Typography variant="subtitle2" sx={{ mb: 1.5, color: "#64748b", fontWeight: 700, fontSize: "0.7rem", textTransform: "uppercase", letterSpacing: "0.05em" }}>
|
||||
Search Queries Used
|
||||
</Typography>
|
||||
<Stack direction="row" spacing={1} flexWrap="wrap" useFlexGap>
|
||||
{research.searchQueries.map((query, idx) => (
|
||||
<Chip
|
||||
key={idx}
|
||||
label={query}
|
||||
size="small"
|
||||
variant="outlined"
|
||||
sx={{
|
||||
borderColor: "rgba(102, 126, 234, 0.15)",
|
||||
color: "#94a3b8",
|
||||
background: alpha("#f8fafc", 0.3),
|
||||
fontSize: "0.7rem",
|
||||
borderRadius: 1,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{research.factCards.length > 0 && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(0,0,0,0.08)" }} />
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center" sx={{ mb: 1.5, flexWrap: "wrap", gap: 1 }}>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600 }}>
|
||||
Research Sources & Facts ({research.factCards.length})
|
||||
</Typography>
|
||||
<Typography variant="caption" sx={{ color: "#64748b", fontSize: "0.75rem" }}>
|
||||
Click to expand • Hover to see source
|
||||
</Typography>
|
||||
</Stack>
|
||||
<Box
|
||||
sx={{
|
||||
display: "grid",
|
||||
gridTemplateColumns: { xs: "1fr", sm: "repeat(2, 1fr)", md: "repeat(3, 1fr)", lg: "repeat(4, 1fr)" },
|
||||
gap: 1.5,
|
||||
width: "100%",
|
||||
overflow: "hidden",
|
||||
}}
|
||||
>
|
||||
{research.factCards.map((fact) => (
|
||||
<FactCard key={fact.id} fact={fact} />
|
||||
))}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Listener CTA Section */}
|
||||
{research.listenerCta && research.listenerCta.length > 0 && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(0,0,0,0.08)" }} />
|
||||
<Box>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<CampaignIcon sx={{ color: "#f59e0b" }} />
|
||||
Listener Call-to-Action Ideas ({research.listenerCta.length})
|
||||
</Typography>
|
||||
<Stack spacing={1.5}>
|
||||
{research.listenerCta.map((cta, idx) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2,
|
||||
background: "linear-gradient(135deg, rgba(245, 158, 11, 0.05) 0%, rgba(251, 191, 36, 0.05) 100%)",
|
||||
border: "1px solid rgba(245, 158, 11, 0.15)",
|
||||
borderRadius: 2,
|
||||
display: "flex",
|
||||
alignItems: "flex-start",
|
||||
gap: 1.5,
|
||||
}}
|
||||
>
|
||||
<Chip
|
||||
label={`#${idx + 1}`}
|
||||
size="small"
|
||||
sx={{
|
||||
bgcolor: alpha("#f59e0b", 0.15),
|
||||
color: "#b45309",
|
||||
fontWeight: 700,
|
||||
fontSize: "0.7rem",
|
||||
height: 24,
|
||||
minWidth: 32,
|
||||
}}
|
||||
/>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.6, flex: 1, pt: 0.2 }}>
|
||||
{cta}
|
||||
</Typography>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Mapped Angles Section */}
|
||||
{research.mappedAngles && research.mappedAngles.length > 0 && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(0,0,0,0.08)" }} />
|
||||
<Box>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<ExploreIcon sx={{ color: "#06b6d4" }} />
|
||||
Content Angles ({research.mappedAngles.length})
|
||||
</Typography>
|
||||
<Stack spacing={2}>
|
||||
{research.mappedAngles.map((angle, idx) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "#ffffff",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
borderLeft: "4px solid #06b6d4",
|
||||
boxShadow: "0 2px 12px rgba(0,0,0,0.03)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="flex-start" sx={{ mb: 1 }}>
|
||||
<Typography variant="subtitle1" sx={{ color: "#0f172a", fontWeight: 700 }}>
|
||||
{angle.title}
|
||||
</Typography>
|
||||
{angle.mappedFactIds && angle.mappedFactIds.length > 0 && (
|
||||
<Stack direction="row" spacing={0.5}>
|
||||
{angle.mappedFactIds.slice(0, 4).map((fid: string) => (
|
||||
<Chip
|
||||
key={fid}
|
||||
label={fid.replace("fact_", "F")}
|
||||
size="small"
|
||||
variant="outlined"
|
||||
sx={{
|
||||
height: 18,
|
||||
fontSize: "0.6rem",
|
||||
fontWeight: 700,
|
||||
borderColor: alpha("#06b6d4", 0.3),
|
||||
color: "#06b6d4",
|
||||
bgcolor: alpha("#06b6d4", 0.05),
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
{angle.mappedFactIds.length > 4 && (
|
||||
<Chip
|
||||
label={`+${angle.mappedFactIds.length - 4}`}
|
||||
size="small"
|
||||
sx={{ height: 18, fontSize: "0.6rem", color: "#64748b" }}
|
||||
/>
|
||||
)}
|
||||
</Stack>
|
||||
)}
|
||||
</Stack>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7, fontSize: "0.9rem" }}>
|
||||
{angle.why}
|
||||
</Typography>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
</Stack>
|
||||
</GlassyCard>
|
||||
);
|
||||
};
|
||||
|
||||
811
_session_backup/SceneEditor.tsx
Normal file
811
_session_backup/SceneEditor.tsx
Normal file
@@ -0,0 +1,811 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Stack, Box, Typography, Divider, Chip, alpha, CircularProgress, LinearProgress, IconButton, Tooltip } from "@mui/material";
|
||||
import {
|
||||
EditNote as EditNoteIcon,
|
||||
CheckCircle as CheckCircleIcon,
|
||||
RadioButtonUnchecked as RadioButtonUncheckedIcon,
|
||||
VolumeUp as VolumeUpIcon,
|
||||
PlayArrow as PlayArrowIcon,
|
||||
Image as ImageIcon,
|
||||
Delete as DeleteIcon,
|
||||
} from "@mui/icons-material";
|
||||
import { Scene, Line, Knobs } from "../types";
|
||||
import { GlassyCard, glassyCardSx, PrimaryButton } from "../ui";
|
||||
import { LineEditor } from "./LineEditor";
|
||||
import { ImageRegenerateModal, ImageGenerationSettings } from "./ImageRegenerateModal";
|
||||
import { AudioRegenerateModal, AudioGenerationSettings } from "./AudioRegenerateModal";
|
||||
import { podcastApi } from "../../../services/podcastApi";
|
||||
import { aiApiClient } from "../../../api/client";
|
||||
import { getCachedMedia, setCachedMedia } from "../../../utils/mediaCache";
|
||||
|
||||
interface SceneEditorProps {
|
||||
scene: Scene;
|
||||
onUpdateScene: (s: Scene) => void;
|
||||
onApprove: (id: string) => Promise<void>;
|
||||
onDelete: (sceneId: string) => void;
|
||||
knobs: Knobs;
|
||||
approvingSceneId?: string | null;
|
||||
generatingAudioId?: string | null;
|
||||
onAudioGenerationStart?: (sceneId: string) => void;
|
||||
onAudioGenerated?: (sceneId: string, audioUrl: string) => void;
|
||||
idea?: string; // Podcast idea for image generation context
|
||||
avatarUrl?: string | null; // Base avatar URL for consistent scene image generation
|
||||
totalScenes?: number; // Total number of scenes in the script
|
||||
}
|
||||
|
||||
export const SceneEditor: React.FC<SceneEditorProps> = ({
|
||||
scene,
|
||||
onUpdateScene,
|
||||
onApprove,
|
||||
onDelete,
|
||||
knobs,
|
||||
approvingSceneId,
|
||||
generatingAudioId,
|
||||
onAudioGenerationStart,
|
||||
onAudioGenerated,
|
||||
idea,
|
||||
avatarUrl,
|
||||
totalScenes,
|
||||
}) => {
|
||||
const [localGenerating, setLocalGenerating] = useState(false);
|
||||
const [generatingImage, setGeneratingImage] = useState(false);
|
||||
const [imageGenerationStatus, setImageGenerationStatus] = useState<string>("");
|
||||
const [imageGenerationProgress, setImageGenerationProgress] = useState<number>(0);
|
||||
const [audioBlobUrl, setAudioBlobUrl] = useState<string | null>(null);
|
||||
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null);
|
||||
const [imageLoading, setImageLoading] = useState(false);
|
||||
const [showRegenerateModal, setShowRegenerateModal] = useState(false);
|
||||
const [showAudioModal, setShowAudioModal] = useState(false);
|
||||
const [audioSettings, setAudioSettings] = useState<AudioGenerationSettings>({
|
||||
voiceId: "Wise_Woman",
|
||||
speed: 1.0,
|
||||
volume: 1.0,
|
||||
pitch: 0.0,
|
||||
emotion: scene.emotion || "neutral",
|
||||
englishNormalization: true,
|
||||
sampleRate: 24000,
|
||||
bitrate: 64000,
|
||||
channel: "1",
|
||||
format: "mp3",
|
||||
languageBoost: "auto",
|
||||
});
|
||||
|
||||
// Load audio as blob when audioUrl is available
|
||||
useEffect(() => {
|
||||
if (!scene.audioUrl) {
|
||||
// Clean up blob URL if audioUrl is removed
|
||||
setAudioBlobUrl((currentBlobUrl) => {
|
||||
if (currentBlobUrl) {
|
||||
URL.revokeObjectURL(currentBlobUrl);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let isMounted = true;
|
||||
const currentAudioUrl = scene.audioUrl; // Capture current value
|
||||
|
||||
const loadAudioBlob = async () => {
|
||||
try {
|
||||
// Normalize path
|
||||
let audioPath = currentAudioUrl.startsWith('/') ? currentAudioUrl : `/${currentAudioUrl}`;
|
||||
|
||||
// Convert /api/story/audio/ to /api/podcast/audio/ if needed
|
||||
if (audioPath.includes('/api/story/audio/')) {
|
||||
const filename = audioPath.split('/api/story/audio/').pop() || '';
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Ensure it's a podcast audio endpoint
|
||||
if (!audioPath.includes('/api/podcast/audio/')) {
|
||||
const filename = audioPath.split('/').pop() || currentAudioUrl;
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
audioPath = audioPath.split('?')[0];
|
||||
|
||||
const response = await aiApiClient.get(audioPath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
if (!isMounted) {
|
||||
// Component unmounted or audioUrl changed, don't set blob URL
|
||||
return;
|
||||
}
|
||||
|
||||
// Double-check that audioUrl hasn't changed
|
||||
if (scene.audioUrl !== currentAudioUrl) {
|
||||
return;
|
||||
}
|
||||
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
setAudioBlobUrl((prevBlobUrl) => {
|
||||
// Clean up previous blob URL if exists
|
||||
if (prevBlobUrl && prevBlobUrl !== blobUrl) {
|
||||
URL.revokeObjectURL(prevBlobUrl);
|
||||
}
|
||||
return blobUrl;
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Failed to load audio blob for scene ${scene.id}:`, error);
|
||||
// Don't set blob URL on error - will show error state
|
||||
}
|
||||
};
|
||||
|
||||
loadAudioBlob();
|
||||
|
||||
// Cleanup: only mark as unmounted, don't revoke blob URL here
|
||||
// The blob URL will be cleaned up when audioUrl changes (new effect) or component unmounts
|
||||
return () => {
|
||||
isMounted = false;
|
||||
};
|
||||
}, [scene.audioUrl, scene.id]);
|
||||
|
||||
// Load image as blob when imageUrl is available
|
||||
useEffect(() => {
|
||||
if (!scene.imageUrl) {
|
||||
// Clean up blob URL if imageUrl is removed
|
||||
setImageBlobUrl((currentBlobUrl) => {
|
||||
if (currentBlobUrl && currentBlobUrl.startsWith('blob:')) {
|
||||
URL.revokeObjectURL(currentBlobUrl);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Check cache first with scene context
|
||||
const cachedUrl = getCachedMedia(scene.imageUrl, scene.id);
|
||||
if (cachedUrl) {
|
||||
console.log('[SceneEditor] Using cached image:', scene.imageUrl, `(scene: ${scene.id})`);
|
||||
setImageBlobUrl(cachedUrl);
|
||||
setImageLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
let isMounted = true;
|
||||
const currentImageUrl = scene.imageUrl; // Capture current value
|
||||
|
||||
const loadImageBlob = async () => {
|
||||
try {
|
||||
setImageLoading(true);
|
||||
|
||||
// Check cache again in case it was loaded while we were waiting
|
||||
const cachedUrl = getCachedMedia(currentImageUrl, scene.id);
|
||||
if (cachedUrl) {
|
||||
if (isMounted) {
|
||||
setImageBlobUrl(cachedUrl);
|
||||
setImageLoading(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('[SceneEditor] Loading image blob for:', currentImageUrl);
|
||||
|
||||
// Normalize path
|
||||
let imagePath = currentImageUrl.startsWith('/') ? currentImageUrl : `/${currentImageUrl}`;
|
||||
|
||||
// Convert /api/story/images/ to /api/podcast/images/ if needed
|
||||
if (imagePath.includes('/api/story/images/')) {
|
||||
const filename = imagePath.split('/api/story/images/').pop() || '';
|
||||
imagePath = `/api/podcast/images/${filename}`;
|
||||
}
|
||||
|
||||
// Ensure it's a podcast image endpoint
|
||||
if (!imagePath.includes('/api/podcast/images/')) {
|
||||
const filename = imagePath.split('/').pop() || currentImageUrl;
|
||||
imagePath = `/api/podcast/images/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
imagePath = imagePath.split('?')[0];
|
||||
|
||||
const response = await aiApiClient.get(imagePath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
if (!isMounted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Double-check that imageUrl hasn't changed
|
||||
if (scene.imageUrl !== currentImageUrl) {
|
||||
return;
|
||||
}
|
||||
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
// Cache the blob URL with scene context
|
||||
setCachedMedia(currentImageUrl, blobUrl, 'image', blob.size, scene.id);
|
||||
|
||||
setImageBlobUrl((prevBlobUrl) => {
|
||||
// Clean up previous blob URL if exists
|
||||
if (prevBlobUrl && prevBlobUrl !== blobUrl && prevBlobUrl.startsWith('blob:')) {
|
||||
URL.revokeObjectURL(prevBlobUrl);
|
||||
}
|
||||
return blobUrl;
|
||||
});
|
||||
console.log('[SceneEditor] Image blob loaded and cached successfully:', currentImageUrl);
|
||||
} catch (error) {
|
||||
console.error('[SceneEditor] Failed to load image blob:', error);
|
||||
if (isMounted) {
|
||||
// Try adding query token as fallback
|
||||
try {
|
||||
const token = localStorage.getItem('clerk_dashboard_token') || '';
|
||||
if (token) {
|
||||
const urlWithToken = `${currentImageUrl}?token=${encodeURIComponent(token)}`;
|
||||
setImageBlobUrl(urlWithToken);
|
||||
setCachedMedia(currentImageUrl, urlWithToken, 'image', undefined, scene.id);
|
||||
}
|
||||
} catch (fallbackError) {
|
||||
console.error('[SceneEditor] Fallback image loading failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (isMounted) {
|
||||
setImageLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
loadImageBlob();
|
||||
|
||||
return () => {
|
||||
isMounted = false;
|
||||
// Don't cleanup blob URL here - let the cache handle it
|
||||
};
|
||||
}, [scene.imageUrl]);
|
||||
|
||||
const updateLine = (updatedLine: Line) => {
|
||||
const updated = { ...scene, lines: scene.lines.map((l) => (l.id === updatedLine.id ? updatedLine : l)) };
|
||||
onUpdateScene(updated);
|
||||
};
|
||||
|
||||
const approving = approvingSceneId === scene.id;
|
||||
const generating = generatingAudioId === scene.id || localGenerating;
|
||||
const hasAudio = Boolean(scene.audioUrl && audioBlobUrl);
|
||||
const hasImage = Boolean(scene.imageUrl);
|
||||
|
||||
const handleApproveAndGenerate = async (settings?: AudioGenerationSettings) => {
|
||||
const wasAlreadyApproved = scene.approved;
|
||||
const sceneId = scene.id;
|
||||
|
||||
try {
|
||||
// Set generating state
|
||||
setLocalGenerating(true);
|
||||
if (onAudioGenerationStart) {
|
||||
onAudioGenerationStart(sceneId);
|
||||
}
|
||||
|
||||
// If scene is not approved yet, approve it first
|
||||
// This will update the parent script state
|
||||
if (!scene.approved) {
|
||||
await onApprove(sceneId);
|
||||
// The parent's approveScene already updated the script state
|
||||
// We need to wait for React to propagate the updated scene prop
|
||||
// For now, we'll update it locally too to ensure UI updates immediately
|
||||
onUpdateScene({ ...scene, approved: true });
|
||||
}
|
||||
|
||||
// Use the current scene (which should now be approved)
|
||||
// If scene prop hasn't updated yet, use the local update we just made
|
||||
const currentScene = { ...scene, approved: true };
|
||||
|
||||
// Generate audio
|
||||
const effectiveSettings = settings || audioSettings;
|
||||
const result = await podcastApi.renderSceneAudio({
|
||||
scene: currentScene,
|
||||
voiceId: effectiveSettings.voiceId || "Wise_Woman",
|
||||
emotion: effectiveSettings.emotion || scene.emotion || knobs.voice_emotion || "neutral",
|
||||
speed: effectiveSettings.speed ?? knobs.voice_speed ?? 1.0,
|
||||
volume: effectiveSettings.volume ?? 1.0,
|
||||
pitch: effectiveSettings.pitch ?? 0.0,
|
||||
englishNormalization: effectiveSettings.englishNormalization ?? true,
|
||||
sampleRate: effectiveSettings.sampleRate,
|
||||
bitrate: effectiveSettings.bitrate,
|
||||
channel: effectiveSettings.channel,
|
||||
format: effectiveSettings.format,
|
||||
languageBoost: effectiveSettings.languageBoost,
|
||||
});
|
||||
|
||||
// Update scene with audio URL and ensure approved state
|
||||
// This will sync with parent script state
|
||||
const updatedScene = { ...currentScene, audioUrl: result.audioUrl, approved: true };
|
||||
onUpdateScene(updatedScene);
|
||||
|
||||
if (onAudioGenerated) {
|
||||
onAudioGenerated(sceneId, result.audioUrl);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to approve and generate audio:", error);
|
||||
// On error, revert approval only if we just approved it in this call
|
||||
if (!wasAlreadyApproved) {
|
||||
onUpdateScene({ ...scene, approved: false, audioUrl: undefined });
|
||||
}
|
||||
throw error;
|
||||
} finally {
|
||||
setLocalGenerating(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleGenerateImage = async (settings?: ImageGenerationSettings) => {
|
||||
const sceneId = scene.id;
|
||||
const startTime = Date.now();
|
||||
let progressInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
try {
|
||||
setGeneratingImage(true);
|
||||
setShowRegenerateModal(false);
|
||||
setImageGenerationStatus("Submitting image generation request...");
|
||||
setImageGenerationProgress(10);
|
||||
|
||||
// Build scene content from lines for context
|
||||
const sceneContent = scene.lines.map((line) => line.text).join(" ");
|
||||
|
||||
// Log avatar URL for debugging
|
||||
console.log("[SceneEditor] Generating image with avatarUrl:", avatarUrl);
|
||||
console.log("[SceneEditor] Custom settings:", settings);
|
||||
|
||||
// Simulate progress updates during API call
|
||||
progressInterval = setInterval(() => {
|
||||
const elapsed = Date.now() - startTime;
|
||||
const seconds = Math.floor(elapsed / 1000);
|
||||
|
||||
// Update status based on elapsed time
|
||||
if (seconds < 5) {
|
||||
setImageGenerationStatus("Submitting request to AI service...");
|
||||
setImageGenerationProgress(15);
|
||||
} else if (seconds < 15) {
|
||||
setImageGenerationStatus("AI is generating your image...");
|
||||
setImageGenerationProgress(30);
|
||||
} else if (seconds < 30) {
|
||||
setImageGenerationStatus("Creating character-consistent scene image...");
|
||||
setImageGenerationProgress(50);
|
||||
} else if (seconds < 60) {
|
||||
setImageGenerationStatus("Rendering image details...");
|
||||
setImageGenerationProgress(70);
|
||||
} else {
|
||||
setImageGenerationStatus(`Processing... (${seconds}s elapsed)`);
|
||||
setImageGenerationProgress(Math.min(90, 50 + (seconds - 30) / 2));
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
const result = await podcastApi.generateSceneImage({
|
||||
sceneId: scene.id,
|
||||
sceneTitle: scene.title,
|
||||
sceneContent: sceneContent,
|
||||
baseAvatarUrl: avatarUrl || undefined, // Pass base avatar URL for character consistency
|
||||
idea: idea,
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
// Pass custom settings if provided
|
||||
customPrompt: settings?.prompt,
|
||||
style: settings?.style,
|
||||
renderingSpeed: settings?.renderingSpeed,
|
||||
aspectRatio: settings?.aspectRatio,
|
||||
});
|
||||
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
progressInterval = null;
|
||||
}
|
||||
|
||||
setImageGenerationStatus("Finalizing image...");
|
||||
setImageGenerationProgress(95);
|
||||
|
||||
// Update scene with image URL
|
||||
const updatedScene = { ...scene, imageUrl: result.image_url };
|
||||
onUpdateScene(updatedScene);
|
||||
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
setImageGenerationStatus(`Image generated successfully in ${elapsed}s`);
|
||||
setImageGenerationProgress(100);
|
||||
|
||||
// Clear status after a moment
|
||||
setTimeout(() => {
|
||||
setImageGenerationStatus("");
|
||||
setImageGenerationProgress(0);
|
||||
}, 2000);
|
||||
} catch (error: any) {
|
||||
// Clear interval on error
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
progressInterval = null;
|
||||
}
|
||||
|
||||
console.error("Failed to generate image:", error);
|
||||
// Extract error message from response if available
|
||||
const errorMessage = error?.response?.data?.detail?.message
|
||||
|| error?.response?.data?.detail?.error
|
||||
|| error?.response?.data?.detail
|
||||
|| error?.message
|
||||
|| "Failed to generate image. Please try again.";
|
||||
console.error("Error details:", {
|
||||
status: error?.response?.status,
|
||||
statusText: error?.response?.statusText,
|
||||
data: error?.response?.data,
|
||||
message: errorMessage,
|
||||
});
|
||||
|
||||
setImageGenerationStatus(`Error: ${errorMessage}`);
|
||||
setImageGenerationProgress(0);
|
||||
|
||||
// Show user-friendly error message
|
||||
alert(`Image generation failed: ${errorMessage}`);
|
||||
throw error;
|
||||
} finally {
|
||||
// Ensure interval is cleared
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
}
|
||||
setGeneratingImage(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRegenerateClick = () => {
|
||||
setShowRegenerateModal(true);
|
||||
};
|
||||
|
||||
const handleAudioRegenerateClick = () => {
|
||||
if (hasAudio) {
|
||||
setShowAudioModal(true);
|
||||
} else {
|
||||
handleApproveAndGenerate(audioSettings);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAudioRegenerate = (settings: AudioGenerationSettings) => {
|
||||
setAudioSettings(settings);
|
||||
setShowAudioModal(false);
|
||||
handleApproveAndGenerate(settings);
|
||||
};
|
||||
|
||||
return (
|
||||
<GlassyCard sx={glassyCardSx}>
|
||||
<Stack spacing={2.5}>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="flex-start">
|
||||
<Box sx={{ flex: 1 }}>
|
||||
<Typography
|
||||
variant="h6"
|
||||
sx={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: 1.5,
|
||||
mb: 1,
|
||||
color: "#0f172a",
|
||||
fontWeight: 600,
|
||||
fontSize: "1.25rem",
|
||||
letterSpacing: "-0.01em",
|
||||
}}
|
||||
>
|
||||
<EditNoteIcon fontSize="small" sx={{ color: "#667eea", fontSize: "1.5rem" }} />
|
||||
{scene.title}
|
||||
</Typography>
|
||||
<Stack direction="row" spacing={1.5} alignItems="center" flexWrap="wrap">
|
||||
<Chip
|
||||
icon={scene.approved ? <CheckCircleIcon /> : <RadioButtonUncheckedIcon />}
|
||||
label={scene.approved ? "Approved" : "Pending Approval"}
|
||||
size="small"
|
||||
color={scene.approved ? "success" : "warning"}
|
||||
sx={{
|
||||
background: scene.approved
|
||||
? "linear-gradient(135deg, rgba(16, 185, 129, 0.12) 0%, rgba(5, 150, 105, 0.12) 100%)"
|
||||
: "linear-gradient(135deg, rgba(245, 158, 11, 0.12) 0%, rgba(217, 119, 6, 0.12) 100%)",
|
||||
color: scene.approved ? "#059669" : "#d97706",
|
||||
border: scene.approved
|
||||
? "1px solid rgba(16, 185, 129, 0.25)"
|
||||
: "1px solid rgba(245, 158, 11, 0.25)",
|
||||
fontWeight: 600,
|
||||
fontSize: "0.75rem",
|
||||
height: 26,
|
||||
boxShadow: "0 1px 2px rgba(0, 0, 0, 0.05)",
|
||||
}}
|
||||
/>
|
||||
<Typography variant="caption" sx={{ color: "#64748b", fontWeight: 500, fontSize: "0.8125rem" }}>
|
||||
Duration: {scene.duration}s
|
||||
</Typography>
|
||||
</Stack>
|
||||
</Box>
|
||||
<Stack direction="row" spacing={1.5} flexWrap="wrap" useFlexGap>
|
||||
<PrimaryButton
|
||||
onClick={handleAudioRegenerateClick}
|
||||
disabled={approving || generating}
|
||||
loading={approving || generating}
|
||||
startIcon={
|
||||
hasAudio && !generating ? (
|
||||
<VolumeUpIcon />
|
||||
) : generating ? (
|
||||
<CircularProgress size={16} sx={{ color: "white" }} />
|
||||
) : (
|
||||
<PlayArrowIcon />
|
||||
)
|
||||
}
|
||||
tooltip={
|
||||
hasAudio && !generating
|
||||
? "Regenerate audio for this scene with custom settings"
|
||||
: generating
|
||||
? "Generating audio..."
|
||||
: scene.approved
|
||||
? "Generate audio for this scene"
|
||||
: "Approve scene and generate audio"
|
||||
}
|
||||
sx={{
|
||||
minWidth: 200,
|
||||
}}
|
||||
>
|
||||
{hasAudio && !generating
|
||||
? "Regenerate Audio"
|
||||
: generating
|
||||
? "Generating Audio..."
|
||||
: scene.approved
|
||||
? "Generate Audio"
|
||||
: "Approve & Generate Audio"}
|
||||
</PrimaryButton>
|
||||
<PrimaryButton
|
||||
onClick={hasImage ? handleRegenerateClick : () => handleGenerateImage()}
|
||||
disabled={generatingImage}
|
||||
loading={generatingImage}
|
||||
startIcon={
|
||||
hasImage && !generatingImage ? (
|
||||
<ImageIcon />
|
||||
) : generatingImage ? (
|
||||
<CircularProgress size={16} sx={{ color: "white" }} />
|
||||
) : (
|
||||
<ImageIcon />
|
||||
)
|
||||
}
|
||||
tooltip={
|
||||
hasImage
|
||||
? "Regenerate image for this scene"
|
||||
: generatingImage
|
||||
? "Generating image..."
|
||||
: "Generate image for video (optional)"
|
||||
}
|
||||
sx={{
|
||||
minWidth: 180,
|
||||
background: hasImage
|
||||
? "linear-gradient(135deg, #10b981 0%, #059669 100%)"
|
||||
: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
"&:hover": {
|
||||
background: hasImage
|
||||
? "linear-gradient(135deg, #059669 0%, #047857 100%)"
|
||||
: "linear-gradient(135deg, #764ba2 0%, #667eea 100%)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
{hasImage && !generatingImage
|
||||
? "Regenerate Image"
|
||||
: generatingImage
|
||||
? "Generating Image..."
|
||||
: "Generate Image"}
|
||||
</PrimaryButton>
|
||||
|
||||
<Tooltip title={totalScenes && totalScenes <= 1 ? "Cannot delete the last scene" : "Delete this scene"}>
|
||||
<IconButton
|
||||
onClick={() => onDelete(scene.id)}
|
||||
disabled={approving || generating || (totalScenes !== undefined && totalScenes <= 1)}
|
||||
sx={{
|
||||
color: "#ef4444",
|
||||
backgroundColor: "rgba(239, 68, 68, 0.1)",
|
||||
border: "1px solid rgba(239, 68, 68, 0.2)",
|
||||
borderRadius: 2,
|
||||
padding: 1.5,
|
||||
"&:hover": {
|
||||
backgroundColor: "rgba(239, 68, 68, 0.15)",
|
||||
borderColor: "rgba(239, 68, 68, 0.3)",
|
||||
},
|
||||
"&:disabled": {
|
||||
backgroundColor: "rgba(156, 163, 175, 0.1)",
|
||||
borderColor: "rgba(156, 163, 175, 0.2)",
|
||||
color: "#9ca3af",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<DeleteIcon sx={{ fontSize: "1.25rem" }} />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
</Stack>
|
||||
</Stack>
|
||||
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1 }} />
|
||||
|
||||
<Stack spacing={2}>
|
||||
{scene.lines.map((line) => (
|
||||
<LineEditor key={line.id} line={line} onChange={updateLine} />
|
||||
))}
|
||||
</Stack>
|
||||
|
||||
{scene.audioUrl && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1, mt: 1 }} />
|
||||
<Box
|
||||
sx={{
|
||||
p: 2,
|
||||
background: hasAudio
|
||||
? "linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(5, 150, 105, 0.08) 100%)"
|
||||
: "linear-gradient(135deg, rgba(245, 158, 11, 0.08) 0%, rgba(217, 119, 6, 0.08) 100%)",
|
||||
borderRadius: 2,
|
||||
border: hasAudio
|
||||
? "1px solid rgba(16, 185, 129, 0.2)"
|
||||
: "1px solid rgba(245, 158, 11, 0.2)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5} sx={{ mb: 1.5 }}>
|
||||
<VolumeUpIcon sx={{ color: hasAudio ? "#059669" : "#d97706", fontSize: "1.25rem" }} />
|
||||
<Typography variant="subtitle2" sx={{ color: hasAudio ? "#059669" : "#d97706", fontWeight: 600 }}>
|
||||
{hasAudio ? "Audio Generated" : "Loading Audio..."}
|
||||
</Typography>
|
||||
</Stack>
|
||||
{hasAudio && audioBlobUrl ? (
|
||||
<audio controls style={{ width: "100%", borderRadius: 8 }}>
|
||||
<source src={audioBlobUrl} type="audio/mpeg" />
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
) : (
|
||||
<Box sx={{ display: "flex", alignItems: "center", justifyContent: "center", py: 2 }}>
|
||||
<CircularProgress size={24} sx={{ color: "#d97706" }} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Image Generation Progress - Show when generating */}
|
||||
{generatingImage && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1, mt: 1 }} />
|
||||
<Box
|
||||
sx={{
|
||||
p: 2,
|
||||
background: "linear-gradient(135deg, rgba(102, 126, 234, 0.08) 0%, rgba(118, 75, 162, 0.08) 100%)",
|
||||
borderRadius: 2,
|
||||
border: "1px solid rgba(102, 126, 234, 0.2)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5} sx={{ mb: 1.5 }}>
|
||||
<ImageIcon sx={{ color: "#667eea", fontSize: "1.25rem" }} />
|
||||
<Typography variant="subtitle2" sx={{ color: "#667eea", fontWeight: 600 }}>
|
||||
Generating Image...
|
||||
</Typography>
|
||||
</Stack>
|
||||
|
||||
{/* Progress Bar */}
|
||||
<Box sx={{ mb: 1.5 }}>
|
||||
<LinearProgress
|
||||
variant="determinate"
|
||||
value={imageGenerationProgress}
|
||||
sx={{
|
||||
height: 8,
|
||||
borderRadius: 4,
|
||||
backgroundColor: alpha("#667eea", 0.1),
|
||||
"& .MuiLinearProgress-bar": {
|
||||
backgroundColor: "#667eea",
|
||||
borderRadius: 4,
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Typography variant="caption" sx={{ color: "#667eea", mt: 0.5, display: "block", textAlign: "right" }}>
|
||||
{imageGenerationProgress}%
|
||||
</Typography>
|
||||
</Box>
|
||||
|
||||
{/* Status Message */}
|
||||
{imageGenerationStatus && (
|
||||
<Typography variant="body2" sx={{ color: "#667eea", fontSize: "0.875rem", lineHeight: 1.6, mb: 1 }}>
|
||||
{imageGenerationStatus}
|
||||
</Typography>
|
||||
)}
|
||||
|
||||
{/* Spinner */}
|
||||
<Box sx={{ display: "flex", alignItems: "center", justifyContent: "center", mt: 1 }}>
|
||||
<CircularProgress size={32} sx={{ color: "#667eea" }} />
|
||||
</Box>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Generated Image Display - Show when image exists and not generating */}
|
||||
{scene.imageUrl && !generatingImage && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1, mt: 1 }} />
|
||||
<Box
|
||||
sx={{
|
||||
p: 2,
|
||||
background: imageBlobUrl && !imageLoading
|
||||
? "linear-gradient(135deg, rgba(102, 126, 234, 0.08) 0%, rgba(118, 75, 162, 0.08) 100%)"
|
||||
: "linear-gradient(135deg, rgba(245, 158, 11, 0.08) 0%, rgba(217, 119, 6, 0.08) 100%)",
|
||||
borderRadius: 2,
|
||||
border: imageBlobUrl && !imageLoading
|
||||
? "1px solid rgba(102, 126, 234, 0.2)"
|
||||
: "1px solid rgba(245, 158, 11, 0.2)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5} sx={{ mb: 1.5 }}>
|
||||
<ImageIcon sx={{ color: imageBlobUrl && !imageLoading ? "#667eea" : "#d97706", fontSize: "1.25rem" }} />
|
||||
<Typography variant="subtitle2" sx={{ color: imageBlobUrl && !imageLoading ? "#667eea" : "#d97706", fontWeight: 600 }}>
|
||||
{imageBlobUrl && !imageLoading ? "Image Generated" : "Loading Image..."}
|
||||
</Typography>
|
||||
</Stack>
|
||||
{imageBlobUrl && !imageLoading ? (
|
||||
<Box
|
||||
sx={{
|
||||
width: "100%",
|
||||
borderRadius: 2,
|
||||
overflow: "hidden",
|
||||
border: "1px solid rgba(102,126,234,0.2)",
|
||||
background: alpha("#667eea", 0.05),
|
||||
}}
|
||||
>
|
||||
<Box
|
||||
component="img"
|
||||
src={imageBlobUrl}
|
||||
alt={scene.title}
|
||||
sx={{
|
||||
width: "100%",
|
||||
height: "auto",
|
||||
display: "block",
|
||||
maxHeight: 400,
|
||||
objectFit: "cover",
|
||||
}}
|
||||
onError={(e) => {
|
||||
console.error('[SceneEditor] Image failed to load:', {
|
||||
src: e.currentTarget.src,
|
||||
imageUrl: scene.imageUrl,
|
||||
imageBlobUrl,
|
||||
});
|
||||
}}
|
||||
onLoad={() => {
|
||||
console.log('[SceneEditor] Image loaded successfully');
|
||||
}}
|
||||
/>
|
||||
</Box>
|
||||
) : (
|
||||
<Box sx={{ display: "flex", alignItems: "center", justifyContent: "center", py: 2 }}>
|
||||
<CircularProgress size={24} sx={{ color: "#d97706" }} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
</Stack>
|
||||
|
||||
{/* Image Regeneration Modal */}
|
||||
<ImageRegenerateModal
|
||||
open={showRegenerateModal}
|
||||
onClose={() => setShowRegenerateModal(false)}
|
||||
onRegenerate={handleGenerateImage}
|
||||
initialPrompt={(() => {
|
||||
const promptParts = [
|
||||
`Scene: ${scene.title}`,
|
||||
"Professional podcast recording studio",
|
||||
"Modern microphone setup",
|
||||
"Clean background, professional lighting",
|
||||
"16:9 aspect ratio, video-optimized composition"
|
||||
];
|
||||
if (idea) {
|
||||
promptParts.push(`Topic: ${idea.substring(0, 60)}`);
|
||||
}
|
||||
return promptParts.join(", ");
|
||||
})()}
|
||||
initialStyle="Realistic"
|
||||
initialRenderingSpeed="Quality"
|
||||
initialAspectRatio="16:9"
|
||||
isGenerating={generatingImage}
|
||||
/>
|
||||
|
||||
<AudioRegenerateModal
|
||||
open={showAudioModal}
|
||||
onClose={() => setShowAudioModal(false)}
|
||||
onRegenerate={handleAudioRegenerate}
|
||||
initialSettings={audioSettings}
|
||||
isGenerating={generating}
|
||||
/>
|
||||
</GlassyCard>
|
||||
);
|
||||
};
|
||||
|
||||
818
_session_backup/ScriptEditor.tsx
Normal file
818
_session_backup/ScriptEditor.tsx
Normal file
@@ -0,0 +1,818 @@
|
||||
import React, { useEffect, useState, useCallback } from "react";
|
||||
import { Box, Stack, Typography, Alert, Paper, LinearProgress, CircularProgress, alpha, Collapse, IconButton, Divider } from "@mui/material";
|
||||
import { EditNote as EditNoteIcon, CheckCircle as CheckCircleIcon, PlayArrow as PlayArrowIcon, ArrowBack as ArrowBackIcon, Info as InfoIcon, ExpandMore as ExpandMoreIcon, ExpandLess as ExpandLessIcon, Download as DownloadIcon, Refresh as RefreshIcon } from "@mui/icons-material";
|
||||
import { Script, Knobs, Scene } from "../types";
|
||||
import { BlogResearchResponse } from "../../../services/blogWriterApi";
|
||||
import { podcastApi } from "../../../services/podcastApi";
|
||||
import { GlassyCard, PrimaryButton, SecondaryButton } from "../ui";
|
||||
import { SceneEditor } from "./SceneEditor";
|
||||
import { InlineAudioPlayer } from "../InlineAudioPlayer";
|
||||
import { aiApiClient } from "../../../api/client";
|
||||
|
||||
interface ScriptEditorProps {
|
||||
projectId: string;
|
||||
idea: string;
|
||||
research: any; // Research type
|
||||
rawResearch: BlogResearchResponse | null;
|
||||
knobs: Knobs;
|
||||
speakers: number;
|
||||
durationMinutes: number;
|
||||
script: Script | null;
|
||||
onScriptChange: (script: Script) => void;
|
||||
onBackToResearch: () => void;
|
||||
onProceedToRendering: (script: Script) => void;
|
||||
onError: (message: string) => void;
|
||||
avatarUrl?: string | null; // Base avatar URL for consistent scene image generation
|
||||
analysis?: any;
|
||||
outline?: any;
|
||||
}
|
||||
|
||||
export const ScriptEditor: React.FC<ScriptEditorProps> = ({
|
||||
projectId,
|
||||
idea,
|
||||
research,
|
||||
rawResearch,
|
||||
knobs,
|
||||
speakers,
|
||||
durationMinutes,
|
||||
script: initialScript,
|
||||
onScriptChange,
|
||||
onBackToResearch,
|
||||
onProceedToRendering,
|
||||
onError,
|
||||
avatarUrl,
|
||||
analysis,
|
||||
outline,
|
||||
}) => {
|
||||
const [script, setScript] = useState<Script | null>(initialScript);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [approvingSceneId, setApprovingSceneId] = useState<string | null>(null);
|
||||
const [generatingAudioId, setGeneratingAudioId] = useState<string | null>(null);
|
||||
const [showScriptFormatInfo, setShowScriptFormatInfo] = useState(true);
|
||||
const [combiningAudio, setCombiningAudio] = useState(false);
|
||||
const [combinedAudioResult, setCombinedAudioResult] = useState<{
|
||||
url: string;
|
||||
filename: string;
|
||||
duration: number;
|
||||
sceneCount: number;
|
||||
} | null>(null);
|
||||
|
||||
// Defer upward script updates to avoid setState during render warnings
|
||||
const emitScriptChange = useCallback(
|
||||
(next: Script) => Promise.resolve().then(() => onScriptChange(next)),
|
||||
[onScriptChange]
|
||||
);
|
||||
|
||||
// Sync with parent state
|
||||
useEffect(() => {
|
||||
if (initialScript) {
|
||||
setScript(initialScript);
|
||||
}
|
||||
}, [initialScript]);
|
||||
|
||||
useEffect(() => {
|
||||
// If script already exists, don't regenerate
|
||||
if (script) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only generate if we have research data
|
||||
if (!rawResearch) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mounted = true;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
podcastApi
|
||||
.generateScript({
|
||||
projectId,
|
||||
idea,
|
||||
research: rawResearch,
|
||||
knobs,
|
||||
speakers,
|
||||
durationMinutes,
|
||||
analysis,
|
||||
outline,
|
||||
})
|
||||
.then((res) => {
|
||||
if (mounted) {
|
||||
setScript(res);
|
||||
emitScriptChange(res);
|
||||
setError(null);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
const message = err instanceof Error ? err.message : "Failed to generate script";
|
||||
setError(message);
|
||||
onError(message);
|
||||
})
|
||||
.finally(() => mounted && setLoading(false));
|
||||
return () => {
|
||||
mounted = false;
|
||||
};
|
||||
}, [projectId, rawResearch, idea, knobs, speakers, durationMinutes, analysis, outline, emitScriptChange, onError, script]);
|
||||
|
||||
const updateScene = (updated: Scene) => {
|
||||
// Use functional update to ensure we're working with latest state
|
||||
setScript((currentScript) => {
|
||||
if (!currentScript) return currentScript;
|
||||
const updatedScript = {
|
||||
...currentScript,
|
||||
scenes: currentScript.scenes.map((s) => (s.id === updated.id ? { ...s, ...updated } : s))
|
||||
};
|
||||
emitScriptChange(updatedScript);
|
||||
return updatedScript;
|
||||
});
|
||||
};
|
||||
|
||||
const approveScene = async (sceneId: string) => {
|
||||
try {
|
||||
setApprovingSceneId(sceneId);
|
||||
await podcastApi.approveScene({ projectId, sceneId });
|
||||
// Use functional update to ensure we're working with latest state
|
||||
setScript((currentScript) => {
|
||||
if (!currentScript) return currentScript;
|
||||
const updatedScript = {
|
||||
...currentScript,
|
||||
scenes: currentScript.scenes.map((s) => (s.id === sceneId ? { ...s, approved: true } : s)),
|
||||
};
|
||||
emitScriptChange(updatedScript);
|
||||
return updatedScript;
|
||||
});
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to approve scene";
|
||||
setError(message);
|
||||
onError(message);
|
||||
throw err;
|
||||
} finally {
|
||||
setApprovingSceneId((current) => (current === sceneId ? null : current));
|
||||
}
|
||||
};
|
||||
|
||||
const deleteScene = useCallback((sceneId: string) => {
|
||||
if (!script) return;
|
||||
|
||||
// Prevent deleting if it's the last scene
|
||||
if (script.scenes.length <= 1) {
|
||||
onError("Cannot delete the last scene. At least one scene is required.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Add confirmation dialog
|
||||
const sceneToDelete = script.scenes.find(s => s.id === sceneId);
|
||||
if (!sceneToDelete) return;
|
||||
|
||||
const confirmDelete = window.confirm(
|
||||
`Are you sure you want to delete "${sceneToDelete.title}"? This action cannot be undone.`
|
||||
);
|
||||
|
||||
if (!confirmDelete) return;
|
||||
|
||||
// Remove the scene from the script
|
||||
const updatedScenes = script.scenes.filter(s => s.id !== sceneId);
|
||||
const updatedScript = { ...script, scenes: updatedScenes };
|
||||
|
||||
emitScriptChange(updatedScript);
|
||||
setScript(updatedScript);
|
||||
|
||||
// Show success message
|
||||
console.log(`[ScriptEditor] Scene "${sceneToDelete.title}" deleted successfully`);
|
||||
}, [script, emitScriptChange, onError]);
|
||||
|
||||
const allApproved = script && script.scenes.every((s) => s.approved);
|
||||
const approvedCount = script ? script.scenes.filter((s) => s.approved).length : 0;
|
||||
const totalScenes = script ? script.scenes.length : 0;
|
||||
|
||||
// Check if all scenes have both audio and images (required for video rendering)
|
||||
const allScenesHaveAudioAndImages = script && script.scenes.every((s) => s.audioUrl && s.imageUrl);
|
||||
const scenesWithAudio = script ? script.scenes.filter((s) => s.audioUrl).length : 0;
|
||||
const allScenesHaveAudio = script && script.scenes.every((s) => s.audioUrl);
|
||||
|
||||
const combineAudio = useCallback(async () => {
|
||||
if (!script || !projectId) return;
|
||||
|
||||
try {
|
||||
setCombiningAudio(true);
|
||||
|
||||
const sceneIds: string[] = [];
|
||||
const sceneAudioUrls: string[] = [];
|
||||
|
||||
script.scenes.forEach((scene) => {
|
||||
if (scene.audioUrl) {
|
||||
// Ensure we're using the correct URL format (not blob URLs)
|
||||
const audioUrl = scene.audioUrl.startsWith('blob:') ? '' : scene.audioUrl;
|
||||
if (audioUrl) {
|
||||
sceneIds.push(scene.id);
|
||||
sceneAudioUrls.push(audioUrl);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (sceneIds.length === 0) {
|
||||
onError("No audio files found to combine.");
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await podcastApi.combineAudio({
|
||||
projectId,
|
||||
sceneIds,
|
||||
sceneAudioUrls,
|
||||
});
|
||||
|
||||
// Store combined audio result for preview
|
||||
setCombinedAudioResult({
|
||||
url: result.combined_audio_url,
|
||||
filename: result.combined_audio_filename,
|
||||
duration: result.total_duration,
|
||||
sceneCount: result.scene_count,
|
||||
});
|
||||
|
||||
// Download the combined audio as blob (for authenticated endpoints)
|
||||
try {
|
||||
// Normalize path
|
||||
let audioPath = result.combined_audio_url.startsWith('/')
|
||||
? result.combined_audio_url
|
||||
: `/${result.combined_audio_url}`;
|
||||
|
||||
// Ensure it's a podcast audio endpoint
|
||||
if (!audioPath.includes('/api/podcast/audio/')) {
|
||||
const filename = audioPath.split('/').pop() || result.combined_audio_filename;
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
audioPath = audioPath.split('?')[0];
|
||||
|
||||
// Fetch as blob using authenticated client
|
||||
const response = await aiApiClient.get(audioPath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
// Create blob URL and download
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = blobUrl;
|
||||
link.download = result.combined_audio_filename || `podcast-episode-${projectId.slice(-8)}.mp3`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
|
||||
// Clean up blob URL after a delay
|
||||
setTimeout(() => {
|
||||
URL.revokeObjectURL(blobUrl);
|
||||
}, 100);
|
||||
} catch (downloadError) {
|
||||
console.error('Failed to download combined audio:', downloadError);
|
||||
onError('Failed to download audio file. You can try downloading again from the preview.');
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "Failed to combine audio";
|
||||
onError(`Failed to combine audio: ${message}`);
|
||||
} finally {
|
||||
setCombiningAudio(false);
|
||||
}
|
||||
}, [script, projectId, onError]);
|
||||
|
||||
return (
|
||||
<Box sx={{ mt: 4 }}>
|
||||
<Stack direction="row" spacing={2} alignItems="center" sx={{ mb: 4 }}>
|
||||
<SecondaryButton onClick={onBackToResearch} startIcon={<ArrowBackIcon />}>
|
||||
Back to Research
|
||||
</SecondaryButton>
|
||||
<Box sx={{ flex: 1 }}>
|
||||
<Typography
|
||||
variant="h4"
|
||||
sx={{
|
||||
background: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
WebkitBackgroundClip: "text",
|
||||
WebkitTextFillColor: "transparent",
|
||||
fontWeight: 700,
|
||||
letterSpacing: "-0.02em",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: 1.5,
|
||||
fontSize: { xs: "1.75rem", md: "2rem" },
|
||||
}}
|
||||
>
|
||||
<EditNoteIcon sx={{ fontSize: "2rem" }} />
|
||||
Script Editor
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#64748b", mt: 0.5, ml: 5.5 }}>
|
||||
Review and refine your podcast script before rendering
|
||||
</Typography>
|
||||
</Box>
|
||||
</Stack>
|
||||
|
||||
{loading && (
|
||||
<Alert
|
||||
severity="info"
|
||||
icon={<CircularProgress size={20} />}
|
||||
sx={{
|
||||
mb: 3,
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.08) 0%, rgba(139, 92, 246, 0.08) 100%)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.2)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 1px 2px rgba(99, 102, 241, 0.05)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#6366f1",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", fontWeight: 500 }}>
|
||||
Generating script with AI... This may take a moment.
|
||||
</Typography>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<Alert
|
||||
severity="error"
|
||||
sx={{
|
||||
mb: 3,
|
||||
background: "linear-gradient(135deg, rgba(239, 68, 68, 0.08) 0%, rgba(220, 38, 38, 0.08) 100%)",
|
||||
border: "1px solid rgba(239, 68, 68, 0.2)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 1px 2px rgba(239, 68, 68, 0.05)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#ef4444",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", fontWeight: 500 }}>
|
||||
{error}
|
||||
</Typography>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{script && (
|
||||
<Stack spacing={3}>
|
||||
{/* Script Format Explanation Panel */}
|
||||
<Paper
|
||||
sx={{
|
||||
p: 3,
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.05) 0%, rgba(139, 92, 246, 0.05) 100%)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.15)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 2px 8px rgba(99, 102, 241, 0.08)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: showScriptFormatInfo ? 2 : 0 }}>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5}>
|
||||
<Box
|
||||
sx={{
|
||||
width: 40,
|
||||
height: 40,
|
||||
borderRadius: "50%",
|
||||
background: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
boxShadow: "0 2px 8px rgba(102, 126, 234, 0.3)",
|
||||
}}
|
||||
>
|
||||
<InfoIcon sx={{ color: "#ffffff", fontSize: "1.5rem" }} />
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="h6" sx={{ color: "#0f172a", fontWeight: 600, fontSize: "1.1rem" }}>
|
||||
Why This Script Format?
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#64748b", mt: 0.25 }}>
|
||||
Understanding how your script creates natural, human-like audio
|
||||
</Typography>
|
||||
</Box>
|
||||
</Stack>
|
||||
<IconButton
|
||||
onClick={() => setShowScriptFormatInfo(!showScriptFormatInfo)}
|
||||
sx={{
|
||||
color: "#6366f1",
|
||||
"&:hover": {
|
||||
background: "rgba(99, 102, 241, 0.1)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
{showScriptFormatInfo ? <ExpandLessIcon /> : <ExpandMoreIcon />}
|
||||
</IconButton>
|
||||
</Stack>
|
||||
|
||||
<Collapse in={showScriptFormatInfo}>
|
||||
<Stack spacing={2.5}>
|
||||
<Box>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", lineHeight: 1.8, mb: 2 }}>
|
||||
Our AI script generator creates scripts specifically optimized for <strong style={{ fontWeight: 600 }}>high-quality text-to-speech</strong>.
|
||||
The format you see here is designed to produce audio that sounds natural and human-like, not robotic.
|
||||
</Typography>
|
||||
</Box>
|
||||
|
||||
<Stack spacing={2}>
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
1
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Natural Pauses & Rhythm
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
The script includes strategic pauses between lines and when speakers change. This creates natural breathing patterns
|
||||
and conversation flow, just like real human speech. Without these pauses, the audio would sound rushed and robotic.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
2
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Emphasis Markers
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
Lines marked with emphasis help highlight important points, statistics, or key insights. The AI voice will naturally
|
||||
stress these parts, making your podcast more engaging and easier to follow—just like a real host would emphasize important information.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
3
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Short, Conversational Sentences
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
The script uses shorter sentences (15-20 words) written in a conversational style. This matches how people actually
|
||||
speak, making the audio sound more natural. Long, complex sentences would sound awkward when spoken aloud.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
4
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Scene-Specific Emotions
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
Each scene has an emotional tone (excited, serious, curious, etc.) that guides the AI voice's delivery. This creates
|
||||
variety and keeps listeners engaged, just like a real podcast host would vary their tone based on the topic.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
5
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Optimized for Podcast Narration
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
The script is optimized with slightly slower pacing and natural pronunciation settings specifically for podcast narration.
|
||||
This ensures clarity and makes the content easy to understand, even when listeners are multitasking.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
</Stack>
|
||||
|
||||
<Alert
|
||||
severity="info"
|
||||
sx={{
|
||||
mt: 1,
|
||||
background: "rgba(99, 102, 241, 0.06)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.15)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#6366f1",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", lineHeight: 1.7 }}>
|
||||
<strong style={{ fontWeight: 600 }}>Tip:</strong> You can edit any line or scene to match your preferences.
|
||||
The format will be preserved when rendering, ensuring your audio still sounds natural and professional.
|
||||
</Typography>
|
||||
</Alert>
|
||||
</Stack>
|
||||
</Collapse>
|
||||
</Paper>
|
||||
|
||||
<Alert
|
||||
severity="info"
|
||||
sx={{
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.08) 0%, rgba(139, 92, 246, 0.08) 100%)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.2)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 1px 2px rgba(99, 102, 241, 0.05)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#6366f1",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", fontWeight: 500, lineHeight: 1.6 }}>
|
||||
<strong style={{ fontWeight: 600 }}>Approval Required:</strong> Each scene must be approved before rendering. Review and edit lines as needed, then approve each scene.
|
||||
</Typography>
|
||||
</Alert>
|
||||
|
||||
<Stack spacing={2}>
|
||||
{script.scenes.map((scene, idx) => (
|
||||
<GlassyCard
|
||||
key={scene.id}
|
||||
initial={{ opacity: 0, y: 8 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.3, delay: idx * 0.1 }}
|
||||
>
|
||||
<SceneEditor
|
||||
scene={scene}
|
||||
onUpdateScene={updateScene}
|
||||
onApprove={approveScene}
|
||||
onDelete={deleteScene}
|
||||
knobs={knobs}
|
||||
approvingSceneId={approvingSceneId}
|
||||
generatingAudioId={generatingAudioId}
|
||||
totalScenes={script.scenes.length}
|
||||
onAudioGenerationStart={(sceneId) => {
|
||||
setGeneratingAudioId(sceneId);
|
||||
}}
|
||||
onAudioGenerated={async (sceneId, audioUrl) => {
|
||||
setGeneratingAudioId(null);
|
||||
// Use functional update to ensure we're working with latest state
|
||||
// Ensure scene is marked as approved and has audioUrl
|
||||
setScript((currentScript) => {
|
||||
if (!currentScript) return currentScript;
|
||||
const updatedScenes = currentScript.scenes.map((s) =>
|
||||
s.id === sceneId ? { ...s, audioUrl, approved: true } : s
|
||||
);
|
||||
const updatedScript = { ...currentScript, scenes: updatedScenes };
|
||||
emitScriptChange(updatedScript);
|
||||
return updatedScript;
|
||||
});
|
||||
}}
|
||||
idea={idea}
|
||||
avatarUrl={avatarUrl}
|
||||
/>
|
||||
</GlassyCard>
|
||||
))}
|
||||
</Stack>
|
||||
|
||||
<Paper
|
||||
sx={{
|
||||
p: 3.5,
|
||||
background: allApproved
|
||||
? "linear-gradient(135deg, rgba(16, 185, 129, 0.05) 0%, rgba(5, 150, 105, 0.05) 100%)"
|
||||
: "#ffffff",
|
||||
border: allApproved
|
||||
? "2px solid rgba(16, 185, 129, 0.25)"
|
||||
: "1px solid rgba(15, 23, 42, 0.08)",
|
||||
borderRadius: 3,
|
||||
boxShadow: allApproved
|
||||
? "0 4px 6px rgba(16, 185, 129, 0.08), 0 8px 24px rgba(16, 185, 129, 0.06)"
|
||||
: "0 1px 3px rgba(15, 23, 42, 0.06), 0 4px 12px rgba(15, 23, 42, 0.04)",
|
||||
transition: "all 0.3s cubic-bezier(0.4, 0, 0.2, 1)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center">
|
||||
<Box>
|
||||
<Typography variant="subtitle1" sx={{ mb: 1, display: "flex", alignItems: "center", gap: 1.5, color: "#0f172a", fontWeight: 600, fontSize: "1.1rem" }}>
|
||||
<CheckCircleIcon fontSize="small" sx={{ color: allApproved ? "#10b981" : "#94a3b8", fontSize: "1.25rem" }} />
|
||||
Approval Status
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#64748b", fontWeight: 400, lineHeight: 1.6 }}>
|
||||
{approvedCount} of {totalScenes} scenes approved
|
||||
{allScenesHaveAudioAndImages && " • All scenes ready for video rendering"}
|
||||
{!allScenesHaveAudioAndImages && allApproved && " • Generate images for all scenes to enable video rendering"}
|
||||
{!allApproved && " — Approve all scenes first"}
|
||||
</Typography>
|
||||
{!allScenesHaveAudioAndImages && (
|
||||
<LinearProgress
|
||||
variant="determinate"
|
||||
value={
|
||||
allScenesHaveAudioAndImages
|
||||
? 100
|
||||
: script
|
||||
? (script.scenes.filter((s) => s.audioUrl && s.imageUrl).length / totalScenes) * 100
|
||||
: 0
|
||||
}
|
||||
sx={{ mt: 1, height: 6, borderRadius: 3 }}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
<PrimaryButton
|
||||
onClick={() => script && onProceedToRendering(script)}
|
||||
disabled={!allScenesHaveAudioAndImages}
|
||||
startIcon={<PlayArrowIcon />}
|
||||
tooltip={
|
||||
!allScenesHaveAudioAndImages
|
||||
? "Generate audio and images for all scenes to proceed to video rendering"
|
||||
: "Proceed to video rendering (all scenes have audio and images)"
|
||||
}
|
||||
>
|
||||
Proceed to Rendering
|
||||
</PrimaryButton>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
{/* Download Audio-Only Podcast Section */}
|
||||
{allScenesHaveAudio && (
|
||||
<Paper
|
||||
sx={{
|
||||
p: 3,
|
||||
background: "linear-gradient(135deg, rgba(102, 126, 234, 0.05) 0%, rgba(118, 75, 162, 0.05) 100%)",
|
||||
border: "1px solid rgba(102, 126, 234, 0.15)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack spacing={3}>
|
||||
<Typography variant="h6" sx={{ color: "#0f172a", fontWeight: 600 }}>
|
||||
Download Audio-Only Podcast
|
||||
</Typography>
|
||||
|
||||
{!combinedAudioResult ? (
|
||||
<>
|
||||
<PrimaryButton
|
||||
onClick={combineAudio}
|
||||
disabled={combiningAudio}
|
||||
loading={combiningAudio}
|
||||
startIcon={<DownloadIcon />}
|
||||
tooltip="Combine all scene audio files into a single podcast episode"
|
||||
sx={{
|
||||
minWidth: 280,
|
||||
fontSize: "1rem",
|
||||
py: 1.5,
|
||||
background: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
"&:hover": {
|
||||
background: "linear-gradient(135deg, #764ba2 0%, #667eea 100%)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
{combiningAudio ? "Combining Audio..." : "Download Audio-Only Podcast"}
|
||||
</PrimaryButton>
|
||||
<Typography variant="caption" sx={{ color: "#64748b", fontStyle: "italic" }}>
|
||||
This will combine all {scenesWithAudio} scene audio files into one complete podcast episode.
|
||||
</Typography>
|
||||
</>
|
||||
) : (
|
||||
<Stack spacing={2}>
|
||||
{/* Success Alert */}
|
||||
<Alert
|
||||
severity="success"
|
||||
sx={{
|
||||
background: alpha("#10b981", 0.1),
|
||||
border: "1px solid rgba(16,185,129,0.3)",
|
||||
"& .MuiAlert-icon": { color: "#10b981" },
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#059669", fontWeight: 500 }}>
|
||||
✅ Combined audio generated successfully! ({combinedAudioResult.sceneCount} scenes,{" "}
|
||||
{Math.round(combinedAudioResult.duration)}s)
|
||||
</Typography>
|
||||
</Alert>
|
||||
|
||||
{/* Combined Audio Preview */}
|
||||
<InlineAudioPlayer audioUrl={combinedAudioResult.url} title="Complete Podcast Episode" />
|
||||
|
||||
{/* Action Buttons */}
|
||||
<Stack direction="row" spacing={2}>
|
||||
<SecondaryButton
|
||||
onClick={async () => {
|
||||
try {
|
||||
// Normalize path
|
||||
let audioPath = combinedAudioResult.url.startsWith('/')
|
||||
? combinedAudioResult.url
|
||||
: `/${combinedAudioResult.url}`;
|
||||
|
||||
// Ensure it's a podcast audio endpoint
|
||||
if (!audioPath.includes('/api/podcast/audio/')) {
|
||||
const filename = audioPath.split('/').pop() || combinedAudioResult.filename;
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
audioPath = audioPath.split('?')[0];
|
||||
|
||||
// Fetch as blob using authenticated client
|
||||
const response = await aiApiClient.get(audioPath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
// Create blob URL and download
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = blobUrl;
|
||||
link.download = combinedAudioResult.filename || `podcast-episode-${projectId.slice(-8)}.mp3`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
|
||||
// Clean up blob URL after a delay
|
||||
setTimeout(() => {
|
||||
URL.revokeObjectURL(blobUrl);
|
||||
}, 100);
|
||||
} catch (error) {
|
||||
console.error('Failed to download audio:', error);
|
||||
onError('Failed to download audio file. Please try again.');
|
||||
}
|
||||
}}
|
||||
startIcon={<DownloadIcon />}
|
||||
tooltip="Download the combined audio file again"
|
||||
>
|
||||
Download Again
|
||||
</SecondaryButton>
|
||||
<SecondaryButton
|
||||
onClick={() => {
|
||||
setCombinedAudioResult(null);
|
||||
combineAudio();
|
||||
}}
|
||||
disabled={combiningAudio}
|
||||
loading={combiningAudio}
|
||||
startIcon={<RefreshIcon />}
|
||||
tooltip="Regenerate combined audio (useful if scenes were updated)"
|
||||
>
|
||||
Regenerate
|
||||
</SecondaryButton>
|
||||
</Stack>
|
||||
</Stack>
|
||||
)}
|
||||
</Stack>
|
||||
</Paper>
|
||||
)}
|
||||
</Stack>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
334
_session_backup/analysis.py
Normal file
334
_session_backup/analysis.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Podcast Analysis Handlers
|
||||
|
||||
Analysis endpoint for podcast ideas.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..models import (
|
||||
PodcastAnalyzeRequest,
|
||||
PodcastAnalyzeResponse,
|
||||
PodcastEnhanceIdeaRequest,
|
||||
PodcastEnhanceIdeaResponse
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/idea/enhance", response_model=PodcastEnhanceIdeaResponse)
|
||||
async def enhance_podcast_idea(
|
||||
request: PodcastEnhanceIdeaRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Take raw keywords/topic and use AI to craft a presentable, detailed podcast idea.
|
||||
Uses the user's Podcast Bible for hyper-personalization if available.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Serialize Bible context if provided or generate from onboarding
|
||||
bible_context = ""
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_enhance")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Enhance] Failed to parse or generate bible context: {exc}")
|
||||
|
||||
prompt = f"""
|
||||
You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{request.idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 enhanced episode pitches (in order: Professional, Storytelling, Trendy)
|
||||
- rationales: array of 3 rationales explaining the approach for each version
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(raw, str):
|
||||
data = json.loads(raw)
|
||||
else:
|
||||
data = raw
|
||||
|
||||
# Extract enhanced ideas and rationales with fallbacks
|
||||
enhanced_ideas = data.get("enhanced_ideas", [])
|
||||
rationales = data.get("rationales", [])
|
||||
|
||||
# Ensure we have exactly 3 ideas, fallback to original if needed
|
||||
if not isinstance(enhanced_ideas, list) or len(enhanced_ideas) != 3:
|
||||
# Fallback: create 3 variations of the original idea
|
||||
base_idea = request.idea
|
||||
enhanced_ideas = [
|
||||
f"Expert insights on {base_idea}: A deep dive into industry trends and best practices.",
|
||||
f"The human side of {base_idea}: Personal stories and real-world experiences that resonate.",
|
||||
f"Modern perspectives on {base_idea}: Current trends and forward-thinking approaches."
|
||||
]
|
||||
rationales = [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
|
||||
# Ensure rationales match the number of ideas
|
||||
if not isinstance(rationales, list) or len(rationales) != 3:
|
||||
rationales = [
|
||||
"Professional angle with expert insights",
|
||||
"Storytelling angle with human interest",
|
||||
"Trendy angle with contemporary relevance"
|
||||
]
|
||||
|
||||
return PodcastEnhanceIdeaResponse(
|
||||
enhanced_ideas=enhanced_ideas[:3], # Ensure exactly 3
|
||||
rationales=rationales[:3] # Ensure exactly 3
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Enhance] Failed for user {user_id}: {exc}")
|
||||
# Fallback to basic variations of original idea
|
||||
base_idea = request.idea
|
||||
return PodcastEnhanceIdeaResponse(
|
||||
enhanced_ideas=[
|
||||
f"Expert insights on {base_idea}: A deep dive into industry trends and best practices.",
|
||||
f"The human side of {base_idea}: Personal stories and real-world experiences that resonate.",
|
||||
f"Modern perspectives on {base_idea}: Current trends and forward-thinking approaches."
|
||||
],
|
||||
rationales=[
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=PodcastAnalyzeResponse)
|
||||
async def analyze_podcast_idea(
|
||||
request: PodcastAnalyzeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Analyze a podcast idea and return podcast-oriented outlines, keywords, and titles.
|
||||
If no avatar_url is provided, it generates one automatically based on the host's look.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Serialize Bible context if provided or generate from onboarding
|
||||
bible_context = ""
|
||||
bible_obj = None
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
bible_obj = bible_data
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_analyze")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
bible_obj = bible_obj
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Analyze] Failed to parse or generate bible context: {exc}")
|
||||
|
||||
# --- NEW: Generate Presenter Avatar if missing ---
|
||||
final_avatar_url = request.avatar_url
|
||||
final_avatar_prompt = None
|
||||
|
||||
if not final_avatar_url:
|
||||
logger.info(f"[Podcast Analyze] No avatar_url provided, generating one for user {user_id}")
|
||||
try:
|
||||
# 1. PRE-FLIGHT VALIDATION: Check subscription limits for image generation
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
pricing_service = PricingService(db)
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# 2. Build avatar prompt from Bible host look or fallback
|
||||
host_look = bible_obj.host.look if bible_obj and bible_obj.host.look else "A professional podcast host"
|
||||
visual_style = bible_obj.visual_style.style_preset if bible_obj else "Realistic Photography"
|
||||
|
||||
final_avatar_prompt = f"Professional headshot of a podcast host, {host_look}, {visual_style} style, clean background, soft studio lighting, center-focused, high resolution, sharp focus, professional photography quality, 16:9 aspect ratio."
|
||||
|
||||
# 3. Generate the image
|
||||
logger.info(f"[Podcast Analyze] Generating avatar with prompt: {final_avatar_prompt}")
|
||||
image_result = generate_image(
|
||||
prompt=final_avatar_prompt,
|
||||
user_id=user_id,
|
||||
width=1024,
|
||||
height=1024
|
||||
)
|
||||
|
||||
# 4. Save to disk and library
|
||||
if image_result and image_result.image_bytes:
|
||||
img_id = str(uuid.uuid4())[:8]
|
||||
filename = f"presenter_podcast_{user_id}_{img_id}.png"
|
||||
output_path = PODCAST_IMAGES_DIR / filename
|
||||
PODCAST_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_result.image_bytes)
|
||||
|
||||
final_avatar_url = f"/api/podcast/images/avatars/{filename}"
|
||||
|
||||
# Save to asset library for reuse
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
file_url=final_avatar_url,
|
||||
filename=filename,
|
||||
title=f"Presenter Avatar - {request.idea[:40]}",
|
||||
description=f"AI-generated podcast presenter for: {request.idea}",
|
||||
provider=image_result.provider,
|
||||
model=image_result.model,
|
||||
cost=image_result.cost
|
||||
)
|
||||
logger.info(f"[Podcast Analyze] ✅ Generated and saved avatar to {final_avatar_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast Analyze] ❌ Failed to generate avatar: {e}")
|
||||
# Non-fatal: continue analysis even if avatar generation fails
|
||||
|
||||
# --- END: Avatar Generation ---
|
||||
|
||||
# Incorporate user feedback if provided
|
||||
feedback_context = ""
|
||||
if request.feedback:
|
||||
feedback_context = f"""
|
||||
USER REGENERATION FEEDBACK:
|
||||
The user was not satisfied with the previous analysis. They provided the following instructions for improvement:
|
||||
"{request.feedback}"
|
||||
Please prioritize this feedback and adjust the analysis accordingly.
|
||||
"""
|
||||
|
||||
prompt = f"""
|
||||
You are an expert podcast producer and research strategist. Given a podcast idea, craft concise podcast-ready assets
|
||||
that sound like episode plans (not fiction stories).
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
{feedback_context}
|
||||
|
||||
Podcast Idea: "{request.idea}"
|
||||
Duration: ~{request.duration} minutes
|
||||
Speakers: {request.speakers} (host + optional guest)
|
||||
|
||||
TASK:
|
||||
1. Define the target audience and content type aligned with the Bible's "Audience DNA" and "Brand DNA".
|
||||
2. Identify 5 high-impact keywords.
|
||||
3. Propose 2 episode outlines with factual segments.
|
||||
4. Suggest 3 titles.
|
||||
5. IMPORTANT: Generate 4-6 specific research queries for Exa. These queries MUST be highly targeted to the episode's topic, the host's expertise level, and the audience's interests as defined in the Bible.
|
||||
* Do NOT use generic queries like "latest trends in X".
|
||||
* DO use queries that look for case studies, specific data points, expert opinions, or contrasting viewpoints that would make for a deep, insightful podcast conversation.
|
||||
|
||||
Return JSON with:
|
||||
- audience: short target audience description
|
||||
- content_type: podcast style/format
|
||||
- top_keywords: 5 podcast-relevant keywords/phrases
|
||||
- suggested_outlines: 2 items, each with title (<=60 chars) and 4-6 short segments (bullet-friendly, factual)
|
||||
- title_suggestions: 3 concise episode titles
|
||||
- research_queries: array of {{"query": "string", "rationale": "string"}}
|
||||
- exa_suggested_config: suggested Exa search options with:
|
||||
- exa_search_type: "auto" | "neural" | "keyword"
|
||||
- exa_category: one of ["research paper","news","company","github","tweet","personal site","pdf","financial report","linkedin profile"]
|
||||
- exa_include_domains: up to 3 reputable domains
|
||||
- exa_exclude_domains: up to 3 domains
|
||||
- max_sources: 6-10
|
||||
- include_statistics: boolean
|
||||
- date_range: one of ["last_month","last_3_months","last_year","all_time"]
|
||||
|
||||
Requirements:
|
||||
- Keep language factual, actionable, and suited for spoken audio.
|
||||
- Avoid narrative fiction tone.
|
||||
- Prefer 2024-2025 context.
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Analyze] Analysis failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Analysis failed: {exc}")
|
||||
|
||||
# Normalize response (accept dict or JSON string)
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="LLM returned non-JSON output")
|
||||
elif isinstance(raw, dict):
|
||||
data = raw
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unexpected LLM response format")
|
||||
|
||||
audience = data.get("audience") or "Growth-focused professionals"
|
||||
content_type = data.get("content_type") or "Interview + insights"
|
||||
top_keywords = data.get("top_keywords") or []
|
||||
suggested_outlines = data.get("suggested_outlines") or []
|
||||
title_suggestions = data.get("title_suggestions") or []
|
||||
research_queries = data.get("research_queries") or []
|
||||
exa_suggested_config = data.get("exa_suggested_config") or None
|
||||
|
||||
return PodcastAnalyzeResponse(
|
||||
audience=audience,
|
||||
content_type=content_type,
|
||||
top_keywords=top_keywords,
|
||||
suggested_outlines=suggested_outlines,
|
||||
title_suggestions=title_suggestions,
|
||||
research_queries=research_queries,
|
||||
exa_suggested_config=exa_suggested_config,
|
||||
bible=bible_obj.model_dump() if bible_obj else None,
|
||||
avatar_url=final_avatar_url,
|
||||
avatar_prompt=final_avatar_prompt,
|
||||
)
|
||||
|
||||
422
_session_backup/models.py
Normal file
422
_session_backup/models.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Podcast API Models
|
||||
|
||||
All Pydantic request/response models for podcast endpoints.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PodcastProjectResponse(BaseModel):
|
||||
"""Response model for podcast project."""
|
||||
id: int
|
||||
project_id: str
|
||||
user_id: str
|
||||
idea: str
|
||||
duration: int
|
||||
speakers: int
|
||||
budget_cap: float
|
||||
analysis: Optional[Dict[str, Any]] = None
|
||||
queries: Optional[List[Dict[str, Any]]] = None
|
||||
selected_queries: Optional[List[str]] = None
|
||||
research: Optional[Dict[str, Any]] = None
|
||||
raw_research: Optional[Dict[str, Any]] = None
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
script_data: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
render_jobs: Optional[List[Dict[str, Any]]] = None
|
||||
knobs: Optional[Dict[str, Any]] = None
|
||||
research_provider: Optional[str] = None
|
||||
show_script_editor: bool = False
|
||||
show_render_queue: bool = False
|
||||
current_step: Optional[str] = None
|
||||
status: str = "draft"
|
||||
is_favorite: bool = False
|
||||
final_video_url: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
avatar_prompt: Optional[str] = None
|
||||
avatar_persona_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PodcastAnalyzeRequest(BaseModel):
|
||||
"""Request model for podcast idea analysis."""
|
||||
idea: str = Field(..., description="Podcast topic or idea")
|
||||
duration: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
avatar_url: Optional[str] = Field(None, description="Current avatar URL if selected")
|
||||
feedback: Optional[str] = Field(None, description="User feedback for regeneration")
|
||||
|
||||
|
||||
class PodcastAnalyzeResponse(BaseModel):
|
||||
"""Response model for podcast idea analysis."""
|
||||
audience: str
|
||||
content_type: str
|
||||
top_keywords: list[str]
|
||||
suggested_outlines: list[Dict[str, Any]]
|
||||
title_suggestions: list[str]
|
||||
research_queries: Optional[List[Dict[str, str]]] = None
|
||||
exa_suggested_config: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
avatar_url: Optional[str] = None
|
||||
avatar_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaRequest(BaseModel):
|
||||
"""Request model for enhancing a podcast idea with AI."""
|
||||
idea: str = Field(..., description="The raw podcast idea or keywords")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaResponse(BaseModel):
|
||||
"""Response model for enhanced podcast idea."""
|
||||
enhanced_ideas: List[str] = Field(..., description="3 AI-enhanced topic choices")
|
||||
rationales: List[str] = Field(..., description="Rationale for each enhanced idea")
|
||||
|
||||
|
||||
class PodcastScriptRequest(BaseModel):
|
||||
"""Request model for podcast script generation."""
|
||||
idea: str = Field(..., description="Podcast idea or topic")
|
||||
duration_minutes: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
research: Optional[Dict[str, Any]] = Field(None, description="Optional research payload to ground the script")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
outline: Optional[Dict[str, Any]] = Field(None, description="The refined episode outline to follow")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="The full analysis context (audience, keywords, etc.)")
|
||||
|
||||
|
||||
class PodcastSceneLine(BaseModel):
|
||||
speaker: str
|
||||
text: str
|
||||
emphasis: Optional[bool] = False
|
||||
|
||||
|
||||
class PodcastScene(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
duration: int
|
||||
lines: list[PodcastSceneLine]
|
||||
approved: bool = False
|
||||
emotion: Optional[str] = None
|
||||
imageUrl: Optional[str] = None # Generated image URL for video generation
|
||||
|
||||
|
||||
class PodcastExaConfig(BaseModel):
|
||||
"""Exa config for podcast research."""
|
||||
exa_search_type: Optional[str] = Field(default="auto", description="auto | keyword | neural")
|
||||
exa_category: Optional[str] = None
|
||||
exa_include_domains: List[str] = []
|
||||
exa_exclude_domains: List[str] = []
|
||||
max_sources: int = 8
|
||||
include_statistics: Optional[bool] = False
|
||||
date_range: Optional[str] = Field(default=None, description="last_month | last_3_months | last_year | all_time")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_domains(self):
|
||||
if self.exa_include_domains and self.exa_exclude_domains:
|
||||
# Exa API does not allow both include and exclude domains together with contents
|
||||
# Prefer include_domains and drop exclude_domains
|
||||
self.exa_exclude_domains = []
|
||||
return self
|
||||
|
||||
|
||||
class PodcastExaResearchRequest(BaseModel):
|
||||
"""Request for podcast research using Exa directly (no blog writer)."""
|
||||
topic: str
|
||||
queries: List[str]
|
||||
exa_config: Optional[PodcastExaConfig] = None
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="Podcast analysis context (audience, content type, etc.)")
|
||||
|
||||
|
||||
class PodcastExaSource(BaseModel):
|
||||
title: str = ""
|
||||
url: str = ""
|
||||
excerpt: str = ""
|
||||
published_at: Optional[str] = None
|
||||
highlights: Optional[List[str]] = None
|
||||
summary: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
image: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastResearchInsight(BaseModel):
|
||||
"""Deep insight extracted from research."""
|
||||
title: str
|
||||
content: str
|
||||
source_indices: List[int] = []
|
||||
|
||||
|
||||
class PodcastExaResearchResponse(BaseModel):
|
||||
sources: List[PodcastExaSource]
|
||||
search_queries: List[str] = []
|
||||
summary: str = ""
|
||||
key_insights: List[PodcastResearchInsight] = []
|
||||
expert_quotes: List[Dict[str, Any]] = []
|
||||
listener_cta: List[str] = []
|
||||
mapped_angles: List[Dict[str, Any]] = []
|
||||
cost: Optional[Dict[str, Any]] = None
|
||||
search_type: Optional[str] = None
|
||||
provider: str = "exa"
|
||||
content: Optional[str] = None # Raw aggregated content (deprecated)
|
||||
|
||||
|
||||
class PodcastScriptResponse(BaseModel):
|
||||
scenes: list[PodcastScene]
|
||||
|
||||
|
||||
class PodcastAudioRequest(BaseModel):
|
||||
"""Generate TTS for a podcast scene."""
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
text: str
|
||||
voice_id: Optional[str] = "Wise_Woman"
|
||||
speed: Optional[float] = 1.0
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[float] = 0.0
|
||||
emotion: Optional[str] = "neutral"
|
||||
english_normalization: Optional[bool] = False # Better number reading for statistics
|
||||
sample_rate: Optional[int] = None
|
||||
bitrate: Optional[int] = None
|
||||
channel: Optional[str] = None
|
||||
format: Optional[str] = None
|
||||
language_boost: Optional[str] = None
|
||||
enable_sync_mode: Optional[bool] = True
|
||||
|
||||
|
||||
class PodcastAudioResponse(BaseModel):
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
audio_filename: str
|
||||
audio_url: str
|
||||
provider: str
|
||||
model: str
|
||||
voice_id: str
|
||||
text_length: int
|
||||
file_size: int
|
||||
cost: float
|
||||
|
||||
|
||||
class PodcastProjectListResponse(BaseModel):
|
||||
"""Response model for project list."""
|
||||
projects: List[PodcastProjectResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class CreateProjectRequest(BaseModel):
|
||||
"""Request model for creating a project."""
|
||||
project_id: str = Field(..., description="Unique project ID")
|
||||
idea: str = Field(..., description="Episode idea or URL")
|
||||
duration: int = Field(..., description="Duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
budget_cap: float = Field(default=50.0, description="Budget cap in USD")
|
||||
avatar_url: Optional[str] = Field(None, description="Optional presenter avatar URL")
|
||||
|
||||
|
||||
class UpdateProjectRequest(BaseModel):
|
||||
"""Request model for updating project state."""
|
||||
analysis: Optional[Dict[str, Any]] = None
|
||||
queries: Optional[List[Dict[str, Any]]] = None
|
||||
selected_queries: Optional[List[str]] = None
|
||||
research: Optional[Dict[str, Any]] = None
|
||||
raw_research: Optional[Dict[str, Any]] = None
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
script_data: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
render_jobs: Optional[List[Dict[str, Any]]] = None
|
||||
knobs: Optional[Dict[str, Any]] = None
|
||||
research_provider: Optional[str] = None
|
||||
show_script_editor: Optional[bool] = None
|
||||
show_render_queue: Optional[bool] = None
|
||||
current_step: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
final_video_url: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastCombineAudioRequest(BaseModel):
|
||||
"""Request model for combining podcast audio files."""
|
||||
project_id: str
|
||||
scene_ids: List[str] = Field(..., description="List of scene IDs to combine")
|
||||
scene_audio_urls: List[str] = Field(..., description="List of audio URLs for each scene")
|
||||
|
||||
|
||||
class PodcastCombineAudioResponse(BaseModel):
|
||||
"""Response model for combined podcast audio."""
|
||||
combined_audio_url: str
|
||||
combined_audio_filename: str
|
||||
total_duration: float
|
||||
file_size: int
|
||||
scene_count: int
|
||||
|
||||
|
||||
class PodcastImageRequest(BaseModel):
|
||||
"""Request for generating an image for a podcast scene."""
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
scene_content: Optional[str] = None # Optional: scene lines text for context
|
||||
idea: Optional[str] = None # Optional: podcast idea for context
|
||||
base_avatar_url: Optional[str] = None # Base avatar image URL for scene variations
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
custom_prompt: Optional[str] = None # Custom prompt from user (overrides auto-generated prompt)
|
||||
style: Optional[str] = None # "Auto", "Fiction", or "Realistic"
|
||||
rendering_speed: Optional[str] = None # "Default", "Turbo", or "Quality"
|
||||
aspect_ratio: Optional[str] = None # "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
|
||||
|
||||
class PodcastImageResponse(BaseModel):
|
||||
"""Response for podcast scene image generation."""
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
image_filename: str
|
||||
image_url: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
cost: float
|
||||
|
||||
|
||||
class PodcastVideoGenerationRequest(BaseModel):
|
||||
"""Request model for podcast video generation."""
|
||||
project_id: str = Field(..., description="Podcast project ID")
|
||||
scene_id: str = Field(..., description="Scene ID")
|
||||
scene_title: str = Field(..., description="Scene title")
|
||||
audio_url: str = Field(..., description="URL to the generated audio file")
|
||||
avatar_image_url: Optional[str] = Field(None, description="URL to scene image (required for video generation)")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
resolution: str = Field("720p", description="Video resolution (480p or 720p)")
|
||||
prompt: Optional[str] = Field(None, description="Optional animation prompt override")
|
||||
seed: Optional[int] = Field(-1, description="Random seed; -1 for random")
|
||||
mask_image_url: Optional[str] = Field(None, description="Optional mask image URL to specify animated region")
|
||||
|
||||
|
||||
class PodcastVideoGenerationResponse(BaseModel):
|
||||
"""Response model for podcast video generation."""
|
||||
task_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class PodcastCombineVideosRequest(BaseModel):
|
||||
"""Request to combine scene videos into final podcast"""
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
scene_video_urls: list[str] = Field(..., description="List of scene video URLs in order")
|
||||
podcast_title: str = Field(default="Podcast", description="Title for the final podcast video")
|
||||
|
||||
|
||||
class PodcastCombineVideosResponse(BaseModel):
|
||||
"""Response from combine videos endpoint"""
|
||||
task_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class AudioDubbingQuality(str, Enum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "AudioDubbingQuality":
|
||||
if value.lower() == "high":
|
||||
return cls.HIGH
|
||||
return cls.LOW
|
||||
|
||||
|
||||
class PodcastAudioDubRequest(BaseModel):
|
||||
"""Request model for audio dubbing."""
|
||||
source_audio_url: str = Field(..., description="URL or path to source audio file")
|
||||
source_language: Optional[str] = Field(None, description="Source language code (auto-detected if None)")
|
||||
target_language: str = Field(..., description="Target language for dubbing")
|
||||
quality: str = Field(default="low", description="Translation quality: low (DeepL) or high (WaveSpeed)")
|
||||
voice_id: Optional[str] = Field(default="Wise_Woman", description="Voice ID for TTS")
|
||||
speed: Optional[float] = Field(default=1.0, ge=0.5, le=2.0, description="Speech speed (0.5-2.0)")
|
||||
emotion: Optional[str] = Field(default="happy", description="Emotion for TTS voice")
|
||||
preserve_emotion: Optional[bool] = Field(default=True, description="Preserve emotional tone in translation")
|
||||
use_voice_clone: Optional[bool] = Field(default=False, description="Use voice cloning to preserve original speaker's voice")
|
||||
custom_voice_id: Optional[str] = Field(None, description="Custom name for the cloned voice")
|
||||
voice_clone_accuracy: Optional[float] = Field(default=0.7, ge=0.1, le=1.0, description="Voice cloning accuracy (0.1-1.0)")
|
||||
|
||||
|
||||
class PodcastAudioDubResponse(BaseModel):
|
||||
"""Response model for audio dubbing task creation."""
|
||||
task_id: str
|
||||
status: str = "pending"
|
||||
message: str = "Audio dubbing task created"
|
||||
|
||||
|
||||
class PodcastAudioDubResult(BaseModel):
|
||||
"""Response model for completed audio dubbing."""
|
||||
dubbed_audio_url: str
|
||||
dubbed_audio_filename: str
|
||||
original_transcript: str
|
||||
translated_transcript: str
|
||||
source_language: str
|
||||
target_language: str
|
||||
voice_id: str
|
||||
quality: str
|
||||
duration_seconds: int
|
||||
file_size: int
|
||||
cost: float
|
||||
task_id: str
|
||||
status: str = "completed"
|
||||
voice_clone_used: Optional[bool] = Field(default=False, description="Whether voice cloning was used")
|
||||
cloned_voice_id: Optional[str] = Field(None, description="ID of the cloned voice if voice_clone_used=True")
|
||||
|
||||
|
||||
class PodcastAudioDubEstimateRequest(BaseModel):
|
||||
"""Request model for dubbing cost estimation."""
|
||||
audio_duration_seconds: float = Field(..., description="Duration of source audio in seconds")
|
||||
target_language: str = Field(..., description="Target language")
|
||||
quality: str = Field(default="low", description="Translation quality")
|
||||
use_voice_clone: Optional[bool] = Field(default=False, description="Include voice cloning cost")
|
||||
|
||||
|
||||
class PodcastAudioDubEstimateResponse(BaseModel):
|
||||
"""Response model for dubbing cost estimation."""
|
||||
estimated_characters: int
|
||||
translation_cost: float
|
||||
tts_cost: float
|
||||
voice_clone_cost: float = 0.0
|
||||
total_cost: float
|
||||
currency: str = "USD"
|
||||
|
||||
|
||||
class VoiceCloneRequest(BaseModel):
|
||||
"""Request model for voice cloning."""
|
||||
source_audio_url: str = Field(..., description="URL or path to source audio file (10-60 seconds recommended)")
|
||||
custom_voice_id: Optional[str] = Field(None, description="Custom name for the cloned voice")
|
||||
accuracy: Optional[float] = Field(default=0.7, ge=0.1, le=1.0, description="Cloning accuracy (0.1-1.0)")
|
||||
language_boost: Optional[str] = Field(None, description="Language to optimize the voice for")
|
||||
|
||||
|
||||
class VoiceCloneResponse(BaseModel):
|
||||
"""Response model for voice cloning."""
|
||||
task_id: str
|
||||
status: str = "pending"
|
||||
message: str = "Voice cloning task created"
|
||||
|
||||
|
||||
class VoiceCloneResult(BaseModel):
|
||||
"""Response model for completed voice cloning."""
|
||||
voice_id: str
|
||||
voice_url: str
|
||||
source_language: str
|
||||
accuracy: float
|
||||
file_size: int
|
||||
task_id: str
|
||||
status: str = "completed"
|
||||
|
||||
837
_session_backup/podcastApi.ts
Normal file
837
_session_backup/podcastApi.ts
Normal file
@@ -0,0 +1,837 @@
|
||||
import { ResearchProvider, ResearchConfig } from "./blogWriterApi";
|
||||
import {
|
||||
storyWriterApi,
|
||||
StorySetupGenerationResponse,
|
||||
} from "./storyWriterApi";
|
||||
import { getResearchConfig, ResearchPersona } from "../api/researchConfig";
|
||||
import { aiApiClient } from "../api/client";
|
||||
import {
|
||||
CreateProjectPayload,
|
||||
CreateProjectResult,
|
||||
Fact,
|
||||
Knobs,
|
||||
PodcastAnalysis,
|
||||
PodcastEstimate,
|
||||
Query,
|
||||
RenderJobResult,
|
||||
Research,
|
||||
Scene,
|
||||
Script,
|
||||
} from "../components/PodcastMaker/types";
|
||||
import { checkPreflight, PreflightOperation } from "./billingService";
|
||||
import { TaskStatus } from "./storyWriterApi";
|
||||
|
||||
const DEFAULT_KNOBS: Knobs = {
|
||||
voice_emotion: "neutral",
|
||||
voice_speed: 1,
|
||||
resolution: "720p",
|
||||
scene_length_target: 45,
|
||||
sample_rate: 24000,
|
||||
bitrate: "standard",
|
||||
};
|
||||
|
||||
// const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
|
||||
const createId = (prefix: string) => {
|
||||
if (typeof crypto !== "undefined" && typeof crypto.randomUUID === "function") {
|
||||
return `${prefix}_${crypto.randomUUID()}`;
|
||||
}
|
||||
return `${prefix}_${Date.now()}_${Math.floor(Math.random() * 10000)}`;
|
||||
};
|
||||
|
||||
type OptionLike = StorySetupGenerationResponse["options"][0] | { plot_elements?: string; premise?: string };
|
||||
|
||||
const deriveSegments = (option?: OptionLike): string[] => {
|
||||
const segments: string[] = [];
|
||||
if (option?.plot_elements) {
|
||||
option.plot_elements
|
||||
.split(/[,.;]+/)
|
||||
.map((p) => p.trim())
|
||||
.filter(Boolean)
|
||||
.forEach((p) => segments.push(p));
|
||||
}
|
||||
if (!segments.length && "premise" in (option || {}) && (option as any)?.premise) {
|
||||
segments.push("Intro", "Key Takeaways", "Examples", "CTA");
|
||||
}
|
||||
return segments.slice(0, 5);
|
||||
};
|
||||
|
||||
const estimateCosts = ({
|
||||
minutes,
|
||||
scenes,
|
||||
chars,
|
||||
quality,
|
||||
avatars,
|
||||
queryCount = 3,
|
||||
}: {
|
||||
minutes: number;
|
||||
scenes: number;
|
||||
chars: number;
|
||||
quality: string;
|
||||
avatars: number;
|
||||
queryCount?: number;
|
||||
}): PodcastEstimate => {
|
||||
const secs = Math.max(60, minutes * 60);
|
||||
const ttsCost = (chars / 1000) * 0.05;
|
||||
const avatarCost = avatars * 0.15;
|
||||
const videoRate = quality === "hd" ? 0.06 : 0.03;
|
||||
const videoCost = secs * videoRate;
|
||||
const researchCost = +(Math.max(1, queryCount) * 0.1).toFixed(2);
|
||||
const total = +(ttsCost + avatarCost + videoCost + researchCost).toFixed(2);
|
||||
return {
|
||||
ttsCost: +ttsCost.toFixed(2),
|
||||
avatarCost: +avatarCost.toFixed(2),
|
||||
videoCost: +videoCost.toFixed(2),
|
||||
researchCost,
|
||||
total,
|
||||
};
|
||||
};
|
||||
|
||||
const mapPersonaQueries = (persona: ResearchPersona | undefined, seed: string): Query[] => {
|
||||
const baseIdea = seed || "AI marketing for small businesses";
|
||||
const personaKeywords = persona?.suggested_keywords?.filter(Boolean) || [];
|
||||
const angles = persona?.research_angles ?? [];
|
||||
const generated: Query[] = [];
|
||||
|
||||
const addQuery = (q: string, why: string, needsRecent = false) => {
|
||||
if (!q.trim()) return;
|
||||
generated.push({
|
||||
id: createId("q"),
|
||||
query: q.trim(),
|
||||
rationale: why,
|
||||
needsRecentStats: needsRecent,
|
||||
});
|
||||
};
|
||||
|
||||
if (personaKeywords.length) {
|
||||
personaKeywords.slice(0, 4).forEach((k, idx) =>
|
||||
addQuery(k, angles[idx % Math.max(1, angles.length)] || "Persona-aligned query", /202[45]|latest|trend/i.test(k))
|
||||
);
|
||||
}
|
||||
|
||||
if (!generated.length) {
|
||||
addQuery(`How is ${baseIdea} evolving in 2024?`, "Trend + outcome focus", true);
|
||||
addQuery(`Best practices for ${baseIdea}`, "Actionable guidance", false);
|
||||
addQuery(`${baseIdea} case studies with ROI`, "Proof and outcomes", true);
|
||||
addQuery(`${baseIdea} risks and objections`, "Address listener concerns", false);
|
||||
}
|
||||
|
||||
return generated.slice(0, 6);
|
||||
};
|
||||
|
||||
const mapSourcesToFacts = (sources: ExaSource[]): Fact[] => {
|
||||
if (!sources || !sources.length) return [];
|
||||
return sources.slice(0, 12).map((source: ExaSource, idx: number) => ({
|
||||
id: source.url || createId("fact"),
|
||||
quote: source.excerpt || source.title || "Insight",
|
||||
url: source.url || "",
|
||||
date: source.published_at || "Unknown",
|
||||
confidence: typeof (source as any).credibility_score === "number" ? (source as any).credibility_score : Math.max(0.5, 0.85 - idx * 0.02),
|
||||
image: source.image,
|
||||
author: source.author,
|
||||
highlights: source.highlights,
|
||||
}));
|
||||
};
|
||||
|
||||
type ExaSource = {
|
||||
title?: string;
|
||||
url?: string;
|
||||
excerpt?: string;
|
||||
published_at?: string;
|
||||
highlights?: string[];
|
||||
summary?: string;
|
||||
source_type?: string;
|
||||
index?: number;
|
||||
image?: string;
|
||||
author?: string;
|
||||
};
|
||||
|
||||
type ExaResearchResult = {
|
||||
sources: ExaSource[];
|
||||
search_queries?: string[];
|
||||
cost?: { total?: number };
|
||||
search_type?: string;
|
||||
provider?: string;
|
||||
content?: string;
|
||||
};
|
||||
|
||||
const mapExaResearchResponse = (response: any): Research => {
|
||||
const factCards = mapSourcesToFacts(response.sources);
|
||||
// Use backend summary if available, otherwise use full content (no truncation) or fallback text
|
||||
const summary = response.summary || response.content || "Research completed.";
|
||||
|
||||
const keyInsights = (response.key_insights || []).map((insight: any) => ({
|
||||
title: insight.title || "Insight",
|
||||
content: insight.content || "",
|
||||
source_indices: insight.source_indices || []
|
||||
}));
|
||||
|
||||
const expertQuotes = (response.expert_quotes || []).map((eq: any) => ({
|
||||
quote: eq.quote || eq.text || "",
|
||||
source_index: eq.source_index ?? 0
|
||||
}));
|
||||
|
||||
const listenerCta = response.listener_cta || [];
|
||||
|
||||
const mappedAngles = (response.mapped_angles || []).map((angle: any) => ({
|
||||
title: angle.title || "",
|
||||
why: angle.why || angle.rationale || "",
|
||||
mappedFactIds: angle.mapped_fact_ids || angle.mappedFactIds || []
|
||||
}));
|
||||
|
||||
return {
|
||||
summary,
|
||||
keyInsights,
|
||||
factCards,
|
||||
mappedAngles,
|
||||
expertQuotes,
|
||||
listenerCta,
|
||||
searchQueries: response.search_queries,
|
||||
searchType: response.search_type,
|
||||
provider: response.provider || "exa",
|
||||
cost: response.cost?.total,
|
||||
sourceCount: response.sources?.length || 0,
|
||||
};
|
||||
};
|
||||
|
||||
const ensurePreflight = async (operation: PreflightOperation) => {
|
||||
const result = await checkPreflight(operation);
|
||||
if (!result.can_proceed) {
|
||||
const message = result.operations[0]?.message || "Pre-flight validation failed";
|
||||
throw new Error(message);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
export const podcastApi = {
|
||||
async createProject(payload: CreateProjectPayload, bible?: any, feedback?: string): Promise<CreateProjectResult> {
|
||||
const storyIdea = payload.ideaOrUrl || "AI marketing for small businesses";
|
||||
|
||||
await ensurePreflight({
|
||||
provider: "gemini",
|
||||
operation_type: "podcast_analysis",
|
||||
tokens_requested: 1500,
|
||||
actual_provider_name: "gemini",
|
||||
});
|
||||
|
||||
// Podcast-specific analysis (not story setup)
|
||||
const analysisResp = await aiApiClient.post("/api/podcast/analyze", {
|
||||
idea: storyIdea,
|
||||
duration: payload.duration,
|
||||
speakers: payload.speakers,
|
||||
bible: bible,
|
||||
avatar_url: payload.avatarUrl,
|
||||
feedback: feedback, // Pass feedback to backend
|
||||
});
|
||||
|
||||
const outlines = (analysisResp.data?.suggested_outlines || []).map((o: any, idx: number) => ({
|
||||
id: o.id || `outline-${idx + 1}`,
|
||||
title: o.title || `Outline ${idx + 1}`,
|
||||
segments: Array.isArray(o.segments) ? o.segments : deriveSegments({ plot_elements: o.segments }),
|
||||
}));
|
||||
|
||||
const analysis: PodcastAnalysis = {
|
||||
audience: analysisResp.data?.audience || "Growth-minded pros",
|
||||
contentType: analysisResp.data?.content_type || "Podcast interview",
|
||||
topKeywords: analysisResp.data?.top_keywords || outlines[0]?.segments?.slice(0, 3) || [],
|
||||
suggestedOutlines: outlines,
|
||||
suggestedKnobs: { ...DEFAULT_KNOBS, ...payload.knobs },
|
||||
titleSuggestions: (analysisResp.data?.title_suggestions || []).filter(Boolean),
|
||||
research_queries: analysisResp.data?.research_queries || [],
|
||||
exaSuggestedConfig: analysisResp.data?.exa_suggested_config || undefined,
|
||||
};
|
||||
|
||||
const researchConfig = await getResearchConfig().catch(() => null);
|
||||
|
||||
// Use AI-generated queries if available, fallback to legacy mapping
|
||||
let queries: Query[] = [];
|
||||
if (analysis.research_queries && analysis.research_queries.length > 0) {
|
||||
queries = analysis.research_queries.map(rq => ({
|
||||
id: createId("q"),
|
||||
query: rq.query,
|
||||
rationale: rq.rationale,
|
||||
needsRecentStats: /202[45]|latest|trend/i.test(rq.query)
|
||||
}));
|
||||
} else {
|
||||
queries = mapPersonaQueries(researchConfig?.research_persona, storyIdea);
|
||||
}
|
||||
|
||||
const projectId = createId("podcast");
|
||||
const estimate = estimateCosts({
|
||||
minutes: payload.duration,
|
||||
scenes: Math.ceil((payload.duration * 60) / (payload.knobs.scene_length_target || DEFAULT_KNOBS.scene_length_target)),
|
||||
chars: Math.max(1000, payload.duration * 900),
|
||||
quality: payload.knobs.bitrate || "standard",
|
||||
avatars: payload.speakers,
|
||||
queryCount: queries.length || 3,
|
||||
});
|
||||
|
||||
return {
|
||||
projectId,
|
||||
analysis,
|
||||
estimate,
|
||||
queries,
|
||||
bible: analysisResp.data?.bible || undefined,
|
||||
avatar_url: analysisResp.data?.avatar_url || null,
|
||||
avatar_prompt: analysisResp.data?.avatar_prompt || null,
|
||||
};
|
||||
},
|
||||
|
||||
async enhanceIdea(params: { idea: string; bible?: any }): Promise<{ enhanced_ideas: string[]; rationales: string[] }> {
|
||||
const response = await aiApiClient.post("/api/podcast/idea/enhance", params);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async runResearch(params: {
|
||||
projectId: string;
|
||||
topic: string;
|
||||
approvedQueries: Query[];
|
||||
provider?: ResearchProvider;
|
||||
exaConfig?: ResearchConfig;
|
||||
bible?: any;
|
||||
analysis?: PodcastAnalysis | null;
|
||||
onProgress?: (message: string) => void;
|
||||
}): Promise<{ research: Research; raw: any }> {
|
||||
const keywords = params.approvedQueries.map((q) => q.query).filter(Boolean);
|
||||
if (!keywords.length) {
|
||||
throw new Error("At least one query must be approved for research.");
|
||||
}
|
||||
|
||||
// Ensure Exa payload respects API constraint: when requesting contents, only one of includeDomains or excludeDomains.
|
||||
let sanitizedExaConfig: ResearchConfig | undefined = params.exaConfig;
|
||||
if (sanitizedExaConfig && sanitizedExaConfig.exa_include_domains?.length) {
|
||||
sanitizedExaConfig = {
|
||||
...sanitizedExaConfig,
|
||||
exa_exclude_domains: undefined,
|
||||
};
|
||||
} else if (sanitizedExaConfig && sanitizedExaConfig.exa_exclude_domains?.length) {
|
||||
sanitizedExaConfig = {
|
||||
...sanitizedExaConfig,
|
||||
exa_include_domains: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
await ensurePreflight({
|
||||
provider: "exa",
|
||||
operation_type: "exa_neural_search",
|
||||
tokens_requested: 0,
|
||||
actual_provider_name: "exa",
|
||||
});
|
||||
|
||||
const response = await aiApiClient.post("/api/podcast/research/exa", {
|
||||
topic: params.topic || keywords[0],
|
||||
queries: keywords,
|
||||
exa_config: sanitizedExaConfig,
|
||||
bible: params.bible,
|
||||
analysis: params.analysis,
|
||||
});
|
||||
|
||||
const exaResult = response.data as ExaResearchResult;
|
||||
if (params.onProgress) {
|
||||
params.onProgress("Deep research completed with Exa.");
|
||||
}
|
||||
const mapped = mapExaResearchResponse(exaResult);
|
||||
return { research: mapped, raw: exaResult };
|
||||
},
|
||||
|
||||
async generateScript(params: {
|
||||
projectId: string;
|
||||
idea: string;
|
||||
research?: ExaResearchResult | null;
|
||||
knobs: Knobs;
|
||||
speakers: number;
|
||||
durationMinutes: number;
|
||||
bible?: any;
|
||||
outline?: any;
|
||||
analysis?: PodcastAnalysis | null;
|
||||
}): Promise<Script> {
|
||||
await ensurePreflight({
|
||||
provider: "gemini",
|
||||
operation_type: "script_generation",
|
||||
tokens_requested: 2000,
|
||||
actual_provider_name: "gemini",
|
||||
});
|
||||
|
||||
const response = await aiApiClient.post("/api/podcast/script", {
|
||||
idea: params.idea,
|
||||
duration_minutes: params.durationMinutes,
|
||||
speakers: params.speakers,
|
||||
research: params.research,
|
||||
bible: params.bible,
|
||||
outline: params.outline,
|
||||
analysis: params.analysis,
|
||||
});
|
||||
|
||||
const scenes = response.data?.scenes || [];
|
||||
const scriptScenes: Scene[] = scenes.map((scene: any) => ({
|
||||
id: scene.id || createId("scene"),
|
||||
title: scene.title || "Scene",
|
||||
duration: scene.duration || Math.max(20, params.knobs.scene_length_target || DEFAULT_KNOBS.scene_length_target),
|
||||
lines:
|
||||
Array.isArray(scene.lines) && scene.lines.length
|
||||
? scene.lines.map((l: any) => ({
|
||||
id: createId("line"),
|
||||
speaker: l.speaker || "Host",
|
||||
text: l.text || "",
|
||||
}))
|
||||
: [
|
||||
{
|
||||
id: createId("line"),
|
||||
speaker: "Host",
|
||||
text: "Let's dive into today's topic.",
|
||||
},
|
||||
],
|
||||
approved: false,
|
||||
}));
|
||||
|
||||
return { scenes: scriptScenes };
|
||||
},
|
||||
|
||||
async previewLine(
|
||||
text: string,
|
||||
options: { voiceId?: string; speed?: number; emotion?: string } = {}
|
||||
): Promise<{ ok: boolean; message: string; audioUrl?: string }> {
|
||||
await ensurePreflight({
|
||||
provider: "audio",
|
||||
operation_type: "tts_preview",
|
||||
tokens_requested: text.length,
|
||||
actual_provider_name: "wavespeed",
|
||||
});
|
||||
|
||||
const response = await storyWriterApi.generateAIAudio({
|
||||
scene_number: 0,
|
||||
scene_title: "Preview",
|
||||
text,
|
||||
voice_id: options.voiceId || "Wise_Woman",
|
||||
speed: options.speed || 1.0,
|
||||
emotion: options.emotion || "neutral",
|
||||
});
|
||||
|
||||
if (!response.success) {
|
||||
throw new Error(response.error || "Preview failed");
|
||||
}
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
message: "Preview ready – opening audio in new tab.",
|
||||
audioUrl: response.audio_url,
|
||||
};
|
||||
},
|
||||
|
||||
async renderSceneAudio(params: {
|
||||
scene: Scene;
|
||||
voiceId?: string;
|
||||
emotion?: string; // Fallback if scene doesn't have emotion
|
||||
speed?: number;
|
||||
volume?: number;
|
||||
pitch?: number;
|
||||
englishNormalization?: boolean;
|
||||
sampleRate?: number;
|
||||
bitrate?: number;
|
||||
channel?: "1" | "2";
|
||||
format?: "mp3" | "wav" | "pcm" | "flac";
|
||||
languageBoost?: string;
|
||||
}): Promise<RenderJobResult> {
|
||||
// Use scene-specific emotion if available, otherwise fallback to provided/default
|
||||
const sceneEmotion = params.scene.emotion || params.emotion || "neutral";
|
||||
|
||||
// Optimize text for Minimax Speech-02-HD TTS
|
||||
// - Strip markdown formatting (bold, italic, etc.) - TTS reads it literally
|
||||
// - Use pause markers <#x#> for natural speech rhythm
|
||||
// - Add longer pauses for speaker changes
|
||||
// - Preserve punctuation for natural breathing
|
||||
// - Add emphasis pauses for important points
|
||||
const text = params.scene.lines
|
||||
.map((line, idx) => {
|
||||
let lineText = line.text.trim();
|
||||
|
||||
// Strip markdown formatting - TTS reads asterisks and other markdown literally
|
||||
// Remove bold (**text** or __text__)
|
||||
lineText = lineText.replace(/\*\*([^*]+)\*\*/g, '$1'); // **bold**
|
||||
lineText = lineText.replace(/\*([^*]+)\*/g, '$1'); // *bold* (single asterisk)
|
||||
lineText = lineText.replace(/__([^_]+)__/g, '$1'); // __bold__
|
||||
lineText = lineText.replace(/_([^_]+)_/g, '$1'); // _italic_ (single underscore)
|
||||
// Remove any remaining stray asterisks or underscores
|
||||
lineText = lineText.replace(/\*+/g, ''); // Remove any remaining asterisks
|
||||
lineText = lineText.replace(/_+/g, ''); // Remove any remaining underscores
|
||||
// Clean up extra spaces
|
||||
lineText = lineText.replace(/\s+/g, ' ').trim();
|
||||
|
||||
// Preserve punctuation (Minimax uses it for natural breathing)
|
||||
// Don't strip punctuation - it helps TTS understand natural pauses
|
||||
|
||||
// Add emphasis pause after lines marked with emphasis
|
||||
if (line.emphasis) {
|
||||
// Minimal pause after emphasized content (0.15s for subtle emphasis)
|
||||
lineText = `${lineText}<#0.15#>`;
|
||||
}
|
||||
|
||||
// Check for speaker change (longer pause for natural conversation flow)
|
||||
const prevLine = idx > 0 ? params.scene.lines[idx - 1] : null;
|
||||
const isSpeakerChange = prevLine && prevLine.speaker !== line.speaker;
|
||||
|
||||
if (isSpeakerChange) {
|
||||
// Short pause for speaker changes (0.2s - enough for natural transition)
|
||||
lineText = `<#0.2#>${lineText}`;
|
||||
}
|
||||
|
||||
// Add minimal pause between lines (only between regular lines, very short)
|
||||
if (idx < params.scene.lines.length - 1) {
|
||||
if (!line.emphasis && !isSpeakerChange) {
|
||||
// Very short pause between lines (0.08s - barely noticeable but helps flow)
|
||||
lineText = `${lineText}<#0.08#>`;
|
||||
}
|
||||
// If emphasis or speaker change, the pause is already added above
|
||||
}
|
||||
|
||||
return lineText;
|
||||
})
|
||||
.join(" ");
|
||||
|
||||
// Validate character limit (Minimax max: 10,000 characters)
|
||||
const MAX_CHARS = 10000;
|
||||
let textToUse = text;
|
||||
if (text.length > MAX_CHARS) {
|
||||
console.warn(
|
||||
`[Podcast] Scene "${params.scene.title}" exceeds ${MAX_CHARS} character limit (${text.length} chars). Truncating...`
|
||||
);
|
||||
// Truncate at word boundary to avoid cutting mid-word
|
||||
const truncated = text.substring(0, MAX_CHARS);
|
||||
const lastSpace = truncated.lastIndexOf(" ");
|
||||
textToUse = lastSpace > 0 ? truncated.substring(0, lastSpace) : truncated;
|
||||
}
|
||||
|
||||
await ensurePreflight({
|
||||
provider: "audio",
|
||||
operation_type: "tts_full_render",
|
||||
tokens_requested: textToUse.length,
|
||||
actual_provider_name: "wavespeed",
|
||||
});
|
||||
|
||||
const response = await aiApiClient.post("/api/podcast/audio", {
|
||||
scene_id: params.scene.id,
|
||||
scene_title: params.scene.title,
|
||||
text: textToUse,
|
||||
voice_id: params.voiceId || "Wise_Woman",
|
||||
speed: params.speed ?? 1.0, // Normal speed (was 0.9, but too slow - causing duration issues)
|
||||
volume: params.volume ?? 1.0,
|
||||
pitch: params.pitch ?? 0.0,
|
||||
emotion: sceneEmotion,
|
||||
english_normalization: params.englishNormalization ?? true, // Better number reading for statistics
|
||||
sample_rate: params.sampleRate || null,
|
||||
bitrate: params.bitrate || null,
|
||||
channel: params.channel || null,
|
||||
format: params.format || null,
|
||||
language_boost: params.languageBoost || null,
|
||||
});
|
||||
|
||||
return {
|
||||
audioUrl: response.data.audio_url,
|
||||
audioFilename: response.data.audio_filename,
|
||||
provider: response.data.provider,
|
||||
model: response.data.model,
|
||||
cost: response.data.cost,
|
||||
voiceId: response.data.voice_id,
|
||||
fileSize: response.data.file_size,
|
||||
};
|
||||
},
|
||||
|
||||
async approveScene(params: { projectId: string; sceneId: string; notes?: string }) {
|
||||
await aiApiClient.post("/api/story/script/approve", {
|
||||
project_id: params.projectId,
|
||||
scene_id: params.sceneId,
|
||||
approved: true,
|
||||
notes: params.notes,
|
||||
});
|
||||
},
|
||||
|
||||
// Project persistence endpoints
|
||||
async saveProject(projectId: string, state: any): Promise<void> {
|
||||
try {
|
||||
await aiApiClient.put(`/api/podcast/projects/${projectId}`, state);
|
||||
} catch (error) {
|
||||
console.error("Failed to save project to database:", error);
|
||||
// Don't throw - localStorage fallback is acceptable
|
||||
}
|
||||
},
|
||||
|
||||
async loadProject(projectId: string): Promise<any> {
|
||||
const response = await aiApiClient.get(`/api/podcast/projects/${projectId}`);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async listProjects(params?: {
|
||||
status?: string;
|
||||
favorites_only?: boolean;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
order_by?: "updated_at" | "created_at";
|
||||
}): Promise<{ projects: any[]; total: number; limit: number; offset: number }> {
|
||||
const response = await aiApiClient.get("/api/podcast/projects", { params });
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async createProjectInDb(params: {
|
||||
project_id: string;
|
||||
idea: string;
|
||||
duration: number;
|
||||
speakers: number;
|
||||
budget_cap: number;
|
||||
avatar_url?: string | null;
|
||||
}): Promise<any> {
|
||||
const response = await aiApiClient.post("/api/podcast/projects", params);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async updateProject(projectId: string, updates: any): Promise<any> {
|
||||
const response = await aiApiClient.put(`/api/podcast/projects/${projectId}`, updates);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async deleteProject(projectId: string): Promise<void> {
|
||||
await aiApiClient.delete(`/api/podcast/projects/${projectId}`);
|
||||
},
|
||||
|
||||
async toggleFavorite(projectId: string): Promise<any> {
|
||||
const response = await aiApiClient.post(`/api/podcast/projects/${projectId}/favorite`);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async saveAudioToAssetLibrary(params: {
|
||||
audioUrl: string;
|
||||
filename: string;
|
||||
title: string;
|
||||
description?: string;
|
||||
projectId: string;
|
||||
sceneId?: string;
|
||||
cost?: number;
|
||||
provider?: string;
|
||||
model?: string;
|
||||
fileSize?: number;
|
||||
}): Promise<{ assetId: number }> {
|
||||
const response = await aiApiClient.post("/api/content-assets/", {
|
||||
asset_type: "audio",
|
||||
source_module: "podcast_maker",
|
||||
filename: params.filename,
|
||||
file_url: params.audioUrl,
|
||||
title: params.title,
|
||||
description: params.description || `Podcast episode audio: ${params.title}`,
|
||||
tags: ["podcast", "audio", params.projectId],
|
||||
asset_metadata: {
|
||||
project_id: params.projectId,
|
||||
scene_id: params.sceneId,
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
},
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
cost: params.cost || 0,
|
||||
file_size: params.fileSize,
|
||||
mime_type: "audio/mpeg",
|
||||
});
|
||||
return { assetId: response.data.id };
|
||||
},
|
||||
|
||||
async generateVideo(params: {
|
||||
projectId: string;
|
||||
sceneId: string;
|
||||
sceneTitle: string;
|
||||
audioUrl: string;
|
||||
avatarImageUrl?: string;
|
||||
bible?: any;
|
||||
resolution?: string;
|
||||
prompt?: string;
|
||||
seed?: number;
|
||||
maskImageUrl?: string;
|
||||
}): Promise<{ taskId: string; status: string; message: string }> {
|
||||
const response = await aiApiClient.post("/api/podcast/render/video", {
|
||||
project_id: params.projectId,
|
||||
scene_id: params.sceneId,
|
||||
scene_title: params.sceneTitle,
|
||||
audio_url: params.audioUrl,
|
||||
avatar_image_url: params.avatarImageUrl,
|
||||
bible: params.bible,
|
||||
resolution: params.resolution || "720p",
|
||||
prompt: params.prompt,
|
||||
seed: params.seed ?? -1,
|
||||
mask_image_url: params.maskImageUrl,
|
||||
});
|
||||
|
||||
// Backend returns snake_case (task_id); normalize to camelCase for callers
|
||||
const { task_id, status, message } = response.data || {};
|
||||
return {
|
||||
taskId: task_id,
|
||||
status,
|
||||
message,
|
||||
};
|
||||
},
|
||||
|
||||
async pollTaskStatus(taskId: string): Promise<TaskStatus | null> {
|
||||
const response = await aiApiClient.get(`/api/podcast/task/${taskId}/status`);
|
||||
// Backend returns null if task not found
|
||||
return response.data || null;
|
||||
},
|
||||
|
||||
async listVideos(projectId?: string): Promise<{
|
||||
videos: Array<{
|
||||
scene_number: number;
|
||||
filename: string;
|
||||
video_url: string;
|
||||
file_size: number;
|
||||
}>;
|
||||
}> {
|
||||
const params = projectId ? { project_id: projectId } : {};
|
||||
const response = await aiApiClient.get("/api/podcast/videos", { params });
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async combineVideos(params: {
|
||||
projectId: string;
|
||||
sceneVideoUrls: string[];
|
||||
podcastTitle?: string;
|
||||
}): Promise<{
|
||||
taskId: string;
|
||||
status: string;
|
||||
message: string;
|
||||
}> {
|
||||
const response = await aiApiClient.post("/api/podcast/render/combine-videos", {
|
||||
project_id: params.projectId,
|
||||
scene_video_urls: params.sceneVideoUrls,
|
||||
podcast_title: params.podcastTitle || "Podcast",
|
||||
});
|
||||
|
||||
const { task_id, status, message } = response.data || {};
|
||||
return {
|
||||
taskId: task_id,
|
||||
status,
|
||||
message,
|
||||
};
|
||||
},
|
||||
|
||||
async generateSceneImage(params: {
|
||||
sceneId: string;
|
||||
sceneTitle: string;
|
||||
sceneContent?: string;
|
||||
baseAvatarUrl?: string;
|
||||
bible?: any;
|
||||
idea?: string;
|
||||
width?: number;
|
||||
height?: number;
|
||||
customPrompt?: string;
|
||||
style?: "Auto" | "Fiction" | "Realistic";
|
||||
renderingSpeed?: "Default" | "Turbo" | "Quality";
|
||||
aspectRatio?: "1:1" | "16:9" | "9:16" | "4:3" | "3:4";
|
||||
}): Promise<{
|
||||
scene_id: string;
|
||||
scene_title: string;
|
||||
image_filename: string;
|
||||
image_url: string;
|
||||
width: number;
|
||||
height: number;
|
||||
provider: string;
|
||||
model?: string;
|
||||
cost: number;
|
||||
}> {
|
||||
const response = await aiApiClient.post("/api/podcast/image", {
|
||||
scene_id: params.sceneId,
|
||||
scene_title: params.sceneTitle,
|
||||
scene_content: params.sceneContent,
|
||||
base_avatar_url: params.baseAvatarUrl || null,
|
||||
bible: params.bible,
|
||||
idea: params.idea || null,
|
||||
width: params.width || 1024,
|
||||
height: params.height || 1024,
|
||||
custom_prompt: params.customPrompt || null,
|
||||
style: params.style || null,
|
||||
rendering_speed: params.renderingSpeed || null,
|
||||
aspect_ratio: params.aspectRatio || null,
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async cancelTask(taskId: string): Promise<void> {
|
||||
// Note: Task cancellation may not be fully supported by backend yet
|
||||
// This is a placeholder for future implementation
|
||||
try {
|
||||
await aiApiClient.post(`/api/story/task/${taskId}/cancel`);
|
||||
} catch (error) {
|
||||
console.warn("Task cancellation not supported:", error);
|
||||
}
|
||||
},
|
||||
|
||||
async combineAudio(params: {
|
||||
projectId: string;
|
||||
sceneIds: string[];
|
||||
sceneAudioUrls: string[];
|
||||
}): Promise<{
|
||||
combined_audio_url: string;
|
||||
combined_audio_filename: string;
|
||||
total_duration: number;
|
||||
file_size: number;
|
||||
scene_count: number;
|
||||
}> {
|
||||
const response = await aiApiClient.post("/api/podcast/combine-audio", {
|
||||
project_id: params.projectId,
|
||||
scene_ids: params.sceneIds,
|
||||
scene_audio_urls: params.sceneAudioUrls,
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async uploadAvatar(file: File, projectId?: string): Promise<{ avatar_url: string; avatar_filename: string }> {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
if (projectId) {
|
||||
formData.append('project_id', projectId);
|
||||
}
|
||||
const response = await aiApiClient.post('/api/podcast/avatar/upload', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async generatePresenters(
|
||||
speakers: number,
|
||||
projectId?: string,
|
||||
audience?: string,
|
||||
contentType?: string,
|
||||
topKeywords?: string[]
|
||||
): Promise<{
|
||||
avatars: Array<{ avatar_url: string; speaker_number: number; prompt?: string; persona_id?: string; seed?: number }>;
|
||||
persona_id?: string;
|
||||
}> {
|
||||
const formData = new FormData();
|
||||
formData.append('speakers', speakers.toString());
|
||||
if (projectId) {
|
||||
formData.append('project_id', projectId);
|
||||
}
|
||||
if (audience) {
|
||||
formData.append('audience', audience);
|
||||
}
|
||||
if (contentType) {
|
||||
formData.append('content_type', contentType);
|
||||
}
|
||||
if (topKeywords && Array.isArray(topKeywords) && topKeywords.length > 0) {
|
||||
formData.append('top_keywords', JSON.stringify(topKeywords));
|
||||
}
|
||||
const response = await aiApiClient.post('/api/podcast/avatar/generate', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async makeAvatarPresentable(avatarUrl: string, projectId?: string): Promise<{ avatar_url: string; avatar_filename: string }> {
|
||||
const formData = new FormData();
|
||||
formData.append('avatar_url', avatarUrl);
|
||||
if (projectId) {
|
||||
formData.append('project_id', projectId);
|
||||
}
|
||||
const response = await aiApiClient.post('/api/podcast/avatar/make-presentable', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
};
|
||||
|
||||
export type PodcastApi = typeof podcastApi;
|
||||
|
||||
244
_session_backup/research.py
Normal file
244
_session_backup/research.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Podcast Research Handlers
|
||||
|
||||
Research endpoints using Exa provider and LLM summarization.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List
|
||||
from types import SimpleNamespace
|
||||
import json
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastExaResearchRequest,
|
||||
PodcastExaResearchResponse,
|
||||
PodcastExaSource,
|
||||
PodcastExaConfig,
|
||||
PodcastResearchInsight,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/research/exa", response_model=PodcastExaResearchResponse)
|
||||
async def podcast_research_exa(
|
||||
request: PodcastExaResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Run podcast research via Exa and then use LLM to extract deep insights.
|
||||
Uses Podcast Bible and Analysis context for hyper-personalization.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
queries = [q.strip() for q in request.queries if q and q.strip()]
|
||||
if not queries:
|
||||
raise HTTPException(status_code=400, detail="At least one query is required for research.")
|
||||
|
||||
exa_cfg = request.exa_config or PodcastExaConfig()
|
||||
cfg = SimpleNamespace(
|
||||
exa_search_type=exa_cfg.exa_search_type or "auto",
|
||||
exa_category=exa_cfg.exa_category,
|
||||
exa_include_domains=exa_cfg.exa_include_domains or [],
|
||||
exa_exclude_domains=exa_cfg.exa_exclude_domains or [],
|
||||
max_sources=exa_cfg.max_sources or 8,
|
||||
source_types=[],
|
||||
)
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
|
||||
# --- Context Building ---
|
||||
bible_service = PodcastBibleService()
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
try:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Research] Failed to serialize bible: {exc}")
|
||||
|
||||
analysis_context = ""
|
||||
if request.analysis:
|
||||
analysis_context = f"""
|
||||
PODCAST ANALYSIS CONTEXT:
|
||||
Audience: {request.analysis.get('audience', 'General')}
|
||||
Content Type: {request.analysis.get('content_type', 'Informative')}
|
||||
Top Keywords: {', '.join(request.analysis.get('top_keywords', []))}
|
||||
"""
|
||||
|
||||
# Exa search params
|
||||
industry = request.bible.get("brand", {}).get("industry", "") if request.bible else ""
|
||||
target_audience = ""
|
||||
if request.bible:
|
||||
audience_dna = request.bible.get("audience", {})
|
||||
if audience_dna:
|
||||
interests = ", ".join(audience_dna.get("interests", []))
|
||||
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
|
||||
|
||||
try:
|
||||
# 1. RUN EXA SEARCH
|
||||
result = await provider.search(
|
||||
prompt=request.topic,
|
||||
topic=request.topic,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
config=cfg,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Exa Research] Search failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Exa research failed: {exc}")
|
||||
|
||||
# 2. EXTRACT INSIGHTS VIA LLM
|
||||
raw_content = result.get("content", "")
|
||||
sources = result.get("sources", [])
|
||||
|
||||
summary = ""
|
||||
key_insights = []
|
||||
expert_quotes = []
|
||||
listener_cta = []
|
||||
mapped_angles = []
|
||||
|
||||
if raw_content and sources:
|
||||
logger.info(f"[Podcast Research] Extracting insights from {len(sources)} sources for user {user_id}")
|
||||
|
||||
prompt = f"""
|
||||
You are an expert research analyst for a high-end podcast production team.
|
||||
Your task is to analyze the following research data and extract deep, actionable insights for a podcast episode.
|
||||
|
||||
PODCAST CONTEXT:
|
||||
Topic: {request.topic}
|
||||
{bible_context}
|
||||
{analysis_context}
|
||||
|
||||
RESEARCH DATA (from {len(sources)} sources):
|
||||
{raw_content}
|
||||
|
||||
TASK:
|
||||
1. Provide a comprehensive summary (2-3 paragraphs) of the most important findings. Use Markdown for formatting (bolding, lists).
|
||||
2. Extract 3-5 "Key Insights". Each insight should have a title and a detailed explanation.
|
||||
3. For each insight, identify which source indices (e.g. 1, 2) it was derived from.
|
||||
4. Extract notable "Expert Quotes" - direct quotes from industry leaders, researchers, or authoritative voices found in the sources.
|
||||
5. Suggest 2-4 "Listener CTA" (call-to-action) ideas that the podcast host can use to engage the audience.
|
||||
6. Identify 3-5 "Mapped Angles" - unique content angles with rationale for why they matter for this topic.
|
||||
|
||||
NOTE: The research data includes "Key Highlights", "Summaries", and "Excerpts" from various sources.
|
||||
Pay special attention to the "Key Highlights" sections as they contain the most relevant information extracted by the neural search engine.
|
||||
|
||||
Return JSON structure:
|
||||
{{
|
||||
"summary": "Detailed markdown summary...",
|
||||
"key_insights": [
|
||||
{{
|
||||
"title": "Insight Title",
|
||||
"content": "Detailed markdown content...",
|
||||
"source_indices": [1, 2]
|
||||
}}
|
||||
],
|
||||
"expert_quotes": [
|
||||
{{
|
||||
"quote": "Exact quote from source...",
|
||||
"source_index": 1
|
||||
}}
|
||||
],
|
||||
"listener_cta": [
|
||||
"Call-to-action suggestion 1",
|
||||
"Call-to-action suggestion 2"
|
||||
],
|
||||
"mapped_angles": [
|
||||
{{
|
||||
"title": "Angle Title",
|
||||
"why": "Why this angle matters for the audience...",
|
||||
"mapped_fact_ids": ["fact_1", "fact_2"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Requirements:
|
||||
- Ensure insights are deep, not just superficial facts. Look for trends, expert opinions, and specific data points.
|
||||
- Expert quotes should be exact or near-exact quotes from the sources, with attribution.
|
||||
- Listener CTAs should be practical and engaging (e.g., "Share your experience with X on social media").
|
||||
- Mapped angles should be unique perspectives that make the episode stand out.
|
||||
- Tone should be professional, insightful, and ready for a podcast host to discuss.
|
||||
- Avoid generic filler.
|
||||
"""
|
||||
try:
|
||||
llm_response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(llm_response, str):
|
||||
data = json.loads(llm_response)
|
||||
else:
|
||||
data = llm_response
|
||||
|
||||
summary = data.get("summary", "")
|
||||
key_insights = [PodcastResearchInsight(**insight) for insight in data.get("key_insights", [])]
|
||||
expert_quotes = data.get("expert_quotes", [])
|
||||
listener_cta = data.get("listener_cta", [])
|
||||
mapped_angles = data.get("mapped_angles", [])
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Research] LLM Insight extraction failed: {exc}")
|
||||
# Fallback to a basic summary if LLM fails
|
||||
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
|
||||
|
||||
# Fallback: if summary is still empty (e.g. LLM returned empty string), use raw content first paragraph or basic text
|
||||
if not summary:
|
||||
if raw_content:
|
||||
summary = raw_content[:2000] # Use first 2000 chars of raw content as summary
|
||||
else:
|
||||
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
|
||||
|
||||
# 3. TRACK USAGE
|
||||
try:
|
||||
cost_total = 0.0
|
||||
if isinstance(result, dict):
|
||||
cost_total = result.get("cost", {}).get("total", 0.005) if result.get("cost") else 0.005
|
||||
provider.track_exa_usage(user_id, cost_total)
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[Podcast Exa Research] Failed to track usage: {track_err}")
|
||||
|
||||
sources_payload = []
|
||||
for src in sources:
|
||||
try:
|
||||
sources_payload.append(PodcastExaSource(**src))
|
||||
except Exception:
|
||||
sources_payload.append(PodcastExaSource(**{
|
||||
"title": src.get("title", ""),
|
||||
"url": src.get("url", ""),
|
||||
"excerpt": src.get("excerpt", ""),
|
||||
"published_at": src.get("published_at"),
|
||||
"highlights": src.get("highlights"),
|
||||
"summary": src.get("summary"),
|
||||
"source_type": src.get("source_type"),
|
||||
"index": src.get("index"),
|
||||
"image": src.get("image"),
|
||||
"author": src.get("author"),
|
||||
}))
|
||||
|
||||
return PodcastExaResearchResponse(
|
||||
sources=sources_payload,
|
||||
search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries,
|
||||
summary=summary,
|
||||
key_insights=key_insights,
|
||||
expert_quotes=expert_quotes,
|
||||
listener_cta=listener_cta,
|
||||
mapped_angles=mapped_angles,
|
||||
cost=result.get("cost") if isinstance(result, dict) else None,
|
||||
search_type=result.get("search_type") if isinstance(result, dict) else None,
|
||||
provider=result.get("provider", "exa") if isinstance(result, dict) else "exa",
|
||||
content=raw_content,
|
||||
)
|
||||
|
||||
183
_session_backup/script.py
Normal file
183
_session_backup/script.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Podcast Script Handlers
|
||||
|
||||
Script generation endpoint.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastScriptRequest,
|
||||
PodcastScriptResponse,
|
||||
PodcastScene,
|
||||
PodcastSceneLine,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/script", response_model=PodcastScriptResponse)
|
||||
async def generate_podcast_script(
|
||||
request: PodcastScriptRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Generate a podcast script outline (scenes + lines) using podcast-oriented prompting.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Build comprehensive research context for higher-quality scripts
|
||||
research_context = ""
|
||||
if request.research:
|
||||
try:
|
||||
key_insights = request.research.get("keyword_analysis", {}).get("key_insights") or []
|
||||
fact_cards = request.research.get("factCards", []) or []
|
||||
mapped_angles = request.research.get("mappedAngles", []) or []
|
||||
sources = request.research.get("sources", []) or []
|
||||
|
||||
top_facts = [f.get("quote", "") for f in fact_cards[:5] if f.get("quote")]
|
||||
angles_summary = [
|
||||
f"{a.get('title', '')}: {a.get('why', '')}" for a in mapped_angles[:3] if a.get("title") or a.get("why")
|
||||
]
|
||||
top_sources = [s.get("url") for s in sources[:3] if s.get("url")]
|
||||
|
||||
research_parts = []
|
||||
if key_insights:
|
||||
research_parts.append(f"Key Insights: {', '.join(key_insights[:5])}")
|
||||
if top_facts:
|
||||
research_parts.append(f"Key Facts: {', '.join(top_facts)}")
|
||||
if angles_summary:
|
||||
research_parts.append(f"Research Angles: {' | '.join(angles_summary)}")
|
||||
if top_sources:
|
||||
research_parts.append(f"Top Sources: {', '.join(top_sources)}")
|
||||
|
||||
research_context = "\n".join(research_parts)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to parse research context: {exc}")
|
||||
research_context = ""
|
||||
|
||||
# Extract Podcast Bible context for hyper-personalization
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
bible_obj = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to serialize podcast bible: {exc}")
|
||||
|
||||
# Extract Analysis and Outline context for grounding
|
||||
analysis_context = ""
|
||||
if request.analysis:
|
||||
analysis_context = f"""
|
||||
TARGET AUDIENCE: {request.analysis.get('audience', 'General')}
|
||||
CONTENT TYPE: {request.analysis.get('contentType', 'Conversational')}
|
||||
TOP KEYWORDS: {', '.join(request.analysis.get('topKeywords', []))}
|
||||
"""
|
||||
|
||||
outline_context = ""
|
||||
if request.outline:
|
||||
outline_context = f"""
|
||||
REFINED EPISODE OUTLINE (Follow this structure closely):
|
||||
Title: {request.outline.get('title', 'N/A')}
|
||||
Segments: {' | '.join(request.outline.get('segments', []))}
|
||||
"""
|
||||
|
||||
prompt = f"""You are an expert podcast script planner. Create natural, conversational podcast scenes.
|
||||
|
||||
{f"PODCAST BIBLE (Hyper-Personalization Context):\n{bible_context}\n" if bible_context else ""}
|
||||
{f"ANALYSIS CONTEXT:\n{analysis_context}\n" if analysis_context else ""}
|
||||
{f"REFINED OUTLINE:\n{outline_context}\n" if outline_context else ""}
|
||||
|
||||
Podcast Idea: "{request.idea}"
|
||||
Duration: ~{request.duration_minutes} minutes
|
||||
Speakers: {request.speakers} (Host + optional Guest)
|
||||
|
||||
{f"RESEARCH CONTEXT:\n{research_context}\n" if research_context else ""}
|
||||
|
||||
Return JSON with:
|
||||
- scenes: array of scenes. Each scene has:
|
||||
- id: string
|
||||
- title: short scene title (<= 60 chars)
|
||||
- duration: duration in seconds (evenly split across total duration)
|
||||
- emotion: string (one of: "neutral", "happy", "excited", "serious", "curious", "confident")
|
||||
- lines: array of {{"speaker": "...", "text": "...", "emphasis": boolean}}
|
||||
* Write natural, conversational dialogue
|
||||
* Each line can be a sentence or a few sentences that flow together
|
||||
* Use plain text only - no markdown formatting (no asterisks, underscores, etc.)
|
||||
* Mark "emphasis": true for key statistics or important points
|
||||
|
||||
Guidelines:
|
||||
- Write for spoken delivery: conversational, natural, with contractions.
|
||||
- Follow the interaction tone specified in the Bible.
|
||||
- Ensure the Host persona matches the background and personality traits from the Bible.
|
||||
- Structure the intro and outro scenes according to the Bible's "Intro Format" and "Outro Format".
|
||||
- Adhere to any constraints mentioned in the Bible.
|
||||
- Use insights from the Research Context to ground the conversation in facts.
|
||||
- IMPORTANT: Follow the REFINED OUTLINE segments as the primary structure for the episode.
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Script generation failed: {exc}")
|
||||
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="LLM returned non-JSON output")
|
||||
elif isinstance(raw, dict):
|
||||
data = raw
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unexpected LLM response format")
|
||||
|
||||
scenes_data = data.get("scenes") or []
|
||||
if not isinstance(scenes_data, list):
|
||||
raise HTTPException(status_code=500, detail="LLM response missing scenes array")
|
||||
|
||||
valid_emotions = {"neutral", "happy", "excited", "serious", "curious", "confident"}
|
||||
|
||||
# Normalize scenes
|
||||
scenes: list[PodcastScene] = []
|
||||
for idx, scene in enumerate(scenes_data):
|
||||
title = scene.get("title") or f"Scene {idx + 1}"
|
||||
duration = int(scene.get("duration") or max(30, (request.duration_minutes * 60) // max(1, len(scenes_data))))
|
||||
emotion = scene.get("emotion") or "neutral"
|
||||
if emotion not in valid_emotions:
|
||||
emotion = "neutral"
|
||||
lines_raw = scene.get("lines") or []
|
||||
lines: list[PodcastSceneLine] = []
|
||||
for line in lines_raw:
|
||||
speaker = line.get("speaker") or ("Host" if len(lines) % request.speakers == 0 else "Guest")
|
||||
text = line.get("text") or ""
|
||||
emphasis = line.get("emphasis", False)
|
||||
if text:
|
||||
lines.append(PodcastSceneLine(speaker=speaker, text=text, emphasis=emphasis))
|
||||
scenes.append(
|
||||
PodcastScene(
|
||||
id=scene.get("id") or f"scene-{idx + 1}",
|
||||
title=title,
|
||||
duration=duration,
|
||||
lines=lines,
|
||||
approved=False,
|
||||
emotion=emotion,
|
||||
)
|
||||
)
|
||||
|
||||
return PodcastScriptResponse(scenes=scenes)
|
||||
|
||||
209
_session_backup/types.ts
Normal file
209
_session_backup/types.ts
Normal file
@@ -0,0 +1,209 @@
|
||||
export type Knobs = {
|
||||
voice_emotion: string;
|
||||
voice_speed: number;
|
||||
resolution: string;
|
||||
scene_length_target: number;
|
||||
sample_rate: number;
|
||||
bitrate: string;
|
||||
};
|
||||
|
||||
export type Query = {
|
||||
id: string;
|
||||
query: string;
|
||||
rationale: string;
|
||||
needsRecentStats: boolean;
|
||||
};
|
||||
|
||||
export type Fact = {
|
||||
id: string;
|
||||
quote: string;
|
||||
url: string;
|
||||
date: string;
|
||||
confidence: number;
|
||||
image?: string;
|
||||
author?: string;
|
||||
highlights?: string[];
|
||||
};
|
||||
|
||||
export type ResearchInsight = {
|
||||
title: string;
|
||||
content: string;
|
||||
source_indices: number[];
|
||||
};
|
||||
|
||||
export type Research = {
|
||||
summary: string;
|
||||
keyInsights: ResearchInsight[];
|
||||
factCards: Fact[];
|
||||
mappedAngles: {
|
||||
title: string;
|
||||
why: string;
|
||||
mappedFactIds: string[];
|
||||
}[];
|
||||
searchQueries?: string[];
|
||||
searchType?: string;
|
||||
provider?: string;
|
||||
cost?: number;
|
||||
sourceCount?: number;
|
||||
expertQuotes?: { quote: string; source_index: number }[];
|
||||
listenerCta?: string[];
|
||||
};
|
||||
|
||||
export type Line = {
|
||||
id: string;
|
||||
speaker: string;
|
||||
text: string;
|
||||
usedFactIds?: string[];
|
||||
emphasis?: boolean; // Mark lines that need vocal emphasis
|
||||
};
|
||||
|
||||
export type Scene = {
|
||||
id: string;
|
||||
title: string;
|
||||
duration: number;
|
||||
lines: Line[];
|
||||
approved?: boolean;
|
||||
emotion?: string; // Scene-specific emotion
|
||||
audioUrl?: string; // Generated audio URL for this scene
|
||||
imageUrl?: string; // Generated image URL for this scene (for video generation)
|
||||
};
|
||||
|
||||
export type Script = {
|
||||
scenes: Scene[];
|
||||
};
|
||||
|
||||
export type JobStatus =
|
||||
| "idle"
|
||||
| "previewing"
|
||||
| "queued"
|
||||
| "running"
|
||||
| "completed"
|
||||
| "cancelled"
|
||||
| "failed";
|
||||
|
||||
export type Job = {
|
||||
sceneId: string;
|
||||
title: string;
|
||||
status: JobStatus;
|
||||
progress: number;
|
||||
previewUrl?: string | null;
|
||||
finalUrl?: string | null;
|
||||
videoUrl?: string | null;
|
||||
jobId?: string | null;
|
||||
taskId?: string | null;
|
||||
cost?: number | null;
|
||||
provider?: string | null;
|
||||
voiceId?: string | null;
|
||||
fileSize?: number | null;
|
||||
avatarImageUrl?: string | null;
|
||||
imageUrl?: string | null; // Scene-specific image URL
|
||||
};
|
||||
|
||||
export type PodcastAnalysis = {
|
||||
audience: string;
|
||||
contentType: string;
|
||||
topKeywords: string[];
|
||||
suggestedOutlines: { id: number | string; title: string; segments: string[] }[];
|
||||
suggestedKnobs: Knobs;
|
||||
titleSuggestions: string[];
|
||||
research_queries?: { query: string; rationale: string }[];
|
||||
exaSuggestedConfig?: {
|
||||
exa_search_type?: "auto" | "keyword" | "neural";
|
||||
exa_category?: string;
|
||||
exa_include_domains?: string[];
|
||||
exa_exclude_domains?: string[];
|
||||
max_sources?: number;
|
||||
include_statistics?: boolean;
|
||||
date_range?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type PodcastEstimate = {
|
||||
ttsCost: number;
|
||||
avatarCost: number;
|
||||
videoCost: number;
|
||||
researchCost: number;
|
||||
total: number;
|
||||
};
|
||||
|
||||
export type HostPersona = {
|
||||
name: string;
|
||||
background: string;
|
||||
expertise_level: string;
|
||||
personality_traits: string[];
|
||||
vocal_style: string;
|
||||
catchphrases: string[];
|
||||
};
|
||||
|
||||
export type AudienceDNA = {
|
||||
expertise_level: string;
|
||||
interests: string[];
|
||||
pain_points: string[];
|
||||
demographics?: string;
|
||||
};
|
||||
|
||||
export type BrandDNA = {
|
||||
industry: string;
|
||||
tone: string;
|
||||
communication_style: string;
|
||||
key_messages: string[];
|
||||
competitor_context?: string;
|
||||
};
|
||||
|
||||
export type PodcastBible = {
|
||||
project_id?: string;
|
||||
host: HostPersona;
|
||||
audience: AudienceDNA;
|
||||
brand: BrandDNA;
|
||||
};
|
||||
|
||||
export type CreateProjectPayload = {
|
||||
ideaOrUrl: string;
|
||||
speakers: number;
|
||||
duration: number;
|
||||
knobs: Knobs;
|
||||
budgetCap: number;
|
||||
files: { voiceFile?: File | null; avatarFile?: File | null };
|
||||
avatarUrl?: string | null;
|
||||
};
|
||||
|
||||
export type CreateProjectResult = {
|
||||
projectId: string;
|
||||
analysis: PodcastAnalysis;
|
||||
estimate: PodcastEstimate;
|
||||
queries: Query[];
|
||||
bible?: PodcastBible;
|
||||
avatar_url?: string | null;
|
||||
avatar_prompt?: string | null;
|
||||
};
|
||||
|
||||
export type RenderJobResult = {
|
||||
audioUrl: string;
|
||||
audioFilename: string;
|
||||
provider: string;
|
||||
model: string;
|
||||
cost: number;
|
||||
voiceId: string;
|
||||
fileSize: number;
|
||||
videoUrl?: string;
|
||||
videoFilename?: string;
|
||||
};
|
||||
|
||||
export interface VideoGenerationSettings {
|
||||
prompt: string;
|
||||
resolution: "480p" | "720p";
|
||||
seed?: number | null;
|
||||
maskImageUrl?: string | null;
|
||||
}
|
||||
|
||||
export type TaskStatus = {
|
||||
task_id: string;
|
||||
status: "pending" | "processing" | "completed" | "failed";
|
||||
progress?: number;
|
||||
message?: string;
|
||||
result?: any;
|
||||
error?: string;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
};
|
||||
|
||||
425
_session_backup/usePodcastWorkflow.ts
Normal file
425
_session_backup/usePodcastWorkflow.ts
Normal file
@@ -0,0 +1,425 @@
|
||||
import { useState, useEffect, useMemo, useCallback } from "react";
|
||||
import { podcastApi } from "../../../services/podcastApi";
|
||||
import { usePreflightCheck } from "../../../hooks/usePreflightCheck";
|
||||
import { useBudgetTracking } from "../../../hooks/useBudgetTracking";
|
||||
import { CreateProjectPayload, Script } from "../types";
|
||||
import { usePodcastProjectState } from "../../../hooks/usePodcastProjectState";
|
||||
import { sanitizeExaConfig, announceError, getStepLabel } from "./utils";
|
||||
|
||||
type PodcastProjectStateReturn = ReturnType<typeof usePodcastProjectState>;
|
||||
|
||||
interface UsePodcastWorkflowProps {
|
||||
projectState: PodcastProjectStateReturn;
|
||||
onError: (message: string) => void;
|
||||
}
|
||||
|
||||
export const usePodcastWorkflow = ({ projectState, onError }: UsePodcastWorkflowProps) => {
|
||||
const {
|
||||
project,
|
||||
analysis,
|
||||
queries,
|
||||
selectedQueries,
|
||||
research,
|
||||
rawResearch,
|
||||
researchProvider,
|
||||
showScriptEditor,
|
||||
showRenderQueue,
|
||||
currentStep,
|
||||
renderJobs,
|
||||
budgetCap,
|
||||
setProject,
|
||||
setAnalysis,
|
||||
setQueries,
|
||||
setSelectedQueries,
|
||||
setResearch,
|
||||
setRawResearch,
|
||||
setEstimate,
|
||||
setScriptData,
|
||||
setShowScriptEditor,
|
||||
setShowRenderQueue,
|
||||
setKnobs,
|
||||
setResearchProvider,
|
||||
setBudgetCap,
|
||||
updateRenderJob,
|
||||
initializeProject,
|
||||
setBible,
|
||||
} = projectState;
|
||||
|
||||
const [isAnalyzing, setIsAnalyzing] = useState(false);
|
||||
const [isResearching, setIsResearching] = useState(false);
|
||||
const [announcement, setAnnouncement] = useState("");
|
||||
const [showResumeAlert, setShowResumeAlert] = useState(false);
|
||||
const [showPreflightDialog, setShowPreflightDialog] = useState(false);
|
||||
const [preflightResponse, setPreflightResponse] = useState<any>(null);
|
||||
const [preflightOperationName, setPreflightOperationName] = useState<string>("");
|
||||
|
||||
const budgetTracking = useBudgetTracking(budgetCap || 50);
|
||||
const preflightCheck = usePreflightCheck({
|
||||
onBlocked: (response) => {
|
||||
setPreflightResponse(response);
|
||||
setShowPreflightDialog(true);
|
||||
},
|
||||
});
|
||||
|
||||
// Update budget cap when project state changes
|
||||
useEffect(() => {
|
||||
if (budgetCap) {
|
||||
budgetTracking.setBudgetCap(budgetCap);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [budgetCap]);
|
||||
|
||||
// Check if we have a saved project on mount
|
||||
useEffect(() => {
|
||||
if (project && currentStep && currentStep !== "create") {
|
||||
setShowResumeAlert(true);
|
||||
setTimeout(() => setShowResumeAlert(false), 5000);
|
||||
}
|
||||
}, [project, currentStep]);
|
||||
|
||||
useEffect(() => {
|
||||
if (announcement) {
|
||||
const t = setTimeout(() => setAnnouncement(""), 4000);
|
||||
return () => clearTimeout(t);
|
||||
}
|
||||
return undefined;
|
||||
}, [announcement]);
|
||||
|
||||
const handleCreate = useCallback(async (payload: CreateProjectPayload, feedback?: string) => {
|
||||
if (isAnalyzing) return;
|
||||
setResearch(null);
|
||||
setRawResearch(null);
|
||||
setScriptData(null);
|
||||
setShowScriptEditor(false);
|
||||
setShowRenderQueue(false);
|
||||
try {
|
||||
setIsAnalyzing(true);
|
||||
|
||||
// Use existing avatar URL if provided (e.g. brand avatar), or upload new file
|
||||
let avatarUrl: string | null = payload.avatarUrl || null;
|
||||
if (payload.files.avatarFile) {
|
||||
try {
|
||||
setAnnouncement("Uploading presenter avatar...");
|
||||
const uploadResponse = await podcastApi.uploadAvatar(payload.files.avatarFile);
|
||||
avatarUrl = uploadResponse.avatar_url;
|
||||
} catch (error) {
|
||||
console.error('Avatar upload failed:', error);
|
||||
// Continue without avatar - will generate one later
|
||||
}
|
||||
}
|
||||
|
||||
// NEW FLOW: Create project first to generate/get the Podcast Bible
|
||||
// This allows the analysis to be personalized using the Bible context
|
||||
const projectId = project?.id || `podcast_${Date.now()}_${Math.floor(Math.random() * 1000)}`;
|
||||
setAnnouncement("Initializing project and brand context...");
|
||||
const dbProject = project ? null : await initializeProject(payload, projectId, avatarUrl);
|
||||
const bible = dbProject?.bible || projectState.bible;
|
||||
|
||||
setAnnouncement(feedback ? "Regenerating analysis using your feedback..." : "Analyzing your idea — AI suggestions incoming");
|
||||
const result = await podcastApi.createProject(payload, bible, feedback);
|
||||
|
||||
if (result.bible) {
|
||||
setBible(result.bible);
|
||||
} else if (dbProject?.bible) {
|
||||
setBible(dbProject.bible);
|
||||
}
|
||||
|
||||
// Update the project in database with the analysis results
|
||||
try {
|
||||
await podcastApi.updateProject(projectId, {
|
||||
analysis: result.analysis,
|
||||
estimate: result.estimate,
|
||||
queries: result.queries,
|
||||
selected_queries: result.queries.map(q => q.id),
|
||||
avatar_url: result.avatar_url,
|
||||
avatar_prompt: result.avatar_prompt,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Failed to update project with analysis results:', error);
|
||||
}
|
||||
|
||||
setProject({
|
||||
id: projectId,
|
||||
idea: payload.ideaOrUrl,
|
||||
duration: payload.duration,
|
||||
speakers: payload.speakers,
|
||||
avatarUrl: result.avatar_url || avatarUrl,
|
||||
avatarPrompt: result.avatar_prompt || null,
|
||||
avatarPersonaId: null,
|
||||
});
|
||||
|
||||
setAnalysis(result.analysis);
|
||||
setEstimate(result.estimate);
|
||||
setQueries(result.queries);
|
||||
setSelectedQueries(new Set(result.queries.map((q) => q.id)));
|
||||
setKnobs(payload.knobs);
|
||||
setBudgetCap(payload.budgetCap);
|
||||
|
||||
// Generate presenters AFTER analysis completes (to use analysis insights)
|
||||
// This happens only if no avatar was uploaded
|
||||
if (!avatarUrl && payload.speakers > 0 && result.analysis) {
|
||||
try {
|
||||
setAnnouncement("Generating presenter avatars using AI insights...");
|
||||
const presentersResponse = await podcastApi.generatePresenters(
|
||||
payload.speakers,
|
||||
result.projectId,
|
||||
result.analysis.audience,
|
||||
result.analysis.contentType,
|
||||
result.analysis.topKeywords
|
||||
);
|
||||
if (presentersResponse.avatars && presentersResponse.avatars.length > 0) {
|
||||
// Store the first presenter avatar URL and prompt
|
||||
const firstAvatar = presentersResponse.avatars[0];
|
||||
const prompt = firstAvatar.prompt || null;
|
||||
setProject({
|
||||
id: result.projectId,
|
||||
idea: payload.ideaOrUrl,
|
||||
duration: payload.duration,
|
||||
speakers: payload.speakers,
|
||||
avatarUrl: firstAvatar.avatar_url,
|
||||
avatarPrompt: prompt,
|
||||
avatarPersonaId: firstAvatar.persona_id || presentersResponse.persona_id || null,
|
||||
});
|
||||
setAnnouncement("Analysis complete - Presenter avatars generated");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Presenter generation failed:', error);
|
||||
setAnnouncement("Analysis complete - Avatar generation will happen later");
|
||||
// Continue without presenters - can generate later
|
||||
}
|
||||
} else {
|
||||
setAnnouncement("Analysis complete");
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error?.response?.status === 429 || error?.response?.data?.detail) {
|
||||
const errorDetail = error.response.data.detail;
|
||||
if (typeof errorDetail === 'object' && errorDetail.error && errorDetail.error.includes('limit')) {
|
||||
const usageInfo = errorDetail.usage_info || {};
|
||||
const blockedResponse = {
|
||||
can_proceed: false,
|
||||
estimated_cost: 0,
|
||||
operations: [{
|
||||
provider: errorDetail.provider || 'huggingface',
|
||||
operation_type: 'ai_text_generation',
|
||||
cost: 0,
|
||||
allowed: false,
|
||||
limit_info: usageInfo.limit_info || null,
|
||||
message: errorDetail.message || errorDetail.error || 'Subscription limit exceeded',
|
||||
}],
|
||||
total_cost: 0,
|
||||
usage_summary: usageInfo.usage_summary || null,
|
||||
cached: false,
|
||||
};
|
||||
setPreflightResponse(blockedResponse);
|
||||
setPreflightOperationName('Podcast Analysis');
|
||||
setShowPreflightDialog(true);
|
||||
setAnnouncement("Subscription limit reached. Please upgrade to continue.");
|
||||
} else {
|
||||
const message = typeof errorDetail === 'string' ? errorDetail : errorDetail.message || errorDetail.error || 'Request limit exceeded';
|
||||
announceError(setAnnouncement, new Error(message));
|
||||
}
|
||||
} else {
|
||||
announceError(setAnnouncement, error);
|
||||
}
|
||||
} finally {
|
||||
setIsAnalyzing(false);
|
||||
}
|
||||
}, [isAnalyzing, setResearch, setRawResearch, setScriptData, setShowScriptEditor, setShowRenderQueue, initializeProject, setProject, setAnalysis, setEstimate, setQueries, setSelectedQueries, setKnobs, setBudgetCap, setBible]);
|
||||
|
||||
const handleRunResearch = useCallback(async () => {
|
||||
if (isResearching) return;
|
||||
if (!project) {
|
||||
setAnnouncement("Create a project first.");
|
||||
return;
|
||||
}
|
||||
if (selectedQueries.size === 0) {
|
||||
setAnnouncement("Select at least one query to research.");
|
||||
return;
|
||||
}
|
||||
|
||||
setPreflightOperationName("Research");
|
||||
const approvedQueries = queries.filter((q) => selectedQueries.has(q.id));
|
||||
const preflightResult = await preflightCheck.check({
|
||||
provider: researchProvider === "exa" ? "exa" : "gemini",
|
||||
operation_type: researchProvider === "exa" ? "exa_neural_search" : "google_grounding",
|
||||
tokens_requested: researchProvider === "exa" ? 0 : 1200,
|
||||
actual_provider_name: researchProvider || "exa",
|
||||
});
|
||||
|
||||
if (!preflightResult.can_proceed) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setIsResearching(true);
|
||||
setAnnouncement(`Starting ${researchProvider === "exa" ? "deep" : "standard"} research — this may take a moment...`);
|
||||
setResearch(null);
|
||||
setRawResearch(null);
|
||||
setScriptData(null);
|
||||
setShowScriptEditor(false);
|
||||
setShowRenderQueue(false);
|
||||
|
||||
try {
|
||||
const { research: mapped, raw } = await podcastApi.runResearch({
|
||||
projectId: project.id,
|
||||
topic: project.idea,
|
||||
approvedQueries,
|
||||
provider: researchProvider,
|
||||
exaConfig: sanitizeExaConfig(analysis?.exaSuggestedConfig),
|
||||
bible: projectState.bible,
|
||||
analysis: analysis,
|
||||
onProgress: (message) => {
|
||||
setAnnouncement(message);
|
||||
},
|
||||
});
|
||||
setResearch(mapped);
|
||||
setRawResearch(raw);
|
||||
setAnnouncement("Research complete — review fact cards below");
|
||||
} catch (researchError) {
|
||||
const errorMessage = researchError instanceof Error
|
||||
? researchError.message
|
||||
: "Research failed. Please try again or switch to Standard Research.";
|
||||
|
||||
if (errorMessage.includes("Exa") || errorMessage.includes("exa")) {
|
||||
setAnnouncement(`Deep research failed: ${errorMessage}. Try Standard Research instead.`);
|
||||
} else if (errorMessage.includes("timeout")) {
|
||||
setAnnouncement("Research timed out. Please try again with fewer queries.");
|
||||
} else {
|
||||
setAnnouncement(`Research failed: ${errorMessage}`);
|
||||
}
|
||||
|
||||
console.error("Research error:", researchError);
|
||||
throw researchError;
|
||||
}
|
||||
} catch (error) {
|
||||
announceError(setAnnouncement, error);
|
||||
} finally {
|
||||
setIsResearching(false);
|
||||
}
|
||||
}, [isResearching, project, selectedQueries, queries, researchProvider, preflightCheck, analysis, setResearch, setRawResearch, setScriptData, setShowScriptEditor, setShowRenderQueue, projectState.bible]);
|
||||
|
||||
const handleGenerateScript = useCallback(async () => {
|
||||
if (showScriptEditor) return;
|
||||
if (!project || !research) {
|
||||
setAnnouncement("Project or research missing — cannot generate script");
|
||||
return;
|
||||
}
|
||||
|
||||
setPreflightOperationName("Script Generation");
|
||||
const preflightResult = await preflightCheck.check({
|
||||
provider: "gemini",
|
||||
operation_type: "script_generation",
|
||||
tokens_requested: 2000,
|
||||
actual_provider_name: "gemini",
|
||||
});
|
||||
|
||||
if (!preflightResult.can_proceed) {
|
||||
return;
|
||||
}
|
||||
|
||||
setScriptData(null);
|
||||
setShowRenderQueue(false);
|
||||
setShowScriptEditor(true);
|
||||
|
||||
try {
|
||||
const result = await podcastApi.generateScript({
|
||||
projectId: project.id,
|
||||
idea: project.idea,
|
||||
research: rawResearch,
|
||||
knobs: projectState.knobs,
|
||||
speakers: project.speakers,
|
||||
durationMinutes: project.duration,
|
||||
bible: projectState.bible,
|
||||
outline: analysis?.suggestedOutlines?.[0], // Pass the first (possibly refined) outline
|
||||
analysis: analysis, // Pass full analysis context
|
||||
});
|
||||
|
||||
setScriptData(result);
|
||||
} catch (error) {
|
||||
announceError(setAnnouncement, error);
|
||||
}
|
||||
}, [showScriptEditor, project, research, preflightCheck, setScriptData, setShowRenderQueue, setShowScriptEditor, rawResearch, projectState.knobs, projectState.bible])
|
||||
|
||||
const handleProceedToRendering = useCallback((script: Script) => {
|
||||
setScriptData(script);
|
||||
if (renderJobs.length === 0) {
|
||||
script.scenes.forEach((scene) => {
|
||||
const hasExistingAudio = Boolean(scene.audioUrl);
|
||||
updateRenderJob(scene.id, {
|
||||
sceneId: scene.id,
|
||||
title: scene.title,
|
||||
status: hasExistingAudio ? ("completed" as const) : ("idle" as const),
|
||||
progress: hasExistingAudio ? 100 : 0,
|
||||
previewUrl: null,
|
||||
finalUrl: hasExistingAudio ? scene.audioUrl : null,
|
||||
jobId: null,
|
||||
});
|
||||
});
|
||||
}
|
||||
setShowRenderQueue(true);
|
||||
setShowScriptEditor(false);
|
||||
}, [renderJobs.length, setScriptData, updateRenderJob, setShowRenderQueue, setShowScriptEditor]);
|
||||
|
||||
const toggleQuery = useCallback((id: string) => {
|
||||
if (isResearching) return;
|
||||
const current = selectedQueries;
|
||||
const next = new Set<string>(current);
|
||||
if (next.has(id)) next.delete(id);
|
||||
else next.add(id);
|
||||
setSelectedQueries(next);
|
||||
}, [isResearching, selectedQueries, setSelectedQueries]);
|
||||
|
||||
const activeStep = useMemo(() => {
|
||||
if (showRenderQueue) return 3;
|
||||
if (showScriptEditor) return 2;
|
||||
if (currentStep === 'research' || research) return 1;
|
||||
if (currentStep === 'analysis' || analysis) return 0;
|
||||
return -1;
|
||||
}, [showRenderQueue, showScriptEditor, currentStep, research, analysis]);
|
||||
|
||||
const canGenerateScript = Boolean(project && research && rawResearch);
|
||||
|
||||
const handleRegenerate = useCallback(async (feedback?: string) => {
|
||||
if (!project) return;
|
||||
|
||||
// Prepare the payload from existing project state
|
||||
const payload: CreateProjectPayload = {
|
||||
ideaOrUrl: project.idea,
|
||||
duration: project.duration,
|
||||
speakers: project.speakers,
|
||||
knobs: projectState.knobs,
|
||||
budgetCap: projectState.budgetCap,
|
||||
avatarUrl: project.avatarUrl,
|
||||
files: {} // No new files for regeneration
|
||||
};
|
||||
|
||||
await handleCreate(payload, feedback);
|
||||
}, [project, projectState.knobs, projectState.budgetCap, handleCreate]);
|
||||
|
||||
return {
|
||||
// State
|
||||
isAnalyzing,
|
||||
isResearching,
|
||||
announcement,
|
||||
showResumeAlert,
|
||||
showPreflightDialog,
|
||||
preflightResponse,
|
||||
preflightOperationName,
|
||||
activeStep,
|
||||
canGenerateScript,
|
||||
// Handlers
|
||||
handleCreate,
|
||||
handleRegenerate,
|
||||
handleRunResearch,
|
||||
handleGenerateScript,
|
||||
handleProceedToRendering,
|
||||
toggleQuery,
|
||||
setAnnouncement,
|
||||
setShowResumeAlert,
|
||||
setShowPreflightDialog,
|
||||
setPreflightResponse,
|
||||
setResearchProvider,
|
||||
getStepLabel,
|
||||
};
|
||||
};
|
||||
|
||||
184
add_missing_columns.py
Normal file
184
add_missing_columns.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration script to add missing columns to usage_summaries table.
|
||||
Run this once to fix the database schema.
|
||||
|
||||
Usage:
|
||||
python add_missing_columns.py
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
def get_db_path():
|
||||
"""Find the database path."""
|
||||
possible_paths = [
|
||||
Path(__file__).parent / "backend" / "alwrity.db",
|
||||
Path(__file__).parent.parent / "backend" / "alwrity.db",
|
||||
Path("C:/Users/diksha rawat/Desktop/ALwrity_github/windsurf/ALwrity/backend/alwrity.db"),
|
||||
]
|
||||
|
||||
for db_path in possible_paths:
|
||||
if db_path.exists():
|
||||
print(f"Using database: {db_path}")
|
||||
return db_path
|
||||
|
||||
backend_dir = Path(__file__).parent / "backend"
|
||||
if backend_dir.exists():
|
||||
db_files = list(backend_dir.glob("*.db"))
|
||||
if db_files:
|
||||
print(f"Found database: {db_files[0]}")
|
||||
return db_files[0]
|
||||
|
||||
raise FileNotFoundError(f"Database not found. Searched: {possible_paths}")
|
||||
|
||||
def create_usage_summaries_table(cursor):
|
||||
"""Create the usage_summaries table if it doesn't exist."""
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS usage_summaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id VARCHAR(100) NOT NULL,
|
||||
billing_period VARCHAR(20) NOT NULL,
|
||||
|
||||
-- API Call Counts
|
||||
gemini_calls INTEGER DEFAULT 0,
|
||||
openai_calls INTEGER DEFAULT 0,
|
||||
anthropic_calls INTEGER DEFAULT 0,
|
||||
mistral_calls INTEGER DEFAULT 0,
|
||||
wavespeed_calls INTEGER DEFAULT 0,
|
||||
tavily_calls INTEGER DEFAULT 0,
|
||||
serper_calls INTEGER DEFAULT 0,
|
||||
metaphor_calls INTEGER DEFAULT 0,
|
||||
firecrawl_calls INTEGER DEFAULT 0,
|
||||
stability_calls INTEGER DEFAULT 0,
|
||||
exa_calls INTEGER DEFAULT 0,
|
||||
video_calls INTEGER DEFAULT 0,
|
||||
image_edit_calls INTEGER DEFAULT 0,
|
||||
audio_calls INTEGER DEFAULT 0,
|
||||
|
||||
-- Token Usage
|
||||
gemini_tokens INTEGER DEFAULT 0,
|
||||
openai_tokens INTEGER DEFAULT 0,
|
||||
anthropic_tokens INTEGER DEFAULT 0,
|
||||
mistral_tokens INTEGER DEFAULT 0,
|
||||
wavespeed_tokens INTEGER DEFAULT 0,
|
||||
|
||||
-- Cost Tracking
|
||||
gemini_cost REAL DEFAULT 0.0,
|
||||
openai_cost REAL DEFAULT 0.0,
|
||||
anthropic_cost REAL DEFAULT 0.0,
|
||||
mistral_cost REAL DEFAULT 0.0,
|
||||
wavespeed_cost REAL DEFAULT 0.0,
|
||||
tavily_cost REAL DEFAULT 0.0,
|
||||
serper_cost REAL DEFAULT 0.0,
|
||||
metaphor_cost REAL DEFAULT 0.0,
|
||||
firecrawl_cost REAL DEFAULT 0.0,
|
||||
stability_cost REAL DEFAULT 0.0,
|
||||
exa_cost REAL DEFAULT 0.0,
|
||||
video_cost REAL DEFAULT 0.0,
|
||||
image_edit_cost REAL DEFAULT 0.0,
|
||||
audio_cost REAL DEFAULT 0.0,
|
||||
|
||||
-- Totals
|
||||
total_calls INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
total_cost REAL DEFAULT 0.0,
|
||||
|
||||
-- Performance Metrics
|
||||
avg_response_time REAL DEFAULT 0.0,
|
||||
error_rate REAL DEFAULT 0.0,
|
||||
usage_status VARCHAR(20) DEFAULT 'active',
|
||||
warnings_sent INTEGER DEFAULT 0,
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
UNIQUE(user_id, billing_period)
|
||||
)
|
||||
""")
|
||||
print("Created usage_summaries table")
|
||||
|
||||
def add_missing_columns():
|
||||
db_path = get_db_path()
|
||||
print(f"Using database: {db_path}")
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check what tables exist
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
print(f"Tables in database: {tables}")
|
||||
|
||||
# Check if usage_summaries exists
|
||||
if "usage_summaries" not in tables:
|
||||
print("usage_summaries table doesn't exist. Creating it...")
|
||||
create_usage_summaries_table(cursor)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print("Done! Table created successfully.")
|
||||
return
|
||||
|
||||
# Get existing columns
|
||||
cursor.execute("PRAGMA table_info(usage_summaries)")
|
||||
existing_columns = {row[1] for row in cursor.fetchall()}
|
||||
print(f"Existing columns in usage_summaries: {len(existing_columns)}")
|
||||
|
||||
# Columns to add (name, type, default)
|
||||
columns_to_add = [
|
||||
# Call counts
|
||||
("wavespeed_calls", "INTEGER", "0"),
|
||||
("tavily_calls", "INTEGER", "0"),
|
||||
("serper_calls", "INTEGER", "0"),
|
||||
("metaphor_calls", "INTEGER", "0"),
|
||||
("firecrawl_calls", "INTEGER", "0"),
|
||||
("stability_calls", "INTEGER", "0"),
|
||||
("exa_calls", "INTEGER", "0"),
|
||||
("video_calls", "INTEGER", "0"),
|
||||
("image_edit_calls", "INTEGER", "0"),
|
||||
("audio_calls", "INTEGER", "0"),
|
||||
# Token usage
|
||||
("wavespeed_tokens", "INTEGER", "0"),
|
||||
# Cost tracking
|
||||
("wavespeed_cost", "REAL", "0.0"),
|
||||
("tavily_cost", "REAL", "0.0"),
|
||||
("serper_cost", "REAL", "0.0"),
|
||||
("metaphor_cost", "REAL", "0.0"),
|
||||
("firecrawl_cost", "REAL", "0.0"),
|
||||
("stability_cost", "REAL", "0.0"),
|
||||
("exa_cost", "REAL", "0.0"),
|
||||
("video_cost", "REAL", "0.0"),
|
||||
("image_edit_cost", "REAL", "0.0"),
|
||||
("audio_cost", "REAL", "0.0"),
|
||||
]
|
||||
|
||||
added = []
|
||||
skipped = []
|
||||
|
||||
for col_name, col_type, default in columns_to_add:
|
||||
if col_name in existing_columns:
|
||||
skipped.append(col_name)
|
||||
continue
|
||||
|
||||
try:
|
||||
sql = f"ALTER TABLE usage_summaries ADD COLUMN {col_name} {col_type} DEFAULT {default}"
|
||||
cursor.execute(sql)
|
||||
added.append(col_name)
|
||||
print(f" Added: {col_name}")
|
||||
except sqlite3.Error as e:
|
||||
print(f" Error adding {col_name}: {e}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print(f"\nSummary:")
|
||||
print(f" Added: {len(added)} columns")
|
||||
print(f" Skipped (already exist): {len(skipped)} columns")
|
||||
|
||||
if added:
|
||||
print(f"\nColumns added: {', '.join(added)}")
|
||||
if skipped:
|
||||
print(f"Already existed: {', '.join(skipped)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
add_missing_columns()
|
||||
@@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Add _get_all_historical_usage method to usage_tracking_service.py
|
||||
|
||||
with open('services/subscription/usage_tracking_service.py', 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find where to insert (before get_usage_trends)
|
||||
insert_idx = None
|
||||
for i, line in enumerate(lines):
|
||||
if ' def get_usage_trends(' in line:
|
||||
insert_idx = i
|
||||
break
|
||||
|
||||
if insert_idx is None:
|
||||
print("Error: Could not find insertion point")
|
||||
exit(1)
|
||||
|
||||
print(f"Inserting at line {insert_idx + 1}")
|
||||
|
||||
# Method to insert
|
||||
new_method = ''' def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
|
||||
# Get all usage summaries for the user
|
||||
all_summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
if not all_summaries:
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'limits': self.pricing_service.get_user_limits(user_id),
|
||||
'provider_breakdown': {},
|
||||
'usage_percentages': {},
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Aggregate all data from UsageSummary
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
# Calculate weighted average response time
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
# Calculate overall error rate
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Map database columns to frontend keys
|
||||
provider_mapping = {
|
||||
'gemini_calls': 'gemini',
|
||||
'openai_calls': 'openai',
|
||||
'anthropic_calls': 'anthropic',
|
||||
'mistral_calls': 'huggingface',
|
||||
'wavespeed_calls': 'wavespeed',
|
||||
'exa_calls': 'exa',
|
||||
'video_calls': 'video',
|
||||
'image_edit_calls': 'image_edit',
|
||||
'audio_calls': 'audio',
|
||||
}
|
||||
|
||||
# Build provider_breakdown for frontend
|
||||
provider_breakdown = {}
|
||||
for db_col, frontend_key in provider_mapping.items():
|
||||
total_provider_calls = sum(getattr(s, db_col, 0) or 0 for s in all_summaries)
|
||||
provider_breakdown[frontend_key] = {
|
||||
'calls': total_provider_calls,
|
||||
'cost': 0,
|
||||
'tokens': 0
|
||||
}
|
||||
|
||||
# Calculate usage_percentages based on limits
|
||||
usage_percentages = {}
|
||||
if limits and limits.get('limits'):
|
||||
# Gemini calls percentage
|
||||
gemini_calls = provider_breakdown.get('gemini', {}).get('calls', 0)
|
||||
gemini_limit = limits.get('limits', {}).get('gemini_calls', 0) or 0
|
||||
if gemini_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (gemini_calls / gemini_limit) * 100
|
||||
|
||||
# HuggingFace calls percentage (from mistral_calls)
|
||||
huggingface_calls = provider_breakdown.get('huggingface', {}).get('calls', 0)
|
||||
huggingface_limit = limits.get('limits', {}).get('mistral_calls', 0) or 0
|
||||
if huggingface_limit > 0:
|
||||
usage_percentages['huggingface_calls'] = (huggingface_calls / huggingface_limit) * 100
|
||||
|
||||
# Cost percentage
|
||||
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (total_cost / cost_limit) * 100
|
||||
|
||||
# Build historical breakdown
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
'billing_period': s.billing_period,
|
||||
'total_calls': s.total_calls or 0,
|
||||
'total_tokens': s.total_tokens or 0,
|
||||
'total_cost': float(s.total_cost or 0),
|
||||
'usage_status': status_val,
|
||||
'updated_at': s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
# Determine overall status
|
||||
usage_status = 'active'
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status = s.usage_status.value
|
||||
except:
|
||||
status = str(s.usage_status)
|
||||
if status == 'limit_reached':
|
||||
usage_status = 'limit_reached'
|
||||
break
|
||||
elif status == 'warning' and usage_status != 'limit_reached':
|
||||
usage_status = 'warning'
|
||||
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': usage_status,
|
||||
'total_calls': total_calls,
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': round(total_cost, 2),
|
||||
'avg_response_time': round(avg_response_time, 2),
|
||||
'error_rate': round(error_rate, 2),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': historical_breakdown,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
'''
|
||||
|
||||
# Insert the new method
|
||||
new_lines = lines[:insert_idx] + [new_method] + lines[insert_idx:]
|
||||
|
||||
# Write back
|
||||
with open('services/subscription/usage_tracking_service.py', 'w', encoding='utf-8') as f:
|
||||
f.writelines(new_lines)
|
||||
|
||||
print("Successfully added _get_all_historical_usage method")
|
||||
@@ -5,8 +5,8 @@ Modular utilities for ALwrity backend startup and configuration.
|
||||
|
||||
import os
|
||||
|
||||
# Check feature mode early to skip heavy imports
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
# Check podcast mode early to skip heavy imports
|
||||
_is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
|
||||
from .dependency_manager import DependencyManager
|
||||
from .environment_setup import EnvironmentSetup
|
||||
@@ -26,25 +26,41 @@ from .feature_runtime import (
|
||||
)
|
||||
|
||||
# Lazy load OnboardingManager - it triggers heavy imports (aiohttp, etc.)
|
||||
if _is_full_mode:
|
||||
if not _is_podcast:
|
||||
from .onboarding_manager import OnboardingManager
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
else:
|
||||
OnboardingManager = None
|
||||
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
|
||||
@@ -51,13 +51,6 @@ FEATURE_GROUPS: Dict[str, FeatureGroup] = {
|
||||
"api.content_planning.strategy_copilot:router",
|
||||
),
|
||||
),
|
||||
"blog_writer": FeatureGroup(
|
||||
features=("blog_writer",),
|
||||
routers=(
|
||||
"api.blog_writer.router:router",
|
||||
"api.blog_writer.seo_analysis:router",
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +59,5 @@ PROFILE_GROUP_MAP: Dict[str, Tuple[str, ...]] = {
|
||||
"core": ("core",),
|
||||
"podcast": ("core", "podcast"),
|
||||
"youtube": ("core", "youtube"),
|
||||
"blog_writer": ("core", "blog_writer"),
|
||||
"planning": ("core", "content_planning"),
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
|
||||
CORE_ROUTER_REGISTRY = [
|
||||
{"name": "component_logic", "module": "api.component_logic", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog_writer", "youtube"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog-writer", "youtube"}},
|
||||
{"name": "step3_research", "module": "api.onboarding_utils.step3_routes", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "step4_assets", "module": "api.onboarding_utils.step4_asset_routes", "attr": "router", "features": {"all", "core", "podcast"}},
|
||||
{"name": "step4_persona", "module": "api.onboarding_utils.step4_persona_routes_optimized", "attr": "router", "features": {"all", "core"}},
|
||||
@@ -29,31 +29,31 @@ CORE_ROUTER_REGISTRY = [
|
||||
{"name": "linkedin_image", "module": "api.linkedin_image_generation", "attr": "router", "features": {"all", "core", "linkedin"}},
|
||||
{"name": "brainstorm", "module": "api.brainstorm", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "hallucination_detector", "module": "api.hallucination_detector", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "platform_analytics", "module": "routers.platform_analytics", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "bing_insights", "module": "routers.bing_insights", "attr": "router", "features": {"all", "core", "seo"}},
|
||||
{"name": "background_jobs", "module": "routers.background_jobs", "attr": "router", "features": {"all", "core"}},
|
||||
]
|
||||
|
||||
OPTIONAL_ROUTER_REGISTRY = [
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story_writer"}},
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story-writer"}},
|
||||
{"name": "wix", "module": "api.wix_routes", "attr": "router", "features": {"all"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "persona", "module": "api.persona_routes", "attr": "router", "features": {"all", "persona"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video_studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product_marketing"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video-studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product-marketing"}},
|
||||
{"name": "campaign_creator", "module": "routers.campaign_creator", "attr": "router", "features": {"all"}},
|
||||
{"name": "content_assets", "module": "api.content_assets.router", "attr": "router", "features": {"all"}},
|
||||
{"name": "podcast", "module": "api.podcast.router", "attr": "router", "features": {"all", "podcast"}},
|
||||
|
||||
@@ -7,11 +7,12 @@ The onboarding endpoints are re-exported from a stable module
|
||||
|
||||
import os
|
||||
|
||||
# In feature-only modes, don't import heavy onboarding endpoints
|
||||
# They trigger heavy dependencies (exa_py, etc.)
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
# Check podcast mode early
|
||||
_is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
|
||||
if not _is_full_mode:
|
||||
# In podcast mode, don't import heavy onboarding endpoints
|
||||
# They trigger heavy dependencies (exa_py, etc.)
|
||||
if _is_podcast:
|
||||
__all__ = []
|
||||
else:
|
||||
from .onboarding_endpoints import (
|
||||
|
||||
@@ -1,104 +1,52 @@
|
||||
"""
|
||||
Assets Serving Router
|
||||
|
||||
Serves user-uploaded assets (avatars, voice samples) from workspace storage.
|
||||
Uses authenticated or query-token access for security.
|
||||
Audio MIME types are set correctly based on file extension so browsers
|
||||
can play voice clone previews without NotSupportedError.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
import os
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
from typing import Dict, Any
|
||||
|
||||
from middleware.auth_middleware import get_current_user_with_query_token
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from utils.storage_paths import get_repo_root, sanitize_user_id
|
||||
from services.database import WORKSPACE_DIR, get_user_db_path
|
||||
|
||||
router = APIRouter(prefix="/api/assets", tags=["Assets Serving"])
|
||||
|
||||
MIME_MAP = {
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".ogg": "audio/ogg",
|
||||
".opus": "audio/opus",
|
||||
".webm": "audio/webm",
|
||||
".m4a": "audio/mp4",
|
||||
".aac": "audio/aac",
|
||||
".flac": "audio/flac",
|
||||
".png": "image/png",
|
||||
".jpg": "image/jpeg",
|
||||
".jpeg": "image/jpeg",
|
||||
".gif": "image/gif",
|
||||
".webp": "image/webp",
|
||||
".svg": "image/svg+xml",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_asset_path(user_id: str, category: str, filename: str) -> Path:
|
||||
"""Resolve asset path in user workspace with path-traversal protection."""
|
||||
safe_user_id = sanitize_user_id(user_id)
|
||||
repo_root = get_repo_root()
|
||||
|
||||
file_path = (repo_root / "workspace" / f"workspace_{safe_user_id}" / "assets" / category / filename).resolve()
|
||||
|
||||
workspace_dir = (repo_root / "workspace" / f"workspace_{safe_user_id}").resolve()
|
||||
if not str(file_path).startswith(str(workspace_dir)):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
def _get_media_type(filename: str) -> str:
|
||||
"""Determine MIME type from file extension, with fallback."""
|
||||
ext = Path(filename).suffix.lower()
|
||||
return MIME_MAP.get(ext, "application/octet-stream")
|
||||
|
||||
|
||||
@router.get("/{user_id}/avatars/{filename}")
|
||||
async def serve_avatar(
|
||||
user_id: str,
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve avatar images. Supports auth via Authorization header or ?token= query param."""
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
async def serve_avatar(user_id: str, filename: str):
|
||||
"""
|
||||
Serve avatar images directly.
|
||||
Public endpoint relying on unguessable filenames.
|
||||
"""
|
||||
# Sanitize user_id (simple check to prevent directory traversal)
|
||||
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
|
||||
if safe_user_id != user_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = os.path.basename(filename)
|
||||
file_path = _resolve_asset_path(user_id, "avatars", safe_filename)
|
||||
|
||||
|
||||
# Construct path
|
||||
# workspace/workspace_{user_id}/assets/avatars/{filename}
|
||||
file_path = Path(WORKSPACE_DIR) / f"workspace_{safe_user_id}" / "assets" / "avatars" / safe_filename
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
media_type = _get_media_type(safe_filename)
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
@router.get("/{user_id}/voice_samples/{filename}")
|
||||
async def serve_voice_sample(
|
||||
user_id: str,
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve voice sample audio files.
|
||||
|
||||
Supports auth via Authorization header or ?token= query param.
|
||||
The ?token= param is essential for <audio> elements and new Audio()
|
||||
which cannot send Authorization headers.
|
||||
async def serve_voice_sample(user_id: str, filename: str):
|
||||
"""
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
Serve voice sample audio files directly.
|
||||
"""
|
||||
# Sanitize user_id
|
||||
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
|
||||
if safe_user_id != user_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid user ID")
|
||||
|
||||
# Sanitize filename
|
||||
safe_filename = os.path.basename(filename)
|
||||
file_path = _resolve_asset_path(user_id, "voice_samples", safe_filename)
|
||||
|
||||
|
||||
# Construct path
|
||||
# workspace/workspace_{user_id}/assets/voice_samples/{filename}
|
||||
file_path = Path(WORKSPACE_DIR) / f"workspace_{safe_user_id}" / "assets" / "voice_samples" / safe_filename
|
||||
|
||||
if not file_path.exists():
|
||||
logger.info(f"[Assets] Voice sample not found: {file_path}")
|
||||
raise HTTPException(status_code=404, detail="Asset not found")
|
||||
|
||||
media_type = _get_media_type(safe_filename)
|
||||
file_size = file_path.stat().st_size
|
||||
logger.warning(f"[Assets] Serving voice sample: {safe_filename} ({media_type}, {file_size} bytes)")
|
||||
return FileResponse(file_path, media_type=media_type)
|
||||
|
||||
return FileResponse(file_path)
|
||||
|
||||
@@ -1195,68 +1195,3 @@ async def generate_introductions(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate introductions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Save Complete Blog Asset
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class SaveCompleteBlogAssetRequest(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
seo_title: Optional[str] = None
|
||||
meta_description: Optional[str] = None
|
||||
focus_keyword: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
categories: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.post("/save-complete-asset")
|
||||
async def save_complete_blog_asset(
|
||||
request: SaveCompleteBlogAssetRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
"""Save the complete blog content as a single asset in the asset library."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
full_content = f"# {request.title}\n\n{request.content}"
|
||||
|
||||
asset_id = save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=full_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Published Blog: {request.title[:60]}",
|
||||
description=request.meta_description or f"Complete published blog post: {request.title}",
|
||||
prompt=f"SEO Title: {request.seo_title or request.title}\nFocus Keyword: {request.focus_keyword or ''}",
|
||||
tags=["blog", "published"] + [t for t in (request.tags or []) if t],
|
||||
asset_metadata={
|
||||
"status": "published",
|
||||
"focus_keyword": request.focus_keyword,
|
||||
"categories": request.categories,
|
||||
"word_count": len(full_content.split()),
|
||||
},
|
||||
subdirectory="published",
|
||||
file_extension=".md"
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"✅ Complete blog asset saved to library: ID={asset_id}")
|
||||
return {"success": True, "asset_id": asset_id}
|
||||
else:
|
||||
logger.warning("save_and_track_text_content returned None for published blog")
|
||||
return {"success": False, "error": "Failed to save blog asset"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save complete blog asset: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from services.database import get_session_for_user
|
||||
from services.database import SessionLocal, get_session_for_user
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -264,7 +264,7 @@ class TaskManager:
|
||||
raise ValueError("Global target words exceed 1000; medium generation not allowed")
|
||||
|
||||
# Create a sync session for asset saving
|
||||
db_session = get_session_for_user(user_id)
|
||||
db_session = SessionLocal()
|
||||
try:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
@@ -326,7 +326,6 @@ class TaskManager:
|
||||
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
|
||||
self.task_storage[task_id]["status"] = "failed"
|
||||
self.task_storage[task_id]["error"] = str(e)
|
||||
self.task_storage[task_id]["error_data"] = {"error_message": str(e), "error_type": type(e).__name__}
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
|
||||
@@ -9,27 +9,13 @@ from fastapi.responses import FileResponse
|
||||
from sqlalchemy.orm import Session
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from .step4_persona_routes import _extract_user_id
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
|
||||
def _extract_user_id(user: Dict[str, Any]) -> str:
|
||||
"""Extract a stable user ID from Clerk-authenticated user payloads.
|
||||
Prefers 'clerk_user_id' or 'id', falls back to 'user_id', else 'unknown'.
|
||||
"""
|
||||
if not isinstance(user, dict):
|
||||
return 'unknown'
|
||||
return (
|
||||
user.get('clerk_user_id')
|
||||
or user.get('id')
|
||||
or user.get('user_id')
|
||||
or 'unknown'
|
||||
)
|
||||
import base64
|
||||
import os
|
||||
from pathlib import Path
|
||||
from utils.file_storage import save_file_safely, generate_unique_filename
|
||||
from services.database import get_db
|
||||
from utils.storage_paths import get_user_workspace, sanitize_user_id
|
||||
from services.database import get_db, WORKSPACE_DIR
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from models.content_asset_models import ContentAsset, AssetType, AssetSource
|
||||
from sqlalchemy import desc
|
||||
@@ -87,8 +73,6 @@ async def get_latest_avatar(
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
|
||||
logger.warning(f"[latest-avatar] Looking for avatar for user_id: {user_id}")
|
||||
|
||||
# Search for assets that are either:
|
||||
# 1. Saved with source_module=BRAND_AVATAR_GENERATOR (new)
|
||||
# 2. Saved with source_module=STORY_WRITER but have metadata category='brand_avatar' (legacy)
|
||||
@@ -103,8 +87,6 @@ async def get_latest_avatar(
|
||||
])
|
||||
).order_by(desc(ContentAsset.created_at)).limit(50).all()
|
||||
|
||||
logger.warning(f"[latest-avatar] Found {len(candidates)} candidate(s)")
|
||||
|
||||
asset = None
|
||||
for candidate in candidates:
|
||||
# Check for direct match (new assets)
|
||||
@@ -185,7 +167,7 @@ async def generate_avatar(
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
|
||||
logger.warning(f"Generating avatar for user {user_id} with prompt: {request.prompt}")
|
||||
logger.info(f"Generating avatar for user {user_id} with prompt: {request.prompt}")
|
||||
|
||||
# 1. Generate Image
|
||||
result = await generate_image_with_provider(
|
||||
@@ -235,7 +217,7 @@ async def generate_avatar(
|
||||
content_to_save = base64.b64decode(image_data) if isinstance(image_data, str) else image_data
|
||||
|
||||
# Construct user assets directory
|
||||
user_assets_dir = get_user_workspace(user_id) / "assets" / "avatars"
|
||||
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
|
||||
|
||||
saved_path, error = save_file_safely(
|
||||
content_to_save,
|
||||
@@ -288,7 +270,7 @@ async def enhance_prompt_route(
|
||||
"""Enhance a simple prompt into a detailed midjourney-style prompt."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.warning(f"Enhancing prompt for user {user_id}: {request.prompt}")
|
||||
logger.info(f"Enhancing prompt for user {user_id}: {request.prompt}")
|
||||
|
||||
enhanced_prompt = await enhance_image_prompt(request.prompt, user_id=user_id)
|
||||
|
||||
@@ -312,7 +294,7 @@ async def create_variation_route(
|
||||
"""Generate a variation of an existing avatar."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.warning(f"Creating variation for user {user_id} with prompt: {prompt}")
|
||||
logger.info(f"Creating variation for user {user_id} with prompt: {prompt}")
|
||||
|
||||
# Read file
|
||||
file_content = await file.read()
|
||||
@@ -333,7 +315,7 @@ async def create_variation_route(
|
||||
content_to_save = base64.b64decode(image_data)
|
||||
|
||||
# Construct user assets directory
|
||||
user_assets_dir = get_user_workspace(user_id) / "assets" / "avatars"
|
||||
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
|
||||
|
||||
saved_path, error = save_file_safely(
|
||||
content_to_save,
|
||||
@@ -387,7 +369,7 @@ async def enhance_avatar_route(
|
||||
"""Enhance/Upscale an existing avatar."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.warning(f"Enhancing avatar for user {user_id}")
|
||||
logger.info(f"Enhancing avatar for user {user_id}")
|
||||
|
||||
# Read file
|
||||
file_content = await file.read()
|
||||
@@ -407,7 +389,7 @@ async def enhance_avatar_route(
|
||||
content_to_save = base64.b64decode(image_data)
|
||||
|
||||
# Construct user assets directory
|
||||
user_assets_dir = get_user_workspace(user_id) / "assets" / "avatars"
|
||||
user_assets_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "avatars"
|
||||
|
||||
saved_path, error = save_file_safely(
|
||||
content_to_save,
|
||||
@@ -464,13 +446,13 @@ async def create_voice_clone(
|
||||
"""Create a voice clone from an audio file."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.warning(f"[VoiceClone] Creating voice clone '{voice_name}' (engine={engine}) for user {user_id}")
|
||||
logger.info(f"Creating voice clone '{voice_name}' (engine={engine}) for user {user_id}")
|
||||
|
||||
# 1. Save uploaded audio file
|
||||
file_content = await file.read()
|
||||
filename = generate_unique_filename("voice_sample", Path(file.filename).suffix.lstrip("."))
|
||||
|
||||
user_voice_dir = get_user_workspace(user_id) / "assets" / "voice_samples"
|
||||
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
|
||||
saved_path, error = save_file_safely(file_content, user_voice_dir, filename)
|
||||
|
||||
if error or not saved_path:
|
||||
@@ -492,7 +474,7 @@ async def create_voice_clone(
|
||||
random_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
|
||||
custom_voice_id = f"vc_{random_suffix}"
|
||||
|
||||
logger.warning(f"Cloning voice with Minimax, ID: {custom_voice_id}")
|
||||
logger.info(f"Cloning voice with Minimax, ID: {custom_voice_id}")
|
||||
|
||||
# Run blocking call in executor
|
||||
result = await loop.run_in_executor(
|
||||
@@ -507,7 +489,7 @@ async def create_voice_clone(
|
||||
preview_audio_bytes = result.preview_audio_bytes
|
||||
|
||||
elif engine.lower() == "cosyvoice":
|
||||
logger.warning("Cloning voice with CosyVoice")
|
||||
logger.info("Cloning voice with CosyVoice")
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: cosyvoice_voice_clone(
|
||||
@@ -522,7 +504,7 @@ async def create_voice_clone(
|
||||
custom_voice_id = f"vc_cosy_{asset_uuid}"
|
||||
|
||||
else: # qwen3 (default)
|
||||
logger.warning("Cloning voice with Qwen3")
|
||||
logger.info("Cloning voice with Qwen3")
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: qwen3_voice_clone(
|
||||
@@ -538,48 +520,27 @@ async def create_voice_clone(
|
||||
|
||||
# 3. Save Preview Audio (if generated)
|
||||
preview_url = None
|
||||
preview_mime_type = "audio/wav"
|
||||
actual_filename = None # Default if preview save fails
|
||||
|
||||
if preview_audio_bytes and len(preview_audio_bytes) > 0:
|
||||
from utils.media_utils import detect_audio_format, ensure_audio_extension
|
||||
if preview_audio_bytes:
|
||||
preview_filename = f"preview_{filename}"
|
||||
# Ensure it ends with .wav
|
||||
if not preview_filename.endswith(".wav"):
|
||||
preview_filename = str(Path(preview_filename).with_suffix('.wav'))
|
||||
|
||||
detected_fmt, preview_mime_type = detect_audio_format(preview_audio_bytes)
|
||||
logger.warning(f"[VoiceClone] Detected preview audio format: {detected_fmt} ({preview_mime_type}), {len(preview_audio_bytes)} bytes")
|
||||
|
||||
# Build filename with correct extension based on actual content format
|
||||
original_stem = Path(filename).stem
|
||||
preview_filename = f"preview_{original_stem}"
|
||||
preview_filename = ensure_audio_extension(preview_filename, preview_audio_bytes)
|
||||
|
||||
user_voice_dir = get_user_workspace(user_id) / "assets" / "voice_samples"
|
||||
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
|
||||
saved_preview_path, error = save_file_safely(preview_audio_bytes, user_voice_dir, preview_filename)
|
||||
|
||||
if not error and saved_preview_path:
|
||||
# Use actual saved filename (may have UUID suffix added by save_file_safely)
|
||||
actual_filename = saved_preview_path.name
|
||||
preview_url = f"/api/assets/{user_id}/voice_samples/{actual_filename}"
|
||||
logger.warning(f"[VoiceClone] Saved preview: {actual_filename} ({saved_preview_path.stat().st_size} bytes, {preview_mime_type})")
|
||||
|
||||
# Verify file exists
|
||||
if not saved_preview_path.exists():
|
||||
logger.warning(f"[VoiceClone] Preview file does not exist after save: {saved_preview_path}")
|
||||
preview_url = None
|
||||
else:
|
||||
logger.warning(f"[VoiceClone] Failed to save preview audio: {error}")
|
||||
preview_url = f"/api/assets/{user_id}/voice_samples/{preview_filename}"
|
||||
|
||||
# 4. Save to Asset Library
|
||||
# Use the preview file (with corrected .wav extension) as the main asset file
|
||||
has_valid_preview = preview_audio_bytes and len(preview_audio_bytes) > 0 and saved_preview_path
|
||||
stored_filename = actual_filename if has_valid_preview else filename
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
file_path=file_path,
|
||||
asset_type="audio",
|
||||
source_module="voice_cloner",
|
||||
filename=stored_filename,
|
||||
file_url=f"/api/assets/{user_id}/voice_samples/{stored_filename}",
|
||||
filename=filename,
|
||||
file_url=f"/api/assets/{user_id}/voice_samples/{filename}",
|
||||
asset_metadata={
|
||||
"voice_name": voice_name,
|
||||
"engine": engine,
|
||||
@@ -594,7 +555,7 @@ async def create_voice_clone(
|
||||
return {
|
||||
"success": True,
|
||||
"custom_voice_id": custom_voice_id,
|
||||
"preview_audio_url": preview_url or f"/api/assets/{user_id}/voice_samples/{stored_filename}",
|
||||
"preview_audio_url": preview_url or f"/api/assets/{user_id}/voice_samples/{filename}",
|
||||
"asset_id": asset_id,
|
||||
"message": "Voice clone created successfully"
|
||||
}
|
||||
@@ -613,7 +574,7 @@ async def create_voice_design(
|
||||
"""Create a voice from text description (Voice Design)."""
|
||||
try:
|
||||
user_id = _extract_user_id(current_user)
|
||||
logger.warning(f"Designing voice for user {user_id}")
|
||||
logger.info(f"Designing voice for user {user_id}")
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
@@ -627,15 +588,9 @@ async def create_voice_design(
|
||||
)
|
||||
)
|
||||
|
||||
# Save the result to a file with correct extension based on content
|
||||
from utils.media_utils import detect_audio_format, ensure_audio_extension
|
||||
detected_fmt, mime_type = detect_audio_format(result.preview_audio_bytes)
|
||||
logger.warning(f"[VoiceDesign] Detected audio format: {detected_fmt} ({mime_type})")
|
||||
|
||||
filename = generate_unique_filename("voice_design_preview", detected_fmt)
|
||||
filename = ensure_audio_extension(filename, result.preview_audio_bytes)
|
||||
|
||||
user_voice_dir = get_user_workspace(user_id) / "assets" / "voice_samples"
|
||||
# Save the result to a temporary file
|
||||
filename = generate_unique_filename("voice_design_preview", "wav")
|
||||
user_voice_dir = Path(WORKSPACE_DIR) / f"workspace_{user_id}" / "assets" / "voice_samples"
|
||||
saved_path, error = save_file_safely(result.preview_audio_bytes, user_voice_dir, filename)
|
||||
|
||||
if error or not saved_path:
|
||||
|
||||
666
backend/api/podcast/broll_temp/README.md
Normal file
666
backend/api/podcast/broll_temp/README.md
Normal file
@@ -0,0 +1,666 @@
|
||||
# Programmatic B-Roll Composer
|
||||
|
||||
A layered video composition pipeline that assembles AI-generated images, programmatic data charts, Pillow text overlays, and circular-masked avatar videos into a single output MP4. Driven by structured JSON from an LLM, exposed via a FastAPI server.
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [Architecture overview](#1-architecture-overview)
|
||||
2. [File structure](#2-file-structure)
|
||||
3. [Installation](#3-installation)
|
||||
4. [Core concepts](#4-core-concepts)
|
||||
- 4.1 [The Insight dataclass](#41-the-insight-dataclass)
|
||||
- 4.2 [The SceneAssets dataclass](#42-the-sceneassets-dataclass)
|
||||
- 4.3 [The layer stack](#43-the-layer-stack)
|
||||
- 4.4 [The JSON bridge](#44-the-json-bridge)
|
||||
5. [Asset generators](#5-asset-generators)
|
||||
- 5.1 [Bar chart — make_bar_chart](#51-bar-chart--make_bar_chart)
|
||||
- 5.2 [Line trend — make_line_trend](#52-line-trend--make_line_trend)
|
||||
- 5.3 [Bullet overlay — make_bullet_overlay](#53-bullet-overlay--make_bullet_overlay)
|
||||
- 5.4 [Insight card — make_insight_card](#54-insight-card--make_insight_card)
|
||||
6. [Video effects](#6-video-effects)
|
||||
- 6.1 [Circular avatar mask — apply_circle_mask](#61-circular-avatar-mask--apply_circle_mask)
|
||||
- 6.2 [Ken Burns zoom — ken_burns](#62-ken-burns-zoom--ken_burns)
|
||||
7. [Scene builders](#7-scene-builders)
|
||||
- 7.1 [Data scene — build_data_scene](#71-data-scene--build_data_scene)
|
||||
- 7.2 [Bullet scene — build_bullet_scene](#72-bullet-scene--build_bullet_scene)
|
||||
- 7.3 [Full avatar scene — build_full_avatar_scene](#73-full-avatar-scene--build_full_avatar_scene)
|
||||
8. [Scene dispatcher — dispatch_scene](#8-scene-dispatcher--dispatch_scene)
|
||||
9. [Crossfade transitions](#9-crossfade-transitions)
|
||||
- 9.1 [How crossfade_concat works](#91-how-crossfade_concat-works)
|
||||
- 9.2 [The set_duration gotcha](#92-the-set_duration-gotcha)
|
||||
10. [Master compositor — compose_video](#10-master-compositor--compose_video)
|
||||
11. [FastAPI server](#11-fastapi-server)
|
||||
- 11.1 [Request models](#111-request-models)
|
||||
- 11.2 [Job lifecycle](#112-job-lifecycle)
|
||||
- 11.3 [API endpoints](#113-api-endpoints)
|
||||
12. [Running the project](#12-running-the-project)
|
||||
- 12.1 [Smoke test (no media files needed)](#121-smoke-test-no-media-files-needed)
|
||||
- 12.2 [Full video composition](#122-full-video-composition)
|
||||
- 12.3 [API server](#123-api-server)
|
||||
13. [Calling the API](#13-calling-the-api)
|
||||
14. [Production notes](#14-production-notes)
|
||||
15. [Extending the pipeline](#15-extending-the-pipeline)
|
||||
|
||||
---
|
||||
|
||||
## 1. Architecture overview
|
||||
|
||||
The pipeline follows a **Layered Composition** model. Rather than generating video in one pass, it assembles independent visual layers — each produced by the cheapest appropriate tool — into a single timeline using MoviePy as the compositor.
|
||||
|
||||
```
|
||||
LLM JSON output
|
||||
│
|
||||
▼
|
||||
dispatch_scene() ← routes visual_cue → builder function
|
||||
│
|
||||
├─ build_data_scene()
|
||||
│ ├─ ImageClip (background) ← AI-generated image
|
||||
│ ├─ ImageClip (chart PNG) ← Matplotlib, transparent bg
|
||||
│ ├─ ImageClip (insight card) ← Pillow RGBA
|
||||
│ └─ VideoFileClip (avatar) ← circular numpy mask
|
||||
│
|
||||
├─ build_bullet_scene()
|
||||
│ ├─ ImageClip (background)
|
||||
│ ├─ ImageClip (bullet overlay) ← Pillow RGBA
|
||||
│ └─ VideoFileClip (avatar)
|
||||
│
|
||||
└─ build_full_avatar_scene()
|
||||
└─ VideoFileClip (full-screen)
|
||||
│
|
||||
▼
|
||||
crossfade_concat() ← dissolve between scenes
|
||||
│
|
||||
▼
|
||||
write_videofile() ← H.264 MP4 via ffmpeg
|
||||
```
|
||||
|
||||
The key design decision: charts and text are **never** rendered by a generative model. Matplotlib produces pixel-perfect data graphics from real numbers; Pillow renders crisp, deterministic text. Only the background and the talking-head avatar come from AI generation, minimising both cost and hallucination risk.
|
||||
|
||||
---
|
||||
|
||||
## 2. File structure
|
||||
|
||||
```
|
||||
.
|
||||
├── broll_composer.py # Core library — all composition logic
|
||||
├── api_server.py # FastAPI wrapper — HTTP interface to the pipeline
|
||||
└── requirements.txt # Python dependencies
|
||||
```
|
||||
|
||||
`broll_composer.py` has no FastAPI dependency and can be imported and called directly from scripts, notebooks, or other web frameworks.
|
||||
|
||||
---
|
||||
|
||||
## 3. Installation
|
||||
|
||||
```bash
|
||||
# System dependency — must be on PATH
|
||||
apt-get install ffmpeg
|
||||
|
||||
# Python packages
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
**requirements.txt**
|
||||
|
||||
```
|
||||
moviepy==1.0.3
|
||||
Pillow>=10.0
|
||||
matplotlib>=3.8
|
||||
numpy>=1.26
|
||||
fastapi>=0.111
|
||||
uvicorn[standard]>=0.29
|
||||
python-multipart>=0.0.9
|
||||
```
|
||||
|
||||
MoviePy 1.0.3 is pinned because 2.x introduced breaking API changes to `CompositeVideoClip` and the effects interface. The rest can float within the specified lower bounds.
|
||||
|
||||
---
|
||||
|
||||
## 4. Core concepts
|
||||
|
||||
### 4.1 The Insight dataclass
|
||||
|
||||
Every scene is driven by a single `Insight` object. This is the contract between the LLM and the composition pipeline:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class Insight:
|
||||
key_insight: str # Headline text rendered on the insight card
|
||||
supporting_stat: str # Sub-headline rendered below the headline
|
||||
visual_cue: str # Selects which scene builder to use (see §8)
|
||||
audio_tone: str # Passed through for downstream TTS / audio selection
|
||||
chart_data: dict # Data payload consumed by chart generators (see §5)
|
||||
duration: float # Scene length in seconds, default 10.0
|
||||
```
|
||||
|
||||
The `audio_tone` field is not used by the video pipeline itself — it is metadata for whatever system generates or selects the voiceover audio track for the scene.
|
||||
|
||||
### 4.2 The SceneAssets dataclass
|
||||
|
||||
`SceneAssets` carries file paths to the media assets for a given scene:
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class SceneAssets:
|
||||
background_img: str # Required — path to JPEG or PNG background
|
||||
chart_img: Optional[str] # Populated by dispatch_scene after chart generation
|
||||
avatar_video: Optional[str] # Optional — path to MP4 avatar clip
|
||||
bullet_img: Optional[str] # Reserved for pre-rendered bullet overlays
|
||||
```
|
||||
|
||||
`chart_img` starts as `None` and is written to by `dispatch_scene` after it generates the Matplotlib PNG, so the scene builders receive a fully-populated `SceneAssets` by the time they run.
|
||||
|
||||
### 4.3 The layer stack
|
||||
|
||||
Every scene is a `CompositeVideoClip` — a MoviePy object that renders multiple clips on a shared canvas by alpha-compositing them bottom-to-top. The layer order is consistent across all scene types:
|
||||
|
||||
| Z-order | Layer | Source | Notes |
|
||||
|---------|-------|--------|-------|
|
||||
| 0 (bottom) | Background | AI image + Ken Burns | Darkened to make overlays legible |
|
||||
| 1 | Chart or bullet overlay | Matplotlib or Pillow PNG | Transparent background; fades in |
|
||||
| 2 | Insight card | Pillow RGBA | Positioned at y=820 (near bottom) |
|
||||
| 3 (top) | Avatar circle | MP4 + numpy mask | Bottom-right corner |
|
||||
|
||||
### 4.4 The JSON bridge
|
||||
|
||||
The LLM is prompted to return a structured JSON object — not prose — so the pipeline can consume it without parsing ambiguity:
|
||||
|
||||
```json
|
||||
{
|
||||
"key_insight": "AI tools reduced content cycles by 40%",
|
||||
"supporting_stat": "HubSpot 2026 report — 12% lift in CTR",
|
||||
"visual_cue": "bar_chart_comparison",
|
||||
"audio_tone": "authoritative_and_surprising",
|
||||
"duration": 10.0,
|
||||
"chart_data": {
|
||||
"labels": ["Content Velocity", "CTR", "Engagement", "Cost/Lead"],
|
||||
"before": [30, 22, 18, 60],
|
||||
"after": [72, 34, 41, 38]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`pipeline_from_json()` is the single-call entry point that accepts this JSON string, constructs the dataclasses, runs `dispatch_scene`, and writes the output MP4.
|
||||
|
||||
---
|
||||
|
||||
## 5. Asset generators
|
||||
|
||||
These functions produce static image files (PNG with alpha transparency) that are loaded as `ImageClip` objects in the scene builders. They are completely independent of MoviePy and can be called and previewed without assembling any video.
|
||||
|
||||
### 5.1 Bar chart — `make_bar_chart`
|
||||
|
||||
```python
|
||||
make_bar_chart(data: dict, out_path: str, title: str = "") -> str
|
||||
```
|
||||
|
||||
Produces a side-by-side "before vs after" bar chart using Matplotlib. The critical detail is the renderer configuration and save parameters:
|
||||
|
||||
```python
|
||||
matplotlib.use("Agg") # Non-interactive backend — no display required
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none") # Transparent axes background
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
```
|
||||
|
||||
Setting both `facecolor="none"` on the figure and `transparent=True` on `savefig` is necessary because they control different things: the figure background and the PNG alpha channel respectively. Without both, a white box appears behind the chart when it is composited over the video background.
|
||||
|
||||
**Expected `data` shape:**
|
||||
|
||||
```python
|
||||
{
|
||||
"labels": ["Category A", "Category B"], # X-axis labels
|
||||
"before": [30, 22], # Bar heights (left bars)
|
||||
"after": [72, 34] # Bar heights (right bars)
|
||||
}
|
||||
```
|
||||
|
||||
### 5.2 Line trend — `make_line_trend`
|
||||
|
||||
```python
|
||||
make_line_trend(data: dict, out_path: str, title: str = "") -> str
|
||||
```
|
||||
|
||||
Produces a time-series line chart with a translucent fill under the curve (`alpha=0.12`). Suited for growth trends, adoption curves, and any metric tracked over sequential time periods.
|
||||
|
||||
**Expected `data` shape:**
|
||||
|
||||
```python
|
||||
{
|
||||
"x": [2021, 2022, 2023, 2024, 2025], # X-axis values (numeric or strings)
|
||||
"y": [10, 18, 30, 45, 72] # Y-axis values
|
||||
}
|
||||
```
|
||||
|
||||
### 5.3 Bullet overlay — `make_bullet_overlay`
|
||||
|
||||
```python
|
||||
make_bullet_overlay(lines: list[str], out_path: str,
|
||||
width: int = 900, font_size: int = 32) -> str
|
||||
```
|
||||
|
||||
Renders a list of bullet-point strings onto a semi-transparent dark rounded rectangle using Pillow. The image height is computed dynamically from the number of lines:
|
||||
|
||||
```python
|
||||
img_h = padding * 2 + len(lines) * line_h + 12
|
||||
```
|
||||
|
||||
The fill colour `(10, 10, 10, 185)` gives roughly 73% opacity — dark enough for text legibility over any background, light enough that the background remains visible. The bullet character (`•`) is prepended in Python rather than in the font, so no special Unicode font support is required.
|
||||
|
||||
Font loading tries the DejaVu Sans Bold path common on Debian/Ubuntu systems, falling back to Pillow's built-in bitmap font if the TTF is absent.
|
||||
|
||||
### 5.4 Insight card — `make_insight_card`
|
||||
|
||||
```python
|
||||
make_insight_card(insight: str, stat: str, out_path: str,
|
||||
width: int = 960, height: int = 200) -> str
|
||||
```
|
||||
|
||||
Renders a two-line card: a large bold headline (`font_size=34`) and a smaller supporting stat line (`font_size=20`). A solid red rectangle (`#E63946`) is drawn as a left-edge accent bar — a visual device borrowed from print editorial design that gives the card a distinct identity when overlaid on varied backgrounds.
|
||||
|
||||
The card uses `fill=(10, 10, 10, 200)` — approximately 78% opacity — slightly more opaque than the bullet overlay because the headline text is denser.
|
||||
|
||||
---
|
||||
|
||||
## 6. Video effects
|
||||
|
||||
### 6.1 Circular avatar mask — `apply_circle_mask`
|
||||
|
||||
```python
|
||||
apply_circle_mask(clip: VideoFileClip, diameter: int) -> VideoFileClip
|
||||
```
|
||||
|
||||
Takes an MP4 avatar clip and returns it with a circular alpha mask applied, so only the circle region is visible when the clip is composited over other layers.
|
||||
|
||||
The mask is built using NumPy's `ogrid`, which creates coordinate arrays without materialising a full mesh:
|
||||
|
||||
```python
|
||||
Y, X = np.ogrid[:h, :w]
|
||||
cx, cy = w / 2, h / 2
|
||||
mask_arr = ((X - cx)**2 + (Y - cy)**2 <= (min(w, h) / 2)**2).astype(float)
|
||||
```
|
||||
|
||||
This produces a 2D float array (values 0.0 or 1.0) where all pixels within the inscribed circle are 1 (opaque) and all pixels outside are 0 (transparent). MoviePy requires mask arrays in this float format — it does not accept uint8 or boolean arrays directly.
|
||||
|
||||
The mask array is wrapped in an `ImageClip` with `ismask=True` and the duration is set to match the source clip before calling `clip.set_mask()`.
|
||||
|
||||
**Why not use imagemagick or a pre-made circular PNG?** The numpy approach has no subprocess dependency, works for any input resolution, and the mask is computed once and reused for every frame without disk I/O.
|
||||
|
||||
### 6.2 Ken Burns zoom — `ken_burns`
|
||||
|
||||
```python
|
||||
ken_burns(clip: ImageClip, zoom_ratio: float = 0.08) -> ImageClip
|
||||
```
|
||||
|
||||
Applies a slow continuous zoom-in to a static image clip, creating the illusion of camera movement. This prevents the background from looking visually "dead" during the scene.
|
||||
|
||||
The implementation uses `clip.fl()`, MoviePy's frame-level transform function, which receives both `get_frame` (a callable that returns the frame array at time `t`) and the current time `t`:
|
||||
|
||||
```python
|
||||
def zoom_frame(get_frame, t):
|
||||
frame = get_frame(t)
|
||||
frac = 1 + zoom_ratio * (t / clip.duration) # grows from 1.0 to 1+zoom_ratio
|
||||
h, w = frame.shape[:2]
|
||||
new_h, new_w = int(h / frac), int(w / frac) # shrink crop window
|
||||
y1 = (h - new_h) // 2 # center the crop
|
||||
x1 = (w - new_w) // 2
|
||||
cropped = frame[y1:y1 + new_h, x1:x1 + new_w]
|
||||
return np.array(Image.fromarray(cropped).resize((w, h), Image.LANCZOS))
|
||||
```
|
||||
|
||||
At `t=0`, `frac=1.0` so the crop is the full frame. At `t=duration`, `frac=1+zoom_ratio` so the crop is slightly smaller, and upscaling it back to full resolution creates the zoom effect. `zoom_ratio=0.08` means an 8% zoom over the full duration — perceptible but not distracting.
|
||||
|
||||
`apply_to=["mask"]` passes the same transform to the mask channel if one is present, keeping the mask geometrically in sync with the image.
|
||||
|
||||
---
|
||||
|
||||
## 7. Scene builders
|
||||
|
||||
Scene builders assemble the layers for a given `visual_cue` type into a `CompositeVideoClip`. Each builder follows the same pattern: build layers bottom-to-top, append to a list, return `CompositeVideoClip(layers, size=bg.size).set_duration(d)`.
|
||||
|
||||
The explicit `.set_duration(d)` on the return value is mandatory — see [§9.2](#92-the-set_duration-gotcha) for why.
|
||||
|
||||
### 7.1 Data scene — `build_data_scene`
|
||||
|
||||
Used for `visual_cue` values `bar_chart_comparison` and `line_trend`. The most information-dense layout:
|
||||
|
||||
- **Background**: full-canvas `ImageClip`, Ken Burns zoom at 8%, brightness reduced by 40 units via `vfx.lum_contrast(0, -40)`.
|
||||
- **Chart**: resized to 700px wide, centred horizontally, positioned 180px from the top. Fades in over 0.6s starting at `t=0.5` and fades out over 0.4s at the end.
|
||||
- **Insight card**: centred horizontally at y=820 (approximately the lower fifth of a 1080p frame). Fades in over 0.5s.
|
||||
- **Avatar**: circular-masked at 240px diameter, positioned 40px from the bottom-right corner (`bg.w - 280, bg.h - 280`).
|
||||
|
||||
### 7.2 Bullet scene — `build_bullet_scene`
|
||||
|
||||
Used for `visual_cue` value `bullet_points`. A simpler layout suited to lists of supporting facts:
|
||||
|
||||
- **Background**: Ken Burns at 5% zoom (slower than the data scene — more contemplative pacing), brightness reduced by 50 units.
|
||||
- **Bullet overlay**: rendered by `make_bullet_overlay`, centred both horizontally and vertically, fades in over 0.7s.
|
||||
- **Avatar**: circular-masked at 200px diameter (slightly smaller than in the data scene), positioned 40px from the bottom-right corner.
|
||||
|
||||
If `bullet_lines` is not provided by the caller, the builder falls back to using `insight.key_insight` and `insight.supporting_stat` as two bullet points.
|
||||
|
||||
### 7.3 Full avatar scene — `build_full_avatar_scene`
|
||||
|
||||
Used for `visual_cue` value `full_avatar`. The "Hook" scene — designed to open a piece with a direct-to-camera delivery that grabs attention before the data arrives. No overlays; the avatar fills the entire frame:
|
||||
|
||||
```python
|
||||
avatar = VideoFileClip(assets.avatar_video).subclip(0, d)
|
||||
return avatar.resize(height=1080).set_duration(d)
|
||||
```
|
||||
|
||||
This is the only scene type that does not use a `CompositeVideoClip` — it returns a `VideoFileClip` directly. The explicit `.set_duration(d)` is still applied (see §9.2).
|
||||
|
||||
---
|
||||
|
||||
## 8. Scene dispatcher — `dispatch_scene`
|
||||
|
||||
```python
|
||||
dispatch_scene(insight: Insight, assets: SceneAssets,
|
||||
bullet_lines: Optional[list[str]] = None) -> CompositeVideoClip
|
||||
```
|
||||
|
||||
The dispatcher is the JSON bridge's execution layer. It reads `insight.visual_cue` and routes to the correct builder, generating any intermediate assets (charts) along the way:
|
||||
|
||||
```
|
||||
visual_cue value Action
|
||||
─────────────────────────────────────────────────────
|
||||
"full_avatar" → build_full_avatar_scene()
|
||||
"bar_chart_comparison" → make_bar_chart() → build_data_scene()
|
||||
"line_trend" → make_line_trend() → build_data_scene()
|
||||
"bullet_points" → build_bullet_scene()
|
||||
<anything else> → build_data_scene() with no chart (fallback)
|
||||
```
|
||||
|
||||
Chart PNGs are written to `/tmp/chart.png`. This is intentionally a fixed path — each call overwrites the previous chart, which is fine because `dispatch_scene` is called sequentially per scene. If scenes are ever parallelised, use a `job_id`-prefixed temp path instead.
|
||||
|
||||
---
|
||||
|
||||
## 9. Crossfade transitions
|
||||
|
||||
### 9.1 How `crossfade_concat` works
|
||||
|
||||
```python
|
||||
def crossfade_concat(scenes: list, fade_dur: float = 0.5) -> CompositeVideoClip:
|
||||
faded = []
|
||||
for i, clip in enumerate(scenes):
|
||||
c = clip
|
||||
if i > 0:
|
||||
c = c.fx(vfx.crossfadein, fade_dur)
|
||||
faded.append(c)
|
||||
return concatenate_videoclips(faded, padding=-fade_dur, method="compose")
|
||||
```
|
||||
|
||||
`vfx.crossfadein` makes a clip's opacity ramp from 0 to 1 over `fade_dur` seconds from its start point. This handles the incoming side of the dissolve.
|
||||
|
||||
`padding=-fade_dur` is the critical parameter. By default, `concatenate_videoclips` places each clip immediately after the previous one ends. A negative padding shifts each clip left by `fade_dur` seconds, so it starts while the previous clip is still playing. The overlap window is exactly `fade_dur` seconds, which matches the duration of the `crossfadein` effect — this is what produces a dissolve rather than a hard cut or a gap.
|
||||
|
||||
`method="compose"` tells MoviePy to use `CompositeVideoClip` internally for the overlapping portions rather than trying to blend frames at the pixel level, which is how the alpha ramp from `crossfadein` is correctly respected.
|
||||
|
||||
The default `fade_dur` of `0.5s` is appropriate for fast-paced content. Increase to `0.8–1.0s` for a more cinematic feel. The total output duration is `sum(scene.duration for scene in scenes) - (len(scenes) - 1) * fade_dur`.
|
||||
|
||||
### 9.2 The `set_duration` gotcha
|
||||
|
||||
`CompositeVideoClip` infers its total duration by scanning the durations of all constituent clips. When sub-clips have `set_start` offsets — such as the chart clip which starts at `t=0.5` and has a duration of `d - 1.5`, and the insight card which starts at `t=0.5` with a duration of `d - 1.0` — MoviePy computes the composite's duration as `max(clip.start + clip.duration for clip in layers)`.
|
||||
|
||||
In most cases this yields a value slightly larger than `d` due to floating-point arithmetic on the offset calculations, or occasionally slightly smaller if a sub-clip ends fractionally before the background. Either error causes `crossfade_concat`'s `padding=-fade_dur` overlap to be miscalculated, typically producing a black flash frame at each scene boundary.
|
||||
|
||||
The fix is to explicitly call `.set_duration(d)` on every scene builder's return value, overriding the inferred value with the authoritative duration from the `Insight`:
|
||||
|
||||
```python
|
||||
return CompositeVideoClip(layers, size=bg.size).set_duration(d)
|
||||
```
|
||||
|
||||
This must be applied to all three builders, including `build_full_avatar_scene`, because a `resize()` call on a `VideoFileClip` creates a new clip object whose duration is re-derived from the source — it does not inherit the `subclip(0, d)` duration reliably on all platforms.
|
||||
|
||||
---
|
||||
|
||||
## 10. Master compositor — `compose_video`
|
||||
|
||||
```python
|
||||
def compose_video(scenes: list, output_path: str = "output.mp4",
|
||||
fps: int = 24, fade_dur: float = 0.5) -> str
|
||||
```
|
||||
|
||||
The final assembly step. Calls `crossfade_concat` to produce the dissolved timeline, then writes to an H.264 MP4 via MoviePy's `write_videofile`:
|
||||
|
||||
```python
|
||||
final.write_videofile(
|
||||
output_path,
|
||||
fps=fps,
|
||||
codec="libx264",
|
||||
audio_codec="aac",
|
||||
threads=4,
|
||||
preset="fast",
|
||||
logger=None,
|
||||
)
|
||||
```
|
||||
|
||||
`preset="fast"` is a reasonable default for a production pipeline — it is significantly faster than `slow` or `medium` with only a marginal quality difference at typical web streaming bitrates. Change to `slow` for archive-quality output. `logger=None` suppresses the verbose ffmpeg progress output; remove it during debugging.
|
||||
|
||||
`threads=4` maps to ffmpeg's `-threads` flag. Increase if the host has more cores available. This affects the encoding step only — MoviePy's frame rendering is single-threaded.
|
||||
|
||||
---
|
||||
|
||||
## 11. FastAPI server
|
||||
|
||||
`api_server.py` wraps the composition pipeline behind an HTTP API, enabling it to be called from any frontend, automation script, or orchestration system.
|
||||
|
||||
### 11.1 Request models
|
||||
|
||||
**`InsightPayload`** — mirrors the `Insight` dataclass with Pydantic validation:
|
||||
|
||||
| Field | Type | Constraints | Description |
|
||||
|-------|------|-------------|-------------|
|
||||
| `key_insight` | str | required | Headline text |
|
||||
| `supporting_stat` | str | required | Sub-headline text |
|
||||
| `visual_cue` | str | required | Scene template selector |
|
||||
| `audio_tone` | str | required | Downstream audio metadata |
|
||||
| `duration` | float | 3.0–60.0 | Scene length in seconds |
|
||||
| `chart_data` | dict | optional | Data payload for chart generators |
|
||||
| `bullet_lines` | list[str] | optional | Explicit bullet text (overrides defaults) |
|
||||
|
||||
**`ComposeRequest`** — the top-level request body:
|
||||
|
||||
| Field | Type | Default | Description |
|
||||
|-------|------|---------|-------------|
|
||||
| `insights` | list[InsightPayload] | required | Ordered list of scenes |
|
||||
| `fps` | int | 24 | Output frame rate (12–60) |
|
||||
| `fade_dur` | float | 0.5 | Crossfade duration in seconds (0.0–2.0) |
|
||||
|
||||
**`JobStatus`** — the response model for job tracking:
|
||||
|
||||
| Field | Values | Description |
|
||||
|-------|--------|-------------|
|
||||
| `job_id` | UUID hex string | Unique identifier for polling |
|
||||
| `status` | `queued`, `processing`, `done`, `error` | Current state |
|
||||
| `output_url` | `/download/{job_id}` or null | Available when `status == "done"` |
|
||||
| `error` | string or null | Error message when `status == "error"` |
|
||||
|
||||
### 11.2 Job lifecycle
|
||||
|
||||
Video composition is CPU-intensive and typically takes 30–120 seconds for a multi-scene piece. The API uses FastAPI's `BackgroundTasks` to run composition asynchronously so the HTTP response is immediate:
|
||||
|
||||
```
|
||||
POST /compose
|
||||
│
|
||||
├─ Validates payload, saves uploaded files to /tmp/broll_jobs/{job_id}/
|
||||
├─ Creates JobStatus(status="queued")
|
||||
├─ Registers BackgroundTask → _compose_worker()
|
||||
└─ Returns 202 Accepted with job_id
|
||||
|
||||
_compose_worker() (background)
|
||||
│
|
||||
├─ Sets status = "processing"
|
||||
├─ Runs _sync_compose() in a thread pool (loop.run_in_executor)
|
||||
│ └─ Iterates insights → dispatch_scene() → compose_video()
|
||||
├─ On success: status = "done", output_url = "/download/{job_id}"
|
||||
└─ On error: status = "error", error = str(exc)
|
||||
|
||||
GET /status/{job_id} ← poll until status == "done" or "error"
|
||||
|
||||
GET /download/{job_id} ← returns MP4 file
|
||||
```
|
||||
|
||||
`loop.run_in_executor(None, _sync_compose)` is important: MoviePy's frame rendering and ffmpeg's encoding are blocking operations. Running them directly in an `async` function would block the entire event loop. `run_in_executor` offloads the work to a thread pool, keeping the server responsive to other requests during composition.
|
||||
|
||||
The job store is currently a plain Python dict (`_jobs`). This is appropriate for a single-worker development server. Replace with Redis (using `aioredis` or `redis-py`) for multi-worker or multi-instance deployments.
|
||||
|
||||
### 11.3 API endpoints
|
||||
|
||||
| Method | Path | Description |
|
||||
|--------|------|-------------|
|
||||
| `POST` | `/compose` | Start a composition job (multipart form) |
|
||||
| `GET` | `/status/{job_id}` | Poll job status |
|
||||
| `GET` | `/download/{job_id}` | Download finished MP4 |
|
||||
| `POST` | `/preview/chart` | Generate and return a chart PNG (no video) |
|
||||
| `GET` | `/health` | Liveness check |
|
||||
|
||||
Interactive documentation is available at `http://localhost:8000/docs` once the server is running (FastAPI's built-in Swagger UI).
|
||||
|
||||
---
|
||||
|
||||
## 12. Running the project
|
||||
|
||||
### 12.1 Smoke test (no media files needed)
|
||||
|
||||
The smoke test validates all asset generators — chart PNGs, bullet overlays, and insight cards — without requiring any background images or avatar videos:
|
||||
|
||||
```bash
|
||||
python broll_composer.py
|
||||
```
|
||||
|
||||
Expected output:
|
||||
|
||||
```
|
||||
Chart saved → /tmp/demo_chart.png
|
||||
Bullets saved → /tmp/demo_bullets.png
|
||||
Insight card saved → /tmp/demo_card.png
|
||||
|
||||
Sample Insight JSON: { ... }
|
||||
|
||||
All asset generation tests passed.
|
||||
To run full video composition, supply real background_img and avatar_video paths.
|
||||
```
|
||||
|
||||
Inspect the PNG files in `/tmp/` to visually verify chart rendering before running the full pipeline.
|
||||
|
||||
### 12.2 Full video composition
|
||||
|
||||
```python
|
||||
from broll_composer import pipeline_from_json
|
||||
|
||||
insight_json = """{
|
||||
"key_insight": "AI reduced production time by 40%",
|
||||
"supporting_stat": "HubSpot 2026: 12% CTR lift",
|
||||
"visual_cue": "bar_chart_comparison",
|
||||
"audio_tone": "authoritative_and_surprising",
|
||||
"duration": 10.0,
|
||||
"chart_data": {
|
||||
"labels": ["Content Velocity", "CTR", "Engagement", "Cost/Lead"],
|
||||
"before": [30, 22, 18, 60],
|
||||
"after": [72, 34, 41, 38]
|
||||
}
|
||||
}"""
|
||||
|
||||
output_path = pipeline_from_json(
|
||||
insight_json,
|
||||
background_img="path/to/background.jpg",
|
||||
avatar_video="path/to/avatar.mp4", # optional
|
||||
)
|
||||
print(f"Video written to {output_path}")
|
||||
```
|
||||
|
||||
### 12.3 API server
|
||||
|
||||
```bash
|
||||
uvicorn api_server:app --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
For development with auto-reload:
|
||||
|
||||
```bash
|
||||
uvicorn api_server:app --reload
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 13. Calling the API
|
||||
|
||||
The `/compose` endpoint accepts `multipart/form-data` with three parts: `payload` (JSON string), `background` (image file), and optionally `avatar` (video file).
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8000/compose \
|
||||
-F 'payload={
|
||||
"insights": [{
|
||||
"key_insight": "AI reduced production time by 40%",
|
||||
"supporting_stat": "HubSpot 2026: 12% CTR lift",
|
||||
"visual_cue": "bar_chart_comparison",
|
||||
"audio_tone": "authoritative_and_surprising",
|
||||
"duration": 10.0,
|
||||
"chart_data": {
|
||||
"labels": ["Velocity","CTR","Engagement","Cost/Lead"],
|
||||
"before": [30, 22, 18, 60],
|
||||
"after": [72, 34, 41, 38]
|
||||
}
|
||||
}],
|
||||
"fps": 24,
|
||||
"fade_dur": 0.5
|
||||
}' \
|
||||
-F 'background=@./bg.jpg' \
|
||||
-F 'avatar=@./avatar.mp4'
|
||||
```
|
||||
|
||||
This returns a `JobStatus` with a `job_id`. Poll for completion:
|
||||
|
||||
```bash
|
||||
curl http://localhost:8000/status/{job_id}
|
||||
# → {"job_id": "...", "status": "done", "output_url": "/download/..."}
|
||||
```
|
||||
|
||||
Download the finished video:
|
||||
|
||||
```bash
|
||||
curl -O http://localhost:8000/download/{job_id}
|
||||
```
|
||||
|
||||
Preview a chart without video assembly:
|
||||
|
||||
```bash
|
||||
curl -X POST "http://localhost:8000/preview/chart?title=My+Chart&chart_type=bar_chart_comparison" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"labels":["A","B"],"before":[30,22],"after":[72,34]}' \
|
||||
--output preview.png
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 14. Production notes
|
||||
|
||||
**Concurrency**: FastAPI's `BackgroundTasks` runs in the same process as the web server. Under concurrent load, multiple composition jobs will share the same thread pool, which can cause memory pressure (each MoviePy frame rendering buffers several seconds of uncompressed video). For production, move composition to a dedicated worker queue (Celery + Redis, or ARQ) and have the API server dispatch jobs to it rather than running them in-process.
|
||||
|
||||
**Temp file isolation**: Chart PNGs and insight card PNGs are written to fixed paths under `/tmp/`. This is safe for sequential processing but will cause race conditions if jobs are parallelised. Prefix all temp file paths with the `job_id` to isolate them:
|
||||
|
||||
```python
|
||||
chart_path = f"/tmp/{job_id}_chart.png"
|
||||
```
|
||||
|
||||
**Memory**: MoviePy loads entire video clips into memory for compositing. For scenes longer than ~30 seconds with a high-resolution avatar, memory use can reach several GB. If this is a concern, render scenes individually and use ffmpeg's `concat` demuxer to join them in a second pass rather than compositing them all in Python.
|
||||
|
||||
**ffmpeg version**: MoviePy 1.0.3 delegates encoding to ffmpeg. Versions prior to 4.x may not support all `preset` values or the `aac` codec without additional flags. The pipeline is tested against ffmpeg 5.x and 6.x.
|
||||
|
||||
**File cleanup**: Completed job files accumulate in `/tmp/broll_jobs/`. Add a cleanup background task or cron job that deletes job directories older than a configurable TTL (e.g. 1 hour).
|
||||
|
||||
---
|
||||
|
||||
## 15. Extending the pipeline
|
||||
|
||||
**Adding a new scene template**: add a builder function following the `build_*_scene` naming convention, then add a `visual_cue` string → function mapping in `dispatch_scene`. No other changes are needed.
|
||||
|
||||
**Adding a new chart type**: add a `make_*` function that writes a transparent PNG, then handle the new `visual_cue` in `dispatch_scene` by calling it before passing `assets` to a builder.
|
||||
|
||||
**Supporting multiple backgrounds per script**: `SceneAssets` currently takes a single `background_img`. To vary the background per scene, add a `background_img` field to `InsightPayload` in the API model and pass it through to `SceneAssets` in the compose worker.
|
||||
|
||||
**Audio**: the pipeline produces silent video. Attach a voiceover by loading it as a MoviePy `AudioFileClip`, setting its start time to align with each scene, and passing the composite audio to `final.set_audio()` before calling `write_videofile`.
|
||||
229
backend/api/podcast/broll_temp/api_server.py
Normal file
229
backend/api/podcast/broll_temp/api_server.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""
|
||||
FastAPI wrapper for the B-Roll Composer pipeline.
|
||||
POST /compose → triggers scene assembly, returns video download URL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
|
||||
from fastapi import FastAPI, UploadFile, File, Form, BackgroundTasks, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from broll_composer import (
|
||||
Insight, SceneAssets, dispatch_scene, compose_video,
|
||||
make_bar_chart, make_line_trend, make_bullet_overlay,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# App setup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
app = FastAPI(
|
||||
title="B-Roll Composer API",
|
||||
description="Programmatic video composition: Background + Chart + Avatar Circle",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
WORK_DIR = Path("/tmp/broll_jobs")
|
||||
WORK_DIR.mkdir(exist_ok=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Request / Response models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class InsightPayload(BaseModel):
|
||||
key_insight: str = Field(..., example="AI tools reduced content cycles by 40% in 2025.")
|
||||
supporting_stat: str = Field(..., example="HubSpot 2026 report cites a 12% lift in CTR.")
|
||||
visual_cue: str = Field(
|
||||
...,
|
||||
example="bar_chart_comparison",
|
||||
description="bar_chart_comparison | line_trend | bullet_points | full_avatar",
|
||||
)
|
||||
audio_tone: str = Field(..., example="authoritative_and_surprising")
|
||||
duration: float = Field(default=10.0, ge=3.0, le=60.0)
|
||||
chart_data: dict = Field(default_factory=dict)
|
||||
bullet_lines: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ComposeRequest(BaseModel):
|
||||
insights: List[InsightPayload]
|
||||
fps: int = Field(default=24, ge=12, le=60)
|
||||
fade_dur: float = Field(default=0.5, ge=0.0, le=2.0,
|
||||
description="Crossfade duration in seconds between scenes")
|
||||
|
||||
|
||||
class JobStatus(BaseModel):
|
||||
job_id: str
|
||||
status: str # queued | processing | done | error
|
||||
output_url: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# In-memory job store (replace with Redis in production)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_jobs: dict[str, JobStatus] = {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Background task: composition worker
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _compose_worker(
|
||||
job_id: str,
|
||||
request: ComposeRequest,
|
||||
bg_path: str,
|
||||
avatar_path: Optional[str],
|
||||
):
|
||||
job = _jobs[job_id]
|
||||
job.status = "processing"
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
out_path = str(WORK_DIR / f"{job_id}.mp4")
|
||||
|
||||
def _sync_compose():
|
||||
scenes = []
|
||||
for i, payload in enumerate(request.insights):
|
||||
insight = Insight(
|
||||
key_insight=payload.key_insight,
|
||||
supporting_stat=payload.supporting_stat,
|
||||
visual_cue=payload.visual_cue,
|
||||
audio_tone=payload.audio_tone,
|
||||
chart_data=payload.chart_data,
|
||||
duration=payload.duration,
|
||||
)
|
||||
assets = SceneAssets(
|
||||
background_img=bg_path,
|
||||
avatar_video=avatar_path,
|
||||
)
|
||||
scene = dispatch_scene(insight, assets, payload.bullet_lines)
|
||||
scenes.append(scene)
|
||||
|
||||
compose_video(scenes, output_path=out_path, fps=request.fps,
|
||||
fade_dur=request.fade_dur)
|
||||
return out_path
|
||||
|
||||
await loop.run_in_executor(None, _sync_compose)
|
||||
job.status = "done"
|
||||
job.output_url = f"/download/{job_id}"
|
||||
|
||||
except Exception as exc:
|
||||
job.status = "error"
|
||||
job.error = str(exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.post("/compose", response_model=JobStatus, status_code=202)
|
||||
async def start_compose(
|
||||
background_tasks: BackgroundTasks,
|
||||
payload: str = Form(..., description="JSON string matching ComposeRequest schema"),
|
||||
background: UploadFile = File(..., description="Background image (JPEG/PNG)"),
|
||||
avatar: Optional[UploadFile] = File(None, description="Avatar video (MP4) — optional"),
|
||||
):
|
||||
"""
|
||||
Kick off a video composition job.
|
||||
- **payload**: JSON body (ComposeRequest)
|
||||
- **background**: background image file
|
||||
- **avatar**: optional avatar video file
|
||||
Returns a job_id — poll GET /status/{job_id} for progress.
|
||||
"""
|
||||
try:
|
||||
request = ComposeRequest(**json.loads(payload))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=422, detail=f"Invalid payload: {e}")
|
||||
|
||||
job_id = uuid.uuid4().hex
|
||||
|
||||
# Save uploads
|
||||
job_dir = WORK_DIR / job_id
|
||||
job_dir.mkdir(exist_ok=True)
|
||||
|
||||
bg_path = str(job_dir / background.filename)
|
||||
with open(bg_path, "wb") as f:
|
||||
f.write(await background.read())
|
||||
|
||||
avatar_path = None
|
||||
if avatar:
|
||||
avatar_path = str(job_dir / avatar.filename)
|
||||
with open(avatar_path, "wb") as f:
|
||||
f.write(await avatar.read())
|
||||
|
||||
# Register job
|
||||
job = JobStatus(job_id=job_id, status="queued")
|
||||
_jobs[job_id] = job
|
||||
|
||||
# Launch background worker
|
||||
background_tasks.add_task(
|
||||
_compose_worker, job_id, request, bg_path, avatar_path
|
||||
)
|
||||
|
||||
return job
|
||||
|
||||
|
||||
@app.get("/status/{job_id}", response_model=JobStatus)
|
||||
async def get_status(job_id: str):
|
||||
"""Poll composition job status."""
|
||||
job = _jobs.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
return job
|
||||
|
||||
|
||||
@app.get("/download/{job_id}")
|
||||
async def download_video(job_id: str):
|
||||
"""Download the finished video."""
|
||||
job = _jobs.get(job_id)
|
||||
if not job:
|
||||
raise HTTPException(status_code=404, detail="Job not found")
|
||||
if job.status != "done":
|
||||
raise HTTPException(status_code=409, detail=f"Job status: {job.status}")
|
||||
|
||||
out_path = WORK_DIR / f"{job_id}.mp4"
|
||||
if not out_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Output file missing")
|
||||
|
||||
return FileResponse(
|
||||
path=str(out_path),
|
||||
media_type="video/mp4",
|
||||
filename=f"broll_{job_id}.mp4",
|
||||
)
|
||||
|
||||
|
||||
@app.post("/preview/chart")
|
||||
async def preview_chart(
|
||||
chart_data: dict,
|
||||
title: str = "",
|
||||
chart_type: str = "bar_chart_comparison",
|
||||
):
|
||||
"""Generate and return a chart PNG for preview (no video assembly)."""
|
||||
out = str(WORK_DIR / f"preview_{uuid.uuid4().hex}.png")
|
||||
if chart_type == "bar_chart_comparison":
|
||||
make_bar_chart(chart_data, out, title)
|
||||
else:
|
||||
make_line_trend(chart_data, out, title)
|
||||
return FileResponse(path=out, media_type="image/png")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
456
backend/api/podcast/broll_temp/broll_composer.py
Normal file
456
backend/api/podcast/broll_temp/broll_composer.py
Normal file
@@ -0,0 +1,456 @@
|
||||
"""
|
||||
Programmatic B-Roll Composer
|
||||
Layered composition pipeline: Background + Chart + Avatar Circle + Text Overlays
|
||||
"""
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from moviepy.editor import (
|
||||
VideoFileClip, ImageClip, CompositeVideoClip,
|
||||
TextClip, ColorClip, concatenate_videoclips,
|
||||
)
|
||||
import moviepy.video.fx.all as vfx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Crossfade concat (Option 1: crossfadein + negative padding)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def crossfade_concat(scenes: list, fade_dur: float = 0.5) -> CompositeVideoClip:
|
||||
"""
|
||||
Concatenate scenes with a dissolve transition between each pair.
|
||||
|
||||
Each clip (except the first) gets a crossfadein effect.
|
||||
padding=-fade_dur overlaps consecutive clips so the fade actually fires
|
||||
instead of creating a black gap. set_duration on every scene is
|
||||
mandatory — CompositeVideoClip.duration can be ambiguous without it,
|
||||
which makes the overlap math wrong.
|
||||
"""
|
||||
faded = []
|
||||
for i, clip in enumerate(scenes):
|
||||
c = clip
|
||||
if i > 0:
|
||||
c = c.fx(vfx.crossfadein, fade_dur)
|
||||
faded.append(c)
|
||||
return concatenate_videoclips(faded, padding=-fade_dur, method="compose")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@dataclass
|
||||
class Insight:
|
||||
key_insight: str
|
||||
supporting_stat: str
|
||||
visual_cue: str # bar_chart_comparison | line_trend | bullet_points | full_avatar
|
||||
audio_tone: str
|
||||
chart_data: dict = field(default_factory=dict)
|
||||
duration: float = 10.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneAssets:
|
||||
background_img: str
|
||||
chart_img: Optional[str] = None
|
||||
avatar_video: Optional[str] = None
|
||||
bullet_img: Optional[str] = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Chart generator (Matplotlib → PNG with transparency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CHART_STYLE = {
|
||||
"bg": "#0D0D0D",
|
||||
"bar_before": "#2E4057",
|
||||
"bar_after": "#E63946",
|
||||
"text": "#F1F1EF",
|
||||
"grid": "#2A2A2A",
|
||||
"accent": "#E63946",
|
||||
}
|
||||
|
||||
|
||||
def make_bar_chart(data: dict, out_path: str, title: str = "") -> str:
|
||||
"""Render a side-by-side comparison bar chart. Returns output path."""
|
||||
labels = data.get("labels", [])
|
||||
before = data.get("before", [])
|
||||
after = data.get("after", [])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
|
||||
x = np.arange(len(labels))
|
||||
w = 0.35
|
||||
bars_b = ax.bar(x - w / 2, before, w, color=CHART_STYLE["bar_before"],
|
||||
label="Before", zorder=3, edgecolor="none")
|
||||
bars_a = ax.bar(x + w / 2, after, w, color=CHART_STYLE["bar_after"],
|
||||
label="After", zorder=3, edgecolor="none")
|
||||
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(labels, color=CHART_STYLE["text"], fontsize=11)
|
||||
ax.tick_params(axis="y", colors=CHART_STYLE["text"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
ax.set_axisbelow(True)
|
||||
|
||||
# Value labels on bars
|
||||
for bar in [*bars_b, *bars_a]:
|
||||
h = bar.get_height()
|
||||
ax.text(bar.get_x() + bar.get_width() / 2, h + 0.5, f"{h:.0f}%",
|
||||
ha="center", va="bottom", color=CHART_STYLE["text"], fontsize=9,
|
||||
fontweight="bold")
|
||||
|
||||
legend = ax.legend(frameon=False, labelcolor=CHART_STYLE["text"],
|
||||
fontsize=10, loc="upper left")
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
def make_line_trend(data: dict, out_path: str, title: str = "") -> str:
|
||||
"""Render a trend line chart. Returns output path."""
|
||||
x_vals = data.get("x", [])
|
||||
y_vals = data.get("y", [])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 4.5), facecolor="none")
|
||||
ax.set_facecolor("none")
|
||||
ax.plot(x_vals, y_vals, color=CHART_STYLE["accent"],
|
||||
linewidth=2.5, marker="o", markersize=7, zorder=3)
|
||||
ax.fill_between(x_vals, y_vals, alpha=0.12, color=CHART_STYLE["accent"])
|
||||
ax.spines[:].set_visible(False)
|
||||
ax.tick_params(colors=CHART_STYLE["text"])
|
||||
ax.yaxis.grid(True, color=CHART_STYLE["grid"], linewidth=0.6, zorder=0)
|
||||
if title:
|
||||
ax.set_title(title, color=CHART_STYLE["text"], fontsize=13,
|
||||
fontweight="bold", pad=12)
|
||||
fig.tight_layout(pad=0.5)
|
||||
fig.savefig(out_path, dpi=150, transparent=True, bbox_inches="tight")
|
||||
plt.close(fig)
|
||||
return out_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Text / Bullet overlay (Pillow → PNG)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def make_bullet_overlay(lines: list[str], out_path: str,
|
||||
width: int = 900, font_size: int = 32) -> str:
|
||||
"""Render bullet points on a semi-transparent dark pill. Returns path."""
|
||||
padding = 32
|
||||
line_h = font_size + 16
|
||||
img_h = padding * 2 + len(lines) * line_h + 12
|
||||
img = Image.new("RGBA", (width, img_h), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
|
||||
# Semi-transparent background pill
|
||||
draw.rounded_rectangle([0, 0, width - 1, img_h - 1],
|
||||
radius=18, fill=(10, 10, 10, 185))
|
||||
|
||||
try:
|
||||
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf",
|
||||
font_size)
|
||||
except OSError:
|
||||
font = ImageFont.load_default()
|
||||
|
||||
y = padding
|
||||
for line in lines:
|
||||
draw.text((padding + 18, y), f"• {line}", font=font, fill=(241, 241, 239, 255))
|
||||
y += line_h
|
||||
|
||||
img.save(out_path, format="PNG")
|
||||
return out_path
|
||||
|
||||
|
||||
def make_insight_card(insight: str, stat: str, out_path: str,
|
||||
width: int = 960, height: int = 200) -> str:
|
||||
"""Render a bold insight card (headline + supporting stat). Returns path."""
|
||||
img = Image.new("RGBA", (width, height), (0, 0, 0, 0))
|
||||
draw = ImageDraw.Draw(img)
|
||||
draw.rounded_rectangle([0, 0, width - 1, height - 1],
|
||||
radius=14, fill=(10, 10, 10, 200))
|
||||
|
||||
# Red accent bar
|
||||
draw.rectangle([28, 24, 36, height - 24], fill=(230, 57, 70, 255))
|
||||
|
||||
try:
|
||||
font_lg = ImageFont.truetype(
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 34)
|
||||
font_sm = ImageFont.truetype(
|
||||
"/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
|
||||
except OSError:
|
||||
font_lg = font_sm = ImageFont.load_default()
|
||||
|
||||
draw.text((58, 36), insight, font=font_lg, fill=(241, 241, 239, 255))
|
||||
draw.text((58, 90), stat, font=font_sm, fill=(180, 180, 178, 230))
|
||||
|
||||
img.save(out_path, format="PNG")
|
||||
return out_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Circular avatar mask
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def apply_circle_mask(clip: VideoFileClip, diameter: int) -> VideoFileClip:
|
||||
"""Resize clip and apply a circular alpha mask."""
|
||||
clip = clip.resize(height=diameter)
|
||||
w, h = clip.size
|
||||
|
||||
# Build a circular mask array (1 = opaque, 0 = transparent)
|
||||
Y, X = np.ogrid[:h, :w]
|
||||
cx, cy = w / 2, h / 2
|
||||
mask_arr = ((X - cx) ** 2 + (Y - cy) ** 2 <= (min(w, h) / 2) ** 2).astype(float)
|
||||
|
||||
mask_clip = ImageClip(mask_arr, ismask=True).set_duration(clip.duration)
|
||||
return clip.set_mask(mask_clip)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ken Burns zoom effect
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def ken_burns(clip: ImageClip, zoom_ratio: float = 0.08) -> ImageClip:
|
||||
"""Apply a slow zoom-in over the clip duration."""
|
||||
def zoom_frame(get_frame, t):
|
||||
frame = get_frame(t)
|
||||
frac = 1 + zoom_ratio * (t / clip.duration)
|
||||
h, w = frame.shape[:2]
|
||||
new_h, new_w = int(h / frac), int(w / frac)
|
||||
y1 = (h - new_h) // 2
|
||||
x1 = (w - new_w) // 2
|
||||
cropped = frame[y1:y1 + new_h, x1:x1 + new_w]
|
||||
return np.array(Image.fromarray(cropped).resize((w, h), Image.LANCZOS))
|
||||
|
||||
return clip.fl(zoom_frame, apply_to=["mask"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scene builders (one per visual_cue type)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def build_data_scene(assets: SceneAssets, insight: Insight) -> CompositeVideoClip:
|
||||
"""
|
||||
Layout: Background (Ken Burns) + Chart (fade-in) + Avatar circle (corner) + Insight card
|
||||
"""
|
||||
d = insight.duration
|
||||
layers = []
|
||||
|
||||
# 1. Background
|
||||
bg = (ImageClip(assets.background_img)
|
||||
.set_duration(d)
|
||||
.resize(height=1080))
|
||||
bg = ken_burns(bg)
|
||||
bg = bg.fx(vfx.lum_contrast, 0, -40) # darken 40 units
|
||||
layers.append(bg)
|
||||
|
||||
# 2. Programmatic chart
|
||||
if assets.chart_img:
|
||||
chart = (ImageClip(assets.chart_img)
|
||||
.set_duration(d - 1.5)
|
||||
.set_start(0.5)
|
||||
.resize(width=700)
|
||||
.set_position(("center", 180))
|
||||
.fx(vfx.fadein, 0.6)
|
||||
.fx(vfx.fadeout, 0.4))
|
||||
layers.append(chart)
|
||||
|
||||
# 3. Insight card at bottom
|
||||
card_path = "/tmp/insight_card.png"
|
||||
make_insight_card(insight.key_insight, insight.supporting_stat, card_path)
|
||||
card = (ImageClip(card_path)
|
||||
.set_duration(d - 1)
|
||||
.set_start(0.5)
|
||||
.set_position(("center", 820))
|
||||
.fx(vfx.fadein, 0.5))
|
||||
layers.append(card)
|
||||
|
||||
# 4. Avatar circle (bottom-right corner)
|
||||
if assets.avatar_video:
|
||||
avatar_raw = VideoFileClip(assets.avatar_video).subclip(0, d)
|
||||
avatar = apply_circle_mask(avatar_raw, diameter=240)
|
||||
avatar = avatar.set_position((bg.w - 280, bg.h - 280))
|
||||
layers.append(avatar)
|
||||
|
||||
# set_duration is required: CompositeVideoClip infers duration from its
|
||||
# constituent clips, which can be ambiguous when sub-clips have set_start
|
||||
# offsets. Without this, crossfade_concat's overlap math goes wrong.
|
||||
return CompositeVideoClip(layers, size=bg.size).set_duration(d)
|
||||
|
||||
|
||||
def build_bullet_scene(assets: SceneAssets, insight: Insight,
|
||||
bullets: list[str]) -> CompositeVideoClip:
|
||||
"""
|
||||
Layout: AI image (Ken Burns) + Bullet overlay + Avatar circle
|
||||
"""
|
||||
d = insight.duration
|
||||
layers = []
|
||||
|
||||
bg = (ImageClip(assets.background_img)
|
||||
.set_duration(d)
|
||||
.resize(height=1080))
|
||||
bg = ken_burns(bg, zoom_ratio=0.05)
|
||||
bg = bg.fx(vfx.lum_contrast, 0, -50)
|
||||
layers.append(bg)
|
||||
|
||||
bullet_path = "/tmp/bullets.png"
|
||||
make_bullet_overlay(bullets, bullet_path, width=860)
|
||||
bullets_clip = (ImageClip(bullet_path)
|
||||
.set_duration(d - 1)
|
||||
.set_start(0.5)
|
||||
.set_position(("center", "center"))
|
||||
.fx(vfx.fadein, 0.7))
|
||||
layers.append(bullets_clip)
|
||||
|
||||
if assets.avatar_video:
|
||||
avatar_raw = VideoFileClip(assets.avatar_video).subclip(0, d)
|
||||
avatar = apply_circle_mask(avatar_raw, diameter=200)
|
||||
avatar = avatar.set_position((bg.w - 240, bg.h - 240))
|
||||
layers.append(avatar)
|
||||
|
||||
return CompositeVideoClip(layers, size=bg.size).set_duration(d)
|
||||
|
||||
|
||||
def build_full_avatar_scene(assets: SceneAssets, insight: Insight) -> VideoFileClip:
|
||||
"""Full-screen avatar — the expensive 'Hook' scene. No overlay."""
|
||||
d = insight.duration
|
||||
avatar = VideoFileClip(assets.avatar_video).subclip(0, d)
|
||||
return avatar.resize(height=1080).set_duration(d)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scene dispatcher — maps visual_cue → builder
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def dispatch_scene(insight: Insight, assets: SceneAssets,
|
||||
bullet_lines: Optional[list[str]] = None) -> CompositeVideoClip:
|
||||
cue = insight.visual_cue
|
||||
|
||||
if cue == "full_avatar":
|
||||
return build_full_avatar_scene(assets, insight)
|
||||
|
||||
elif cue in ("bar_chart_comparison", "line_trend"):
|
||||
chart_path = "/tmp/chart.png"
|
||||
if cue == "bar_chart_comparison":
|
||||
make_bar_chart(insight.chart_data, chart_path,
|
||||
title=insight.key_insight)
|
||||
else:
|
||||
make_line_trend(insight.chart_data, chart_path,
|
||||
title=insight.key_insight)
|
||||
assets.chart_img = chart_path
|
||||
return build_data_scene(assets, insight)
|
||||
|
||||
elif cue == "bullet_points":
|
||||
lines = bullet_lines or [insight.key_insight, insight.supporting_stat]
|
||||
return build_bullet_scene(assets, insight, lines)
|
||||
|
||||
else:
|
||||
# Fallback: data scene without chart
|
||||
return build_data_scene(assets, insight)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Master compositor — assembles all scenes into one video
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def compose_video(scenes: list, output_path: str = "output.mp4",
|
||||
fps: int = 24, fade_dur: float = 0.5) -> str:
|
||||
"""Concatenate scenes with crossfade transitions and write final video file."""
|
||||
final = crossfade_concat(scenes, fade_dur=fade_dur)
|
||||
final.write_videofile(
|
||||
output_path,
|
||||
fps=fps,
|
||||
codec="libx264",
|
||||
audio_codec="aac",
|
||||
threads=4,
|
||||
preset="fast",
|
||||
logger=None,
|
||||
)
|
||||
return output_path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JSON bridge — LLM insight → assets + scene
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def pipeline_from_json(insight_json: str,
|
||||
background_img: str,
|
||||
avatar_video: Optional[str] = None) -> str:
|
||||
"""
|
||||
Full pipeline:
|
||||
1. Parse LLM insight JSON
|
||||
2. Generate chart / overlay assets
|
||||
3. Build scene
|
||||
4. Write video
|
||||
Returns path to output video.
|
||||
"""
|
||||
data = json.loads(insight_json)
|
||||
insight = Insight(**{k: data[k] for k in Insight.__dataclass_fields__ if k in data})
|
||||
assets = SceneAssets(background_img=background_img, avatar_video=avatar_video)
|
||||
scene = dispatch_scene(insight, assets,
|
||||
bullet_lines=data.get("bullet_lines"))
|
||||
out = f"/tmp/scene_{insight.visual_cue}.mp4"
|
||||
compose_video([scene], output_path=out)
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Demo / smoke-test (no real media files needed for chart generation)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
# --- Test 1: Chart PNG generation only ---
|
||||
sample_bar_data = {
|
||||
"labels": ["Content Velocity", "CTR", "Engagement", "Cost/Lead"],
|
||||
"before": [30, 22, 18, 60],
|
||||
"after": [72, 34, 41, 38],
|
||||
}
|
||||
chart_out = make_bar_chart(
|
||||
sample_bar_data,
|
||||
"/tmp/demo_chart.png",
|
||||
title="AI Tools Impact: Before vs After (2025)",
|
||||
)
|
||||
print(f"Chart saved → {chart_out}")
|
||||
|
||||
# --- Test 2: Bullet overlay PNG ---
|
||||
bullets = [
|
||||
"AI reduced content cycles by 40% in 2025",
|
||||
"HubSpot: 12% lift in CTR with AI-assisted copy",
|
||||
"Video production cost down 3x with hybrid pipeline",
|
||||
]
|
||||
bullet_out = make_bullet_overlay(bullets, "/tmp/demo_bullets.png")
|
||||
print(f"Bullets saved → {bullet_out}")
|
||||
|
||||
# --- Test 3: Insight card PNG ---
|
||||
card_out = make_insight_card(
|
||||
"AI tools reduced content cycles by 40%",
|
||||
"HubSpot 2026 report — 12% lift in CTR",
|
||||
"/tmp/demo_card.png",
|
||||
)
|
||||
print(f"Insight card saved → {card_out}")
|
||||
|
||||
# --- Test 4: JSON bridge (chart only, no video files required) ---
|
||||
sample_json = json.dumps({
|
||||
"key_insight": "AI reduced production time by 40%",
|
||||
"supporting_stat": "HubSpot 2026: 12% CTR lift",
|
||||
"visual_cue": "bar_chart_comparison",
|
||||
"audio_tone": "authoritative_and_surprising",
|
||||
"duration": 8.0,
|
||||
"chart_data": sample_bar_data,
|
||||
})
|
||||
print("\nSample Insight JSON:\n", sample_json)
|
||||
print("\nAll asset generation tests passed.")
|
||||
print("To run full video composition, supply real background_img and avatar_video paths.")
|
||||
@@ -2,26 +2,34 @@
|
||||
Podcast API Constants
|
||||
|
||||
Centralized constants and directory configuration for podcast module.
|
||||
All workspace paths use utils.storage_paths for root resolution.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
from loguru import logger
|
||||
from services.story_writer.audio_generation_service import StoryAudioGenerationService
|
||||
from utils.storage_paths import get_repo_root, sanitize_user_id as _sanitize_user_id
|
||||
|
||||
ROOT_DIR = get_repo_root()
|
||||
# Directory paths
|
||||
# router.py is at: backend/api/podcast/router.py
|
||||
# parents[0] = backend/api/podcast/
|
||||
# parents[1] = backend/api/
|
||||
# parents[2] = backend/
|
||||
# parents[3] = root/
|
||||
ROOT_DIR = Path(__file__).resolve().parents[3] # root/
|
||||
DATA_MEDIA_DIR = ROOT_DIR / "data" / "media"
|
||||
|
||||
# Video subdirectory (relative to workspace media dir)
|
||||
PODCAST_AUDIO_DIR = (DATA_MEDIA_DIR / "podcast_audio").resolve()
|
||||
PODCAST_IMAGES_DIR = (DATA_MEDIA_DIR / "podcast_images").resolve()
|
||||
PODCAST_VIDEOS_DIR = (DATA_MEDIA_DIR / "podcast_videos").resolve()
|
||||
|
||||
# Video subdirectory
|
||||
AI_VIDEO_SUBDIR = Path("AI_Videos")
|
||||
|
||||
# Legacy constants - DEPRECATED, use get_podcast_media_dir() instead
|
||||
# Kept for backward compatibility with some handlers
|
||||
PODCAST_AVATARS_SUBDIR = Path("avatars")
|
||||
MediaType = Literal["audio", "image", "video"]
|
||||
|
||||
MediaType = Literal["audio", "image", "video", "chart"]
|
||||
|
||||
def _sanitize_user_id(user_id: str) -> str:
|
||||
return "".join(c for c in user_id if c.isalnum() or c in ("-", "_"))
|
||||
|
||||
|
||||
def get_podcast_media_dir(
|
||||
@@ -30,30 +38,21 @@ def get_podcast_media_dir(
|
||||
*,
|
||||
ensure_exists: bool = False,
|
||||
) -> Path:
|
||||
"""
|
||||
Resolve podcast media directory (workspace-only for multi-tenant isolation).
|
||||
|
||||
Requires user_id for tenant isolation. Falls back to default workspace
|
||||
only if no user_id provided (for backward compat in development).
|
||||
Logs a warning in production when user_id is missing.
|
||||
"""
|
||||
"""Resolve podcast media directory (tenant workspace first, legacy global fallback)."""
|
||||
media_subdir = {
|
||||
"audio": "podcast_audio",
|
||||
"image": "podcast_images",
|
||||
"video": "podcast_videos",
|
||||
"chart": "podcast_charts",
|
||||
}[media_type]
|
||||
|
||||
if user_id:
|
||||
sanitized = _sanitize_user_id(user_id)
|
||||
resolved_dir = (
|
||||
ROOT_DIR / "workspace" / f"workspace_{sanitized}" / "media" / media_subdir
|
||||
).resolve()
|
||||
tenant_media_dir = ROOT_DIR / "workspace" / f"workspace_{sanitized}" / "media" / media_subdir
|
||||
resolved_dir = tenant_media_dir.resolve()
|
||||
else:
|
||||
logger.warning(f"[Podcast] get_podcast_media_dir called without user_id for {media_type} — using default workspace. This should not happen in production.")
|
||||
resolved_dir = (
|
||||
ROOT_DIR / "workspace" / "workspace_alwrity" / "media" / media_subdir
|
||||
).resolve()
|
||||
resolved_dir = (DATA_MEDIA_DIR / media_subdir).resolve()
|
||||
|
||||
logger.debug(f"[Podcast] get_podcast_media_dir: type={media_type}, user_id={user_id}, sanitized={user_id and _sanitize_user_id(user_id)}, resolved={resolved_dir}")
|
||||
|
||||
if ensure_exists:
|
||||
resolved_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -62,11 +61,14 @@ def get_podcast_media_dir(
|
||||
|
||||
|
||||
def get_podcast_media_read_dirs(media_type: MediaType, user_id: str | None = None) -> list[Path]:
|
||||
"""
|
||||
Return directories to search for podcast media.
|
||||
Now workspace-only (no legacy fallback).
|
||||
"""
|
||||
return [get_podcast_media_dir(media_type, user_id)]
|
||||
"""Return ordered directories to search (tenant path first, then legacy global path)."""
|
||||
dirs: list[Path] = []
|
||||
if user_id:
|
||||
dirs.append(get_podcast_media_dir(media_type, user_id))
|
||||
logger.debug(f"[Podcast] get_podcast_media_read_dirs: added user dir for {user_id}")
|
||||
dirs.append(get_podcast_media_dir(media_type, None))
|
||||
logger.debug(f"[Podcast] get_podcast_media_read_dirs: dirs={dirs}")
|
||||
return dirs
|
||||
|
||||
|
||||
def get_podcast_audio_service(user_id: str | None = None) -> StoryAudioGenerationService:
|
||||
|
||||
@@ -1,216 +0,0 @@
|
||||
"""
|
||||
Podcast cost estimation helpers.
|
||||
|
||||
Builds user-facing podcast estimates from the subscription pricing catalog
|
||||
instead of hard-coded frontend heuristics.
|
||||
|
||||
Supports multiple models for each component:
|
||||
- Audio TTS: minimax/speech-02-hd (default), qwen3-tts, cosyvoice-tts
|
||||
- Voice Clone: qwen3, cosyvoice, minimax
|
||||
- Image: qwen-image (default), ideogram-v3-turbo
|
||||
- Video: wan-2.5 (default), kling-v2.5, infinitetalk
|
||||
- LLM: gemini-2.5-flash (default)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.subscription_models import APIProvider
|
||||
from services.subscription.pricing_service import PricingService
|
||||
|
||||
|
||||
def _round_money(value: float) -> float:
|
||||
return round(float(value), 4)
|
||||
|
||||
|
||||
def _load_pricing(
|
||||
pricing_service: PricingService,
|
||||
provider: APIProvider,
|
||||
preferred_model: str,
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""Load pricing for a provider and model, with fallback to default."""
|
||||
pricing = pricing_service.get_pricing_for_provider_model(provider, preferred_model)
|
||||
if pricing:
|
||||
return pricing
|
||||
# Fallback to provider default model row (if configured).
|
||||
return pricing_service.get_pricing_for_provider_model(provider, "default")
|
||||
|
||||
|
||||
# Default models used in podcast generation
|
||||
DEFAULT_MODELS = {
|
||||
"gemini": "gemini-2.5-flash",
|
||||
"exa": "exa-search",
|
||||
"audio_tts": "minimax/speech-02-hd",
|
||||
"voice_clone": "wavespeed-ai/qwen3-tts/voice-clone",
|
||||
"image": "qwen-image",
|
||||
"video": "wan-2.5",
|
||||
}
|
||||
|
||||
|
||||
def estimate_podcast_cost(
|
||||
*,
|
||||
db: Session,
|
||||
duration_minutes: int,
|
||||
speakers: int,
|
||||
query_count: int,
|
||||
include_avatar_phase: bool = True,
|
||||
# Optional model overrides
|
||||
gemini_model: str = "gemini-2.5-flash",
|
||||
audio_tts_model: str = "minimax/speech-02-hd",
|
||||
voice_clone_engine: str = "qwen3",
|
||||
image_model: str = "qwen-image",
|
||||
video_model: str = "wan-2.5",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Compute a backend estimate for podcast creation.
|
||||
|
||||
Supports customizable models for each component.
|
||||
Uses pricing_catalog for accurate cost calculation.
|
||||
"""
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# Load pricing for each component and model
|
||||
gemini_pricing = _load_pricing(pricing_service, APIProvider.GEMINI, gemini_model)
|
||||
exa_pricing = _load_pricing(pricing_service, APIProvider.EXA, "exa-search")
|
||||
|
||||
# Audio TTS pricing (minimax/speech-02-hd)
|
||||
audio_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, audio_tts_model)
|
||||
|
||||
# Voice clone pricing (different engines)
|
||||
voice_clone_model = f"wavespeed-ai/{voice_clone_engine}-tts/voice-clone"
|
||||
voice_clone_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, voice_clone_model)
|
||||
if not voice_clone_pricing:
|
||||
# Try alternate model names
|
||||
voice_clone_pricing = _load_pricing(pricing_service, APIProvider.AUDIO, f"{voice_clone_engine}/voice-clone")
|
||||
|
||||
# Image pricing (qwen-image or ideogram)
|
||||
image_pricing = _load_pricing(pricing_service, APIProvider.STABILITY, image_model)
|
||||
|
||||
# Video pricing (wan-2.5, kling, or infinitetalk)
|
||||
video_pricing = _load_pricing(pricing_service, APIProvider.VIDEO, video_model)
|
||||
|
||||
# Return None if critical pricing unavailable (fail fast)
|
||||
if not gemini_pricing:
|
||||
return None
|
||||
|
||||
# Configuration
|
||||
minutes = max(1, int(duration_minutes or 1))
|
||||
speaker_count = max(1, int(speakers or 1))
|
||||
research_queries = max(1, int(query_count or 1))
|
||||
|
||||
# Token usage assumptions per phase
|
||||
analysis_input_tokens = 1800
|
||||
analysis_output_tokens = 1000
|
||||
research_synthesis_input_tokens = 2200
|
||||
research_synthesis_output_tokens = 900
|
||||
script_input_tokens = max(1800, minutes * 300)
|
||||
script_output_tokens = max(2200, minutes * 700)
|
||||
|
||||
# TTS: ~900 chars per minute per speaker
|
||||
estimated_tts_tokens = max(900, minutes * 900 * speaker_count)
|
||||
|
||||
# Voice clone: 1 clone operation per speaker
|
||||
voice_clone_count = speaker_count
|
||||
|
||||
# ===== COST CALCULATIONS =====
|
||||
|
||||
# 1. Analysis phase (LLM)
|
||||
analysis_cost = (
|
||||
analysis_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ analysis_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
)
|
||||
|
||||
# 2. Research phase
|
||||
# 2a. LLM for research synthesis
|
||||
research_llm_cost = (
|
||||
research_synthesis_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ research_synthesis_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
)
|
||||
# 2b. Search API (Exa)
|
||||
research_search_cost = 0.0
|
||||
if exa_pricing:
|
||||
research_search_cost = research_queries * float(exa_pricing.get("cost_per_request") or 0.0)
|
||||
research_cost = research_search_cost + research_llm_cost
|
||||
|
||||
# 3. Script generation (LLM)
|
||||
script_cost = (
|
||||
script_input_tokens * float(gemini_pricing.get("cost_per_input_token") or 0.0)
|
||||
+ script_output_tokens * float(gemini_pricing.get("cost_per_output_token") or 0.0)
|
||||
)
|
||||
|
||||
# 4. Audio TTS
|
||||
tts_cost = 0.0
|
||||
if audio_pricing:
|
||||
tts_cost = estimated_tts_tokens * float(audio_pricing.get("cost_per_input_token") or 0.0)
|
||||
|
||||
# 5. Voice cloning (if needed)
|
||||
voice_clone_cost = 0.0
|
||||
if voice_clone_pricing:
|
||||
voice_clone_cost = voice_clone_count * (
|
||||
float(voice_clone_pricing.get("cost_per_request") or 0.0)
|
||||
+ estimated_tts_tokens * float(voice_clone_pricing.get("cost_per_input_token") or 0.0)
|
||||
)
|
||||
|
||||
# 6. Avatar image generation
|
||||
avatar_cost = 0.0
|
||||
if include_avatar_phase and image_pricing:
|
||||
image_unit = float(image_pricing.get("cost_per_image") or image_pricing.get("cost_per_request") or 0.0)
|
||||
avatar_cost = speaker_count * image_unit
|
||||
|
||||
# 7. Video rendering
|
||||
video_cost = 0.0
|
||||
if video_pricing:
|
||||
# Assume 1 video render per minute (upper bound)
|
||||
video_cost = minutes * float(video_pricing.get("cost_per_request") or 0.0)
|
||||
|
||||
# ===== TOTALS =====
|
||||
llm_total = analysis_cost + research_llm_cost + script_cost
|
||||
audio_total = tts_cost + voice_clone_cost
|
||||
media_total = avatar_cost + video_cost
|
||||
total = llm_total + research_search_cost + audio_total + media_total
|
||||
|
||||
return {
|
||||
# Cost breakdown
|
||||
"analysisCost": _round_money(analysis_cost),
|
||||
"researchCost": _round_money(research_cost),
|
||||
"researchSearchCost": _round_money(research_search_cost),
|
||||
"researchLlmCost": _round_money(research_llm_cost),
|
||||
"scriptCost": _round_money(script_cost),
|
||||
"ttsCost": _round_money(tts_cost),
|
||||
"voiceCloneCost": _round_money(voice_clone_cost),
|
||||
"avatarCost": _round_money(avatar_cost),
|
||||
"videoCost": _round_money(video_cost),
|
||||
"total": _round_money(total),
|
||||
# Totals by category
|
||||
"llmCost": _round_money(llm_total),
|
||||
"audioCost": _round_money(audio_total),
|
||||
"mediaCost": _round_money(media_total),
|
||||
# Currency
|
||||
"currency": "USD",
|
||||
"source": "pricing_catalog",
|
||||
# Models used for this estimate
|
||||
"models": {
|
||||
"llm": gemini_model,
|
||||
"research": "exa-search",
|
||||
"audio_tts": audio_tts_model,
|
||||
"voice_clone": voice_clone_model,
|
||||
"image": image_model,
|
||||
"video": video_model,
|
||||
},
|
||||
# Assumptions used
|
||||
"assumptions": {
|
||||
"analysis_input_tokens": analysis_input_tokens,
|
||||
"analysis_output_tokens": analysis_output_tokens,
|
||||
"research_synthesis_input_tokens": research_synthesis_input_tokens,
|
||||
"research_synthesis_output_tokens": research_synthesis_output_tokens,
|
||||
"script_input_tokens": script_input_tokens,
|
||||
"script_output_tokens": script_output_tokens,
|
||||
"estimated_tts_tokens": estimated_tts_tokens,
|
||||
"research_queries": research_queries,
|
||||
"voice_clone_count": voice_clone_count,
|
||||
"video_requests": minutes,
|
||||
"avatar_requests": speaker_count if include_avatar_phase else 0,
|
||||
},
|
||||
}
|
||||
@@ -4,9 +4,8 @@ Podcast Analysis Handlers
|
||||
Analysis endpoint for podcast ideas.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, Optional, List
|
||||
from datetime import datetime
|
||||
import json
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -18,99 +17,101 @@ from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
import os
|
||||
from ..constants import get_podcast_media_dir
|
||||
from ..prompts import get_enhance_topic_prompt, format_website_context
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..models import (
|
||||
PodcastAnalyzeRequest,
|
||||
PodcastAnalyzeResponse,
|
||||
PodcastEnhanceIdeaRequest,
|
||||
PodcastEnhanceIdeaResponse,
|
||||
ExtractUrlRequest,
|
||||
ExtractUrlResponse,
|
||||
WebsiteAnalysisRequest,
|
||||
WebsiteAnalysisResponse,
|
||||
PodcastPreEstimateRequest,
|
||||
PodcastPreEstimateResponse,
|
||||
PodcastEnhanceIdeaResponse
|
||||
)
|
||||
from ..cost_estimator import estimate_podcast_cost
|
||||
|
||||
# Check if running in podcast-only demo mode
|
||||
def _is_podcast_only_mode() -> bool:
|
||||
"""Check if podcast-only demo mode is enabled."""
|
||||
return os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
def _estimate_tokens(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return max(1, len(text) // 4)
|
||||
|
||||
|
||||
@router.post("/pre-estimate", response_model=PodcastPreEstimateResponse)
|
||||
async def pre_estimate_cost(
|
||||
request: PodcastPreEstimateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
def _build_analysis_estimate(
|
||||
db: Session,
|
||||
idea: str,
|
||||
duration: int,
|
||||
speakers: int,
|
||||
has_avatar: bool,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Lightweight endpoint to estimate podcast creation cost before analysis.
|
||||
|
||||
Takes user configuration (duration, speakers, query_count, podcast_mode) and returns
|
||||
a cost estimate WITHOUT running full analysis.
|
||||
|
||||
Optional model overrides can be specified to estimate with different models.
|
||||
Build a user-facing estimate from pricing catalog and phase-level assumptions.
|
||||
"""
|
||||
# Defaults if catalog lookup fails
|
||||
gemini_in_token = 0.00000015
|
||||
gemini_out_token = 0.0000006
|
||||
exa_per_request = 0.005
|
||||
image_per_request = 0.01
|
||||
video_per_request = 0.01
|
||||
audio_per_request = 0.005
|
||||
|
||||
try:
|
||||
include_avatar_phase = request.podcast_mode != "audio_only"
|
||||
|
||||
estimate = estimate_podcast_cost(
|
||||
db=db,
|
||||
duration_minutes=request.duration,
|
||||
speakers=request.speakers,
|
||||
query_count=request.query_count,
|
||||
include_avatar_phase=include_avatar_phase,
|
||||
# Model overrides if provided
|
||||
gemini_model=request.gemini_model or "gemini-2.5-flash",
|
||||
audio_tts_model=request.audio_tts_model or "minimax/speech-02-hd",
|
||||
voice_clone_engine=request.voice_clone_engine or "qwen3",
|
||||
image_model=request.image_model or "qwen-image",
|
||||
video_model=request.video_model or "wan-2.5",
|
||||
)
|
||||
|
||||
# Debug: get pricing row count and providers
|
||||
from models.subscription_models import APIProviderPricing
|
||||
pricing_count = db.query(APIProviderPricing).count()
|
||||
providers = db.query(APIProviderPricing.provider).distinct().all()
|
||||
provider_list = sorted([p[0].value for p in providers]) if providers else []
|
||||
|
||||
debug_info = {
|
||||
"pricing_rows": pricing_count,
|
||||
"providers": provider_list,
|
||||
}
|
||||
|
||||
# Log pricing debug info at warning level
|
||||
logger.warning(f"[PRE-ESTIMATE] Pricing debug: rows={pricing_count}, providers={provider_list}")
|
||||
logger.warning(f"[PRE-ESTIMATE] Models: llm={request.gemini_model}, tts={request.audio_tts_model}, video={request.video_model}")
|
||||
|
||||
if estimate is None:
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=None,
|
||||
error="Pricing data unavailable. Please try again later.",
|
||||
pricing_available=False,
|
||||
debug=debug_info,
|
||||
)
|
||||
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=estimate,
|
||||
error=None,
|
||||
pricing_available=True,
|
||||
debug=debug_info,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Pre-estimate error: {e}")
|
||||
return PodcastPreEstimateResponse(
|
||||
estimate=None,
|
||||
error=str(e),
|
||||
)
|
||||
pricing_service = PricingService(db)
|
||||
gemini_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.GEMINI, "gemini-2.5-flash") or {}
|
||||
gemini_in_token = float(gemini_pricing.get("cost_per_input_token") or gemini_in_token)
|
||||
gemini_out_token = float(gemini_pricing.get("cost_per_output_token") or gemini_out_token)
|
||||
exa_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.EXA, "exa-search") or {}
|
||||
exa_per_request = float(exa_pricing.get("cost_per_request") or exa_per_request)
|
||||
img_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.STABILITY, "stable-image-ultra") or {}
|
||||
image_per_request = float(img_pricing.get("cost_per_request") or image_per_request)
|
||||
video_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.VIDEO, "minimax-video-01") or {}
|
||||
video_per_request = float(video_pricing.get("cost_per_request") or video_per_request)
|
||||
audio_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.AUDIO, "gemini-2.5-flash-preview-tts") or {}
|
||||
audio_per_request = float(audio_pricing.get("cost_per_request") or audio_per_request)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Analyze] Pricing catalog lookup failed, using defaults: {exc}")
|
||||
|
||||
# Phase assumptions
|
||||
query_count = 5
|
||||
analyze_in = _estimate_tokens(idea) + 240
|
||||
analyze_out = 750
|
||||
analyze_cost = (analyze_in * gemini_in_token) + (analyze_out * gemini_out_token)
|
||||
|
||||
gather_cost = query_count * exa_per_request
|
||||
|
||||
script_chars = max(1000, duration * 900)
|
||||
write_in = _estimate_tokens(idea) + _estimate_tokens(str(script_chars)) + 320
|
||||
write_out = max(900, int(duration * 220))
|
||||
write_cost = (write_in * gemini_in_token) + (write_out * gemini_out_token)
|
||||
|
||||
tts_cost = max(1, speakers) * audio_per_request
|
||||
avatar_cost = 0.0 if has_avatar else image_per_request
|
||||
video_cost = max(1, duration) * video_per_request
|
||||
produce_cost = tts_cost + avatar_cost + video_cost
|
||||
|
||||
breakdown = [
|
||||
{"phase": "Analyze", "cost": round(analyze_cost, 6)},
|
||||
{"phase": "Gather", "cost": round(gather_cost, 6)},
|
||||
{"phase": "Write", "cost": round(write_cost, 6)},
|
||||
{"phase": "Produce", "cost": round(produce_cost, 6)},
|
||||
]
|
||||
total = round(sum(item["cost"] for item in breakdown), 6)
|
||||
return {
|
||||
"ttsCost": round(tts_cost, 6),
|
||||
"avatarCost": round(avatar_cost, 6),
|
||||
"videoCost": round(video_cost, 6),
|
||||
"researchCost": round(gather_cost, 6),
|
||||
"total": total,
|
||||
"breakdown": breakdown,
|
||||
"currency": "USD",
|
||||
}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/idea/enhance", response_model=PodcastEnhanceIdeaResponse)
|
||||
@@ -153,27 +154,39 @@ async def enhance_podcast_idea(
|
||||
except Exception as exc:
|
||||
logger.debug(f"[Podcast Enhance] Bible parsing skipped in podcast mode: {exc}")
|
||||
|
||||
# Log what's being used for context
|
||||
context_used = []
|
||||
if bible_context:
|
||||
context_used.append("Podcast Bible")
|
||||
if request.website_data:
|
||||
context_used.append("Website Extraction")
|
||||
if request.topic_context:
|
||||
category = request.topic_context.get("category", "unknown")
|
||||
context_used.append(f"Category Research ({category})")
|
||||
|
||||
logger.warning(f"[Podcast Enhance] Generating with context: {', '.join(context_used) if context_used else 'basic idea only'}")
|
||||
prompt = f"""
|
||||
You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
|
||||
# Use new context builder for prompt generation
|
||||
from services.podcast_context_builder import context_builder
|
||||
context_result = context_builder.build_enhance_context(
|
||||
idea=request.idea,
|
||||
bible_context=bible_context,
|
||||
website_data=request.website_data,
|
||||
topic_context=request.topic_context,
|
||||
)
|
||||
prompt = context_result["prompt"]
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{request.idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}}
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
@@ -311,8 +324,7 @@ async def analyze_podcast_idea(
|
||||
if image_result and image_result.image_bytes:
|
||||
img_id = str(uuid.uuid4())[:8]
|
||||
filename = f"presenter_podcast_{user_id}_{img_id}.png"
|
||||
images_dir = get_podcast_media_dir("image", user_id, ensure_exists=True)
|
||||
avatars_dir = images_dir / "avatars"
|
||||
avatars_dir = PODCAST_IMAGES_DIR / "avatars"
|
||||
avatars_dir.mkdir(parents=True, exist_ok=True)
|
||||
output_path = avatars_dir / filename
|
||||
|
||||
@@ -438,13 +450,6 @@ Requirements:
|
||||
listener_cta = data.get("listener_cta") or ""
|
||||
research_queries = data.get("research_queries") or []
|
||||
exa_suggested_config = data.get("exa_suggested_config") or None
|
||||
estimate = estimate_podcast_cost(
|
||||
db=db,
|
||||
duration_minutes=request.duration,
|
||||
speakers=request.speakers,
|
||||
query_count=len(research_queries) if isinstance(research_queries, list) else 0,
|
||||
include_avatar_phase=podcast_mode != "audio_only",
|
||||
)
|
||||
|
||||
return PodcastAnalyzeResponse(
|
||||
audience=audience,
|
||||
@@ -461,7 +466,13 @@ Requirements:
|
||||
bible=bible_obj.model_dump() if bible_obj else None,
|
||||
avatar_url=final_avatar_url,
|
||||
avatar_prompt=final_avatar_prompt,
|
||||
estimate=estimate,
|
||||
estimate=_build_analysis_estimate(
|
||||
db=db,
|
||||
idea=request.idea,
|
||||
duration=request.duration,
|
||||
speakers=request.speakers,
|
||||
has_avatar=bool(final_avatar_url),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -566,316 +577,3 @@ Requirements:
|
||||
except Exception as exc:
|
||||
logger.error(f"[Regenerate Queries] Failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Regenerate queries failed: {exc}")
|
||||
|
||||
|
||||
@router.post("/extract-url", response_model=ExtractUrlResponse)
|
||||
async def extract_url_content(
|
||||
request: ExtractUrlRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Extract content from a URL using Exa's get_contents API.
|
||||
|
||||
This allows users to paste a blog post or article URL as their podcast topic,
|
||||
and we'll extract the content to use as the podcast idea.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
from exa_py import Exa
|
||||
import os
|
||||
|
||||
api_key = os.getenv("EXA_API_KEY")
|
||||
if not api_key:
|
||||
raise HTTPException(status_code=500, detail="EXA_API_KEY not configured")
|
||||
|
||||
exa = Exa(api_key)
|
||||
|
||||
logger.warning(f"[ExtractUrl] Extracting content from: {request.url} for user {user_id}")
|
||||
|
||||
try:
|
||||
result = exa.get_contents(
|
||||
urls=[request.url],
|
||||
text=True,
|
||||
highlights=True,
|
||||
summary=True,
|
||||
subpages=2,
|
||||
)
|
||||
except Exception as exa_error:
|
||||
logger.error(f"[ExtractUrl] Exa call error: {exa_error}")
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error=f"Exa API error: {str(exa_error)}"
|
||||
)
|
||||
|
||||
# Check for errors using the correct attribute (statuses is array of status objects)
|
||||
if hasattr(result, 'statuses') and result.statuses:
|
||||
for status in result.statuses:
|
||||
if status.status == "error":
|
||||
logger.error(f"[ExtractUrl] Failed to extract {status.id}: {status.error.tag if hasattr(status.error, 'tag') else 'unknown'}")
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error=f"Failed to extract content: {status.error.tag if hasattr(status.error, 'tag') else 'unknown error'}"
|
||||
)
|
||||
|
||||
if not result.results:
|
||||
return ExtractUrlResponse(
|
||||
success=False,
|
||||
url=request.url,
|
||||
error="No content found at the provided URL"
|
||||
)
|
||||
|
||||
# Extract content - safe to access result now
|
||||
content = result.results[0]
|
||||
|
||||
# Extract all available fields from Exa response
|
||||
extracted_text = content.text or ""
|
||||
extracted_summary = getattr(content, 'summary', "") or ""
|
||||
extracted_title = content.title or ""
|
||||
|
||||
# Highlights - extract from content.highlights array if available
|
||||
highlights = []
|
||||
if hasattr(content, 'highlights') and content.highlights:
|
||||
highlights = [h for h in content.highlights if h]
|
||||
|
||||
# Additional fields from Exa response
|
||||
image = getattr(content, 'image', None)
|
||||
favicon = getattr(content, 'favicon', None)
|
||||
|
||||
# Subpages - extract with their own content
|
||||
subpages = []
|
||||
if hasattr(content, 'subpages') and content.subpages:
|
||||
for sp in content.subpages:
|
||||
subpages.append({
|
||||
'id': sp.get('id', ''),
|
||||
'title': sp.get('title', ''),
|
||||
'url': sp.get('url', ''),
|
||||
'summary': sp.get('summary', ''),
|
||||
'text': sp.get('text', '')[:500] if sp.get('text') else '', # First 500 chars
|
||||
})
|
||||
|
||||
logger.warning(f"[ExtractUrl] Successfully extracted {len(extracted_text)} chars from {request.url}")
|
||||
logger.warning(f"[ExtractUrl] title={extracted_title[:50]}, summary={extracted_summary[:50]}, highlights={len(highlights)}, subpages={len(subpages)}")
|
||||
|
||||
return ExtractUrlResponse(
|
||||
success=True,
|
||||
title=extracted_title,
|
||||
text=extracted_text,
|
||||
summary=extracted_summary,
|
||||
author=getattr(content, 'author', None),
|
||||
highlights=highlights,
|
||||
url=request.url,
|
||||
image=image,
|
||||
favicon=favicon,
|
||||
subpages=subpages,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/website-analysis", response_model=WebsiteAnalysisResponse)
|
||||
async def save_website_analysis(
|
||||
request: WebsiteAnalysisRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save the user's website analysis for reuse in future podcasts."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.user_data_service import user_data_service
|
||||
|
||||
website_data = {
|
||||
"website_url": request.website_url,
|
||||
"extracted_at": datetime.now().isoformat(),
|
||||
"exa_content": request.exa_content,
|
||||
"full_analysis": None,
|
||||
"analysis_status": "pending",
|
||||
}
|
||||
|
||||
success = user_data_service.save_user_data(
|
||||
user_id=user_id,
|
||||
data_key="website_analysis",
|
||||
data_value=website_data,
|
||||
)
|
||||
|
||||
if success:
|
||||
logger.warning(f"[WebsiteAnalysis] Saved analysis for user {user_id}: {request.website_url}")
|
||||
return WebsiteAnalysisResponse(
|
||||
success=True,
|
||||
website_url=request.website_url,
|
||||
message="Website analysis saved successfully",
|
||||
)
|
||||
else:
|
||||
return WebsiteAnalysisResponse(
|
||||
success=False,
|
||||
error="Failed to save website analysis",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteAnalysis] Failed to save for user {user_id}: {exc}")
|
||||
return WebsiteAnalysisResponse(
|
||||
success=False,
|
||||
error=f"Failed to save: {str(exc)}"
|
||||
)
|
||||
|
||||
|
||||
@router.get("/website-extraction")
|
||||
async def get_saved_website_extraction(request: Request = None):
|
||||
"""Get previously saved website extraction data for this user."""
|
||||
try:
|
||||
# Safely get current_user from Depends
|
||||
if request is None or not hasattr(request, 'state'):
|
||||
logger.warning("[WebsiteExtraction] No request or state - user not authenticated")
|
||||
return {"success": False, "data": None, "error": "Not authenticated"}
|
||||
|
||||
current_user = getattr(request.state, 'user', None)
|
||||
if not current_user:
|
||||
logger.warning("[WebsiteExtraction] No user in request state")
|
||||
return {"success": False, "data": None, "error": "Not authenticated"}
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
from services.user_data_service import UserDataService
|
||||
from services.database import get_db
|
||||
db = next(get_db())
|
||||
|
||||
user_service = UserDataService(db)
|
||||
extraction = user_service.get_website_extraction(user_id)
|
||||
|
||||
if extraction:
|
||||
logger.info(f"[WebsiteExtraction] Found saved data for user {user_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"data": extraction
|
||||
}
|
||||
else:
|
||||
logger.info(f"[WebsiteExtraction] No saved data for user {user_id}")
|
||||
return {
|
||||
"success": False,
|
||||
"data": None
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteExtraction] Failed for user: {exc}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/website-extraction")
|
||||
async def save_website_extraction(
|
||||
extraction: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save website extraction data for future use."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.user_data_service import UserDataService
|
||||
from services.database import get_db
|
||||
db = next(get_db())
|
||||
|
||||
user_service = UserDataService(db)
|
||||
success = user_service.save_website_extraction(user_id, extraction)
|
||||
|
||||
if success:
|
||||
logger.info(f"[WebsiteExtraction] Saved for user {user_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Website extraction saved"
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Failed to save"
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[WebsiteExtraction] Save failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.post("/project/{project_id}/topic-context")
|
||||
async def save_topic_context(
|
||||
project_id: str,
|
||||
topic_context: Dict[str, Any],
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Save topic context (category research) to a podcast project."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from models.podcast_models import PodcastProject
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
# Find the project
|
||||
project = db.query(PodcastProject).filter(
|
||||
PodcastProject.project_id == project_id,
|
||||
PodcastProject.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not project:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Project not found"
|
||||
}
|
||||
|
||||
# Update topic context
|
||||
project.topic_context = topic_context
|
||||
db.commit()
|
||||
|
||||
logger.info(f"[TopicContext] Saved for project {project_id}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "Topic context saved"
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[TopicContext] Save failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
|
||||
@router.get("/project/{project_id}/topic-context")
|
||||
async def get_topic_context(
|
||||
project_id: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Get topic context from a podcast project."""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from models.podcast_models import PodcastProject
|
||||
|
||||
db = next(get_db())
|
||||
|
||||
project = db.query(PodcastProject).filter(
|
||||
PodcastProject.project_id == project_id,
|
||||
PodcastProject.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not project:
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Project not found"
|
||||
}
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"data": project.topic_context
|
||||
}
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"[TopicContext] Get failed: {exc}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(exc)
|
||||
}
|
||||
|
||||
@@ -12,15 +12,7 @@ from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
import tempfile
|
||||
import uuid
|
||||
import hashlib
|
||||
import time
|
||||
import shutil
|
||||
import requests
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
@@ -39,124 +31,6 @@ from ..models import (
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Thread pool for CPU/IO-intensive voice clone operations
|
||||
_audio_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="podcast_audio")
|
||||
|
||||
# In-memory LRU cache for voice samples (per user) to avoid re-downloading
|
||||
_voice_sample_cache: dict[str, tuple[float, bytes]] = {}
|
||||
_VOICE_SAMPLE_CACHE_TTL = 1800 # 30 minutes
|
||||
|
||||
|
||||
def _get_cached_voice_sample(cache_key: str) -> Optional[bytes]:
|
||||
"""Get voice sample bytes from in-memory cache if fresh."""
|
||||
if cache_key in _voice_sample_cache:
|
||||
ts, data = _voice_sample_cache[cache_key]
|
||||
if time.time() - ts < _VOICE_SAMPLE_CACHE_TTL:
|
||||
logger.debug(f"[Podcast] Voice sample cache hit for {cache_key[:16]}...")
|
||||
return data
|
||||
del _voice_sample_cache[cache_key]
|
||||
return None
|
||||
|
||||
|
||||
def _cache_voice_sample(cache_key: str, data: bytes) -> None:
|
||||
"""Store voice sample bytes in in-memory cache."""
|
||||
# Evict oldest entries if cache grows too large
|
||||
if len(_voice_sample_cache) > 50:
|
||||
oldest_key = min(_voice_sample_cache, key=lambda k: _voice_sample_cache[k][0])
|
||||
del _voice_sample_cache[oldest_key]
|
||||
_voice_sample_cache[cache_key] = (time.time(), data)
|
||||
|
||||
|
||||
def _get_latest_voice_sample_url(user_id: str, db) -> Optional[str]:
|
||||
"""Get the latest voice sample URL for a user from their voice clone assets."""
|
||||
try:
|
||||
from models.content_asset_models import ContentAsset, AssetType, AssetSource
|
||||
from sqlalchemy import desc
|
||||
|
||||
asset = db.query(ContentAsset).filter(
|
||||
ContentAsset.user_id == user_id,
|
||||
ContentAsset.asset_type == AssetType.AUDIO,
|
||||
ContentAsset.source_module == AssetSource.VOICE_CLONER,
|
||||
).order_by(desc(ContentAsset.created_at)).first()
|
||||
|
||||
if asset and asset.file_url:
|
||||
logger.info(f"[Podcast] Found voice sample for user {user_id}: {asset.file_url}")
|
||||
return asset.file_url
|
||||
|
||||
logger.warning(f"[Podcast] No voice sample asset found for user {user_id}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Error fetching voice sample URL: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_voice_sample(voice_sample_url: str, user_id: str) -> Optional[bytes]:
|
||||
"""Fetch voice sample audio bytes from URL, with caching."""
|
||||
cache_key = hashlib.md5(f"{user_id}:{voice_sample_url}".encode()).hexdigest()
|
||||
|
||||
# Check in-memory cache first
|
||||
cached = _get_cached_voice_sample(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
from utils.media_utils import resolve_media_path
|
||||
|
||||
# Try resolving as a local workspace path first (fastest)
|
||||
if "/api/assets/" in voice_sample_url:
|
||||
# Resolve user workspace path directly
|
||||
sanitized_uid = "".join(c for c in user_id if c.isalnum() or c in ("-", "_"))
|
||||
from api.podcast.constants import ROOT_DIR
|
||||
parts = voice_sample_url.split("/")
|
||||
# Expected: /api/assets/{user_id}/voice_samples/{filename}
|
||||
try:
|
||||
idx = parts.index("voice_samples")
|
||||
filename = parts[idx + 1].split("?")[0]
|
||||
local_path = ROOT_DIR / "workspace" / f"workspace_{sanitized_uid}" / "assets" / "voice_samples" / filename
|
||||
if local_path.exists():
|
||||
data = local_path.read_bytes()
|
||||
_cache_voice_sample(cache_key, data)
|
||||
logger.info(f"[Podcast] Voice sample loaded from workspace: {local_path}")
|
||||
return data
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
|
||||
# Fall back to media utils resolver
|
||||
local_path = resolve_media_path(voice_sample_url)
|
||||
if local_path and local_path.exists():
|
||||
data = local_path.read_bytes()
|
||||
_cache_voice_sample(cache_key, data)
|
||||
return data
|
||||
|
||||
# Try resolving as a podcast audio file
|
||||
if "/api/podcast/audio/" in voice_sample_url:
|
||||
filename = voice_sample_url.split("/api/podcast/audio/")[-1].split("?")[0]
|
||||
try:
|
||||
audio_dir = get_podcast_media_dir("audio", user_id)
|
||||
local_path = audio_dir / filename
|
||||
if local_path.exists():
|
||||
data = local_path.read_bytes()
|
||||
_cache_voice_sample(cache_key, data)
|
||||
return data
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Try direct HTTP fetch as fallback
|
||||
if voice_sample_url.startswith("http"):
|
||||
logger.info(f"[Podcast] Fetching voice sample via HTTP: {voice_sample_url[:80]}...")
|
||||
resp = requests.get(voice_sample_url, timeout=30)
|
||||
if resp.status_code == 200:
|
||||
data = resp.content
|
||||
_cache_voice_sample(cache_key, data)
|
||||
logger.info(f"[Podcast] Voice sample fetched via HTTP ({len(data)} bytes)")
|
||||
return data
|
||||
|
||||
logger.warning(f"[Podcast] Could not fetch voice sample from: {voice_sample_url}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Error fetching voice sample: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.post("/audio/upload")
|
||||
async def upload_podcast_audio(
|
||||
@@ -251,190 +125,36 @@ async def generate_podcast_audio(
|
||||
raise HTTPException(status_code=400, detail="Text is required")
|
||||
|
||||
try:
|
||||
# Determine if we should use voice clone path
|
||||
# Voice clone is used when: explicitly requested, OR when voice_id/custom_voice_id indicates a clone
|
||||
# (cloned voice IDs start with "vc_" or match the placeholder "MY_VOICE_CLONE")
|
||||
_vid = request.voice_id or ""
|
||||
_cvid = request.custom_voice_id or ""
|
||||
is_voice_clone = request.use_voice_clone or (
|
||||
_cvid.startswith("vc_") or _cvid == "MY_VOICE_CLONE"
|
||||
) or (
|
||||
_vid.startswith("vc_") or _vid == "MY_VOICE_CLONE"
|
||||
audio_service = get_podcast_audio_service(user_id)
|
||||
logger.warning(f"[Podcast] Generating audio with service dir: {audio_service.output_dir}")
|
||||
result: StoryAudioResult = audio_service.generate_ai_audio(
|
||||
scene_number=0,
|
||||
scene_title=request.scene_title,
|
||||
text=request.text.strip(),
|
||||
user_id=user_id,
|
||||
voice_id=request.voice_id or "Wise_Woman",
|
||||
custom_voice_id=request.custom_voice_id,
|
||||
speed=request.speed or 1.0, # Normal speed (was 0.9, but too slow - causing duration issues)
|
||||
volume=request.volume or 1.0,
|
||||
pitch=request.pitch or 0.0, # Normal pitch (0.0 = neutral)
|
||||
emotion=request.emotion or "neutral",
|
||||
english_normalization=request.english_normalization or False,
|
||||
sample_rate=request.sample_rate,
|
||||
bitrate=request.bitrate,
|
||||
channel=request.channel,
|
||||
format=request.format,
|
||||
language_boost=request.language_boost,
|
||||
enable_sync_mode=request.enable_sync_mode,
|
||||
)
|
||||
|
||||
# If voice_id is a clone ID, normalize it to use Wise_Woman for TTS fallback
|
||||
effective_voice_id = _vid if not (_vid.startswith("vc_") or _vid == "MY_VOICE_CLONE") else "Wise_Woman"
|
||||
|
||||
logger.warning(f"[Podcast] Audio request: use_voice_clone={request.use_voice_clone}, voice_id={request.voice_id}, custom_voice_id={request.custom_voice_id}, is_voice_clone={is_voice_clone}, voice_sample_url={request.voice_sample_url}, voice_clone_engine={request.voice_clone_engine}")
|
||||
|
||||
# Voice clone path: use user's voice sample with scene text as reference
|
||||
if is_voice_clone:
|
||||
# If no voice_sample_url provided, try to fetch it from the user's latest voice clone
|
||||
voice_sample_url = request.voice_sample_url
|
||||
if not voice_sample_url:
|
||||
try:
|
||||
voice_sample_url = _get_latest_voice_sample_url(user_id, db)
|
||||
logger.warning(f"[Podcast] DB fallback voice sample URL for user {user_id}: {voice_sample_url}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] Could not fetch voice sample URL: {e}")
|
||||
|
||||
if voice_sample_url:
|
||||
from services.llm_providers.main_audio_generation import qwen3_voice_clone, cosyvoice_voice_clone
|
||||
from utils.media_utils import detect_audio_format
|
||||
|
||||
engine = (request.voice_clone_engine or "qwen3").lower()
|
||||
logger.warning(f"[Podcast] 🔊 Voice clone path: engine={engine}, scene='{request.scene_title}', voice_sample_url={voice_sample_url[:80]}...")
|
||||
|
||||
# Download voice sample from URL (with caching)
|
||||
logger.warning(f"[Podcast] Fetching voice sample from: {voice_sample_url}")
|
||||
try:
|
||||
voice_sample_bytes = _fetch_voice_sample(voice_sample_url, user_id)
|
||||
except Exception as fetch_err:
|
||||
logger.error(f"[Podcast] ❌ Failed to fetch voice sample: {fetch_err}", exc_info=True)
|
||||
raise HTTPException(status_code=400, detail=f"Could not fetch voice sample: {str(fetch_err)}")
|
||||
logger.warning(f"[Podcast] Voice sample fetch result: {len(voice_sample_bytes) if voice_sample_bytes else 0} bytes")
|
||||
if not voice_sample_bytes:
|
||||
raise HTTPException(status_code=400, detail=f"Could not fetch voice sample from {voice_sample_url}")
|
||||
|
||||
# Detect actual audio format from bytes (may differ from file extension)
|
||||
detected_fmt, detected_mime = detect_audio_format(voice_sample_bytes)
|
||||
logger.warning(f"[Podcast] 🔊 Detected voice sample format: {detected_fmt} ({detected_mime}), {len(voice_sample_bytes)} bytes")
|
||||
voice_mime_type = detected_mime or "audio/wav"
|
||||
|
||||
scene_text = request.text.strip()
|
||||
if len(scene_text) > 4000:
|
||||
scene_text = scene_text[:4000]
|
||||
|
||||
# Run voice clone in thread pool to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
try:
|
||||
if engine == "minimax":
|
||||
from services.llm_providers.main_audio_generation import clone_voice
|
||||
import random
|
||||
import string
|
||||
random_suffix = ''.join(random.choices(string.ascii_letters + string.digits, k=8))
|
||||
custom_vid = request.custom_voice_id or f"vc_{random_suffix}"
|
||||
|
||||
result_obj = await loop.run_in_executor(
|
||||
_audio_executor,
|
||||
lambda cv=custom_vid: clone_voice(
|
||||
audio_bytes=voice_sample_bytes,
|
||||
custom_voice_id=cv,
|
||||
text=scene_text,
|
||||
user_id=user_id,
|
||||
),
|
||||
)
|
||||
audio_bytes = result_obj.preview_audio_bytes
|
||||
provider = "minimax"
|
||||
model = "minimax/voice-clone"
|
||||
elif engine == "cosyvoice":
|
||||
result_obj = await loop.run_in_executor(
|
||||
_audio_executor,
|
||||
lambda: cosyvoice_voice_clone(
|
||||
audio_bytes=voice_sample_bytes,
|
||||
text=scene_text,
|
||||
user_id=user_id,
|
||||
audio_mime_type=voice_mime_type,
|
||||
),
|
||||
)
|
||||
audio_bytes = result_obj.preview_audio_bytes
|
||||
provider = "wavespeed-ai"
|
||||
model = "wavespeed-ai/cosyvoice-tts/voice-clone"
|
||||
else:
|
||||
result_obj = await loop.run_in_executor(
|
||||
_audio_executor,
|
||||
lambda: qwen3_voice_clone(
|
||||
audio_bytes=voice_sample_bytes,
|
||||
text=scene_text,
|
||||
user_id=user_id,
|
||||
audio_mime_type=voice_mime_type,
|
||||
),
|
||||
)
|
||||
audio_bytes = result_obj.preview_audio_bytes
|
||||
provider = "wavespeed-ai"
|
||||
model = "wavespeed-ai/qwen3-tts/voice-clone"
|
||||
|
||||
logger.warning(f"[Podcast] 🔊 Voice clone result: {len(audio_bytes) if audio_bytes else 0} bytes, provider={provider}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as clone_err:
|
||||
logger.error(f"[Podcast] ❌ Voice clone failed: {clone_err}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Voice clone generation failed: {str(clone_err)}")
|
||||
|
||||
# Save audio bytes to file
|
||||
audio_service = get_podcast_audio_service(user_id)
|
||||
audio_filename = f"scene_{request.scene_id}_{uuid.uuid4().hex[:8]}.mp3"
|
||||
audio_path = audio_service.output_dir / audio_filename
|
||||
|
||||
with open(audio_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
|
||||
file_size = len(audio_bytes)
|
||||
audio_url = f"/api/podcast/audio/{audio_filename}"
|
||||
cost = max(0.005, 0.005 * (len(scene_text) / 100.0))
|
||||
|
||||
result = {
|
||||
"audio_path": str(audio_path),
|
||||
"audio_filename": audio_filename,
|
||||
"audio_url": audio_url,
|
||||
"file_size": file_size,
|
||||
"provider": provider,
|
||||
"model": model,
|
||||
"cost": cost,
|
||||
"scene_number": 0,
|
||||
"scene_title": request.scene_title,
|
||||
}
|
||||
|
||||
else:
|
||||
# Standard TTS path - but NOT if custom_voice_id is a clone ID
|
||||
# Clone IDs (vc_*, MY_VOICE_CLONE) are not valid for minimax TTS
|
||||
if is_voice_clone:
|
||||
logger.warning(f"[Podcast] ⚠️ Voice clone detected but no voice sample available - falling back to standard TTS with voice_id={effective_voice_id}")
|
||||
effective_custom_voice_id = request.custom_voice_id
|
||||
if effective_custom_voice_id and (
|
||||
effective_custom_voice_id.startswith("vc_") or
|
||||
effective_custom_voice_id == "MY_VOICE_CLONE"
|
||||
):
|
||||
logger.warning(f"[Podcast] Ignoring clone ID '{effective_custom_voice_id}' in standard TTS path - no voice sample URL available")
|
||||
effective_custom_voice_id = None
|
||||
|
||||
audio_service = get_podcast_audio_service(user_id)
|
||||
logger.warning(f"[Podcast] Standard TTS path: voice_id={effective_voice_id}, custom_voice_id={effective_custom_voice_id}")
|
||||
result: StoryAudioResult = audio_service.generate_ai_audio(
|
||||
scene_number=0,
|
||||
scene_title=request.scene_title,
|
||||
text=request.text.strip(),
|
||||
user_id=user_id,
|
||||
voice_id=effective_voice_id,
|
||||
custom_voice_id=effective_custom_voice_id,
|
||||
speed=request.speed or 1.0, # Normal speed (was 0.9, but too slow - causing duration issues)
|
||||
volume=request.volume or 1.0,
|
||||
pitch=request.pitch or 0.0, # Normal pitch (0.0 = neutral)
|
||||
emotion=request.emotion or "neutral",
|
||||
english_normalization=request.english_normalization or False,
|
||||
sample_rate=request.sample_rate,
|
||||
bitrate=request.bitrate,
|
||||
channel=request.channel,
|
||||
format=request.format,
|
||||
language_boost=request.language_boost,
|
||||
enable_sync_mode=request.enable_sync_mode,
|
||||
)
|
||||
|
||||
# Override URL to use podcast endpoint instead of story endpoint
|
||||
if result.get("audio_url") and "/api/story/audio/" in result.get("audio_url", ""):
|
||||
audio_filename = result.get("audio_filename", "")
|
||||
result["audio_url"] = f"/api/podcast/audio/{audio_filename}"
|
||||
|
||||
logger.warning(f"[Podcast] Audio generated - path: {result.get('audio_path')}, url: {result.get('audio_url')}")
|
||||
except HTTPException:
|
||||
raise
|
||||
# Override URL to use podcast endpoint instead of story endpoint
|
||||
if result.get("audio_url") and "/api/story/audio/" in result.get("audio_url", ""):
|
||||
audio_filename = result.get("audio_filename", "")
|
||||
result["audio_url"] = f"/api/podcast/audio/{audio_filename}"
|
||||
|
||||
logger.warning(f"[Podcast] Audio generated - path: {result.get('audio_path')}, url: {result.get('audio_url')}")
|
||||
except Exception as exc:
|
||||
exc_type = type(exc).__name__
|
||||
exc_msg = str(exc)[:500]
|
||||
logger.error(f"[Podcast] Audio generation failed ({exc_type}): {exc_msg}")
|
||||
logger.error(f"[Podcast] Audio generation traceback:", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Audio generation failed ({exc_type}): {exc_msg}")
|
||||
raise HTTPException(status_code=500, detail=f"Audio generation failed: {exc}")
|
||||
|
||||
# Save to asset library (podcast module)
|
||||
try:
|
||||
@@ -671,10 +391,7 @@ async def serve_podcast_audio(
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
logger.info(f"[Podcast] serve_podcast_audio: filename={filename}, user_id={user_id}")
|
||||
|
||||
audio_path = _resolve_podcast_media_file(filename, "audio", user_id)
|
||||
logger.info(f"[Podcast] Audio resolved path: {audio_path}, exists={audio_path.exists()}")
|
||||
logger.debug(f"[Podcast] serve_podcast_audio called: user_id={user_id}, filename={filename}")
|
||||
audio_path = _resolve_podcast_media_file(filename, "audio", user_id)
|
||||
logger.debug(f"[Podcast] Resolved audio path: {audio_path}")
|
||||
|
||||
|
||||
@@ -12,39 +12,22 @@ from pathlib import Path
|
||||
import uuid
|
||||
import hashlib
|
||||
|
||||
from services.database import get_db, get_session_for_user
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.llm_providers.main_image_editing import edit_image
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..constants import get_podcast_media_dir, PODCAST_AVATARS_SUBDIR
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..presenter_personas import choose_persona_id, get_persona
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# Avatar subdirectory
|
||||
AVATAR_SUBDIR = PODCAST_AVATARS_SUBDIR
|
||||
|
||||
|
||||
async def _get_db_or_none(current_user: Dict[str, Any]):
|
||||
"""Try to get a database session, returning None on failure (non-fatal for uploads)."""
|
||||
try:
|
||||
user_id = current_user.get('id') or current_user.get('clerk_user_id')
|
||||
if not user_id:
|
||||
return None
|
||||
return get_session_for_user(user_id)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] DB session unavailable (non-fatal): {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_podcast_avatars_dir(user_id: str) -> Path:
|
||||
"""Get podcast avatars directory for a user (workspace-aware)."""
|
||||
avatars_dir = get_podcast_media_dir("image", user_id, ensure_exists=True) / AVATAR_SUBDIR
|
||||
avatars_dir.mkdir(parents=True, exist_ok=True)
|
||||
return avatars_dir
|
||||
AVATAR_SUBDIR = "avatars"
|
||||
PODCAST_AVATARS_DIR = PODCAST_IMAGES_DIR / AVATAR_SUBDIR
|
||||
PODCAST_AVATARS_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
@router.post("/avatar/upload")
|
||||
@@ -58,16 +41,8 @@ async def upload_podcast_avatar(
|
||||
Upload a presenter avatar image for a podcast project.
|
||||
Returns the avatar URL for use in scene image generation.
|
||||
"""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Avatar upload auth failed: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=401, detail="Authentication failed")
|
||||
|
||||
logger.info(f"[Podcast] Avatar upload request - user_id={user_id}, project_id={project_id}, content_type={file.content_type}")
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Validate file type
|
||||
if not file.content_type or not file.content_type.startswith('image/'):
|
||||
raise HTTPException(status_code=400, detail="File must be an image")
|
||||
@@ -82,21 +57,19 @@ async def upload_podcast_avatar(
|
||||
file_ext = Path(file.filename).suffix or '.png'
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
avatar_filename = f"avatar_{project_id or 'temp'}_{unique_id}{file_ext}"
|
||||
avatars_dir = _get_podcast_avatars_dir(user_id)
|
||||
logger.info(f"[Podcast] Saving avatar to: {avatars_dir / avatar_filename}")
|
||||
avatar_path = avatars_dir / avatar_filename
|
||||
avatar_path = PODCAST_AVATARS_DIR / avatar_filename
|
||||
|
||||
# Save file
|
||||
with open(avatar_path, "wb") as f:
|
||||
f.write(file_content)
|
||||
|
||||
logger.info(f"[Podcast] Avatar uploaded successfully: {avatar_path}")
|
||||
logger.info(f"[Podcast] Avatar uploaded: {avatar_path}")
|
||||
|
||||
# Create avatar URL
|
||||
avatar_url = f"/api/podcast/images/{AVATAR_SUBDIR}/{avatar_filename}"
|
||||
|
||||
# Save to asset library if project_id provided and DB session available
|
||||
if project_id and db:
|
||||
# Save to asset library if project_id provided
|
||||
if project_id:
|
||||
try:
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
@@ -118,17 +91,13 @@ async def upload_podcast_avatar(
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] Failed to save avatar asset (non-fatal): {e}")
|
||||
elif project_id and not db:
|
||||
logger.warning(f"[Podcast] DB session unavailable, skipping asset library save for avatar")
|
||||
logger.warning(f"[Podcast] Failed to save avatar asset: {e}")
|
||||
|
||||
return {
|
||||
"avatar_url": avatar_url,
|
||||
"avatar_filename": avatar_filename,
|
||||
"message": "Avatar uploaded successfully"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast] Avatar upload failed: {exc}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Avatar upload failed: {str(exc)}")
|
||||
@@ -155,7 +124,7 @@ async def make_avatar_presentable(
|
||||
# Load the uploaded avatar image
|
||||
from ..utils import load_podcast_image_bytes
|
||||
logger.info(f"[Podcast] Loading avatar image from {avatar_url}")
|
||||
avatar_bytes = load_podcast_image_bytes(avatar_url, user_id=user_id)
|
||||
avatar_bytes = load_podcast_image_bytes(avatar_url)
|
||||
logger.info(f"[Podcast] Avatar loaded successfully - size={len(avatar_bytes)} bytes")
|
||||
|
||||
logger.info(f"[Podcast] Transforming avatar to podcast presenter for project {project_id}")
|
||||
@@ -194,8 +163,7 @@ async def make_avatar_presentable(
|
||||
# Save transformed avatar
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
transformed_filename = f"presenter_transformed_{project_id or 'temp'}_{unique_id}.png"
|
||||
avatars_dir = _get_podcast_avatars_dir(user_id)
|
||||
transformed_path = avatars_dir / transformed_filename
|
||||
transformed_path = PODCAST_AVATARS_DIR / transformed_filename
|
||||
|
||||
with open(transformed_path, "wb") as f:
|
||||
f.write(result.image_bytes)
|
||||
@@ -377,8 +345,7 @@ async def generate_podcast_presenters(
|
||||
# Save avatar
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
avatar_filename = f"presenter_{project_id or 'temp'}_{i+1}_{unique_id}.png"
|
||||
avatars_dir = _get_podcast_avatars_dir(user_id)
|
||||
avatar_path = avatars_dir / avatar_filename
|
||||
avatar_path = PODCAST_AVATARS_DIR / avatar_filename
|
||||
|
||||
with open(avatar_path, "wb") as f:
|
||||
f.write(result.image_bytes)
|
||||
|
||||
@@ -4,125 +4,19 @@ B-Roll Handlers
|
||||
API endpoints for B-roll chart preview and video generation.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, BackgroundTasks
|
||||
from fastapi.responses import FileResponse
|
||||
from typing import Dict, Any, Optional, List
|
||||
from pydantic import BaseModel, Field
|
||||
from pathlib import Path
|
||||
import uuid
|
||||
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from api.story_writer.task_manager import task_manager
|
||||
from api.podcast.utils import _resolve_podcast_media_file
|
||||
from services.podcast.broll_service import get_broll_service
|
||||
from utils.media_utils import resolve_media_path
|
||||
from loguru import logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/broll", tags=["B-Roll"])
|
||||
|
||||
|
||||
def _resolve_broll_background_image_path(background_image_url: str) -> str:
|
||||
"""Resolve background image URL/path to a local file path."""
|
||||
resolved = resolve_media_path(background_image_url)
|
||||
if not resolved:
|
||||
raise HTTPException(status_code=404, detail=f"Background image not found: {background_image_url}")
|
||||
return str(resolved)
|
||||
|
||||
|
||||
def _resolve_broll_avatar_video_path(avatar_video_url: Optional[str], user_id: str) -> Optional[str]:
|
||||
"""Resolve optional avatar video URL/path to a local file path."""
|
||||
if not avatar_video_url:
|
||||
return None
|
||||
|
||||
parsed = urlparse(avatar_video_url)
|
||||
path = parsed.path if parsed.scheme else avatar_video_url
|
||||
|
||||
if "/api/podcast/videos/" in path:
|
||||
filename = path.split("/api/podcast/videos/", 1)[1].split("?", 1)[0].strip()
|
||||
if not filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid avatar video URL")
|
||||
return str(_resolve_podcast_media_file(filename, "video", user_id))
|
||||
|
||||
local_path = Path(path).expanduser().resolve()
|
||||
if local_path.exists() and local_path.is_file():
|
||||
return str(local_path)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
"Unsupported avatar video URL format. "
|
||||
"Use /api/podcast/videos/{filename} or a valid local file path."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _execute_broll_scene_task(
|
||||
task_id: str,
|
||||
*,
|
||||
scene_id: str,
|
||||
key_insight: str,
|
||||
supporting_stat: str,
|
||||
chart_data: Optional[Dict[str, Any]],
|
||||
visual_cue: str,
|
||||
duration: float,
|
||||
background_img_path: str,
|
||||
avatar_video_path: Optional[str],
|
||||
):
|
||||
"""Background task for rendering a B-roll scene."""
|
||||
try:
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=10.0,
|
||||
message="Starting B-roll scene render...",
|
||||
)
|
||||
|
||||
broll_service = get_broll_service()
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=35.0,
|
||||
message="Composing scene layers and overlays...",
|
||||
)
|
||||
|
||||
video_path = broll_service.generate_scene_broll(
|
||||
scene_id=scene_id,
|
||||
key_insight=key_insight,
|
||||
supporting_stat=supporting_stat,
|
||||
chart_data=chart_data,
|
||||
visual_cue=visual_cue,
|
||||
duration=duration,
|
||||
background_img_path=background_img_path,
|
||||
avatar_video_path=avatar_video_path,
|
||||
)
|
||||
|
||||
filename = Path(video_path).name
|
||||
video_url = f"/api/podcast/broll/final/{filename}"
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"completed",
|
||||
progress=100.0,
|
||||
message="B-roll scene render completed.",
|
||||
result={
|
||||
"scene_id": scene_id,
|
||||
"broll_video_path": video_path,
|
||||
"broll_video_url": video_url,
|
||||
},
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Broll] Task {task_id} failed: {exc}")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"failed",
|
||||
error=f"B-roll scene render failed: {str(exc)}",
|
||||
error_status=500,
|
||||
)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class ChartPreviewRequest(BaseModel):
|
||||
@@ -148,7 +42,7 @@ class BrollSceneRequest(BaseModel):
|
||||
key_insight: str
|
||||
supporting_stat: str
|
||||
chart_data: Optional[Dict[str, Any]] = None
|
||||
visual_cue: str = Field(default="bar_comparison", description="bar_comparison | bar_horizontal | line_trend | pie | stacked_bar | bullet_points | full_avatar")
|
||||
visual_cue: str = Field(default="bar_chart_comparison", description="bar_chart_comparison | bullet_points")
|
||||
duration: float = Field(default=10.0, ge=3.0, le=60.0)
|
||||
background_image_url: str
|
||||
avatar_video_url: Optional[str] = None
|
||||
@@ -157,11 +51,8 @@ class BrollSceneRequest(BaseModel):
|
||||
class BrollSceneResponse(BaseModel):
|
||||
"""Response for B-roll scene generation."""
|
||||
scene_id: str
|
||||
broll_video_url: str = ""
|
||||
broll_video_path: str = ""
|
||||
task_id: Optional[str] = None
|
||||
status: str = "completed"
|
||||
message: Optional[str] = None
|
||||
broll_video_url: str
|
||||
broll_video_path: str
|
||||
|
||||
|
||||
class BrollComposeRequest(BaseModel):
|
||||
@@ -191,34 +82,21 @@ async def generate_chart_preview(
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Debug logging
|
||||
logger.warning(f"[Broll] Chart preview request: type={request.chart_type}, title={request.title}, chart_data keys={list(request.chart_data.keys())}, user_id={user_id}")
|
||||
|
||||
try:
|
||||
broll_service = get_broll_service(user_id=user_id)
|
||||
chart_id = uuid.uuid4().hex[:8]
|
||||
broll_service = get_broll_service()
|
||||
|
||||
preview_path = broll_service.generate_chart_preview(
|
||||
chart_data=request.chart_data,
|
||||
chart_type=request.chart_type,
|
||||
title=request.title,
|
||||
subtitle=request.subtitle or "",
|
||||
chart_id=chart_id,
|
||||
)
|
||||
|
||||
# If chart generation failed (empty path), return a placeholder instead of 500
|
||||
if not preview_path:
|
||||
# Return a fallback response so frontend doesn't crash
|
||||
logger.warning(f"[Broll] Chart preview skipped - invalid data for type: {request.chart_type}")
|
||||
return ChartPreviewResponse(
|
||||
preview_url="",
|
||||
chart_id=chart_id,
|
||||
)
|
||||
raise HTTPException(status_code=500, detail="Failed to generate chart preview")
|
||||
|
||||
preview_filename = Path(preview_path).name
|
||||
preview_url = f"/api/podcast/broll/preview/{chart_id}/{preview_filename}"
|
||||
|
||||
logger.warning(f"[Broll] Chart preview generated: chart_id={chart_id}, path={preview_path}, url={preview_url}")
|
||||
chart_id = uuid.uuid4().hex[:8]
|
||||
preview_url = f"/api/podcast/broll/preview/{chart_id}/{preview_path.split('/')[-1]}"
|
||||
|
||||
return ChartPreviewResponse(
|
||||
preview_url=preview_url,
|
||||
@@ -251,42 +129,23 @@ async def generate_broll_scene(
|
||||
|
||||
try:
|
||||
# Validate visual_cue
|
||||
valid_cues = ["bar_comparison", "bar_chart_comparison", "bar_horizontal", "line_trend", "pie", "stacked_bar", "bullet_points", "full_avatar"]
|
||||
valid_cues = ["bar_chart_comparison", "bullet_points", "full_avatar"]
|
||||
if request.visual_cue not in valid_cues:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid visual_cue. Must be one of: {valid_cues}"
|
||||
)
|
||||
|
||||
background_img_path = _resolve_broll_background_image_path(request.background_image_url)
|
||||
avatar_video_path = _resolve_broll_avatar_video_path(request.avatar_video_url, user_id)
|
||||
|
||||
# For now, return a placeholder - full video generation requires
|
||||
# resolving image/video URLs to actual file paths
|
||||
# In V2, this will integrate with the actual video generation
|
||||
|
||||
logger.info(f"[Broll] B-roll scene request for scene: {request.scene_id}")
|
||||
|
||||
# Scene rendering can be expensive, so use task manager/background execution.
|
||||
task_id = task_manager.create_task(
|
||||
"podcast_broll_scene_generation",
|
||||
metadata={"owner_user_id": user_id, "scene_id": request.scene_id},
|
||||
)
|
||||
|
||||
background_tasks.add_task(
|
||||
_execute_broll_scene_task,
|
||||
task_id=task_id,
|
||||
scene_id=request.scene_id,
|
||||
key_insight=request.key_insight,
|
||||
supporting_stat=request.supporting_stat,
|
||||
chart_data=request.chart_data,
|
||||
visual_cue=request.visual_cue,
|
||||
duration=request.duration,
|
||||
background_img_path=background_img_path,
|
||||
avatar_video_path=avatar_video_path,
|
||||
)
|
||||
|
||||
|
||||
return BrollSceneResponse(
|
||||
scene_id=request.scene_id,
|
||||
task_id=task_id,
|
||||
status="pending",
|
||||
message="B-roll scene render started. Poll /api/podcast/task/{task_id}/status for progress.",
|
||||
broll_video_url="",
|
||||
broll_video_path="",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
@@ -335,35 +194,19 @@ async def compose_broll_videos(
|
||||
async def serve_chart_preview(
|
||||
chart_id: str,
|
||||
filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Serve chart preview PNG files.
|
||||
"""Serve chart preview PNG files."""
|
||||
from pathlib import Path
|
||||
|
||||
Uses authentication via Authorization header or token query parameter,
|
||||
matching the pattern used by /api/podcast/images/ for browser <img> tags.
|
||||
"""
|
||||
from api.podcast.constants import get_podcast_media_dir
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Validate filename to prevent directory traversal
|
||||
if ".." in filename or "/" in filename or "\\" in filename:
|
||||
raise HTTPException(status_code=400, detail="Invalid filename")
|
||||
|
||||
logger.warning(f"[Broll] serve_chart_preview: chart_id={chart_id}, filename={filename}, user_id={user_id}")
|
||||
|
||||
charts_dir = get_podcast_media_dir("chart", user_id)
|
||||
file_path = charts_dir / filename
|
||||
|
||||
logger.warning(f"[Broll] serve_chart_preview: resolved path={file_path}, exists={file_path.exists()}")
|
||||
broll_service = get_broll_service()
|
||||
file_path = broll_service.output_dir / f"chart_preview_{chart_id}.png"
|
||||
|
||||
if not file_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Chart preview not found")
|
||||
|
||||
# Security: ensure resolved path is within charts_dir
|
||||
if not str(file_path.resolve()).startswith(str(charts_dir.resolve())):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
return FileResponse(
|
||||
path=str(file_path),
|
||||
media_type="image/png",
|
||||
@@ -395,4 +238,4 @@ async def serve_final_broll(
|
||||
@router.get("/health")
|
||||
async def broll_health():
|
||||
"""Health check for B-roll service."""
|
||||
return {"status": "ok", "service": "broll"}
|
||||
return {"status": "ok", "service": "broll"}
|
||||
@@ -17,7 +17,7 @@ from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_image_generation import generate_image, generate_character_image
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..constants import get_podcast_media_dir
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..models import PodcastImageRequest, PodcastImageResponse
|
||||
|
||||
router = APIRouter()
|
||||
@@ -69,7 +69,7 @@ async def generate_podcast_scene_image(
|
||||
from ..utils import load_podcast_image_bytes
|
||||
try:
|
||||
logger.info(f"[Podcast] Attempting to load base avatar from: {request.base_avatar_url}")
|
||||
base_avatar_bytes = load_podcast_image_bytes(request.base_avatar_url, user_id=user_id)
|
||||
base_avatar_bytes = load_podcast_image_bytes(request.base_avatar_url)
|
||||
logger.info(f"[Podcast] ✅ Successfully loaded base avatar ({len(base_avatar_bytes)} bytes) for scene {request.scene_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] ❌ Failed to load base avatar from {request.base_avatar_url}: {e}", exc_info=True)
|
||||
@@ -377,14 +377,14 @@ async def generate_podcast_scene_image(
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Save image to podcast images directory (workspace-aware)
|
||||
images_dir = get_podcast_media_dir("image", user_id, ensure_exists=True)
|
||||
# Save image to podcast images directory
|
||||
PODCAST_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate filename
|
||||
clean_title = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in request.scene_title[:30])
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
image_filename = f"scene_{request.scene_id}_{clean_title}_{unique_id}.png"
|
||||
image_path = images_dir / image_filename
|
||||
image_path = PODCAST_IMAGES_DIR / image_filename
|
||||
|
||||
# Save image
|
||||
with open(image_path, "wb") as f:
|
||||
@@ -470,17 +470,16 @@ async def serve_podcast_image(
|
||||
Query parameter is useful for HTML elements like <img> that cannot send custom headers.
|
||||
Supports subdirectories like avatars/
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
require_authenticated_user(current_user)
|
||||
|
||||
# Security check: ensure path doesn't contain path traversal or absolute paths
|
||||
if ".." in path or path.startswith("/"):
|
||||
raise HTTPException(status_code=400, detail="Invalid path")
|
||||
|
||||
images_dir = get_podcast_media_dir("image", user_id)
|
||||
image_path = (images_dir / path).resolve()
|
||||
image_path = (PODCAST_IMAGES_DIR / path).resolve()
|
||||
|
||||
# Security check: ensure resolved path is within images_dir
|
||||
if not str(image_path).startswith(str(images_dir)):
|
||||
# Security check: ensure resolved path is within PODCAST_IMAGES_DIR
|
||||
if not str(image_path).startswith(str(PODCAST_IMAGES_DIR)):
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
if not image_path.exists():
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import Optional, Dict, Any
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.podcast_service import PodcastService
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastProjectResponse,
|
||||
CreateProjectRequest,
|
||||
@@ -107,57 +106,25 @@ async def update_project(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Update a podcast project state."""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
logger.error(f"[Podcast] update_project: No user_id found in current_user: {current_user}")
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
# Get only field names being updated (not full data to avoid console flooding)
|
||||
request_dict = request.model_dump(exclude_none=True)
|
||||
updated_fields = list(request_dict.keys())
|
||||
|
||||
logger.warning(f"[Podcast] ===== UPDATE_PROJECT_START =====")
|
||||
logger.warning(f"[Podcast] project_id={project_id}, user_id={user_id}, fields={updated_fields}")
|
||||
|
||||
service = PodcastService(db)
|
||||
|
||||
# Check if project exists; if not, create it (upsert behavior for resilience)
|
||||
existing = service.get_project(user_id, project_id)
|
||||
if not existing:
|
||||
logger.warning(f"[Podcast] Project {project_id} not found for user {user_id}, creating new project with default values")
|
||||
# Try to create the project - this handles cases where create succeeded but wasn't found later
|
||||
# (can happen with user_id mismatch or after session refresh)
|
||||
try:
|
||||
project = service.create_project(
|
||||
user_id=user_id,
|
||||
project_id=project_id,
|
||||
idea="Untitled Podcast",
|
||||
status="scripting",
|
||||
duration=10,
|
||||
speakers=1,
|
||||
budget_cap=0.0,
|
||||
)
|
||||
except Exception as create_err:
|
||||
logger.error(f"[Podcast] Failed to create project {project_id}: {create_err}")
|
||||
raise HTTPException(status_code=404, detail=f"Project {project_id} not found and could not create: {create_err}")
|
||||
else:
|
||||
# Convert request to dict, excluding None values
|
||||
updates = request.model_dump(exclude_unset=True)
|
||||
project = service.update_project(user_id, project_id, **updates)
|
||||
# Convert request to dict, excluding None values
|
||||
updates = request.model_dump(exclude_unset=True)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[Podcast] ===== UPDATE_PROJECT_END (took {duration_ms}ms) =====")
|
||||
project = service.update_project(user_id, project_id, **updates)
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail=f"Project {project_id} not found")
|
||||
|
||||
return PodcastProjectResponse.model_validate(project)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"[Podcast] ===== UPDATE_PROJECT_ERROR (took {duration_ms}ms): {str(e)} =====")
|
||||
raise HTTPException(status_code=500, detail=f"Error updating project: {str(e)}")
|
||||
|
||||
|
||||
|
||||
@@ -9,13 +9,10 @@ from typing import Dict, Any, List
|
||||
from types import SimpleNamespace
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.database import get_db
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
@@ -23,7 +20,6 @@ from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
from loguru import logger
|
||||
from ..cost_estimator import estimate_podcast_cost
|
||||
from ..models import (
|
||||
PodcastExaResearchRequest,
|
||||
PodcastExaResearchResponse,
|
||||
@@ -64,7 +60,6 @@ def _build_research_cost_estimate(
|
||||
raw_content: str,
|
||||
sources_count: int,
|
||||
provider_result: Dict[str, Any],
|
||||
user_id: str = "default",
|
||||
) -> PodcastCostEst:
|
||||
# Fallback defaults mirror current catalog defaults.
|
||||
exa_per_request = 0.005
|
||||
@@ -72,19 +67,17 @@ def _build_research_cost_estimate(
|
||||
gemini_out_token = 0.0000006
|
||||
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
exa_per_request = _get_price_from_catalog(
|
||||
pricing_service, APIProvider.EXA, "exa-search", "cost_per_request", exa_per_request
|
||||
)
|
||||
gemini_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.GEMINI, "gemini-2.5-flash") or {}
|
||||
gemini_in_token = float(gemini_pricing.get("cost_per_input_token") or gemini_in_token)
|
||||
gemini_out_token = float(gemini_pricing.get("cost_per_output_token") or gemini_out_token)
|
||||
finally:
|
||||
db.close()
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
exa_per_request = _get_price_from_catalog(
|
||||
pricing_service, APIProvider.EXA, "exa-search", "cost_per_request", exa_per_request
|
||||
)
|
||||
gemini_pricing = pricing_service.get_pricing_for_provider_model(APIProvider.GEMINI, "gemini-2.5-flash") or {}
|
||||
gemini_in_token = float(gemini_pricing.get("cost_per_input_token") or gemini_in_token)
|
||||
gemini_out_token = float(gemini_pricing.get("cost_per_output_token") or gemini_out_token)
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as pricing_err:
|
||||
logger.warning(f"[Podcast Research] Failed loading pricing catalog; using defaults: {pricing_err}")
|
||||
|
||||
@@ -133,18 +126,15 @@ def _build_research_cost_estimate(
|
||||
async def podcast_research_exa(
|
||||
request: PodcastExaResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Run podcast research via Exa and then use LLM to extract deep insights.
|
||||
Uses Podcast Bible and Analysis context for hyper-personalization.
|
||||
"""
|
||||
start_time = time.time()
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Log only essential info, not full request data
|
||||
logger.warning(f"[Podcast Research] ===== RESEARCH_START =====")
|
||||
logger.warning(f"[Podcast Research] user={user_id}, topic='{request.topic[:50]}...', queries={len(request.queries) if request.queries else 0}")
|
||||
logger.warning(f"[Podcast Research] ========== REQUEST START ==========")
|
||||
logger.warning(f"[Podcast Research] User: {user_id}, Topic: {request.topic[:80]}...")
|
||||
logger.warning(f"[Podcast Research] Queries count: {len(request.queries) if request.queries else 0}")
|
||||
|
||||
|
||||
queries = [q.strip() for q in request.queries if q and q.strip()]
|
||||
@@ -202,26 +192,6 @@ Listener CTA: {request.analysis.get('listener_cta', 'N/A')}
|
||||
interests = ", ".join(audience_dna.get("interests", []))
|
||||
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
|
||||
|
||||
# Preflight subscription check for Exa
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'exa', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[Podcast Research] Preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast Research] Preflight check failed: {e}")
|
||||
|
||||
try:
|
||||
# 1. RUN EXA SEARCH
|
||||
logger.warning(f"[Podcast Research] Calling Exa search with topic: {request.topic[:100]}...")
|
||||
@@ -244,9 +214,6 @@ Listener CTA: {request.analysis.get('listener_cta', 'N/A')}
|
||||
|
||||
summary = ""
|
||||
key_insights = []
|
||||
expert_quotes = []
|
||||
listener_cta_suggestions = []
|
||||
mapped_angles = []
|
||||
|
||||
if raw_content and sources:
|
||||
logger.warning(f"[Podcast Research] Extracting insights from {len(sources)} sources for user {user_id}")
|
||||
@@ -366,22 +333,13 @@ QUALITY STANDARDS:
|
||||
try:
|
||||
summary = data.get("summary", "")
|
||||
key_insights = [PodcastResearchInsight(**insight) for insight in data.get("key_insights", [])]
|
||||
expert_quotes = data.get("expert_quotes", [])
|
||||
listener_cta_suggestions = data.get("listener_cta_suggestions", [])
|
||||
mapped_angles = data.get("mapped_angles", [])
|
||||
except Exception as insight_err:
|
||||
logger.warning(f"[Podcast Research] Failed to parse insights: {insight_err}. Data keys: {list(data.keys()) if isinstance(data, dict) else 'not a dict'}")
|
||||
summary = data.get("summary", "") if isinstance(data, dict) else ""
|
||||
key_insights = []
|
||||
expert_quotes = data.get("expert_quotes", []) if isinstance(data, dict) else []
|
||||
listener_cta_suggestions = data.get("listener_cta_suggestions", []) if isinstance(data, dict) else []
|
||||
mapped_angles = data.get("mapped_angles", []) if isinstance(data, dict) else []
|
||||
else:
|
||||
summary = ""
|
||||
key_insights = []
|
||||
expert_quotes = []
|
||||
listener_cta_suggestions = []
|
||||
mapped_angles = []
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
@@ -433,24 +391,6 @@ QUALITY STANDARDS:
|
||||
"credibility_score": src.get("credibility_score"),
|
||||
}))
|
||||
|
||||
duration_minutes = 10
|
||||
speakers = 1
|
||||
if request.analysis:
|
||||
duration_minutes = int(request.analysis.get("duration", 10) or 10)
|
||||
speakers = int(request.analysis.get("speakers", 1) or 1)
|
||||
|
||||
estimate = estimate_podcast_cost(
|
||||
db=db,
|
||||
duration_minutes=duration_minutes,
|
||||
speakers=speakers,
|
||||
query_count=len(queries),
|
||||
include_avatar_phase=True,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[Podcast Research] ===== RESEARCH_END (took {duration_ms}ms) =====")
|
||||
logger.warning(f"[Podcast Research] sources={len(sources_payload)}, insights={len(key_insights)}, summary_len={len(summary)}")
|
||||
|
||||
return PodcastExaResearchResponse(
|
||||
sources=sources_payload,
|
||||
search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries,
|
||||
@@ -461,13 +401,8 @@ QUALITY STANDARDS:
|
||||
raw_content=raw_content,
|
||||
sources_count=len(sources_payload),
|
||||
provider_result=result if isinstance(result, dict) else {},
|
||||
user_id=user_id,
|
||||
),
|
||||
search_type=result.get("search_type") if isinstance(result, dict) else None,
|
||||
provider=result.get("provider", "exa") if isinstance(result, dict) else "exa",
|
||||
content=raw_content,
|
||||
mapped_angles=mapped_angles,
|
||||
expert_quotes=expert_quotes,
|
||||
listener_cta_suggestions=listener_cta_suggestions,
|
||||
estimate=estimate,
|
||||
)
|
||||
|
||||
@@ -8,8 +8,6 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
@@ -25,8 +23,6 @@ from ..models import (
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
MAX_TTS_CHARS_PER_REQUEST = 10_000
|
||||
TARGET_TTS_CHARS_PER_SCENE = 8_500
|
||||
|
||||
|
||||
class SceneApprovalRequest(BaseModel):
|
||||
@@ -61,46 +57,31 @@ async def generate_podcast_script(
|
||||
Generate a podcast script outline (scenes + lines) using podcast-oriented prompting.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
start_time = time.time()
|
||||
logger.warning(f"[ScriptGen] ===== SCRIPT_GEN_START =====")
|
||||
logger.warning(f"[ScriptGen] user={user_id}, topic='{request.idea[:50]}...', duration={request.duration_minutes}min, speakers={request.speakers}")
|
||||
podcast_mode = (request.podcast_mode or "video_only").strip().lower()
|
||||
logger.warning(f"[ScriptGen] research={bool(request.research)}, bible={bool(request.bible)}, analysis={bool(request.analysis)}, mode={podcast_mode}")
|
||||
research_fact_cards = request.research.get("factCards", []) if request.research else []
|
||||
logger.warning(f"[ScriptGen] ========== SCRIPT GENERATION START ==========")
|
||||
logger.warning(f"[ScriptGen] Topic: {request.idea[:60]}...")
|
||||
logger.warning(f"[ScriptGen] Duration: {request.duration_minutes} min, Speakers: {request.speakers}")
|
||||
logger.warning(f"[ScriptGen] Has research: {bool(request.research)}, Has bible: {bool(request.bible)}, Has analysis: {bool(request.analysis)}")
|
||||
|
||||
# Build comprehensive research context for higher-quality scripts
|
||||
research_context = ""
|
||||
if request.research:
|
||||
try:
|
||||
key_insights = request.research.get("keyword_analysis", {}).get("key_insights") or []
|
||||
fact_cards = research_fact_cards or []
|
||||
fact_cards = request.research.get("factCards", []) or []
|
||||
mapped_angles = request.research.get("mappedAngles", []) or []
|
||||
sources = request.research.get("sources", []) or []
|
||||
|
||||
top_facts = [
|
||||
f"[{f.get('id') or f'fact_{idx + 1}'}] {f.get('quote', '')}"
|
||||
for idx, f in enumerate(fact_cards[:10])
|
||||
if f.get("quote")
|
||||
]
|
||||
top_facts = [f.get("quote", "") for f in fact_cards[:5] if f.get("quote")]
|
||||
angles_summary = [
|
||||
f"{a.get('title', '')}: {a.get('why', '')}" for a in mapped_angles[:3] if a.get("title") or a.get("why")
|
||||
]
|
||||
top_sources = [s.get("url") for s in sources[:3] if s.get("url")]
|
||||
numeric_signals = []
|
||||
for f in fact_cards[:12]:
|
||||
quote = (f.get("quote") or "").strip()
|
||||
if any(ch.isdigit() for ch in quote):
|
||||
numeric_signals.append(quote[:180])
|
||||
if len(numeric_signals) >= 5:
|
||||
break
|
||||
|
||||
research_parts = []
|
||||
if key_insights:
|
||||
research_parts.append(f"Key Insights: {', '.join(key_insights[:5])}")
|
||||
if top_facts:
|
||||
research_parts.append(f"Key Facts: {', '.join(top_facts)}")
|
||||
if numeric_signals:
|
||||
research_parts.append(f"Numeric Signals (prefer for chart scenes): {' | '.join(numeric_signals)}")
|
||||
if angles_summary:
|
||||
research_parts.append(f"Research Angles: {' | '.join(angles_summary)}")
|
||||
if top_sources:
|
||||
@@ -111,53 +92,6 @@ async def generate_podcast_script(
|
||||
logger.warning(f"Failed to parse research context: {exc}")
|
||||
research_context = ""
|
||||
|
||||
def _normalize_fact_ids(value: Any) -> Optional[list[str]]:
|
||||
if not value:
|
||||
return None
|
||||
if isinstance(value, list):
|
||||
cleaned = [str(v).strip() for v in value if str(v).strip()]
|
||||
return cleaned or None
|
||||
if isinstance(value, str) and value.strip():
|
||||
return [value.strip()]
|
||||
return None
|
||||
|
||||
def _default_chart_data(scene_title: str) -> Dict[str, Any]:
|
||||
numeric_pairs: list[tuple[str, float]] = []
|
||||
for fact in research_fact_cards[:12]:
|
||||
quote = (fact.get("quote") or "").strip()
|
||||
if not quote:
|
||||
continue
|
||||
nums = re.findall(r"\d+(?:\.\d+)?", quote.replace(",", ""))
|
||||
if not nums:
|
||||
continue
|
||||
label = quote[:48] + ("…" if len(quote) > 48 else "")
|
||||
try:
|
||||
numeric_pairs.append((label, float(nums[0])))
|
||||
except ValueError:
|
||||
continue
|
||||
if len(numeric_pairs) >= 5:
|
||||
break
|
||||
|
||||
if numeric_pairs:
|
||||
labels = [p[0] for p in numeric_pairs]
|
||||
values = [p[1] for p in numeric_pairs]
|
||||
sources = [f.get("url", f.get("source", "")) for f in research_fact_cards[:12] if f.get("url") or f.get("source")]
|
||||
return {
|
||||
"type": "bar_comparison",
|
||||
"title": scene_title,
|
||||
"labels": labels,
|
||||
"values": values,
|
||||
"takeaway": "Data points sourced from research facts used in this scene.",
|
||||
"source": sources[0] if sources else "",
|
||||
}
|
||||
|
||||
return {
|
||||
"type": "bullet_points",
|
||||
"title": scene_title,
|
||||
"bullet_points": ["Key point 1", "Key point 2", "Key point 3"],
|
||||
"takeaway": "Narration summary for this scene.",
|
||||
}
|
||||
|
||||
# Extract Podcast Bible context for hyper-personalization
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
@@ -188,62 +122,25 @@ async def generate_podcast_script(
|
||||
except:
|
||||
pass
|
||||
|
||||
mode_instructions = ""
|
||||
if podcast_mode == "audio_only":
|
||||
mode_instructions = f"""
|
||||
AUDIO-ONLY MODE RULES (CRITICAL):
|
||||
- This is an audio-only episode. Do NOT include avatar/image/camera instructions.
|
||||
- Keep each scene's total dialogue under {TARGET_TTS_CHARS_PER_SCENE} chars to stay below TTS max request size ({MAX_TTS_CHARS_PER_REQUEST}).
|
||||
- For every scene include chart_data so B-roll charts can be generated while narration plays.
|
||||
- Build script STRICTLY from RESEARCH context and cite fact linkage via usedFactIds.
|
||||
- If evidence is weak, say uncertainty explicitly rather than inventing facts.
|
||||
- Add natural TTS pacing in dialogue with markers like [pause:300ms], [pause:700ms], [emote:curious], [emote:serious].
|
||||
"""
|
||||
elif podcast_mode == "audio_video":
|
||||
mode_instructions = """
|
||||
AUDIO+VIDEO MODE:
|
||||
- Include rich narration that works for both listening and visual storytelling.
|
||||
- Use a balanced pace suitable for TTS and scene visuals.
|
||||
"""
|
||||
else:
|
||||
mode_instructions = """
|
||||
VIDEO-ONLY MODE:
|
||||
- Prioritize visual rhythm and concise narration per scene.
|
||||
"""
|
||||
|
||||
prompt = f"""Create a podcast script with scenes and dialogue.
|
||||
|
||||
{f"BIBLE: {bible_context[:1500]}" if bible_context else ""}
|
||||
{f"{analysis_context}" if analysis_context else ""}
|
||||
{f"{outline_context}" if outline_context else ""}
|
||||
{f"RESEARCH: {research_context[:2500]}" if research_context else ""}
|
||||
{mode_instructions}
|
||||
{f"RESEARCH: {research_context[:1200]}" if research_context else ""}
|
||||
|
||||
Topic: "{request.idea}"
|
||||
Duration: {request.duration_minutes} min | Speakers: {request.speakers}
|
||||
Podcast mode: {podcast_mode}
|
||||
|
||||
Return JSON with scenes array. Each scene:
|
||||
- id: string
|
||||
- title: short title (<=50 chars)
|
||||
- duration: seconds (total/5)
|
||||
- emotion: neutral|happy|excited|serious|curious|confident
|
||||
- lines: array of {{speaker, text, emphasis, usedFactIds, ttsHints}}
|
||||
- lines: array of {{speaker, text, emphasis}}
|
||||
- Use 2-4 LINES PER SCENE (shorter script = lower TTS costs)
|
||||
- Each line: 1-3 sentences, conversational
|
||||
- usedFactIds: include related fact ids when research facts are available (example: ["fact_1", "fact_3"])
|
||||
- ttsHints: optional list from [pause_300ms, pause_700ms, smile, serious_tone, emphasize_data]
|
||||
- Plain text only, no markdown
|
||||
- chart_data: object for B-roll mapping (required in audio_only)
|
||||
- type: bar_comparison|bar_horizontal|line_trend|pie|stacked_bar|bullet_points
|
||||
- title: short chart title
|
||||
- labels: list
|
||||
- values: list (same length as labels, required for bar/line/pie)
|
||||
- before/after: parallel lists of numbers (for bar_comparison only)
|
||||
- segments: list of {{name, values}} (for stacked_bar only)
|
||||
- bullet_points: list of strings (for bullet_points only)
|
||||
- takeaway: one sentence tying chart to narration
|
||||
- source: URL or citation for the data (e.g. "Research fact #3" or a URL from the research context)
|
||||
|
||||
COST OPTIMIZATION:
|
||||
- 5-6 scenes max for {request.duration_minutes} min episode
|
||||
@@ -334,8 +231,7 @@ COST OPTIMIZATION:
|
||||
line_id = line.get("id") or f"line-{idx + 1}-{line_idx + 1}"
|
||||
|
||||
# Get used fact IDs if provided
|
||||
used_fact_ids = _normalize_fact_ids(line.get("usedFactIds") or line.get("used_fact_ids"))
|
||||
tts_hints = line.get("ttsHints") or line.get("tts_hints") or None
|
||||
used_fact_ids = line.get("usedFactIds") or line.get("used_fact_ids") or None
|
||||
|
||||
if text:
|
||||
lines.append(PodcastSceneLine(
|
||||
@@ -343,8 +239,7 @@ COST OPTIMIZATION:
|
||||
text=text,
|
||||
emphasis=emphasis,
|
||||
id=line_id,
|
||||
usedFactIds=used_fact_ids,
|
||||
ttsHints=tts_hints if isinstance(tts_hints, list) else None,
|
||||
usedFactIds=used_fact_ids
|
||||
))
|
||||
total_lines_output += 1
|
||||
else:
|
||||
@@ -360,33 +255,6 @@ COST OPTIMIZATION:
|
||||
if audio_url_raw:
|
||||
logger.warning(f"[ScriptGen] Scene {idx} has audioUrl - will be reset to None")
|
||||
|
||||
# Keep each scene under TTS request size to prevent failures
|
||||
scene_char_count = sum(len((l.text or "").strip()) for l in lines)
|
||||
if scene_char_count > TARGET_TTS_CHARS_PER_SCENE and lines:
|
||||
logger.warning(
|
||||
f"[ScriptGen] Scene {idx} text too long ({scene_char_count} chars). "
|
||||
f"Trimming to {TARGET_TTS_CHARS_PER_SCENE} target."
|
||||
)
|
||||
trimmed_lines: list[PodcastSceneLine] = []
|
||||
remaining = TARGET_TTS_CHARS_PER_SCENE
|
||||
for l in lines:
|
||||
if remaining <= 0:
|
||||
break
|
||||
line_text = (l.text or "").strip()
|
||||
if len(line_text) <= remaining:
|
||||
trimmed_lines.append(l)
|
||||
remaining -= len(line_text)
|
||||
continue
|
||||
l.text = f"{line_text[:max(0, remaining - 1)].rstrip()}…"
|
||||
trimmed_lines.append(l)
|
||||
remaining = 0
|
||||
lines = trimmed_lines
|
||||
|
||||
chart_data = scene.get("chart_data") or scene.get("chartData") or None
|
||||
if podcast_mode == "audio_only" and not chart_data:
|
||||
# Ensure audio-only always has a B-roll mapping fallback
|
||||
chart_data = _default_chart_data(title)
|
||||
|
||||
scenes.append(
|
||||
PodcastScene(
|
||||
id=scene.get("id") or f"scene-{idx + 1}",
|
||||
@@ -398,7 +266,6 @@ COST OPTIMIZATION:
|
||||
imageUrl=None, # Will be generated later
|
||||
audioUrl=None, # Will be generated later
|
||||
imagePrompt=None, # Will be generated during image generation
|
||||
chart_data=chart_data if isinstance(chart_data, dict) else None,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -406,8 +273,6 @@ COST OPTIMIZATION:
|
||||
logger.warning(f"[ScriptGen] Script generated: {len(scenes)} scenes, {total_lines_output}/{total_lines_input} lines")
|
||||
if dropped_empty_lines > 0:
|
||||
logger.warning(f"[ScriptGen] Dropped {dropped_empty_lines} empty lines")
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[ScriptGen] ===== SCRIPT_GEN_END (took {duration_ms}ms) =====")
|
||||
|
||||
return PodcastScriptResponse(scenes=scenes)
|
||||
|
||||
|
||||
@@ -1,338 +0,0 @@
|
||||
"""
|
||||
Category Research Handlers
|
||||
|
||||
Research endpoints using Tavily or Exa for category-based topic discovery.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from types import SimpleNamespace
|
||||
from sqlalchemy import text
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.research.tavily_service import TavilyService
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
router = APIRouter(prefix="/research", tags=["Podcast Category Research"])
|
||||
|
||||
CATEGORY_PROVIDER_MAP = {
|
||||
"news": "tavily",
|
||||
"finance": "tavily",
|
||||
"research-paper": "exa",
|
||||
"personal-site": "exa",
|
||||
}
|
||||
|
||||
EXA_CATEGORY_MAP = {
|
||||
"research-paper": "research paper",
|
||||
"personal-site": "personal site",
|
||||
}
|
||||
|
||||
|
||||
def _preflight_check(user_id: str, provider: APIProvider, provider_name: str):
|
||||
"""Check subscription limits before making a research API call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=0,
|
||||
actual_provider_name=provider_name,
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': provider_name, 'usage_info': usage_info or {}
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[CategoryResearch] Preflight check failed for {provider_name}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_research_usage(user_id: str, provider_name: str, cost: float, calls_column: str, cost_column: str):
|
||||
"""Track research API usage after successful call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[CategoryResearch] Could not get DB session for user {user_id}")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
update_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {calls_column} = COALESCE({calls_column}, 0) + 1,
|
||||
{cost_column} = COALESCE({cost_column}, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[CategoryResearch] Tracked {provider_name} usage: user={user_id}, cost=${cost}")
|
||||
|
||||
# Clear dashboard cache so header stats update immediately
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[CategoryResearch] Failed to clear dashboard cache: {cache_err}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Failed to track {provider_name} usage: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class CategoryResearchRequest(BaseModel):
|
||||
category: str
|
||||
keyword: Optional[str] = None
|
||||
max_results: Optional[int] = 8
|
||||
website_url: Optional[str] = None
|
||||
|
||||
|
||||
class CategoryTopic(BaseModel):
|
||||
title: str
|
||||
url: str
|
||||
snippet: str
|
||||
score: float
|
||||
favicon: Optional[str] = None
|
||||
|
||||
|
||||
class CategoryResearchResponse(BaseModel):
|
||||
success: bool
|
||||
category: str
|
||||
provider: str
|
||||
topics: List[CategoryTopic]
|
||||
query: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def _normalize_tavily_results(results: List[Dict]) -> List[CategoryTopic]:
|
||||
topics = []
|
||||
for item in results:
|
||||
topics.append(CategoryTopic(
|
||||
title=item.get("title", ""),
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("content", ""),
|
||||
score=item.get("score", 0.0),
|
||||
favicon=item.get("favicon"),
|
||||
))
|
||||
return topics
|
||||
|
||||
|
||||
def _normalize_exa_results(results: List[Dict], query: str) -> List[CategoryTopic]:
|
||||
topics = []
|
||||
for idx, item in enumerate(results):
|
||||
score = 1.0 - (idx * 0.1)
|
||||
topics.append(CategoryTopic(
|
||||
title=item.get("title", "") or f"Result {idx + 1}",
|
||||
url=item.get("url", ""),
|
||||
snippet=item.get("summary", "") or item.get("text", "") or "",
|
||||
score=max(0.5, score),
|
||||
favicon=None,
|
||||
))
|
||||
return topics
|
||||
|
||||
|
||||
async def _search_tavily(category: str, keyword: str, max_results: int, user_id: str) -> CategoryResearchResponse:
|
||||
logger.info(f"[CategoryResearch] Using Tavily for category={category}, keyword={keyword}")
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.TAVILY, "tavily")
|
||||
|
||||
try:
|
||||
tavily = TavilyService()
|
||||
result = await tavily.search(
|
||||
query=keyword,
|
||||
topic=category,
|
||||
search_depth="basic",
|
||||
max_results=max_results,
|
||||
include_favicon=True,
|
||||
)
|
||||
|
||||
if not result.get("success"):
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=result.get("error", "Tavily search failed")
|
||||
)
|
||||
|
||||
topics = _normalize_tavily_results(result.get("results", []))
|
||||
logger.info(f"[CategoryResearch] Tavily found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.001 # basic search = 1 credit
|
||||
_track_research_usage(user_id, "tavily", cost, "tavily_calls", "tavily_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
provider="tavily",
|
||||
topics=topics,
|
||||
query=keyword,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Tavily error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _search_exa(category: str, keyword: str, max_results: int, user_id: str, website_url: Optional[str] = None) -> CategoryResearchResponse:
|
||||
exa_category = EXA_CATEGORY_MAP.get(category, category)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa: category={category}, exa_category={exa_category}, keyword={keyword}, website_url={website_url}")
|
||||
|
||||
try:
|
||||
# Import exa directly for more control
|
||||
import os
|
||||
from urllib.parse import urlparse
|
||||
exa_api_key = os.getenv("EXA_API_KEY")
|
||||
if not exa_api_key:
|
||||
raise HTTPException(status_code=500, detail="EXA_API_KEY not configured")
|
||||
|
||||
from exa_py import Exa
|
||||
exa = Exa(exa_api_key)
|
||||
logger.info(f"[CategoryResearch] Exa client initialized")
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.EXA, "exa")
|
||||
|
||||
# Build search parameters
|
||||
search_params = {
|
||||
"num_results": max_results,
|
||||
"category": exa_category,
|
||||
}
|
||||
|
||||
# For personal-site, extract domain from URL if provided
|
||||
include_domains = None
|
||||
if category == "personal-site" and website_url:
|
||||
try:
|
||||
parsed = urlparse(website_url)
|
||||
if parsed.netloc:
|
||||
include_domains = [parsed.netloc]
|
||||
logger.info(f"[CategoryResearch] Personal site - limiting to domain: {parsed.netloc}")
|
||||
elif parsed.path and "." in parsed.path:
|
||||
# Could be domain without protocol
|
||||
include_domains = [parsed.path]
|
||||
logger.info(f"[CategoryResearch] Personal site - using as domain: {parsed.path}")
|
||||
except Exception as url_err:
|
||||
logger.warning(f"[CategoryResearch] Failed to parse website_url: {url_err}")
|
||||
|
||||
logger.info(f"[CategoryResearch] Calling Exa with params: {search_params}, include_domains={include_domains}")
|
||||
|
||||
# Make the search call
|
||||
results = exa.search_and_contents(
|
||||
query=keyword,
|
||||
type="auto" if category != "personal-site" else "neural",
|
||||
num_results=max_results,
|
||||
category=exa_category,
|
||||
text=True,
|
||||
summary=True,
|
||||
include_domains=include_domains,
|
||||
)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa search completed, got results")
|
||||
|
||||
# Transform results to our format
|
||||
topics = []
|
||||
if results and hasattr(results, 'results'):
|
||||
for item in results.results:
|
||||
title = getattr(item, 'title', 'Untitled')
|
||||
url = getattr(item, 'url', '')
|
||||
snippet = getattr(item, 'summary', '') or getattr(item, 'text', '') or ''
|
||||
score = 0.8 # Default score for Exa results
|
||||
|
||||
topics.append(CategoryTopic(
|
||||
title=title,
|
||||
url=url,
|
||||
snippet=snippet[:300] if snippet else '',
|
||||
score=score,
|
||||
favicon=None,
|
||||
))
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.005 # Default Exa cost for 1-25 results
|
||||
_track_research_usage(user_id, "exa", cost, "exa_calls", "exa_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
provider="exa",
|
||||
topics=topics,
|
||||
query=keyword,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(f"[CategoryResearch] Exa error: {type(e).__name__}: {e}")
|
||||
logger.error(f"[CategoryResearch] Stack: {traceback.format_exc()}")
|
||||
raise HTTPException(status_code=500, detail=f"Exa search failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/tavily-category", response_model=CategoryResearchResponse)
|
||||
async def research_by_category(
|
||||
request: CategoryResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Research topics by category using Tavily or Exa.
|
||||
|
||||
Categories:
|
||||
- news, finance: Uses Tavily
|
||||
- research-paper, personal-site: Uses Exa
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
category = request.category.lower()
|
||||
valid_categories = list(CATEGORY_PROVIDER_MAP.keys())
|
||||
|
||||
logger.info(f"[CategoryResearch] Full request payload: category={request.category}, keyword={request.keyword}, website_url={request.website_url}")
|
||||
|
||||
if category not in valid_categories:
|
||||
logger.error(f"[CategoryResearch] Invalid category: {category}, valid: {valid_categories}")
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Category must be one of: {', '.join(valid_categories)}"
|
||||
)
|
||||
|
||||
keyword = request.keyword or category
|
||||
max_results = min(max(request.max_results or 8, 5), 10)
|
||||
website_url = request.website_url
|
||||
|
||||
logger.info(f"[CategoryResearch] Processing: category={category}, keyword={keyword}, max_results={max_results}, website_url={website_url}")
|
||||
|
||||
provider = CATEGORY_PROVIDER_MAP.get(category, "tavily")
|
||||
logger.info(f"[CategoryResearch] Selected provider: {provider} for category: {category}")
|
||||
|
||||
try:
|
||||
if provider == "tavily":
|
||||
return await _search_tavily(category, keyword, max_results, user_id)
|
||||
elif provider == "exa":
|
||||
return await _search_exa(category, keyword, max_results, user_id, website_url)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unknown provider")
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Outer error: {type(e).__name__}: {e}", exc_info=True)
|
||||
raise
|
||||
@@ -1,119 +0,0 @@
|
||||
"""
|
||||
Podcast Trends Handler
|
||||
|
||||
Endpoints for fetching Google Trends data relevant to podcast topics.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
from loguru import logger
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/trends", tags=["Podcast Trends"])
|
||||
|
||||
# Module-level shared instance (singleton pattern)
|
||||
_trends_service_instance = None
|
||||
_trends_service_lock = None
|
||||
|
||||
|
||||
def get_trends_service():
|
||||
"""Get or create shared GoogleTrendsService instance."""
|
||||
global _trends_service_instance, _trends_service_lock
|
||||
if _trends_service_instance is None:
|
||||
try:
|
||||
from services.research.trends import GoogleTrendsService
|
||||
_trends_service_instance = GoogleTrendsService()
|
||||
_trends_service_lock = asyncio.Lock()
|
||||
logger.info("[Podcast Trends] Created shared GoogleTrendsService instance")
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] Failed to create GoogleTrendsService: {e}")
|
||||
raise
|
||||
return _trends_service_instance
|
||||
|
||||
|
||||
class PodcastTrendsRequest(BaseModel):
|
||||
keywords: List[str] = Field(..., min_length=1, max_length=5, description="1-5 keywords to analyze")
|
||||
timeframe: str = Field(default="today 12-m", description="Timeframe: 'today 3-m', 'today 12-m', 'today 5-y', 'all'")
|
||||
geo: str = Field(default="US", description="Country code: 'US', 'GB', 'IN', etc.")
|
||||
source: str = Field(default="web", description="Data source: 'web' (Google), 'podcast' (YouTube)")
|
||||
|
||||
|
||||
class PodcastTrendsResponse(BaseModel):
|
||||
success: bool
|
||||
data: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
@router.post("", response_model=PodcastTrendsResponse)
|
||||
async def get_podcast_trends(
|
||||
request: PodcastTrendsRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""Fetch Google Trends data for podcast topic keywords."""
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
try:
|
||||
service = get_trends_service()
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] GoogleTrendsService unavailable: {e}")
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Google Trends service is currently unavailable. Please try again later."
|
||||
)
|
||||
|
||||
try:
|
||||
# Map 'source' to 'gprop' - 'podcast' uses YouTube for video/podcast relevance
|
||||
gprop_map = {"": "", "web": "", "podcast": "youtube", "news": "news", "images": "images", "shopping": "froogle"}
|
||||
gprop = gprop_map.get(request.source, "")
|
||||
|
||||
result = await service.analyze_trends(
|
||||
keywords=request.keywords,
|
||||
timeframe=request.timeframe,
|
||||
geo=request.geo,
|
||||
gprop=gprop,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
has_error = result.get("error")
|
||||
has_data = (
|
||||
len(result.get("interest_over_time", [])) > 0
|
||||
or len(result.get("interest_by_region", [])) > 0
|
||||
or len(result.get("related_topics", {}).get("top", [])) > 0
|
||||
or len(result.get("related_topics", {}).get("rising", [])) > 0
|
||||
or len(result.get("related_queries", {}).get("top", [])) > 0
|
||||
or len(result.get("related_queries", {}).get("rising", [])) > 0
|
||||
)
|
||||
|
||||
# Return error if: has error OR no data (meaning blocked/empty)
|
||||
if has_error and not has_data:
|
||||
error_msg = result.get("error", "")
|
||||
cooldown_active = result.get("cooldown_active", False)
|
||||
logger.warning(f"[Trends] No data or error: {error_msg[:100]}")
|
||||
# Provide helpful message during cooldown
|
||||
if cooldown_active:
|
||||
return PodcastTrendsResponse(
|
||||
success=False,
|
||||
data=result,
|
||||
error="Google is rate limiting requests. Try using 'Get Trending Topics' instead, or wait 30 minutes."
|
||||
)
|
||||
return PodcastTrendsResponse(success=False, data=result, error=error_msg or "No trends data available. Google may be blocking requests.")
|
||||
|
||||
# Even if no error but empty data - return error
|
||||
if not has_data:
|
||||
logger.warning("[Trends] Empty data returned")
|
||||
return PodcastTrendsResponse(success=False, data=result, error="No trends data available. Please try different keywords.")
|
||||
|
||||
return PodcastTrendsResponse(success=True, data=result)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast Trends] Error fetching trends for {request.keywords}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to fetch trends data: {str(e)}"
|
||||
)
|
||||
@@ -321,7 +321,7 @@ async def generate_podcast_video(
|
||||
|
||||
# Load image bytes (scene image is required for video generation)
|
||||
if body.avatar_image_url:
|
||||
image_bytes = load_podcast_image_bytes(body.avatar_image_url, user_id=user_id)
|
||||
image_bytes = load_podcast_image_bytes(body.avatar_image_url)
|
||||
else:
|
||||
# Scene-specific image should be generated before video generation
|
||||
raise HTTPException(
|
||||
@@ -332,7 +332,7 @@ async def generate_podcast_video(
|
||||
mask_image_bytes = None
|
||||
if body.mask_image_url:
|
||||
try:
|
||||
mask_image_bytes = load_podcast_image_bytes(body.mask_image_url, user_id=user_id)
|
||||
mask_image_bytes = load_podcast_image_bytes(body.mask_image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Failed to load mask image: {e}")
|
||||
raise HTTPException(
|
||||
|
||||
@@ -80,14 +80,6 @@ class PodcastEnhanceIdeaRequest(BaseModel):
|
||||
"""Request model for enhancing a podcast idea with AI."""
|
||||
idea: str = Field(..., description="The raw podcast idea or keywords")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
website_data: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional website extraction data for enriched context (title, summary, highlights, subpages, url)"
|
||||
)
|
||||
topic_context: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Optional category research context (category, topics, selected_topic)"
|
||||
)
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaResponse(BaseModel):
|
||||
@@ -105,7 +97,6 @@ class PodcastScriptRequest(BaseModel):
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
outline: Optional[Dict[str, Any]] = Field(None, description="The refined episode outline to follow")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="The full analysis context (audience, keywords, etc.)")
|
||||
podcast_mode: Optional[str] = Field(default="video_only", description="Podcast mode: audio_only, video_only, or audio_video")
|
||||
|
||||
|
||||
class PodcastSceneLine(BaseModel):
|
||||
@@ -114,7 +105,6 @@ class PodcastSceneLine(BaseModel):
|
||||
emphasis: Optional[bool] = False
|
||||
id: Optional[str] = None # Optional line ID for frontend tracking
|
||||
usedFactIds: Optional[List[str]] = None # Facts referenced in this line
|
||||
ttsHints: Optional[List[str]] = None # Optional TTS hints, e.g. pause_300ms, smile, emphasize_data
|
||||
|
||||
|
||||
class PodcastScene(BaseModel):
|
||||
@@ -127,7 +117,6 @@ class PodcastScene(BaseModel):
|
||||
imageUrl: Optional[str] = None # Generated image URL for video generation
|
||||
audioUrl: Optional[str] = None # Generated audio URL for this scene
|
||||
imagePrompt: Optional[str] = None # Original image generation prompt for video context
|
||||
chart_data: Optional[Dict[str, Any]] = None # Optional chart mapping for B-roll scenes
|
||||
|
||||
|
||||
class PodcastExaConfig(BaseModel):
|
||||
@@ -217,7 +206,6 @@ class PodcastExaResearchResponse(BaseModel):
|
||||
mapped_angles: List[Dict[str, Any]] = [] # Content angles for the episode
|
||||
expert_quotes: List[Dict[str, Any]] = [] # Expert quotes from research
|
||||
listener_cta_suggestions: List[str] = [] # CTA suggestions
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class PodcastScriptResponse(BaseModel):
|
||||
@@ -231,9 +219,6 @@ class PodcastAudioRequest(BaseModel):
|
||||
text: str
|
||||
voice_id: Optional[str] = "Wise_Woman"
|
||||
custom_voice_id: Optional[str] = None # Voice clone ID for custom voice
|
||||
use_voice_clone: Optional[bool] = False # If True, use voice clone with voice_sample_url
|
||||
voice_sample_url: Optional[str] = None # URL to user's voice sample for cloning
|
||||
voice_clone_engine: Optional[str] = None # Engine: "qwen3", "minimax", "cosyvoice"
|
||||
speed: Optional[float] = 1.0
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[float] = 0.0
|
||||
@@ -478,59 +463,3 @@ class VoiceCloneResult(BaseModel):
|
||||
file_size: int
|
||||
task_id: str
|
||||
status: str = "completed"
|
||||
|
||||
|
||||
class ExtractUrlRequest(BaseModel):
|
||||
"""Request to extract content from a URL using Exa."""
|
||||
url: str = Field(..., description="URL to extract content from")
|
||||
|
||||
|
||||
class ExtractUrlResponse(BaseModel):
|
||||
"""Response with extracted content from URL."""
|
||||
success: bool
|
||||
title: Optional[str] = None
|
||||
text: Optional[str] = None
|
||||
summary: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
highlights: Optional[List[str]] = Field(default_factory=list, description="Key highlights from the content")
|
||||
url: str
|
||||
image: Optional[str] = None
|
||||
favicon: Optional[str] = None
|
||||
subpages: Optional[List[Dict[str, Any]]] = Field(default_factory=list, description="Subpages with their own content")
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class WebsiteAnalysisRequest(BaseModel):
|
||||
"""Request to save user's website analysis."""
|
||||
website_url: str = Field(..., description="The website URL")
|
||||
exa_content: Dict[str, Any] = Field(default_factory=dict, description="Exa extracted content")
|
||||
|
||||
|
||||
class WebsiteAnalysisResponse(BaseModel):
|
||||
"""Response for website analysis."""
|
||||
success: bool
|
||||
website_url: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastPreEstimateRequest(BaseModel):
|
||||
"""Request model for pre-analysis cost estimate."""
|
||||
duration: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
query_count: int = Field(default=3, description="Number of research queries")
|
||||
podcast_mode: str = Field(default="audio_video", description="Podcast mode: audio_only, video_only, or audio_video")
|
||||
# Optional model overrides for cost estimation
|
||||
gemini_model: Optional[str] = Field(default=None, description="LLM model: gemini-2.5-flash, gemini-1.5-flash, etc.")
|
||||
audio_tts_model: Optional[str] = Field(default=None, description="Audio TTS model: minimax/speech-02-hd")
|
||||
voice_clone_engine: Optional[str] = Field(default=None, description="Voice clone engine: qwen3, cosyvoice, minimax")
|
||||
image_model: Optional[str] = Field(default=None, description="Image model: qwen-image, ideogram-v3-turbo")
|
||||
video_model: Optional[str] = Field(default=None, description="Video model: wan-2.5, kling-v2.5-turbo-std-5s, wavespeed-ai/infinitetalk")
|
||||
|
||||
|
||||
class PodcastPreEstimateResponse(BaseModel):
|
||||
"""Response model for pre-analysis cost estimate."""
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
pricing_available: bool = Field(default=False, description="Whether pricing data is available in DB")
|
||||
debug: Optional[Dict[str, Any]] = Field(default=None, description="Debug info: pricing rows count, providers")
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
"""
|
||||
Prompts module for podcast topic enhancement.
|
||||
"""
|
||||
|
||||
from .website_enhance_prompts import (
|
||||
get_enhance_topic_prompt,
|
||||
format_website_context,
|
||||
STANDARD_ENHANCE_PROMPT,
|
||||
WEBSITE_AWARE_ENHANCE_PROMPT,
|
||||
)
|
||||
|
||||
from services.podcast_context_builder import (
|
||||
PodcastContextBuilder,
|
||||
context_builder,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"get_enhance_topic_prompt",
|
||||
"format_website_context",
|
||||
"STANDARD_ENHANCE_PROMPT",
|
||||
"WEBSITE_AWARE_ENHANCE_PROMPT",
|
||||
"PodcastContextBuilder",
|
||||
"context_builder",
|
||||
]
|
||||
@@ -1,187 +0,0 @@
|
||||
"""
|
||||
Website-aware prompts for podcast topic enhancement.
|
||||
|
||||
This module provides prompts for enhancing podcast topics with optional
|
||||
website extraction data for richer context.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from string import Template
|
||||
|
||||
|
||||
# Standard prompt for when no website data is available
|
||||
STANDARD_ENHANCE_PROMPT = Template("""">You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
|
||||
${bible_context}
|
||||
|
||||
RAW IDEA/KEYWORDS: "$idea"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
# Website-aware prompt for when website data is available
|
||||
WEBSITE_AWARE_ENHANCE_PROMPT = Template("""">You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea, enriched with website content analysis.
|
||||
|
||||
${bible_context}
|
||||
|
||||
WEBSITE CONTENT ANALYSIS:
|
||||
${website_context}
|
||||
|
||||
RAW IDEA/KEYWORDS: "$idea"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle, that INCORPORATE the website content context:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise from the website)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections tied to the brand)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance leveraging the site's focus areas)
|
||||
|
||||
Each version should:
|
||||
- Be 2-3 sentences
|
||||
- Reference specific elements from the website content when relevant
|
||||
- Be audience-focused and align with host persona if provided
|
||||
- NOT just repeat the website summary - create fresh podcast angles
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 strings, each string being a complete episode pitch (NOT objects, just plain strings)
|
||||
- rationales: array of 3 strings explaining the approach for each version
|
||||
|
||||
IMPORTANT: enhanced_ideas must be an array of plain strings, NOT objects. Example:
|
||||
{
|
||||
"enhanced_ideas": [
|
||||
"Your expert guide to AI advancement: A practical look at how AI is transforming industries...",
|
||||
"The human stories behind AI innovation: From Silicon Valley to your daily life...",
|
||||
"AI in 2026: What's trending and what's next in artificial intelligence..."
|
||||
],
|
||||
"rationales": [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
def get_enhance_topic_prompt(
|
||||
idea: str,
|
||||
bible_context: str = "",
|
||||
website_data: Optional[Dict[str, Any]] = None
|
||||
) -> str:
|
||||
"""
|
||||
Returns the appropriate prompt based on available context.
|
||||
|
||||
Args:
|
||||
idea: The raw podcast idea or keywords
|
||||
bible_context: Optional Podcast Bible context string
|
||||
website_data: Optional website extraction data
|
||||
|
||||
Returns:
|
||||
Formatted prompt string with appropriate context
|
||||
"""
|
||||
# Build bible context section
|
||||
bible_section = f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""
|
||||
|
||||
if website_data:
|
||||
# Build website context section
|
||||
website_context_parts = []
|
||||
if website_data.get('url'):
|
||||
website_context_parts.append(f"Source: {website_data.get('url')}")
|
||||
if website_data.get('title'):
|
||||
website_context_parts.append(f"Company/Organization: {website_data.get('title')}")
|
||||
if website_data.get('summary'):
|
||||
website_context_parts.append(f"About: {website_data.get('summary')}")
|
||||
if website_data.get('highlights'):
|
||||
highlights_str = ', '.join(website_data.get('highlights', [])[:3])
|
||||
website_context_parts.append(f"Key Highlights: {highlights_str}")
|
||||
if website_data.get('subpages'):
|
||||
subpages_str = ', '.join([
|
||||
sp.get('title', sp.get('url', ''))
|
||||
for sp in website_data.get('subpages', [])[:3]
|
||||
])
|
||||
website_context_parts.append(f"Subpages: {subpages_str}")
|
||||
|
||||
website_context_str = "\n".join(website_context_parts)
|
||||
|
||||
return WEBSITE_AWARE_ENHANCE_PROMPT.substitute(
|
||||
idea=idea,
|
||||
bible_context=bible_section,
|
||||
website_context=website_context_str
|
||||
)
|
||||
else:
|
||||
return STANDARD_ENHANCE_PROMPT.substitute(
|
||||
idea=idea,
|
||||
bible_context=bible_section
|
||||
)
|
||||
|
||||
|
||||
def format_website_context(website_data: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Format website data for inclusion in progress messages.
|
||||
|
||||
Args:
|
||||
website_data: Website extraction data
|
||||
|
||||
Returns:
|
||||
Formatted string describing what's being used
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if website_data.get('title'):
|
||||
parts.append(f"• {website_data['title']}")
|
||||
|
||||
if website_data.get('summary'):
|
||||
summary_preview = website_data['summary'][:100]
|
||||
parts.append(f"• Summary: {summary_preview}...")
|
||||
|
||||
if website_data.get('highlights'):
|
||||
parts.append(f"• {len(website_data['highlights'])} key highlights")
|
||||
|
||||
if website_data.get('subpages'):
|
||||
parts.append(f"• {len(website_data['subpages'])} subpages analyzed")
|
||||
|
||||
if website_data.get('url'):
|
||||
parts.append(f"• Source: {website_data['url']}")
|
||||
|
||||
return "\n".join(parts) if parts else "Basic website analysis"
|
||||
|
||||
if website_data.get('title'):
|
||||
parts.append(f"• {website_data['title']}")
|
||||
|
||||
if website_data.get('summary'):
|
||||
summary_preview = website_data['summary'][:100]
|
||||
parts.append(f"• Summary: {summary_preview}...")
|
||||
|
||||
if website_data.get('highlights'):
|
||||
parts.append(f"• {len(website_data['highlights'])} key highlights")
|
||||
|
||||
if website_data.get('subpages'):
|
||||
parts.append(f"• {len(website_data['subpages'])} subpages analyzed")
|
||||
|
||||
if website_data.get('url'):
|
||||
parts.append(f"• Source: {website_data['url']}")
|
||||
|
||||
return "\n".join(parts) if parts else "Basic website analysis"
|
||||
@@ -12,7 +12,7 @@ from api.story_writer.utils.auth import require_authenticated_user
|
||||
from api.story_writer.task_manager import task_manager
|
||||
|
||||
# Import all handler routers
|
||||
from .handlers import projects, analysis, research, script, audio, images, video, avatar, dubbing, broll, trends, tavily_category_research
|
||||
from .handlers import projects, analysis, research, script, audio, images, video, avatar, dubbing
|
||||
|
||||
# Create main router
|
||||
router = APIRouter(prefix="/api/podcast", tags=["Podcast Maker"])
|
||||
@@ -27,9 +27,6 @@ router.include_router(images.router)
|
||||
router.include_router(video.router)
|
||||
router.include_router(avatar.router)
|
||||
router.include_router(dubbing.router)
|
||||
router.include_router(broll.router)
|
||||
router.include_router(trends.router)
|
||||
router.include_router(tavily_category_research.router)
|
||||
|
||||
|
||||
@router.get("/task/{task_id}/status")
|
||||
|
||||
@@ -67,32 +67,15 @@ def load_podcast_audio_bytes(audio_url: str, user_id: str | None = None) -> byte
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load audio: {str(exc)}")
|
||||
|
||||
|
||||
def load_podcast_image_bytes(image_url: str, user_id: str | None = None) -> bytes:
|
||||
"""Load podcast image bytes from URL. Resolves from workspace first."""
|
||||
def load_podcast_image_bytes(image_url: str) -> bytes:
|
||||
"""Load podcast image bytes from URL. Uses centralized media loader."""
|
||||
if not image_url:
|
||||
raise HTTPException(status_code=400, detail="Image URL is required")
|
||||
|
||||
logger.info(f"[Podcast] Loading image from URL: {image_url}")
|
||||
|
||||
try:
|
||||
# Extract filename from URL path
|
||||
prefix = "/api/podcast/images/"
|
||||
if prefix in image_url:
|
||||
filename = image_url.split(prefix, 1)[1].split("?", 1)[0].strip()
|
||||
# Handle subdirectories like avatars/
|
||||
subdir = None
|
||||
if "/" in filename:
|
||||
subdir_part = filename.rsplit("/", 1)[0]
|
||||
subdir = Path(subdir_part)
|
||||
filename = filename.rsplit("/", 1)[1]
|
||||
|
||||
try:
|
||||
image_path = _resolve_podcast_media_file(filename, "image", user_id, subdir=subdir)
|
||||
return image_path.read_bytes()
|
||||
except HTTPException:
|
||||
pass # Fall through to centralized loader
|
||||
|
||||
# Fall back to centralized media loader
|
||||
# REUSE: Use centralized media loader which handles cross-module lookups
|
||||
image_bytes = load_media_bytes(image_url)
|
||||
|
||||
if not image_bytes:
|
||||
|
||||
@@ -12,7 +12,7 @@ import sqlite3
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
from models.subscription_models import UsageAlert, UserSubscription
|
||||
from models.subscription_models import UsageAlert
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from ..dependencies import verify_user_access
|
||||
from ..cache import get_cached_dashboard, set_cached_dashboard
|
||||
@@ -27,9 +27,7 @@ async def get_dashboard_data(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive dashboard data for usage monitoring.
|
||||
Returns all-time total + current period usage by default.
|
||||
When billing_period is specified, returns that period's data only."""
|
||||
"""Get comprehensive dashboard data for usage monitoring."""
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
@@ -37,23 +35,17 @@ async def get_dashboard_data(
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
|
||||
# Check cache first (only for default view, skip when a specific period is requested)
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data and not billing_period:
|
||||
return cached_data
|
||||
# Check cache first (skip if billing_period is specified)
|
||||
if not billing_period:
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# When a specific billing_period is requested, show only that period's data
|
||||
# Otherwise show all-time total + current period usage
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
# Get current usage stats (for the requested period)
|
||||
current_usage = usage_service.get_user_usage_stats(user_id, billing_period)
|
||||
|
||||
# Get usage trends (last 6 months)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
@@ -84,44 +76,13 @@ async def get_dashboard_data(
|
||||
]
|
||||
|
||||
# Calculate cost projections (only relevant for current month)
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
|
||||
# Determine if viewing current period based on subscription, not calendar
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
# Use subscription's billing period or fallback to calendar
|
||||
if subscription and subscription.current_period_start:
|
||||
sub_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Check if we have data for subscription period or calendar period
|
||||
from models.subscription_models import UsageSummary
|
||||
sub_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == sub_period
|
||||
).first()
|
||||
|
||||
# Determine which period to use for "current"
|
||||
if sub_data_exists:
|
||||
effective_period = sub_period
|
||||
else:
|
||||
# Check calendar period for backward compatibility
|
||||
cal_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
effective_period = calendar_period if cal_data_exists else sub_period
|
||||
|
||||
is_current_period = not billing_period or billing_period == effective_period
|
||||
else:
|
||||
is_current_period = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
|
||||
if is_current_period:
|
||||
# Only project costs if viewing current month
|
||||
is_current_month = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
if is_current_month:
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
else:
|
||||
projected_cost = current_cost # For past months, projected is actual
|
||||
@@ -129,8 +90,7 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -140,9 +100,9 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
@@ -171,13 +131,7 @@ async def get_dashboard_data(
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
@@ -198,7 +152,7 @@ async def get_dashboard_data(
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
@@ -206,8 +160,7 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -217,17 +170,16 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Cache the response after successful retry (only for default view)
|
||||
if not billing_period:
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
# Cache the response after successful retry
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
return response_payload
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
@@ -235,8 +187,7 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(retry_err),
|
||||
"data": {
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
@@ -250,8 +201,7 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"data": {
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
|
||||
@@ -14,21 +14,13 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""
|
||||
Format subscription plan limits for API response.
|
||||
|
||||
Includes _zero_means metadata per field to disambiguate:
|
||||
- 'disabled': 0 means the feature is not available (Free tier)
|
||||
- 'unlimited': 0 means unlimited usage (Enterprise tier)
|
||||
- 'limited': >0 means numerical limit applies
|
||||
|
||||
Args:
|
||||
plan: SubscriptionPlan model instance
|
||||
|
||||
Returns:
|
||||
Dictionary with formatted limits and _zero_means metadata
|
||||
Dictionary with formatted limits
|
||||
"""
|
||||
tier = plan.tier.value if hasattr(plan.tier, 'value') else str(plan.tier)
|
||||
is_enterprise = tier == 'enterprise'
|
||||
|
||||
limit_fields = {
|
||||
return {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
@@ -43,43 +35,11 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0) or 0,
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0) or 0,
|
||||
"exa_calls": getattr(plan, 'exa_calls_limit', 0) or 0,
|
||||
"wavespeed_calls": getattr(plan, 'wavespeed_calls_limit', 0) or 0,
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
"mistral_tokens": plan.mistral_tokens_limit,
|
||||
"monthly_cost": plan.monthly_cost_limit,
|
||||
}
|
||||
|
||||
# Build _zero_means metadata: indicates whether 0 means 'disabled' or 'unlimited'
|
||||
zero_means = {}
|
||||
for field, value in limit_fields.items():
|
||||
if field == "monthly_cost":
|
||||
zero_means[field] = "disabled"
|
||||
elif is_enterprise:
|
||||
# Enterprise: 0 means unlimited for all call/token fields
|
||||
zero_means[field] = "unlimited"
|
||||
else:
|
||||
# Free/Basic/Pro: determine per-field
|
||||
# Fields that are 0=disabled on Free tier but 0=unlimited on Basic/Pro
|
||||
call_and_token_fields = {
|
||||
"gemini_calls", "openai_calls", "anthropic_calls", "mistral_calls",
|
||||
"tavily_calls", "serper_calls", "metaphor_calls", "firecrawl_calls",
|
||||
"stability_calls", "video_calls", "image_edit_calls", "audio_calls",
|
||||
"exa_calls", "wavespeed_calls", "ai_text_generation_calls",
|
||||
"gemini_tokens", "openai_tokens", "anthropic_tokens", "mistral_tokens",
|
||||
}
|
||||
if field in call_and_token_fields:
|
||||
if value == 0:
|
||||
zero_means[field] = "disabled" if tier == "free" else "unlimited"
|
||||
else:
|
||||
zero_means[field] = "limited"
|
||||
else:
|
||||
zero_means[field] = "limited" if value > 0 else "disabled"
|
||||
|
||||
return {
|
||||
**limit_fields,
|
||||
"_zero_means": zero_means,
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from services.writing_assistant import WritingAssistantService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
@@ -12,6 +11,7 @@ router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
|
||||
class SuggestRequest(BaseModel):
|
||||
text: str
|
||||
max_results: int | None = 1
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
@@ -38,10 +38,9 @@ assistant_service = WritingAssistantService()
|
||||
|
||||
|
||||
@router.post("/suggest", response_model=SuggestResponse)
|
||||
async def suggest_endpoint(req: SuggestRequest, current_user: Dict[str, Any] = Depends(get_current_user)) -> SuggestResponse:
|
||||
async def suggest_endpoint(req: SuggestRequest) -> SuggestResponse:
|
||||
try:
|
||||
user_id = current_user.get("id")
|
||||
suggestions = await assistant_service.suggest(req.text, user_id=user_id)
|
||||
suggestions = await assistant_service.suggest(req.text, req.max_results or 1)
|
||||
return SuggestResponse(
|
||||
success=True,
|
||||
suggestions=[
|
||||
|
||||
273
backend/app.py
273
backend/app.py
@@ -27,11 +27,11 @@ load_dotenv(backend_dir / '.env', override=False)
|
||||
load_dotenv(project_root / '.env', override=False)
|
||||
load_dotenv(override=False)
|
||||
|
||||
# Set LOG_LEVEL early to WARNING in feature-only modes to suppress DEBUG persona logs
|
||||
# Set LOG_LEVEL early to WARNING to suppress DEBUG persona logs in podcast mode
|
||||
import os
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast":
|
||||
os.environ["LOG_LEVEL"] = "WARNING"
|
||||
|
||||
|
||||
print(f"[app.py] Starting... ALWRITY_ENABLED_FEATURES={os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
|
||||
@@ -43,21 +43,22 @@ def get_enabled_features() -> set:
|
||||
return {f.strip() for f in env_value.split(",") if f.strip()}
|
||||
|
||||
|
||||
def _is_full_mode() -> bool:
|
||||
"""Check if running in full mode (all features enabled)."""
|
||||
enabled = get_enabled_features()
|
||||
return "all" in enabled
|
||||
|
||||
|
||||
def _is_feature_enabled(feature: str) -> bool:
|
||||
"""Check if a specific feature is enabled (including in 'all' mode)."""
|
||||
enabled = get_enabled_features()
|
||||
return feature in enabled or "all" in enabled
|
||||
|
||||
|
||||
# Print env var IMMEDIATELY at module start
|
||||
print(f"[app.py] ALWRITY_ENABLED_FEATURES at start: {os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
def is_podcast_only_demo_mode() -> bool:
|
||||
"""Check if podcast-only mode is enabled."""
|
||||
import os
|
||||
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "all")
|
||||
enabled = get_enabled_features()
|
||||
result = "podcast" in enabled and "all" not in enabled
|
||||
print(f"[DEBUG] is_podcast_only_demo_mode: ALWRITY_ENABLED_FEATURES={env_val}, enabled={enabled}, result={result}", flush=True)
|
||||
return result
|
||||
|
||||
|
||||
# Podcast-only check BEFORE heavy imports
|
||||
PODCAST_ONLY_DEMO_MODE = is_podcast_only_demo_mode()
|
||||
|
||||
|
||||
# Import onboarding models (after env is loaded, before heavy imports)
|
||||
from models.onboarding import APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData, CompetitorAnalysis
|
||||
@@ -89,18 +90,28 @@ _log_memory_usage()
|
||||
logger.info("app.py: Early memory checkpoint after env load")
|
||||
|
||||
|
||||
# Import modular utilities (skip OnboardingManager import in feature-only modes)
|
||||
# Import modular utilities (skip OnboardingManager import in podcast-only mode)
|
||||
from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager
|
||||
if _is_full_mode():
|
||||
if not is_podcast_only_demo_mode():
|
||||
from alwrity_utils import OnboardingManager
|
||||
|
||||
# Skip monitoring middleware in feature-only modes to save memory
|
||||
if _is_full_mode():
|
||||
# Skip monitoring middleware in podcast-only mode to save memory
|
||||
if not is_podcast_only_demo_mode():
|
||||
from services.subscription import monitoring_middleware
|
||||
else:
|
||||
monitoring_middleware = None
|
||||
|
||||
|
||||
def should_include_non_podcast_features() -> bool:
|
||||
"""Check if non-podcast features should be included."""
|
||||
enabled = get_enabled_features()
|
||||
return "all" in enabled or "core" in enabled
|
||||
|
||||
|
||||
# Legacy constant for backwards compatibility
|
||||
PODCAST_ONLY_DEMO_MODE = is_podcast_only_demo_mode()
|
||||
|
||||
|
||||
# Set up clean logging for end users
|
||||
from logging_config import setup_clean_logging
|
||||
setup_clean_logging()
|
||||
@@ -108,27 +119,27 @@ setup_clean_logging()
|
||||
# Import middleware
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
# Import component logic endpoints (skip in feature-only modes - uses seo_analyzer)
|
||||
# Import component logic endpoints (skip in podcast-only mode - uses seo_analyzer)
|
||||
component_logic_router = None
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.component_logic import router as component_logic_router
|
||||
|
||||
# Import subscription API endpoints
|
||||
from api.subscription import router as subscription_router
|
||||
|
||||
# Import Step 3 onboarding routes (skip in feature-only modes)
|
||||
# Import Step 3 onboarding routes (skip in podcast-only mode)
|
||||
step3_routes = None
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.onboarding_utils.step3_routes import router as step3_routes
|
||||
|
||||
# Import SEO tools router (skip in feature-only modes - uses seo_analyzer)
|
||||
# Import SEO tools router (skip in podcast-only mode - uses seo_analyzer)
|
||||
seo_tools_router = None
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from routers.seo_tools import router as seo_tools_router
|
||||
|
||||
# Skip Facebook Writer, LinkedIn, and other non-essential routes in feature-only modes
|
||||
# Skip Facebook Writer, LinkedIn, and other non-podcast routes in podcast-only mode
|
||||
# Also skip other heavy services that trigger PersonaAnalysisService initialization
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.facebook_writer.routers import facebook_router
|
||||
from routers.linkedin import router as linkedin_router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
@@ -139,7 +150,7 @@ if _is_full_mode():
|
||||
from routers.product_marketing import router as product_marketing_router
|
||||
from routers.campaign_creator import router as campaign_creator_router
|
||||
else:
|
||||
# In feature-only modes, only load essential assets router
|
||||
# In podcast-only mode, only load essential podcast assets router
|
||||
from api.assets_serving import router as assets_serving_router
|
||||
brainstorm_router = None
|
||||
images_router = None
|
||||
@@ -147,31 +158,31 @@ else:
|
||||
product_marketing_router = None
|
||||
campaign_creator_router = None
|
||||
|
||||
# Import hallucination detector router (skip in feature-only modes - triggers heavy ML)
|
||||
if _is_full_mode():
|
||||
# Import hallucination detector router (skip in podcast-only mode - triggers heavy ML)
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
else:
|
||||
hallucination_detector_router = None
|
||||
writing_assistant_router = None
|
||||
|
||||
# Import research configuration router (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# Import research configuration router (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.research_config import router as research_config_router
|
||||
else:
|
||||
research_config_router = None
|
||||
|
||||
# Import user data endpoints
|
||||
# Import content planning endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# Import content planning endpoints (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.content_planning.api.router import router as content_planning_router
|
||||
from api.content_planning.strategy_copilot import router as strategy_copilot_router
|
||||
else:
|
||||
content_planning_router = None
|
||||
strategy_copilot_router = None
|
||||
|
||||
# Import user data endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
# Import user data endpoints (skip in podcast-only mode to save memory)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.user_data import router as user_data_router
|
||||
else:
|
||||
user_data_router = None
|
||||
@@ -186,14 +197,14 @@ from services.startup_health import (
|
||||
|
||||
# Trigger reload for monitoring fix
|
||||
|
||||
# Import OAuth token monitoring routes (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# Import OAuth token monitoring routes (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.oauth_token_monitoring_routes import router as oauth_token_monitoring_router
|
||||
else:
|
||||
oauth_token_monitoring_router = None
|
||||
|
||||
# Import SEO Dashboard endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
# Import SEO Dashboard endpoints (skip in podcast-only mode to save memory)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.seo_dashboard import (
|
||||
get_seo_dashboard_data,
|
||||
get_seo_health_score,
|
||||
@@ -307,8 +318,8 @@ router_manager = RouterManager(app)
|
||||
router_group_status: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
onboarding_manager = None
|
||||
# Only create OnboardingManager in full mode
|
||||
if _is_full_mode():
|
||||
# Only create OnboardingManager if NOT in podcast-only mode
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from alwrity_utils import OnboardingManager
|
||||
onboarding_manager = OnboardingManager(app)
|
||||
|
||||
@@ -335,8 +346,7 @@ app.middleware("http")(api_key_injection_middleware)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
health_data = health_checker.basic_health_check()
|
||||
health_data["feature_mode"] = "single" if not _is_full_mode() else "full"
|
||||
health_data["enabled_features"] = list(get_enabled_features())
|
||||
health_data["podcast_only_demo_mode"] = PODCAST_ONLY_DEMO_MODE
|
||||
return health_data
|
||||
|
||||
@app.get("/health/database")
|
||||
@@ -353,8 +363,7 @@ async def comprehensive_health():
|
||||
async def readiness(current_user: dict = Depends(get_current_user)):
|
||||
"""Readiness check that validates tenant DB resolution/session under auth context."""
|
||||
return {
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"startup": get_startup_status(),
|
||||
"tenant": readiness_under_auth_context(current_user),
|
||||
}
|
||||
@@ -386,8 +395,7 @@ async def router_status():
|
||||
status = router_manager.get_router_status()
|
||||
status.update(
|
||||
{
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"router_groups": router_group_status,
|
||||
}
|
||||
)
|
||||
@@ -402,19 +410,35 @@ async def feature_profile_status():
|
||||
@app.get("/api/onboarding/status")
|
||||
async def onboarding_status():
|
||||
"""Get onboarding manager status (or demo-mode disabled state)."""
|
||||
if not _is_full_mode():
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
return {
|
||||
"enabled": False,
|
||||
"status": "disabled",
|
||||
"message": f"Onboarding is disabled in feature-only mode. Enabled features: {list(get_enabled_features())}",
|
||||
"feature_mode": "single",
|
||||
"message": "Onboarding is disabled for podcast-only demo mode.",
|
||||
"demo_mode": "podcast_only",
|
||||
}
|
||||
return onboarding_manager.get_onboarding_status()
|
||||
|
||||
# Include routers using modular utilities
|
||||
enabled_features = get_enabled_features()
|
||||
if "all" in enabled_features:
|
||||
# Full mode: load all core and optional routers
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
# In podcast-only mode, include only podcast-enabled routers from core registry
|
||||
from alwrity_utils.router_manager import CORE_ROUTER_REGISTRY
|
||||
podcast_routers = [r for r in CORE_ROUTER_REGISTRY if "podcast" in r.get("features", set())]
|
||||
for entry in podcast_routers:
|
||||
try:
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.warning(f"{entry['name']} router not mounted: {e}")
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": True,
|
||||
"reason": "Podcast routers only in podcast-only mode",
|
||||
}
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
else:
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": router_manager.include_core_routers(),
|
||||
"reason": "Full mode",
|
||||
@@ -423,72 +447,6 @@ if "all" in enabled_features:
|
||||
"mounted": router_manager.include_optional_routers(),
|
||||
"reason": "Full mode",
|
||||
}
|
||||
else:
|
||||
# Feature-only mode: load only routers matching enabled features
|
||||
from alwrity_utils.router_manager import CORE_ROUTER_REGISTRY
|
||||
|
||||
# Filter core routers that match any enabled feature
|
||||
matching_core = [
|
||||
r for r in CORE_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
logger.info(
|
||||
f"[FEATURE-MODE] Enabled features: {enabled_features}, "
|
||||
f"matching {len(matching_core)} core routers: {[r['name'] for r in matching_core]}"
|
||||
)
|
||||
|
||||
# Try to include step4_assets for voice cloning (may fail if nltk not installed)
|
||||
step4_entry = next((r for r in matching_core if r.get("name") == "step4_assets"), None)
|
||||
if step4_entry:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Attempting to load step4_assets")
|
||||
router = router_manager._load_router_from_registry(step4_entry)
|
||||
router_manager.include_router_safely(router, step4_entry["name"], step4_entry.get("include_kwargs"))
|
||||
except ImportError as e:
|
||||
logger.warning(f"[FEATURE-MODE] Skipping step4_assets (missing optional dependency): {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount step4_assets: {e}")
|
||||
|
||||
# Load other matching core routers
|
||||
for entry in matching_core:
|
||||
if entry.get("name") == "step4_assets":
|
||||
continue # Already loaded above
|
||||
if entry.get("name") == "subscription":
|
||||
continue # Loaded separately below
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Load optional routers matching enabled features
|
||||
from alwrity_utils.router_manager import OPTIONAL_ROUTER_REGISTRY
|
||||
matching_optional = [
|
||||
r for r in OPTIONAL_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
for entry in matching_optional:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading optional router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount optional {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Safety net: explicitly include hallucination detector (router_manager may skip silently)
|
||||
if hallucination_detector_router:
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
|
||||
# Log startup summary
|
||||
router_manager.log_startup_summary()
|
||||
@@ -504,8 +462,8 @@ router_group_status["assets_serving"] = {
|
||||
"reason": "Required for podcast media assets",
|
||||
}
|
||||
|
||||
# SEO Dashboard endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# SEO Dashboard endpoints (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
@app.get("/api/seo-dashboard/data")
|
||||
async def seo_dashboard_data():
|
||||
"""Get complete SEO dashboard data."""
|
||||
@@ -643,7 +601,7 @@ if _is_full_mode():
|
||||
return await analyze_urls_ai(request, current_user)
|
||||
|
||||
# Include platform analytics router
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
# Include Bing Analytics Storage router to expose storage-backed endpoints
|
||||
@@ -668,38 +626,24 @@ if _is_full_mode():
|
||||
else:
|
||||
router_group_status["platform_extensions"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in feature-only mode",
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
|
||||
# Include Podcast Maker router (only when podcast feature is enabled)
|
||||
if _is_feature_enabled("podcast") and "all" not in get_enabled_features():
|
||||
from api.podcast.router import router as podcast_router
|
||||
logger.info(f"[ROUTER] Including podcast_router")
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Podcast feature enabled",
|
||||
}
|
||||
elif "all" in get_enabled_features():
|
||||
# In full mode, podcast is loaded via optional router registry
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Full mode (loaded via registry)",
|
||||
}
|
||||
else:
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": False,
|
||||
"reason": "Podcast feature not enabled",
|
||||
}
|
||||
# Include Podcast Maker router (always needed for podcast mode)
|
||||
from api.podcast.router import router as podcast_router
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Always mounted",
|
||||
}
|
||||
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
# Include YouTube Creator Studio router
|
||||
from api.youtube.router import router as youtube_router
|
||||
app.include_router(youtube_router, prefix="/api")
|
||||
|
||||
# Include research configuration router
|
||||
if research_config_router:
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
|
||||
# Include Research Engine router (standalone AI research module)
|
||||
from api.research.router import router as research_engine_router
|
||||
@@ -725,7 +669,7 @@ if _is_full_mode():
|
||||
else:
|
||||
router_group_status["advanced_workflows"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in feature-only mode",
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
|
||||
# Setup frontend serving using modular utilities
|
||||
@@ -749,26 +693,20 @@ async def startup_event():
|
||||
try:
|
||||
_log_memory_usage()
|
||||
|
||||
# Note: Pricing is initialized per-user in services/database.py:init_user_database()
|
||||
# which runs on first database access for each user. No global seeding needed at startup.
|
||||
|
||||
enabled_features = get_enabled_features()
|
||||
is_single_mode = "all" not in enabled_features
|
||||
|
||||
# Skip startup health checks in feature-only modes to avoid unnecessary DB errors
|
||||
if _is_full_mode():
|
||||
# Skip startup health checks in podcast-only mode to avoid unnecessary DB errors
|
||||
if not is_podcast_only_demo_mode():
|
||||
startup_report = run_startup_health_routine(app)
|
||||
if startup_report.get("status") != "healthy":
|
||||
logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}")
|
||||
else:
|
||||
logger.info(f"[FEATURE-MODE] Skipping startup health routine (features: {enabled_features})")
|
||||
logger.info("[Podcast] Skipping startup health routine (podcast-only mode)")
|
||||
|
||||
# Start task scheduler only in full mode
|
||||
if _is_full_mode():
|
||||
# Start task scheduler only if NOT in podcast-only mode
|
||||
if not is_podcast_only_demo_mode():
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().start()
|
||||
else:
|
||||
logger.info(f"[FEATURE-MODE] Skipping scheduler startup (features: {enabled_features})")
|
||||
logger.info("[Podcast] Skipping scheduler startup (podcast-only mode)")
|
||||
|
||||
# Check Wix API key configuration
|
||||
wix_api_key = os.getenv('WIX_API_KEY')
|
||||
@@ -780,12 +718,9 @@ async def startup_event():
|
||||
elapsed = time.time() - startup_start
|
||||
logger.info(f"ALwrity backend started successfully in {elapsed:.1f}s")
|
||||
|
||||
# Critical router mount assertions for feature-only modes
|
||||
# Critical router mount assertions for podcast-only demo mode
|
||||
_assert_router_mounted("subscription")
|
||||
if _is_feature_enabled("podcast"):
|
||||
_assert_router_mounted("podcast")
|
||||
if _is_feature_enabled("blog_writer"):
|
||||
_assert_router_mounted("blog_writer")
|
||||
_assert_router_mounted("podcast")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during startup: {e}")
|
||||
# Don't raise - let the server start anyway
|
||||
@@ -800,7 +735,6 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
router_path_indicators = {
|
||||
"subscription": ["/api/subscription/plans", "/api/subscription/preflight"],
|
||||
"podcast": ["/api/podcast/projects", "/api/podcast/"],
|
||||
"blog_writer": ["/api/blog/health", "/api/blog/research/start"],
|
||||
}
|
||||
|
||||
expected_paths = router_path_indicators.get(router_name, [])
|
||||
@@ -811,9 +745,10 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
else:
|
||||
error_msg = f"❌ CRITICAL: Router '{router_name}' is NOT mounted! Expected paths: {expected_paths}"
|
||||
logger.error(error_msg)
|
||||
# In feature-only mode, only fail if the feature is expected
|
||||
if not _is_full_mode() and _is_feature_enabled(router_name):
|
||||
raise RuntimeError(error_msg)
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
# In demo mode, podcast router MUST be mounted
|
||||
if router_name == "podcast":
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Shutdown event
|
||||
@app.on_event("shutdown")
|
||||
|
||||
@@ -252,8 +252,6 @@ router_manager.include_core_routers()
|
||||
# Safety net: keep subscription routes available even if core inclusion flow changes
|
||||
# in special modes (e.g., demo mode). De-dup is handled by RouterManager.
|
||||
router_manager.include_router_safely(subscription_router, "subscription")
|
||||
# Include hallucination detector explicitly (router_manager may skip silently on import failure)
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
router_manager.include_optional_routers()
|
||||
|
||||
# SEO Dashboard endpoints
|
||||
|
||||
@@ -45,9 +45,6 @@ class PodcastProject(Base):
|
||||
knobs = Column(JSON, nullable=True) # Knobs settings
|
||||
research_provider = Column(String(50), nullable=True, default="google") # Research provider
|
||||
|
||||
# Project-specific topic context (category research, selected topics)
|
||||
topic_context = Column(JSON, nullable=True) # { category: "news"|"finance", topics: [...], selected_topic: {...} }
|
||||
|
||||
# UI state
|
||||
show_script_editor = Column(Boolean, default=False)
|
||||
show_render_queue = Column(Boolean, default=False)
|
||||
|
||||
@@ -80,7 +80,6 @@ class SubscriptionPlan(Base):
|
||||
video_calls_limit = Column(Integer, default=0) # AI video generation
|
||||
image_edit_calls_limit = Column(Integer, default=0) # AI image editing
|
||||
audio_calls_limit = Column(Integer, default=0) # AI audio generation (text-to-speech)
|
||||
wavespeed_calls_limit = Column(Integer, default=0) # WaveSpeed API calls (LLM + TTS + video + image)
|
||||
|
||||
# Token Limits (for LLM providers)
|
||||
gemini_tokens_limit = Column(Integer, default=0)
|
||||
|
||||
@@ -11,30 +11,17 @@ echo "📦 Checking ALWRITY_ENABLED_FEATURES..."
|
||||
ENABLED_FEATURES="${ALWRITY_ENABLED_FEATURES:-all}"
|
||||
echo "DEBUG: ENABLED_FEATURES='$ENABLED_FEATURES'"
|
||||
|
||||
case "$ENABLED_FEATURES" in
|
||||
all)
|
||||
echo "📦 Full mode: Installing all requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
# Download spaCy/NLTK models for full mode
|
||||
echo "🧠 Installing spaCy and NLTK models..."
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
;;
|
||||
podcast)
|
||||
echo "🔊 Podcast-only mode: Installing lean requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements-podcast.txt --only-binary :all: --retries 10 --timeout 120
|
||||
;;
|
||||
*)
|
||||
echo "🎯 Feature-limited mode ($ENABLED_FEATURES): Installing requirements..."
|
||||
req_file="requirements-${ENABLED_FEATURES}.txt"
|
||||
if [[ -f "$req_file" ]]; then
|
||||
python -m pip install --no-cache-dir -r "$req_file" --only-binary :all: --retries 10 --timeout 120
|
||||
else
|
||||
echo "⚠️ No feature-specific requirements file found ($req_file), installing full requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
if [[ "$ENABLED_FEATURES" == "podcast" ]]; then
|
||||
echo "🔊 Podcast-only mode: Installing lean requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements-podcast.txt --only-binary :all: --retries 10 --timeout 120
|
||||
else
|
||||
echo "📦 Full mode: Installing all requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
# Download spaCy/NLTK models for full mode
|
||||
echo "🧠 Installing spaCy and NLTK models..."
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
fi
|
||||
|
||||
# 3. Clean up unnecessary build artifacts
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
@@ -47,7 +47,6 @@ pandas>=2.0.0
|
||||
|
||||
# Image/media for podcast
|
||||
Pillow>=10.0.0
|
||||
matplotlib>=3.7.0
|
||||
huggingface_hub>=1.1.4
|
||||
|
||||
# TTS for podcast
|
||||
|
||||
@@ -45,7 +45,6 @@ numpy>=1.24.0
|
||||
|
||||
# Image/media for podcast
|
||||
Pillow>=10.0.0
|
||||
matplotlib>=3.8.0
|
||||
huggingface_hub>=1.1.4
|
||||
|
||||
# TTS for podcast
|
||||
|
||||
1620
backend/routers/image_studio.py
Normal file
1620
backend/routers/image_studio.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,34 +0,0 @@
|
||||
"""Image Studio API router package.
|
||||
|
||||
Composed from modular sub-routers. Same prefix and tags as the original monolithic file.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .health import router as health_router
|
||||
from .upscale import router as upscale_router
|
||||
from .control import router as control_router
|
||||
from .social import router as social_router
|
||||
from .edit import router as edit_router
|
||||
from .face_swap import router as face_swap_router
|
||||
from .create import router as create_router
|
||||
from .transform import router as transform_router
|
||||
from .compress import router as compress_router
|
||||
from .convert import router as convert_router
|
||||
from .save import router as save_router
|
||||
|
||||
router = APIRouter(prefix="/api/image-studio", tags=["image-studio"])
|
||||
|
||||
router.include_router(health_router)
|
||||
router.include_router(upscale_router)
|
||||
router.include_router(control_router)
|
||||
router.include_router(social_router)
|
||||
router.include_router(edit_router)
|
||||
router.include_router(face_swap_router)
|
||||
router.include_router(create_router)
|
||||
router.include_router(transform_router)
|
||||
router.include_router(compress_router)
|
||||
router.include_router(convert_router)
|
||||
router.include_router(save_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
@@ -1,158 +0,0 @@
|
||||
"""Compression Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
CompressImageRequest, CompressImageResponse,
|
||||
CompressBatchRequest, CompressBatchResponse,
|
||||
CompressionEstimateRequest, CompressionEstimateResponse,
|
||||
CompressionFormatsResponse, CompressionPresetsResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/compress", response_model=CompressImageResponse, summary="Compress an image")
|
||||
async def compress_image(
|
||||
request: CompressImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Compress an image with specified quality and format settings."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image compression")
|
||||
logger.info(f"[Compression] Request from user {user_id}: format={request.format}, quality={request.quality}")
|
||||
|
||||
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
|
||||
|
||||
compression_request = ServiceRequest(
|
||||
image_base64=request.image_base64,
|
||||
quality=request.quality,
|
||||
format=request.format,
|
||||
target_size_kb=request.target_size_kb,
|
||||
strip_metadata=request.strip_metadata,
|
||||
progressive=request.progressive,
|
||||
optimize=request.optimize,
|
||||
)
|
||||
|
||||
result = await studio_manager.compress_image(compression_request, user_id=user_id)
|
||||
|
||||
return CompressImageResponse(
|
||||
success=result.success,
|
||||
image_base64=result.image_base64,
|
||||
original_size_kb=result.original_size_kb,
|
||||
compressed_size_kb=result.compressed_size_kb,
|
||||
compression_ratio=result.compression_ratio,
|
||||
format=result.format,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
quality_used=result.quality_used,
|
||||
metadata_stripped=result.metadata_stripped,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image compression failed: {e}")
|
||||
|
||||
|
||||
@router.post("/compress/batch", response_model=CompressBatchResponse, summary="Compress multiple images")
|
||||
async def compress_batch(
|
||||
request: CompressBatchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Compress multiple images with the same or individual settings."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "batch compression")
|
||||
logger.info(f"[Compression] Batch request from user {user_id}: {len(request.images)} images")
|
||||
|
||||
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
|
||||
|
||||
compression_requests = [
|
||||
ServiceRequest(
|
||||
image_base64=img.image_base64,
|
||||
quality=img.quality,
|
||||
format=img.format,
|
||||
target_size_kb=img.target_size_kb,
|
||||
strip_metadata=img.strip_metadata,
|
||||
progressive=img.progressive,
|
||||
optimize=img.optimize,
|
||||
)
|
||||
for img in request.images
|
||||
]
|
||||
|
||||
results = await studio_manager.compress_batch(compression_requests, user_id=user_id)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
|
||||
return CompressBatchResponse(
|
||||
success=failed == 0,
|
||||
results=[
|
||||
CompressImageResponse(
|
||||
success=r.success,
|
||||
image_base64=r.image_base64,
|
||||
original_size_kb=r.original_size_kb,
|
||||
compressed_size_kb=r.compressed_size_kb,
|
||||
compression_ratio=r.compression_ratio,
|
||||
format=r.format,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
quality_used=r.quality_used,
|
||||
metadata_stripped=r.metadata_stripped,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
total_images=len(results),
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Batch error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Batch compression failed: {e}")
|
||||
|
||||
|
||||
@router.post("/compress/estimate", response_model=CompressionEstimateResponse, summary="Estimate compression results")
|
||||
async def estimate_compression(
|
||||
request: CompressionEstimateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Estimate compression results without actually compressing the image."""
|
||||
try:
|
||||
result = await studio_manager.estimate_compression(
|
||||
request.image_base64,
|
||||
request.format,
|
||||
request.quality,
|
||||
)
|
||||
return CompressionEstimateResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Estimate error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Compression estimation failed: {e}")
|
||||
|
||||
|
||||
@router.get("/compress/formats", response_model=CompressionFormatsResponse, summary="Get supported compression formats")
|
||||
async def get_compression_formats(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get list of supported compression formats with their capabilities."""
|
||||
formats = studio_manager.get_compression_formats()
|
||||
return CompressionFormatsResponse(formats=formats)
|
||||
|
||||
|
||||
@router.get("/compress/presets", response_model=CompressionPresetsResponse, summary="Get compression presets")
|
||||
async def get_compression_presets(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get predefined compression presets for common use cases."""
|
||||
presets = studio_manager.get_compression_presets()
|
||||
return CompressionPresetsResponse(presets=presets)
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Control Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import ControlImageRequest, ControlImageResponse, ControlOperationsResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, ControlStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/control/process", response_model=ControlImageResponse, summary="Process Control Studio request")
|
||||
async def process_control_image(
|
||||
request: ControlImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Control Studio operations such as sketch-to-image, structure control, style control, and style transfer."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image control")
|
||||
logger.info(f"[Control Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
control_request = ControlStudioRequest(
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
control_image_base64=request.control_image_base64,
|
||||
style_image_base64=request.style_image_base64,
|
||||
negative_prompt=request.negative_prompt,
|
||||
control_strength=request.control_strength,
|
||||
fidelity=request.fidelity,
|
||||
style_strength=request.style_strength,
|
||||
composition_fidelity=request.composition_fidelity,
|
||||
change_strength=request.change_strength,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
|
||||
result = await studio_manager.control_image(control_request, user_id=user_id)
|
||||
return ControlImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image control failed: {e}")
|
||||
|
||||
|
||||
@router.get("/control/operations", response_model=ControlOperationsResponse, summary="List Control Studio operations")
|
||||
async def get_control_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Control Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_control_operations()
|
||||
return ControlOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load control operations")
|
||||
@@ -1,143 +0,0 @@
|
||||
"""Format Converter endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from .models import (
|
||||
ConvertFormatRequest, ConvertFormatResponse,
|
||||
ConvertFormatBatchRequest, ConvertFormatBatchResponse,
|
||||
SupportedFormatsResponse, FormatRecommendationsResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/convert-format", response_model=ConvertFormatResponse, summary="Convert image format")
|
||||
async def convert_format(
|
||||
request: ConvertFormatRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Convert an image to a different format."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "format conversion")
|
||||
logger.info(f"[Format Converter] Request from user {user_id}: {request.target_format}")
|
||||
|
||||
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
|
||||
|
||||
conversion_request = ServiceRequest(
|
||||
image_base64=request.image_base64,
|
||||
target_format=request.target_format,
|
||||
preserve_transparency=request.preserve_transparency,
|
||||
quality=request.quality,
|
||||
color_space=request.color_space,
|
||||
strip_metadata=request.strip_metadata,
|
||||
optimize=request.optimize,
|
||||
progressive=request.progressive,
|
||||
)
|
||||
|
||||
result = await studio_manager.convert_format(conversion_request, user_id=user_id)
|
||||
|
||||
return ConvertFormatResponse(
|
||||
success=result.success,
|
||||
image_base64=result.image_base64,
|
||||
original_format=result.original_format,
|
||||
target_format=result.target_format,
|
||||
original_size_kb=result.original_size_kb,
|
||||
converted_size_kb=result.converted_size_kb,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
transparency_preserved=result.transparency_preserved,
|
||||
metadata_preserved=result.metadata_preserved,
|
||||
color_space=result.color_space,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Format conversion failed: {e}")
|
||||
|
||||
|
||||
@router.post("/convert-format/batch", response_model=ConvertFormatBatchResponse, summary="Convert multiple images")
|
||||
async def convert_format_batch(
|
||||
request: ConvertFormatBatchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Convert multiple images to different formats."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "batch format conversion")
|
||||
logger.info(f"[Format Converter] Batch request from user {user_id}: {len(request.images)} images")
|
||||
|
||||
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
|
||||
|
||||
conversion_requests = [
|
||||
ServiceRequest(
|
||||
image_base64=img.image_base64,
|
||||
target_format=img.target_format,
|
||||
preserve_transparency=img.preserve_transparency,
|
||||
quality=img.quality,
|
||||
color_space=img.color_space,
|
||||
strip_metadata=img.strip_metadata,
|
||||
optimize=img.optimize,
|
||||
progressive=img.progressive,
|
||||
)
|
||||
for img in request.images
|
||||
]
|
||||
|
||||
results = await studio_manager.convert_format_batch(conversion_requests, user_id=user_id)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
|
||||
return ConvertFormatBatchResponse(
|
||||
success=failed == 0,
|
||||
results=[
|
||||
ConvertFormatResponse(
|
||||
success=r.success,
|
||||
image_base64=r.image_base64,
|
||||
original_format=r.original_format,
|
||||
target_format=r.target_format,
|
||||
original_size_kb=r.original_size_kb,
|
||||
converted_size_kb=r.converted_size_kb,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
transparency_preserved=r.transparency_preserved,
|
||||
metadata_preserved=r.metadata_preserved,
|
||||
color_space=r.color_space,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
total_images=len(results),
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] ❌ Batch error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Batch format conversion failed: {e}")
|
||||
|
||||
|
||||
@router.get("/convert-format/supported", response_model=SupportedFormatsResponse, summary="Get supported formats")
|
||||
async def get_supported_formats(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get list of supported conversion formats with their capabilities."""
|
||||
formats = studio_manager.get_supported_formats()
|
||||
return SupportedFormatsResponse(formats=formats)
|
||||
|
||||
|
||||
@router.get("/convert-format/recommendations", response_model=FormatRecommendationsResponse, summary="Get format recommendations")
|
||||
async def get_format_recommendations(
|
||||
source_format: str = Query(..., description="Source format"),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get format recommendations based on source format."""
|
||||
recommendations = studio_manager.get_format_recommendations(source_format)
|
||||
return FormatRecommendationsResponse(recommendations=recommendations)
|
||||
@@ -1,231 +0,0 @@
|
||||
"""Create Studio, Templates, Providers, Cost Estimation, and Platform Specs endpoints."""
|
||||
|
||||
import base64
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import CreateImageRequest, CostEstimationRequest
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, CreateStudioRequest
|
||||
from services.image_studio.templates import Platform, TemplateCategory
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/create", summary="Generate Image")
|
||||
async def create_image(
|
||||
request: CreateImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Generate image(s) using Create Studio."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image generation")
|
||||
logger.info(f"[Create Image] Request from user {user_id}: {request.prompt[:100]}")
|
||||
|
||||
studio_request = CreateStudioRequest(
|
||||
prompt=request.prompt,
|
||||
template_id=request.template_id,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
quality=request.quality,
|
||||
negative_prompt=request.negative_prompt,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
num_variations=request.num_variations,
|
||||
enhance_prompt=request.enhance_prompt,
|
||||
use_persona=request.use_persona,
|
||||
persona_id=request.persona_id,
|
||||
)
|
||||
|
||||
result = await studio_manager.create_image(studio_request, user_id=user_id)
|
||||
|
||||
for idx, img_result in enumerate(result["results"]):
|
||||
if "image_bytes" in img_result:
|
||||
img_result["image_base64"] = base64.b64encode(img_result["image_bytes"]).decode("utf-8")
|
||||
del img_result["image_bytes"]
|
||||
|
||||
logger.info(f"[Create Image] ✅ Success: {result['total_generated']} images generated")
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Create Image] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"[Create Image] ❌ Generation error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Create Image] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/templates", summary="Get Templates")
|
||||
async def get_templates(
|
||||
platform: Optional[Platform] = None,
|
||||
category: Optional[TemplateCategory] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available image templates."""
|
||||
try:
|
||||
templates = studio_manager.get_templates(platform=platform, category=category)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict)}
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/search", summary="Search Templates")
|
||||
async def search_templates(
|
||||
query: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Search templates by query."""
|
||||
try:
|
||||
templates = studio_manager.search_templates(query)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "query": query}
|
||||
except Exception as e:
|
||||
logger.error(f"[Search Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/recommend", summary="Recommend Templates")
|
||||
async def recommend_templates(
|
||||
use_case: str,
|
||||
platform: Optional[Platform] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Recommend templates based on use case."""
|
||||
try:
|
||||
templates = studio_manager.recommend_templates(use_case, platform=platform)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "use_case": use_case}
|
||||
except Exception as e:
|
||||
logger.error(f"[Recommend Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/providers", summary="Get Providers")
|
||||
async def get_providers(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available AI providers and their capabilities."""
|
||||
try:
|
||||
providers = studio_manager.get_providers()
|
||||
return {"providers": providers}
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Providers] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/estimate-cost", summary="Estimate Cost")
|
||||
async def estimate_cost(
|
||||
request: CostEstimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Estimate cost for image generation operations."""
|
||||
try:
|
||||
resolution = None
|
||||
if request.width and request.height:
|
||||
resolution = (request.width, request.height)
|
||||
estimate = studio_manager.estimate_cost(
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
operation=request.operation,
|
||||
num_images=request.num_images,
|
||||
resolution=resolution
|
||||
)
|
||||
return estimate
|
||||
except Exception as e:
|
||||
logger.error(f"[Estimate Cost] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/platform-specs/{platform}", summary="Get Platform Specifications")
|
||||
async def get_platform_specs(
|
||||
platform: Platform,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get specifications and requirements for a specific platform."""
|
||||
try:
|
||||
specs = studio_manager.get_platform_specs(platform)
|
||||
if not specs:
|
||||
raise HTTPException(status_code=404, detail=f"Specifications not found for platform: {platform}")
|
||||
return specs
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Platform Specs] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Shared dependencies for Image Studio API endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
|
||||
|
||||
def get_studio_manager() -> ImageStudioManager:
|
||||
"""Get Image Studio Manager instance."""
|
||||
return ImageStudioManager()
|
||||
|
||||
|
||||
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
|
||||
"""Ensure user_id is available for protected operations."""
|
||||
user_id = (
|
||||
current_user.get("sub")
|
||||
or current_user.get("user_id")
|
||||
or current_user.get("id")
|
||||
or current_user.get("clerk_user_id")
|
||||
)
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"[Image Studio] ❌ Missing user_id for %s operation - blocking request",
|
||||
operation,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authenticated user required for image operations.",
|
||||
)
|
||||
return user_id
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Edit Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
EditImageRequest, EditImageResponse, EditOperationsResponse,
|
||||
EditModelsResponse, EditModelRecommendationRequest, EditModelRecommendationResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, EditStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/edit/process", response_model=EditImageResponse, summary="Process Edit Studio request")
|
||||
async def process_edit_image(
|
||||
request: EditImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Edit Studio operations such as remove background, inpaint, or recolor."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image editing")
|
||||
logger.info(f"[Edit Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
edit_request = EditStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
mask_base64=request.mask_base64,
|
||||
search_prompt=request.search_prompt,
|
||||
select_prompt=request.select_prompt,
|
||||
background_image_base64=request.background_image_base64,
|
||||
lighting_image_base64=request.lighting_image_base64,
|
||||
expand_left=request.expand_left,
|
||||
expand_right=request.expand_right,
|
||||
expand_up=request.expand_up,
|
||||
expand_down=request.expand_down,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
style_preset=request.style_preset,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
options=request.options or {},
|
||||
)
|
||||
|
||||
result = await studio_manager.edit_image(edit_request, user_id=user_id)
|
||||
return EditImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image editing failed: {e}")
|
||||
|
||||
|
||||
@router.get("/edit/operations", response_model=EditOperationsResponse, summary="List Edit Studio operations")
|
||||
async def get_edit_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Edit Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_edit_operations()
|
||||
return EditOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load edit operations")
|
||||
|
||||
|
||||
@router.get("/edit/models", response_model=EditModelsResponse, summary="List available editing models")
|
||||
async def get_edit_models(
|
||||
operation: Optional[str] = None,
|
||||
tier: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available WaveSpeed editing models with metadata.
|
||||
|
||||
Query Parameters:
|
||||
- operation: Filter by operation type (e.g., "general_edit")
|
||||
- tier: Filter by tier ("budget", "mid", "premium")
|
||||
"""
|
||||
try:
|
||||
result = studio_manager.get_edit_models(operation=operation, tier=tier)
|
||||
return EditModelsResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Models] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load editing models")
|
||||
|
||||
|
||||
@router.post("/edit/recommend", response_model=EditModelRecommendationResponse, summary="Get model recommendation")
|
||||
async def recommend_edit_model(
|
||||
request: EditModelRecommendationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get recommended editing model based on operation, image resolution, and user preferences.
|
||||
|
||||
Auto-detects best model when user doesn't specify one.
|
||||
"""
|
||||
try:
|
||||
user_tier = request.user_tier
|
||||
if not user_tier and current_user:
|
||||
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
|
||||
|
||||
result = studio_manager.recommend_edit_model(
|
||||
operation=request.operation,
|
||||
image_resolution=request.image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=request.preferences,
|
||||
)
|
||||
return EditModelRecommendationResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Recommend] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Face Swap Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
FaceSwapRequest, FaceSwapResponse, FaceSwapModelsResponse,
|
||||
FaceSwapModelRecommendationRequest, FaceSwapModelRecommendationResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from services.image_studio.face_swap_service import FaceSwapStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/face-swap/process", response_model=FaceSwapResponse, summary="Process Face Swap")
|
||||
async def process_face_swap(
|
||||
request: FaceSwapRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Process face swap request with auto-detection and model selection."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "face swap")
|
||||
face_swap_request = FaceSwapStudioRequest(
|
||||
base_image_base64=request.base_image_base64,
|
||||
face_image_base64=request.face_image_base64,
|
||||
model=request.model,
|
||||
target_face_index=request.target_face_index,
|
||||
target_gender=request.target_gender,
|
||||
options=request.options,
|
||||
)
|
||||
result = await studio_manager.face_swap(face_swap_request, user_id=user_id)
|
||||
return FaceSwapResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap] ❌ Error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Face swap failed: {e}")
|
||||
|
||||
|
||||
@router.get("/face-swap/models", response_model=FaceSwapModelsResponse, summary="List available face swap models")
|
||||
async def get_face_swap_models(
|
||||
tier: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available WaveSpeed face swap models with metadata.
|
||||
|
||||
Query Parameters:
|
||||
- tier: Filter by tier ("budget", "mid", "premium")
|
||||
"""
|
||||
try:
|
||||
result = studio_manager.get_face_swap_models(tier=tier)
|
||||
return FaceSwapModelsResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap Models] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load face swap models")
|
||||
|
||||
|
||||
@router.post("/face-swap/recommend", response_model=FaceSwapModelRecommendationResponse, summary="Get face swap model recommendation")
|
||||
async def recommend_face_swap_model(
|
||||
request: FaceSwapModelRecommendationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get recommended face swap model based on image resolutions and user preferences.
|
||||
|
||||
Auto-detects best model when user doesn't specify one.
|
||||
"""
|
||||
try:
|
||||
user_tier = request.user_tier
|
||||
if not user_tier and current_user:
|
||||
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
|
||||
|
||||
result = studio_manager.recommend_face_swap_model(
|
||||
base_image_resolution=request.base_image_resolution,
|
||||
face_image_resolution=request.face_image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=request.preferences,
|
||||
)
|
||||
return FaceSwapModelRecommendationResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap Recommend] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Health check endpoint."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.get("/health", summary="Health Check")
|
||||
async def health_check():
|
||||
"""Health check endpoint for Image Studio."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "image_studio",
|
||||
"version": "1.0.0",
|
||||
"modules": {
|
||||
"create_studio": "available",
|
||||
"templates": "available",
|
||||
"providers": "available",
|
||||
"compression": "available",
|
||||
}
|
||||
}
|
||||
@@ -1,372 +0,0 @@
|
||||
"""Pydantic request/response models for Image Studio API."""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ==================== Create Studio ====================
|
||||
|
||||
class CreateImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description="Image generation prompt")
|
||||
template_id: Optional[str] = Field(None, description="Template ID to use")
|
||||
provider: Optional[str] = Field("auto", description="Provider: auto, stability, wavespeed, huggingface, gemini")
|
||||
model: Optional[str] = Field(None, description="Specific model to use")
|
||||
width: Optional[int] = Field(None, description="Image width in pixels")
|
||||
height: Optional[int] = Field(None, description="Image height in pixels")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (e.g., '1:1', '16:9')")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
quality: str = Field("standard", description="Quality: draft, standard, premium")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale")
|
||||
steps: Optional[int] = Field(None, description="Number of inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
num_variations: int = Field(1, ge=1, le=10, description="Number of variations (1-10)")
|
||||
enhance_prompt: bool = Field(True, description="Enhance prompt with AI")
|
||||
use_persona: bool = Field(False, description="Use persona for brand consistency")
|
||||
persona_id: Optional[str] = Field(None, description="Persona ID")
|
||||
|
||||
|
||||
class CostEstimationRequest(BaseModel):
|
||||
provider: str = Field(..., description="Provider name")
|
||||
model: Optional[str] = Field(None, description="Model name")
|
||||
operation: str = Field("generate", description="Operation type")
|
||||
num_images: int = Field(1, ge=1, description="Number of images")
|
||||
width: Optional[int] = Field(None, description="Image width")
|
||||
height: Optional[int] = Field(None, description="Image height")
|
||||
|
||||
|
||||
# ==================== Edit Studio ====================
|
||||
|
||||
class EditImageRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Primary image payload (base64 or data URL)")
|
||||
operation: Literal[
|
||||
"remove_background",
|
||||
"inpaint",
|
||||
"outpaint",
|
||||
"search_replace",
|
||||
"search_recolor",
|
||||
"general_edit",
|
||||
] = Field(..., description="Edit operation to perform")
|
||||
prompt: Optional[str] = Field(None, description="Primary prompt/instruction")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt for providers that support it")
|
||||
mask_base64: Optional[str] = Field(None, description="Optional mask image in base64")
|
||||
search_prompt: Optional[str] = Field(None, description="Search prompt for replace operations")
|
||||
select_prompt: Optional[str] = Field(None, description="Select prompt for recolor operations")
|
||||
background_image_base64: Optional[str] = Field(None, description="Reference background image")
|
||||
lighting_image_base64: Optional[str] = Field(None, description="Reference lighting image")
|
||||
expand_left: Optional[int] = Field(0, description="Outpaint expansion in pixels (left)")
|
||||
expand_right: Optional[int] = Field(0, description="Outpaint expansion in pixels (right)")
|
||||
expand_up: Optional[int] = Field(0, description="Outpaint expansion in pixels (up)")
|
||||
expand_down: Optional[int] = Field(0, description="Outpaint expansion in pixels (down)")
|
||||
provider: Optional[str] = Field(None, description="Explicit provider override")
|
||||
model: Optional[str] = Field(None, description="Explicit model override")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset for Stability helpers")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale for general edits")
|
||||
steps: Optional[int] = Field(None, description="Inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
output_format: str = Field("png", description="Output format for edited image")
|
||||
options: Optional[Dict[str, Any]] = Field(None, description="Advanced provider-specific options (e.g., grow_mask)")
|
||||
|
||||
|
||||
class EditImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class EditOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
class EditModelsResponse(BaseModel):
|
||||
models: List[Dict[str, Any]]
|
||||
total: int
|
||||
|
||||
|
||||
class EditModelRecommendationRequest(BaseModel):
|
||||
operation: str
|
||||
image_resolution: Optional[Dict[str, int]] = None
|
||||
user_tier: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class EditModelRecommendationResponse(BaseModel):
|
||||
recommended_model: str
|
||||
reason: str
|
||||
alternatives: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Face Swap Studio ====================
|
||||
|
||||
class FaceSwapRequest(BaseModel):
|
||||
base_image_base64: str
|
||||
face_image_base64: str
|
||||
model: Optional[str] = None
|
||||
target_face_index: Optional[int] = None
|
||||
target_gender: Optional[str] = None
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FaceSwapResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class FaceSwapModelsResponse(BaseModel):
|
||||
models: List[Dict[str, Any]]
|
||||
total: int
|
||||
|
||||
|
||||
class FaceSwapModelRecommendationRequest(BaseModel):
|
||||
base_image_resolution: Optional[Dict[str, int]] = None
|
||||
face_image_resolution: Optional[Dict[str, int]] = None
|
||||
user_tier: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FaceSwapModelRecommendationResponse(BaseModel):
|
||||
recommended_model: str
|
||||
reason: str
|
||||
alternatives: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Upscale Studio ====================
|
||||
|
||||
class UpscaleImageRequest(BaseModel):
|
||||
image_base64: str
|
||||
mode: Literal["fast", "conservative", "creative", "auto"] = "auto"
|
||||
target_width: Optional[int] = Field(None, description="Target width in pixels")
|
||||
target_height: Optional[int] = Field(None, description="Target height in pixels")
|
||||
preset: Optional[str] = Field(None, description="Named preset (web, print, social)")
|
||||
prompt: Optional[str] = Field(None, description="Prompt for conservative/creative modes")
|
||||
|
||||
|
||||
class UpscaleImageResponse(BaseModel):
|
||||
success: bool
|
||||
mode: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
# ==================== Control Studio ====================
|
||||
|
||||
class ControlImageRequest(BaseModel):
|
||||
control_image_base64: str = Field(..., description="Control image (sketch/structure/style) in base64")
|
||||
operation: Literal["sketch", "structure", "style", "style_transfer"] = Field(..., description="Control operation")
|
||||
prompt: str = Field(..., description="Text prompt for generation")
|
||||
style_image_base64: Optional[str] = Field(None, description="Style reference image (for style_transfer only)")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
control_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Control strength (sketch/structure)")
|
||||
fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style fidelity (style operation)")
|
||||
style_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style strength (style_transfer)")
|
||||
composition_fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Composition fidelity (style_transfer)")
|
||||
change_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Change strength (style_transfer)")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (style operation)")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
output_format: str = Field("png", description="Output format")
|
||||
|
||||
|
||||
class ControlImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class ControlOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Social Optimizer ====================
|
||||
|
||||
class SocialOptimizeRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Source image in base64 or data URL")
|
||||
platforms: List[str] = Field(..., description="List of platforms to optimize for")
|
||||
format_names: Optional[Dict[str, str]] = Field(None, description="Specific format per platform")
|
||||
show_safe_zones: bool = Field(False, description="Include safe zone overlay in output")
|
||||
crop_mode: str = Field("smart", description="Crop mode: smart, center, or fit")
|
||||
focal_point: Optional[Dict[str, float]] = Field(None, description="Focal point for smart crop (x, y as 0-1)")
|
||||
output_format: str = Field("png", description="Output format (png or jpg)")
|
||||
|
||||
|
||||
class SocialOptimizeResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[Dict[str, Any]]
|
||||
total_optimized: int
|
||||
|
||||
|
||||
class PlatformFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Transform Studio ====================
|
||||
|
||||
class TransformImageToVideoRequestModel(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
prompt: str = Field(..., description="Text prompt describing the video")
|
||||
audio_base64: Optional[str] = Field(None, description="Optional audio file (wav/mp3, 3-30s, ≤15MB)")
|
||||
resolution: Literal["480p", "720p", "1080p"] = Field("720p", description="Output resolution")
|
||||
duration: Literal[5, 10] = Field(5, description="Video duration in seconds")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
enable_prompt_expansion: bool = Field(True, description="Enable prompt optimizer")
|
||||
|
||||
|
||||
class TalkingAvatarRequestModel(BaseModel):
|
||||
image_base64: str = Field(..., description="Person image in base64 or data URL")
|
||||
audio_base64: str = Field(..., description="Audio file in base64 or data URL (wav/mp3, max 10 minutes)")
|
||||
resolution: Literal["480p", "720p"] = Field("720p", description="Output resolution")
|
||||
prompt: Optional[str] = Field(None, description="Optional prompt for expression/style")
|
||||
mask_image_base64: Optional[str] = Field(None, description="Optional mask for animatable regions")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
|
||||
|
||||
class TransformVideoResponse(BaseModel):
|
||||
success: bool
|
||||
video_url: Optional[str] = None
|
||||
video_base64: Optional[str] = None
|
||||
duration: float
|
||||
resolution: str
|
||||
width: int
|
||||
height: int
|
||||
file_size: int
|
||||
cost: float
|
||||
provider: str
|
||||
model: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TransformCostEstimateRequest(BaseModel):
|
||||
operation: Literal["image-to-video", "talking-avatar"] = Field(..., description="Operation type")
|
||||
resolution: str = Field(..., description="Output resolution")
|
||||
duration: Optional[int] = Field(None, description="Video duration in seconds (for image-to-video)")
|
||||
|
||||
|
||||
class TransformCostEstimateResponse(BaseModel):
|
||||
estimated_cost: float
|
||||
breakdown: Dict[str, Any]
|
||||
currency: str
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
# ==================== Compression ====================
|
||||
|
||||
class CompressImageRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
quality: int = Field(85, ge=1, le=100, description="Compression quality (1-100)")
|
||||
format: str = Field("jpeg", description="Output format: jpeg, png, webp")
|
||||
target_size_kb: Optional[int] = Field(None, ge=10, description="Target file size in KB")
|
||||
strip_metadata: bool = Field(True, description="Remove EXIF metadata")
|
||||
progressive: bool = Field(True, description="Progressive JPEG encoding")
|
||||
optimize: bool = Field(True, description="Optimize encoding")
|
||||
|
||||
|
||||
class CompressImageResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_size_kb: float
|
||||
compressed_size_kb: float
|
||||
compression_ratio: float
|
||||
format: str
|
||||
width: int
|
||||
height: int
|
||||
quality_used: int
|
||||
metadata_stripped: bool
|
||||
|
||||
|
||||
class CompressBatchRequest(BaseModel):
|
||||
images: List[CompressImageRequest] = Field(..., description="List of images to compress")
|
||||
|
||||
|
||||
class CompressBatchResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[CompressImageResponse]
|
||||
total_images: int
|
||||
successful: int
|
||||
failed: int
|
||||
|
||||
|
||||
class CompressionEstimateRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
format: str = Field("jpeg", description="Output format")
|
||||
quality: int = Field(85, ge=1, le=100, description="Quality level")
|
||||
|
||||
|
||||
class CompressionEstimateResponse(BaseModel):
|
||||
original_size_kb: float
|
||||
estimated_size_kb: float
|
||||
estimated_reduction_percent: float
|
||||
width: int
|
||||
height: int
|
||||
format: str
|
||||
|
||||
|
||||
class CompressionFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class CompressionPresetsResponse(BaseModel):
|
||||
presets: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Format Converter ====================
|
||||
|
||||
class ConvertFormatRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
target_format: str = Field(..., description="Target format: png, jpeg, jpg, webp, gif, bmp, tiff")
|
||||
preserve_transparency: bool = Field(True, description="Preserve transparency when possible")
|
||||
quality: Optional[int] = Field(None, ge=1, le=100, description="Quality for lossy formats (1-100)")
|
||||
color_space: Optional[str] = Field(None, description="Color space: sRGB, Adobe RGB")
|
||||
strip_metadata: bool = Field(False, description="Remove EXIF metadata")
|
||||
optimize: bool = Field(True, description="Optimize encoding")
|
||||
progressive: bool = Field(True, description="Progressive JPEG encoding")
|
||||
|
||||
|
||||
class ConvertFormatResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_format: str
|
||||
target_format: str
|
||||
original_size_kb: float
|
||||
converted_size_kb: float
|
||||
width: int
|
||||
height: int
|
||||
transparency_preserved: bool
|
||||
metadata_preserved: bool
|
||||
color_space: Optional[str] = None
|
||||
|
||||
|
||||
class ConvertFormatBatchRequest(BaseModel):
|
||||
images: List[ConvertFormatRequest] = Field(..., description="List of images to convert")
|
||||
|
||||
|
||||
class ConvertFormatBatchResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[ConvertFormatResponse]
|
||||
total_images: int
|
||||
successful: int
|
||||
failed: int
|
||||
|
||||
|
||||
class SupportedFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class FormatRecommendationsResponse(BaseModel):
|
||||
recommendations: List[Dict[str, Any]]
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Save generated images to the unified asset library."""
|
||||
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .deps import _require_user_id
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.storage_paths import get_repo_root, sanitize_user_id
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
class SaveToLibraryRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Base64-encoded image (or data URL)")
|
||||
prompt: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: Optional[float] = None
|
||||
operation: str = Field("image-generation", description="Operation type for labelling")
|
||||
output_format: str = Field("png", description="Output image format")
|
||||
|
||||
|
||||
@router.post("/save-to-library")
|
||||
async def save_to_library(
|
||||
req: SaveToLibraryRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Save a generated image to the asset library.
|
||||
|
||||
Decodes base64 image data, saves to workspace disk storage,
|
||||
and creates a record in the ContentAsset database table.
|
||||
"""
|
||||
user_id = _require_user_id(current_user, "save-to-library")
|
||||
|
||||
# Decode base64 payload
|
||||
try:
|
||||
b64data = req.image_base64
|
||||
if "base64," in b64data:
|
||||
b64data = b64data.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(b64data)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid base64 image data")
|
||||
|
||||
# Generate file path under workspace
|
||||
safe_user = sanitize_user_id(user_id)
|
||||
repo_root = get_repo_root()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"generated_{timestamp}.{req.output_format or 'png'}"
|
||||
|
||||
assets_dir = repo_root / "workspace" / f"workspace_{safe_user}" / "assets" / "images"
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = assets_dir / filename
|
||||
file_path.write_bytes(image_bytes)
|
||||
|
||||
# Build serving URL (assets_serving.py serves /{user_id}/avatars/{filename})
|
||||
file_url = f"/api/assets/{safe_user}/avatars/{filename}"
|
||||
|
||||
# Save to unified asset library via existing utility
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="image_studio",
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=str(file_path),
|
||||
file_size=len(image_bytes),
|
||||
mime_type=f"image/{req.output_format or 'png'}",
|
||||
title=f"Generated Image - {timestamp}",
|
||||
prompt=req.prompt,
|
||||
provider=req.provider,
|
||||
model=req.model,
|
||||
cost=req.cost,
|
||||
)
|
||||
|
||||
if not asset_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to save to asset library")
|
||||
|
||||
logger.info(f"[Save to Library] ✅ Image saved: asset_id={asset_id}, user={user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_id": asset_id,
|
||||
"file_url": file_url,
|
||||
"filename": filename,
|
||||
"file_size": len(image_bytes),
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Social Optimizer endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import SocialOptimizeRequest, SocialOptimizeResponse, PlatformFormatsResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, SocialOptimizerRequest
|
||||
from services.image_studio.templates import Platform
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/social/optimize", response_model=SocialOptimizeResponse, summary="Optimize image for social platforms")
|
||||
async def optimize_for_social(
|
||||
request: SocialOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Optimize an image for multiple social media platforms with smart cropping and safe zones."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "social optimization")
|
||||
logger.info(f"[Social Optimizer] Request from user {user_id}: platforms={request.platforms}")
|
||||
|
||||
platforms = []
|
||||
for platform_str in request.platforms:
|
||||
try:
|
||||
platforms.append(Platform(platform_str.lower()))
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform: {platform_str}")
|
||||
continue
|
||||
|
||||
if not platforms:
|
||||
raise HTTPException(status_code=400, detail="No valid platforms provided")
|
||||
|
||||
format_names = None
|
||||
if request.format_names:
|
||||
format_names = {}
|
||||
for platform_str, format_name in request.format_names.items():
|
||||
try:
|
||||
platform = Platform(platform_str.lower())
|
||||
format_names[platform] = format_name
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform in format_names: {platform_str}")
|
||||
|
||||
social_request = SocialOptimizerRequest(
|
||||
image_base64=request.image_base64,
|
||||
platforms=platforms,
|
||||
format_names=format_names,
|
||||
show_safe_zones=request.show_safe_zones,
|
||||
crop_mode=request.crop_mode,
|
||||
focal_point=request.focal_point,
|
||||
output_format=request.output_format,
|
||||
options={},
|
||||
)
|
||||
|
||||
result = await studio_manager.optimize_for_social(social_request, user_id=user_id)
|
||||
return SocialOptimizeResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Social Optimizer] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Social optimization failed: {e}")
|
||||
|
||||
|
||||
@router.get("/social/platforms/{platform}/formats", response_model=PlatformFormatsResponse, summary="Get platform formats")
|
||||
async def get_platform_formats(
|
||||
platform: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available formats for a social media platform."""
|
||||
try:
|
||||
try:
|
||||
platform_enum = Platform(platform.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid platform: {platform}")
|
||||
|
||||
formats = studio_manager.get_social_platform_formats(platform_enum)
|
||||
return PlatformFormatsResponse(formats=formats)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Platform Formats] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load platform formats: {e}")
|
||||
@@ -1,158 +0,0 @@
|
||||
"""Transform Studio endpoints — image-to-video, talking avatar, and video serving."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .models import (
|
||||
TransformImageToVideoRequestModel, TalkingAvatarRequestModel,
|
||||
TransformVideoResponse, TransformCostEstimateRequest, TransformCostEstimateResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, TransformImageToVideoRequest, TalkingAvatarRequest
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/transform/image-to-video", response_model=TransformVideoResponse, summary="Transform Image to Video")
|
||||
async def transform_image_to_video(
|
||||
request: TransformImageToVideoRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Transform an image into a video using WAN 2.5."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image-to-video transformation")
|
||||
logger.info(f"[Transform Studio] Image-to-video request from user {user_id}: resolution={request.resolution}, duration={request.duration}s")
|
||||
|
||||
transform_request = TransformImageToVideoRequest(
|
||||
image_base64=request.image_base64,
|
||||
prompt=request.prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=request.enable_prompt_expansion,
|
||||
)
|
||||
|
||||
result = await studio_manager.transform_image_to_video(transform_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Image-to-video completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/talking-avatar", response_model=TransformVideoResponse, summary="Create Talking Avatar")
|
||||
async def create_talking_avatar(
|
||||
request: TalkingAvatarRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Create a talking avatar video using InfiniteTalk."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "talking avatar generation")
|
||||
logger.info(f"[Transform Studio] Talking avatar request from user {user_id}: resolution={request.resolution}")
|
||||
|
||||
avatar_request = TalkingAvatarRequest(
|
||||
image_base64=request.image_base64,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=request.prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
result = await studio_manager.create_talking_avatar(avatar_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Talking avatar completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Talking avatar generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/estimate-cost", response_model=TransformCostEstimateResponse, summary="Estimate Transform Cost")
|
||||
async def estimate_transform_cost(
|
||||
request: TransformCostEstimateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Estimate cost for transform operations."""
|
||||
try:
|
||||
estimate = studio_manager.estimate_transform_cost(
|
||||
operation=request.operation,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
)
|
||||
return TransformCostEstimateResponse(**estimate)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Cost estimation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/videos/{user_id}/{video_filename:path}", summary="Serve Transform Studio Video")
|
||||
async def serve_transform_video(
|
||||
user_id: str,
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve a generated Transform Studio video file."""
|
||||
try:
|
||||
authenticated_user_id = _require_user_id(current_user, "video access")
|
||||
if authenticated_user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: You can only access your own videos"
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
transform_videos_dir = base_dir / "transform_videos"
|
||||
video_path = transform_videos_dir / user_id / video_filename
|
||||
|
||||
try:
|
||||
resolved_video_path = video_path.resolve()
|
||||
resolved_base = transform_videos_dir.resolve()
|
||||
resolved_video_path.relative_to(resolved_base)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid video path: path traversal detected"
|
||||
)
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Video not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] Failed to serve video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Upscale Studio endpoint."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import UpscaleImageRequest, UpscaleImageResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from services.image_studio.upscale_service import UpscaleStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/upscale", response_model=UpscaleImageResponse, summary="Upscale Image")
|
||||
async def upscale_image(
|
||||
request: UpscaleImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Upscale an image using Stability AI pipelines."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image upscaling")
|
||||
upscale_request = UpscaleStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
mode=request.mode,
|
||||
target_width=request.target_width,
|
||||
target_height=request.target_height,
|
||||
preset=request.preset,
|
||||
prompt=request.prompt,
|
||||
)
|
||||
result = await studio_manager.upscale_image(upscale_request, user_id=user_id)
|
||||
return UpscaleImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Upscale Image] ❌ Error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image upscaling failed: {e}")
|
||||
@@ -1,182 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import RedirectResponse
|
||||
from loguru import logger
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.database import get_db
|
||||
|
||||
router = APIRouter(prefix="/v1/social-proxy", tags=["social-proxy"])
|
||||
|
||||
|
||||
def _utc_now_iso() -> str:
|
||||
return datetime.now(timezone.utc).isoformat()
|
||||
|
||||
|
||||
def _ensure_tables(db: Session) -> None:
|
||||
# Keep this router backward-compatible on tenant DBs without migrations.
|
||||
db.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS oauth_nonce_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
state TEXT NOT NULL UNIQUE,
|
||||
nonce TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
channel_id INTEGER,
|
||||
consumed_at TEXT,
|
||||
expires_at TEXT,
|
||||
created_at TEXT NOT NULL
|
||||
)
|
||||
"""))
|
||||
db.execute(text("""
|
||||
CREATE TABLE IF NOT EXISTS social_channels (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id TEXT NOT NULL,
|
||||
platform TEXT NOT NULL,
|
||||
platform_account_id TEXT NOT NULL,
|
||||
token_bundle TEXT NOT NULL,
|
||||
token_version INTEGER NOT NULL DEFAULT 1,
|
||||
publication_linkage TEXT,
|
||||
is_connected INTEGER NOT NULL DEFAULT 1,
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
UNIQUE(platform, platform_account_id)
|
||||
)
|
||||
"""))
|
||||
|
||||
|
||||
def _build_redirect(base_url: str, code: str, message: str, channel_id: Optional[int] = None) -> RedirectResponse:
|
||||
params = {"code": code, "message": message}
|
||||
if channel_id is not None:
|
||||
params["channel_id"] = str(channel_id)
|
||||
return RedirectResponse(url=f"{base_url}?{urlencode(params)}", status_code=303)
|
||||
|
||||
|
||||
@router.get("/oauth/callback")
|
||||
def oauth_callback(
|
||||
state: str = Query(...),
|
||||
platform: str = Query(...),
|
||||
account_id: str = Query(...),
|
||||
token_bundle: str = Query(..., description="Serialized token payload"),
|
||||
ui_redirect: str = Query("/dashboard/connections"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Consume OAuth callback, bind to user/platform, and upsert social channel connection."""
|
||||
_ensure_tables(db)
|
||||
|
||||
record = db.execute(
|
||||
text("""
|
||||
SELECT id, nonce, user_id, platform, channel_id, consumed_at, expires_at
|
||||
FROM oauth_nonce_sessions WHERE state = :state
|
||||
"""),
|
||||
{"state": state},
|
||||
).mappings().first()
|
||||
|
||||
if not record:
|
||||
return _build_redirect(ui_redirect, "invalid_state", "Missing OAuth session")
|
||||
|
||||
if record["consumed_at"] is not None:
|
||||
return _build_redirect(ui_redirect, "state_reused", "OAuth state already consumed")
|
||||
|
||||
if record["platform"] != platform:
|
||||
return _build_redirect(ui_redirect, "platform_mismatch", "Platform mismatch")
|
||||
|
||||
if record["expires_at"] and record["expires_at"] < _utc_now_iso():
|
||||
return _build_redirect(ui_redirect, "state_expired", "OAuth session expired")
|
||||
|
||||
user_id = record["user_id"]
|
||||
|
||||
# Validate token payload is JSON.
|
||||
try:
|
||||
parsed_bundle = json.loads(token_bundle)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise HTTPException(status_code=400, detail="Invalid token_bundle JSON") from exc
|
||||
|
||||
now = _utc_now_iso()
|
||||
|
||||
existing = db.execute(
|
||||
text("""
|
||||
SELECT id, publication_linkage, token_version
|
||||
FROM social_channels
|
||||
WHERE platform = :platform AND platform_account_id = :account_id
|
||||
"""),
|
||||
{"platform": platform, "account_id": account_id},
|
||||
).mappings().first()
|
||||
|
||||
if existing:
|
||||
# Reconnect path: preserve publication linkage and bump token version.
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE social_channels
|
||||
SET user_id = :user_id,
|
||||
token_bundle = :token_bundle,
|
||||
token_version = :token_version,
|
||||
is_connected = 1,
|
||||
updated_at = :updated_at
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{
|
||||
"id": existing["id"],
|
||||
"user_id": user_id,
|
||||
"token_bundle": json.dumps(parsed_bundle),
|
||||
"token_version": int(existing["token_version"] or 0) + 1,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
channel_id = existing["id"]
|
||||
result_code = "reconnected"
|
||||
result_message = "Channel reconnected"
|
||||
else:
|
||||
db.execute(
|
||||
text("""
|
||||
INSERT INTO social_channels (
|
||||
user_id, platform, platform_account_id, token_bundle,
|
||||
token_version, publication_linkage, is_connected, created_at, updated_at
|
||||
) VALUES (
|
||||
:user_id, :platform, :account_id, :token_bundle,
|
||||
1, :publication_linkage, 1, :created_at, :updated_at
|
||||
)
|
||||
"""),
|
||||
{
|
||||
"user_id": user_id,
|
||||
"platform": platform,
|
||||
"account_id": account_id,
|
||||
"token_bundle": json.dumps(parsed_bundle),
|
||||
"publication_linkage": None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
},
|
||||
)
|
||||
channel_id = db.execute(text("SELECT last_insert_rowid()")).scalar_one()
|
||||
result_code = "connected"
|
||||
result_message = "Channel connected"
|
||||
|
||||
# Bind callback session to concrete channel/user/platform and mark consumed.
|
||||
db.execute(
|
||||
text("""
|
||||
UPDATE oauth_nonce_sessions
|
||||
SET consumed_at = :consumed_at,
|
||||
channel_id = :channel_id,
|
||||
user_id = :user_id,
|
||||
platform = :platform
|
||||
WHERE id = :id
|
||||
"""),
|
||||
{
|
||||
"id": record["id"],
|
||||
"consumed_at": now,
|
||||
"channel_id": channel_id,
|
||||
"user_id": user_id,
|
||||
"platform": platform,
|
||||
},
|
||||
)
|
||||
|
||||
db.commit()
|
||||
logger.info(f"OAuth callback complete user={user_id} platform={platform} channel_id={channel_id}")
|
||||
return _build_redirect(ui_redirect, result_code, result_message, channel_id)
|
||||
@@ -2,10 +2,6 @@
|
||||
"""
|
||||
Initialize Alpha Tester Subscription Tiers
|
||||
Creates subscription plans for alpha testing with appropriate limits.
|
||||
|
||||
NOTE: Pricing is seeded via PricingService.initialize_default_pricing()
|
||||
which runs in services/database.py:init_user_database()
|
||||
NOT via this script.
|
||||
"""
|
||||
|
||||
import sys
|
||||
@@ -14,7 +10,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from models.subscription_models import (
|
||||
SubscriptionPlan, SubscriptionTier
|
||||
SubscriptionPlan, SubscriptionTier, APIProviderPricing, APIProvider
|
||||
)
|
||||
from services.database import get_db_session
|
||||
from datetime import datetime
|
||||
@@ -28,7 +24,7 @@ def create_alpha_subscription_tiers():
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error("Could not get database session")
|
||||
logger.error("❌ Could not get database session")
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -42,12 +38,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Free tier for alpha testing - Limited usage",
|
||||
"features": ["blog_writer", "basic_seo", "content_planning"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 50,
|
||||
"gemini_tokens_limit": 10000,
|
||||
"tavily_calls_limit": 20,
|
||||
"serper_calls_limit": 10,
|
||||
"stability_calls_limit": 5,
|
||||
"monthly_cost_limit": 5.0
|
||||
"gemini_calls_limit": 50, # 50 calls per day
|
||||
"gemini_tokens_limit": 10000, # 10k tokens per day
|
||||
"tavily_calls_limit": 20, # 20 searches per day
|
||||
"serper_calls_limit": 10, # 10 SEO searches per day
|
||||
"stability_calls_limit": 5, # 5 images per day
|
||||
"monthly_cost_limit": 5.0 # $5 monthly limit
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -58,12 +54,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Basic alpha tier - Moderate usage for testing",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 200,
|
||||
"gemini_tokens_limit": 50000,
|
||||
"tavily_calls_limit": 100,
|
||||
"serper_calls_limit": 50,
|
||||
"stability_calls_limit": 25,
|
||||
"monthly_cost_limit": 25.0
|
||||
"gemini_calls_limit": 200, # 200 calls per day
|
||||
"gemini_tokens_limit": 50000, # 50k tokens per day
|
||||
"tavily_calls_limit": 100, # 100 searches per day
|
||||
"serper_calls_limit": 50, # 50 SEO searches per day
|
||||
"stability_calls_limit": 25, # 25 images per day
|
||||
"monthly_cost_limit": 25.0 # $25 monthly limit
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -74,12 +70,12 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Pro alpha tier - High usage for power users",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot", "advanced_analytics"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 500,
|
||||
"gemini_tokens_limit": 150000,
|
||||
"tavily_calls_limit": 300,
|
||||
"serper_calls_limit": 150,
|
||||
"stability_calls_limit": 100,
|
||||
"monthly_cost_limit": 100.0
|
||||
"gemini_calls_limit": 500, # 500 calls per day
|
||||
"gemini_tokens_limit": 150000, # 150k tokens per day
|
||||
"tavily_calls_limit": 300, # 300 searches per day
|
||||
"serper_calls_limit": 150, # 150 SEO searches per day
|
||||
"stability_calls_limit": 100, # 100 images per day
|
||||
"monthly_cost_limit": 100.0 # $100 monthly limit
|
||||
}
|
||||
},
|
||||
{
|
||||
@@ -90,31 +86,34 @@ def create_alpha_subscription_tiers():
|
||||
"description": "Enterprise alpha tier - Unlimited usage for enterprise testing",
|
||||
"features": ["blog_writer", "seo_analysis", "content_planning", "strategy_copilot", "advanced_analytics", "custom_integrations"],
|
||||
"limits": {
|
||||
"gemini_calls_limit": 0,
|
||||
"gemini_tokens_limit": 0,
|
||||
"tavily_calls_limit": 0,
|
||||
"serper_calls_limit": 0,
|
||||
"stability_calls_limit": 0,
|
||||
"monthly_cost_limit": 500.0
|
||||
"gemini_calls_limit": 0, # Unlimited calls
|
||||
"gemini_tokens_limit": 0, # Unlimited tokens
|
||||
"tavily_calls_limit": 0, # Unlimited searches
|
||||
"serper_calls_limit": 0, # Unlimited SEO searches
|
||||
"stability_calls_limit": 0, # Unlimited images
|
||||
"monthly_cost_limit": 500.0 # $500 monthly limit
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Create subscription plans
|
||||
for tier_data in alpha_tiers:
|
||||
# Check if plan already exists
|
||||
existing_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == tier_data["name"]
|
||||
).first()
|
||||
|
||||
if existing_plan:
|
||||
logger.info(f"Plan '{tier_data['name']}' already exists, updating...")
|
||||
logger.info(f"✅ Plan '{tier_data['name']}' already exists, updating...")
|
||||
# Update existing plan
|
||||
for key, value in tier_data["limits"].items():
|
||||
setattr(existing_plan, key, value)
|
||||
existing_plan.description = tier_data["description"]
|
||||
existing_plan.features = tier_data["features"]
|
||||
existing_plan.updated_at = datetime.utcnow()
|
||||
else:
|
||||
logger.info(f"Creating new plan: {tier_data['name']}")
|
||||
logger.info(f"🆕 Creating new plan: {tier_data['name']}")
|
||||
# Create new plan
|
||||
plan = SubscriptionPlan(
|
||||
name=tier_data["name"],
|
||||
tier=tier_data["tier"],
|
||||
@@ -127,17 +126,106 @@ def create_alpha_subscription_tiers():
|
||||
db.add(plan)
|
||||
|
||||
db.commit()
|
||||
logger.info("Alpha subscription tiers created/updated successfully!")
|
||||
logger.info("✅ Alpha subscription tiers created/updated successfully!")
|
||||
|
||||
# Create API provider pricing
|
||||
create_api_pricing(db)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating alpha subscription tiers: {e}")
|
||||
logger.error(f"❌ Error creating alpha subscription tiers: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
def create_api_pricing(db: Session):
|
||||
"""Create API provider pricing configuration."""
|
||||
|
||||
try:
|
||||
# Gemini pricing (based on current Google AI pricing)
|
||||
gemini_pricing = [
|
||||
{
|
||||
"model_name": "gemini-2.0-flash-exp",
|
||||
"cost_per_input_token": 0.00000075, # $0.75 per 1M tokens
|
||||
"cost_per_output_token": 0.000003, # $3 per 1M tokens
|
||||
"description": "Gemini 2.0 Flash Experimental"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-1.5-flash",
|
||||
"cost_per_input_token": 0.00000075, # $0.75 per 1M tokens
|
||||
"cost_per_output_token": 0.000003, # $3 per 1M tokens
|
||||
"description": "Gemini 1.5 Flash"
|
||||
},
|
||||
{
|
||||
"model_name": "gemini-1.5-pro",
|
||||
"cost_per_input_token": 0.00000125, # $1.25 per 1M tokens
|
||||
"cost_per_output_token": 0.000005, # $5 per 1M tokens
|
||||
"description": "Gemini 1.5 Pro"
|
||||
}
|
||||
]
|
||||
|
||||
# Tavily pricing
|
||||
tavily_pricing = [
|
||||
{
|
||||
"model_name": "search",
|
||||
"cost_per_search": 0.001, # $0.001 per search
|
||||
"description": "Tavily Search API"
|
||||
}
|
||||
]
|
||||
|
||||
# Serper pricing
|
||||
serper_pricing = [
|
||||
{
|
||||
"model_name": "search",
|
||||
"cost_per_search": 0.001, # $0.001 per search
|
||||
"description": "Serper Google Search API"
|
||||
}
|
||||
]
|
||||
|
||||
# Stability AI pricing
|
||||
stability_pricing = [
|
||||
{
|
||||
"model_name": "stable-diffusion-xl",
|
||||
"cost_per_image": 0.01, # $0.01 per image
|
||||
"description": "Stable Diffusion XL"
|
||||
}
|
||||
]
|
||||
|
||||
# Create pricing records
|
||||
pricing_configs = [
|
||||
(APIProvider.GEMINI, gemini_pricing),
|
||||
(APIProvider.TAVILY, tavily_pricing),
|
||||
(APIProvider.SERPER, serper_pricing),
|
||||
(APIProvider.STABILITY, stability_pricing)
|
||||
]
|
||||
|
||||
for provider, pricing_list in pricing_configs:
|
||||
for pricing_data in pricing_list:
|
||||
# Check if pricing already exists
|
||||
existing_pricing = db.query(APIProviderPricing).filter(
|
||||
APIProviderPricing.provider == provider,
|
||||
APIProviderPricing.model_name == pricing_data["model_name"]
|
||||
).first()
|
||||
|
||||
if existing_pricing:
|
||||
logger.info(f"✅ Pricing for {provider.value}/{pricing_data['model_name']} already exists")
|
||||
else:
|
||||
logger.info(f"🆕 Creating pricing for {provider.value}/{pricing_data['model_name']}")
|
||||
pricing = APIProviderPricing(
|
||||
provider=provider,
|
||||
**pricing_data
|
||||
)
|
||||
db.add(pricing)
|
||||
|
||||
db.commit()
|
||||
logger.info("✅ API provider pricing created successfully!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating API pricing: {e}")
|
||||
db.rollback()
|
||||
|
||||
def assign_default_plan_to_users():
|
||||
"""Assign Free Alpha plan to all existing users."""
|
||||
if os.getenv('ENABLE_ALPHA', 'false').lower() not in {'1','true','yes','on'}:
|
||||
@@ -146,28 +234,32 @@ def assign_default_plan_to_users():
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error("Could not get database session")
|
||||
logger.error("❌ Could not get database session")
|
||||
return False
|
||||
|
||||
try:
|
||||
# Get Free Alpha plan
|
||||
free_plan = db.query(SubscriptionPlan).filter(
|
||||
SubscriptionPlan.name == "Free Alpha"
|
||||
).first()
|
||||
|
||||
if not free_plan:
|
||||
logger.error("Free Alpha plan not found")
|
||||
logger.error("❌ Free Alpha plan not found")
|
||||
return False
|
||||
|
||||
from models.subscription_models import UserSubscription, BillingCycle, UsageStatus
|
||||
from datetime import timedelta
|
||||
|
||||
# For now, we'll create a default user subscription
|
||||
# In a real system, you'd query actual users
|
||||
from models.subscription_models import UserSubscription, BillingCycle, UsageStatus
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
# Create default user subscription for testing
|
||||
default_user_id = "default_user"
|
||||
existing_subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == default_user_id
|
||||
).first()
|
||||
|
||||
if not existing_subscription:
|
||||
logger.info(f"Creating default subscription for {default_user_id}")
|
||||
logger.info(f"🆕 Creating default subscription for {default_user_id}")
|
||||
subscription = UserSubscription(
|
||||
user_id=default_user_id,
|
||||
plan_id=free_plan.id,
|
||||
@@ -180,32 +272,33 @@ def assign_default_plan_to_users():
|
||||
)
|
||||
db.add(subscription)
|
||||
db.commit()
|
||||
logger.info(f"Default subscription created for {default_user_id}")
|
||||
logger.info(f"✅ Default subscription created for {default_user_id}")
|
||||
else:
|
||||
logger.info(f"Default subscription already exists for {default_user_id}")
|
||||
logger.info(f"✅ Default subscription already exists for {default_user_id}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error assigning default plan: {e}")
|
||||
logger.error(f"❌ Error assigning default plan: {e}")
|
||||
db.rollback()
|
||||
return False
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Initializing Alpha Subscription Tiers...")
|
||||
logger.info("🚀 Initializing Alpha Subscription Tiers...")
|
||||
|
||||
success = create_alpha_subscription_tiers()
|
||||
if success:
|
||||
logger.info("Subscription tiers created successfully!")
|
||||
logger.info("✅ Subscription tiers created successfully!")
|
||||
|
||||
# Assign default plan
|
||||
assign_success = assign_default_plan_to_users()
|
||||
if assign_success:
|
||||
logger.info("Default plan assigned successfully!")
|
||||
logger.info("✅ Default plan assigned successfully!")
|
||||
else:
|
||||
logger.error("Failed to assign default plan")
|
||||
logger.error("❌ Failed to assign default plan")
|
||||
else:
|
||||
logger.error("Failed to create subscription tiers")
|
||||
logger.error("❌ Failed to create subscription tiers")
|
||||
|
||||
logger.info("Alpha subscription system initialization complete!")
|
||||
logger.info("🎉 Alpha subscription system initialization complete!")
|
||||
|
||||
@@ -9,7 +9,6 @@ import json
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.blog_models import (
|
||||
MediumBlogGenerateRequest,
|
||||
@@ -27,7 +26,7 @@ class MediumBlogGenerator:
|
||||
def __init__(self):
|
||||
self.cache = persistent_content_cache
|
||||
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str, db: Session = None) -> MediumBlogGenerateResult:
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -499,7 +499,7 @@ class DatabaseTaskManager:
|
||||
)
|
||||
blog_writer_logger.log_error(e, "outline_generation_task", context={"task_id": task_id})
|
||||
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest):
|
||||
"""Background task to generate a medium blog using a single structured JSON call."""
|
||||
try:
|
||||
await self.update_progress(task_id, "📦 Packaging outline and metadata...", 0)
|
||||
@@ -512,7 +512,7 @@ class DatabaseTaskManager:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
task_id,
|
||||
user_id,
|
||||
user_id=request.user_id if hasattr(request, 'user_id') else (await self.get_task_status(task_id))['user_id'],
|
||||
db=self.db
|
||||
)
|
||||
|
||||
|
||||
@@ -70,22 +70,22 @@ STRATEGIC REQUIREMENTS:
|
||||
- Ensure engaging, actionable content throughout
|
||||
|
||||
Return JSON format:
|
||||
{{
|
||||
{
|
||||
"title_options": [
|
||||
"Title option 1",
|
||||
"Title option 2",
|
||||
"Title option 3"
|
||||
],
|
||||
"outline": [
|
||||
{{
|
||||
{
|
||||
"heading": "Section heading with primary keyword",
|
||||
"subheadings": ["Subheading 1", "Subheading 2", "Subheading 3"],
|
||||
"key_points": ["Key point 1", "Key point 2", "Key point 3"],
|
||||
"target_words": 300,
|
||||
"keywords": ["primary keyword", "secondary keyword"]
|
||||
}}
|
||||
}
|
||||
]
|
||||
}}"""
|
||||
}"""
|
||||
|
||||
def get_outline_schema(self) -> Dict[str, Any]:
|
||||
"""Get the structured JSON schema for outline generation."""
|
||||
|
||||
@@ -5,8 +5,8 @@ Enhances individual outline sections for better engagement and value.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import BlogOutlineSection
|
||||
import json
|
||||
|
||||
|
||||
class SectionEnhancer:
|
||||
@@ -73,45 +73,14 @@ class SectionEnhancer:
|
||||
"required": ["heading", "subheadings", "key_points", "target_words", "keywords"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
enhanced_data = llm_text_gen(
|
||||
prompt=enhancement_prompt,
|
||||
json_struct=enhancement_schema,
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
enhanced_data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
enhanced_data = json.loads(json_match.group(0))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Section enhancement returned invalid JSON: {e}")
|
||||
return section
|
||||
else:
|
||||
logger.warning(f"Section enhancement returned non-JSON string: {cleaned[:200]}")
|
||||
return section
|
||||
elif isinstance(raw, dict):
|
||||
enhanced_data = raw
|
||||
else:
|
||||
logger.warning(f"Unexpected LLM response type: {type(raw)}")
|
||||
return section
|
||||
|
||||
if 'error' in enhanced_data:
|
||||
logger.warning(f"AI section enhancement failed: {enhanced_data.get('error', 'Unknown error')}")
|
||||
else:
|
||||
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
|
||||
return BlogOutlineSection(
|
||||
id=section.id,
|
||||
heading=enhanced_data.get('heading', section.heading),
|
||||
|
||||
@@ -6,7 +6,6 @@ Extracts competitor insights and market intelligence from research content.
|
||||
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CompetitorAnalyzer:
|
||||
@@ -23,7 +22,7 @@ class CompetitorAnalyzer:
|
||||
Extract and analyze:
|
||||
1. Top competitors mentioned (companies, brands, platforms)
|
||||
2. Content gaps (what competitors are missing)
|
||||
3. Opportunities (untapped areas)
|
||||
3. Market opportunities (untapped areas)
|
||||
4. Competitive advantages (what makes content unique)
|
||||
5. Market positioning insights
|
||||
6. Industry leaders and their strategies
|
||||
@@ -56,38 +55,18 @@ class CompetitorAnalyzer:
|
||||
"required": ["top_competitors", "content_gaps", "opportunities", "competitive_advantages", "market_positioning", "industry_leaders", "analysis_notes"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
competitor_analysis = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
competitor_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
competitor_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Competitor analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
competitor_analysis = raw
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
else:
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in competitor_analysis:
|
||||
raise ValueError(f"Competitor analysis failed: {competitor_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = competitor_analysis.get('error', 'Unknown error') if isinstance(competitor_analysis, dict) else str(competitor_analysis)
|
||||
logger.error(f"AI competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Competitor analysis failed: {error_msg}")
|
||||
|
||||
|
||||
@@ -63,41 +63,18 @@ class ContentAngleGenerator:
|
||||
"required": ["content_angles"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
angles_result = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import json, re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
angles_result = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
angles_result = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Content angles returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
angles_result = raw
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
else:
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in angles_result:
|
||||
raise ValueError(f"Content angles generation failed: {angles_result.get('error', 'Unknown error')}")
|
||||
|
||||
if 'content_angles' not in angles_result:
|
||||
raise ValueError(f"Content angles missing from response")
|
||||
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = angles_result.get('error', 'Unknown error') if isinstance(angles_result, dict) else str(angles_result)
|
||||
logger.error(f"AI content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Content angles generation failed: {error_msg}")
|
||||
|
||||
|
||||
@@ -314,14 +314,11 @@ class ExaResearchProvider(BaseProvider):
|
||||
|
||||
def track_exa_usage(self, user_id: str, cost: float):
|
||||
"""Track Exa API usage after successful call."""
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[track_exa_usage] Could not get DB session for user {user_id}")
|
||||
return
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
@@ -6,7 +6,6 @@ Extracts and analyzes keywords from research content using structured AI respons
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class KeywordAnalyzer:
|
||||
@@ -63,38 +62,18 @@ class KeywordAnalyzer:
|
||||
"required": ["primary", "secondary", "long_tail", "search_intent", "difficulty", "content_gaps", "semantic_keywords", "trending_terms", "analysis_insights"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
keyword_analysis = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
keyword_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
keyword_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Keyword analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
keyword_analysis = raw
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
else:
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in keyword_analysis:
|
||||
raise ValueError(f"Keyword analysis failed: {keyword_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = keyword_analysis.get('error', 'Unknown error') if isinstance(keyword_analysis, dict) else str(keyword_analysis)
|
||||
logger.error(f"AI keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Keyword analysis failed: {error_msg}")
|
||||
|
||||
|
||||
@@ -111,22 +111,19 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
finally:
|
||||
if db_val:
|
||||
db_val.close()
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
api_start_time = time.time()
|
||||
@@ -165,15 +162,13 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation (similar to Exa)
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -434,16 +429,14 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
@@ -453,8 +446,7 @@ class ResearchService:
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
finally:
|
||||
if db_val:
|
||||
db_val.close()
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
|
||||
@@ -493,16 +485,14 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Tavily AI search...")
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -539,8 +529,7 @@ class ResearchService:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Tavily limits: {e}")
|
||||
finally:
|
||||
if db_val:
|
||||
db_val.close()
|
||||
db_val.close()
|
||||
|
||||
# Execute Tavily search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Tavily AI search...")
|
||||
|
||||
@@ -135,14 +135,11 @@ class TavilyResearchProvider(BaseProvider):
|
||||
|
||||
def track_tavily_usage(self, user_id: str, cost: float, search_depth: str):
|
||||
"""Track Tavily API usage after successful call."""
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[Tavily] Could not get DB session for user {user_id}, skipping usage tracking")
|
||||
return
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
@@ -92,7 +92,6 @@ class BlogSEORecommendationApplier:
|
||||
None,
|
||||
schema,
|
||||
user_id, # Pass user_id for subscription checking
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
|
||||
@@ -7,7 +7,6 @@ import os
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker, Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from typing import Optional, List
|
||||
|
||||
@@ -387,15 +386,12 @@ def get_db(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
user_id = current_user.get('id') or current_user.get('clerk_user_id')
|
||||
if not user_id:
|
||||
# Fallback or error? For now log error
|
||||
logger.error("No user ID found in context for DB connection")
|
||||
raise HTTPException(status_code=401, detail="User ID required for database access")
|
||||
# Could raise exception, but let's try to be safe
|
||||
raise Exception("User ID required for database access")
|
||||
|
||||
try:
|
||||
engine = get_engine_for_user(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[DB] Failed to create engine for user {user_id}: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable")
|
||||
|
||||
engine = get_engine_for_user(user_id)
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
db = SessionLocal()
|
||||
try:
|
||||
|
||||
@@ -237,21 +237,6 @@ class ControlStudioService:
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||||
|
||||
# Track usage
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"control-{operation}",
|
||||
operation_type="image-control",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/control/process",
|
||||
log_prefix="[Control Studio]"
|
||||
)
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"operation": operation,
|
||||
|
||||
@@ -514,19 +514,6 @@ class EditStudioService:
|
||||
background_bytes=background_bytes,
|
||||
lighting_bytes=lighting_bytes,
|
||||
)
|
||||
# Track usage for Stability operations
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"edit-{operation}",
|
||||
operation_type="image-edit",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/edit/process",
|
||||
log_prefix="[Edit Studio]"
|
||||
)
|
||||
else:
|
||||
image_bytes = await self._handle_general_edit(
|
||||
request=request,
|
||||
|
||||
@@ -88,20 +88,6 @@ class UpscaleStudioService:
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_metadata(image_bytes)
|
||||
|
||||
# Track usage
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"upscale-{mode}",
|
||||
operation_type="image-upscale",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/upscale",
|
||||
log_prefix="[Upscale Studio]"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"mode": mode,
|
||||
|
||||
@@ -233,7 +233,7 @@ def create_blog_post(
|
||||
|
||||
# BACK TO BASICS MODE: Try simplest possible structure FIRST
|
||||
# Since posting worked before Ricos/SEO, let's test with absolute minimum
|
||||
BACK_TO_BASICS_MODE = False # Disabled: full Ricos conversion now produces valid output
|
||||
BACK_TO_BASICS_MODE = True # Set to True to test with simplest structure
|
||||
|
||||
wix_logger.reset()
|
||||
wix_logger.log_operation_start("Blog Post Creation", title=title[:50] if title else None, member_id=member_id[:20] if member_id else None)
|
||||
@@ -257,7 +257,8 @@ def create_blog_post(
|
||||
'text': (content[:500] if content else "This is a post from ALwrity.").strip(),
|
||||
'decorations': []
|
||||
}
|
||||
}]
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
@@ -256,16 +256,17 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
quote_content = ' '.join(quote_lines)
|
||||
text_nodes = parse_markdown_inline(quote_content)
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
blockquote_node = {
|
||||
'id': node_id,
|
||||
'type': 'BLOCKQUOTE',
|
||||
'nodes': [paragraph_node],
|
||||
'blockquoteData': {}
|
||||
}
|
||||
nodes.append(blockquote_node)
|
||||
|
||||
@@ -331,6 +332,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -343,6 +345,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'BULLETED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'bulletedListData': {}
|
||||
}
|
||||
nodes.append(bulleted_list_node)
|
||||
|
||||
@@ -370,6 +373,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -382,6 +386,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'ORDERED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'orderedListData': {}
|
||||
}
|
||||
nodes.append(ordered_list_node)
|
||||
|
||||
@@ -437,6 +442,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(paragraph_node)
|
||||
|
||||
@@ -455,6 +461,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(fallback_paragraph)
|
||||
|
||||
|
||||
@@ -20,14 +20,13 @@ class SemanticHarvesterService:
|
||||
"last_harvest_time": None
|
||||
}
|
||||
|
||||
async def harvest_website(self, website_url: str, limit: int = 100, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
async def harvest_website(self, website_url: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deep crawl a website using Exa AI.
|
||||
|
||||
Args:
|
||||
website_url: The root URL to crawl.
|
||||
limit: Maximum number of pages to retrieve.
|
||||
user_id: Optional user ID for usage tracking and preflight checks.
|
||||
|
||||
Returns:
|
||||
List of pages with content and metadata.
|
||||
@@ -60,30 +59,6 @@ class SemanticHarvesterService:
|
||||
logger.warning("[SemanticHarvester] Exa service disabled. Returning placeholder data.")
|
||||
return self._get_placeholder_data(website_url)
|
||||
|
||||
# Preflight subscription check if user_id provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
logger.warning(f"[SemanticHarvester] Exa blocked for user {user_id}: {message}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[SemanticHarvester] Preflight check failed: {e}")
|
||||
|
||||
# Use Exa to search for all pages in this domain
|
||||
search_response = self.exa_service.exa.search_and_contents(
|
||||
query=f"site:{website_url}",
|
||||
@@ -107,38 +82,6 @@ class SemanticHarvesterService:
|
||||
})
|
||||
|
||||
logger.info(f"[SemanticHarvester] Successfully harvested {len(results)} pages from {website_url}")
|
||||
|
||||
# Track Exa usage if user_id provided
|
||||
if user_id and results:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Exa search cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET exa_calls = COALESCE(exa_calls, 0) + 1,
|
||||
exa_cost = COALESCE(exa_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[SemanticHarvester] Tracked Exa usage: user={user_id}, cost=${cost}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[SemanticHarvester] Failed to track Exa usage: {track_err}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -67,11 +67,10 @@ import sys
|
||||
from pathlib import Path
|
||||
import google.genai as genai
|
||||
from google.genai import types
|
||||
from dotenv import load_dotenv
|
||||
|
||||
|
||||
from loguru import logger
|
||||
from utils.logger_utils import get_service_logger
|
||||
from services.api_key_manager import APIKeyManager
|
||||
|
||||
# Use service-specific logger to avoid conflicts
|
||||
logger = get_service_logger("gemini_audio_text")
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
"""Image editing operations — generate_image_edit and related helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import ImageEditOptions, ImageGenerationResult, ImageEditProvider
|
||||
from .wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.edit")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Generate edited image with pre-flight validation and usage tracking.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Image editing failed", "message": str(e)}
|
||||
)
|
||||
|
||||
# 6. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
estimated_cost = 0.02 if provider_name == "wavespeed" else 0.05
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=result.model or model or "unknown",
|
||||
operation_type="image-edit",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation/edit",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||
|
||||
# 7. Return result
|
||||
return result
|
||||
@@ -1,105 +0,0 @@
|
||||
"""Face swap operations — generate_face_swap and related helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import FaceSwapOptions, FaceSwapProvider, ImageGenerationResult
|
||||
from .wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.face_swap")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Generate face swap with pre-flight validation and usage tracking.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
num_operations=1,
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025)
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None,
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap failed", "message": str(api_error)}
|
||||
)
|
||||
@@ -1,200 +0,0 @@
|
||||
"""Shared helpers for image generation operations — validation and usage tracking."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.helpers")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""Reusable pre-flight validation helper for all image operations."""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id}")
|
||||
except HTTPException:
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Reusable usage tracking helper for all image operations."""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Map provider to DB column names
|
||||
provider_column_map = {
|
||||
"stability": ("stability_calls", "stability_cost"),
|
||||
"wavespeed": ("wavespeed_calls", "wavespeed_cost"),
|
||||
"gemini": ("gemini_calls", "gemini_cost"),
|
||||
"openai": ("openai_calls", "openai_cost"),
|
||||
"huggingface": ("total_calls", "total_cost"), # no dedicated columns
|
||||
}
|
||||
calls_col, cost_col = provider_column_map.get(provider, ("total_calls", "total_cost"))
|
||||
|
||||
current_calls_before = getattr(summary, calls_col, 0) or 0
|
||||
current_cost_before = getattr(summary, cost_col, 0.0) or 0.0
|
||||
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {calls_col} = :new_calls,
|
||||
{cost_col} = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Map provider to APIProvider enum
|
||||
provider_api_map = {
|
||||
"stability": APIProvider.STABILITY,
|
||||
"wavespeed": APIProvider.WAVESPEED,
|
||||
"gemini": APIProvider.GEMINI,
|
||||
"openai": APIProvider.OPENAI,
|
||||
"image_edit": APIProvider.IMAGE_EDIT,
|
||||
"video": APIProvider.VIDEO,
|
||||
"audio": APIProvider.AUDIO,
|
||||
}
|
||||
api_provider = provider_api_map.get(provider, APIProvider.STABILITY)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
provider_limit = limits['limits'].get(calls_col, 0) if limits else 0
|
||||
provider_limit_display = provider_limit if (provider_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {provider_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {"current_calls": new_calls, "cost": cost, "total_cost": new_cost}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
@@ -133,9 +133,9 @@ def edit_image(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
# In feature-limited mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in feature-limited mode - allowing operation to continue")
|
||||
# In podcast-only mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES") == "podcast":
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in podcast mode - allowing operation to continue")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
|
||||
@@ -18,9 +18,9 @@ from .image_generation import (
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from .image_generation.helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from .image_generation.edit import generate_image_edit
|
||||
from .image_generation.face_swap import generate_face_swap
|
||||
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
@@ -53,6 +53,259 @@ def _get_provider(provider_name: str, user_id: Optional[str] = None):
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance.
|
||||
|
||||
Args:
|
||||
provider_name: Provider name ("wavespeed", "stability", etc.)
|
||||
|
||||
Returns:
|
||||
ImageEditProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
# TODO: Add Stability edit provider if needed
|
||||
# elif provider_name == "stability":
|
||||
# return StabilityEditProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""
|
||||
Reusable pre-flight validation helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for subscription checking
|
||||
operation_type: Type of operation (for logging)
|
||||
num_operations: Number of operations to validate (default: 1)
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails (subscription limits exceeded, etc.)
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name (e.g., "wavespeed", "stability")
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated/processed image bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text (for request size calculation)
|
||||
endpoint: API endpoint path (for logging)
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information (current_calls, cost, etc.)
|
||||
"""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Stability, HuggingFace, etc.)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, Stability, etc.)
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
@@ -247,7 +500,165 @@ def generate_character_image(
|
||||
)
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate edited image - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
# If model is specified and starts with "wavespeed", use wavespeed provider
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider (REUSES provider pattern)
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Image editing failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate face swap - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
image_base64=base_image_base64, # Use base image for validation
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
# Get model cost
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None, # Face swap doesn't use prompts
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -45,7 +45,6 @@ def llm_text_gen(
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
preferred_provider: Optional[str] = None,
|
||||
flow_type: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
@@ -76,8 +75,7 @@ def llm_text_gen(
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
temperature = 0.7
|
||||
if max_tokens is None:
|
||||
max_tokens = 4000
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
n = 1
|
||||
fp = 16
|
||||
@@ -373,27 +371,16 @@ def llm_text_gen(
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "wavespeed":
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
llm_start = time.time()
|
||||
if json_struct:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_structured_json_response
|
||||
response_text = wavespeed_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
llm_ms = (time.time() - llm_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] LLM API call took {llm_ms:.0f}ms for user {user_id} (wavespeed)")
|
||||
else:
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user