Compare commits
1 Commits
codex/impl
...
fix/add-ma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3150941c36 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -8,10 +8,6 @@ nul
|
||||
LICENSE
|
||||
CHANGELOG.md
|
||||
|
||||
.planning
|
||||
.planning/
|
||||
|
||||
|
||||
.trae/
|
||||
.trae
|
||||
|
||||
|
||||
46
.planning/PROJECT.md
Normal file
46
.planning/PROJECT.md
Normal file
@@ -0,0 +1,46 @@
|
||||
# ALwrity Project
|
||||
|
||||
## What This Is
|
||||
ALwrity is an AI-powered content creation platform that helps users generate various types of content including podcasts, videos, blogs, and social media content. The platform features a React frontend and a FastAPI backend with onboarding workflows, API key management, and content generation capabilities.
|
||||
|
||||
## Core Value
|
||||
To provide an all-in-one AI content creation suite that simplifies the content production process for creators, marketers, and businesses.
|
||||
|
||||
## Current Focus
|
||||
Based on recent git commits, the team has been working on:
|
||||
- Podcast production features (voice cloning, avatar generation, B-roll integration)
|
||||
- Onboarding flow improvements
|
||||
- Backend stability and debugging
|
||||
- Frontend UI/UX enhancements
|
||||
|
||||
## Requirements
|
||||
|
||||
### Validated
|
||||
- User authentication (Clerk)
|
||||
- API key management for AI providers
|
||||
- Basic podcast generation workflow
|
||||
- File storage and media handling
|
||||
|
||||
### Active
|
||||
- Podcast script generation and editing
|
||||
- Voice cloning and avatar creation
|
||||
- B-roll scene rendering and integration
|
||||
- Onboarding flow completion tracking
|
||||
- API endpoint stability and debugging
|
||||
|
||||
### Out of Scope
|
||||
- Mobile applications (currently web-only)
|
||||
- Enterprise team collaboration features
|
||||
- Advanced analytics dashboard
|
||||
|
||||
## Key Decisions
|
||||
- Using FastAPI for backend performance
|
||||
- React with Material-UI for frontend consistency
|
||||
- Modular API design for extensibility
|
||||
- Database-first approach for persistence
|
||||
|
||||
## Constraints
|
||||
- Must maintain backward compatibility with existing API
|
||||
- Deployment targets include both development and production environments
|
||||
- Must support multiple AI providers (OpenAI, HuggingFace, etc.)
|
||||
- Budget-conscious resource usage for AI API calls
|
||||
@@ -1,88 +0,0 @@
|
||||
# Roadmap: Alwrity - ALwrity Frontend Optimization
|
||||
|
||||
## Overview
|
||||
|
||||
Optimize the frontend build to reduce build time from 5 minutes to under 30 seconds and shrink bundle size from 8.42MB to under 1MB. First, implement code splitting with React.lazy and feature-gated loading using ALWRITY_ENABLED_FEATURES. Then migrate from Create React App to Vite for faster builds. Finally, optimize dependencies for maximum performance.
|
||||
|
||||
## Phases
|
||||
|
||||
**Phase Numbering:**
|
||||
- Integer phases (1, 2, 3, 4): Planned work
|
||||
- All phases planned and ready for execution
|
||||
|
||||
---
|
||||
|
||||
### Phase 1: Code Splitting & Feature-Based Lazy Loading ✅ Complete
|
||||
**Goal**: Replace all static imports with React.lazy dynamic imports and add feature-gated loading using ALWRITY_ENABLED_FEATURES. Also convert MUI icon barrel imports to individual imports (moved here from Phase 3 for Vite readiness).
|
||||
**Depends on**: Nothing (first phase)
|
||||
**Requirements**: VITE-04 (code splitting), VITE-06 (dependency optimization)
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. ✅ All 31+ route components loaded via React.lazy (not static imports)
|
||||
2. ✅ Initial bundle size reduced from 8.42MB to 2.50MB (70% reduction)
|
||||
3. ✅ Disabled features (via ALWRITY_ENABLED_FEATURES) don't load their bundles
|
||||
4. ✅ All existing routes still work correctly
|
||||
5. ✅ No build warnings or errors with CRA
|
||||
6. ✅ All MUI icon imports changed from barrel to individual (111 files)
|
||||
|
||||
**Plans**: 3 plans (all complete)
|
||||
|
||||
Plans:
|
||||
- [x] 01-01: Convert 31 static imports to React.lazy with Suspense
|
||||
- [x] 01-02: Add feature-gated route loading using ALWRITY_ENABLED_FEATURES
|
||||
- [x] 01-03: Convert MUI icon barrel imports to individual imports (111 files)
|
||||
|
||||
---
|
||||
|
||||
### Phase 2: Migrate from CRA to Vite (Next)
|
||||
**Goal**: Migrate frontend from Create React App to Vite for fast builds
|
||||
**Depends on**: Phase 1 ✅
|
||||
**Requirements**: VITE-01, VITE-02, VITE-03
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. `npm run dev` starts Vite dev server with HMR
|
||||
2. `npm run build` completes in under 30 seconds (down from 5 minutes)
|
||||
3. All environment variables work with `VITE_*` prefix
|
||||
4. TypeScript compiles without errors
|
||||
5. Material UI theme renders correctly
|
||||
|
||||
**Plans**: 3 plans
|
||||
|
||||
Plans:
|
||||
- [ ] 02-01: Install Vite dependencies and create configuration
|
||||
- [ ] 02-02: Migrate index.html and entry point
|
||||
- [ ] 02-03: Update environment variables and scripts
|
||||
|
||||
---
|
||||
|
||||
### Phase 3: Dependency Cleanup & Production Validation
|
||||
**Goal**: Remove unused dependencies and deploy Vite build to production
|
||||
**Depends on**: Phase 2
|
||||
**Requirements**: VITE-07, VITE-08, VITE-09
|
||||
**Success Criteria** (what must be TRUE):
|
||||
1. Unused dependencies identified and removed
|
||||
2. Production build serves correctly (preview mode)
|
||||
3. All features tested and working (Clerk auth, Stripe, CopilotKit)
|
||||
4. Vercel deployment config updated for Vite
|
||||
5. Build time consistently under 30 seconds
|
||||
6. Total bundle size under 2MB
|
||||
|
||||
**Plans**: 2 plans (consolidated from former Phase 3 & 4)
|
||||
|
||||
Plans:
|
||||
- [ ] 03-01: Audit and remove unused dependencies, update Vercel config
|
||||
- [ ] 03-02: Full feature testing and performance validation
|
||||
|
||||
---
|
||||
|
||||
## Execution Order
|
||||
|
||||
Phases execute in numeric order: 1 → 2 → 3
|
||||
|
||||
**Key insight:** Phase 1 (code splitting) works with CRA, so we immediately reduce bundle size. Phase 2 (Vite) gives build speed bonus on already-split bundles. Phase 3 is cleanup and deployment.
|
||||
|
||||
## Progress
|
||||
|
||||
| Phase | Plans Complete | Status | Completed |
|
||||
|-------|----------------|--------|-----------|
|
||||
| 1. Code Splitting & MUI Optimization | 3/3 | ✅ Complete | 2026-05-08 |
|
||||
| 2. Migrate CRA to Vite | 0/3 | ⏳ Ready | - |
|
||||
| 3. Cleanup & Production | 0/2 | ⏳ Planned | - |
|
||||
@@ -1,73 +1,40 @@
|
||||
# Project State: Alwrity
|
||||
# Project State
|
||||
|
||||
## Project Reference
|
||||
**Core Value**: ALwrity is an AI-powered content creation platform that helps users generate various types of content including podcasts, videos, blogs, and social media content.
|
||||
|
||||
**Current Focus**: Based on recent development activity, the team is implementing Phase 2 of the WaveSpeed AI integration roadmap - Hyper-Personalization features for the Persona system, including voice training and avatar creation.
|
||||
|
||||
## Current Position
|
||||
**Phase**: 2 of 3 - Hyper-Personalization
|
||||
**Plan**: 3 of 5 - Persona Avatar Creation & Integration
|
||||
**Status**: In Progress - Working on avatar service implementation and frontend UI for avatar creation
|
||||
|
||||
**Active Phase:** Phase 1 - Code Splitting & Feature-Based Lazy Loading
|
||||
**Phase Status:** ✅ Complete — Ready for Phase 2
|
||||
**Milestone:** v1.0 - Frontend Optimization
|
||||
## Progress
|
||||
Progress: [███████░░] 70%
|
||||
|
||||
## Phase Progress
|
||||
## Recent Decisions
|
||||
1. **Avatar Service Architecture**: Decided to create a shared avatar service in backend/services/wavespeed/avatar/ for reuse across LinkedIn and Persona modules
|
||||
2. **UI Framework**: Continuing with Material-UI (MUI) for consistent avatar creation interface
|
||||
3. **Storage Strategy**: Using cloud storage for avatar assets with metadata tracking in PostgreSQL
|
||||
4. **Generation Queue**: Implementing asynchronous processing for avatar generation to prevent API timeouts
|
||||
|
||||
### Phase 1: Code Splitting & Feature-Based Lazy Loading
|
||||
- **Status:** ✅ Complete
|
||||
- **Plans:** 3 plans executed (01-01, 01-02, 01-03)
|
||||
## Pending Todos
|
||||
- [ ] Complete avatar generation API endpoints
|
||||
- [ ] Implement avatar library management UI
|
||||
- [ ] Add avatar preview functionality
|
||||
- [ ] Create avatar upload/download capabilities
|
||||
- [ ] Integrate avatar selection into Persona dashboard
|
||||
- [ ] Add usage tracking and cost estimation for avatar generation
|
||||
- [ ] Write comprehensive tests for avatar service
|
||||
- [ ] Update documentation for avatar feature
|
||||
|
||||
**Plans:**
|
||||
- [x] 01-01: Convert 31 static imports to React.lazy with Suspense
|
||||
- [x] 01-02: Add feature-gated route loading using ALWRITY_ENABLED_FEATURES
|
||||
- [x] 01-03: Convert MUI icon barrel imports to individual imports (111 files)
|
||||
## Blockers/Concerns
|
||||
- **WaveSpeed API Rate Limits**: Need to implement proper queuing and retry mechanisms
|
||||
- **Storage Costs**: Avatar storage could become expensive at scale - need to implement cleanup policies
|
||||
- **Generation Time**: Avatar generation can take 30-60 seconds - need to improve user experience during wait
|
||||
- **Quality Consistency**: Ensuring generated avatars maintain consistent quality across different inputs
|
||||
|
||||
**Results:**
|
||||
- Main bundle: 8.42MB → 2.50MB (70% reduction via React.lazy)
|
||||
- 190+ chunk files for route-level code splitting
|
||||
- 47 routes feature-gated with ALWRITY_ENABLED_FEATURES
|
||||
- 16 feature keys in FEATURE_KEYS constant
|
||||
- 111 files converted from barrel to individual MUI icon imports
|
||||
- Zero barrel imports from @mui/icons-material remain
|
||||
|
||||
### Phase 2: Migrate CRA to Vite
|
||||
- **Status:** Ready to start (Phase 1 complete)
|
||||
- **Plans:** 3 plans created (02-01, 02-02, 02-03)
|
||||
- **Dependencies:** Phase 1 complete
|
||||
|
||||
**Plans:**
|
||||
- [ ] 02-01: Install Vite dependencies and create configuration
|
||||
- [ ] 02-02: Migrate index.html and entry point
|
||||
- [ ] 02-03: Update environment variables and scripts
|
||||
|
||||
### Phase 3: Production Validation (Planned)
|
||||
- Depends on: Phase 2
|
||||
- Focus: Vercel deploy, full feature testing
|
||||
|
||||
### Phase 4: (Removed — MUI icon optimization folded into Phase 1-03)
|
||||
|
||||
## Decisions Made
|
||||
|
||||
### Locked Decisions
|
||||
- **Code splitting first**, then Vite migration (not the other way around) ✅ Done
|
||||
- Use React.lazy for ALL route components (this is a React feature, NOT bundler-specific) ✅ Done
|
||||
- Use ALWRITY_ENABLED_FEATURES for feature-gated route loading ✅ Done
|
||||
- **MUI icon imports before Vite migration** — barrel imports converted to individual per-file default imports ✅ Done
|
||||
- Use Vite 5.x with @vitejs/plugin-react
|
||||
- Disable sourcemaps in production build for speed
|
||||
- Migrate env vars from `REACT_APP_*` to `VITE_*`
|
||||
|
||||
### Patterns Established
|
||||
- **MUI icon imports**: Always `import IconName from '@mui/icons-material/IconName'` — never barrel destructuring
|
||||
- **Route splitting**: All route components use React.lazy with Suspense
|
||||
- **Feature gating**: FeatureRoute wraps inside ProtectedRoute (auth → then feature check)
|
||||
|
||||
## Key Insight
|
||||
|
||||
**React.lazy is a React feature (not CRA or Vite specific).** Doing code splitting first with CRA:
|
||||
1. Immediately reduces main bundle from 8.42MB → ~1-2MB
|
||||
2. Adds no risk (React.lazy is stable since React 16.6)
|
||||
3. Makes Vite migration smoother (bundles are already split)
|
||||
4. ALWRITY_ENABLED_FEATURES can prevent disabled feature bundles from loading at all
|
||||
|
||||
**MUI icon barrel imports eliminated** — 111 files converted to individual per-file imports. This ensures reliable tree-shaking during Vite migration and beyond.
|
||||
|
||||
---
|
||||
|
||||
*Last updated: 2026-05-08*
|
||||
*Updated by: gsd-executor*
|
||||
Last session: 2026-04-21 07:02:08
|
||||
Stopped at: Session resumed, proceeding to discuss Phase 2 context
|
||||
Resume file: [updated if applicable]
|
||||
|
||||
@@ -1,129 +0,0 @@
|
||||
---
|
||||
phase: 01-code-splitting
|
||||
plan: 03
|
||||
type: execute
|
||||
subsystem: frontend
|
||||
tags: [performance, MUI, icons, tree-shaking, barrel-imports]
|
||||
requires:
|
||||
- phase: 01-code-splitting-02
|
||||
provides: feature gating structure for route protection
|
||||
provides:
|
||||
- All MUI icon imports converted from barrel (destructured) to individual per-file default imports
|
||||
- Zero barrel imports from @mui/icons-material remain in the codebase
|
||||
affects: [02-vite-migration, build performance]
|
||||
tech-stack:
|
||||
added: []
|
||||
patterns: [individual MUI icon imports, per-file default imports for tree-shaking]
|
||||
key-files:
|
||||
created: []
|
||||
modified:
|
||||
- frontend/src/components/shared/ErrorBoundary.tsx
|
||||
- frontend/src/components/SubscriptionGuard.tsx
|
||||
- frontend/src/components/SubscriptionExpiredModal.tsx
|
||||
- frontend/src/pages/SchedulerDashboard.tsx
|
||||
- frontend/src/pages/BillingPage.tsx
|
||||
- +106 additional frontend component files
|
||||
key-decisions:
|
||||
- "All MUI icon barrel imports converted BEFORE Vite migration to eliminate Webpack 4 tree-shaking uncertainty"
|
||||
- "Used per-file default imports (import X from '@mui/icons-material/X') instead of destructured barrel imports"
|
||||
- "Aliased icons (e.g., ErrorOutline as ErrorIcon) converted to named default imports matching the alias (import ErrorIcon from '@mui/icons-material/ErrorOutline')"
|
||||
- "JSX variable names preserved — only import statements changed"
|
||||
patterns-established:
|
||||
- "MUI icon imports: always use import X from '@mui/icons-material/X' pattern, never import { X } from '@mui/icons-material'"
|
||||
duration: 45min
|
||||
completed: 2026-05-08
|
||||
---
|
||||
|
||||
# Phase 1 Plan 01-03: MUI Icon Import Optimization Summary
|
||||
|
||||
**Converted all 300+ MUI icon barrel imports to individual per-file default imports across 111 frontend files — eliminating Webpack 4 tree-shaking uncertainty before Vite migration**
|
||||
|
||||
## Performance
|
||||
|
||||
- **Duration:** ~35 min
|
||||
- **Completed:** 2026-05-08
|
||||
- **Tasks:** 10 commits across 111 files
|
||||
- **Files modified:** 111
|
||||
|
||||
## Accomplishments
|
||||
|
||||
- Converted **all barrel** `import { X } from '@mui/icons-material'` to individual `import X from '@mui/icons-material/X'` — **zero barrel imports remaining**
|
||||
- Modified **111 files** across every area: PodcastMaker, YouTubeCreator, OnboardingWizard, billing, SEO, shared components, and more
|
||||
- Handled aliased imports (`IconName as Alias`) correctly — JSX variable names preserved unchanged
|
||||
- Build verified — `npm run build:nomap` succeeds with zero new errors
|
||||
- Enables reliable tree-shaking during Phase 2 (Vite migration) — each file imports only the icons it uses
|
||||
|
||||
## Task Commits
|
||||
|
||||
Each batch was committed atomically:
|
||||
|
||||
1. **ErrorBoundary** (`components/shared/`) - `46781a0` — 5 icons
|
||||
2. **SubscriptionGuard** - `bda75cb` — 2 icons
|
||||
3. **SubscriptionExpiredModal** - `80f76b1` — 3 icons
|
||||
4. **SchedulerDashboard** - `7ffd972` — 7 icons
|
||||
5. **BillingPage** - `a76671c` — 1 icon
|
||||
6. **Billing, Blog, ContentPlanning, ErrorBoundary, Pricing, Alerts** - `a009cbb` — 8 files, 36 insertions
|
||||
7. **ImageStudio, Landing, LinkedIn, MainDashboard, OnboardingWizard** - `205e098` — 14 files, 65 insertions
|
||||
8. **PodcastMaker AnalysisPanel** - `25ce5b9` — 18 files, 58 insertions
|
||||
9. **PodcastMaker, ProductMarketing, Research, Scheduler, SEO, Shared** - `986a7e5` — 44 files, 149 insertions
|
||||
10. **StoryWriter, YouTubeCreator** - `6361255` — 22 files, 67 insertions
|
||||
|
||||
## Files Modified
|
||||
|
||||
**111 files total** across the frontend source tree:
|
||||
|
||||
- `components/billing/` — 2 files (ComprehensiveAPIBreakdown, CostOptimizationRecommendations)
|
||||
- `components/BlogWriter/` — 1 file (BlogWriterPhasesSection)
|
||||
- `components/ContentPlanningDashboard/` — 2 files (CardExpansionWrapper, StrategyErrorBoundary)
|
||||
- `components/ErrorBoundary.tsx` — 1 file (3 icons)
|
||||
- `components/ImageStudio/` — 2 files (AssetFilters, CreateStudioCostAlerts)
|
||||
- `components/Landing/` — 2 files (EnterpriseCTA, FeatureShowcase)
|
||||
- `components/LinkedInWriter/` — 1 file (FactCheckResults)
|
||||
- `components/MainDashboard/` — 1 file (MainDashboard)
|
||||
- `components/OnboardingWizard/` — 7 files (incl. VoiceAvatarPlaceholder with 22 icons)
|
||||
- `components/PodcastMaker/` — 40 files (AnalysisPanel, CreateStep, ScriptEditor, etc.)
|
||||
- `components/Pricing/` — 1 file (PricingPage)
|
||||
- `components/ProductMarketing/` — 5 files (CampaignWizard, ProductPhotoshootStudio, etc.)
|
||||
- `components/Research/` — 2 files (PersonalizationIndicator, ResearchInputContainer)
|
||||
- `components/SchedulerDashboard/` — 1 file (SchedulerCharts)
|
||||
- `components/SEODashboard/` — 3 files (AIInsightsPanel, HealthScore, MetricCard)
|
||||
- `components/shared/` — 12 files (ErrorBoundary, AlertsBadge, ProtectedRoute, etc.)
|
||||
- `components/StoryWriter/` — 3 files (AIStorySetupModal, FormFieldWithTooltip, SelectFieldWithTooltip)
|
||||
- `components/SubscriptionGuard.tsx` — 1 file
|
||||
- `components/SubscriptionExpiredModal.tsx` — 1 file
|
||||
- `components/YouTubeCreator/` — 19 files (SceneCard, RenderStep, PlanStep, etc.)
|
||||
- `pages/` — 2 files (BillingPage, ResearchDashboard/PresetsCard)
|
||||
|
||||
## Decisions Made
|
||||
|
||||
- **Convert all barrel imports now, before Vite migration** — CRA's Webpack 4 cannot reliably tree-shake barrel imports. Converting before the bundler swap reduces migration risk and ensures Vite's native ESM tree-shaking works optimally.
|
||||
- **Per-file default import pattern** — Every icon gets its own import line: `import IconName from '@mui/icons-material/IconName'`. This is the most predictable pattern and works identically in both Webpack and Vite.
|
||||
- **Alias handling** — For icons imported as `{ X as Y }`, the alias `Y` becomes the import name: `import Y from '@mui/icons-material/X'`. JSX usage unchanged.
|
||||
- **Multiple import lines preserved** — Files with separate barrel imports from `@mui/icons-material` were converted to multiple individual import blocks, preserving the original organizational structure.
|
||||
|
||||
## Deviations from Plan
|
||||
|
||||
None - this was ad-hoc work not covered by an existing PLAN.md.
|
||||
|
||||
## Issues Encountered
|
||||
|
||||
- **Task agent timeout**: First attempt at parallel conversion agents failed silently for batches 1-2 (73 files). Re-launched with explicit edit instructions - succeeded on second attempt.
|
||||
- **No naming conflicts found**: Despite converting 300+ icon imports across 111 files, no variable naming collisions occurred. Each icon only appears once per file.
|
||||
|
||||
## Build Verification
|
||||
|
||||
- `npm run build:nomap` — **PASSED** with zero errors
|
||||
- Only pre-existing CRA bundle size warning remains (expected — Vite migration will resolve it in Phase 2)
|
||||
- No new build warnings introduced
|
||||
|
||||
## Next Phase Readiness
|
||||
|
||||
- Frontend is ready for **Phase 2: Vite Migration**
|
||||
- All MUI icon imports use individual default imports — tree-shaking will work correctly with Vite's rollup
|
||||
- User should perform manual testing of Podcast Maker with `REACT_APP_ENABLED_FEATURES=podcast` before Vite migration begins
|
||||
- After manual verification, proceed with [Phase 2-01: Install Vite dependencies and create configuration]
|
||||
|
||||
---
|
||||
|
||||
*Phase: 01-code-splitting*
|
||||
*Completed: 2026-05-08*
|
||||
14
Procfile
14
Procfile
@@ -1 +1,13 @@
|
||||
web: cd backend && python start_alwrity_backend.py --production
|
||||
web: cd backend && ALWRITY_ENABLED_FEATURES=podcast python -c "
|
||||
import os
|
||||
import sys
|
||||
# Ensure podcast mode
|
||||
os.environ.setdefault('ALWRITY_ENABLED_FEATURES', 'podcast')
|
||||
# Set HOST/PORT for Render
|
||||
port = os.getenv('PORT', '10000')
|
||||
host = os.getenv('HOST', '0.0.0.0')
|
||||
print(f'[STARTUP] Starting uvicorn on {host}:{port}', flush=True)
|
||||
sys.stdout.flush()
|
||||
import uvicorn
|
||||
uvicorn.run('app:app', host=host, port=int(port), reload=False)
|
||||
"
|
||||
|
||||
14
README.md
Normal file
14
README.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Render CLI
|
||||
|
||||
## Installation
|
||||
|
||||
- [Homebrew](https://render.com/docs/cli#homebrew-macos-linux)
|
||||
- [Direct Download](https://render.com/docs/cli#direct-download)
|
||||
|
||||
## Documentation
|
||||
|
||||
Documentation is hosted at https://render.com/docs/cli.
|
||||
|
||||
## Contributing
|
||||
|
||||
To create a new command, use the `cmd/template.go` template file as a starting point. Reference the [CLI Style Guide](docs/STYLE.md) to learn more about command naming, flags, arguments, and help text conventions.
|
||||
672
_session_backup/App.tsx
Normal file
672
_session_backup/App.tsx
Normal file
@@ -0,0 +1,672 @@
|
||||
import React from 'react';
|
||||
import { BrowserRouter as Router, Routes, Route, Navigate, useLocation } from 'react-router-dom';
|
||||
import { Box, CircularProgress, Typography } from '@mui/material';
|
||||
import { CopilotKit } from "@copilotkit/react-core";
|
||||
import { ClerkProvider, useAuth } from '@clerk/clerk-react';
|
||||
import "@copilotkit/react-ui/styles.css";
|
||||
import Wizard from './components/OnboardingWizard/Wizard';
|
||||
import MainDashboard from './components/MainDashboard/MainDashboard';
|
||||
import SEODashboard from './components/SEODashboard/SEODashboard';
|
||||
import ContentPlanningDashboard from './components/ContentPlanningDashboard/ContentPlanningDashboard';
|
||||
import FacebookWriter from './components/FacebookWriter/FacebookWriter';
|
||||
import LinkedInWriter from './components/LinkedInWriter/LinkedInWriter';
|
||||
import BlogWriter from './components/BlogWriter/BlogWriter';
|
||||
import StoryWriter from './components/StoryWriter/StoryWriter';
|
||||
import { StoryProjectList } from './components/StoryWriter/StoryProjectList';
|
||||
import YouTubeCreator from './components/YouTubeCreator/YouTubeCreator';
|
||||
import { CreateStudio, EditStudio, UpscaleStudio, ControlStudio, SocialOptimizer, AssetLibrary, ImageStudioDashboard, FaceSwapStudio, CompressionStudio, ImageProcessingStudio } from './components/ImageStudio';
|
||||
import {
|
||||
VideoStudioDashboard,
|
||||
CreateVideo,
|
||||
AvatarVideo,
|
||||
EnhanceVideo,
|
||||
ExtendVideo,
|
||||
EditVideo,
|
||||
TransformVideo,
|
||||
SocialVideo,
|
||||
FaceSwap,
|
||||
VideoTranslate,
|
||||
VideoBackgroundRemover,
|
||||
AddAudioToVideo,
|
||||
LibraryVideo,
|
||||
} from './components/VideoStudio';
|
||||
import {
|
||||
ProductMarketingDashboard,
|
||||
ProductPhotoshootStudio,
|
||||
ProductAnimationStudio,
|
||||
ProductVideoStudio,
|
||||
ProductAvatarStudio,
|
||||
} from './components/ProductMarketing';
|
||||
import PodcastDashboard from './components/PodcastMaker/PodcastDashboard';
|
||||
import PricingPage from './components/Pricing/PricingPage';
|
||||
import WixTestPage from './components/WixTestPage/WixTestPage';
|
||||
import WixCallbackPage from './components/WixCallbackPage/WixCallbackPage';
|
||||
import WordPressCallbackPage from './components/WordPressCallbackPage/WordPressCallbackPage';
|
||||
import BingCallbackPage from './components/BingCallbackPage/BingCallbackPage';
|
||||
import BingAnalyticsStorage from './components/BingAnalyticsStorage/BingAnalyticsStorage';
|
||||
import ResearchDashboard from './pages/ResearchDashboard';
|
||||
import IntentResearchTest from './pages/IntentResearchTest';
|
||||
import SchedulerDashboard from './pages/SchedulerDashboard';
|
||||
import BillingPage from './pages/BillingPage';
|
||||
import ApprovalsPage from './pages/ApprovalsPage';
|
||||
import TeamActivityPage from './pages/TeamActivityPage';
|
||||
import StripeDisputesDashboard from './pages/StripeDisputesDashboard';
|
||||
import ProtectedRoute from './components/shared/ProtectedRoute';
|
||||
import GSCAuthCallback from './components/SEODashboard/components/GSCAuthCallback';
|
||||
import Landing from './components/Landing/Landing';
|
||||
import ErrorBoundary from './components/shared/ErrorBoundary';
|
||||
import ErrorBoundaryTest from './components/shared/ErrorBoundaryTest';
|
||||
import CopilotKitDegradedBanner from './components/shared/CopilotKitDegradedBanner';
|
||||
import { OnboardingProvider } from './contexts/OnboardingContext';
|
||||
import { SubscriptionProvider, useSubscription } from './contexts/SubscriptionContext';
|
||||
import { CopilotKitHealthProvider } from './contexts/CopilotKitHealthContext';
|
||||
import { useOAuthTokenAlerts } from './hooks/useOAuthTokenAlerts';
|
||||
|
||||
import { setAuthTokenGetter, setClerkSignOut } from './api/client';
|
||||
import { setMediaAuthTokenGetter } from './utils/fetchMediaBlobUrl';
|
||||
import { setBillingAuthTokenGetter } from './services/billingService';
|
||||
import { useOnboarding } from './contexts/OnboardingContext';
|
||||
import { useState, useEffect } from 'react';
|
||||
import ConnectionErrorPage from './components/shared/ConnectionErrorPage';
|
||||
import { isPodcastOnlyDemoMode } from './utils/demoMode';
|
||||
|
||||
// interface OnboardingStatus {
|
||||
// onboarding_required: boolean;
|
||||
// onboarding_complete: boolean;
|
||||
// current_step?: number;
|
||||
// total_steps?: number;
|
||||
// completion_percentage?: number;
|
||||
// }
|
||||
|
||||
// Conditional CopilotKit wrapper that only shows sidebar on content-planning route
|
||||
const ConditionalCopilotKit: React.FC<{ children: React.ReactNode }> = ({ children }) => {
|
||||
// Do not render CopilotSidebar here. Let specific pages/components control it.
|
||||
return <>{children}</>;
|
||||
};
|
||||
|
||||
// Wrapper to only enable CopilotKit checks/provider when user is authenticated
|
||||
// This prevents CopilotKit from running on the Landing page
|
||||
const AuthenticatedCopilotWrapper: React.FC<{
|
||||
children: React.ReactNode;
|
||||
apiKey: string;
|
||||
}> = ({ children, apiKey }) => {
|
||||
const { isSignedIn } = useAuth();
|
||||
const location = useLocation();
|
||||
|
||||
// Exclude CopilotKit from running on:
|
||||
// 1. Landing page (handled by !isSignedIn)
|
||||
// 2. Onboarding pages (to prevent health check timeouts)
|
||||
// 3. Podcast-only demo mode (CopilotKit not needed)
|
||||
const isPodcastOnly = isPodcastOnlyDemoMode();
|
||||
const shouldExcludeCopilot = !isSignedIn || location.pathname.startsWith('/onboarding') || isPodcastOnly;
|
||||
|
||||
if (shouldExcludeCopilot) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
const hasKey = apiKey && apiKey.trim();
|
||||
|
||||
if (hasKey) {
|
||||
// Enhanced error handler that updates health context
|
||||
const handleCopilotKitError = (e: any) => {
|
||||
console.error("CopilotKit Error:", e);
|
||||
|
||||
// Try to get health context if available
|
||||
// We'll use a custom event to notify health context since we can't access it directly here
|
||||
const errorMessage = e?.error?.message || e?.message || 'CopilotKit error occurred';
|
||||
const errorType = errorMessage.toLowerCase();
|
||||
|
||||
// Differentiate between fatal and transient errors
|
||||
const isFatalError =
|
||||
errorType.includes('cors') ||
|
||||
errorType.includes('ssl') ||
|
||||
errorType.includes('certificate') ||
|
||||
errorType.includes('403') ||
|
||||
errorType.includes('forbidden') ||
|
||||
errorType.includes('ERR_CERT_COMMON_NAME_INVALID');
|
||||
|
||||
// Dispatch event for health context to listen to
|
||||
window.dispatchEvent(new CustomEvent('copilotkit-error', {
|
||||
detail: {
|
||||
error: e,
|
||||
errorMessage,
|
||||
isFatal: isFatalError,
|
||||
}
|
||||
}));
|
||||
};
|
||||
|
||||
return (
|
||||
<CopilotKitHealthProvider initialHealthStatus={true}>
|
||||
<CopilotKitDegradedBanner />
|
||||
<ErrorBoundary
|
||||
context="CopilotKit"
|
||||
showDetails={process.env.NODE_ENV === 'development'}
|
||||
fallback={
|
||||
<Box sx={{ p: 3, textAlign: 'center' }}>
|
||||
<Typography variant="h6" color="warning" gutterBottom>
|
||||
Chat Unavailable
|
||||
</Typography>
|
||||
<Typography variant="body2" color="textSecondary">
|
||||
CopilotKit encountered an error. The app continues to work with manual controls.
|
||||
</Typography>
|
||||
</Box>
|
||||
}
|
||||
>
|
||||
<CopilotKit
|
||||
publicApiKey={apiKey}
|
||||
showDevConsole={false}
|
||||
onError={handleCopilotKitError}
|
||||
>
|
||||
{children}
|
||||
</CopilotKit>
|
||||
</ErrorBoundary>
|
||||
</CopilotKitHealthProvider>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<CopilotKitHealthProvider initialHealthStatus={false}>
|
||||
<CopilotKitDegradedBanner />
|
||||
{children}
|
||||
</CopilotKitHealthProvider>
|
||||
);
|
||||
};
|
||||
|
||||
// Component to handle initial routing based on subscription and onboarding status
|
||||
// Flow: Subscription → Onboarding → Dashboard
|
||||
const InitialRouteHandler: React.FC = () => {
|
||||
const { loading, error, isOnboardingComplete, initializeOnboarding, data } = useOnboarding();
|
||||
const { subscription, loading: subscriptionLoading, checkSubscription } = useSubscription();
|
||||
const [connectionError, setConnectionError] = useState<{
|
||||
hasError: boolean;
|
||||
error: Error | null;
|
||||
}>({
|
||||
hasError: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
// Poll for OAuth token alerts and show toast notifications
|
||||
// Only enabled when user is authenticated (has subscription)
|
||||
useOAuthTokenAlerts({
|
||||
enabled: subscription?.active === true,
|
||||
interval: 60000, // Poll every 1 minute
|
||||
});
|
||||
|
||||
// Check subscription on mount (non-blocking - don't wait for it to route)
|
||||
useEffect(() => {
|
||||
// Delay subscription check slightly to allow auth token getter to be installed first
|
||||
const timeoutId = setTimeout(async () => {
|
||||
// Retry logic for initial subscription check
|
||||
const maxRetries = 3;
|
||||
for (let attempt = 0; attempt < maxRetries; attempt++) {
|
||||
try {
|
||||
await checkSubscription();
|
||||
break; // Success
|
||||
} catch (err) {
|
||||
console.error(`App: Subscription check attempt ${attempt + 1} failed:`, err);
|
||||
|
||||
// If it's a connection error and we have retries left, wait and retry
|
||||
const isConnectionError = err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError');
|
||||
|
||||
if (isConnectionError && attempt < maxRetries - 1) {
|
||||
const delay = 1000 * Math.pow(2, attempt); // 1s, 2s
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
|
||||
// If final attempt or not a connection error, handle it
|
||||
if (attempt === maxRetries - 1 || !isConnectionError) {
|
||||
if (isConnectionError) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err as Error,
|
||||
});
|
||||
}
|
||||
// Don't block routing on other errors
|
||||
}
|
||||
}
|
||||
}
|
||||
}, 100); // Small delay to ensure TokenInstaller has run
|
||||
|
||||
return () => clearTimeout(timeoutId);
|
||||
}, []); // Remove checkSubscription dependency to prevent loop
|
||||
|
||||
// Initialize onboarding only after subscription is confirmed
|
||||
useEffect(() => {
|
||||
if (subscription && !subscriptionLoading) {
|
||||
// Check if user is new (no subscription record at all)
|
||||
const isNewUser = !subscription || subscription.plan === 'none';
|
||||
|
||||
console.log('InitialRouteHandler: Subscription data received:', {
|
||||
plan: subscription.plan,
|
||||
active: subscription.active,
|
||||
isNewUser,
|
||||
subscriptionLoading
|
||||
});
|
||||
|
||||
if (subscription.active && !isNewUser) {
|
||||
console.log('InitialRouteHandler: Subscription confirmed, initializing onboarding...');
|
||||
initializeOnboarding();
|
||||
}
|
||||
}
|
||||
}, [subscription, subscriptionLoading, initializeOnboarding]);
|
||||
|
||||
// Handle connection error - show connection error page
|
||||
if (connectionError.hasError) {
|
||||
const handleRetry = () => {
|
||||
setConnectionError({
|
||||
hasError: false,
|
||||
error: null,
|
||||
});
|
||||
// Re-trigger the subscription check using context
|
||||
checkSubscription().catch((err) => {
|
||||
if (err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError')) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err,
|
||||
});
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
const handleGoHome = () => {
|
||||
window.location.href = '/';
|
||||
};
|
||||
|
||||
return (
|
||||
<ConnectionErrorPage
|
||||
onRetry={handleRetry}
|
||||
onGoHome={handleGoHome}
|
||||
message={connectionError.error?.message || "Backend service is not available. Please check if the server is running."}
|
||||
title="Connection Error"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
// Loading state - only wait for onboarding init, not subscription check
|
||||
// Subscription check is non-blocking and happens in background
|
||||
const waitingForOnboardingInit = loading || !data;
|
||||
if (loading || waitingForOnboardingInit) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
{subscriptionLoading ? 'Checking subscription...' : 'Preparing your workspace...'}
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Error state
|
||||
if (error) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
p={3}
|
||||
>
|
||||
<Typography variant="h5" color="error" gutterBottom>
|
||||
Error
|
||||
</Typography>
|
||||
<Typography variant="body1" color="textSecondary" textAlign="center">
|
||||
{error}
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Decision tree for SIGNED-IN users:
|
||||
// Priority: Subscription → Onboarding → Dashboard (as per user flow: Landing → Subscription → Onboarding → Dashboard)
|
||||
|
||||
// 1. If subscription is still loading, show loading state
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// 2. No subscription data yet - handle gracefully
|
||||
// If onboarding is complete, allow access to dashboard (user already went through flow)
|
||||
// If onboarding not complete, check if subscription check is still loading or failed
|
||||
if (!subscription) {
|
||||
if (isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Onboarding complete but no subscription data → Dashboard (allow access)');
|
||||
return <Navigate to="/dashboard" replace />;
|
||||
}
|
||||
|
||||
// Onboarding not complete and no subscription data
|
||||
// If subscription check is still loading, show loading state
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Subscription check completed but returned null/undefined
|
||||
// This likely means no subscription - redirect to pricing
|
||||
console.log('InitialRouteHandler: No subscription data after check → Pricing page');
|
||||
return <Navigate to="/pricing" replace />;
|
||||
}
|
||||
|
||||
// 3. Check subscription status first
|
||||
const isNewUser = !subscription || subscription.plan === 'none';
|
||||
|
||||
// No active subscription → Show modal (SubscriptionContext handles this)
|
||||
// Don't redirect immediately - let the modal show first
|
||||
// User can click "Renew Subscription" button in modal to go to pricing
|
||||
// Or click "Maybe Later" to dismiss (but they still can't use features)
|
||||
if (isNewUser || !subscription.active) {
|
||||
console.log('InitialRouteHandler: No active subscription - modal will be shown by SubscriptionContext');
|
||||
// Note: SubscriptionContext will show the modal automatically when subscription is inactive
|
||||
// We still redirect to pricing for new users, but allow existing users with expired subscriptions
|
||||
// to see the modal first. The modal has a "Renew Subscription" button that navigates to pricing.
|
||||
// For new users (no subscription at all), redirect to pricing immediately
|
||||
if (isNewUser) {
|
||||
console.log('InitialRouteHandler: New user (no subscription) → Pricing page');
|
||||
return <Navigate to="/pricing" replace />;
|
||||
}
|
||||
// For existing users with inactive subscription, show modal but don't redirect immediately
|
||||
// The modal will be shown by SubscriptionContext, and user can click "Renew Subscription"
|
||||
// Allow access to dashboard (modal will be shown and block functionality)
|
||||
console.log('InitialRouteHandler: Inactive subscription - allowing access to show modal');
|
||||
// Continue to onboarding/dashboard flow - modal will be shown by SubscriptionContext
|
||||
}
|
||||
|
||||
// 4. Has active subscription, check onboarding status
|
||||
if (!isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Subscription active but onboarding incomplete → Onboarding');
|
||||
return <Navigate to="/onboarding" replace />;
|
||||
}
|
||||
|
||||
// 5. Has subscription AND completed onboarding → Dashboard
|
||||
console.log('InitialRouteHandler: All set (subscription + onboarding) → Dashboard');
|
||||
return <Navigate to="/dashboard" replace />;
|
||||
};
|
||||
|
||||
// Root route that chooses Landing (signed out) or InitialRouteHandler (signed in)
|
||||
const RootRoute: React.FC = () => {
|
||||
const { isSignedIn } = useAuth();
|
||||
if (isSignedIn) {
|
||||
return <InitialRouteHandler />;
|
||||
}
|
||||
return <Landing />;
|
||||
};
|
||||
|
||||
// Installs Clerk auth token getter into axios clients and stores user_id
|
||||
// Must render under ClerkProvider
|
||||
const TokenInstaller: React.FC = () => {
|
||||
const { getToken, userId, isSignedIn, signOut } = useAuth();
|
||||
|
||||
// Store user_id in localStorage when user signs in
|
||||
useEffect(() => {
|
||||
if (isSignedIn && userId) {
|
||||
console.log('TokenInstaller: Storing user_id in localStorage:', userId);
|
||||
localStorage.setItem('user_id', userId);
|
||||
|
||||
// Trigger event to notify SubscriptionContext that user is authenticated
|
||||
window.dispatchEvent(new CustomEvent('user-authenticated', { detail: { userId } }));
|
||||
} else if (!isSignedIn) {
|
||||
// Clear user_id when signed out
|
||||
console.log('TokenInstaller: Clearing user_id from localStorage');
|
||||
localStorage.removeItem('user_id');
|
||||
}
|
||||
}, [isSignedIn, userId]);
|
||||
|
||||
// Install token getter for API calls
|
||||
useEffect(() => {
|
||||
const tokenGetter = async () => {
|
||||
try {
|
||||
const template = process.env.REACT_APP_CLERK_JWT_TEMPLATE;
|
||||
// If a template is provided and it's not a placeholder, request a template-specific JWT
|
||||
if (template && template !== 'your_jwt_template_name_here') {
|
||||
// @ts-ignore Clerk types allow options object
|
||||
return await getToken({ template });
|
||||
}
|
||||
return await getToken();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
// Set token getter for main API client
|
||||
setAuthTokenGetter(tokenGetter);
|
||||
|
||||
// Set token getter for billing API client (same function)
|
||||
setBillingAuthTokenGetter(tokenGetter);
|
||||
|
||||
// Set token getter for media blob URL fetcher (for authenticated image/video requests)
|
||||
setMediaAuthTokenGetter(tokenGetter);
|
||||
}, [getToken]);
|
||||
|
||||
// Install Clerk signOut function for handling expired tokens
|
||||
useEffect(() => {
|
||||
if (signOut) {
|
||||
setClerkSignOut(async () => {
|
||||
await signOut();
|
||||
});
|
||||
}
|
||||
}, [signOut]);
|
||||
|
||||
return null;
|
||||
};
|
||||
|
||||
const App: React.FC = () => {
|
||||
// React Hooks MUST be at the top before any conditionals
|
||||
const [loading, setLoading] = useState(true);
|
||||
|
||||
// Get CopilotKit key from localStorage or .env
|
||||
const [copilotApiKey, setCopilotApiKey] = useState(() => {
|
||||
const savedKey = localStorage.getItem('copilotkit_api_key');
|
||||
const envKey = process.env.REACT_APP_COPILOTKIT_API_KEY || '';
|
||||
const key = (savedKey || envKey).trim();
|
||||
|
||||
// Validate key format if present
|
||||
if (key && !key.startsWith('ck_pub_')) {
|
||||
console.warn('CopilotKit API key format invalid - must start with ck_pub_');
|
||||
}
|
||||
|
||||
return key;
|
||||
});
|
||||
|
||||
// Initialize app - loading state will be managed by InitialRouteHandler
|
||||
useEffect(() => {
|
||||
// Remove manual health check - connection errors are handled by ErrorBoundary
|
||||
setLoading(false);
|
||||
}, []);
|
||||
|
||||
// Listen for CopilotKit key updates
|
||||
useEffect(() => {
|
||||
const handleKeyUpdate = (event: CustomEvent) => {
|
||||
const newKey = event.detail?.apiKey;
|
||||
if (newKey) {
|
||||
console.log('App: CopilotKit key updated, reloading...');
|
||||
setCopilotApiKey(newKey);
|
||||
setTimeout(() => window.location.reload(), 500);
|
||||
}
|
||||
};
|
||||
|
||||
window.addEventListener('copilotkit-key-updated', handleKeyUpdate as EventListener);
|
||||
return () => window.removeEventListener('copilotkit-key-updated', handleKeyUpdate as EventListener);
|
||||
}, []);
|
||||
|
||||
// Token installer must be inside ClerkProvider; see TokenInstaller below
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Connecting to ALwrity...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// Get environment variables with fallbacks
|
||||
const clerkPublishableKey = process.env.REACT_APP_CLERK_PUBLISHABLE_KEY || '';
|
||||
const clerkJSUrl = process.env.REACT_APP_CLERK_JS_URL;
|
||||
|
||||
// Show error if required keys are missing
|
||||
if (!clerkPublishableKey) {
|
||||
return (
|
||||
<Box sx={{ p: 3, textAlign: 'center' }}>
|
||||
<Typography color="error" variant="h6">
|
||||
Missing Clerk Publishable Key
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ mt: 1 }}>
|
||||
Please add REACT_APP_CLERK_PUBLISHABLE_KEY to your .env file
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Render app with or without CopilotKit based on whether we have a key
|
||||
const renderApp = () => {
|
||||
return (
|
||||
<Router>
|
||||
<AuthenticatedCopilotWrapper apiKey={copilotApiKey}>
|
||||
<ConditionalCopilotKit>
|
||||
<TokenInstaller />
|
||||
<Routes>
|
||||
<Route path="/" element={<RootRoute />} />
|
||||
<Route
|
||||
path="/onboarding"
|
||||
element={
|
||||
<ErrorBoundary context="Onboarding Wizard" showDetails>
|
||||
<Wizard />
|
||||
</ErrorBoundary>
|
||||
}
|
||||
/>
|
||||
{/* Error Boundary Testing - Development Only */}
|
||||
{process.env.NODE_ENV === 'development' && (
|
||||
<Route path="/error-test" element={<ErrorBoundaryTest />} />
|
||||
)}
|
||||
<Route path="/dashboard" element={<ProtectedRoute><MainDashboard /></ProtectedRoute>} />
|
||||
<Route path="/seo" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
|
||||
<Route path="/seo-dashboard" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
|
||||
<Route path="/content-planning" element={<ProtectedRoute><ContentPlanningDashboard /></ProtectedRoute>} />
|
||||
<Route path="/facebook-writer" element={<ProtectedRoute><FacebookWriter /></ProtectedRoute>} />
|
||||
<Route path="/linkedin-writer" element={<ProtectedRoute><LinkedInWriter /></ProtectedRoute>} />
|
||||
<Route path="/blog-writer" element={<ProtectedRoute><BlogWriter /></ProtectedRoute>} />
|
||||
<Route path="/story-writer" element={<ProtectedRoute><StoryWriter /></ProtectedRoute>} />
|
||||
<Route path="/story-projects" element={<ProtectedRoute><StoryProjectList /></ProtectedRoute>} />
|
||||
<Route path="/youtube-creator" element={<ProtectedRoute><YouTubeCreator /></ProtectedRoute>} />
|
||||
<Route path="/podcast-maker" element={<ProtectedRoute><PodcastDashboard /></ProtectedRoute>} />
|
||||
<Route path="/image-studio" element={<ProtectedRoute><ImageStudioDashboard /></ProtectedRoute>} />
|
||||
<Route path="/video-studio" element={<ProtectedRoute><VideoStudioDashboard /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/create" element={<ProtectedRoute><CreateVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/avatar" element={<ProtectedRoute><AvatarVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/enhance" element={<ProtectedRoute><EnhanceVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/extend" element={<ProtectedRoute><ExtendVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/edit" element={<ProtectedRoute><EditVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/transform" element={<ProtectedRoute><TransformVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/social" element={<ProtectedRoute><SocialVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/face-swap" element={<ProtectedRoute><FaceSwap /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-translate" element={<ProtectedRoute><VideoTranslate /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-background-remover" element={<ProtectedRoute><VideoBackgroundRemover /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/add-audio-to-video" element={<ProtectedRoute><AddAudioToVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/library" element={<ProtectedRoute><LibraryVideo /></ProtectedRoute>} />
|
||||
<Route path="/image-generator" element={<ProtectedRoute><CreateStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-editor" element={<ProtectedRoute><EditStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-upscale" element={<ProtectedRoute><UpscaleStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-control" element={<ProtectedRoute><ControlStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/face-swap" element={<ProtectedRoute><FaceSwapStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/compress" element={<ProtectedRoute><CompressionStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/processing" element={<ProtectedRoute><ImageProcessingStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/social-optimizer" element={<ProtectedRoute><SocialOptimizer /></ProtectedRoute>} />
|
||||
<Route path="/asset-library" element={<ProtectedRoute><AssetLibrary /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator" element={<ProtectedRoute><ProductMarketingDashboard /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/photoshoot" element={<ProtectedRoute><ProductPhotoshootStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/animation" element={<ProtectedRoute><ProductAnimationStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/video" element={<ProtectedRoute><ProductVideoStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/avatar" element={<ProtectedRoute><ProductAvatarStudio /></ProtectedRoute>} />
|
||||
<Route path="/product-marketing" element={<Navigate to="/campaign-creator" replace />} />
|
||||
<Route path="/scheduler-dashboard" element={<ProtectedRoute><SchedulerDashboard /></ProtectedRoute>} />
|
||||
<Route path="/billing" element={<ProtectedRoute><BillingPage /></ProtectedRoute>} />
|
||||
<Route path="/approvals" element={<ProtectedRoute><ApprovalsPage /></ProtectedRoute>} />
|
||||
<Route path="/team-activity" element={<ProtectedRoute><TeamActivityPage /></ProtectedRoute>} />
|
||||
<Route path="/stripe-disputes" element={<ProtectedRoute><StripeDisputesDashboard /></ProtectedRoute>} />
|
||||
<Route path="/pricing" element={<PricingPage />} />
|
||||
<Route path="/research-test" element={<ResearchDashboard />} />
|
||||
<Route path="/research-dashboard" element={<ResearchDashboard />} />
|
||||
<Route path="/alwrity-researcher" element={<ResearchDashboard />} />
|
||||
<Route path="/intent-research" element={<IntentResearchTest />} />
|
||||
<Route path="/wix-test" element={<WixTestPage />} />
|
||||
<Route path="/wix-test-direct" element={<WixTestPage />} />
|
||||
<Route path="/wix/callback" element={<WixCallbackPage />} />
|
||||
<Route path="/wp/callback" element={<WordPressCallbackPage />} />
|
||||
<Route path="/gsc/callback" element={<GSCAuthCallback />} />
|
||||
<Route path="/bing/callback" element={<BingCallbackPage />} />
|
||||
<Route path="/bing-analytics-storage" element={<ProtectedRoute><BingAnalyticsStorage /></ProtectedRoute>} />
|
||||
</Routes>
|
||||
</ConditionalCopilotKit>
|
||||
</AuthenticatedCopilotWrapper>
|
||||
</Router>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<ErrorBoundary
|
||||
context="Application Root"
|
||||
showDetails={process.env.NODE_ENV === 'development'}
|
||||
onError={(error, errorInfo) => {
|
||||
// Custom error handler - send to analytics/monitoring
|
||||
console.error('Global error caught:', { error, errorInfo });
|
||||
// TODO: Send to error tracking service (Sentry, LogRocket, etc.)
|
||||
}}
|
||||
>
|
||||
<ClerkProvider publishableKey={clerkPublishableKey} clerkJSUrl={clerkJSUrl}>
|
||||
<SubscriptionProvider>
|
||||
<OnboardingProvider>
|
||||
{renderApp()}
|
||||
</OnboardingProvider>
|
||||
</SubscriptionProvider>
|
||||
</ClerkProvider>
|
||||
</ErrorBoundary>
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
||||
537
_session_backup/ResearchSummary.tsx
Normal file
537
_session_backup/ResearchSummary.tsx
Normal file
@@ -0,0 +1,537 @@
|
||||
import React, { useMemo, useCallback } from "react";
|
||||
import { Stack, Typography, Chip, Divider, Box, alpha, Paper, Tooltip } from "@mui/material";
|
||||
import {
|
||||
Insights as InsightsIcon,
|
||||
Search as SearchIcon,
|
||||
AttachMoney as AttachMoneyIcon,
|
||||
EditNote as EditNoteIcon,
|
||||
Article as ArticleIcon,
|
||||
AutoAwesome as AutoAwesomeIcon,
|
||||
FormatQuote as FormatQuoteIcon,
|
||||
Campaign as CampaignIcon,
|
||||
Explore as ExploreIcon,
|
||||
} from "@mui/icons-material";
|
||||
import { Research, ResearchInsight } from "../types";
|
||||
import { GlassyCard, glassyCardSx, PrimaryButton } from "../ui";
|
||||
import { FactCard } from "../FactCard";
|
||||
|
||||
interface ResearchSummaryProps {
|
||||
research: Research;
|
||||
canGenerateScript: boolean;
|
||||
onGenerateScript: () => void;
|
||||
}
|
||||
|
||||
export const ResearchSummary: React.FC<ResearchSummaryProps> = ({
|
||||
research,
|
||||
canGenerateScript,
|
||||
onGenerateScript,
|
||||
}) => {
|
||||
// Simple markdown-to-HTML converter
|
||||
const renderMarkdown = useCallback((text: string) => {
|
||||
if (!text) return null;
|
||||
return text
|
||||
.split('\n')
|
||||
.filter(line => line.trim() !== '') // Remove empty lines
|
||||
.map((line, i) => {
|
||||
// Handle bold
|
||||
let processedLine = line.replace(/\*\*(.*?)\*\*/g, '<strong>$1</strong>');
|
||||
// Handle lists
|
||||
if (processedLine.trim().startsWith('- ') || processedLine.trim().startsWith('* ')) {
|
||||
return <li key={i} dangerouslySetInnerHTML={{ __html: processedLine.trim().substring(2) }} style={{ marginBottom: '4px', fontSize: '0.9rem' }} />;
|
||||
}
|
||||
// Handle headers - make them smaller
|
||||
if (processedLine.startsWith('### ')) {
|
||||
return <Typography key={i} variant="subtitle2" fontWeight={700} sx={{ mt: 1.5, mb: 0.5, color: '#1e293b' }}>{processedLine.substring(4)}</Typography>;
|
||||
}
|
||||
if (processedLine.startsWith('## ')) {
|
||||
return <Typography key={i} variant="subtitle1" fontWeight={700} sx={{ mt: 1.5, mb: 0.5, color: '#0f172a' }}>{processedLine.substring(3)}</Typography>;
|
||||
}
|
||||
// Paragraphs - compact spacing
|
||||
return processedLine.trim() ? <p key={i} dangerouslySetInnerHTML={{ __html: processedLine }} style={{ margin: '4px 0', fontSize: '0.9rem' }} /> : null;
|
||||
});
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<GlassyCard sx={glassyCardSx}>
|
||||
<Stack spacing={3}>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center" flexWrap="wrap" gap={2}>
|
||||
<Stack direction="row" alignItems="center" spacing={2} sx={{ flex: 1 }}>
|
||||
<Typography variant="h6" sx={{ display: "flex", alignItems: "center", gap: 1, color: "#0f172a", fontWeight: 700 }}>
|
||||
<InsightsIcon />
|
||||
Research Summary
|
||||
</Typography>
|
||||
|
||||
{/* Research Metadata - Moved alongside title */}
|
||||
<Stack direction="row" spacing={1.5} flexWrap="wrap">
|
||||
{research.searchQueries && research.searchQueries.length > 0 && (
|
||||
<Chip
|
||||
icon={<SearchIcon sx={{ fontSize: "1rem !important" }} />}
|
||||
label={`${research.searchQueries.length} search${research.searchQueries.length > 1 ? "es" : ""}`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#667eea", 0.1),
|
||||
color: "#667eea",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(102, 126, 234, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{research.searchType && (
|
||||
<Chip
|
||||
label={`${research.searchType.charAt(0).toUpperCase() + research.searchType.slice(1)} search`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#10b981", 0.1),
|
||||
color: "#059669",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(16, 185, 129, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{research.sourceCount !== undefined && (
|
||||
<Chip
|
||||
label={`${research.sourceCount} source${research.sourceCount !== 1 ? "s" : ""}`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#6366f1", 0.1),
|
||||
color: "#4f46e5",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(99, 102, 241, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{research.cost !== undefined && (
|
||||
<Chip
|
||||
icon={<AttachMoneyIcon sx={{ fontSize: "0.875rem !important" }} />}
|
||||
label={`$${research.cost.toFixed(3)}`}
|
||||
size="small"
|
||||
sx={{
|
||||
background: alpha("#f59e0b", 0.1),
|
||||
color: "#d97706",
|
||||
fontWeight: 600,
|
||||
border: "1px solid rgba(245, 158, 11, 0.2)",
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</Stack>
|
||||
</Stack>
|
||||
|
||||
<PrimaryButton
|
||||
onClick={onGenerateScript}
|
||||
disabled={!canGenerateScript}
|
||||
startIcon={<EditNoteIcon />}
|
||||
tooltip={!canGenerateScript ? "Complete research to generate script" : "Generate AI-powered script from research"}
|
||||
>
|
||||
Generate Script
|
||||
</PrimaryButton>
|
||||
</Stack>
|
||||
|
||||
<Box sx={{ width: "100%" }}>
|
||||
{/* Main Summary */}
|
||||
{research.summary && (
|
||||
<Paper
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
mb: 3,
|
||||
background: "#f8fafc",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Typography variant="subtitle2" sx={{ mb: 1.5, color: "#64748b", fontWeight: 700, fontSize: "0.75rem", textTransform: "uppercase", letterSpacing: "0.05em", display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<AutoAwesomeIcon fontSize="small" sx={{ color: "#667eea", fontSize: "1rem" }} />
|
||||
Executive Summary
|
||||
</Typography>
|
||||
<Box sx={{
|
||||
lineHeight: 1.6,
|
||||
fontSize: "0.9rem",
|
||||
color: "#334155",
|
||||
"& p": { m: 0, mb: 1 },
|
||||
"& ul": { m: 0, mb: 1, pl: 2.5 },
|
||||
"& li": { mb: 0.5 },
|
||||
"& strong": { color: "#0f172a", fontWeight: 600 }
|
||||
}}>
|
||||
{renderMarkdown(research.summary)}
|
||||
</Box>
|
||||
</Paper>
|
||||
)}
|
||||
|
||||
{/* Deep Insights */}
|
||||
{(research.keyInsights && research.keyInsights.length > 0) ? (
|
||||
<Box sx={{ mb: 4 }}>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<ArticleIcon sx={{ color: "#667eea" }} />
|
||||
Deep Insights
|
||||
</Typography>
|
||||
<Stack spacing={2.5}>
|
||||
{research.keyInsights.map((insight: ResearchInsight, idx: number) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "#ffffff",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
boxShadow: "0 2px 12px rgba(0,0,0,0.03)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="flex-start" sx={{ mb: 1.5 }}>
|
||||
<Typography variant="subtitle1" sx={{ color: "#0f172a", fontWeight: 700 }}>
|
||||
{insight.title}
|
||||
</Typography>
|
||||
{insight.source_indices && insight.source_indices.length > 0 && (
|
||||
<Stack direction="row" spacing={0.5}>
|
||||
{insight.source_indices.map(sIdx => {
|
||||
const sourceIdx = sIdx - 1;
|
||||
const fact = research.factCards[sourceIdx];
|
||||
const sourceUrl = fact?.url;
|
||||
const hasUrl = !!sourceUrl;
|
||||
const hue = (sIdx * 47 + 220) % 360;
|
||||
const gradientFrom = `hsl(${hue}, 70%, 55%)`;
|
||||
const gradientTo = `hsl(${(hue + 30) % 360}, 80%, 65%)`;
|
||||
return (
|
||||
<Tooltip
|
||||
key={sIdx}
|
||||
title={hasUrl ? (
|
||||
<Box sx={{ maxWidth: 300, wordBreak: "break-all" }}>
|
||||
<Typography variant="caption" sx={{ color: "#fff", fontWeight: 600 }}>Source {sIdx}</Typography>
|
||||
<br />
|
||||
<Typography variant="caption" sx={{ color: "rgba(255,255,255,0.8)", fontSize: "0.65rem" }}>{sourceUrl}</Typography>
|
||||
</Box>
|
||||
) : `Source ${sIdx}`}
|
||||
arrow
|
||||
placement="top"
|
||||
>
|
||||
<Chip
|
||||
label={hasUrl ? `S${sIdx} ↗` : `S${sIdx}`}
|
||||
size="small"
|
||||
onClick={hasUrl ? () => window.open(sourceUrl, "_blank", "noopener,noreferrer") : undefined}
|
||||
sx={{
|
||||
height: 24,
|
||||
minWidth: 36,
|
||||
fontSize: '0.7rem',
|
||||
fontWeight: 800,
|
||||
fontFamily: "'Inter', 'Roboto', monospace",
|
||||
letterSpacing: "0.02em",
|
||||
border: "none",
|
||||
background: hasUrl
|
||||
? `linear-gradient(135deg, ${gradientFrom}, ${gradientTo})`
|
||||
: `linear-gradient(135deg, ${alpha(gradientFrom, 0.3)}, ${alpha(gradientTo, 0.3)})`,
|
||||
color: hasUrl ? "#fff" : alpha("#fff", 0.7),
|
||||
cursor: hasUrl ? "pointer" : "default",
|
||||
borderRadius: "8px",
|
||||
px: 0.5,
|
||||
boxShadow: hasUrl
|
||||
? `0 2px 8px ${alpha(gradientFrom, 0.35)}, inset 0 1px 0 ${alpha("#fff", 0.2)}`
|
||||
: "none",
|
||||
transition: "all 0.2s ease",
|
||||
"&:hover": hasUrl ? {
|
||||
background: `linear-gradient(135deg, ${gradientTo}, ${gradientFrom})`,
|
||||
boxShadow: `0 4px 14px ${alpha(gradientFrom, 0.5)}, inset 0 1px 0 ${alpha("#fff", 0.3)}`,
|
||||
transform: "translateY(-1px)",
|
||||
} : {},
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
);
|
||||
})}
|
||||
</Stack>
|
||||
)}
|
||||
</Stack>
|
||||
<Box sx={{
|
||||
color: "#475569",
|
||||
lineHeight: 1.7,
|
||||
fontSize: "0.9rem",
|
||||
"& p": { m: 0, mb: 1.5 },
|
||||
"& ul": { m: 0, mb: 1.5, pl: 2 }
|
||||
}}>
|
||||
{renderMarkdown(insight.content)}
|
||||
</Box>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
) : (
|
||||
/* Fallback if keyInsights is missing but we have summary paragraphs */
|
||||
research.summary && research.summary.length > 500 && !research.keyInsights && (
|
||||
<Box sx={{ mb: 4 }}>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<ArticleIcon sx={{ color: "#667eea" }} />
|
||||
Additional Insights
|
||||
</Typography>
|
||||
<Paper
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "#ffffff",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
boxShadow: "0 2px 12px rgba(0,0,0,0.03)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Box sx={{
|
||||
color: "#475569",
|
||||
lineHeight: 1.7,
|
||||
fontSize: "0.9rem",
|
||||
}}>
|
||||
{/* Render parts of summary that might contain insights if structured data is missing */}
|
||||
{renderMarkdown(research.summary.split('\n\n').slice(1).join('\n\n'))}
|
||||
</Box>
|
||||
</Paper>
|
||||
</Box>
|
||||
)
|
||||
)}
|
||||
|
||||
{/* Expert Quotes Section */}
|
||||
{research.expertQuotes && research.expertQuotes.length > 0 && (
|
||||
<Box sx={{ mt: 4, pt: 3, borderTop: "1px solid rgba(0,0,0,0.04)" }}>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<FormatQuoteIcon sx={{ color: "#8b5cf6" }} />
|
||||
Expert Quotes ({research.expertQuotes.length})
|
||||
</Typography>
|
||||
<Stack spacing={2}>
|
||||
{research.expertQuotes.map((eq, idx) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "linear-gradient(135deg, rgba(139, 92, 246, 0.04) 0%, rgba(99, 102, 241, 0.04) 100%)",
|
||||
border: "1px solid rgba(139, 92, 246, 0.15)",
|
||||
borderLeft: "4px solid #8b5cf6",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" spacing={1.5} alignItems="flex-start">
|
||||
<FormatQuoteIcon sx={{ color: "#8b5cf6", fontSize: "1.5rem", mt: -0.5, opacity: 0.7 }} />
|
||||
<Box sx={{ flex: 1 }}>
|
||||
<Typography variant="body2" sx={{ color: "#1e293b", fontStyle: "italic", lineHeight: 1.7, fontSize: "0.95rem" }}>
|
||||
“{eq.quote}”
|
||||
</Typography>
|
||||
{eq.source_index !== undefined && (() => {
|
||||
const fact = research.factCards[eq.source_index - 1];
|
||||
const sourceUrl = fact?.url;
|
||||
const hasUrl = !!sourceUrl;
|
||||
const hue = (eq.source_index * 47 + 270) % 360;
|
||||
const gradientFrom = `hsl(${hue}, 70%, 55%)`;
|
||||
const gradientTo = `hsl(${(hue + 30) % 360}, 80%, 65%)`;
|
||||
return (
|
||||
<Box sx={{ mt: 1 }}>
|
||||
<Tooltip title={hasUrl ? (
|
||||
<Box sx={{ maxWidth: 300, wordBreak: "break-all" }}>
|
||||
<Typography variant="caption" sx={{ color: "#fff", fontWeight: 600 }}>Source {eq.source_index}</Typography>
|
||||
<br />
|
||||
<Typography variant="caption" sx={{ color: "rgba(255,255,255,0.8)", fontSize: "0.65rem" }}>{sourceUrl}</Typography>
|
||||
</Box>
|
||||
) : `Source ${eq.source_index}`} arrow placement="top">
|
||||
<Chip
|
||||
label={hasUrl ? `Source ${eq.source_index} ↗` : `Source ${eq.source_index}`}
|
||||
size="small"
|
||||
onClick={hasUrl ? () => window.open(sourceUrl, "_blank", "noopener,noreferrer") : undefined}
|
||||
sx={{
|
||||
height: 24,
|
||||
fontSize: "0.7rem",
|
||||
fontWeight: 800,
|
||||
fontFamily: "'Inter', 'Roboto', monospace",
|
||||
border: "none",
|
||||
background: hasUrl
|
||||
? `linear-gradient(135deg, ${gradientFrom}, ${gradientTo})`
|
||||
: `linear-gradient(135deg, ${alpha(gradientFrom, 0.3)}, ${alpha(gradientTo, 0.3)})`,
|
||||
color: hasUrl ? "#fff" : alpha("#fff", 0.7),
|
||||
cursor: hasUrl ? "pointer" : "default",
|
||||
borderRadius: "8px",
|
||||
px: 1,
|
||||
boxShadow: hasUrl
|
||||
? `0 2px 8px ${alpha(gradientFrom, 0.35)}, inset 0 1px 0 ${alpha("#fff", 0.2)}`
|
||||
: "none",
|
||||
transition: "all 0.2s ease",
|
||||
"&:hover": hasUrl ? {
|
||||
background: `linear-gradient(135deg, ${gradientTo}, ${gradientFrom})`,
|
||||
boxShadow: `0 4px 14px ${alpha(gradientFrom, 0.5)}, inset 0 1px 0 ${alpha("#fff", 0.3)}`,
|
||||
transform: "translateY(-1px)",
|
||||
} : {},
|
||||
}}
|
||||
/>
|
||||
</Tooltip>
|
||||
</Box>
|
||||
);
|
||||
})()}
|
||||
</Box>
|
||||
</Stack>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
)}
|
||||
|
||||
{/* Search Queries Used */}
|
||||
{research.searchQueries && research.searchQueries.length > 0 && (
|
||||
<Box sx={{ mt: 4, pt: 3, borderTop: "1px solid rgba(0,0,0,0.04)" }}>
|
||||
<Typography variant="subtitle2" sx={{ mb: 1.5, color: "#64748b", fontWeight: 700, fontSize: "0.7rem", textTransform: "uppercase", letterSpacing: "0.05em" }}>
|
||||
Search Queries Used
|
||||
</Typography>
|
||||
<Stack direction="row" spacing={1} flexWrap="wrap" useFlexGap>
|
||||
{research.searchQueries.map((query, idx) => (
|
||||
<Chip
|
||||
key={idx}
|
||||
label={query}
|
||||
size="small"
|
||||
variant="outlined"
|
||||
sx={{
|
||||
borderColor: "rgba(102, 126, 234, 0.15)",
|
||||
color: "#94a3b8",
|
||||
background: alpha("#f8fafc", 0.3),
|
||||
fontSize: "0.7rem",
|
||||
borderRadius: 1,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
|
||||
{research.factCards.length > 0 && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(0,0,0,0.08)" }} />
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center" sx={{ mb: 1.5, flexWrap: "wrap", gap: 1 }}>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600 }}>
|
||||
Research Sources & Facts ({research.factCards.length})
|
||||
</Typography>
|
||||
<Typography variant="caption" sx={{ color: "#64748b", fontSize: "0.75rem" }}>
|
||||
Click to expand • Hover to see source
|
||||
</Typography>
|
||||
</Stack>
|
||||
<Box
|
||||
sx={{
|
||||
display: "grid",
|
||||
gridTemplateColumns: { xs: "1fr", sm: "repeat(2, 1fr)", md: "repeat(3, 1fr)", lg: "repeat(4, 1fr)" },
|
||||
gap: 1.5,
|
||||
width: "100%",
|
||||
overflow: "hidden",
|
||||
}}
|
||||
>
|
||||
{research.factCards.map((fact) => (
|
||||
<FactCard key={fact.id} fact={fact} />
|
||||
))}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Listener CTA Section */}
|
||||
{research.listenerCta && research.listenerCta.length > 0 && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(0,0,0,0.08)" }} />
|
||||
<Box>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<CampaignIcon sx={{ color: "#f59e0b" }} />
|
||||
Listener Call-to-Action Ideas ({research.listenerCta.length})
|
||||
</Typography>
|
||||
<Stack spacing={1.5}>
|
||||
{research.listenerCta.map((cta, idx) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2,
|
||||
background: "linear-gradient(135deg, rgba(245, 158, 11, 0.05) 0%, rgba(251, 191, 36, 0.05) 100%)",
|
||||
border: "1px solid rgba(245, 158, 11, 0.15)",
|
||||
borderRadius: 2,
|
||||
display: "flex",
|
||||
alignItems: "flex-start",
|
||||
gap: 1.5,
|
||||
}}
|
||||
>
|
||||
<Chip
|
||||
label={`#${idx + 1}`}
|
||||
size="small"
|
||||
sx={{
|
||||
bgcolor: alpha("#f59e0b", 0.15),
|
||||
color: "#b45309",
|
||||
fontWeight: 700,
|
||||
fontSize: "0.7rem",
|
||||
height: 24,
|
||||
minWidth: 32,
|
||||
}}
|
||||
/>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.6, flex: 1, pt: 0.2 }}>
|
||||
{cta}
|
||||
</Typography>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Mapped Angles Section */}
|
||||
{research.mappedAngles && research.mappedAngles.length > 0 && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(0,0,0,0.08)" }} />
|
||||
<Box>
|
||||
<Typography variant="h6" sx={{ mb: 2, color: "#0f172a", fontWeight: 700, display: "flex", alignItems: "center", gap: 1 }}>
|
||||
<ExploreIcon sx={{ color: "#06b6d4" }} />
|
||||
Content Angles ({research.mappedAngles.length})
|
||||
</Typography>
|
||||
<Stack spacing={2}>
|
||||
{research.mappedAngles.map((angle, idx) => (
|
||||
<Paper
|
||||
key={idx}
|
||||
elevation={0}
|
||||
sx={{
|
||||
p: 2.5,
|
||||
background: "#ffffff",
|
||||
border: "1px solid rgba(0,0,0,0.06)",
|
||||
borderLeft: "4px solid #06b6d4",
|
||||
boxShadow: "0 2px 12px rgba(0,0,0,0.03)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="flex-start" sx={{ mb: 1 }}>
|
||||
<Typography variant="subtitle1" sx={{ color: "#0f172a", fontWeight: 700 }}>
|
||||
{angle.title}
|
||||
</Typography>
|
||||
{angle.mappedFactIds && angle.mappedFactIds.length > 0 && (
|
||||
<Stack direction="row" spacing={0.5}>
|
||||
{angle.mappedFactIds.slice(0, 4).map((fid: string) => (
|
||||
<Chip
|
||||
key={fid}
|
||||
label={fid.replace("fact_", "F")}
|
||||
size="small"
|
||||
variant="outlined"
|
||||
sx={{
|
||||
height: 18,
|
||||
fontSize: "0.6rem",
|
||||
fontWeight: 700,
|
||||
borderColor: alpha("#06b6d4", 0.3),
|
||||
color: "#06b6d4",
|
||||
bgcolor: alpha("#06b6d4", 0.05),
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
{angle.mappedFactIds.length > 4 && (
|
||||
<Chip
|
||||
label={`+${angle.mappedFactIds.length - 4}`}
|
||||
size="small"
|
||||
sx={{ height: 18, fontSize: "0.6rem", color: "#64748b" }}
|
||||
/>
|
||||
)}
|
||||
</Stack>
|
||||
)}
|
||||
</Stack>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7, fontSize: "0.9rem" }}>
|
||||
{angle.why}
|
||||
</Typography>
|
||||
</Paper>
|
||||
))}
|
||||
</Stack>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
</Stack>
|
||||
</GlassyCard>
|
||||
);
|
||||
};
|
||||
|
||||
811
_session_backup/SceneEditor.tsx
Normal file
811
_session_backup/SceneEditor.tsx
Normal file
@@ -0,0 +1,811 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Stack, Box, Typography, Divider, Chip, alpha, CircularProgress, LinearProgress, IconButton, Tooltip } from "@mui/material";
|
||||
import {
|
||||
EditNote as EditNoteIcon,
|
||||
CheckCircle as CheckCircleIcon,
|
||||
RadioButtonUnchecked as RadioButtonUncheckedIcon,
|
||||
VolumeUp as VolumeUpIcon,
|
||||
PlayArrow as PlayArrowIcon,
|
||||
Image as ImageIcon,
|
||||
Delete as DeleteIcon,
|
||||
} from "@mui/icons-material";
|
||||
import { Scene, Line, Knobs } from "../types";
|
||||
import { GlassyCard, glassyCardSx, PrimaryButton } from "../ui";
|
||||
import { LineEditor } from "./LineEditor";
|
||||
import { ImageRegenerateModal, ImageGenerationSettings } from "./ImageRegenerateModal";
|
||||
import { AudioRegenerateModal, AudioGenerationSettings } from "./AudioRegenerateModal";
|
||||
import { podcastApi } from "../../../services/podcastApi";
|
||||
import { aiApiClient } from "../../../api/client";
|
||||
import { getCachedMedia, setCachedMedia } from "../../../utils/mediaCache";
|
||||
|
||||
interface SceneEditorProps {
|
||||
scene: Scene;
|
||||
onUpdateScene: (s: Scene) => void;
|
||||
onApprove: (id: string) => Promise<void>;
|
||||
onDelete: (sceneId: string) => void;
|
||||
knobs: Knobs;
|
||||
approvingSceneId?: string | null;
|
||||
generatingAudioId?: string | null;
|
||||
onAudioGenerationStart?: (sceneId: string) => void;
|
||||
onAudioGenerated?: (sceneId: string, audioUrl: string) => void;
|
||||
idea?: string; // Podcast idea for image generation context
|
||||
avatarUrl?: string | null; // Base avatar URL for consistent scene image generation
|
||||
totalScenes?: number; // Total number of scenes in the script
|
||||
}
|
||||
|
||||
export const SceneEditor: React.FC<SceneEditorProps> = ({
|
||||
scene,
|
||||
onUpdateScene,
|
||||
onApprove,
|
||||
onDelete,
|
||||
knobs,
|
||||
approvingSceneId,
|
||||
generatingAudioId,
|
||||
onAudioGenerationStart,
|
||||
onAudioGenerated,
|
||||
idea,
|
||||
avatarUrl,
|
||||
totalScenes,
|
||||
}) => {
|
||||
const [localGenerating, setLocalGenerating] = useState(false);
|
||||
const [generatingImage, setGeneratingImage] = useState(false);
|
||||
const [imageGenerationStatus, setImageGenerationStatus] = useState<string>("");
|
||||
const [imageGenerationProgress, setImageGenerationProgress] = useState<number>(0);
|
||||
const [audioBlobUrl, setAudioBlobUrl] = useState<string | null>(null);
|
||||
const [imageBlobUrl, setImageBlobUrl] = useState<string | null>(null);
|
||||
const [imageLoading, setImageLoading] = useState(false);
|
||||
const [showRegenerateModal, setShowRegenerateModal] = useState(false);
|
||||
const [showAudioModal, setShowAudioModal] = useState(false);
|
||||
const [audioSettings, setAudioSettings] = useState<AudioGenerationSettings>({
|
||||
voiceId: "Wise_Woman",
|
||||
speed: 1.0,
|
||||
volume: 1.0,
|
||||
pitch: 0.0,
|
||||
emotion: scene.emotion || "neutral",
|
||||
englishNormalization: true,
|
||||
sampleRate: 24000,
|
||||
bitrate: 64000,
|
||||
channel: "1",
|
||||
format: "mp3",
|
||||
languageBoost: "auto",
|
||||
});
|
||||
|
||||
// Load audio as blob when audioUrl is available
|
||||
useEffect(() => {
|
||||
if (!scene.audioUrl) {
|
||||
// Clean up blob URL if audioUrl is removed
|
||||
setAudioBlobUrl((currentBlobUrl) => {
|
||||
if (currentBlobUrl) {
|
||||
URL.revokeObjectURL(currentBlobUrl);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
let isMounted = true;
|
||||
const currentAudioUrl = scene.audioUrl; // Capture current value
|
||||
|
||||
const loadAudioBlob = async () => {
|
||||
try {
|
||||
// Normalize path
|
||||
let audioPath = currentAudioUrl.startsWith('/') ? currentAudioUrl : `/${currentAudioUrl}`;
|
||||
|
||||
// Convert /api/story/audio/ to /api/podcast/audio/ if needed
|
||||
if (audioPath.includes('/api/story/audio/')) {
|
||||
const filename = audioPath.split('/api/story/audio/').pop() || '';
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Ensure it's a podcast audio endpoint
|
||||
if (!audioPath.includes('/api/podcast/audio/')) {
|
||||
const filename = audioPath.split('/').pop() || currentAudioUrl;
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
audioPath = audioPath.split('?')[0];
|
||||
|
||||
const response = await aiApiClient.get(audioPath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
if (!isMounted) {
|
||||
// Component unmounted or audioUrl changed, don't set blob URL
|
||||
return;
|
||||
}
|
||||
|
||||
// Double-check that audioUrl hasn't changed
|
||||
if (scene.audioUrl !== currentAudioUrl) {
|
||||
return;
|
||||
}
|
||||
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
setAudioBlobUrl((prevBlobUrl) => {
|
||||
// Clean up previous blob URL if exists
|
||||
if (prevBlobUrl && prevBlobUrl !== blobUrl) {
|
||||
URL.revokeObjectURL(prevBlobUrl);
|
||||
}
|
||||
return blobUrl;
|
||||
});
|
||||
} catch (error) {
|
||||
console.error(`Failed to load audio blob for scene ${scene.id}:`, error);
|
||||
// Don't set blob URL on error - will show error state
|
||||
}
|
||||
};
|
||||
|
||||
loadAudioBlob();
|
||||
|
||||
// Cleanup: only mark as unmounted, don't revoke blob URL here
|
||||
// The blob URL will be cleaned up when audioUrl changes (new effect) or component unmounts
|
||||
return () => {
|
||||
isMounted = false;
|
||||
};
|
||||
}, [scene.audioUrl, scene.id]);
|
||||
|
||||
// Load image as blob when imageUrl is available
|
||||
useEffect(() => {
|
||||
if (!scene.imageUrl) {
|
||||
// Clean up blob URL if imageUrl is removed
|
||||
setImageBlobUrl((currentBlobUrl) => {
|
||||
if (currentBlobUrl && currentBlobUrl.startsWith('blob:')) {
|
||||
URL.revokeObjectURL(currentBlobUrl);
|
||||
}
|
||||
return null;
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Check cache first with scene context
|
||||
const cachedUrl = getCachedMedia(scene.imageUrl, scene.id);
|
||||
if (cachedUrl) {
|
||||
console.log('[SceneEditor] Using cached image:', scene.imageUrl, `(scene: ${scene.id})`);
|
||||
setImageBlobUrl(cachedUrl);
|
||||
setImageLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
let isMounted = true;
|
||||
const currentImageUrl = scene.imageUrl; // Capture current value
|
||||
|
||||
const loadImageBlob = async () => {
|
||||
try {
|
||||
setImageLoading(true);
|
||||
|
||||
// Check cache again in case it was loaded while we were waiting
|
||||
const cachedUrl = getCachedMedia(currentImageUrl, scene.id);
|
||||
if (cachedUrl) {
|
||||
if (isMounted) {
|
||||
setImageBlobUrl(cachedUrl);
|
||||
setImageLoading(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
console.log('[SceneEditor] Loading image blob for:', currentImageUrl);
|
||||
|
||||
// Normalize path
|
||||
let imagePath = currentImageUrl.startsWith('/') ? currentImageUrl : `/${currentImageUrl}`;
|
||||
|
||||
// Convert /api/story/images/ to /api/podcast/images/ if needed
|
||||
if (imagePath.includes('/api/story/images/')) {
|
||||
const filename = imagePath.split('/api/story/images/').pop() || '';
|
||||
imagePath = `/api/podcast/images/${filename}`;
|
||||
}
|
||||
|
||||
// Ensure it's a podcast image endpoint
|
||||
if (!imagePath.includes('/api/podcast/images/')) {
|
||||
const filename = imagePath.split('/').pop() || currentImageUrl;
|
||||
imagePath = `/api/podcast/images/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
imagePath = imagePath.split('?')[0];
|
||||
|
||||
const response = await aiApiClient.get(imagePath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
if (!isMounted) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Double-check that imageUrl hasn't changed
|
||||
if (scene.imageUrl !== currentImageUrl) {
|
||||
return;
|
||||
}
|
||||
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
// Cache the blob URL with scene context
|
||||
setCachedMedia(currentImageUrl, blobUrl, 'image', blob.size, scene.id);
|
||||
|
||||
setImageBlobUrl((prevBlobUrl) => {
|
||||
// Clean up previous blob URL if exists
|
||||
if (prevBlobUrl && prevBlobUrl !== blobUrl && prevBlobUrl.startsWith('blob:')) {
|
||||
URL.revokeObjectURL(prevBlobUrl);
|
||||
}
|
||||
return blobUrl;
|
||||
});
|
||||
console.log('[SceneEditor] Image blob loaded and cached successfully:', currentImageUrl);
|
||||
} catch (error) {
|
||||
console.error('[SceneEditor] Failed to load image blob:', error);
|
||||
if (isMounted) {
|
||||
// Try adding query token as fallback
|
||||
try {
|
||||
const token = localStorage.getItem('clerk_dashboard_token') || '';
|
||||
if (token) {
|
||||
const urlWithToken = `${currentImageUrl}?token=${encodeURIComponent(token)}`;
|
||||
setImageBlobUrl(urlWithToken);
|
||||
setCachedMedia(currentImageUrl, urlWithToken, 'image', undefined, scene.id);
|
||||
}
|
||||
} catch (fallbackError) {
|
||||
console.error('[SceneEditor] Fallback image loading failed:', fallbackError);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
if (isMounted) {
|
||||
setImageLoading(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
loadImageBlob();
|
||||
|
||||
return () => {
|
||||
isMounted = false;
|
||||
// Don't cleanup blob URL here - let the cache handle it
|
||||
};
|
||||
}, [scene.imageUrl]);
|
||||
|
||||
const updateLine = (updatedLine: Line) => {
|
||||
const updated = { ...scene, lines: scene.lines.map((l) => (l.id === updatedLine.id ? updatedLine : l)) };
|
||||
onUpdateScene(updated);
|
||||
};
|
||||
|
||||
const approving = approvingSceneId === scene.id;
|
||||
const generating = generatingAudioId === scene.id || localGenerating;
|
||||
const hasAudio = Boolean(scene.audioUrl && audioBlobUrl);
|
||||
const hasImage = Boolean(scene.imageUrl);
|
||||
|
||||
const handleApproveAndGenerate = async (settings?: AudioGenerationSettings) => {
|
||||
const wasAlreadyApproved = scene.approved;
|
||||
const sceneId = scene.id;
|
||||
|
||||
try {
|
||||
// Set generating state
|
||||
setLocalGenerating(true);
|
||||
if (onAudioGenerationStart) {
|
||||
onAudioGenerationStart(sceneId);
|
||||
}
|
||||
|
||||
// If scene is not approved yet, approve it first
|
||||
// This will update the parent script state
|
||||
if (!scene.approved) {
|
||||
await onApprove(sceneId);
|
||||
// The parent's approveScene already updated the script state
|
||||
// We need to wait for React to propagate the updated scene prop
|
||||
// For now, we'll update it locally too to ensure UI updates immediately
|
||||
onUpdateScene({ ...scene, approved: true });
|
||||
}
|
||||
|
||||
// Use the current scene (which should now be approved)
|
||||
// If scene prop hasn't updated yet, use the local update we just made
|
||||
const currentScene = { ...scene, approved: true };
|
||||
|
||||
// Generate audio
|
||||
const effectiveSettings = settings || audioSettings;
|
||||
const result = await podcastApi.renderSceneAudio({
|
||||
scene: currentScene,
|
||||
voiceId: effectiveSettings.voiceId || "Wise_Woman",
|
||||
emotion: effectiveSettings.emotion || scene.emotion || knobs.voice_emotion || "neutral",
|
||||
speed: effectiveSettings.speed ?? knobs.voice_speed ?? 1.0,
|
||||
volume: effectiveSettings.volume ?? 1.0,
|
||||
pitch: effectiveSettings.pitch ?? 0.0,
|
||||
englishNormalization: effectiveSettings.englishNormalization ?? true,
|
||||
sampleRate: effectiveSettings.sampleRate,
|
||||
bitrate: effectiveSettings.bitrate,
|
||||
channel: effectiveSettings.channel,
|
||||
format: effectiveSettings.format,
|
||||
languageBoost: effectiveSettings.languageBoost,
|
||||
});
|
||||
|
||||
// Update scene with audio URL and ensure approved state
|
||||
// This will sync with parent script state
|
||||
const updatedScene = { ...currentScene, audioUrl: result.audioUrl, approved: true };
|
||||
onUpdateScene(updatedScene);
|
||||
|
||||
if (onAudioGenerated) {
|
||||
onAudioGenerated(sceneId, result.audioUrl);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to approve and generate audio:", error);
|
||||
// On error, revert approval only if we just approved it in this call
|
||||
if (!wasAlreadyApproved) {
|
||||
onUpdateScene({ ...scene, approved: false, audioUrl: undefined });
|
||||
}
|
||||
throw error;
|
||||
} finally {
|
||||
setLocalGenerating(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleGenerateImage = async (settings?: ImageGenerationSettings) => {
|
||||
const sceneId = scene.id;
|
||||
const startTime = Date.now();
|
||||
let progressInterval: NodeJS.Timeout | null = null;
|
||||
|
||||
try {
|
||||
setGeneratingImage(true);
|
||||
setShowRegenerateModal(false);
|
||||
setImageGenerationStatus("Submitting image generation request...");
|
||||
setImageGenerationProgress(10);
|
||||
|
||||
// Build scene content from lines for context
|
||||
const sceneContent = scene.lines.map((line) => line.text).join(" ");
|
||||
|
||||
// Log avatar URL for debugging
|
||||
console.log("[SceneEditor] Generating image with avatarUrl:", avatarUrl);
|
||||
console.log("[SceneEditor] Custom settings:", settings);
|
||||
|
||||
// Simulate progress updates during API call
|
||||
progressInterval = setInterval(() => {
|
||||
const elapsed = Date.now() - startTime;
|
||||
const seconds = Math.floor(elapsed / 1000);
|
||||
|
||||
// Update status based on elapsed time
|
||||
if (seconds < 5) {
|
||||
setImageGenerationStatus("Submitting request to AI service...");
|
||||
setImageGenerationProgress(15);
|
||||
} else if (seconds < 15) {
|
||||
setImageGenerationStatus("AI is generating your image...");
|
||||
setImageGenerationProgress(30);
|
||||
} else if (seconds < 30) {
|
||||
setImageGenerationStatus("Creating character-consistent scene image...");
|
||||
setImageGenerationProgress(50);
|
||||
} else if (seconds < 60) {
|
||||
setImageGenerationStatus("Rendering image details...");
|
||||
setImageGenerationProgress(70);
|
||||
} else {
|
||||
setImageGenerationStatus(`Processing... (${seconds}s elapsed)`);
|
||||
setImageGenerationProgress(Math.min(90, 50 + (seconds - 30) / 2));
|
||||
}
|
||||
}, 1000);
|
||||
|
||||
const result = await podcastApi.generateSceneImage({
|
||||
sceneId: scene.id,
|
||||
sceneTitle: scene.title,
|
||||
sceneContent: sceneContent,
|
||||
baseAvatarUrl: avatarUrl || undefined, // Pass base avatar URL for character consistency
|
||||
idea: idea,
|
||||
width: 1024,
|
||||
height: 1024,
|
||||
// Pass custom settings if provided
|
||||
customPrompt: settings?.prompt,
|
||||
style: settings?.style,
|
||||
renderingSpeed: settings?.renderingSpeed,
|
||||
aspectRatio: settings?.aspectRatio,
|
||||
});
|
||||
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
progressInterval = null;
|
||||
}
|
||||
|
||||
setImageGenerationStatus("Finalizing image...");
|
||||
setImageGenerationProgress(95);
|
||||
|
||||
// Update scene with image URL
|
||||
const updatedScene = { ...scene, imageUrl: result.image_url };
|
||||
onUpdateScene(updatedScene);
|
||||
|
||||
const elapsed = Math.floor((Date.now() - startTime) / 1000);
|
||||
setImageGenerationStatus(`Image generated successfully in ${elapsed}s`);
|
||||
setImageGenerationProgress(100);
|
||||
|
||||
// Clear status after a moment
|
||||
setTimeout(() => {
|
||||
setImageGenerationStatus("");
|
||||
setImageGenerationProgress(0);
|
||||
}, 2000);
|
||||
} catch (error: any) {
|
||||
// Clear interval on error
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
progressInterval = null;
|
||||
}
|
||||
|
||||
console.error("Failed to generate image:", error);
|
||||
// Extract error message from response if available
|
||||
const errorMessage = error?.response?.data?.detail?.message
|
||||
|| error?.response?.data?.detail?.error
|
||||
|| error?.response?.data?.detail
|
||||
|| error?.message
|
||||
|| "Failed to generate image. Please try again.";
|
||||
console.error("Error details:", {
|
||||
status: error?.response?.status,
|
||||
statusText: error?.response?.statusText,
|
||||
data: error?.response?.data,
|
||||
message: errorMessage,
|
||||
});
|
||||
|
||||
setImageGenerationStatus(`Error: ${errorMessage}`);
|
||||
setImageGenerationProgress(0);
|
||||
|
||||
// Show user-friendly error message
|
||||
alert(`Image generation failed: ${errorMessage}`);
|
||||
throw error;
|
||||
} finally {
|
||||
// Ensure interval is cleared
|
||||
if (progressInterval) {
|
||||
clearInterval(progressInterval);
|
||||
}
|
||||
setGeneratingImage(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleRegenerateClick = () => {
|
||||
setShowRegenerateModal(true);
|
||||
};
|
||||
|
||||
const handleAudioRegenerateClick = () => {
|
||||
if (hasAudio) {
|
||||
setShowAudioModal(true);
|
||||
} else {
|
||||
handleApproveAndGenerate(audioSettings);
|
||||
}
|
||||
};
|
||||
|
||||
const handleAudioRegenerate = (settings: AudioGenerationSettings) => {
|
||||
setAudioSettings(settings);
|
||||
setShowAudioModal(false);
|
||||
handleApproveAndGenerate(settings);
|
||||
};
|
||||
|
||||
return (
|
||||
<GlassyCard sx={glassyCardSx}>
|
||||
<Stack spacing={2.5}>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="flex-start">
|
||||
<Box sx={{ flex: 1 }}>
|
||||
<Typography
|
||||
variant="h6"
|
||||
sx={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: 1.5,
|
||||
mb: 1,
|
||||
color: "#0f172a",
|
||||
fontWeight: 600,
|
||||
fontSize: "1.25rem",
|
||||
letterSpacing: "-0.01em",
|
||||
}}
|
||||
>
|
||||
<EditNoteIcon fontSize="small" sx={{ color: "#667eea", fontSize: "1.5rem" }} />
|
||||
{scene.title}
|
||||
</Typography>
|
||||
<Stack direction="row" spacing={1.5} alignItems="center" flexWrap="wrap">
|
||||
<Chip
|
||||
icon={scene.approved ? <CheckCircleIcon /> : <RadioButtonUncheckedIcon />}
|
||||
label={scene.approved ? "Approved" : "Pending Approval"}
|
||||
size="small"
|
||||
color={scene.approved ? "success" : "warning"}
|
||||
sx={{
|
||||
background: scene.approved
|
||||
? "linear-gradient(135deg, rgba(16, 185, 129, 0.12) 0%, rgba(5, 150, 105, 0.12) 100%)"
|
||||
: "linear-gradient(135deg, rgba(245, 158, 11, 0.12) 0%, rgba(217, 119, 6, 0.12) 100%)",
|
||||
color: scene.approved ? "#059669" : "#d97706",
|
||||
border: scene.approved
|
||||
? "1px solid rgba(16, 185, 129, 0.25)"
|
||||
: "1px solid rgba(245, 158, 11, 0.25)",
|
||||
fontWeight: 600,
|
||||
fontSize: "0.75rem",
|
||||
height: 26,
|
||||
boxShadow: "0 1px 2px rgba(0, 0, 0, 0.05)",
|
||||
}}
|
||||
/>
|
||||
<Typography variant="caption" sx={{ color: "#64748b", fontWeight: 500, fontSize: "0.8125rem" }}>
|
||||
Duration: {scene.duration}s
|
||||
</Typography>
|
||||
</Stack>
|
||||
</Box>
|
||||
<Stack direction="row" spacing={1.5} flexWrap="wrap" useFlexGap>
|
||||
<PrimaryButton
|
||||
onClick={handleAudioRegenerateClick}
|
||||
disabled={approving || generating}
|
||||
loading={approving || generating}
|
||||
startIcon={
|
||||
hasAudio && !generating ? (
|
||||
<VolumeUpIcon />
|
||||
) : generating ? (
|
||||
<CircularProgress size={16} sx={{ color: "white" }} />
|
||||
) : (
|
||||
<PlayArrowIcon />
|
||||
)
|
||||
}
|
||||
tooltip={
|
||||
hasAudio && !generating
|
||||
? "Regenerate audio for this scene with custom settings"
|
||||
: generating
|
||||
? "Generating audio..."
|
||||
: scene.approved
|
||||
? "Generate audio for this scene"
|
||||
: "Approve scene and generate audio"
|
||||
}
|
||||
sx={{
|
||||
minWidth: 200,
|
||||
}}
|
||||
>
|
||||
{hasAudio && !generating
|
||||
? "Regenerate Audio"
|
||||
: generating
|
||||
? "Generating Audio..."
|
||||
: scene.approved
|
||||
? "Generate Audio"
|
||||
: "Approve & Generate Audio"}
|
||||
</PrimaryButton>
|
||||
<PrimaryButton
|
||||
onClick={hasImage ? handleRegenerateClick : () => handleGenerateImage()}
|
||||
disabled={generatingImage}
|
||||
loading={generatingImage}
|
||||
startIcon={
|
||||
hasImage && !generatingImage ? (
|
||||
<ImageIcon />
|
||||
) : generatingImage ? (
|
||||
<CircularProgress size={16} sx={{ color: "white" }} />
|
||||
) : (
|
||||
<ImageIcon />
|
||||
)
|
||||
}
|
||||
tooltip={
|
||||
hasImage
|
||||
? "Regenerate image for this scene"
|
||||
: generatingImage
|
||||
? "Generating image..."
|
||||
: "Generate image for video (optional)"
|
||||
}
|
||||
sx={{
|
||||
minWidth: 180,
|
||||
background: hasImage
|
||||
? "linear-gradient(135deg, #10b981 0%, #059669 100%)"
|
||||
: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
"&:hover": {
|
||||
background: hasImage
|
||||
? "linear-gradient(135deg, #059669 0%, #047857 100%)"
|
||||
: "linear-gradient(135deg, #764ba2 0%, #667eea 100%)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
{hasImage && !generatingImage
|
||||
? "Regenerate Image"
|
||||
: generatingImage
|
||||
? "Generating Image..."
|
||||
: "Generate Image"}
|
||||
</PrimaryButton>
|
||||
|
||||
<Tooltip title={totalScenes && totalScenes <= 1 ? "Cannot delete the last scene" : "Delete this scene"}>
|
||||
<IconButton
|
||||
onClick={() => onDelete(scene.id)}
|
||||
disabled={approving || generating || (totalScenes !== undefined && totalScenes <= 1)}
|
||||
sx={{
|
||||
color: "#ef4444",
|
||||
backgroundColor: "rgba(239, 68, 68, 0.1)",
|
||||
border: "1px solid rgba(239, 68, 68, 0.2)",
|
||||
borderRadius: 2,
|
||||
padding: 1.5,
|
||||
"&:hover": {
|
||||
backgroundColor: "rgba(239, 68, 68, 0.15)",
|
||||
borderColor: "rgba(239, 68, 68, 0.3)",
|
||||
},
|
||||
"&:disabled": {
|
||||
backgroundColor: "rgba(156, 163, 175, 0.1)",
|
||||
borderColor: "rgba(156, 163, 175, 0.2)",
|
||||
color: "#9ca3af",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<DeleteIcon sx={{ fontSize: "1.25rem" }} />
|
||||
</IconButton>
|
||||
</Tooltip>
|
||||
</Stack>
|
||||
</Stack>
|
||||
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1 }} />
|
||||
|
||||
<Stack spacing={2}>
|
||||
{scene.lines.map((line) => (
|
||||
<LineEditor key={line.id} line={line} onChange={updateLine} />
|
||||
))}
|
||||
</Stack>
|
||||
|
||||
{scene.audioUrl && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1, mt: 1 }} />
|
||||
<Box
|
||||
sx={{
|
||||
p: 2,
|
||||
background: hasAudio
|
||||
? "linear-gradient(135deg, rgba(16, 185, 129, 0.08) 0%, rgba(5, 150, 105, 0.08) 100%)"
|
||||
: "linear-gradient(135deg, rgba(245, 158, 11, 0.08) 0%, rgba(217, 119, 6, 0.08) 100%)",
|
||||
borderRadius: 2,
|
||||
border: hasAudio
|
||||
? "1px solid rgba(16, 185, 129, 0.2)"
|
||||
: "1px solid rgba(245, 158, 11, 0.2)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5} sx={{ mb: 1.5 }}>
|
||||
<VolumeUpIcon sx={{ color: hasAudio ? "#059669" : "#d97706", fontSize: "1.25rem" }} />
|
||||
<Typography variant="subtitle2" sx={{ color: hasAudio ? "#059669" : "#d97706", fontWeight: 600 }}>
|
||||
{hasAudio ? "Audio Generated" : "Loading Audio..."}
|
||||
</Typography>
|
||||
</Stack>
|
||||
{hasAudio && audioBlobUrl ? (
|
||||
<audio controls style={{ width: "100%", borderRadius: 8 }}>
|
||||
<source src={audioBlobUrl} type="audio/mpeg" />
|
||||
Your browser does not support the audio element.
|
||||
</audio>
|
||||
) : (
|
||||
<Box sx={{ display: "flex", alignItems: "center", justifyContent: "center", py: 2 }}>
|
||||
<CircularProgress size={24} sx={{ color: "#d97706" }} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Image Generation Progress - Show when generating */}
|
||||
{generatingImage && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1, mt: 1 }} />
|
||||
<Box
|
||||
sx={{
|
||||
p: 2,
|
||||
background: "linear-gradient(135deg, rgba(102, 126, 234, 0.08) 0%, rgba(118, 75, 162, 0.08) 100%)",
|
||||
borderRadius: 2,
|
||||
border: "1px solid rgba(102, 126, 234, 0.2)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5} sx={{ mb: 1.5 }}>
|
||||
<ImageIcon sx={{ color: "#667eea", fontSize: "1.25rem" }} />
|
||||
<Typography variant="subtitle2" sx={{ color: "#667eea", fontWeight: 600 }}>
|
||||
Generating Image...
|
||||
</Typography>
|
||||
</Stack>
|
||||
|
||||
{/* Progress Bar */}
|
||||
<Box sx={{ mb: 1.5 }}>
|
||||
<LinearProgress
|
||||
variant="determinate"
|
||||
value={imageGenerationProgress}
|
||||
sx={{
|
||||
height: 8,
|
||||
borderRadius: 4,
|
||||
backgroundColor: alpha("#667eea", 0.1),
|
||||
"& .MuiLinearProgress-bar": {
|
||||
backgroundColor: "#667eea",
|
||||
borderRadius: 4,
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<Typography variant="caption" sx={{ color: "#667eea", mt: 0.5, display: "block", textAlign: "right" }}>
|
||||
{imageGenerationProgress}%
|
||||
</Typography>
|
||||
</Box>
|
||||
|
||||
{/* Status Message */}
|
||||
{imageGenerationStatus && (
|
||||
<Typography variant="body2" sx={{ color: "#667eea", fontSize: "0.875rem", lineHeight: 1.6, mb: 1 }}>
|
||||
{imageGenerationStatus}
|
||||
</Typography>
|
||||
)}
|
||||
|
||||
{/* Spinner */}
|
||||
<Box sx={{ display: "flex", alignItems: "center", justifyContent: "center", mt: 1 }}>
|
||||
<CircularProgress size={32} sx={{ color: "#667eea" }} />
|
||||
</Box>
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Generated Image Display - Show when image exists and not generating */}
|
||||
{scene.imageUrl && !generatingImage && (
|
||||
<>
|
||||
<Divider sx={{ borderColor: "rgba(15, 23, 42, 0.08)", borderWidth: 1, mt: 1 }} />
|
||||
<Box
|
||||
sx={{
|
||||
p: 2,
|
||||
background: imageBlobUrl && !imageLoading
|
||||
? "linear-gradient(135deg, rgba(102, 126, 234, 0.08) 0%, rgba(118, 75, 162, 0.08) 100%)"
|
||||
: "linear-gradient(135deg, rgba(245, 158, 11, 0.08) 0%, rgba(217, 119, 6, 0.08) 100%)",
|
||||
borderRadius: 2,
|
||||
border: imageBlobUrl && !imageLoading
|
||||
? "1px solid rgba(102, 126, 234, 0.2)"
|
||||
: "1px solid rgba(245, 158, 11, 0.2)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5} sx={{ mb: 1.5 }}>
|
||||
<ImageIcon sx={{ color: imageBlobUrl && !imageLoading ? "#667eea" : "#d97706", fontSize: "1.25rem" }} />
|
||||
<Typography variant="subtitle2" sx={{ color: imageBlobUrl && !imageLoading ? "#667eea" : "#d97706", fontWeight: 600 }}>
|
||||
{imageBlobUrl && !imageLoading ? "Image Generated" : "Loading Image..."}
|
||||
</Typography>
|
||||
</Stack>
|
||||
{imageBlobUrl && !imageLoading ? (
|
||||
<Box
|
||||
sx={{
|
||||
width: "100%",
|
||||
borderRadius: 2,
|
||||
overflow: "hidden",
|
||||
border: "1px solid rgba(102,126,234,0.2)",
|
||||
background: alpha("#667eea", 0.05),
|
||||
}}
|
||||
>
|
||||
<Box
|
||||
component="img"
|
||||
src={imageBlobUrl}
|
||||
alt={scene.title}
|
||||
sx={{
|
||||
width: "100%",
|
||||
height: "auto",
|
||||
display: "block",
|
||||
maxHeight: 400,
|
||||
objectFit: "cover",
|
||||
}}
|
||||
onError={(e) => {
|
||||
console.error('[SceneEditor] Image failed to load:', {
|
||||
src: e.currentTarget.src,
|
||||
imageUrl: scene.imageUrl,
|
||||
imageBlobUrl,
|
||||
});
|
||||
}}
|
||||
onLoad={() => {
|
||||
console.log('[SceneEditor] Image loaded successfully');
|
||||
}}
|
||||
/>
|
||||
</Box>
|
||||
) : (
|
||||
<Box sx={{ display: "flex", alignItems: "center", justifyContent: "center", py: 2 }}>
|
||||
<CircularProgress size={24} sx={{ color: "#d97706" }} />
|
||||
</Box>
|
||||
)}
|
||||
</Box>
|
||||
</>
|
||||
)}
|
||||
</Stack>
|
||||
|
||||
{/* Image Regeneration Modal */}
|
||||
<ImageRegenerateModal
|
||||
open={showRegenerateModal}
|
||||
onClose={() => setShowRegenerateModal(false)}
|
||||
onRegenerate={handleGenerateImage}
|
||||
initialPrompt={(() => {
|
||||
const promptParts = [
|
||||
`Scene: ${scene.title}`,
|
||||
"Professional podcast recording studio",
|
||||
"Modern microphone setup",
|
||||
"Clean background, professional lighting",
|
||||
"16:9 aspect ratio, video-optimized composition"
|
||||
];
|
||||
if (idea) {
|
||||
promptParts.push(`Topic: ${idea.substring(0, 60)}`);
|
||||
}
|
||||
return promptParts.join(", ");
|
||||
})()}
|
||||
initialStyle="Realistic"
|
||||
initialRenderingSpeed="Quality"
|
||||
initialAspectRatio="16:9"
|
||||
isGenerating={generatingImage}
|
||||
/>
|
||||
|
||||
<AudioRegenerateModal
|
||||
open={showAudioModal}
|
||||
onClose={() => setShowAudioModal(false)}
|
||||
onRegenerate={handleAudioRegenerate}
|
||||
initialSettings={audioSettings}
|
||||
isGenerating={generating}
|
||||
/>
|
||||
</GlassyCard>
|
||||
);
|
||||
};
|
||||
|
||||
818
_session_backup/ScriptEditor.tsx
Normal file
818
_session_backup/ScriptEditor.tsx
Normal file
@@ -0,0 +1,818 @@
|
||||
import React, { useEffect, useState, useCallback } from "react";
|
||||
import { Box, Stack, Typography, Alert, Paper, LinearProgress, CircularProgress, alpha, Collapse, IconButton, Divider } from "@mui/material";
|
||||
import { EditNote as EditNoteIcon, CheckCircle as CheckCircleIcon, PlayArrow as PlayArrowIcon, ArrowBack as ArrowBackIcon, Info as InfoIcon, ExpandMore as ExpandMoreIcon, ExpandLess as ExpandLessIcon, Download as DownloadIcon, Refresh as RefreshIcon } from "@mui/icons-material";
|
||||
import { Script, Knobs, Scene } from "../types";
|
||||
import { BlogResearchResponse } from "../../../services/blogWriterApi";
|
||||
import { podcastApi } from "../../../services/podcastApi";
|
||||
import { GlassyCard, PrimaryButton, SecondaryButton } from "../ui";
|
||||
import { SceneEditor } from "./SceneEditor";
|
||||
import { InlineAudioPlayer } from "../InlineAudioPlayer";
|
||||
import { aiApiClient } from "../../../api/client";
|
||||
|
||||
interface ScriptEditorProps {
|
||||
projectId: string;
|
||||
idea: string;
|
||||
research: any; // Research type
|
||||
rawResearch: BlogResearchResponse | null;
|
||||
knobs: Knobs;
|
||||
speakers: number;
|
||||
durationMinutes: number;
|
||||
script: Script | null;
|
||||
onScriptChange: (script: Script) => void;
|
||||
onBackToResearch: () => void;
|
||||
onProceedToRendering: (script: Script) => void;
|
||||
onError: (message: string) => void;
|
||||
avatarUrl?: string | null; // Base avatar URL for consistent scene image generation
|
||||
analysis?: any;
|
||||
outline?: any;
|
||||
}
|
||||
|
||||
export const ScriptEditor: React.FC<ScriptEditorProps> = ({
|
||||
projectId,
|
||||
idea,
|
||||
research,
|
||||
rawResearch,
|
||||
knobs,
|
||||
speakers,
|
||||
durationMinutes,
|
||||
script: initialScript,
|
||||
onScriptChange,
|
||||
onBackToResearch,
|
||||
onProceedToRendering,
|
||||
onError,
|
||||
avatarUrl,
|
||||
analysis,
|
||||
outline,
|
||||
}) => {
|
||||
const [script, setScript] = useState<Script | null>(initialScript);
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [approvingSceneId, setApprovingSceneId] = useState<string | null>(null);
|
||||
const [generatingAudioId, setGeneratingAudioId] = useState<string | null>(null);
|
||||
const [showScriptFormatInfo, setShowScriptFormatInfo] = useState(true);
|
||||
const [combiningAudio, setCombiningAudio] = useState(false);
|
||||
const [combinedAudioResult, setCombinedAudioResult] = useState<{
|
||||
url: string;
|
||||
filename: string;
|
||||
duration: number;
|
||||
sceneCount: number;
|
||||
} | null>(null);
|
||||
|
||||
// Defer upward script updates to avoid setState during render warnings
|
||||
const emitScriptChange = useCallback(
|
||||
(next: Script) => Promise.resolve().then(() => onScriptChange(next)),
|
||||
[onScriptChange]
|
||||
);
|
||||
|
||||
// Sync with parent state
|
||||
useEffect(() => {
|
||||
if (initialScript) {
|
||||
setScript(initialScript);
|
||||
}
|
||||
}, [initialScript]);
|
||||
|
||||
useEffect(() => {
|
||||
// If script already exists, don't regenerate
|
||||
if (script) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only generate if we have research data
|
||||
if (!rawResearch) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mounted = true;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
podcastApi
|
||||
.generateScript({
|
||||
projectId,
|
||||
idea,
|
||||
research: rawResearch,
|
||||
knobs,
|
||||
speakers,
|
||||
durationMinutes,
|
||||
analysis,
|
||||
outline,
|
||||
})
|
||||
.then((res) => {
|
||||
if (mounted) {
|
||||
setScript(res);
|
||||
emitScriptChange(res);
|
||||
setError(null);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
const message = err instanceof Error ? err.message : "Failed to generate script";
|
||||
setError(message);
|
||||
onError(message);
|
||||
})
|
||||
.finally(() => mounted && setLoading(false));
|
||||
return () => {
|
||||
mounted = false;
|
||||
};
|
||||
}, [projectId, rawResearch, idea, knobs, speakers, durationMinutes, analysis, outline, emitScriptChange, onError, script]);
|
||||
|
||||
const updateScene = (updated: Scene) => {
|
||||
// Use functional update to ensure we're working with latest state
|
||||
setScript((currentScript) => {
|
||||
if (!currentScript) return currentScript;
|
||||
const updatedScript = {
|
||||
...currentScript,
|
||||
scenes: currentScript.scenes.map((s) => (s.id === updated.id ? { ...s, ...updated } : s))
|
||||
};
|
||||
emitScriptChange(updatedScript);
|
||||
return updatedScript;
|
||||
});
|
||||
};
|
||||
|
||||
const approveScene = async (sceneId: string) => {
|
||||
try {
|
||||
setApprovingSceneId(sceneId);
|
||||
await podcastApi.approveScene({ projectId, sceneId });
|
||||
// Use functional update to ensure we're working with latest state
|
||||
setScript((currentScript) => {
|
||||
if (!currentScript) return currentScript;
|
||||
const updatedScript = {
|
||||
...currentScript,
|
||||
scenes: currentScript.scenes.map((s) => (s.id === sceneId ? { ...s, approved: true } : s)),
|
||||
};
|
||||
emitScriptChange(updatedScript);
|
||||
return updatedScript;
|
||||
});
|
||||
} catch (err) {
|
||||
const message = err instanceof Error ? err.message : "Failed to approve scene";
|
||||
setError(message);
|
||||
onError(message);
|
||||
throw err;
|
||||
} finally {
|
||||
setApprovingSceneId((current) => (current === sceneId ? null : current));
|
||||
}
|
||||
};
|
||||
|
||||
const deleteScene = useCallback((sceneId: string) => {
|
||||
if (!script) return;
|
||||
|
||||
// Prevent deleting if it's the last scene
|
||||
if (script.scenes.length <= 1) {
|
||||
onError("Cannot delete the last scene. At least one scene is required.");
|
||||
return;
|
||||
}
|
||||
|
||||
// Add confirmation dialog
|
||||
const sceneToDelete = script.scenes.find(s => s.id === sceneId);
|
||||
if (!sceneToDelete) return;
|
||||
|
||||
const confirmDelete = window.confirm(
|
||||
`Are you sure you want to delete "${sceneToDelete.title}"? This action cannot be undone.`
|
||||
);
|
||||
|
||||
if (!confirmDelete) return;
|
||||
|
||||
// Remove the scene from the script
|
||||
const updatedScenes = script.scenes.filter(s => s.id !== sceneId);
|
||||
const updatedScript = { ...script, scenes: updatedScenes };
|
||||
|
||||
emitScriptChange(updatedScript);
|
||||
setScript(updatedScript);
|
||||
|
||||
// Show success message
|
||||
console.log(`[ScriptEditor] Scene "${sceneToDelete.title}" deleted successfully`);
|
||||
}, [script, emitScriptChange, onError]);
|
||||
|
||||
const allApproved = script && script.scenes.every((s) => s.approved);
|
||||
const approvedCount = script ? script.scenes.filter((s) => s.approved).length : 0;
|
||||
const totalScenes = script ? script.scenes.length : 0;
|
||||
|
||||
// Check if all scenes have both audio and images (required for video rendering)
|
||||
const allScenesHaveAudioAndImages = script && script.scenes.every((s) => s.audioUrl && s.imageUrl);
|
||||
const scenesWithAudio = script ? script.scenes.filter((s) => s.audioUrl).length : 0;
|
||||
const allScenesHaveAudio = script && script.scenes.every((s) => s.audioUrl);
|
||||
|
||||
const combineAudio = useCallback(async () => {
|
||||
if (!script || !projectId) return;
|
||||
|
||||
try {
|
||||
setCombiningAudio(true);
|
||||
|
||||
const sceneIds: string[] = [];
|
||||
const sceneAudioUrls: string[] = [];
|
||||
|
||||
script.scenes.forEach((scene) => {
|
||||
if (scene.audioUrl) {
|
||||
// Ensure we're using the correct URL format (not blob URLs)
|
||||
const audioUrl = scene.audioUrl.startsWith('blob:') ? '' : scene.audioUrl;
|
||||
if (audioUrl) {
|
||||
sceneIds.push(scene.id);
|
||||
sceneAudioUrls.push(audioUrl);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (sceneIds.length === 0) {
|
||||
onError("No audio files found to combine.");
|
||||
return;
|
||||
}
|
||||
|
||||
const result = await podcastApi.combineAudio({
|
||||
projectId,
|
||||
sceneIds,
|
||||
sceneAudioUrls,
|
||||
});
|
||||
|
||||
// Store combined audio result for preview
|
||||
setCombinedAudioResult({
|
||||
url: result.combined_audio_url,
|
||||
filename: result.combined_audio_filename,
|
||||
duration: result.total_duration,
|
||||
sceneCount: result.scene_count,
|
||||
});
|
||||
|
||||
// Download the combined audio as blob (for authenticated endpoints)
|
||||
try {
|
||||
// Normalize path
|
||||
let audioPath = result.combined_audio_url.startsWith('/')
|
||||
? result.combined_audio_url
|
||||
: `/${result.combined_audio_url}`;
|
||||
|
||||
// Ensure it's a podcast audio endpoint
|
||||
if (!audioPath.includes('/api/podcast/audio/')) {
|
||||
const filename = audioPath.split('/').pop() || result.combined_audio_filename;
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
audioPath = audioPath.split('?')[0];
|
||||
|
||||
// Fetch as blob using authenticated client
|
||||
const response = await aiApiClient.get(audioPath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
// Create blob URL and download
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = blobUrl;
|
||||
link.download = result.combined_audio_filename || `podcast-episode-${projectId.slice(-8)}.mp3`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
|
||||
// Clean up blob URL after a delay
|
||||
setTimeout(() => {
|
||||
URL.revokeObjectURL(blobUrl);
|
||||
}, 100);
|
||||
} catch (downloadError) {
|
||||
console.error('Failed to download combined audio:', downloadError);
|
||||
onError('Failed to download audio file. You can try downloading again from the preview.');
|
||||
}
|
||||
} catch (error) {
|
||||
const message = error instanceof Error ? error.message : "Failed to combine audio";
|
||||
onError(`Failed to combine audio: ${message}`);
|
||||
} finally {
|
||||
setCombiningAudio(false);
|
||||
}
|
||||
}, [script, projectId, onError]);
|
||||
|
||||
return (
|
||||
<Box sx={{ mt: 4 }}>
|
||||
<Stack direction="row" spacing={2} alignItems="center" sx={{ mb: 4 }}>
|
||||
<SecondaryButton onClick={onBackToResearch} startIcon={<ArrowBackIcon />}>
|
||||
Back to Research
|
||||
</SecondaryButton>
|
||||
<Box sx={{ flex: 1 }}>
|
||||
<Typography
|
||||
variant="h4"
|
||||
sx={{
|
||||
background: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
WebkitBackgroundClip: "text",
|
||||
WebkitTextFillColor: "transparent",
|
||||
fontWeight: 700,
|
||||
letterSpacing: "-0.02em",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: 1.5,
|
||||
fontSize: { xs: "1.75rem", md: "2rem" },
|
||||
}}
|
||||
>
|
||||
<EditNoteIcon sx={{ fontSize: "2rem" }} />
|
||||
Script Editor
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#64748b", mt: 0.5, ml: 5.5 }}>
|
||||
Review and refine your podcast script before rendering
|
||||
</Typography>
|
||||
</Box>
|
||||
</Stack>
|
||||
|
||||
{loading && (
|
||||
<Alert
|
||||
severity="info"
|
||||
icon={<CircularProgress size={20} />}
|
||||
sx={{
|
||||
mb: 3,
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.08) 0%, rgba(139, 92, 246, 0.08) 100%)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.2)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 1px 2px rgba(99, 102, 241, 0.05)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#6366f1",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", fontWeight: 500 }}>
|
||||
Generating script with AI... This may take a moment.
|
||||
</Typography>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<Alert
|
||||
severity="error"
|
||||
sx={{
|
||||
mb: 3,
|
||||
background: "linear-gradient(135deg, rgba(239, 68, 68, 0.08) 0%, rgba(220, 38, 38, 0.08) 100%)",
|
||||
border: "1px solid rgba(239, 68, 68, 0.2)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 1px 2px rgba(239, 68, 68, 0.05)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#ef4444",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", fontWeight: 500 }}>
|
||||
{error}
|
||||
</Typography>
|
||||
</Alert>
|
||||
)}
|
||||
|
||||
{script && (
|
||||
<Stack spacing={3}>
|
||||
{/* Script Format Explanation Panel */}
|
||||
<Paper
|
||||
sx={{
|
||||
p: 3,
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.05) 0%, rgba(139, 92, 246, 0.05) 100%)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.15)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 2px 8px rgba(99, 102, 241, 0.08)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" alignItems="center" justifyContent="space-between" sx={{ mb: showScriptFormatInfo ? 2 : 0 }}>
|
||||
<Stack direction="row" alignItems="center" spacing={1.5}>
|
||||
<Box
|
||||
sx={{
|
||||
width: 40,
|
||||
height: 40,
|
||||
borderRadius: "50%",
|
||||
background: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
boxShadow: "0 2px 8px rgba(102, 126, 234, 0.3)",
|
||||
}}
|
||||
>
|
||||
<InfoIcon sx={{ color: "#ffffff", fontSize: "1.5rem" }} />
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="h6" sx={{ color: "#0f172a", fontWeight: 600, fontSize: "1.1rem" }}>
|
||||
Why This Script Format?
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#64748b", mt: 0.25 }}>
|
||||
Understanding how your script creates natural, human-like audio
|
||||
</Typography>
|
||||
</Box>
|
||||
</Stack>
|
||||
<IconButton
|
||||
onClick={() => setShowScriptFormatInfo(!showScriptFormatInfo)}
|
||||
sx={{
|
||||
color: "#6366f1",
|
||||
"&:hover": {
|
||||
background: "rgba(99, 102, 241, 0.1)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
{showScriptFormatInfo ? <ExpandLessIcon /> : <ExpandMoreIcon />}
|
||||
</IconButton>
|
||||
</Stack>
|
||||
|
||||
<Collapse in={showScriptFormatInfo}>
|
||||
<Stack spacing={2.5}>
|
||||
<Box>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", lineHeight: 1.8, mb: 2 }}>
|
||||
Our AI script generator creates scripts specifically optimized for <strong style={{ fontWeight: 600 }}>high-quality text-to-speech</strong>.
|
||||
The format you see here is designed to produce audio that sounds natural and human-like, not robotic.
|
||||
</Typography>
|
||||
</Box>
|
||||
|
||||
<Stack spacing={2}>
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
1
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Natural Pauses & Rhythm
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
The script includes strategic pauses between lines and when speakers change. This creates natural breathing patterns
|
||||
and conversation flow, just like real human speech. Without these pauses, the audio would sound rushed and robotic.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
2
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Emphasis Markers
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
Lines marked with emphasis help highlight important points, statistics, or key insights. The AI voice will naturally
|
||||
stress these parts, making your podcast more engaging and easier to follow—just like a real host would emphasize important information.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
3
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Short, Conversational Sentences
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
The script uses shorter sentences (15-20 words) written in a conversational style. This matches how people actually
|
||||
speak, making the audio sound more natural. Long, complex sentences would sound awkward when spoken aloud.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
4
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Scene-Specific Emotions
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
Each scene has an emotional tone (excited, serious, curious, etc.) that guides the AI voice's delivery. This creates
|
||||
variety and keeps listeners engaged, just like a real podcast host would vary their tone based on the topic.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
|
||||
<Box sx={{ display: "flex", gap: 2 }}>
|
||||
<Box
|
||||
sx={{
|
||||
minWidth: 32,
|
||||
height: 32,
|
||||
borderRadius: "8px",
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%)",
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: "center",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#6366f1", fontWeight: 700 }}>
|
||||
5
|
||||
</Typography>
|
||||
</Box>
|
||||
<Box>
|
||||
<Typography variant="subtitle2" sx={{ color: "#0f172a", fontWeight: 600, mb: 0.5 }}>
|
||||
Optimized for Podcast Narration
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#475569", lineHeight: 1.7 }}>
|
||||
The script is optimized with slightly slower pacing and natural pronunciation settings specifically for podcast narration.
|
||||
This ensures clarity and makes the content easy to understand, even when listeners are multitasking.
|
||||
</Typography>
|
||||
</Box>
|
||||
</Box>
|
||||
</Stack>
|
||||
|
||||
<Alert
|
||||
severity="info"
|
||||
sx={{
|
||||
mt: 1,
|
||||
background: "rgba(99, 102, 241, 0.06)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.15)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#6366f1",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", lineHeight: 1.7 }}>
|
||||
<strong style={{ fontWeight: 600 }}>Tip:</strong> You can edit any line or scene to match your preferences.
|
||||
The format will be preserved when rendering, ensuring your audio still sounds natural and professional.
|
||||
</Typography>
|
||||
</Alert>
|
||||
</Stack>
|
||||
</Collapse>
|
||||
</Paper>
|
||||
|
||||
<Alert
|
||||
severity="info"
|
||||
sx={{
|
||||
background: "linear-gradient(135deg, rgba(99, 102, 241, 0.08) 0%, rgba(139, 92, 246, 0.08) 100%)",
|
||||
border: "1px solid rgba(99, 102, 241, 0.2)",
|
||||
borderRadius: 2,
|
||||
boxShadow: "0 1px 2px rgba(99, 102, 241, 0.05)",
|
||||
"& .MuiAlert-icon": {
|
||||
color: "#6366f1",
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#0f172a", fontWeight: 500, lineHeight: 1.6 }}>
|
||||
<strong style={{ fontWeight: 600 }}>Approval Required:</strong> Each scene must be approved before rendering. Review and edit lines as needed, then approve each scene.
|
||||
</Typography>
|
||||
</Alert>
|
||||
|
||||
<Stack spacing={2}>
|
||||
{script.scenes.map((scene, idx) => (
|
||||
<GlassyCard
|
||||
key={scene.id}
|
||||
initial={{ opacity: 0, y: 8 }}
|
||||
animate={{ opacity: 1, y: 0 }}
|
||||
transition={{ duration: 0.3, delay: idx * 0.1 }}
|
||||
>
|
||||
<SceneEditor
|
||||
scene={scene}
|
||||
onUpdateScene={updateScene}
|
||||
onApprove={approveScene}
|
||||
onDelete={deleteScene}
|
||||
knobs={knobs}
|
||||
approvingSceneId={approvingSceneId}
|
||||
generatingAudioId={generatingAudioId}
|
||||
totalScenes={script.scenes.length}
|
||||
onAudioGenerationStart={(sceneId) => {
|
||||
setGeneratingAudioId(sceneId);
|
||||
}}
|
||||
onAudioGenerated={async (sceneId, audioUrl) => {
|
||||
setGeneratingAudioId(null);
|
||||
// Use functional update to ensure we're working with latest state
|
||||
// Ensure scene is marked as approved and has audioUrl
|
||||
setScript((currentScript) => {
|
||||
if (!currentScript) return currentScript;
|
||||
const updatedScenes = currentScript.scenes.map((s) =>
|
||||
s.id === sceneId ? { ...s, audioUrl, approved: true } : s
|
||||
);
|
||||
const updatedScript = { ...currentScript, scenes: updatedScenes };
|
||||
emitScriptChange(updatedScript);
|
||||
return updatedScript;
|
||||
});
|
||||
}}
|
||||
idea={idea}
|
||||
avatarUrl={avatarUrl}
|
||||
/>
|
||||
</GlassyCard>
|
||||
))}
|
||||
</Stack>
|
||||
|
||||
<Paper
|
||||
sx={{
|
||||
p: 3.5,
|
||||
background: allApproved
|
||||
? "linear-gradient(135deg, rgba(16, 185, 129, 0.05) 0%, rgba(5, 150, 105, 0.05) 100%)"
|
||||
: "#ffffff",
|
||||
border: allApproved
|
||||
? "2px solid rgba(16, 185, 129, 0.25)"
|
||||
: "1px solid rgba(15, 23, 42, 0.08)",
|
||||
borderRadius: 3,
|
||||
boxShadow: allApproved
|
||||
? "0 4px 6px rgba(16, 185, 129, 0.08), 0 8px 24px rgba(16, 185, 129, 0.06)"
|
||||
: "0 1px 3px rgba(15, 23, 42, 0.06), 0 4px 12px rgba(15, 23, 42, 0.04)",
|
||||
transition: "all 0.3s cubic-bezier(0.4, 0, 0.2, 1)",
|
||||
}}
|
||||
>
|
||||
<Stack direction="row" justifyContent="space-between" alignItems="center">
|
||||
<Box>
|
||||
<Typography variant="subtitle1" sx={{ mb: 1, display: "flex", alignItems: "center", gap: 1.5, color: "#0f172a", fontWeight: 600, fontSize: "1.1rem" }}>
|
||||
<CheckCircleIcon fontSize="small" sx={{ color: allApproved ? "#10b981" : "#94a3b8", fontSize: "1.25rem" }} />
|
||||
Approval Status
|
||||
</Typography>
|
||||
<Typography variant="body2" sx={{ color: "#64748b", fontWeight: 400, lineHeight: 1.6 }}>
|
||||
{approvedCount} of {totalScenes} scenes approved
|
||||
{allScenesHaveAudioAndImages && " • All scenes ready for video rendering"}
|
||||
{!allScenesHaveAudioAndImages && allApproved && " • Generate images for all scenes to enable video rendering"}
|
||||
{!allApproved && " — Approve all scenes first"}
|
||||
</Typography>
|
||||
{!allScenesHaveAudioAndImages && (
|
||||
<LinearProgress
|
||||
variant="determinate"
|
||||
value={
|
||||
allScenesHaveAudioAndImages
|
||||
? 100
|
||||
: script
|
||||
? (script.scenes.filter((s) => s.audioUrl && s.imageUrl).length / totalScenes) * 100
|
||||
: 0
|
||||
}
|
||||
sx={{ mt: 1, height: 6, borderRadius: 3 }}
|
||||
/>
|
||||
)}
|
||||
</Box>
|
||||
<PrimaryButton
|
||||
onClick={() => script && onProceedToRendering(script)}
|
||||
disabled={!allScenesHaveAudioAndImages}
|
||||
startIcon={<PlayArrowIcon />}
|
||||
tooltip={
|
||||
!allScenesHaveAudioAndImages
|
||||
? "Generate audio and images for all scenes to proceed to video rendering"
|
||||
: "Proceed to video rendering (all scenes have audio and images)"
|
||||
}
|
||||
>
|
||||
Proceed to Rendering
|
||||
</PrimaryButton>
|
||||
</Stack>
|
||||
</Paper>
|
||||
|
||||
{/* Download Audio-Only Podcast Section */}
|
||||
{allScenesHaveAudio && (
|
||||
<Paper
|
||||
sx={{
|
||||
p: 3,
|
||||
background: "linear-gradient(135deg, rgba(102, 126, 234, 0.05) 0%, rgba(118, 75, 162, 0.05) 100%)",
|
||||
border: "1px solid rgba(102, 126, 234, 0.15)",
|
||||
borderRadius: 2,
|
||||
}}
|
||||
>
|
||||
<Stack spacing={3}>
|
||||
<Typography variant="h6" sx={{ color: "#0f172a", fontWeight: 600 }}>
|
||||
Download Audio-Only Podcast
|
||||
</Typography>
|
||||
|
||||
{!combinedAudioResult ? (
|
||||
<>
|
||||
<PrimaryButton
|
||||
onClick={combineAudio}
|
||||
disabled={combiningAudio}
|
||||
loading={combiningAudio}
|
||||
startIcon={<DownloadIcon />}
|
||||
tooltip="Combine all scene audio files into a single podcast episode"
|
||||
sx={{
|
||||
minWidth: 280,
|
||||
fontSize: "1rem",
|
||||
py: 1.5,
|
||||
background: "linear-gradient(135deg, #667eea 0%, #764ba2 100%)",
|
||||
"&:hover": {
|
||||
background: "linear-gradient(135deg, #764ba2 0%, #667eea 100%)",
|
||||
},
|
||||
}}
|
||||
>
|
||||
{combiningAudio ? "Combining Audio..." : "Download Audio-Only Podcast"}
|
||||
</PrimaryButton>
|
||||
<Typography variant="caption" sx={{ color: "#64748b", fontStyle: "italic" }}>
|
||||
This will combine all {scenesWithAudio} scene audio files into one complete podcast episode.
|
||||
</Typography>
|
||||
</>
|
||||
) : (
|
||||
<Stack spacing={2}>
|
||||
{/* Success Alert */}
|
||||
<Alert
|
||||
severity="success"
|
||||
sx={{
|
||||
background: alpha("#10b981", 0.1),
|
||||
border: "1px solid rgba(16,185,129,0.3)",
|
||||
"& .MuiAlert-icon": { color: "#10b981" },
|
||||
}}
|
||||
>
|
||||
<Typography variant="body2" sx={{ color: "#059669", fontWeight: 500 }}>
|
||||
✅ Combined audio generated successfully! ({combinedAudioResult.sceneCount} scenes,{" "}
|
||||
{Math.round(combinedAudioResult.duration)}s)
|
||||
</Typography>
|
||||
</Alert>
|
||||
|
||||
{/* Combined Audio Preview */}
|
||||
<InlineAudioPlayer audioUrl={combinedAudioResult.url} title="Complete Podcast Episode" />
|
||||
|
||||
{/* Action Buttons */}
|
||||
<Stack direction="row" spacing={2}>
|
||||
<SecondaryButton
|
||||
onClick={async () => {
|
||||
try {
|
||||
// Normalize path
|
||||
let audioPath = combinedAudioResult.url.startsWith('/')
|
||||
? combinedAudioResult.url
|
||||
: `/${combinedAudioResult.url}`;
|
||||
|
||||
// Ensure it's a podcast audio endpoint
|
||||
if (!audioPath.includes('/api/podcast/audio/')) {
|
||||
const filename = audioPath.split('/').pop() || combinedAudioResult.filename;
|
||||
audioPath = `/api/podcast/audio/${filename}`;
|
||||
}
|
||||
|
||||
// Remove query parameters if present
|
||||
audioPath = audioPath.split('?')[0];
|
||||
|
||||
// Fetch as blob using authenticated client
|
||||
const response = await aiApiClient.get(audioPath, {
|
||||
responseType: 'blob',
|
||||
});
|
||||
|
||||
// Create blob URL and download
|
||||
const blob = response.data;
|
||||
const blobUrl = URL.createObjectURL(blob);
|
||||
|
||||
const link = document.createElement("a");
|
||||
link.href = blobUrl;
|
||||
link.download = combinedAudioResult.filename || `podcast-episode-${projectId.slice(-8)}.mp3`;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
|
||||
// Clean up blob URL after a delay
|
||||
setTimeout(() => {
|
||||
URL.revokeObjectURL(blobUrl);
|
||||
}, 100);
|
||||
} catch (error) {
|
||||
console.error('Failed to download audio:', error);
|
||||
onError('Failed to download audio file. Please try again.');
|
||||
}
|
||||
}}
|
||||
startIcon={<DownloadIcon />}
|
||||
tooltip="Download the combined audio file again"
|
||||
>
|
||||
Download Again
|
||||
</SecondaryButton>
|
||||
<SecondaryButton
|
||||
onClick={() => {
|
||||
setCombinedAudioResult(null);
|
||||
combineAudio();
|
||||
}}
|
||||
disabled={combiningAudio}
|
||||
loading={combiningAudio}
|
||||
startIcon={<RefreshIcon />}
|
||||
tooltip="Regenerate combined audio (useful if scenes were updated)"
|
||||
>
|
||||
Regenerate
|
||||
</SecondaryButton>
|
||||
</Stack>
|
||||
</Stack>
|
||||
)}
|
||||
</Stack>
|
||||
</Paper>
|
||||
)}
|
||||
</Stack>
|
||||
)}
|
||||
</Box>
|
||||
);
|
||||
};
|
||||
|
||||
334
_session_backup/analysis.py
Normal file
334
_session_backup/analysis.py
Normal file
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Podcast Analysis Handlers
|
||||
|
||||
Analysis endpoint for podcast ideas.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..models import (
|
||||
PodcastAnalyzeRequest,
|
||||
PodcastAnalyzeResponse,
|
||||
PodcastEnhanceIdeaRequest,
|
||||
PodcastEnhanceIdeaResponse
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/idea/enhance", response_model=PodcastEnhanceIdeaResponse)
|
||||
async def enhance_podcast_idea(
|
||||
request: PodcastEnhanceIdeaRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Take raw keywords/topic and use AI to craft a presentable, detailed podcast idea.
|
||||
Uses the user's Podcast Bible for hyper-personalization if available.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Serialize Bible context if provided or generate from onboarding
|
||||
bible_context = ""
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_enhance")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Enhance] Failed to parse or generate bible context: {exc}")
|
||||
|
||||
prompt = f"""
|
||||
You are a creative podcast producer. Generate 3 distinct, compelling podcast episode concepts from the raw idea.
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{request.idea}"
|
||||
|
||||
TASK:
|
||||
Generate 3 different enhanced versions, each with a unique angle:
|
||||
1. Professional & Expert-led angle (focus on authority, insights, and expertise)
|
||||
2. Storytelling & Human interest angle (focus on narratives, emotions, and personal connections)
|
||||
3. Trendy & Contemporary angle (focus on current trends, modern perspectives, and relevance)
|
||||
|
||||
Each version should be 2-3 sentences, audience-focused, and align with host persona if provided.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_ideas: array of 3 enhanced episode pitches (in order: Professional, Storytelling, Trendy)
|
||||
- rationales: array of 3 rationales explaining the approach for each version
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(raw, str):
|
||||
data = json.loads(raw)
|
||||
else:
|
||||
data = raw
|
||||
|
||||
# Extract enhanced ideas and rationales with fallbacks
|
||||
enhanced_ideas = data.get("enhanced_ideas", [])
|
||||
rationales = data.get("rationales", [])
|
||||
|
||||
# Ensure we have exactly 3 ideas, fallback to original if needed
|
||||
if not isinstance(enhanced_ideas, list) or len(enhanced_ideas) != 3:
|
||||
# Fallback: create 3 variations of the original idea
|
||||
base_idea = request.idea
|
||||
enhanced_ideas = [
|
||||
f"Expert insights on {base_idea}: A deep dive into industry trends and best practices.",
|
||||
f"The human side of {base_idea}: Personal stories and real-world experiences that resonate.",
|
||||
f"Modern perspectives on {base_idea}: Current trends and forward-thinking approaches."
|
||||
]
|
||||
rationales = [
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
|
||||
# Ensure rationales match the number of ideas
|
||||
if not isinstance(rationales, list) or len(rationales) != 3:
|
||||
rationales = [
|
||||
"Professional angle with expert insights",
|
||||
"Storytelling angle with human interest",
|
||||
"Trendy angle with contemporary relevance"
|
||||
]
|
||||
|
||||
return PodcastEnhanceIdeaResponse(
|
||||
enhanced_ideas=enhanced_ideas[:3], # Ensure exactly 3
|
||||
rationales=rationales[:3] # Ensure exactly 3
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Enhance] Failed for user {user_id}: {exc}")
|
||||
# Fallback to basic variations of original idea
|
||||
base_idea = request.idea
|
||||
return PodcastEnhanceIdeaResponse(
|
||||
enhanced_ideas=[
|
||||
f"Expert insights on {base_idea}: A deep dive into industry trends and best practices.",
|
||||
f"The human side of {base_idea}: Personal stories and real-world experiences that resonate.",
|
||||
f"Modern perspectives on {base_idea}: Current trends and forward-thinking approaches."
|
||||
],
|
||||
rationales=[
|
||||
"Professional approach focusing on expertise and authority",
|
||||
"Storytelling approach emphasizing human connection",
|
||||
"Contemporary approach highlighting current relevance"
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=PodcastAnalyzeResponse)
|
||||
async def analyze_podcast_idea(
|
||||
request: PodcastAnalyzeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Analyze a podcast idea and return podcast-oriented outlines, keywords, and titles.
|
||||
If no avatar_url is provided, it generates one automatically based on the host's look.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Serialize Bible context if provided or generate from onboarding
|
||||
bible_context = ""
|
||||
bible_obj = None
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
bible_obj = bible_data
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_analyze")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
bible_obj = bible_obj
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Analyze] Failed to parse or generate bible context: {exc}")
|
||||
|
||||
# --- NEW: Generate Presenter Avatar if missing ---
|
||||
final_avatar_url = request.avatar_url
|
||||
final_avatar_prompt = None
|
||||
|
||||
if not final_avatar_url:
|
||||
logger.info(f"[Podcast Analyze] No avatar_url provided, generating one for user {user_id}")
|
||||
try:
|
||||
# 1. PRE-FLIGHT VALIDATION: Check subscription limits for image generation
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
pricing_service = PricingService(db)
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# 2. Build avatar prompt from Bible host look or fallback
|
||||
host_look = bible_obj.host.look if bible_obj and bible_obj.host.look else "A professional podcast host"
|
||||
visual_style = bible_obj.visual_style.style_preset if bible_obj else "Realistic Photography"
|
||||
|
||||
final_avatar_prompt = f"Professional headshot of a podcast host, {host_look}, {visual_style} style, clean background, soft studio lighting, center-focused, high resolution, sharp focus, professional photography quality, 16:9 aspect ratio."
|
||||
|
||||
# 3. Generate the image
|
||||
logger.info(f"[Podcast Analyze] Generating avatar with prompt: {final_avatar_prompt}")
|
||||
image_result = generate_image(
|
||||
prompt=final_avatar_prompt,
|
||||
user_id=user_id,
|
||||
width=1024,
|
||||
height=1024
|
||||
)
|
||||
|
||||
# 4. Save to disk and library
|
||||
if image_result and image_result.image_bytes:
|
||||
img_id = str(uuid.uuid4())[:8]
|
||||
filename = f"presenter_podcast_{user_id}_{img_id}.png"
|
||||
output_path = PODCAST_IMAGES_DIR / filename
|
||||
PODCAST_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_result.image_bytes)
|
||||
|
||||
final_avatar_url = f"/api/podcast/images/avatars/{filename}"
|
||||
|
||||
# Save to asset library for reuse
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
file_url=final_avatar_url,
|
||||
filename=filename,
|
||||
title=f"Presenter Avatar - {request.idea[:40]}",
|
||||
description=f"AI-generated podcast presenter for: {request.idea}",
|
||||
provider=image_result.provider,
|
||||
model=image_result.model,
|
||||
cost=image_result.cost
|
||||
)
|
||||
logger.info(f"[Podcast Analyze] ✅ Generated and saved avatar to {final_avatar_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast Analyze] ❌ Failed to generate avatar: {e}")
|
||||
# Non-fatal: continue analysis even if avatar generation fails
|
||||
|
||||
# --- END: Avatar Generation ---
|
||||
|
||||
# Incorporate user feedback if provided
|
||||
feedback_context = ""
|
||||
if request.feedback:
|
||||
feedback_context = f"""
|
||||
USER REGENERATION FEEDBACK:
|
||||
The user was not satisfied with the previous analysis. They provided the following instructions for improvement:
|
||||
"{request.feedback}"
|
||||
Please prioritize this feedback and adjust the analysis accordingly.
|
||||
"""
|
||||
|
||||
prompt = f"""
|
||||
You are an expert podcast producer and research strategist. Given a podcast idea, craft concise podcast-ready assets
|
||||
that sound like episode plans (not fiction stories).
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
{feedback_context}
|
||||
|
||||
Podcast Idea: "{request.idea}"
|
||||
Duration: ~{request.duration} minutes
|
||||
Speakers: {request.speakers} (host + optional guest)
|
||||
|
||||
TASK:
|
||||
1. Define the target audience and content type aligned with the Bible's "Audience DNA" and "Brand DNA".
|
||||
2. Identify 5 high-impact keywords.
|
||||
3. Propose 2 episode outlines with factual segments.
|
||||
4. Suggest 3 titles.
|
||||
5. IMPORTANT: Generate 4-6 specific research queries for Exa. These queries MUST be highly targeted to the episode's topic, the host's expertise level, and the audience's interests as defined in the Bible.
|
||||
* Do NOT use generic queries like "latest trends in X".
|
||||
* DO use queries that look for case studies, specific data points, expert opinions, or contrasting viewpoints that would make for a deep, insightful podcast conversation.
|
||||
|
||||
Return JSON with:
|
||||
- audience: short target audience description
|
||||
- content_type: podcast style/format
|
||||
- top_keywords: 5 podcast-relevant keywords/phrases
|
||||
- suggested_outlines: 2 items, each with title (<=60 chars) and 4-6 short segments (bullet-friendly, factual)
|
||||
- title_suggestions: 3 concise episode titles
|
||||
- research_queries: array of {{"query": "string", "rationale": "string"}}
|
||||
- exa_suggested_config: suggested Exa search options with:
|
||||
- exa_search_type: "auto" | "neural" | "keyword"
|
||||
- exa_category: one of ["research paper","news","company","github","tweet","personal site","pdf","financial report","linkedin profile"]
|
||||
- exa_include_domains: up to 3 reputable domains
|
||||
- exa_exclude_domains: up to 3 domains
|
||||
- max_sources: 6-10
|
||||
- include_statistics: boolean
|
||||
- date_range: one of ["last_month","last_3_months","last_year","all_time"]
|
||||
|
||||
Requirements:
|
||||
- Keep language factual, actionable, and suited for spoken audio.
|
||||
- Avoid narrative fiction tone.
|
||||
- Prefer 2024-2025 context.
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., 429 subscription limit) - preserve error details
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Analyze] Analysis failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Analysis failed: {exc}")
|
||||
|
||||
# Normalize response (accept dict or JSON string)
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="LLM returned non-JSON output")
|
||||
elif isinstance(raw, dict):
|
||||
data = raw
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unexpected LLM response format")
|
||||
|
||||
audience = data.get("audience") or "Growth-focused professionals"
|
||||
content_type = data.get("content_type") or "Interview + insights"
|
||||
top_keywords = data.get("top_keywords") or []
|
||||
suggested_outlines = data.get("suggested_outlines") or []
|
||||
title_suggestions = data.get("title_suggestions") or []
|
||||
research_queries = data.get("research_queries") or []
|
||||
exa_suggested_config = data.get("exa_suggested_config") or None
|
||||
|
||||
return PodcastAnalyzeResponse(
|
||||
audience=audience,
|
||||
content_type=content_type,
|
||||
top_keywords=top_keywords,
|
||||
suggested_outlines=suggested_outlines,
|
||||
title_suggestions=title_suggestions,
|
||||
research_queries=research_queries,
|
||||
exa_suggested_config=exa_suggested_config,
|
||||
bible=bible_obj.model_dump() if bible_obj else None,
|
||||
avatar_url=final_avatar_url,
|
||||
avatar_prompt=final_avatar_prompt,
|
||||
)
|
||||
|
||||
422
_session_backup/models.py
Normal file
422
_session_backup/models.py
Normal file
@@ -0,0 +1,422 @@
|
||||
"""
|
||||
Podcast API Models
|
||||
|
||||
All Pydantic request/response models for podcast endpoints.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import List, Optional, Dict, Any
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PodcastProjectResponse(BaseModel):
|
||||
"""Response model for podcast project."""
|
||||
id: int
|
||||
project_id: str
|
||||
user_id: str
|
||||
idea: str
|
||||
duration: int
|
||||
speakers: int
|
||||
budget_cap: float
|
||||
analysis: Optional[Dict[str, Any]] = None
|
||||
queries: Optional[List[Dict[str, Any]]] = None
|
||||
selected_queries: Optional[List[str]] = None
|
||||
research: Optional[Dict[str, Any]] = None
|
||||
raw_research: Optional[Dict[str, Any]] = None
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
script_data: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
render_jobs: Optional[List[Dict[str, Any]]] = None
|
||||
knobs: Optional[Dict[str, Any]] = None
|
||||
research_provider: Optional[str] = None
|
||||
show_script_editor: bool = False
|
||||
show_render_queue: bool = False
|
||||
current_step: Optional[str] = None
|
||||
status: str = "draft"
|
||||
is_favorite: bool = False
|
||||
final_video_url: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
avatar_prompt: Optional[str] = None
|
||||
avatar_persona_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PodcastAnalyzeRequest(BaseModel):
|
||||
"""Request model for podcast idea analysis."""
|
||||
idea: str = Field(..., description="Podcast topic or idea")
|
||||
duration: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
avatar_url: Optional[str] = Field(None, description="Current avatar URL if selected")
|
||||
feedback: Optional[str] = Field(None, description="User feedback for regeneration")
|
||||
|
||||
|
||||
class PodcastAnalyzeResponse(BaseModel):
|
||||
"""Response model for podcast idea analysis."""
|
||||
audience: str
|
||||
content_type: str
|
||||
top_keywords: list[str]
|
||||
suggested_outlines: list[Dict[str, Any]]
|
||||
title_suggestions: list[str]
|
||||
research_queries: Optional[List[Dict[str, str]]] = None
|
||||
exa_suggested_config: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
avatar_url: Optional[str] = None
|
||||
avatar_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaRequest(BaseModel):
|
||||
"""Request model for enhancing a podcast idea with AI."""
|
||||
idea: str = Field(..., description="The raw podcast idea or keywords")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaResponse(BaseModel):
|
||||
"""Response model for enhanced podcast idea."""
|
||||
enhanced_ideas: List[str] = Field(..., description="3 AI-enhanced topic choices")
|
||||
rationales: List[str] = Field(..., description="Rationale for each enhanced idea")
|
||||
|
||||
|
||||
class PodcastScriptRequest(BaseModel):
|
||||
"""Request model for podcast script generation."""
|
||||
idea: str = Field(..., description="Podcast idea or topic")
|
||||
duration_minutes: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
research: Optional[Dict[str, Any]] = Field(None, description="Optional research payload to ground the script")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
outline: Optional[Dict[str, Any]] = Field(None, description="The refined episode outline to follow")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="The full analysis context (audience, keywords, etc.)")
|
||||
|
||||
|
||||
class PodcastSceneLine(BaseModel):
|
||||
speaker: str
|
||||
text: str
|
||||
emphasis: Optional[bool] = False
|
||||
|
||||
|
||||
class PodcastScene(BaseModel):
|
||||
id: str
|
||||
title: str
|
||||
duration: int
|
||||
lines: list[PodcastSceneLine]
|
||||
approved: bool = False
|
||||
emotion: Optional[str] = None
|
||||
imageUrl: Optional[str] = None # Generated image URL for video generation
|
||||
|
||||
|
||||
class PodcastExaConfig(BaseModel):
|
||||
"""Exa config for podcast research."""
|
||||
exa_search_type: Optional[str] = Field(default="auto", description="auto | keyword | neural")
|
||||
exa_category: Optional[str] = None
|
||||
exa_include_domains: List[str] = []
|
||||
exa_exclude_domains: List[str] = []
|
||||
max_sources: int = 8
|
||||
include_statistics: Optional[bool] = False
|
||||
date_range: Optional[str] = Field(default=None, description="last_month | last_3_months | last_year | all_time")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_domains(self):
|
||||
if self.exa_include_domains and self.exa_exclude_domains:
|
||||
# Exa API does not allow both include and exclude domains together with contents
|
||||
# Prefer include_domains and drop exclude_domains
|
||||
self.exa_exclude_domains = []
|
||||
return self
|
||||
|
||||
|
||||
class PodcastExaResearchRequest(BaseModel):
|
||||
"""Request for podcast research using Exa directly (no blog writer)."""
|
||||
topic: str
|
||||
queries: List[str]
|
||||
exa_config: Optional[PodcastExaConfig] = None
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="Podcast analysis context (audience, content type, etc.)")
|
||||
|
||||
|
||||
class PodcastExaSource(BaseModel):
|
||||
title: str = ""
|
||||
url: str = ""
|
||||
excerpt: str = ""
|
||||
published_at: Optional[str] = None
|
||||
highlights: Optional[List[str]] = None
|
||||
summary: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
image: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastResearchInsight(BaseModel):
|
||||
"""Deep insight extracted from research."""
|
||||
title: str
|
||||
content: str
|
||||
source_indices: List[int] = []
|
||||
|
||||
|
||||
class PodcastExaResearchResponse(BaseModel):
|
||||
sources: List[PodcastExaSource]
|
||||
search_queries: List[str] = []
|
||||
summary: str = ""
|
||||
key_insights: List[PodcastResearchInsight] = []
|
||||
expert_quotes: List[Dict[str, Any]] = []
|
||||
listener_cta: List[str] = []
|
||||
mapped_angles: List[Dict[str, Any]] = []
|
||||
cost: Optional[Dict[str, Any]] = None
|
||||
search_type: Optional[str] = None
|
||||
provider: str = "exa"
|
||||
content: Optional[str] = None # Raw aggregated content (deprecated)
|
||||
|
||||
|
||||
class PodcastScriptResponse(BaseModel):
|
||||
scenes: list[PodcastScene]
|
||||
|
||||
|
||||
class PodcastAudioRequest(BaseModel):
|
||||
"""Generate TTS for a podcast scene."""
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
text: str
|
||||
voice_id: Optional[str] = "Wise_Woman"
|
||||
speed: Optional[float] = 1.0
|
||||
volume: Optional[float] = 1.0
|
||||
pitch: Optional[float] = 0.0
|
||||
emotion: Optional[str] = "neutral"
|
||||
english_normalization: Optional[bool] = False # Better number reading for statistics
|
||||
sample_rate: Optional[int] = None
|
||||
bitrate: Optional[int] = None
|
||||
channel: Optional[str] = None
|
||||
format: Optional[str] = None
|
||||
language_boost: Optional[str] = None
|
||||
enable_sync_mode: Optional[bool] = True
|
||||
|
||||
|
||||
class PodcastAudioResponse(BaseModel):
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
audio_filename: str
|
||||
audio_url: str
|
||||
provider: str
|
||||
model: str
|
||||
voice_id: str
|
||||
text_length: int
|
||||
file_size: int
|
||||
cost: float
|
||||
|
||||
|
||||
class PodcastProjectListResponse(BaseModel):
|
||||
"""Response model for project list."""
|
||||
projects: List[PodcastProjectResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class CreateProjectRequest(BaseModel):
|
||||
"""Request model for creating a project."""
|
||||
project_id: str = Field(..., description="Unique project ID")
|
||||
idea: str = Field(..., description="Episode idea or URL")
|
||||
duration: int = Field(..., description="Duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
budget_cap: float = Field(default=50.0, description="Budget cap in USD")
|
||||
avatar_url: Optional[str] = Field(None, description="Optional presenter avatar URL")
|
||||
|
||||
|
||||
class UpdateProjectRequest(BaseModel):
|
||||
"""Request model for updating project state."""
|
||||
analysis: Optional[Dict[str, Any]] = None
|
||||
queries: Optional[List[Dict[str, Any]]] = None
|
||||
selected_queries: Optional[List[str]] = None
|
||||
research: Optional[Dict[str, Any]] = None
|
||||
raw_research: Optional[Dict[str, Any]] = None
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
script_data: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
render_jobs: Optional[List[Dict[str, Any]]] = None
|
||||
knobs: Optional[Dict[str, Any]] = None
|
||||
research_provider: Optional[str] = None
|
||||
show_script_editor: Optional[bool] = None
|
||||
show_render_queue: Optional[bool] = None
|
||||
current_step: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
final_video_url: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastCombineAudioRequest(BaseModel):
|
||||
"""Request model for combining podcast audio files."""
|
||||
project_id: str
|
||||
scene_ids: List[str] = Field(..., description="List of scene IDs to combine")
|
||||
scene_audio_urls: List[str] = Field(..., description="List of audio URLs for each scene")
|
||||
|
||||
|
||||
class PodcastCombineAudioResponse(BaseModel):
|
||||
"""Response model for combined podcast audio."""
|
||||
combined_audio_url: str
|
||||
combined_audio_filename: str
|
||||
total_duration: float
|
||||
file_size: int
|
||||
scene_count: int
|
||||
|
||||
|
||||
class PodcastImageRequest(BaseModel):
|
||||
"""Request for generating an image for a podcast scene."""
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
scene_content: Optional[str] = None # Optional: scene lines text for context
|
||||
idea: Optional[str] = None # Optional: podcast idea for context
|
||||
base_avatar_url: Optional[str] = None # Base avatar image URL for scene variations
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
custom_prompt: Optional[str] = None # Custom prompt from user (overrides auto-generated prompt)
|
||||
style: Optional[str] = None # "Auto", "Fiction", or "Realistic"
|
||||
rendering_speed: Optional[str] = None # "Default", "Turbo", or "Quality"
|
||||
aspect_ratio: Optional[str] = None # "1:1", "16:9", "9:16", "4:3", "3:4"
|
||||
|
||||
|
||||
class PodcastImageResponse(BaseModel):
|
||||
"""Response for podcast scene image generation."""
|
||||
scene_id: str
|
||||
scene_title: str
|
||||
image_filename: str
|
||||
image_url: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: Optional[str] = None
|
||||
cost: float
|
||||
|
||||
|
||||
class PodcastVideoGenerationRequest(BaseModel):
|
||||
"""Request model for podcast video generation."""
|
||||
project_id: str = Field(..., description="Podcast project ID")
|
||||
scene_id: str = Field(..., description="Scene ID")
|
||||
scene_title: str = Field(..., description="Scene title")
|
||||
audio_url: str = Field(..., description="URL to the generated audio file")
|
||||
avatar_image_url: Optional[str] = Field(None, description="URL to scene image (required for video generation)")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
resolution: str = Field("720p", description="Video resolution (480p or 720p)")
|
||||
prompt: Optional[str] = Field(None, description="Optional animation prompt override")
|
||||
seed: Optional[int] = Field(-1, description="Random seed; -1 for random")
|
||||
mask_image_url: Optional[str] = Field(None, description="Optional mask image URL to specify animated region")
|
||||
|
||||
|
||||
class PodcastVideoGenerationResponse(BaseModel):
|
||||
"""Response model for podcast video generation."""
|
||||
task_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class PodcastCombineVideosRequest(BaseModel):
|
||||
"""Request to combine scene videos into final podcast"""
|
||||
project_id: str = Field(..., description="Project ID")
|
||||
scene_video_urls: list[str] = Field(..., description="List of scene video URLs in order")
|
||||
podcast_title: str = Field(default="Podcast", description="Title for the final podcast video")
|
||||
|
||||
|
||||
class PodcastCombineVideosResponse(BaseModel):
|
||||
"""Response from combine videos endpoint"""
|
||||
task_id: str
|
||||
status: str
|
||||
message: str
|
||||
|
||||
|
||||
class AudioDubbingQuality(str, Enum):
|
||||
LOW = "low"
|
||||
HIGH = "high"
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, value: str) -> "AudioDubbingQuality":
|
||||
if value.lower() == "high":
|
||||
return cls.HIGH
|
||||
return cls.LOW
|
||||
|
||||
|
||||
class PodcastAudioDubRequest(BaseModel):
|
||||
"""Request model for audio dubbing."""
|
||||
source_audio_url: str = Field(..., description="URL or path to source audio file")
|
||||
source_language: Optional[str] = Field(None, description="Source language code (auto-detected if None)")
|
||||
target_language: str = Field(..., description="Target language for dubbing")
|
||||
quality: str = Field(default="low", description="Translation quality: low (DeepL) or high (WaveSpeed)")
|
||||
voice_id: Optional[str] = Field(default="Wise_Woman", description="Voice ID for TTS")
|
||||
speed: Optional[float] = Field(default=1.0, ge=0.5, le=2.0, description="Speech speed (0.5-2.0)")
|
||||
emotion: Optional[str] = Field(default="happy", description="Emotion for TTS voice")
|
||||
preserve_emotion: Optional[bool] = Field(default=True, description="Preserve emotional tone in translation")
|
||||
use_voice_clone: Optional[bool] = Field(default=False, description="Use voice cloning to preserve original speaker's voice")
|
||||
custom_voice_id: Optional[str] = Field(None, description="Custom name for the cloned voice")
|
||||
voice_clone_accuracy: Optional[float] = Field(default=0.7, ge=0.1, le=1.0, description="Voice cloning accuracy (0.1-1.0)")
|
||||
|
||||
|
||||
class PodcastAudioDubResponse(BaseModel):
|
||||
"""Response model for audio dubbing task creation."""
|
||||
task_id: str
|
||||
status: str = "pending"
|
||||
message: str = "Audio dubbing task created"
|
||||
|
||||
|
||||
class PodcastAudioDubResult(BaseModel):
|
||||
"""Response model for completed audio dubbing."""
|
||||
dubbed_audio_url: str
|
||||
dubbed_audio_filename: str
|
||||
original_transcript: str
|
||||
translated_transcript: str
|
||||
source_language: str
|
||||
target_language: str
|
||||
voice_id: str
|
||||
quality: str
|
||||
duration_seconds: int
|
||||
file_size: int
|
||||
cost: float
|
||||
task_id: str
|
||||
status: str = "completed"
|
||||
voice_clone_used: Optional[bool] = Field(default=False, description="Whether voice cloning was used")
|
||||
cloned_voice_id: Optional[str] = Field(None, description="ID of the cloned voice if voice_clone_used=True")
|
||||
|
||||
|
||||
class PodcastAudioDubEstimateRequest(BaseModel):
|
||||
"""Request model for dubbing cost estimation."""
|
||||
audio_duration_seconds: float = Field(..., description="Duration of source audio in seconds")
|
||||
target_language: str = Field(..., description="Target language")
|
||||
quality: str = Field(default="low", description="Translation quality")
|
||||
use_voice_clone: Optional[bool] = Field(default=False, description="Include voice cloning cost")
|
||||
|
||||
|
||||
class PodcastAudioDubEstimateResponse(BaseModel):
|
||||
"""Response model for dubbing cost estimation."""
|
||||
estimated_characters: int
|
||||
translation_cost: float
|
||||
tts_cost: float
|
||||
voice_clone_cost: float = 0.0
|
||||
total_cost: float
|
||||
currency: str = "USD"
|
||||
|
||||
|
||||
class VoiceCloneRequest(BaseModel):
|
||||
"""Request model for voice cloning."""
|
||||
source_audio_url: str = Field(..., description="URL or path to source audio file (10-60 seconds recommended)")
|
||||
custom_voice_id: Optional[str] = Field(None, description="Custom name for the cloned voice")
|
||||
accuracy: Optional[float] = Field(default=0.7, ge=0.1, le=1.0, description="Cloning accuracy (0.1-1.0)")
|
||||
language_boost: Optional[str] = Field(None, description="Language to optimize the voice for")
|
||||
|
||||
|
||||
class VoiceCloneResponse(BaseModel):
|
||||
"""Response model for voice cloning."""
|
||||
task_id: str
|
||||
status: str = "pending"
|
||||
message: str = "Voice cloning task created"
|
||||
|
||||
|
||||
class VoiceCloneResult(BaseModel):
|
||||
"""Response model for completed voice cloning."""
|
||||
voice_id: str
|
||||
voice_url: str
|
||||
source_language: str
|
||||
accuracy: float
|
||||
file_size: int
|
||||
task_id: str
|
||||
status: str = "completed"
|
||||
|
||||
837
_session_backup/podcastApi.ts
Normal file
837
_session_backup/podcastApi.ts
Normal file
@@ -0,0 +1,837 @@
|
||||
import { ResearchProvider, ResearchConfig } from "./blogWriterApi";
|
||||
import {
|
||||
storyWriterApi,
|
||||
StorySetupGenerationResponse,
|
||||
} from "./storyWriterApi";
|
||||
import { getResearchConfig, ResearchPersona } from "../api/researchConfig";
|
||||
import { aiApiClient } from "../api/client";
|
||||
import {
|
||||
CreateProjectPayload,
|
||||
CreateProjectResult,
|
||||
Fact,
|
||||
Knobs,
|
||||
PodcastAnalysis,
|
||||
PodcastEstimate,
|
||||
Query,
|
||||
RenderJobResult,
|
||||
Research,
|
||||
Scene,
|
||||
Script,
|
||||
} from "../components/PodcastMaker/types";
|
||||
import { checkPreflight, PreflightOperation } from "./billingService";
|
||||
import { TaskStatus } from "./storyWriterApi";
|
||||
|
||||
const DEFAULT_KNOBS: Knobs = {
|
||||
voice_emotion: "neutral",
|
||||
voice_speed: 1,
|
||||
resolution: "720p",
|
||||
scene_length_target: 45,
|
||||
sample_rate: 24000,
|
||||
bitrate: "standard",
|
||||
};
|
||||
|
||||
// const sleep = (ms: number) => new Promise((resolve) => setTimeout(resolve, ms));
|
||||
|
||||
const createId = (prefix: string) => {
|
||||
if (typeof crypto !== "undefined" && typeof crypto.randomUUID === "function") {
|
||||
return `${prefix}_${crypto.randomUUID()}`;
|
||||
}
|
||||
return `${prefix}_${Date.now()}_${Math.floor(Math.random() * 10000)}`;
|
||||
};
|
||||
|
||||
type OptionLike = StorySetupGenerationResponse["options"][0] | { plot_elements?: string; premise?: string };
|
||||
|
||||
const deriveSegments = (option?: OptionLike): string[] => {
|
||||
const segments: string[] = [];
|
||||
if (option?.plot_elements) {
|
||||
option.plot_elements
|
||||
.split(/[,.;]+/)
|
||||
.map((p) => p.trim())
|
||||
.filter(Boolean)
|
||||
.forEach((p) => segments.push(p));
|
||||
}
|
||||
if (!segments.length && "premise" in (option || {}) && (option as any)?.premise) {
|
||||
segments.push("Intro", "Key Takeaways", "Examples", "CTA");
|
||||
}
|
||||
return segments.slice(0, 5);
|
||||
};
|
||||
|
||||
const estimateCosts = ({
|
||||
minutes,
|
||||
scenes,
|
||||
chars,
|
||||
quality,
|
||||
avatars,
|
||||
queryCount = 3,
|
||||
}: {
|
||||
minutes: number;
|
||||
scenes: number;
|
||||
chars: number;
|
||||
quality: string;
|
||||
avatars: number;
|
||||
queryCount?: number;
|
||||
}): PodcastEstimate => {
|
||||
const secs = Math.max(60, minutes * 60);
|
||||
const ttsCost = (chars / 1000) * 0.05;
|
||||
const avatarCost = avatars * 0.15;
|
||||
const videoRate = quality === "hd" ? 0.06 : 0.03;
|
||||
const videoCost = secs * videoRate;
|
||||
const researchCost = +(Math.max(1, queryCount) * 0.1).toFixed(2);
|
||||
const total = +(ttsCost + avatarCost + videoCost + researchCost).toFixed(2);
|
||||
return {
|
||||
ttsCost: +ttsCost.toFixed(2),
|
||||
avatarCost: +avatarCost.toFixed(2),
|
||||
videoCost: +videoCost.toFixed(2),
|
||||
researchCost,
|
||||
total,
|
||||
};
|
||||
};
|
||||
|
||||
const mapPersonaQueries = (persona: ResearchPersona | undefined, seed: string): Query[] => {
|
||||
const baseIdea = seed || "AI marketing for small businesses";
|
||||
const personaKeywords = persona?.suggested_keywords?.filter(Boolean) || [];
|
||||
const angles = persona?.research_angles ?? [];
|
||||
const generated: Query[] = [];
|
||||
|
||||
const addQuery = (q: string, why: string, needsRecent = false) => {
|
||||
if (!q.trim()) return;
|
||||
generated.push({
|
||||
id: createId("q"),
|
||||
query: q.trim(),
|
||||
rationale: why,
|
||||
needsRecentStats: needsRecent,
|
||||
});
|
||||
};
|
||||
|
||||
if (personaKeywords.length) {
|
||||
personaKeywords.slice(0, 4).forEach((k, idx) =>
|
||||
addQuery(k, angles[idx % Math.max(1, angles.length)] || "Persona-aligned query", /202[45]|latest|trend/i.test(k))
|
||||
);
|
||||
}
|
||||
|
||||
if (!generated.length) {
|
||||
addQuery(`How is ${baseIdea} evolving in 2024?`, "Trend + outcome focus", true);
|
||||
addQuery(`Best practices for ${baseIdea}`, "Actionable guidance", false);
|
||||
addQuery(`${baseIdea} case studies with ROI`, "Proof and outcomes", true);
|
||||
addQuery(`${baseIdea} risks and objections`, "Address listener concerns", false);
|
||||
}
|
||||
|
||||
return generated.slice(0, 6);
|
||||
};
|
||||
|
||||
const mapSourcesToFacts = (sources: ExaSource[]): Fact[] => {
|
||||
if (!sources || !sources.length) return [];
|
||||
return sources.slice(0, 12).map((source: ExaSource, idx: number) => ({
|
||||
id: source.url || createId("fact"),
|
||||
quote: source.excerpt || source.title || "Insight",
|
||||
url: source.url || "",
|
||||
date: source.published_at || "Unknown",
|
||||
confidence: typeof (source as any).credibility_score === "number" ? (source as any).credibility_score : Math.max(0.5, 0.85 - idx * 0.02),
|
||||
image: source.image,
|
||||
author: source.author,
|
||||
highlights: source.highlights,
|
||||
}));
|
||||
};
|
||||
|
||||
type ExaSource = {
|
||||
title?: string;
|
||||
url?: string;
|
||||
excerpt?: string;
|
||||
published_at?: string;
|
||||
highlights?: string[];
|
||||
summary?: string;
|
||||
source_type?: string;
|
||||
index?: number;
|
||||
image?: string;
|
||||
author?: string;
|
||||
};
|
||||
|
||||
type ExaResearchResult = {
|
||||
sources: ExaSource[];
|
||||
search_queries?: string[];
|
||||
cost?: { total?: number };
|
||||
search_type?: string;
|
||||
provider?: string;
|
||||
content?: string;
|
||||
};
|
||||
|
||||
const mapExaResearchResponse = (response: any): Research => {
|
||||
const factCards = mapSourcesToFacts(response.sources);
|
||||
// Use backend summary if available, otherwise use full content (no truncation) or fallback text
|
||||
const summary = response.summary || response.content || "Research completed.";
|
||||
|
||||
const keyInsights = (response.key_insights || []).map((insight: any) => ({
|
||||
title: insight.title || "Insight",
|
||||
content: insight.content || "",
|
||||
source_indices: insight.source_indices || []
|
||||
}));
|
||||
|
||||
const expertQuotes = (response.expert_quotes || []).map((eq: any) => ({
|
||||
quote: eq.quote || eq.text || "",
|
||||
source_index: eq.source_index ?? 0
|
||||
}));
|
||||
|
||||
const listenerCta = response.listener_cta || [];
|
||||
|
||||
const mappedAngles = (response.mapped_angles || []).map((angle: any) => ({
|
||||
title: angle.title || "",
|
||||
why: angle.why || angle.rationale || "",
|
||||
mappedFactIds: angle.mapped_fact_ids || angle.mappedFactIds || []
|
||||
}));
|
||||
|
||||
return {
|
||||
summary,
|
||||
keyInsights,
|
||||
factCards,
|
||||
mappedAngles,
|
||||
expertQuotes,
|
||||
listenerCta,
|
||||
searchQueries: response.search_queries,
|
||||
searchType: response.search_type,
|
||||
provider: response.provider || "exa",
|
||||
cost: response.cost?.total,
|
||||
sourceCount: response.sources?.length || 0,
|
||||
};
|
||||
};
|
||||
|
||||
const ensurePreflight = async (operation: PreflightOperation) => {
|
||||
const result = await checkPreflight(operation);
|
||||
if (!result.can_proceed) {
|
||||
const message = result.operations[0]?.message || "Pre-flight validation failed";
|
||||
throw new Error(message);
|
||||
}
|
||||
return result;
|
||||
};
|
||||
|
||||
export const podcastApi = {
|
||||
async createProject(payload: CreateProjectPayload, bible?: any, feedback?: string): Promise<CreateProjectResult> {
|
||||
const storyIdea = payload.ideaOrUrl || "AI marketing for small businesses";
|
||||
|
||||
await ensurePreflight({
|
||||
provider: "gemini",
|
||||
operation_type: "podcast_analysis",
|
||||
tokens_requested: 1500,
|
||||
actual_provider_name: "gemini",
|
||||
});
|
||||
|
||||
// Podcast-specific analysis (not story setup)
|
||||
const analysisResp = await aiApiClient.post("/api/podcast/analyze", {
|
||||
idea: storyIdea,
|
||||
duration: payload.duration,
|
||||
speakers: payload.speakers,
|
||||
bible: bible,
|
||||
avatar_url: payload.avatarUrl,
|
||||
feedback: feedback, // Pass feedback to backend
|
||||
});
|
||||
|
||||
const outlines = (analysisResp.data?.suggested_outlines || []).map((o: any, idx: number) => ({
|
||||
id: o.id || `outline-${idx + 1}`,
|
||||
title: o.title || `Outline ${idx + 1}`,
|
||||
segments: Array.isArray(o.segments) ? o.segments : deriveSegments({ plot_elements: o.segments }),
|
||||
}));
|
||||
|
||||
const analysis: PodcastAnalysis = {
|
||||
audience: analysisResp.data?.audience || "Growth-minded pros",
|
||||
contentType: analysisResp.data?.content_type || "Podcast interview",
|
||||
topKeywords: analysisResp.data?.top_keywords || outlines[0]?.segments?.slice(0, 3) || [],
|
||||
suggestedOutlines: outlines,
|
||||
suggestedKnobs: { ...DEFAULT_KNOBS, ...payload.knobs },
|
||||
titleSuggestions: (analysisResp.data?.title_suggestions || []).filter(Boolean),
|
||||
research_queries: analysisResp.data?.research_queries || [],
|
||||
exaSuggestedConfig: analysisResp.data?.exa_suggested_config || undefined,
|
||||
};
|
||||
|
||||
const researchConfig = await getResearchConfig().catch(() => null);
|
||||
|
||||
// Use AI-generated queries if available, fallback to legacy mapping
|
||||
let queries: Query[] = [];
|
||||
if (analysis.research_queries && analysis.research_queries.length > 0) {
|
||||
queries = analysis.research_queries.map(rq => ({
|
||||
id: createId("q"),
|
||||
query: rq.query,
|
||||
rationale: rq.rationale,
|
||||
needsRecentStats: /202[45]|latest|trend/i.test(rq.query)
|
||||
}));
|
||||
} else {
|
||||
queries = mapPersonaQueries(researchConfig?.research_persona, storyIdea);
|
||||
}
|
||||
|
||||
const projectId = createId("podcast");
|
||||
const estimate = estimateCosts({
|
||||
minutes: payload.duration,
|
||||
scenes: Math.ceil((payload.duration * 60) / (payload.knobs.scene_length_target || DEFAULT_KNOBS.scene_length_target)),
|
||||
chars: Math.max(1000, payload.duration * 900),
|
||||
quality: payload.knobs.bitrate || "standard",
|
||||
avatars: payload.speakers,
|
||||
queryCount: queries.length || 3,
|
||||
});
|
||||
|
||||
return {
|
||||
projectId,
|
||||
analysis,
|
||||
estimate,
|
||||
queries,
|
||||
bible: analysisResp.data?.bible || undefined,
|
||||
avatar_url: analysisResp.data?.avatar_url || null,
|
||||
avatar_prompt: analysisResp.data?.avatar_prompt || null,
|
||||
};
|
||||
},
|
||||
|
||||
async enhanceIdea(params: { idea: string; bible?: any }): Promise<{ enhanced_ideas: string[]; rationales: string[] }> {
|
||||
const response = await aiApiClient.post("/api/podcast/idea/enhance", params);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async runResearch(params: {
|
||||
projectId: string;
|
||||
topic: string;
|
||||
approvedQueries: Query[];
|
||||
provider?: ResearchProvider;
|
||||
exaConfig?: ResearchConfig;
|
||||
bible?: any;
|
||||
analysis?: PodcastAnalysis | null;
|
||||
onProgress?: (message: string) => void;
|
||||
}): Promise<{ research: Research; raw: any }> {
|
||||
const keywords = params.approvedQueries.map((q) => q.query).filter(Boolean);
|
||||
if (!keywords.length) {
|
||||
throw new Error("At least one query must be approved for research.");
|
||||
}
|
||||
|
||||
// Ensure Exa payload respects API constraint: when requesting contents, only one of includeDomains or excludeDomains.
|
||||
let sanitizedExaConfig: ResearchConfig | undefined = params.exaConfig;
|
||||
if (sanitizedExaConfig && sanitizedExaConfig.exa_include_domains?.length) {
|
||||
sanitizedExaConfig = {
|
||||
...sanitizedExaConfig,
|
||||
exa_exclude_domains: undefined,
|
||||
};
|
||||
} else if (sanitizedExaConfig && sanitizedExaConfig.exa_exclude_domains?.length) {
|
||||
sanitizedExaConfig = {
|
||||
...sanitizedExaConfig,
|
||||
exa_include_domains: undefined,
|
||||
};
|
||||
}
|
||||
|
||||
await ensurePreflight({
|
||||
provider: "exa",
|
||||
operation_type: "exa_neural_search",
|
||||
tokens_requested: 0,
|
||||
actual_provider_name: "exa",
|
||||
});
|
||||
|
||||
const response = await aiApiClient.post("/api/podcast/research/exa", {
|
||||
topic: params.topic || keywords[0],
|
||||
queries: keywords,
|
||||
exa_config: sanitizedExaConfig,
|
||||
bible: params.bible,
|
||||
analysis: params.analysis,
|
||||
});
|
||||
|
||||
const exaResult = response.data as ExaResearchResult;
|
||||
if (params.onProgress) {
|
||||
params.onProgress("Deep research completed with Exa.");
|
||||
}
|
||||
const mapped = mapExaResearchResponse(exaResult);
|
||||
return { research: mapped, raw: exaResult };
|
||||
},
|
||||
|
||||
async generateScript(params: {
|
||||
projectId: string;
|
||||
idea: string;
|
||||
research?: ExaResearchResult | null;
|
||||
knobs: Knobs;
|
||||
speakers: number;
|
||||
durationMinutes: number;
|
||||
bible?: any;
|
||||
outline?: any;
|
||||
analysis?: PodcastAnalysis | null;
|
||||
}): Promise<Script> {
|
||||
await ensurePreflight({
|
||||
provider: "gemini",
|
||||
operation_type: "script_generation",
|
||||
tokens_requested: 2000,
|
||||
actual_provider_name: "gemini",
|
||||
});
|
||||
|
||||
const response = await aiApiClient.post("/api/podcast/script", {
|
||||
idea: params.idea,
|
||||
duration_minutes: params.durationMinutes,
|
||||
speakers: params.speakers,
|
||||
research: params.research,
|
||||
bible: params.bible,
|
||||
outline: params.outline,
|
||||
analysis: params.analysis,
|
||||
});
|
||||
|
||||
const scenes = response.data?.scenes || [];
|
||||
const scriptScenes: Scene[] = scenes.map((scene: any) => ({
|
||||
id: scene.id || createId("scene"),
|
||||
title: scene.title || "Scene",
|
||||
duration: scene.duration || Math.max(20, params.knobs.scene_length_target || DEFAULT_KNOBS.scene_length_target),
|
||||
lines:
|
||||
Array.isArray(scene.lines) && scene.lines.length
|
||||
? scene.lines.map((l: any) => ({
|
||||
id: createId("line"),
|
||||
speaker: l.speaker || "Host",
|
||||
text: l.text || "",
|
||||
}))
|
||||
: [
|
||||
{
|
||||
id: createId("line"),
|
||||
speaker: "Host",
|
||||
text: "Let's dive into today's topic.",
|
||||
},
|
||||
],
|
||||
approved: false,
|
||||
}));
|
||||
|
||||
return { scenes: scriptScenes };
|
||||
},
|
||||
|
||||
async previewLine(
|
||||
text: string,
|
||||
options: { voiceId?: string; speed?: number; emotion?: string } = {}
|
||||
): Promise<{ ok: boolean; message: string; audioUrl?: string }> {
|
||||
await ensurePreflight({
|
||||
provider: "audio",
|
||||
operation_type: "tts_preview",
|
||||
tokens_requested: text.length,
|
||||
actual_provider_name: "wavespeed",
|
||||
});
|
||||
|
||||
const response = await storyWriterApi.generateAIAudio({
|
||||
scene_number: 0,
|
||||
scene_title: "Preview",
|
||||
text,
|
||||
voice_id: options.voiceId || "Wise_Woman",
|
||||
speed: options.speed || 1.0,
|
||||
emotion: options.emotion || "neutral",
|
||||
});
|
||||
|
||||
if (!response.success) {
|
||||
throw new Error(response.error || "Preview failed");
|
||||
}
|
||||
|
||||
return {
|
||||
ok: true,
|
||||
message: "Preview ready – opening audio in new tab.",
|
||||
audioUrl: response.audio_url,
|
||||
};
|
||||
},
|
||||
|
||||
async renderSceneAudio(params: {
|
||||
scene: Scene;
|
||||
voiceId?: string;
|
||||
emotion?: string; // Fallback if scene doesn't have emotion
|
||||
speed?: number;
|
||||
volume?: number;
|
||||
pitch?: number;
|
||||
englishNormalization?: boolean;
|
||||
sampleRate?: number;
|
||||
bitrate?: number;
|
||||
channel?: "1" | "2";
|
||||
format?: "mp3" | "wav" | "pcm" | "flac";
|
||||
languageBoost?: string;
|
||||
}): Promise<RenderJobResult> {
|
||||
// Use scene-specific emotion if available, otherwise fallback to provided/default
|
||||
const sceneEmotion = params.scene.emotion || params.emotion || "neutral";
|
||||
|
||||
// Optimize text for Minimax Speech-02-HD TTS
|
||||
// - Strip markdown formatting (bold, italic, etc.) - TTS reads it literally
|
||||
// - Use pause markers <#x#> for natural speech rhythm
|
||||
// - Add longer pauses for speaker changes
|
||||
// - Preserve punctuation for natural breathing
|
||||
// - Add emphasis pauses for important points
|
||||
const text = params.scene.lines
|
||||
.map((line, idx) => {
|
||||
let lineText = line.text.trim();
|
||||
|
||||
// Strip markdown formatting - TTS reads asterisks and other markdown literally
|
||||
// Remove bold (**text** or __text__)
|
||||
lineText = lineText.replace(/\*\*([^*]+)\*\*/g, '$1'); // **bold**
|
||||
lineText = lineText.replace(/\*([^*]+)\*/g, '$1'); // *bold* (single asterisk)
|
||||
lineText = lineText.replace(/__([^_]+)__/g, '$1'); // __bold__
|
||||
lineText = lineText.replace(/_([^_]+)_/g, '$1'); // _italic_ (single underscore)
|
||||
// Remove any remaining stray asterisks or underscores
|
||||
lineText = lineText.replace(/\*+/g, ''); // Remove any remaining asterisks
|
||||
lineText = lineText.replace(/_+/g, ''); // Remove any remaining underscores
|
||||
// Clean up extra spaces
|
||||
lineText = lineText.replace(/\s+/g, ' ').trim();
|
||||
|
||||
// Preserve punctuation (Minimax uses it for natural breathing)
|
||||
// Don't strip punctuation - it helps TTS understand natural pauses
|
||||
|
||||
// Add emphasis pause after lines marked with emphasis
|
||||
if (line.emphasis) {
|
||||
// Minimal pause after emphasized content (0.15s for subtle emphasis)
|
||||
lineText = `${lineText}<#0.15#>`;
|
||||
}
|
||||
|
||||
// Check for speaker change (longer pause for natural conversation flow)
|
||||
const prevLine = idx > 0 ? params.scene.lines[idx - 1] : null;
|
||||
const isSpeakerChange = prevLine && prevLine.speaker !== line.speaker;
|
||||
|
||||
if (isSpeakerChange) {
|
||||
// Short pause for speaker changes (0.2s - enough for natural transition)
|
||||
lineText = `<#0.2#>${lineText}`;
|
||||
}
|
||||
|
||||
// Add minimal pause between lines (only between regular lines, very short)
|
||||
if (idx < params.scene.lines.length - 1) {
|
||||
if (!line.emphasis && !isSpeakerChange) {
|
||||
// Very short pause between lines (0.08s - barely noticeable but helps flow)
|
||||
lineText = `${lineText}<#0.08#>`;
|
||||
}
|
||||
// If emphasis or speaker change, the pause is already added above
|
||||
}
|
||||
|
||||
return lineText;
|
||||
})
|
||||
.join(" ");
|
||||
|
||||
// Validate character limit (Minimax max: 10,000 characters)
|
||||
const MAX_CHARS = 10000;
|
||||
let textToUse = text;
|
||||
if (text.length > MAX_CHARS) {
|
||||
console.warn(
|
||||
`[Podcast] Scene "${params.scene.title}" exceeds ${MAX_CHARS} character limit (${text.length} chars). Truncating...`
|
||||
);
|
||||
// Truncate at word boundary to avoid cutting mid-word
|
||||
const truncated = text.substring(0, MAX_CHARS);
|
||||
const lastSpace = truncated.lastIndexOf(" ");
|
||||
textToUse = lastSpace > 0 ? truncated.substring(0, lastSpace) : truncated;
|
||||
}
|
||||
|
||||
await ensurePreflight({
|
||||
provider: "audio",
|
||||
operation_type: "tts_full_render",
|
||||
tokens_requested: textToUse.length,
|
||||
actual_provider_name: "wavespeed",
|
||||
});
|
||||
|
||||
const response = await aiApiClient.post("/api/podcast/audio", {
|
||||
scene_id: params.scene.id,
|
||||
scene_title: params.scene.title,
|
||||
text: textToUse,
|
||||
voice_id: params.voiceId || "Wise_Woman",
|
||||
speed: params.speed ?? 1.0, // Normal speed (was 0.9, but too slow - causing duration issues)
|
||||
volume: params.volume ?? 1.0,
|
||||
pitch: params.pitch ?? 0.0,
|
||||
emotion: sceneEmotion,
|
||||
english_normalization: params.englishNormalization ?? true, // Better number reading for statistics
|
||||
sample_rate: params.sampleRate || null,
|
||||
bitrate: params.bitrate || null,
|
||||
channel: params.channel || null,
|
||||
format: params.format || null,
|
||||
language_boost: params.languageBoost || null,
|
||||
});
|
||||
|
||||
return {
|
||||
audioUrl: response.data.audio_url,
|
||||
audioFilename: response.data.audio_filename,
|
||||
provider: response.data.provider,
|
||||
model: response.data.model,
|
||||
cost: response.data.cost,
|
||||
voiceId: response.data.voice_id,
|
||||
fileSize: response.data.file_size,
|
||||
};
|
||||
},
|
||||
|
||||
async approveScene(params: { projectId: string; sceneId: string; notes?: string }) {
|
||||
await aiApiClient.post("/api/story/script/approve", {
|
||||
project_id: params.projectId,
|
||||
scene_id: params.sceneId,
|
||||
approved: true,
|
||||
notes: params.notes,
|
||||
});
|
||||
},
|
||||
|
||||
// Project persistence endpoints
|
||||
async saveProject(projectId: string, state: any): Promise<void> {
|
||||
try {
|
||||
await aiApiClient.put(`/api/podcast/projects/${projectId}`, state);
|
||||
} catch (error) {
|
||||
console.error("Failed to save project to database:", error);
|
||||
// Don't throw - localStorage fallback is acceptable
|
||||
}
|
||||
},
|
||||
|
||||
async loadProject(projectId: string): Promise<any> {
|
||||
const response = await aiApiClient.get(`/api/podcast/projects/${projectId}`);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async listProjects(params?: {
|
||||
status?: string;
|
||||
favorites_only?: boolean;
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
order_by?: "updated_at" | "created_at";
|
||||
}): Promise<{ projects: any[]; total: number; limit: number; offset: number }> {
|
||||
const response = await aiApiClient.get("/api/podcast/projects", { params });
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async createProjectInDb(params: {
|
||||
project_id: string;
|
||||
idea: string;
|
||||
duration: number;
|
||||
speakers: number;
|
||||
budget_cap: number;
|
||||
avatar_url?: string | null;
|
||||
}): Promise<any> {
|
||||
const response = await aiApiClient.post("/api/podcast/projects", params);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async updateProject(projectId: string, updates: any): Promise<any> {
|
||||
const response = await aiApiClient.put(`/api/podcast/projects/${projectId}`, updates);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async deleteProject(projectId: string): Promise<void> {
|
||||
await aiApiClient.delete(`/api/podcast/projects/${projectId}`);
|
||||
},
|
||||
|
||||
async toggleFavorite(projectId: string): Promise<any> {
|
||||
const response = await aiApiClient.post(`/api/podcast/projects/${projectId}/favorite`);
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async saveAudioToAssetLibrary(params: {
|
||||
audioUrl: string;
|
||||
filename: string;
|
||||
title: string;
|
||||
description?: string;
|
||||
projectId: string;
|
||||
sceneId?: string;
|
||||
cost?: number;
|
||||
provider?: string;
|
||||
model?: string;
|
||||
fileSize?: number;
|
||||
}): Promise<{ assetId: number }> {
|
||||
const response = await aiApiClient.post("/api/content-assets/", {
|
||||
asset_type: "audio",
|
||||
source_module: "podcast_maker",
|
||||
filename: params.filename,
|
||||
file_url: params.audioUrl,
|
||||
title: params.title,
|
||||
description: params.description || `Podcast episode audio: ${params.title}`,
|
||||
tags: ["podcast", "audio", params.projectId],
|
||||
asset_metadata: {
|
||||
project_id: params.projectId,
|
||||
scene_id: params.sceneId,
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
},
|
||||
provider: params.provider,
|
||||
model: params.model,
|
||||
cost: params.cost || 0,
|
||||
file_size: params.fileSize,
|
||||
mime_type: "audio/mpeg",
|
||||
});
|
||||
return { assetId: response.data.id };
|
||||
},
|
||||
|
||||
async generateVideo(params: {
|
||||
projectId: string;
|
||||
sceneId: string;
|
||||
sceneTitle: string;
|
||||
audioUrl: string;
|
||||
avatarImageUrl?: string;
|
||||
bible?: any;
|
||||
resolution?: string;
|
||||
prompt?: string;
|
||||
seed?: number;
|
||||
maskImageUrl?: string;
|
||||
}): Promise<{ taskId: string; status: string; message: string }> {
|
||||
const response = await aiApiClient.post("/api/podcast/render/video", {
|
||||
project_id: params.projectId,
|
||||
scene_id: params.sceneId,
|
||||
scene_title: params.sceneTitle,
|
||||
audio_url: params.audioUrl,
|
||||
avatar_image_url: params.avatarImageUrl,
|
||||
bible: params.bible,
|
||||
resolution: params.resolution || "720p",
|
||||
prompt: params.prompt,
|
||||
seed: params.seed ?? -1,
|
||||
mask_image_url: params.maskImageUrl,
|
||||
});
|
||||
|
||||
// Backend returns snake_case (task_id); normalize to camelCase for callers
|
||||
const { task_id, status, message } = response.data || {};
|
||||
return {
|
||||
taskId: task_id,
|
||||
status,
|
||||
message,
|
||||
};
|
||||
},
|
||||
|
||||
async pollTaskStatus(taskId: string): Promise<TaskStatus | null> {
|
||||
const response = await aiApiClient.get(`/api/podcast/task/${taskId}/status`);
|
||||
// Backend returns null if task not found
|
||||
return response.data || null;
|
||||
},
|
||||
|
||||
async listVideos(projectId?: string): Promise<{
|
||||
videos: Array<{
|
||||
scene_number: number;
|
||||
filename: string;
|
||||
video_url: string;
|
||||
file_size: number;
|
||||
}>;
|
||||
}> {
|
||||
const params = projectId ? { project_id: projectId } : {};
|
||||
const response = await aiApiClient.get("/api/podcast/videos", { params });
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async combineVideos(params: {
|
||||
projectId: string;
|
||||
sceneVideoUrls: string[];
|
||||
podcastTitle?: string;
|
||||
}): Promise<{
|
||||
taskId: string;
|
||||
status: string;
|
||||
message: string;
|
||||
}> {
|
||||
const response = await aiApiClient.post("/api/podcast/render/combine-videos", {
|
||||
project_id: params.projectId,
|
||||
scene_video_urls: params.sceneVideoUrls,
|
||||
podcast_title: params.podcastTitle || "Podcast",
|
||||
});
|
||||
|
||||
const { task_id, status, message } = response.data || {};
|
||||
return {
|
||||
taskId: task_id,
|
||||
status,
|
||||
message,
|
||||
};
|
||||
},
|
||||
|
||||
async generateSceneImage(params: {
|
||||
sceneId: string;
|
||||
sceneTitle: string;
|
||||
sceneContent?: string;
|
||||
baseAvatarUrl?: string;
|
||||
bible?: any;
|
||||
idea?: string;
|
||||
width?: number;
|
||||
height?: number;
|
||||
customPrompt?: string;
|
||||
style?: "Auto" | "Fiction" | "Realistic";
|
||||
renderingSpeed?: "Default" | "Turbo" | "Quality";
|
||||
aspectRatio?: "1:1" | "16:9" | "9:16" | "4:3" | "3:4";
|
||||
}): Promise<{
|
||||
scene_id: string;
|
||||
scene_title: string;
|
||||
image_filename: string;
|
||||
image_url: string;
|
||||
width: number;
|
||||
height: number;
|
||||
provider: string;
|
||||
model?: string;
|
||||
cost: number;
|
||||
}> {
|
||||
const response = await aiApiClient.post("/api/podcast/image", {
|
||||
scene_id: params.sceneId,
|
||||
scene_title: params.sceneTitle,
|
||||
scene_content: params.sceneContent,
|
||||
base_avatar_url: params.baseAvatarUrl || null,
|
||||
bible: params.bible,
|
||||
idea: params.idea || null,
|
||||
width: params.width || 1024,
|
||||
height: params.height || 1024,
|
||||
custom_prompt: params.customPrompt || null,
|
||||
style: params.style || null,
|
||||
rendering_speed: params.renderingSpeed || null,
|
||||
aspect_ratio: params.aspectRatio || null,
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async cancelTask(taskId: string): Promise<void> {
|
||||
// Note: Task cancellation may not be fully supported by backend yet
|
||||
// This is a placeholder for future implementation
|
||||
try {
|
||||
await aiApiClient.post(`/api/story/task/${taskId}/cancel`);
|
||||
} catch (error) {
|
||||
console.warn("Task cancellation not supported:", error);
|
||||
}
|
||||
},
|
||||
|
||||
async combineAudio(params: {
|
||||
projectId: string;
|
||||
sceneIds: string[];
|
||||
sceneAudioUrls: string[];
|
||||
}): Promise<{
|
||||
combined_audio_url: string;
|
||||
combined_audio_filename: string;
|
||||
total_duration: number;
|
||||
file_size: number;
|
||||
scene_count: number;
|
||||
}> {
|
||||
const response = await aiApiClient.post("/api/podcast/combine-audio", {
|
||||
project_id: params.projectId,
|
||||
scene_ids: params.sceneIds,
|
||||
scene_audio_urls: params.sceneAudioUrls,
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async uploadAvatar(file: File, projectId?: string): Promise<{ avatar_url: string; avatar_filename: string }> {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
if (projectId) {
|
||||
formData.append('project_id', projectId);
|
||||
}
|
||||
const response = await aiApiClient.post('/api/podcast/avatar/upload', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async generatePresenters(
|
||||
speakers: number,
|
||||
projectId?: string,
|
||||
audience?: string,
|
||||
contentType?: string,
|
||||
topKeywords?: string[]
|
||||
): Promise<{
|
||||
avatars: Array<{ avatar_url: string; speaker_number: number; prompt?: string; persona_id?: string; seed?: number }>;
|
||||
persona_id?: string;
|
||||
}> {
|
||||
const formData = new FormData();
|
||||
formData.append('speakers', speakers.toString());
|
||||
if (projectId) {
|
||||
formData.append('project_id', projectId);
|
||||
}
|
||||
if (audience) {
|
||||
formData.append('audience', audience);
|
||||
}
|
||||
if (contentType) {
|
||||
formData.append('content_type', contentType);
|
||||
}
|
||||
if (topKeywords && Array.isArray(topKeywords) && topKeywords.length > 0) {
|
||||
formData.append('top_keywords', JSON.stringify(topKeywords));
|
||||
}
|
||||
const response = await aiApiClient.post('/api/podcast/avatar/generate', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
|
||||
async makeAvatarPresentable(avatarUrl: string, projectId?: string): Promise<{ avatar_url: string; avatar_filename: string }> {
|
||||
const formData = new FormData();
|
||||
formData.append('avatar_url', avatarUrl);
|
||||
if (projectId) {
|
||||
formData.append('project_id', projectId);
|
||||
}
|
||||
const response = await aiApiClient.post('/api/podcast/avatar/make-presentable', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
});
|
||||
return response.data;
|
||||
},
|
||||
};
|
||||
|
||||
export type PodcastApi = typeof podcastApi;
|
||||
|
||||
244
_session_backup/research.py
Normal file
244
_session_backup/research.py
Normal file
@@ -0,0 +1,244 @@
|
||||
"""
|
||||
Podcast Research Handlers
|
||||
|
||||
Research endpoints using Exa provider and LLM summarization.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List
|
||||
from types import SimpleNamespace
|
||||
import json
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastExaResearchRequest,
|
||||
PodcastExaResearchResponse,
|
||||
PodcastExaSource,
|
||||
PodcastExaConfig,
|
||||
PodcastResearchInsight,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/research/exa", response_model=PodcastExaResearchResponse)
|
||||
async def podcast_research_exa(
|
||||
request: PodcastExaResearchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Run podcast research via Exa and then use LLM to extract deep insights.
|
||||
Uses Podcast Bible and Analysis context for hyper-personalization.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
queries = [q.strip() for q in request.queries if q and q.strip()]
|
||||
if not queries:
|
||||
raise HTTPException(status_code=400, detail="At least one query is required for research.")
|
||||
|
||||
exa_cfg = request.exa_config or PodcastExaConfig()
|
||||
cfg = SimpleNamespace(
|
||||
exa_search_type=exa_cfg.exa_search_type or "auto",
|
||||
exa_category=exa_cfg.exa_category,
|
||||
exa_include_domains=exa_cfg.exa_include_domains or [],
|
||||
exa_exclude_domains=exa_cfg.exa_exclude_domains or [],
|
||||
max_sources=exa_cfg.max_sources or 8,
|
||||
source_types=[],
|
||||
)
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
|
||||
# --- Context Building ---
|
||||
bible_service = PodcastBibleService()
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
try:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Research] Failed to serialize bible: {exc}")
|
||||
|
||||
analysis_context = ""
|
||||
if request.analysis:
|
||||
analysis_context = f"""
|
||||
PODCAST ANALYSIS CONTEXT:
|
||||
Audience: {request.analysis.get('audience', 'General')}
|
||||
Content Type: {request.analysis.get('content_type', 'Informative')}
|
||||
Top Keywords: {', '.join(request.analysis.get('top_keywords', []))}
|
||||
"""
|
||||
|
||||
# Exa search params
|
||||
industry = request.bible.get("brand", {}).get("industry", "") if request.bible else ""
|
||||
target_audience = ""
|
||||
if request.bible:
|
||||
audience_dna = request.bible.get("audience", {})
|
||||
if audience_dna:
|
||||
interests = ", ".join(audience_dna.get("interests", []))
|
||||
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
|
||||
|
||||
try:
|
||||
# 1. RUN EXA SEARCH
|
||||
result = await provider.search(
|
||||
prompt=request.topic,
|
||||
topic=request.topic,
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
config=cfg,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Exa Research] Search failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Exa research failed: {exc}")
|
||||
|
||||
# 2. EXTRACT INSIGHTS VIA LLM
|
||||
raw_content = result.get("content", "")
|
||||
sources = result.get("sources", [])
|
||||
|
||||
summary = ""
|
||||
key_insights = []
|
||||
expert_quotes = []
|
||||
listener_cta = []
|
||||
mapped_angles = []
|
||||
|
||||
if raw_content and sources:
|
||||
logger.info(f"[Podcast Research] Extracting insights from {len(sources)} sources for user {user_id}")
|
||||
|
||||
prompt = f"""
|
||||
You are an expert research analyst for a high-end podcast production team.
|
||||
Your task is to analyze the following research data and extract deep, actionable insights for a podcast episode.
|
||||
|
||||
PODCAST CONTEXT:
|
||||
Topic: {request.topic}
|
||||
{bible_context}
|
||||
{analysis_context}
|
||||
|
||||
RESEARCH DATA (from {len(sources)} sources):
|
||||
{raw_content}
|
||||
|
||||
TASK:
|
||||
1. Provide a comprehensive summary (2-3 paragraphs) of the most important findings. Use Markdown for formatting (bolding, lists).
|
||||
2. Extract 3-5 "Key Insights". Each insight should have a title and a detailed explanation.
|
||||
3. For each insight, identify which source indices (e.g. 1, 2) it was derived from.
|
||||
4. Extract notable "Expert Quotes" - direct quotes from industry leaders, researchers, or authoritative voices found in the sources.
|
||||
5. Suggest 2-4 "Listener CTA" (call-to-action) ideas that the podcast host can use to engage the audience.
|
||||
6. Identify 3-5 "Mapped Angles" - unique content angles with rationale for why they matter for this topic.
|
||||
|
||||
NOTE: The research data includes "Key Highlights", "Summaries", and "Excerpts" from various sources.
|
||||
Pay special attention to the "Key Highlights" sections as they contain the most relevant information extracted by the neural search engine.
|
||||
|
||||
Return JSON structure:
|
||||
{{
|
||||
"summary": "Detailed markdown summary...",
|
||||
"key_insights": [
|
||||
{{
|
||||
"title": "Insight Title",
|
||||
"content": "Detailed markdown content...",
|
||||
"source_indices": [1, 2]
|
||||
}}
|
||||
],
|
||||
"expert_quotes": [
|
||||
{{
|
||||
"quote": "Exact quote from source...",
|
||||
"source_index": 1
|
||||
}}
|
||||
],
|
||||
"listener_cta": [
|
||||
"Call-to-action suggestion 1",
|
||||
"Call-to-action suggestion 2"
|
||||
],
|
||||
"mapped_angles": [
|
||||
{{
|
||||
"title": "Angle Title",
|
||||
"why": "Why this angle matters for the audience...",
|
||||
"mapped_fact_ids": ["fact_1", "fact_2"]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Requirements:
|
||||
- Ensure insights are deep, not just superficial facts. Look for trends, expert opinions, and specific data points.
|
||||
- Expert quotes should be exact or near-exact quotes from the sources, with attribution.
|
||||
- Listener CTAs should be practical and engaging (e.g., "Share your experience with X on social media").
|
||||
- Mapped angles should be unique perspectives that make the episode stand out.
|
||||
- Tone should be professional, insightful, and ready for a podcast host to discuss.
|
||||
- Avoid generic filler.
|
||||
"""
|
||||
try:
|
||||
llm_response = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(llm_response, str):
|
||||
data = json.loads(llm_response)
|
||||
else:
|
||||
data = llm_response
|
||||
|
||||
summary = data.get("summary", "")
|
||||
key_insights = [PodcastResearchInsight(**insight) for insight in data.get("key_insights", [])]
|
||||
expert_quotes = data.get("expert_quotes", [])
|
||||
listener_cta = data.get("listener_cta", [])
|
||||
mapped_angles = data.get("mapped_angles", [])
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Research] LLM Insight extraction failed: {exc}")
|
||||
# Fallback to a basic summary if LLM fails
|
||||
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
|
||||
|
||||
# Fallback: if summary is still empty (e.g. LLM returned empty string), use raw content first paragraph or basic text
|
||||
if not summary:
|
||||
if raw_content:
|
||||
summary = raw_content[:2000] # Use first 2000 chars of raw content as summary
|
||||
else:
|
||||
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
|
||||
|
||||
# 3. TRACK USAGE
|
||||
try:
|
||||
cost_total = 0.0
|
||||
if isinstance(result, dict):
|
||||
cost_total = result.get("cost", {}).get("total", 0.005) if result.get("cost") else 0.005
|
||||
provider.track_exa_usage(user_id, cost_total)
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[Podcast Exa Research] Failed to track usage: {track_err}")
|
||||
|
||||
sources_payload = []
|
||||
for src in sources:
|
||||
try:
|
||||
sources_payload.append(PodcastExaSource(**src))
|
||||
except Exception:
|
||||
sources_payload.append(PodcastExaSource(**{
|
||||
"title": src.get("title", ""),
|
||||
"url": src.get("url", ""),
|
||||
"excerpt": src.get("excerpt", ""),
|
||||
"published_at": src.get("published_at"),
|
||||
"highlights": src.get("highlights"),
|
||||
"summary": src.get("summary"),
|
||||
"source_type": src.get("source_type"),
|
||||
"index": src.get("index"),
|
||||
"image": src.get("image"),
|
||||
"author": src.get("author"),
|
||||
}))
|
||||
|
||||
return PodcastExaResearchResponse(
|
||||
sources=sources_payload,
|
||||
search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries,
|
||||
summary=summary,
|
||||
key_insights=key_insights,
|
||||
expert_quotes=expert_quotes,
|
||||
listener_cta=listener_cta,
|
||||
mapped_angles=mapped_angles,
|
||||
cost=result.get("cost") if isinstance(result, dict) else None,
|
||||
search_type=result.get("search_type") if isinstance(result, dict) else None,
|
||||
provider=result.get("provider", "exa") if isinstance(result, dict) else "exa",
|
||||
content=raw_content,
|
||||
)
|
||||
|
||||
183
_session_backup/script.py
Normal file
183
_session_backup/script.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""
|
||||
Podcast Script Handlers
|
||||
|
||||
Script generation endpoint.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastScriptRequest,
|
||||
PodcastScriptResponse,
|
||||
PodcastScene,
|
||||
PodcastSceneLine,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/script", response_model=PodcastScriptResponse)
|
||||
async def generate_podcast_script(
|
||||
request: PodcastScriptRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Generate a podcast script outline (scenes + lines) using podcast-oriented prompting.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Build comprehensive research context for higher-quality scripts
|
||||
research_context = ""
|
||||
if request.research:
|
||||
try:
|
||||
key_insights = request.research.get("keyword_analysis", {}).get("key_insights") or []
|
||||
fact_cards = request.research.get("factCards", []) or []
|
||||
mapped_angles = request.research.get("mappedAngles", []) or []
|
||||
sources = request.research.get("sources", []) or []
|
||||
|
||||
top_facts = [f.get("quote", "") for f in fact_cards[:5] if f.get("quote")]
|
||||
angles_summary = [
|
||||
f"{a.get('title', '')}: {a.get('why', '')}" for a in mapped_angles[:3] if a.get("title") or a.get("why")
|
||||
]
|
||||
top_sources = [s.get("url") for s in sources[:3] if s.get("url")]
|
||||
|
||||
research_parts = []
|
||||
if key_insights:
|
||||
research_parts.append(f"Key Insights: {', '.join(key_insights[:5])}")
|
||||
if top_facts:
|
||||
research_parts.append(f"Key Facts: {', '.join(top_facts)}")
|
||||
if angles_summary:
|
||||
research_parts.append(f"Research Angles: {' | '.join(angles_summary)}")
|
||||
if top_sources:
|
||||
research_parts.append(f"Top Sources: {', '.join(top_sources)}")
|
||||
|
||||
research_context = "\n".join(research_parts)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to parse research context: {exc}")
|
||||
research_context = ""
|
||||
|
||||
# Extract Podcast Bible context for hyper-personalization
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
bible_obj = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to serialize podcast bible: {exc}")
|
||||
|
||||
# Extract Analysis and Outline context for grounding
|
||||
analysis_context = ""
|
||||
if request.analysis:
|
||||
analysis_context = f"""
|
||||
TARGET AUDIENCE: {request.analysis.get('audience', 'General')}
|
||||
CONTENT TYPE: {request.analysis.get('contentType', 'Conversational')}
|
||||
TOP KEYWORDS: {', '.join(request.analysis.get('topKeywords', []))}
|
||||
"""
|
||||
|
||||
outline_context = ""
|
||||
if request.outline:
|
||||
outline_context = f"""
|
||||
REFINED EPISODE OUTLINE (Follow this structure closely):
|
||||
Title: {request.outline.get('title', 'N/A')}
|
||||
Segments: {' | '.join(request.outline.get('segments', []))}
|
||||
"""
|
||||
|
||||
prompt = f"""You are an expert podcast script planner. Create natural, conversational podcast scenes.
|
||||
|
||||
{f"PODCAST BIBLE (Hyper-Personalization Context):\n{bible_context}\n" if bible_context else ""}
|
||||
{f"ANALYSIS CONTEXT:\n{analysis_context}\n" if analysis_context else ""}
|
||||
{f"REFINED OUTLINE:\n{outline_context}\n" if outline_context else ""}
|
||||
|
||||
Podcast Idea: "{request.idea}"
|
||||
Duration: ~{request.duration_minutes} minutes
|
||||
Speakers: {request.speakers} (Host + optional Guest)
|
||||
|
||||
{f"RESEARCH CONTEXT:\n{research_context}\n" if research_context else ""}
|
||||
|
||||
Return JSON with:
|
||||
- scenes: array of scenes. Each scene has:
|
||||
- id: string
|
||||
- title: short scene title (<= 60 chars)
|
||||
- duration: duration in seconds (evenly split across total duration)
|
||||
- emotion: string (one of: "neutral", "happy", "excited", "serious", "curious", "confident")
|
||||
- lines: array of {{"speaker": "...", "text": "...", "emphasis": boolean}}
|
||||
* Write natural, conversational dialogue
|
||||
* Each line can be a sentence or a few sentences that flow together
|
||||
* Use plain text only - no markdown formatting (no asterisks, underscores, etc.)
|
||||
* Mark "emphasis": true for key statistics or important points
|
||||
|
||||
Guidelines:
|
||||
- Write for spoken delivery: conversational, natural, with contractions.
|
||||
- Follow the interaction tone specified in the Bible.
|
||||
- Ensure the Host persona matches the background and personality traits from the Bible.
|
||||
- Structure the intro and outro scenes according to the Bible's "Intro Format" and "Outro Format".
|
||||
- Adhere to any constraints mentioned in the Bible.
|
||||
- Use insights from the Research Context to ground the conversation in facts.
|
||||
- IMPORTANT: Follow the REFINED OUTLINE segments as the primary structure for the episode.
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt,
|
||||
user_id=user_id,
|
||||
json_struct=None,
|
||||
preferred_provider="huggingface",
|
||||
flow_type="premium_tool",
|
||||
)
|
||||
except Exception as exc:
|
||||
raise HTTPException(status_code=500, detail=f"Script generation failed: {exc}")
|
||||
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
raise HTTPException(status_code=500, detail="LLM returned non-JSON output")
|
||||
elif isinstance(raw, dict):
|
||||
data = raw
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unexpected LLM response format")
|
||||
|
||||
scenes_data = data.get("scenes") or []
|
||||
if not isinstance(scenes_data, list):
|
||||
raise HTTPException(status_code=500, detail="LLM response missing scenes array")
|
||||
|
||||
valid_emotions = {"neutral", "happy", "excited", "serious", "curious", "confident"}
|
||||
|
||||
# Normalize scenes
|
||||
scenes: list[PodcastScene] = []
|
||||
for idx, scene in enumerate(scenes_data):
|
||||
title = scene.get("title") or f"Scene {idx + 1}"
|
||||
duration = int(scene.get("duration") or max(30, (request.duration_minutes * 60) // max(1, len(scenes_data))))
|
||||
emotion = scene.get("emotion") or "neutral"
|
||||
if emotion not in valid_emotions:
|
||||
emotion = "neutral"
|
||||
lines_raw = scene.get("lines") or []
|
||||
lines: list[PodcastSceneLine] = []
|
||||
for line in lines_raw:
|
||||
speaker = line.get("speaker") or ("Host" if len(lines) % request.speakers == 0 else "Guest")
|
||||
text = line.get("text") or ""
|
||||
emphasis = line.get("emphasis", False)
|
||||
if text:
|
||||
lines.append(PodcastSceneLine(speaker=speaker, text=text, emphasis=emphasis))
|
||||
scenes.append(
|
||||
PodcastScene(
|
||||
id=scene.get("id") or f"scene-{idx + 1}",
|
||||
title=title,
|
||||
duration=duration,
|
||||
lines=lines,
|
||||
approved=False,
|
||||
emotion=emotion,
|
||||
)
|
||||
)
|
||||
|
||||
return PodcastScriptResponse(scenes=scenes)
|
||||
|
||||
209
_session_backup/types.ts
Normal file
209
_session_backup/types.ts
Normal file
@@ -0,0 +1,209 @@
|
||||
export type Knobs = {
|
||||
voice_emotion: string;
|
||||
voice_speed: number;
|
||||
resolution: string;
|
||||
scene_length_target: number;
|
||||
sample_rate: number;
|
||||
bitrate: string;
|
||||
};
|
||||
|
||||
export type Query = {
|
||||
id: string;
|
||||
query: string;
|
||||
rationale: string;
|
||||
needsRecentStats: boolean;
|
||||
};
|
||||
|
||||
export type Fact = {
|
||||
id: string;
|
||||
quote: string;
|
||||
url: string;
|
||||
date: string;
|
||||
confidence: number;
|
||||
image?: string;
|
||||
author?: string;
|
||||
highlights?: string[];
|
||||
};
|
||||
|
||||
export type ResearchInsight = {
|
||||
title: string;
|
||||
content: string;
|
||||
source_indices: number[];
|
||||
};
|
||||
|
||||
export type Research = {
|
||||
summary: string;
|
||||
keyInsights: ResearchInsight[];
|
||||
factCards: Fact[];
|
||||
mappedAngles: {
|
||||
title: string;
|
||||
why: string;
|
||||
mappedFactIds: string[];
|
||||
}[];
|
||||
searchQueries?: string[];
|
||||
searchType?: string;
|
||||
provider?: string;
|
||||
cost?: number;
|
||||
sourceCount?: number;
|
||||
expertQuotes?: { quote: string; source_index: number }[];
|
||||
listenerCta?: string[];
|
||||
};
|
||||
|
||||
export type Line = {
|
||||
id: string;
|
||||
speaker: string;
|
||||
text: string;
|
||||
usedFactIds?: string[];
|
||||
emphasis?: boolean; // Mark lines that need vocal emphasis
|
||||
};
|
||||
|
||||
export type Scene = {
|
||||
id: string;
|
||||
title: string;
|
||||
duration: number;
|
||||
lines: Line[];
|
||||
approved?: boolean;
|
||||
emotion?: string; // Scene-specific emotion
|
||||
audioUrl?: string; // Generated audio URL for this scene
|
||||
imageUrl?: string; // Generated image URL for this scene (for video generation)
|
||||
};
|
||||
|
||||
export type Script = {
|
||||
scenes: Scene[];
|
||||
};
|
||||
|
||||
export type JobStatus =
|
||||
| "idle"
|
||||
| "previewing"
|
||||
| "queued"
|
||||
| "running"
|
||||
| "completed"
|
||||
| "cancelled"
|
||||
| "failed";
|
||||
|
||||
export type Job = {
|
||||
sceneId: string;
|
||||
title: string;
|
||||
status: JobStatus;
|
||||
progress: number;
|
||||
previewUrl?: string | null;
|
||||
finalUrl?: string | null;
|
||||
videoUrl?: string | null;
|
||||
jobId?: string | null;
|
||||
taskId?: string | null;
|
||||
cost?: number | null;
|
||||
provider?: string | null;
|
||||
voiceId?: string | null;
|
||||
fileSize?: number | null;
|
||||
avatarImageUrl?: string | null;
|
||||
imageUrl?: string | null; // Scene-specific image URL
|
||||
};
|
||||
|
||||
export type PodcastAnalysis = {
|
||||
audience: string;
|
||||
contentType: string;
|
||||
topKeywords: string[];
|
||||
suggestedOutlines: { id: number | string; title: string; segments: string[] }[];
|
||||
suggestedKnobs: Knobs;
|
||||
titleSuggestions: string[];
|
||||
research_queries?: { query: string; rationale: string }[];
|
||||
exaSuggestedConfig?: {
|
||||
exa_search_type?: "auto" | "keyword" | "neural";
|
||||
exa_category?: string;
|
||||
exa_include_domains?: string[];
|
||||
exa_exclude_domains?: string[];
|
||||
max_sources?: number;
|
||||
include_statistics?: boolean;
|
||||
date_range?: string;
|
||||
};
|
||||
};
|
||||
|
||||
export type PodcastEstimate = {
|
||||
ttsCost: number;
|
||||
avatarCost: number;
|
||||
videoCost: number;
|
||||
researchCost: number;
|
||||
total: number;
|
||||
};
|
||||
|
||||
export type HostPersona = {
|
||||
name: string;
|
||||
background: string;
|
||||
expertise_level: string;
|
||||
personality_traits: string[];
|
||||
vocal_style: string;
|
||||
catchphrases: string[];
|
||||
};
|
||||
|
||||
export type AudienceDNA = {
|
||||
expertise_level: string;
|
||||
interests: string[];
|
||||
pain_points: string[];
|
||||
demographics?: string;
|
||||
};
|
||||
|
||||
export type BrandDNA = {
|
||||
industry: string;
|
||||
tone: string;
|
||||
communication_style: string;
|
||||
key_messages: string[];
|
||||
competitor_context?: string;
|
||||
};
|
||||
|
||||
export type PodcastBible = {
|
||||
project_id?: string;
|
||||
host: HostPersona;
|
||||
audience: AudienceDNA;
|
||||
brand: BrandDNA;
|
||||
};
|
||||
|
||||
export type CreateProjectPayload = {
|
||||
ideaOrUrl: string;
|
||||
speakers: number;
|
||||
duration: number;
|
||||
knobs: Knobs;
|
||||
budgetCap: number;
|
||||
files: { voiceFile?: File | null; avatarFile?: File | null };
|
||||
avatarUrl?: string | null;
|
||||
};
|
||||
|
||||
export type CreateProjectResult = {
|
||||
projectId: string;
|
||||
analysis: PodcastAnalysis;
|
||||
estimate: PodcastEstimate;
|
||||
queries: Query[];
|
||||
bible?: PodcastBible;
|
||||
avatar_url?: string | null;
|
||||
avatar_prompt?: string | null;
|
||||
};
|
||||
|
||||
export type RenderJobResult = {
|
||||
audioUrl: string;
|
||||
audioFilename: string;
|
||||
provider: string;
|
||||
model: string;
|
||||
cost: number;
|
||||
voiceId: string;
|
||||
fileSize: number;
|
||||
videoUrl?: string;
|
||||
videoFilename?: string;
|
||||
};
|
||||
|
||||
export interface VideoGenerationSettings {
|
||||
prompt: string;
|
||||
resolution: "480p" | "720p";
|
||||
seed?: number | null;
|
||||
maskImageUrl?: string | null;
|
||||
}
|
||||
|
||||
export type TaskStatus = {
|
||||
task_id: string;
|
||||
status: "pending" | "processing" | "completed" | "failed";
|
||||
progress?: number;
|
||||
message?: string;
|
||||
result?: any;
|
||||
error?: string;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
};
|
||||
|
||||
425
_session_backup/usePodcastWorkflow.ts
Normal file
425
_session_backup/usePodcastWorkflow.ts
Normal file
@@ -0,0 +1,425 @@
|
||||
import { useState, useEffect, useMemo, useCallback } from "react";
|
||||
import { podcastApi } from "../../../services/podcastApi";
|
||||
import { usePreflightCheck } from "../../../hooks/usePreflightCheck";
|
||||
import { useBudgetTracking } from "../../../hooks/useBudgetTracking";
|
||||
import { CreateProjectPayload, Script } from "../types";
|
||||
import { usePodcastProjectState } from "../../../hooks/usePodcastProjectState";
|
||||
import { sanitizeExaConfig, announceError, getStepLabel } from "./utils";
|
||||
|
||||
type PodcastProjectStateReturn = ReturnType<typeof usePodcastProjectState>;
|
||||
|
||||
interface UsePodcastWorkflowProps {
|
||||
projectState: PodcastProjectStateReturn;
|
||||
onError: (message: string) => void;
|
||||
}
|
||||
|
||||
export const usePodcastWorkflow = ({ projectState, onError }: UsePodcastWorkflowProps) => {
|
||||
const {
|
||||
project,
|
||||
analysis,
|
||||
queries,
|
||||
selectedQueries,
|
||||
research,
|
||||
rawResearch,
|
||||
researchProvider,
|
||||
showScriptEditor,
|
||||
showRenderQueue,
|
||||
currentStep,
|
||||
renderJobs,
|
||||
budgetCap,
|
||||
setProject,
|
||||
setAnalysis,
|
||||
setQueries,
|
||||
setSelectedQueries,
|
||||
setResearch,
|
||||
setRawResearch,
|
||||
setEstimate,
|
||||
setScriptData,
|
||||
setShowScriptEditor,
|
||||
setShowRenderQueue,
|
||||
setKnobs,
|
||||
setResearchProvider,
|
||||
setBudgetCap,
|
||||
updateRenderJob,
|
||||
initializeProject,
|
||||
setBible,
|
||||
} = projectState;
|
||||
|
||||
const [isAnalyzing, setIsAnalyzing] = useState(false);
|
||||
const [isResearching, setIsResearching] = useState(false);
|
||||
const [announcement, setAnnouncement] = useState("");
|
||||
const [showResumeAlert, setShowResumeAlert] = useState(false);
|
||||
const [showPreflightDialog, setShowPreflightDialog] = useState(false);
|
||||
const [preflightResponse, setPreflightResponse] = useState<any>(null);
|
||||
const [preflightOperationName, setPreflightOperationName] = useState<string>("");
|
||||
|
||||
const budgetTracking = useBudgetTracking(budgetCap || 50);
|
||||
const preflightCheck = usePreflightCheck({
|
||||
onBlocked: (response) => {
|
||||
setPreflightResponse(response);
|
||||
setShowPreflightDialog(true);
|
||||
},
|
||||
});
|
||||
|
||||
// Update budget cap when project state changes
|
||||
useEffect(() => {
|
||||
if (budgetCap) {
|
||||
budgetTracking.setBudgetCap(budgetCap);
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [budgetCap]);
|
||||
|
||||
// Check if we have a saved project on mount
|
||||
useEffect(() => {
|
||||
if (project && currentStep && currentStep !== "create") {
|
||||
setShowResumeAlert(true);
|
||||
setTimeout(() => setShowResumeAlert(false), 5000);
|
||||
}
|
||||
}, [project, currentStep]);
|
||||
|
||||
useEffect(() => {
|
||||
if (announcement) {
|
||||
const t = setTimeout(() => setAnnouncement(""), 4000);
|
||||
return () => clearTimeout(t);
|
||||
}
|
||||
return undefined;
|
||||
}, [announcement]);
|
||||
|
||||
const handleCreate = useCallback(async (payload: CreateProjectPayload, feedback?: string) => {
|
||||
if (isAnalyzing) return;
|
||||
setResearch(null);
|
||||
setRawResearch(null);
|
||||
setScriptData(null);
|
||||
setShowScriptEditor(false);
|
||||
setShowRenderQueue(false);
|
||||
try {
|
||||
setIsAnalyzing(true);
|
||||
|
||||
// Use existing avatar URL if provided (e.g. brand avatar), or upload new file
|
||||
let avatarUrl: string | null = payload.avatarUrl || null;
|
||||
if (payload.files.avatarFile) {
|
||||
try {
|
||||
setAnnouncement("Uploading presenter avatar...");
|
||||
const uploadResponse = await podcastApi.uploadAvatar(payload.files.avatarFile);
|
||||
avatarUrl = uploadResponse.avatar_url;
|
||||
} catch (error) {
|
||||
console.error('Avatar upload failed:', error);
|
||||
// Continue without avatar - will generate one later
|
||||
}
|
||||
}
|
||||
|
||||
// NEW FLOW: Create project first to generate/get the Podcast Bible
|
||||
// This allows the analysis to be personalized using the Bible context
|
||||
const projectId = project?.id || `podcast_${Date.now()}_${Math.floor(Math.random() * 1000)}`;
|
||||
setAnnouncement("Initializing project and brand context...");
|
||||
const dbProject = project ? null : await initializeProject(payload, projectId, avatarUrl);
|
||||
const bible = dbProject?.bible || projectState.bible;
|
||||
|
||||
setAnnouncement(feedback ? "Regenerating analysis using your feedback..." : "Analyzing your idea — AI suggestions incoming");
|
||||
const result = await podcastApi.createProject(payload, bible, feedback);
|
||||
|
||||
if (result.bible) {
|
||||
setBible(result.bible);
|
||||
} else if (dbProject?.bible) {
|
||||
setBible(dbProject.bible);
|
||||
}
|
||||
|
||||
// Update the project in database with the analysis results
|
||||
try {
|
||||
await podcastApi.updateProject(projectId, {
|
||||
analysis: result.analysis,
|
||||
estimate: result.estimate,
|
||||
queries: result.queries,
|
||||
selected_queries: result.queries.map(q => q.id),
|
||||
avatar_url: result.avatar_url,
|
||||
avatar_prompt: result.avatar_prompt,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error('Failed to update project with analysis results:', error);
|
||||
}
|
||||
|
||||
setProject({
|
||||
id: projectId,
|
||||
idea: payload.ideaOrUrl,
|
||||
duration: payload.duration,
|
||||
speakers: payload.speakers,
|
||||
avatarUrl: result.avatar_url || avatarUrl,
|
||||
avatarPrompt: result.avatar_prompt || null,
|
||||
avatarPersonaId: null,
|
||||
});
|
||||
|
||||
setAnalysis(result.analysis);
|
||||
setEstimate(result.estimate);
|
||||
setQueries(result.queries);
|
||||
setSelectedQueries(new Set(result.queries.map((q) => q.id)));
|
||||
setKnobs(payload.knobs);
|
||||
setBudgetCap(payload.budgetCap);
|
||||
|
||||
// Generate presenters AFTER analysis completes (to use analysis insights)
|
||||
// This happens only if no avatar was uploaded
|
||||
if (!avatarUrl && payload.speakers > 0 && result.analysis) {
|
||||
try {
|
||||
setAnnouncement("Generating presenter avatars using AI insights...");
|
||||
const presentersResponse = await podcastApi.generatePresenters(
|
||||
payload.speakers,
|
||||
result.projectId,
|
||||
result.analysis.audience,
|
||||
result.analysis.contentType,
|
||||
result.analysis.topKeywords
|
||||
);
|
||||
if (presentersResponse.avatars && presentersResponse.avatars.length > 0) {
|
||||
// Store the first presenter avatar URL and prompt
|
||||
const firstAvatar = presentersResponse.avatars[0];
|
||||
const prompt = firstAvatar.prompt || null;
|
||||
setProject({
|
||||
id: result.projectId,
|
||||
idea: payload.ideaOrUrl,
|
||||
duration: payload.duration,
|
||||
speakers: payload.speakers,
|
||||
avatarUrl: firstAvatar.avatar_url,
|
||||
avatarPrompt: prompt,
|
||||
avatarPersonaId: firstAvatar.persona_id || presentersResponse.persona_id || null,
|
||||
});
|
||||
setAnnouncement("Analysis complete - Presenter avatars generated");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Presenter generation failed:', error);
|
||||
setAnnouncement("Analysis complete - Avatar generation will happen later");
|
||||
// Continue without presenters - can generate later
|
||||
}
|
||||
} else {
|
||||
setAnnouncement("Analysis complete");
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error?.response?.status === 429 || error?.response?.data?.detail) {
|
||||
const errorDetail = error.response.data.detail;
|
||||
if (typeof errorDetail === 'object' && errorDetail.error && errorDetail.error.includes('limit')) {
|
||||
const usageInfo = errorDetail.usage_info || {};
|
||||
const blockedResponse = {
|
||||
can_proceed: false,
|
||||
estimated_cost: 0,
|
||||
operations: [{
|
||||
provider: errorDetail.provider || 'huggingface',
|
||||
operation_type: 'ai_text_generation',
|
||||
cost: 0,
|
||||
allowed: false,
|
||||
limit_info: usageInfo.limit_info || null,
|
||||
message: errorDetail.message || errorDetail.error || 'Subscription limit exceeded',
|
||||
}],
|
||||
total_cost: 0,
|
||||
usage_summary: usageInfo.usage_summary || null,
|
||||
cached: false,
|
||||
};
|
||||
setPreflightResponse(blockedResponse);
|
||||
setPreflightOperationName('Podcast Analysis');
|
||||
setShowPreflightDialog(true);
|
||||
setAnnouncement("Subscription limit reached. Please upgrade to continue.");
|
||||
} else {
|
||||
const message = typeof errorDetail === 'string' ? errorDetail : errorDetail.message || errorDetail.error || 'Request limit exceeded';
|
||||
announceError(setAnnouncement, new Error(message));
|
||||
}
|
||||
} else {
|
||||
announceError(setAnnouncement, error);
|
||||
}
|
||||
} finally {
|
||||
setIsAnalyzing(false);
|
||||
}
|
||||
}, [isAnalyzing, setResearch, setRawResearch, setScriptData, setShowScriptEditor, setShowRenderQueue, initializeProject, setProject, setAnalysis, setEstimate, setQueries, setSelectedQueries, setKnobs, setBudgetCap, setBible]);
|
||||
|
||||
const handleRunResearch = useCallback(async () => {
|
||||
if (isResearching) return;
|
||||
if (!project) {
|
||||
setAnnouncement("Create a project first.");
|
||||
return;
|
||||
}
|
||||
if (selectedQueries.size === 0) {
|
||||
setAnnouncement("Select at least one query to research.");
|
||||
return;
|
||||
}
|
||||
|
||||
setPreflightOperationName("Research");
|
||||
const approvedQueries = queries.filter((q) => selectedQueries.has(q.id));
|
||||
const preflightResult = await preflightCheck.check({
|
||||
provider: researchProvider === "exa" ? "exa" : "gemini",
|
||||
operation_type: researchProvider === "exa" ? "exa_neural_search" : "google_grounding",
|
||||
tokens_requested: researchProvider === "exa" ? 0 : 1200,
|
||||
actual_provider_name: researchProvider || "exa",
|
||||
});
|
||||
|
||||
if (!preflightResult.can_proceed) {
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setIsResearching(true);
|
||||
setAnnouncement(`Starting ${researchProvider === "exa" ? "deep" : "standard"} research — this may take a moment...`);
|
||||
setResearch(null);
|
||||
setRawResearch(null);
|
||||
setScriptData(null);
|
||||
setShowScriptEditor(false);
|
||||
setShowRenderQueue(false);
|
||||
|
||||
try {
|
||||
const { research: mapped, raw } = await podcastApi.runResearch({
|
||||
projectId: project.id,
|
||||
topic: project.idea,
|
||||
approvedQueries,
|
||||
provider: researchProvider,
|
||||
exaConfig: sanitizeExaConfig(analysis?.exaSuggestedConfig),
|
||||
bible: projectState.bible,
|
||||
analysis: analysis,
|
||||
onProgress: (message) => {
|
||||
setAnnouncement(message);
|
||||
},
|
||||
});
|
||||
setResearch(mapped);
|
||||
setRawResearch(raw);
|
||||
setAnnouncement("Research complete — review fact cards below");
|
||||
} catch (researchError) {
|
||||
const errorMessage = researchError instanceof Error
|
||||
? researchError.message
|
||||
: "Research failed. Please try again or switch to Standard Research.";
|
||||
|
||||
if (errorMessage.includes("Exa") || errorMessage.includes("exa")) {
|
||||
setAnnouncement(`Deep research failed: ${errorMessage}. Try Standard Research instead.`);
|
||||
} else if (errorMessage.includes("timeout")) {
|
||||
setAnnouncement("Research timed out. Please try again with fewer queries.");
|
||||
} else {
|
||||
setAnnouncement(`Research failed: ${errorMessage}`);
|
||||
}
|
||||
|
||||
console.error("Research error:", researchError);
|
||||
throw researchError;
|
||||
}
|
||||
} catch (error) {
|
||||
announceError(setAnnouncement, error);
|
||||
} finally {
|
||||
setIsResearching(false);
|
||||
}
|
||||
}, [isResearching, project, selectedQueries, queries, researchProvider, preflightCheck, analysis, setResearch, setRawResearch, setScriptData, setShowScriptEditor, setShowRenderQueue, projectState.bible]);
|
||||
|
||||
const handleGenerateScript = useCallback(async () => {
|
||||
if (showScriptEditor) return;
|
||||
if (!project || !research) {
|
||||
setAnnouncement("Project or research missing — cannot generate script");
|
||||
return;
|
||||
}
|
||||
|
||||
setPreflightOperationName("Script Generation");
|
||||
const preflightResult = await preflightCheck.check({
|
||||
provider: "gemini",
|
||||
operation_type: "script_generation",
|
||||
tokens_requested: 2000,
|
||||
actual_provider_name: "gemini",
|
||||
});
|
||||
|
||||
if (!preflightResult.can_proceed) {
|
||||
return;
|
||||
}
|
||||
|
||||
setScriptData(null);
|
||||
setShowRenderQueue(false);
|
||||
setShowScriptEditor(true);
|
||||
|
||||
try {
|
||||
const result = await podcastApi.generateScript({
|
||||
projectId: project.id,
|
||||
idea: project.idea,
|
||||
research: rawResearch,
|
||||
knobs: projectState.knobs,
|
||||
speakers: project.speakers,
|
||||
durationMinutes: project.duration,
|
||||
bible: projectState.bible,
|
||||
outline: analysis?.suggestedOutlines?.[0], // Pass the first (possibly refined) outline
|
||||
analysis: analysis, // Pass full analysis context
|
||||
});
|
||||
|
||||
setScriptData(result);
|
||||
} catch (error) {
|
||||
announceError(setAnnouncement, error);
|
||||
}
|
||||
}, [showScriptEditor, project, research, preflightCheck, setScriptData, setShowRenderQueue, setShowScriptEditor, rawResearch, projectState.knobs, projectState.bible])
|
||||
|
||||
const handleProceedToRendering = useCallback((script: Script) => {
|
||||
setScriptData(script);
|
||||
if (renderJobs.length === 0) {
|
||||
script.scenes.forEach((scene) => {
|
||||
const hasExistingAudio = Boolean(scene.audioUrl);
|
||||
updateRenderJob(scene.id, {
|
||||
sceneId: scene.id,
|
||||
title: scene.title,
|
||||
status: hasExistingAudio ? ("completed" as const) : ("idle" as const),
|
||||
progress: hasExistingAudio ? 100 : 0,
|
||||
previewUrl: null,
|
||||
finalUrl: hasExistingAudio ? scene.audioUrl : null,
|
||||
jobId: null,
|
||||
});
|
||||
});
|
||||
}
|
||||
setShowRenderQueue(true);
|
||||
setShowScriptEditor(false);
|
||||
}, [renderJobs.length, setScriptData, updateRenderJob, setShowRenderQueue, setShowScriptEditor]);
|
||||
|
||||
const toggleQuery = useCallback((id: string) => {
|
||||
if (isResearching) return;
|
||||
const current = selectedQueries;
|
||||
const next = new Set<string>(current);
|
||||
if (next.has(id)) next.delete(id);
|
||||
else next.add(id);
|
||||
setSelectedQueries(next);
|
||||
}, [isResearching, selectedQueries, setSelectedQueries]);
|
||||
|
||||
const activeStep = useMemo(() => {
|
||||
if (showRenderQueue) return 3;
|
||||
if (showScriptEditor) return 2;
|
||||
if (currentStep === 'research' || research) return 1;
|
||||
if (currentStep === 'analysis' || analysis) return 0;
|
||||
return -1;
|
||||
}, [showRenderQueue, showScriptEditor, currentStep, research, analysis]);
|
||||
|
||||
const canGenerateScript = Boolean(project && research && rawResearch);
|
||||
|
||||
const handleRegenerate = useCallback(async (feedback?: string) => {
|
||||
if (!project) return;
|
||||
|
||||
// Prepare the payload from existing project state
|
||||
const payload: CreateProjectPayload = {
|
||||
ideaOrUrl: project.idea,
|
||||
duration: project.duration,
|
||||
speakers: project.speakers,
|
||||
knobs: projectState.knobs,
|
||||
budgetCap: projectState.budgetCap,
|
||||
avatarUrl: project.avatarUrl,
|
||||
files: {} // No new files for regeneration
|
||||
};
|
||||
|
||||
await handleCreate(payload, feedback);
|
||||
}, [project, projectState.knobs, projectState.budgetCap, handleCreate]);
|
||||
|
||||
return {
|
||||
// State
|
||||
isAnalyzing,
|
||||
isResearching,
|
||||
announcement,
|
||||
showResumeAlert,
|
||||
showPreflightDialog,
|
||||
preflightResponse,
|
||||
preflightOperationName,
|
||||
activeStep,
|
||||
canGenerateScript,
|
||||
// Handlers
|
||||
handleCreate,
|
||||
handleRegenerate,
|
||||
handleRunResearch,
|
||||
handleGenerateScript,
|
||||
handleProceedToRendering,
|
||||
toggleQuery,
|
||||
setAnnouncement,
|
||||
setShowResumeAlert,
|
||||
setShowPreflightDialog,
|
||||
setPreflightResponse,
|
||||
setResearchProvider,
|
||||
getStepLabel,
|
||||
};
|
||||
};
|
||||
|
||||
184
add_missing_columns.py
Normal file
184
add_missing_columns.py
Normal file
@@ -0,0 +1,184 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Migration script to add missing columns to usage_summaries table.
|
||||
Run this once to fix the database schema.
|
||||
|
||||
Usage:
|
||||
python add_missing_columns.py
|
||||
"""
|
||||
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
def get_db_path():
|
||||
"""Find the database path."""
|
||||
possible_paths = [
|
||||
Path(__file__).parent / "backend" / "alwrity.db",
|
||||
Path(__file__).parent.parent / "backend" / "alwrity.db",
|
||||
Path("C:/Users/diksha rawat/Desktop/ALwrity_github/windsurf/ALwrity/backend/alwrity.db"),
|
||||
]
|
||||
|
||||
for db_path in possible_paths:
|
||||
if db_path.exists():
|
||||
print(f"Using database: {db_path}")
|
||||
return db_path
|
||||
|
||||
backend_dir = Path(__file__).parent / "backend"
|
||||
if backend_dir.exists():
|
||||
db_files = list(backend_dir.glob("*.db"))
|
||||
if db_files:
|
||||
print(f"Found database: {db_files[0]}")
|
||||
return db_files[0]
|
||||
|
||||
raise FileNotFoundError(f"Database not found. Searched: {possible_paths}")
|
||||
|
||||
def create_usage_summaries_table(cursor):
|
||||
"""Create the usage_summaries table if it doesn't exist."""
|
||||
cursor.execute("""
|
||||
CREATE TABLE IF NOT EXISTS usage_summaries (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
user_id VARCHAR(100) NOT NULL,
|
||||
billing_period VARCHAR(20) NOT NULL,
|
||||
|
||||
-- API Call Counts
|
||||
gemini_calls INTEGER DEFAULT 0,
|
||||
openai_calls INTEGER DEFAULT 0,
|
||||
anthropic_calls INTEGER DEFAULT 0,
|
||||
mistral_calls INTEGER DEFAULT 0,
|
||||
wavespeed_calls INTEGER DEFAULT 0,
|
||||
tavily_calls INTEGER DEFAULT 0,
|
||||
serper_calls INTEGER DEFAULT 0,
|
||||
metaphor_calls INTEGER DEFAULT 0,
|
||||
firecrawl_calls INTEGER DEFAULT 0,
|
||||
stability_calls INTEGER DEFAULT 0,
|
||||
exa_calls INTEGER DEFAULT 0,
|
||||
video_calls INTEGER DEFAULT 0,
|
||||
image_edit_calls INTEGER DEFAULT 0,
|
||||
audio_calls INTEGER DEFAULT 0,
|
||||
|
||||
-- Token Usage
|
||||
gemini_tokens INTEGER DEFAULT 0,
|
||||
openai_tokens INTEGER DEFAULT 0,
|
||||
anthropic_tokens INTEGER DEFAULT 0,
|
||||
mistral_tokens INTEGER DEFAULT 0,
|
||||
wavespeed_tokens INTEGER DEFAULT 0,
|
||||
|
||||
-- Cost Tracking
|
||||
gemini_cost REAL DEFAULT 0.0,
|
||||
openai_cost REAL DEFAULT 0.0,
|
||||
anthropic_cost REAL DEFAULT 0.0,
|
||||
mistral_cost REAL DEFAULT 0.0,
|
||||
wavespeed_cost REAL DEFAULT 0.0,
|
||||
tavily_cost REAL DEFAULT 0.0,
|
||||
serper_cost REAL DEFAULT 0.0,
|
||||
metaphor_cost REAL DEFAULT 0.0,
|
||||
firecrawl_cost REAL DEFAULT 0.0,
|
||||
stability_cost REAL DEFAULT 0.0,
|
||||
exa_cost REAL DEFAULT 0.0,
|
||||
video_cost REAL DEFAULT 0.0,
|
||||
image_edit_cost REAL DEFAULT 0.0,
|
||||
audio_cost REAL DEFAULT 0.0,
|
||||
|
||||
-- Totals
|
||||
total_calls INTEGER DEFAULT 0,
|
||||
total_tokens INTEGER DEFAULT 0,
|
||||
total_cost REAL DEFAULT 0.0,
|
||||
|
||||
-- Performance Metrics
|
||||
avg_response_time REAL DEFAULT 0.0,
|
||||
error_rate REAL DEFAULT 0.0,
|
||||
usage_status VARCHAR(20) DEFAULT 'active',
|
||||
warnings_sent INTEGER DEFAULT 0,
|
||||
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
UNIQUE(user_id, billing_period)
|
||||
)
|
||||
""")
|
||||
print("Created usage_summaries table")
|
||||
|
||||
def add_missing_columns():
|
||||
db_path = get_db_path()
|
||||
print(f"Using database: {db_path}")
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Check what tables exist
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
print(f"Tables in database: {tables}")
|
||||
|
||||
# Check if usage_summaries exists
|
||||
if "usage_summaries" not in tables:
|
||||
print("usage_summaries table doesn't exist. Creating it...")
|
||||
create_usage_summaries_table(cursor)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
print("Done! Table created successfully.")
|
||||
return
|
||||
|
||||
# Get existing columns
|
||||
cursor.execute("PRAGMA table_info(usage_summaries)")
|
||||
existing_columns = {row[1] for row in cursor.fetchall()}
|
||||
print(f"Existing columns in usage_summaries: {len(existing_columns)}")
|
||||
|
||||
# Columns to add (name, type, default)
|
||||
columns_to_add = [
|
||||
# Call counts
|
||||
("wavespeed_calls", "INTEGER", "0"),
|
||||
("tavily_calls", "INTEGER", "0"),
|
||||
("serper_calls", "INTEGER", "0"),
|
||||
("metaphor_calls", "INTEGER", "0"),
|
||||
("firecrawl_calls", "INTEGER", "0"),
|
||||
("stability_calls", "INTEGER", "0"),
|
||||
("exa_calls", "INTEGER", "0"),
|
||||
("video_calls", "INTEGER", "0"),
|
||||
("image_edit_calls", "INTEGER", "0"),
|
||||
("audio_calls", "INTEGER", "0"),
|
||||
# Token usage
|
||||
("wavespeed_tokens", "INTEGER", "0"),
|
||||
# Cost tracking
|
||||
("wavespeed_cost", "REAL", "0.0"),
|
||||
("tavily_cost", "REAL", "0.0"),
|
||||
("serper_cost", "REAL", "0.0"),
|
||||
("metaphor_cost", "REAL", "0.0"),
|
||||
("firecrawl_cost", "REAL", "0.0"),
|
||||
("stability_cost", "REAL", "0.0"),
|
||||
("exa_cost", "REAL", "0.0"),
|
||||
("video_cost", "REAL", "0.0"),
|
||||
("image_edit_cost", "REAL", "0.0"),
|
||||
("audio_cost", "REAL", "0.0"),
|
||||
]
|
||||
|
||||
added = []
|
||||
skipped = []
|
||||
|
||||
for col_name, col_type, default in columns_to_add:
|
||||
if col_name in existing_columns:
|
||||
skipped.append(col_name)
|
||||
continue
|
||||
|
||||
try:
|
||||
sql = f"ALTER TABLE usage_summaries ADD COLUMN {col_name} {col_type} DEFAULT {default}"
|
||||
cursor.execute(sql)
|
||||
added.append(col_name)
|
||||
print(f" Added: {col_name}")
|
||||
except sqlite3.Error as e:
|
||||
print(f" Error adding {col_name}: {e}")
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
print(f"\nSummary:")
|
||||
print(f" Added: {len(added)} columns")
|
||||
print(f" Skipped (already exist): {len(skipped)} columns")
|
||||
|
||||
if added:
|
||||
print(f"\nColumns added: {', '.join(added)}")
|
||||
if skipped:
|
||||
print(f"Already existed: {', '.join(skipped)}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
add_missing_columns()
|
||||
@@ -1,157 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# Add _get_all_historical_usage method to usage_tracking_service.py
|
||||
|
||||
with open('services/subscription/usage_tracking_service.py', 'r', encoding='utf-8') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Find where to insert (before get_usage_trends)
|
||||
insert_idx = None
|
||||
for i, line in enumerate(lines):
|
||||
if ' def get_usage_trends(' in line:
|
||||
insert_idx = i
|
||||
break
|
||||
|
||||
if insert_idx is None:
|
||||
print("Error: Could not find insertion point")
|
||||
exit(1)
|
||||
|
||||
print(f"Inserting at line {insert_idx + 1}")
|
||||
|
||||
# Method to insert
|
||||
new_method = ''' def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
|
||||
# Get all usage summaries for the user
|
||||
all_summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
if not all_summaries:
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'limits': self.pricing_service.get_user_limits(user_id),
|
||||
'provider_breakdown': {},
|
||||
'usage_percentages': {},
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Aggregate all data from UsageSummary
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
# Calculate weighted average response time
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
# Calculate overall error rate
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Map database columns to frontend keys
|
||||
provider_mapping = {
|
||||
'gemini_calls': 'gemini',
|
||||
'openai_calls': 'openai',
|
||||
'anthropic_calls': 'anthropic',
|
||||
'mistral_calls': 'huggingface',
|
||||
'wavespeed_calls': 'wavespeed',
|
||||
'exa_calls': 'exa',
|
||||
'video_calls': 'video',
|
||||
'image_edit_calls': 'image_edit',
|
||||
'audio_calls': 'audio',
|
||||
}
|
||||
|
||||
# Build provider_breakdown for frontend
|
||||
provider_breakdown = {}
|
||||
for db_col, frontend_key in provider_mapping.items():
|
||||
total_provider_calls = sum(getattr(s, db_col, 0) or 0 for s in all_summaries)
|
||||
provider_breakdown[frontend_key] = {
|
||||
'calls': total_provider_calls,
|
||||
'cost': 0,
|
||||
'tokens': 0
|
||||
}
|
||||
|
||||
# Calculate usage_percentages based on limits
|
||||
usage_percentages = {}
|
||||
if limits and limits.get('limits'):
|
||||
# Gemini calls percentage
|
||||
gemini_calls = provider_breakdown.get('gemini', {}).get('calls', 0)
|
||||
gemini_limit = limits.get('limits', {}).get('gemini_calls', 0) or 0
|
||||
if gemini_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (gemini_calls / gemini_limit) * 100
|
||||
|
||||
# HuggingFace calls percentage (from mistral_calls)
|
||||
huggingface_calls = provider_breakdown.get('huggingface', {}).get('calls', 0)
|
||||
huggingface_limit = limits.get('limits', {}).get('mistral_calls', 0) or 0
|
||||
if huggingface_limit > 0:
|
||||
usage_percentages['huggingface_calls'] = (huggingface_calls / huggingface_limit) * 100
|
||||
|
||||
# Cost percentage
|
||||
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (total_cost / cost_limit) * 100
|
||||
|
||||
# Build historical breakdown
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
'billing_period': s.billing_period,
|
||||
'total_calls': s.total_calls or 0,
|
||||
'total_tokens': s.total_tokens or 0,
|
||||
'total_cost': float(s.total_cost or 0),
|
||||
'usage_status': status_val,
|
||||
'updated_at': s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
# Determine overall status
|
||||
usage_status = 'active'
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status = s.usage_status.value
|
||||
except:
|
||||
status = str(s.usage_status)
|
||||
if status == 'limit_reached':
|
||||
usage_status = 'limit_reached'
|
||||
break
|
||||
elif status == 'warning' and usage_status != 'limit_reached':
|
||||
usage_status = 'warning'
|
||||
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': usage_status,
|
||||
'total_calls': total_calls,
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': round(total_cost, 2),
|
||||
'avg_response_time': round(avg_response_time, 2),
|
||||
'error_rate': round(error_rate, 2),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': historical_breakdown,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
'''
|
||||
|
||||
# Insert the new method
|
||||
new_lines = lines[:insert_idx] + [new_method] + lines[insert_idx:]
|
||||
|
||||
# Write back
|
||||
with open('services/subscription/usage_tracking_service.py', 'w', encoding='utf-8') as f:
|
||||
f.writelines(new_lines)
|
||||
|
||||
print("Successfully added _get_all_historical_usage method")
|
||||
@@ -5,8 +5,8 @@ Modular utilities for ALwrity backend startup and configuration.
|
||||
|
||||
import os
|
||||
|
||||
# Check feature mode early to skip heavy imports
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
# Check podcast mode early to skip heavy imports
|
||||
_is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
|
||||
from .dependency_manager import DependencyManager
|
||||
from .environment_setup import EnvironmentSetup
|
||||
@@ -26,25 +26,41 @@ from .feature_runtime import (
|
||||
)
|
||||
|
||||
# Lazy load OnboardingManager - it triggers heavy imports (aiohttp, etc.)
|
||||
if _is_full_mode:
|
||||
if not _is_podcast:
|
||||
from .onboarding_manager import OnboardingManager
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
else:
|
||||
OnboardingManager = None
|
||||
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
__all__ = [
|
||||
'DependencyManager',
|
||||
'EnvironmentSetup',
|
||||
'DatabaseSetup',
|
||||
'ProductionOptimizer',
|
||||
'HealthChecker',
|
||||
'RateLimiter',
|
||||
'FrontendServing',
|
||||
'RouterManager',
|
||||
'OnboardingManager',
|
||||
'get_active_profiles',
|
||||
'get_enabled_groups',
|
||||
'get_enabled_optional_services',
|
||||
'get_enabled_routers',
|
||||
'get_enabled_startup_hooks',
|
||||
'is_enabled'
|
||||
]
|
||||
|
||||
@@ -51,13 +51,6 @@ FEATURE_GROUPS: Dict[str, FeatureGroup] = {
|
||||
"api.content_planning.strategy_copilot:router",
|
||||
),
|
||||
),
|
||||
"blog_writer": FeatureGroup(
|
||||
features=("blog_writer",),
|
||||
routers=(
|
||||
"api.blog_writer.router:router",
|
||||
"api.blog_writer.seo_analysis:router",
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
@@ -66,6 +59,5 @@ PROFILE_GROUP_MAP: Dict[str, Tuple[str, ...]] = {
|
||||
"core": ("core",),
|
||||
"podcast": ("core", "podcast"),
|
||||
"youtube": ("core", "youtube"),
|
||||
"blog_writer": ("core", "blog_writer"),
|
||||
"planning": ("core", "content_planning"),
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ from loguru import logger
|
||||
|
||||
CORE_ROUTER_REGISTRY = [
|
||||
{"name": "component_logic", "module": "api.component_logic", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog_writer", "youtube"}},
|
||||
{"name": "subscription", "module": "api.subscription", "attr": "router", "features": {"all", "core", "podcast", "blog-writer", "youtube"}},
|
||||
{"name": "step3_research", "module": "api.onboarding_utils.step3_routes", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "step4_assets", "module": "api.onboarding_utils.step4_asset_routes", "attr": "router", "features": {"all", "core", "podcast"}},
|
||||
{"name": "step4_persona", "module": "api.onboarding_utils.step4_persona_routes_optimized", "attr": "router", "features": {"all", "core"}},
|
||||
@@ -29,31 +29,31 @@ CORE_ROUTER_REGISTRY = [
|
||||
{"name": "linkedin_image", "module": "api.linkedin_image_generation", "attr": "router", "features": {"all", "core", "linkedin"}},
|
||||
{"name": "brainstorm", "module": "api.brainstorm", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "hallucination_detector", "module": "api.hallucination_detector", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content_planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core", "blog_writer"}},
|
||||
{"name": "writing_assistant", "module": "api.writing_assistant", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "content_planning", "module": "api.content_planning.api.router", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "user_data", "module": "api.user_data", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "user_environment", "module": "api.user_environment", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "strategy_copilot", "module": "api.content_planning.strategy_copilot", "attr": "router", "features": {"all", "core", "content-planning"}},
|
||||
{"name": "error_logging", "module": "routers.error_logging", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "frontend_env_manager", "module": "routers.frontend_env_manager", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "platform_analytics", "module": "routers.platform_analytics", "attr": "router", "features": {"all", "core"}},
|
||||
{"name": "bing_insights", "module": "routers.bing_insights", "attr": "router", "features": {"all", "core", "seo"}},
|
||||
{"name": "background_jobs", "module": "routers.background_jobs", "attr": "router", "features": {"all", "core"}},
|
||||
]
|
||||
|
||||
OPTIONAL_ROUTER_REGISTRY = [
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story_writer"}},
|
||||
{"name": "blog_writer", "module": "api.blog_writer.router", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "story_writer", "module": "api.story_writer.router", "attr": "router", "features": {"all", "story-writer"}},
|
||||
{"name": "wix", "module": "api.wix_routes", "attr": "router", "features": {"all"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog_writer"}},
|
||||
{"name": "blog_seo_analysis", "module": "api.blog_writer.seo_analysis", "attr": "router", "features": {"all", "blog-writer"}},
|
||||
{"name": "persona", "module": "api.persona_routes", "attr": "router", "features": {"all", "persona"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video_studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image_studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product_marketing"}},
|
||||
{"name": "video_studio", "module": "api.video_studio.router", "attr": "router", "features": {"all", "video-studio"}},
|
||||
{"name": "stability", "module": "routers.stability", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_advanced", "module": "routers.stability_advanced", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "stability_admin", "module": "routers.stability_admin", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "images", "module": "api.images", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "image_studio", "module": "routers.image_studio", "attr": "router", "features": {"all", "image-studio"}},
|
||||
{"name": "product_marketing", "module": "routers.product_marketing", "attr": "router", "features": {"all", "product-marketing"}},
|
||||
{"name": "campaign_creator", "module": "routers.campaign_creator", "attr": "router", "features": {"all"}},
|
||||
{"name": "content_assets", "module": "api.content_assets.router", "attr": "router", "features": {"all"}},
|
||||
{"name": "podcast", "module": "api.podcast.router", "attr": "router", "features": {"all", "podcast"}},
|
||||
|
||||
@@ -7,11 +7,12 @@ The onboarding endpoints are re-exported from a stable module
|
||||
|
||||
import os
|
||||
|
||||
# In feature-only modes, don't import heavy onboarding endpoints
|
||||
# They trigger heavy dependencies (exa_py, etc.)
|
||||
_is_full_mode = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() in ("", "all")
|
||||
# Check podcast mode early
|
||||
_is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
|
||||
if not _is_full_mode:
|
||||
# In podcast mode, don't import heavy onboarding endpoints
|
||||
# They trigger heavy dependencies (exa_py, etc.)
|
||||
if _is_podcast:
|
||||
__all__ = []
|
||||
else:
|
||||
from .onboarding_endpoints import (
|
||||
|
||||
@@ -1195,68 +1195,3 @@ async def generate_introductions(
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate introductions: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ---------------------------
|
||||
# Save Complete Blog Asset
|
||||
# ---------------------------
|
||||
|
||||
|
||||
class SaveCompleteBlogAssetRequest(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
seo_title: Optional[str] = None
|
||||
meta_description: Optional[str] = None
|
||||
focus_keyword: Optional[str] = None
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
categories: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
@router.post("/save-complete-asset")
|
||||
async def save_complete_blog_asset(
|
||||
request: SaveCompleteBlogAssetRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
) -> Dict[str, Any]:
|
||||
"""Save the complete blog content as a single asset in the asset library."""
|
||||
try:
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get('id', ''))
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="Invalid user ID in authentication token")
|
||||
|
||||
full_content = f"# {request.title}\n\n{request.content}"
|
||||
|
||||
asset_id = save_and_track_text_content(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
content=full_content,
|
||||
source_module="blog_writer",
|
||||
title=f"Published Blog: {request.title[:60]}",
|
||||
description=request.meta_description or f"Complete published blog post: {request.title}",
|
||||
prompt=f"SEO Title: {request.seo_title or request.title}\nFocus Keyword: {request.focus_keyword or ''}",
|
||||
tags=["blog", "published"] + [t for t in (request.tags or []) if t],
|
||||
asset_metadata={
|
||||
"status": "published",
|
||||
"focus_keyword": request.focus_keyword,
|
||||
"categories": request.categories,
|
||||
"word_count": len(full_content.split()),
|
||||
},
|
||||
subdirectory="published",
|
||||
file_extension=".md"
|
||||
)
|
||||
|
||||
if asset_id:
|
||||
logger.info(f"✅ Complete blog asset saved to library: ID={asset_id}")
|
||||
return {"success": True, "asset_id": asset_id}
|
||||
else:
|
||||
logger.warning("save_and_track_text_content returned None for published blog")
|
||||
return {"success": False, "error": "Failed to save blog asset"}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save complete blog asset: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -13,7 +13,7 @@ from typing import Any, Dict, List
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from services.database import get_session_for_user
|
||||
from services.database import SessionLocal, get_session_for_user
|
||||
|
||||
from models.blog_models import (
|
||||
BlogResearchRequest,
|
||||
@@ -264,7 +264,7 @@ class TaskManager:
|
||||
raise ValueError("Global target words exceed 1000; medium generation not allowed")
|
||||
|
||||
# Create a sync session for asset saving
|
||||
db_session = get_session_for_user(user_id)
|
||||
db_session = SessionLocal()
|
||||
try:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
@@ -326,7 +326,6 @@ class TaskManager:
|
||||
await self.update_progress(task_id, f"❌ Medium generation failed: {str(e)}")
|
||||
self.task_storage[task_id]["status"] = "failed"
|
||||
self.task_storage[task_id]["error"] = str(e)
|
||||
self.task_storage[task_id]["error_data"] = {"error_message": str(e), "error_type": type(e).__name__}
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
|
||||
@@ -202,26 +202,6 @@ Listener CTA: {request.analysis.get('listener_cta', 'N/A')}
|
||||
interests = ", ".join(audience_dna.get("interests", []))
|
||||
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
|
||||
|
||||
# Preflight subscription check for Exa
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'exa', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[Podcast Research] Preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast Research] Preflight check failed: {e}")
|
||||
|
||||
try:
|
||||
# 1. RUN EXA SEARCH
|
||||
logger.warning(f"[Podcast Research] Calling Exa search with topic: {request.topic[:100]}...")
|
||||
|
||||
@@ -9,13 +9,10 @@ from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
from types import SimpleNamespace
|
||||
from sqlalchemy import text
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.research.tavily_service import TavilyService
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
|
||||
router = APIRouter(prefix="/research", tags=["Podcast Category Research"])
|
||||
|
||||
@@ -32,75 +29,6 @@ EXA_CATEGORY_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _preflight_check(user_id: str, provider: APIProvider, provider_name: str):
|
||||
"""Check subscription limits before making a research API call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=0,
|
||||
actual_provider_name=provider_name,
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': provider_name, 'usage_info': usage_info or {}
|
||||
})
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[CategoryResearch] Preflight check failed for {provider_name}: {e}")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_research_usage(user_id: str, provider_name: str, cost: float, calls_column: str, cost_column: str):
|
||||
"""Track research API usage after successful call."""
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[CategoryResearch] Could not get DB session for user {user_id}")
|
||||
return
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
update_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {calls_column} = COALESCE({calls_column}, 0) + 1,
|
||||
{cost_column} = COALESCE({cost_column}, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[CategoryResearch] Tracked {provider_name} usage: user={user_id}, cost=${cost}")
|
||||
|
||||
# Clear dashboard cache so header stats update immediately
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"[CategoryResearch] Failed to clear dashboard cache: {cache_err}")
|
||||
except Exception as e:
|
||||
logger.error(f"[CategoryResearch] Failed to track {provider_name} usage: {e}")
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
class CategoryResearchRequest(BaseModel):
|
||||
category: str
|
||||
keyword: Optional[str] = None
|
||||
@@ -152,12 +80,9 @@ def _normalize_exa_results(results: List[Dict], query: str) -> List[CategoryTopi
|
||||
return topics
|
||||
|
||||
|
||||
async def _search_tavily(category: str, keyword: str, max_results: int, user_id: str) -> CategoryResearchResponse:
|
||||
async def _search_tavily(category: str, keyword: str, max_results: int) -> CategoryResearchResponse:
|
||||
logger.info(f"[CategoryResearch] Using Tavily for category={category}, keyword={keyword}")
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.TAVILY, "tavily")
|
||||
|
||||
|
||||
try:
|
||||
tavily = TavilyService()
|
||||
result = await tavily.search(
|
||||
@@ -177,10 +102,6 @@ async def _search_tavily(category: str, keyword: str, max_results: int, user_id:
|
||||
topics = _normalize_tavily_results(result.get("results", []))
|
||||
logger.info(f"[CategoryResearch] Tavily found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.001 # basic search = 1 credit
|
||||
_track_research_usage(user_id, "tavily", cost, "tavily_calls", "tavily_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
@@ -196,7 +117,7 @@ async def _search_tavily(category: str, keyword: str, max_results: int, user_id:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _search_exa(category: str, keyword: str, max_results: int, user_id: str, website_url: Optional[str] = None) -> CategoryResearchResponse:
|
||||
async def _search_exa(category: str, keyword: str, max_results: int, website_url: Optional[str] = None) -> CategoryResearchResponse:
|
||||
exa_category = EXA_CATEGORY_MAP.get(category, category)
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa: category={category}, exa_category={exa_category}, keyword={keyword}, website_url={website_url}")
|
||||
@@ -212,9 +133,6 @@ async def _search_exa(category: str, keyword: str, max_results: int, user_id: st
|
||||
from exa_py import Exa
|
||||
exa = Exa(exa_api_key)
|
||||
logger.info(f"[CategoryResearch] Exa client initialized")
|
||||
|
||||
# Preflight subscription check
|
||||
_preflight_check(user_id, APIProvider.EXA, "exa")
|
||||
|
||||
# Build search parameters
|
||||
search_params = {
|
||||
@@ -271,10 +189,6 @@ async def _search_exa(category: str, keyword: str, max_results: int, user_id: st
|
||||
|
||||
logger.info(f"[CategoryResearch] Exa found {len(topics)} topics")
|
||||
|
||||
# Track usage
|
||||
cost = 0.005 # Default Exa cost for 1-25 results
|
||||
_track_research_usage(user_id, "exa", cost, "exa_calls", "exa_cost")
|
||||
|
||||
return CategoryResearchResponse(
|
||||
success=True,
|
||||
category=category,
|
||||
@@ -304,7 +218,6 @@ async def research_by_category(
|
||||
- news, finance: Uses Tavily
|
||||
- research-paper, personal-site: Uses Exa
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
category = request.category.lower()
|
||||
valid_categories = list(CATEGORY_PROVIDER_MAP.keys())
|
||||
|
||||
@@ -328,9 +241,9 @@ async def research_by_category(
|
||||
|
||||
try:
|
||||
if provider == "tavily":
|
||||
return await _search_tavily(category, keyword, max_results, user_id)
|
||||
return await _search_tavily(category, keyword, max_results)
|
||||
elif provider == "exa":
|
||||
return await _search_exa(category, keyword, max_results, user_id, website_url)
|
||||
return await _search_exa(category, keyword, max_results, website_url)
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail="Unknown provider")
|
||||
except Exception as e:
|
||||
|
||||
@@ -4,7 +4,6 @@ Podcast Trends Handler
|
||||
Endpoints for fetching Google Trends data relevant to podcast topics.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any, List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -14,25 +13,6 @@ from middleware.auth_middleware import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/trends", tags=["Podcast Trends"])
|
||||
|
||||
# Module-level shared instance (singleton pattern)
|
||||
_trends_service_instance = None
|
||||
_trends_service_lock = None
|
||||
|
||||
|
||||
def get_trends_service():
|
||||
"""Get or create shared GoogleTrendsService instance."""
|
||||
global _trends_service_instance, _trends_service_lock
|
||||
if _trends_service_instance is None:
|
||||
try:
|
||||
from services.research.trends import GoogleTrendsService
|
||||
_trends_service_instance = GoogleTrendsService()
|
||||
_trends_service_lock = asyncio.Lock()
|
||||
logger.info("[Podcast Trends] Created shared GoogleTrendsService instance")
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] Failed to create GoogleTrendsService: {e}")
|
||||
raise
|
||||
return _trends_service_instance
|
||||
|
||||
|
||||
class PodcastTrendsRequest(BaseModel):
|
||||
keywords: List[str] = Field(..., min_length=1, max_length=5, description="1-5 keywords to analyze")
|
||||
@@ -58,7 +38,7 @@ async def get_podcast_trends(
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
try:
|
||||
service = get_trends_service()
|
||||
from services.research.trends import GoogleTrendsService
|
||||
except (ImportError, RuntimeError) as e:
|
||||
logger.error(f"[Podcast Trends] GoogleTrendsService unavailable: {e}")
|
||||
raise HTTPException(
|
||||
@@ -67,10 +47,11 @@ async def get_podcast_trends(
|
||||
)
|
||||
|
||||
try:
|
||||
service = GoogleTrendsService()
|
||||
# Map 'source' to 'gprop' - 'podcast' uses YouTube for video/podcast relevance
|
||||
gprop_map = {"": "", "web": "", "podcast": "youtube", "news": "news", "images": "images", "shopping": "froogle"}
|
||||
gprop = gprop_map.get(request.source, "")
|
||||
|
||||
|
||||
result = await service.analyze_trends(
|
||||
keywords=request.keywords,
|
||||
timeframe=request.timeframe,
|
||||
@@ -92,15 +73,7 @@ async def get_podcast_trends(
|
||||
# Return error if: has error OR no data (meaning blocked/empty)
|
||||
if has_error and not has_data:
|
||||
error_msg = result.get("error", "")
|
||||
cooldown_active = result.get("cooldown_active", False)
|
||||
logger.warning(f"[Trends] No data or error: {error_msg[:100]}")
|
||||
# Provide helpful message during cooldown
|
||||
if cooldown_active:
|
||||
return PodcastTrendsResponse(
|
||||
success=False,
|
||||
data=result,
|
||||
error="Google is rate limiting requests. Try using 'Get Trending Topics' instead, or wait 30 minutes."
|
||||
)
|
||||
return PodcastTrendsResponse(success=False, data=result, error=error_msg or "No trends data available. Google may be blocking requests.")
|
||||
|
||||
# Even if no error but empty data - return error
|
||||
|
||||
@@ -12,7 +12,7 @@ import sqlite3
|
||||
from services.database import get_db
|
||||
from services.subscription import UsageTrackingService, PricingService
|
||||
from services.subscription.schema_utils import ensure_subscription_plan_columns, ensure_usage_summaries_columns
|
||||
from models.subscription_models import UsageAlert, UserSubscription
|
||||
from models.subscription_models import UsageAlert
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from ..dependencies import verify_user_access
|
||||
from ..cache import get_cached_dashboard, set_cached_dashboard
|
||||
@@ -27,9 +27,7 @@ async def get_dashboard_data(
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""Get comprehensive dashboard data for usage monitoring.
|
||||
Returns all-time total + current period usage by default.
|
||||
When billing_period is specified, returns that period's data only."""
|
||||
"""Get comprehensive dashboard data for usage monitoring."""
|
||||
|
||||
verify_user_access(user_id, current_user)
|
||||
|
||||
@@ -37,23 +35,17 @@ async def get_dashboard_data(
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
|
||||
# Check cache first (only for default view, skip when a specific period is requested)
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data and not billing_period:
|
||||
return cached_data
|
||||
# Check cache first (skip if billing_period is specified)
|
||||
if not billing_period:
|
||||
cached_data = get_cached_dashboard(user_id)
|
||||
if cached_data:
|
||||
return cached_data
|
||||
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
# When a specific billing_period is requested, show only that period's data
|
||||
# Otherwise show all-time total + current period usage
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
# Get current usage stats (for the requested period)
|
||||
current_usage = usage_service.get_user_usage_stats(user_id, billing_period)
|
||||
|
||||
# Get usage trends (last 6 months)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
@@ -84,44 +76,13 @@ async def get_dashboard_data(
|
||||
]
|
||||
|
||||
# Calculate cost projections (only relevant for current month)
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
|
||||
# Determine if viewing current period based on subscription, not calendar
|
||||
subscription = db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
# Use subscription's billing period or fallback to calendar
|
||||
if subscription and subscription.current_period_start:
|
||||
sub_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Check if we have data for subscription period or calendar period
|
||||
from models.subscription_models import UsageSummary
|
||||
sub_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == sub_period
|
||||
).first()
|
||||
|
||||
# Determine which period to use for "current"
|
||||
if sub_data_exists:
|
||||
effective_period = sub_period
|
||||
else:
|
||||
# Check calendar period for backward compatibility
|
||||
cal_data_exists = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
effective_period = calendar_period if cal_data_exists else sub_period
|
||||
|
||||
is_current_period = not billing_period or billing_period == effective_period
|
||||
else:
|
||||
is_current_period = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
|
||||
if is_current_period:
|
||||
# Only project costs if viewing current month
|
||||
is_current_month = not billing_period or billing_period == datetime.now().strftime("%Y-%m")
|
||||
if is_current_month:
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
else:
|
||||
projected_cost = current_cost # For past months, projected is actual
|
||||
@@ -129,8 +90,7 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -140,9 +100,9 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
@@ -171,13 +131,7 @@ async def get_dashboard_data(
|
||||
usage_service = UsageTrackingService(db)
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
if billing_period:
|
||||
period_usage = usage_service.get_usage_for_period(user_id, billing_period)
|
||||
total_usage = period_usage
|
||||
current_period_usage = period_usage
|
||||
else:
|
||||
total_usage = usage_service.get_user_usage_stats(user_id, None)
|
||||
current_period_usage = usage_service.get_current_period_usage(user_id)
|
||||
current_usage = usage_service.get_user_usage_stats(user_id)
|
||||
trends = usage_service.get_usage_trends(user_id, 6)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
@@ -198,7 +152,7 @@ async def get_dashboard_data(
|
||||
for alert in alerts
|
||||
]
|
||||
|
||||
current_cost = total_usage.get('total_cost', 0)
|
||||
current_cost = current_usage.get('total_cost', 0)
|
||||
days_in_period = 30
|
||||
current_day = datetime.now().day
|
||||
projected_cost = (current_cost / current_day) * days_in_period if current_day > 0 else 0
|
||||
@@ -206,8 +160,7 @@ async def get_dashboard_data(
|
||||
response_payload = {
|
||||
"success": True,
|
||||
"data": {
|
||||
"total_usage": total_usage,
|
||||
"current_period_usage": current_period_usage,
|
||||
"current_usage": current_usage,
|
||||
"trends": trends,
|
||||
"limits": limits,
|
||||
"alerts": alerts_data,
|
||||
@@ -217,17 +170,16 @@ async def get_dashboard_data(
|
||||
"projected_usage_percentage": (projected_cost / max(limits.get('limits', {}).get('monthly_cost', 1), 1)) * 100 if limits else 0
|
||||
},
|
||||
"summary": {
|
||||
"total_api_calls_this_month": total_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": total_usage.get('total_cost', 0),
|
||||
"usage_status": total_usage.get('usage_status', 'active'),
|
||||
"total_api_calls_this_month": current_usage.get('total_calls', 0),
|
||||
"total_cost_this_month": current_usage.get('total_cost', 0),
|
||||
"usage_status": current_usage.get('usage_status', 'active'),
|
||||
"unread_alerts": len(alerts_data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# Cache the response after successful retry (only for default view)
|
||||
if not billing_period:
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
# Cache the response after successful retry
|
||||
set_cached_dashboard(user_id, response_payload)
|
||||
return response_payload
|
||||
except Exception as retry_err:
|
||||
logger.error(f"Schema fix and retry failed: {retry_err}")
|
||||
@@ -235,8 +187,7 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(retry_err),
|
||||
"data": {
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
@@ -250,8 +201,7 @@ async def get_dashboard_data(
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"data": {
|
||||
"total_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"current_period_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}, "usage_percentages": {}},
|
||||
"current_usage": {"total_calls": 0, "total_cost": 0, "usage_status": "error", "provider_breakdown": {}},
|
||||
"trends": [],
|
||||
"limits": {"limits": {"monthly_cost": 0}},
|
||||
"alerts": [],
|
||||
|
||||
@@ -14,21 +14,13 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"""
|
||||
Format subscription plan limits for API response.
|
||||
|
||||
Includes _zero_means metadata per field to disambiguate:
|
||||
- 'disabled': 0 means the feature is not available (Free tier)
|
||||
- 'unlimited': 0 means unlimited usage (Enterprise tier)
|
||||
- 'limited': >0 means numerical limit applies
|
||||
|
||||
Args:
|
||||
plan: SubscriptionPlan model instance
|
||||
|
||||
Returns:
|
||||
Dictionary with formatted limits and _zero_means metadata
|
||||
Dictionary with formatted limits
|
||||
"""
|
||||
tier = plan.tier.value if hasattr(plan.tier, 'value') else str(plan.tier)
|
||||
is_enterprise = tier == 'enterprise'
|
||||
|
||||
limit_fields = {
|
||||
return {
|
||||
"ai_text_generation_calls": getattr(plan, 'ai_text_generation_calls_limit', None) or 0,
|
||||
"gemini_calls": plan.gemini_calls_limit,
|
||||
"openai_calls": plan.openai_calls_limit,
|
||||
@@ -43,43 +35,11 @@ def format_plan_limits(plan: SubscriptionPlan) -> Dict[str, Any]:
|
||||
"image_edit_calls": getattr(plan, 'image_edit_calls_limit', 0) or 0,
|
||||
"audio_calls": getattr(plan, 'audio_calls_limit', 0) or 0,
|
||||
"exa_calls": getattr(plan, 'exa_calls_limit', 0) or 0,
|
||||
"wavespeed_calls": getattr(plan, 'wavespeed_calls_limit', 0) or 0,
|
||||
"gemini_tokens": plan.gemini_tokens_limit,
|
||||
"openai_tokens": plan.openai_tokens_limit,
|
||||
"anthropic_tokens": plan.anthropic_tokens_limit,
|
||||
"mistral_tokens": plan.mistral_tokens_limit,
|
||||
"monthly_cost": plan.monthly_cost_limit,
|
||||
}
|
||||
|
||||
# Build _zero_means metadata: indicates whether 0 means 'disabled' or 'unlimited'
|
||||
zero_means = {}
|
||||
for field, value in limit_fields.items():
|
||||
if field == "monthly_cost":
|
||||
zero_means[field] = "disabled"
|
||||
elif is_enterprise:
|
||||
# Enterprise: 0 means unlimited for all call/token fields
|
||||
zero_means[field] = "unlimited"
|
||||
else:
|
||||
# Free/Basic/Pro: determine per-field
|
||||
# Fields that are 0=disabled on Free tier but 0=unlimited on Basic/Pro
|
||||
call_and_token_fields = {
|
||||
"gemini_calls", "openai_calls", "anthropic_calls", "mistral_calls",
|
||||
"tavily_calls", "serper_calls", "metaphor_calls", "firecrawl_calls",
|
||||
"stability_calls", "video_calls", "image_edit_calls", "audio_calls",
|
||||
"exa_calls", "wavespeed_calls", "ai_text_generation_calls",
|
||||
"gemini_tokens", "openai_tokens", "anthropic_tokens", "mistral_tokens",
|
||||
}
|
||||
if field in call_and_token_fields:
|
||||
if value == 0:
|
||||
zero_means[field] = "disabled" if tier == "free" else "unlimited"
|
||||
else:
|
||||
zero_means[field] = "limited"
|
||||
else:
|
||||
zero_means[field] = "limited" if value > 0 else "disabled"
|
||||
|
||||
return {
|
||||
**limit_fields,
|
||||
"_zero_means": zero_means,
|
||||
"monthly_cost": plan.monthly_cost_limit
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Any, Dict
|
||||
from loguru import logger
|
||||
|
||||
from services.writing_assistant import WritingAssistantService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
@@ -12,6 +11,7 @@ router = APIRouter(prefix="/api/writing-assistant", tags=["writing-assistant"])
|
||||
|
||||
class SuggestRequest(BaseModel):
|
||||
text: str
|
||||
max_results: int | None = 1
|
||||
|
||||
|
||||
class SourceModel(BaseModel):
|
||||
@@ -38,10 +38,9 @@ assistant_service = WritingAssistantService()
|
||||
|
||||
|
||||
@router.post("/suggest", response_model=SuggestResponse)
|
||||
async def suggest_endpoint(req: SuggestRequest, current_user: Dict[str, Any] = Depends(get_current_user)) -> SuggestResponse:
|
||||
async def suggest_endpoint(req: SuggestRequest) -> SuggestResponse:
|
||||
try:
|
||||
user_id = current_user.get("id")
|
||||
suggestions = await assistant_service.suggest(req.text, user_id=user_id)
|
||||
suggestions = await assistant_service.suggest(req.text, req.max_results or 1)
|
||||
return SuggestResponse(
|
||||
success=True,
|
||||
suggestions=[
|
||||
|
||||
289
backend/app.py
289
backend/app.py
@@ -27,11 +27,11 @@ load_dotenv(backend_dir / '.env', override=False)
|
||||
load_dotenv(project_root / '.env', override=False)
|
||||
load_dotenv(override=False)
|
||||
|
||||
# Set LOG_LEVEL early to WARNING in feature-only modes to suppress DEBUG persona logs
|
||||
# Set LOG_LEVEL early to WARNING to suppress DEBUG persona logs in podcast mode
|
||||
import os
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast":
|
||||
os.environ["LOG_LEVEL"] = "WARNING"
|
||||
|
||||
|
||||
print(f"[app.py] Starting... ALWRITY_ENABLED_FEATURES={os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
|
||||
@@ -43,21 +43,22 @@ def get_enabled_features() -> set:
|
||||
return {f.strip() for f in env_value.split(",") if f.strip()}
|
||||
|
||||
|
||||
def _is_full_mode() -> bool:
|
||||
"""Check if running in full mode (all features enabled)."""
|
||||
enabled = get_enabled_features()
|
||||
return "all" in enabled
|
||||
|
||||
|
||||
def _is_feature_enabled(feature: str) -> bool:
|
||||
"""Check if a specific feature is enabled (including in 'all' mode)."""
|
||||
enabled = get_enabled_features()
|
||||
return feature in enabled or "all" in enabled
|
||||
|
||||
|
||||
# Print env var IMMEDIATELY at module start
|
||||
print(f"[app.py] ALWRITY_ENABLED_FEATURES at start: {os.getenv('ALWRITY_ENABLED_FEATURES')}", flush=True)
|
||||
|
||||
def is_podcast_only_demo_mode() -> bool:
|
||||
"""Check if podcast-only mode is enabled."""
|
||||
import os
|
||||
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "all")
|
||||
enabled = get_enabled_features()
|
||||
result = "podcast" in enabled and "all" not in enabled
|
||||
# Removed debug print - too verbose during startup
|
||||
return result
|
||||
|
||||
|
||||
# Podcast-only check BEFORE heavy imports
|
||||
PODCAST_ONLY_DEMO_MODE = is_podcast_only_demo_mode()
|
||||
|
||||
|
||||
# Import onboarding models (after env is loaded, before heavy imports)
|
||||
from models.onboarding import APIKey, WebsiteAnalysis, ResearchPreferences, PersonaData, CompetitorAnalysis
|
||||
@@ -89,18 +90,28 @@ _log_memory_usage()
|
||||
logger.info("app.py: Early memory checkpoint after env load")
|
||||
|
||||
|
||||
# Import modular utilities (skip OnboardingManager import in feature-only modes)
|
||||
# Import modular utilities (skip OnboardingManager import in podcast-only mode)
|
||||
from alwrity_utils import HealthChecker, RateLimiter, FrontendServing, RouterManager
|
||||
if _is_full_mode():
|
||||
if not is_podcast_only_demo_mode():
|
||||
from alwrity_utils import OnboardingManager
|
||||
|
||||
# Skip monitoring middleware in feature-only modes to save memory
|
||||
if _is_full_mode():
|
||||
# Skip monitoring middleware in podcast-only mode to save memory
|
||||
if not is_podcast_only_demo_mode():
|
||||
from services.subscription import monitoring_middleware
|
||||
else:
|
||||
monitoring_middleware = None
|
||||
|
||||
|
||||
def should_include_non_podcast_features() -> bool:
|
||||
"""Check if non-podcast features should be included."""
|
||||
enabled = get_enabled_features()
|
||||
return "all" in enabled or "core" in enabled
|
||||
|
||||
|
||||
# Legacy constant for backwards compatibility
|
||||
PODCAST_ONLY_DEMO_MODE = is_podcast_only_demo_mode()
|
||||
|
||||
|
||||
# Set up clean logging for end users
|
||||
from logging_config import setup_clean_logging
|
||||
setup_clean_logging()
|
||||
@@ -108,27 +119,27 @@ setup_clean_logging()
|
||||
# Import middleware
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
# Import component logic endpoints (skip in feature-only modes - uses seo_analyzer)
|
||||
# Import component logic endpoints (skip in podcast-only mode - uses seo_analyzer)
|
||||
component_logic_router = None
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.component_logic import router as component_logic_router
|
||||
|
||||
# Import subscription API endpoints
|
||||
from api.subscription import router as subscription_router
|
||||
|
||||
# Import Step 3 onboarding routes (skip in feature-only modes)
|
||||
# Import Step 3 onboarding routes (skip in podcast-only mode)
|
||||
step3_routes = None
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.onboarding_utils.step3_routes import router as step3_routes
|
||||
|
||||
# Import SEO tools router (skip in feature-only modes - uses seo_analyzer)
|
||||
# Import SEO tools router (skip in podcast-only mode - uses seo_analyzer)
|
||||
seo_tools_router = None
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from routers.seo_tools import router as seo_tools_router
|
||||
|
||||
# Skip Facebook Writer, LinkedIn, and other non-essential routes in feature-only modes
|
||||
# Skip Facebook Writer, LinkedIn, and other non-podcast routes in podcast-only mode
|
||||
# Also skip other heavy services that trigger PersonaAnalysisService initialization
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.facebook_writer.routers import facebook_router
|
||||
from routers.linkedin import router as linkedin_router
|
||||
from api.linkedin_image_generation import router as linkedin_image_router
|
||||
@@ -139,7 +150,7 @@ if _is_full_mode():
|
||||
from routers.product_marketing import router as product_marketing_router
|
||||
from routers.campaign_creator import router as campaign_creator_router
|
||||
else:
|
||||
# In feature-only modes, only load essential assets router
|
||||
# In podcast-only mode, only load essential podcast assets router
|
||||
from api.assets_serving import router as assets_serving_router
|
||||
brainstorm_router = None
|
||||
images_router = None
|
||||
@@ -147,31 +158,31 @@ else:
|
||||
product_marketing_router = None
|
||||
campaign_creator_router = None
|
||||
|
||||
# Import hallucination detector router (skip in feature-only modes - triggers heavy ML)
|
||||
if _is_full_mode():
|
||||
# Import hallucination detector router (skip in podcast-only mode - triggers heavy ML)
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from api.hallucination_detector import router as hallucination_detector_router
|
||||
from api.writing_assistant import router as writing_assistant_router
|
||||
else:
|
||||
hallucination_detector_router = None
|
||||
writing_assistant_router = None
|
||||
|
||||
# Import research configuration router (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# Import research configuration router (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.research_config import router as research_config_router
|
||||
else:
|
||||
research_config_router = None
|
||||
|
||||
# Import user data endpoints
|
||||
# Import content planning endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# Import content planning endpoints (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.content_planning.api.router import router as content_planning_router
|
||||
from api.content_planning.strategy_copilot import router as strategy_copilot_router
|
||||
else:
|
||||
content_planning_router = None
|
||||
strategy_copilot_router = None
|
||||
|
||||
# Import user data endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
# Import user data endpoints (skip in podcast-only mode to save memory)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.user_data import router as user_data_router
|
||||
else:
|
||||
user_data_router = None
|
||||
@@ -186,14 +197,14 @@ from services.startup_health import (
|
||||
|
||||
# Trigger reload for monitoring fix
|
||||
|
||||
# Import OAuth token monitoring routes (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# Import OAuth token monitoring routes (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.oauth_token_monitoring_routes import router as oauth_token_monitoring_router
|
||||
else:
|
||||
oauth_token_monitoring_router = None
|
||||
|
||||
# Import SEO Dashboard endpoints (skip in feature-only modes to save memory)
|
||||
if _is_full_mode():
|
||||
# Import SEO Dashboard endpoints (skip in podcast-only mode to save memory)
|
||||
if not is_podcast_only_demo_mode():
|
||||
from api.seo_dashboard import (
|
||||
get_seo_dashboard_data,
|
||||
get_seo_health_score,
|
||||
@@ -307,8 +318,8 @@ router_manager = RouterManager(app)
|
||||
router_group_status: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
onboarding_manager = None
|
||||
# Only create OnboardingManager in full mode
|
||||
if _is_full_mode():
|
||||
# Only create OnboardingManager if NOT in podcast-only mode
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from alwrity_utils import OnboardingManager
|
||||
onboarding_manager = OnboardingManager(app)
|
||||
|
||||
@@ -335,8 +346,7 @@ app.middleware("http")(api_key_injection_middleware)
|
||||
async def health():
|
||||
"""Health check endpoint."""
|
||||
health_data = health_checker.basic_health_check()
|
||||
health_data["feature_mode"] = "single" if not _is_full_mode() else "full"
|
||||
health_data["enabled_features"] = list(get_enabled_features())
|
||||
health_data["podcast_only_demo_mode"] = PODCAST_ONLY_DEMO_MODE
|
||||
return health_data
|
||||
|
||||
@app.get("/health/database")
|
||||
@@ -353,8 +363,7 @@ async def comprehensive_health():
|
||||
async def readiness(current_user: dict = Depends(get_current_user)):
|
||||
"""Readiness check that validates tenant DB resolution/session under auth context."""
|
||||
return {
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"startup": get_startup_status(),
|
||||
"tenant": readiness_under_auth_context(current_user),
|
||||
}
|
||||
@@ -386,8 +395,7 @@ async def router_status():
|
||||
status = router_manager.get_router_status()
|
||||
status.update(
|
||||
{
|
||||
"feature_mode": "single" if not _is_full_mode() else "full",
|
||||
"enabled_features": list(get_enabled_features()),
|
||||
"podcast_only_demo_mode": PODCAST_ONLY_DEMO_MODE,
|
||||
"router_groups": router_group_status,
|
||||
}
|
||||
)
|
||||
@@ -402,19 +410,53 @@ async def feature_profile_status():
|
||||
@app.get("/api/onboarding/status")
|
||||
async def onboarding_status():
|
||||
"""Get onboarding manager status (or demo-mode disabled state)."""
|
||||
if not _is_full_mode():
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
return {
|
||||
"enabled": False,
|
||||
"status": "disabled",
|
||||
"message": f"Onboarding is disabled in feature-only mode. Enabled features: {list(get_enabled_features())}",
|
||||
"feature_mode": "single",
|
||||
"message": "Onboarding is disabled for podcast-only demo mode.",
|
||||
"demo_mode": "podcast_only",
|
||||
}
|
||||
return onboarding_manager.get_onboarding_status()
|
||||
|
||||
# Include routers using modular utilities
|
||||
enabled_features = get_enabled_features()
|
||||
if "all" in enabled_features:
|
||||
# Full mode: load all core and optional routers
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
# In podcast-only mode, include only podcast-enabled routers from core registry
|
||||
from alwrity_utils.router_manager import CORE_ROUTER_REGISTRY
|
||||
podcast_routers = [r for r in CORE_ROUTER_REGISTRY if "podcast" in r.get("features", set())]
|
||||
logger.info(f"[PODCAST-ONLY] Found {len(podcast_routers)} podcast routers: {[r['name'] for r in podcast_routers]}")
|
||||
|
||||
# Try to include step4_assets for voice cloning (may fail if nltk not installed)
|
||||
step4_entry = next((r for r in CORE_ROUTER_REGISTRY if r.get("name") == "step4_assets"), None)
|
||||
if step4_entry:
|
||||
try:
|
||||
logger.info(f"[PODCAST-ONLY] Attempting to load step4_assets for voice cloning")
|
||||
router = router_manager._load_router_from_registry(step4_entry)
|
||||
router_manager.include_router_safely(router, step4_entry["name"], step4_entry.get("include_kwargs"))
|
||||
except ImportError as e:
|
||||
logger.warning(f"[PODCAST-ONLY] Skipping step4_assets (missing optional dependency): {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[PODCAST-ONLY] Failed to mount step4_assets: {e}")
|
||||
|
||||
# Load other podcast routers
|
||||
for entry in podcast_routers:
|
||||
if entry.get("name") == "step4_assets":
|
||||
continue # Already loaded above
|
||||
try:
|
||||
logger.info(f"[PODCAST-ONLY] Loading router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[PODCAST-ONLY] Failed to mount {entry.get('name', 'unknown')}: {e}")
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": True,
|
||||
"reason": "Podcast routers only in podcast-only mode",
|
||||
}
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
else:
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": router_manager.include_core_routers(),
|
||||
"reason": "Full mode",
|
||||
@@ -423,72 +465,6 @@ if "all" in enabled_features:
|
||||
"mounted": router_manager.include_optional_routers(),
|
||||
"reason": "Full mode",
|
||||
}
|
||||
else:
|
||||
# Feature-only mode: load only routers matching enabled features
|
||||
from alwrity_utils.router_manager import CORE_ROUTER_REGISTRY
|
||||
|
||||
# Filter core routers that match any enabled feature
|
||||
matching_core = [
|
||||
r for r in CORE_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
logger.info(
|
||||
f"[FEATURE-MODE] Enabled features: {enabled_features}, "
|
||||
f"matching {len(matching_core)} core routers: {[r['name'] for r in matching_core]}"
|
||||
)
|
||||
|
||||
# Try to include step4_assets for voice cloning (may fail if nltk not installed)
|
||||
step4_entry = next((r for r in matching_core if r.get("name") == "step4_assets"), None)
|
||||
if step4_entry:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Attempting to load step4_assets")
|
||||
router = router_manager._load_router_from_registry(step4_entry)
|
||||
router_manager.include_router_safely(router, step4_entry["name"], step4_entry.get("include_kwargs"))
|
||||
except ImportError as e:
|
||||
logger.warning(f"[FEATURE-MODE] Skipping step4_assets (missing optional dependency): {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount step4_assets: {e}")
|
||||
|
||||
# Load other matching core routers
|
||||
for entry in matching_core:
|
||||
if entry.get("name") == "step4_assets":
|
||||
continue # Already loaded above
|
||||
if entry.get("name") == "subscription":
|
||||
continue # Loaded separately below
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_core"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Load optional routers matching enabled features
|
||||
from alwrity_utils.router_manager import OPTIONAL_ROUTER_REGISTRY
|
||||
matching_optional = [
|
||||
r for r in OPTIONAL_ROUTER_REGISTRY
|
||||
if r.get("features", set()) & enabled_features
|
||||
]
|
||||
for entry in matching_optional:
|
||||
try:
|
||||
logger.info(f"[FEATURE-MODE] Loading optional router: {entry['name']}")
|
||||
router = router_manager._load_router_from_registry(entry)
|
||||
router_manager.include_router_safely(router, entry["name"], entry.get("include_kwargs"))
|
||||
except Exception as e:
|
||||
logger.error(f"[FEATURE-MODE] Failed to mount optional {entry.get('name', 'unknown')}: {e}")
|
||||
|
||||
router_group_status["modular_optional"] = {
|
||||
"mounted": True,
|
||||
"reason": f"Feature-only mode: {enabled_features}",
|
||||
}
|
||||
|
||||
# Safety net: explicitly include hallucination detector (router_manager may skip silently)
|
||||
if hallucination_detector_router:
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
|
||||
# Log startup summary
|
||||
router_manager.log_startup_summary()
|
||||
@@ -504,8 +480,8 @@ router_group_status["assets_serving"] = {
|
||||
"reason": "Required for podcast media assets",
|
||||
}
|
||||
|
||||
# SEO Dashboard endpoints (skip in feature-only modes)
|
||||
if _is_full_mode():
|
||||
# SEO Dashboard endpoints (skip in podcast-only mode)
|
||||
if not is_podcast_only_demo_mode():
|
||||
@app.get("/api/seo-dashboard/data")
|
||||
async def seo_dashboard_data():
|
||||
"""Get complete SEO dashboard data."""
|
||||
@@ -643,7 +619,7 @@ if _is_full_mode():
|
||||
return await analyze_urls_ai(request, current_user)
|
||||
|
||||
# Include platform analytics router
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
# Include Bing Analytics Storage router to expose storage-backed endpoints
|
||||
@@ -668,38 +644,25 @@ if _is_full_mode():
|
||||
else:
|
||||
router_group_status["platform_extensions"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in feature-only mode",
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
|
||||
# Include Podcast Maker router (only when podcast feature is enabled)
|
||||
if _is_feature_enabled("podcast") and "all" not in get_enabled_features():
|
||||
from api.podcast.router import router as podcast_router
|
||||
logger.info(f"[ROUTER] Including podcast_router")
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Podcast feature enabled",
|
||||
}
|
||||
elif "all" in get_enabled_features():
|
||||
# In full mode, podcast is loaded via optional router registry
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Full mode (loaded via registry)",
|
||||
}
|
||||
else:
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": False,
|
||||
"reason": "Podcast feature not enabled",
|
||||
}
|
||||
# Include Podcast Maker router (always needed for podcast mode)
|
||||
from api.podcast.router import router as podcast_router
|
||||
logger.info(f"[PODCAST] Including podcast_router with prefixes: {podcast_router.routes}")
|
||||
app.include_router(podcast_router)
|
||||
router_group_status["podcast_maker"] = {
|
||||
"mounted": True,
|
||||
"reason": "Always mounted",
|
||||
}
|
||||
|
||||
if _is_full_mode():
|
||||
if not PODCAST_ONLY_DEMO_MODE:
|
||||
# Include YouTube Creator Studio router
|
||||
from api.youtube.router import router as youtube_router
|
||||
app.include_router(youtube_router, prefix="/api")
|
||||
|
||||
# Include research configuration router
|
||||
if research_config_router:
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
app.include_router(research_config_router, prefix="/api/research", tags=["research"])
|
||||
|
||||
# Include Research Engine router (standalone AI research module)
|
||||
from api.research.router import router as research_engine_router
|
||||
@@ -725,7 +688,7 @@ if _is_full_mode():
|
||||
else:
|
||||
router_group_status["advanced_workflows"] = {
|
||||
"mounted": False,
|
||||
"reason": "Skipped in feature-only mode",
|
||||
"reason": "Skipped in podcast-only demo mode",
|
||||
}
|
||||
|
||||
# Setup frontend serving using modular utilities
|
||||
@@ -752,23 +715,20 @@ async def startup_event():
|
||||
# Note: Pricing is initialized per-user in services/database.py:init_user_database()
|
||||
# which runs on first database access for each user. No global seeding needed at startup.
|
||||
|
||||
enabled_features = get_enabled_features()
|
||||
is_single_mode = "all" not in enabled_features
|
||||
|
||||
# Skip startup health checks in feature-only modes to avoid unnecessary DB errors
|
||||
if _is_full_mode():
|
||||
# Skip startup health checks in podcast-only mode to avoid unnecessary DB errors
|
||||
if not is_podcast_only_demo_mode():
|
||||
startup_report = run_startup_health_routine(app)
|
||||
if startup_report.get("status") != "healthy":
|
||||
logger.error(f"Startup readiness finished with failures: {startup_report.get('errors', [])}")
|
||||
else:
|
||||
logger.info(f"[FEATURE-MODE] Skipping startup health routine (features: {enabled_features})")
|
||||
logger.info("[Podcast] Skipping startup health routine (podcast-only mode)")
|
||||
|
||||
# Start task scheduler only in full mode
|
||||
if _is_full_mode():
|
||||
# Start task scheduler only if NOT in podcast-only mode
|
||||
if not is_podcast_only_demo_mode():
|
||||
from services.scheduler import get_scheduler
|
||||
await get_scheduler().start()
|
||||
else:
|
||||
logger.info(f"[FEATURE-MODE] Skipping scheduler startup (features: {enabled_features})")
|
||||
logger.info("[Podcast] Skipping scheduler startup (podcast-only mode)")
|
||||
|
||||
# Check Wix API key configuration
|
||||
wix_api_key = os.getenv('WIX_API_KEY')
|
||||
@@ -780,12 +740,9 @@ async def startup_event():
|
||||
elapsed = time.time() - startup_start
|
||||
logger.info(f"ALwrity backend started successfully in {elapsed:.1f}s")
|
||||
|
||||
# Critical router mount assertions for feature-only modes
|
||||
# Critical router mount assertions for podcast-only demo mode
|
||||
_assert_router_mounted("subscription")
|
||||
if _is_feature_enabled("podcast"):
|
||||
_assert_router_mounted("podcast")
|
||||
if _is_feature_enabled("blog_writer"):
|
||||
_assert_router_mounted("blog_writer")
|
||||
_assert_router_mounted("podcast")
|
||||
except Exception as e:
|
||||
logger.error(f"Error during startup: {e}")
|
||||
# Don't raise - let the server start anyway
|
||||
@@ -800,7 +757,6 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
router_path_indicators = {
|
||||
"subscription": ["/api/subscription/plans", "/api/subscription/preflight"],
|
||||
"podcast": ["/api/podcast/projects", "/api/podcast/"],
|
||||
"blog_writer": ["/api/blog/health", "/api/blog/research/start"],
|
||||
}
|
||||
|
||||
expected_paths = router_path_indicators.get(router_name, [])
|
||||
@@ -811,9 +767,10 @@ def _assert_router_mounted(router_name: str) -> None:
|
||||
else:
|
||||
error_msg = f"❌ CRITICAL: Router '{router_name}' is NOT mounted! Expected paths: {expected_paths}"
|
||||
logger.error(error_msg)
|
||||
# In feature-only mode, only fail if the feature is expected
|
||||
if not _is_full_mode() and _is_feature_enabled(router_name):
|
||||
raise RuntimeError(error_msg)
|
||||
if PODCAST_ONLY_DEMO_MODE:
|
||||
# In demo mode, podcast router MUST be mounted
|
||||
if router_name == "podcast":
|
||||
raise RuntimeError(error_msg)
|
||||
|
||||
# Shutdown event
|
||||
@app.on_event("shutdown")
|
||||
|
||||
@@ -252,8 +252,6 @@ router_manager.include_core_routers()
|
||||
# Safety net: keep subscription routes available even if core inclusion flow changes
|
||||
# in special modes (e.g., demo mode). De-dup is handled by RouterManager.
|
||||
router_manager.include_router_safely(subscription_router, "subscription")
|
||||
# Include hallucination detector explicitly (router_manager may skip silently on import failure)
|
||||
router_manager.include_router_safely(hallucination_detector_router, "hallucination_detector")
|
||||
router_manager.include_optional_routers()
|
||||
|
||||
# SEO Dashboard endpoints
|
||||
|
||||
@@ -11,30 +11,17 @@ echo "📦 Checking ALWRITY_ENABLED_FEATURES..."
|
||||
ENABLED_FEATURES="${ALWRITY_ENABLED_FEATURES:-all}"
|
||||
echo "DEBUG: ENABLED_FEATURES='$ENABLED_FEATURES'"
|
||||
|
||||
case "$ENABLED_FEATURES" in
|
||||
all)
|
||||
echo "📦 Full mode: Installing all requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
# Download spaCy/NLTK models for full mode
|
||||
echo "🧠 Installing spaCy and NLTK models..."
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
;;
|
||||
podcast)
|
||||
echo "🔊 Podcast-only mode: Installing lean requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements-podcast.txt --only-binary :all: --retries 10 --timeout 120
|
||||
;;
|
||||
*)
|
||||
echo "🎯 Feature-limited mode ($ENABLED_FEATURES): Installing requirements..."
|
||||
req_file="requirements-${ENABLED_FEATURES}.txt"
|
||||
if [[ -f "$req_file" ]]; then
|
||||
python -m pip install --no-cache-dir -r "$req_file" --only-binary :all: --retries 10 --timeout 120
|
||||
else
|
||||
echo "⚠️ No feature-specific requirements file found ($req_file), installing full requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
fi
|
||||
;;
|
||||
esac
|
||||
if [[ "$ENABLED_FEATURES" == "podcast" ]]; then
|
||||
echo "🔊 Podcast-only mode: Installing lean requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements-podcast.txt --only-binary :all: --retries 10 --timeout 120
|
||||
else
|
||||
echo "📦 Full mode: Installing all requirements..."
|
||||
python -m pip install --no-cache-dir -r requirements.txt --only-binary :all: --retries 10 --timeout 120
|
||||
# Download spaCy/NLTK models for full mode
|
||||
echo "🧠 Installing spaCy and NLTK models..."
|
||||
python -m spacy download en_core_web_sm
|
||||
python -m nltk.downloader punkt_tab stopwords averaged_perceptron_tagger
|
||||
fi
|
||||
|
||||
# 3. Clean up unnecessary build artifacts
|
||||
find . -type d -name "__pycache__" -exec rm -rf {} + 2>/dev/null || true
|
||||
|
||||
1620
backend/routers/image_studio.py
Normal file
1620
backend/routers/image_studio.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,34 +0,0 @@
|
||||
"""Image Studio API router package.
|
||||
|
||||
Composed from modular sub-routers. Same prefix and tags as the original monolithic file.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from .health import router as health_router
|
||||
from .upscale import router as upscale_router
|
||||
from .control import router as control_router
|
||||
from .social import router as social_router
|
||||
from .edit import router as edit_router
|
||||
from .face_swap import router as face_swap_router
|
||||
from .create import router as create_router
|
||||
from .transform import router as transform_router
|
||||
from .compress import router as compress_router
|
||||
from .convert import router as convert_router
|
||||
from .save import router as save_router
|
||||
|
||||
router = APIRouter(prefix="/api/image-studio", tags=["image-studio"])
|
||||
|
||||
router.include_router(health_router)
|
||||
router.include_router(upscale_router)
|
||||
router.include_router(control_router)
|
||||
router.include_router(social_router)
|
||||
router.include_router(edit_router)
|
||||
router.include_router(face_swap_router)
|
||||
router.include_router(create_router)
|
||||
router.include_router(transform_router)
|
||||
router.include_router(compress_router)
|
||||
router.include_router(convert_router)
|
||||
router.include_router(save_router)
|
||||
|
||||
__all__ = ["router"]
|
||||
@@ -1,158 +0,0 @@
|
||||
"""Compression Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
CompressImageRequest, CompressImageResponse,
|
||||
CompressBatchRequest, CompressBatchResponse,
|
||||
CompressionEstimateRequest, CompressionEstimateResponse,
|
||||
CompressionFormatsResponse, CompressionPresetsResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/compress", response_model=CompressImageResponse, summary="Compress an image")
|
||||
async def compress_image(
|
||||
request: CompressImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Compress an image with specified quality and format settings."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image compression")
|
||||
logger.info(f"[Compression] Request from user {user_id}: format={request.format}, quality={request.quality}")
|
||||
|
||||
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
|
||||
|
||||
compression_request = ServiceRequest(
|
||||
image_base64=request.image_base64,
|
||||
quality=request.quality,
|
||||
format=request.format,
|
||||
target_size_kb=request.target_size_kb,
|
||||
strip_metadata=request.strip_metadata,
|
||||
progressive=request.progressive,
|
||||
optimize=request.optimize,
|
||||
)
|
||||
|
||||
result = await studio_manager.compress_image(compression_request, user_id=user_id)
|
||||
|
||||
return CompressImageResponse(
|
||||
success=result.success,
|
||||
image_base64=result.image_base64,
|
||||
original_size_kb=result.original_size_kb,
|
||||
compressed_size_kb=result.compressed_size_kb,
|
||||
compression_ratio=result.compression_ratio,
|
||||
format=result.format,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
quality_used=result.quality_used,
|
||||
metadata_stripped=result.metadata_stripped,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image compression failed: {e}")
|
||||
|
||||
|
||||
@router.post("/compress/batch", response_model=CompressBatchResponse, summary="Compress multiple images")
|
||||
async def compress_batch(
|
||||
request: CompressBatchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Compress multiple images with the same or individual settings."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "batch compression")
|
||||
logger.info(f"[Compression] Batch request from user {user_id}: {len(request.images)} images")
|
||||
|
||||
from services.image_studio.compression_service import CompressionRequest as ServiceRequest
|
||||
|
||||
compression_requests = [
|
||||
ServiceRequest(
|
||||
image_base64=img.image_base64,
|
||||
quality=img.quality,
|
||||
format=img.format,
|
||||
target_size_kb=img.target_size_kb,
|
||||
strip_metadata=img.strip_metadata,
|
||||
progressive=img.progressive,
|
||||
optimize=img.optimize,
|
||||
)
|
||||
for img in request.images
|
||||
]
|
||||
|
||||
results = await studio_manager.compress_batch(compression_requests, user_id=user_id)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
|
||||
return CompressBatchResponse(
|
||||
success=failed == 0,
|
||||
results=[
|
||||
CompressImageResponse(
|
||||
success=r.success,
|
||||
image_base64=r.image_base64,
|
||||
original_size_kb=r.original_size_kb,
|
||||
compressed_size_kb=r.compressed_size_kb,
|
||||
compression_ratio=r.compression_ratio,
|
||||
format=r.format,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
quality_used=r.quality_used,
|
||||
metadata_stripped=r.metadata_stripped,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
total_images=len(results),
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Batch error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Batch compression failed: {e}")
|
||||
|
||||
|
||||
@router.post("/compress/estimate", response_model=CompressionEstimateResponse, summary="Estimate compression results")
|
||||
async def estimate_compression(
|
||||
request: CompressionEstimateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Estimate compression results without actually compressing the image."""
|
||||
try:
|
||||
result = await studio_manager.estimate_compression(
|
||||
request.image_base64,
|
||||
request.format,
|
||||
request.quality,
|
||||
)
|
||||
return CompressionEstimateResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Compression] ❌ Estimate error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Compression estimation failed: {e}")
|
||||
|
||||
|
||||
@router.get("/compress/formats", response_model=CompressionFormatsResponse, summary="Get supported compression formats")
|
||||
async def get_compression_formats(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get list of supported compression formats with their capabilities."""
|
||||
formats = studio_manager.get_compression_formats()
|
||||
return CompressionFormatsResponse(formats=formats)
|
||||
|
||||
|
||||
@router.get("/compress/presets", response_model=CompressionPresetsResponse, summary="Get compression presets")
|
||||
async def get_compression_presets(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get predefined compression presets for common use cases."""
|
||||
presets = studio_manager.get_compression_presets()
|
||||
return CompressionPresetsResponse(presets=presets)
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Control Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import ControlImageRequest, ControlImageResponse, ControlOperationsResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, ControlStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/control/process", response_model=ControlImageResponse, summary="Process Control Studio request")
|
||||
async def process_control_image(
|
||||
request: ControlImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Control Studio operations such as sketch-to-image, structure control, style control, and style transfer."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image control")
|
||||
logger.info(f"[Control Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
control_request = ControlStudioRequest(
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
control_image_base64=request.control_image_base64,
|
||||
style_image_base64=request.style_image_base64,
|
||||
negative_prompt=request.negative_prompt,
|
||||
control_strength=request.control_strength,
|
||||
fidelity=request.fidelity,
|
||||
style_strength=request.style_strength,
|
||||
composition_fidelity=request.composition_fidelity,
|
||||
change_strength=request.change_strength,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
)
|
||||
|
||||
result = await studio_manager.control_image(control_request, user_id=user_id)
|
||||
return ControlImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image control failed: {e}")
|
||||
|
||||
|
||||
@router.get("/control/operations", response_model=ControlOperationsResponse, summary="List Control Studio operations")
|
||||
async def get_control_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Control Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_control_operations()
|
||||
return ControlOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Control Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load control operations")
|
||||
@@ -1,143 +0,0 @@
|
||||
"""Format Converter endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
from .models import (
|
||||
ConvertFormatRequest, ConvertFormatResponse,
|
||||
ConvertFormatBatchRequest, ConvertFormatBatchResponse,
|
||||
SupportedFormatsResponse, FormatRecommendationsResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/convert-format", response_model=ConvertFormatResponse, summary="Convert image format")
|
||||
async def convert_format(
|
||||
request: ConvertFormatRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Convert an image to a different format."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "format conversion")
|
||||
logger.info(f"[Format Converter] Request from user {user_id}: {request.target_format}")
|
||||
|
||||
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
|
||||
|
||||
conversion_request = ServiceRequest(
|
||||
image_base64=request.image_base64,
|
||||
target_format=request.target_format,
|
||||
preserve_transparency=request.preserve_transparency,
|
||||
quality=request.quality,
|
||||
color_space=request.color_space,
|
||||
strip_metadata=request.strip_metadata,
|
||||
optimize=request.optimize,
|
||||
progressive=request.progressive,
|
||||
)
|
||||
|
||||
result = await studio_manager.convert_format(conversion_request, user_id=user_id)
|
||||
|
||||
return ConvertFormatResponse(
|
||||
success=result.success,
|
||||
image_base64=result.image_base64,
|
||||
original_format=result.original_format,
|
||||
target_format=result.target_format,
|
||||
original_size_kb=result.original_size_kb,
|
||||
converted_size_kb=result.converted_size_kb,
|
||||
width=result.width,
|
||||
height=result.height,
|
||||
transparency_preserved=result.transparency_preserved,
|
||||
metadata_preserved=result.metadata_preserved,
|
||||
color_space=result.color_space,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Format conversion failed: {e}")
|
||||
|
||||
|
||||
@router.post("/convert-format/batch", response_model=ConvertFormatBatchResponse, summary="Convert multiple images")
|
||||
async def convert_format_batch(
|
||||
request: ConvertFormatBatchRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Convert multiple images to different formats."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "batch format conversion")
|
||||
logger.info(f"[Format Converter] Batch request from user {user_id}: {len(request.images)} images")
|
||||
|
||||
from services.image_studio.format_converter_service import FormatConversionRequest as ServiceRequest
|
||||
|
||||
conversion_requests = [
|
||||
ServiceRequest(
|
||||
image_base64=img.image_base64,
|
||||
target_format=img.target_format,
|
||||
preserve_transparency=img.preserve_transparency,
|
||||
quality=img.quality,
|
||||
color_space=img.color_space,
|
||||
strip_metadata=img.strip_metadata,
|
||||
optimize=img.optimize,
|
||||
progressive=img.progressive,
|
||||
)
|
||||
for img in request.images
|
||||
]
|
||||
|
||||
results = await studio_manager.convert_format_batch(conversion_requests, user_id=user_id)
|
||||
|
||||
successful = sum(1 for r in results if r.success)
|
||||
failed = len(results) - successful
|
||||
|
||||
return ConvertFormatBatchResponse(
|
||||
success=failed == 0,
|
||||
results=[
|
||||
ConvertFormatResponse(
|
||||
success=r.success,
|
||||
image_base64=r.image_base64,
|
||||
original_format=r.original_format,
|
||||
target_format=r.target_format,
|
||||
original_size_kb=r.original_size_kb,
|
||||
converted_size_kb=r.converted_size_kb,
|
||||
width=r.width,
|
||||
height=r.height,
|
||||
transparency_preserved=r.transparency_preserved,
|
||||
metadata_preserved=r.metadata_preserved,
|
||||
color_space=r.color_space,
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
total_images=len(results),
|
||||
successful=successful,
|
||||
failed=failed,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Format Converter] ❌ Batch error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Batch format conversion failed: {e}")
|
||||
|
||||
|
||||
@router.get("/convert-format/supported", response_model=SupportedFormatsResponse, summary="Get supported formats")
|
||||
async def get_supported_formats(
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get list of supported conversion formats with their capabilities."""
|
||||
formats = studio_manager.get_supported_formats()
|
||||
return SupportedFormatsResponse(formats=formats)
|
||||
|
||||
|
||||
@router.get("/convert-format/recommendations", response_model=FormatRecommendationsResponse, summary="Get format recommendations")
|
||||
async def get_format_recommendations(
|
||||
source_format: str = Query(..., description="Source format"),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get format recommendations based on source format."""
|
||||
recommendations = studio_manager.get_format_recommendations(source_format)
|
||||
return FormatRecommendationsResponse(recommendations=recommendations)
|
||||
@@ -1,231 +0,0 @@
|
||||
"""Create Studio, Templates, Providers, Cost Estimation, and Platform Specs endpoints."""
|
||||
|
||||
import base64
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import CreateImageRequest, CostEstimationRequest
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, CreateStudioRequest
|
||||
from services.image_studio.templates import Platform, TemplateCategory
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/create", summary="Generate Image")
|
||||
async def create_image(
|
||||
request: CreateImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Generate image(s) using Create Studio."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image generation")
|
||||
logger.info(f"[Create Image] Request from user {user_id}: {request.prompt[:100]}")
|
||||
|
||||
studio_request = CreateStudioRequest(
|
||||
prompt=request.prompt,
|
||||
template_id=request.template_id,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
aspect_ratio=request.aspect_ratio,
|
||||
style_preset=request.style_preset,
|
||||
quality=request.quality,
|
||||
negative_prompt=request.negative_prompt,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
num_variations=request.num_variations,
|
||||
enhance_prompt=request.enhance_prompt,
|
||||
use_persona=request.use_persona,
|
||||
persona_id=request.persona_id,
|
||||
)
|
||||
|
||||
result = await studio_manager.create_image(studio_request, user_id=user_id)
|
||||
|
||||
for idx, img_result in enumerate(result["results"]):
|
||||
if "image_bytes" in img_result:
|
||||
img_result["image_base64"] = base64.b64encode(img_result["image_bytes"]).decode("utf-8")
|
||||
del img_result["image_bytes"]
|
||||
|
||||
logger.info(f"[Create Image] ✅ Success: {result['total_generated']} images generated")
|
||||
return result
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Create Image] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except RuntimeError as e:
|
||||
logger.error(f"[Create Image] ❌ Generation error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Image generation failed: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Create Image] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/templates", summary="Get Templates")
|
||||
async def get_templates(
|
||||
platform: Optional[Platform] = None,
|
||||
category: Optional[TemplateCategory] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available image templates."""
|
||||
try:
|
||||
templates = studio_manager.get_templates(platform=platform, category=category)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict)}
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/search", summary="Search Templates")
|
||||
async def search_templates(
|
||||
query: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Search templates by query."""
|
||||
try:
|
||||
templates = studio_manager.search_templates(query)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "query": query}
|
||||
except Exception as e:
|
||||
logger.error(f"[Search Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/templates/recommend", summary="Recommend Templates")
|
||||
async def recommend_templates(
|
||||
use_case: str,
|
||||
platform: Optional[Platform] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Recommend templates based on use case."""
|
||||
try:
|
||||
templates = studio_manager.recommend_templates(use_case, platform=platform)
|
||||
templates_dict = [
|
||||
{
|
||||
"id": t.id,
|
||||
"name": t.name,
|
||||
"category": t.category.value,
|
||||
"platform": t.platform.value if t.platform else None,
|
||||
"aspect_ratio": {
|
||||
"ratio": t.aspect_ratio.ratio,
|
||||
"width": t.aspect_ratio.width,
|
||||
"height": t.aspect_ratio.height,
|
||||
"label": t.aspect_ratio.label,
|
||||
},
|
||||
"description": t.description,
|
||||
"recommended_provider": t.recommended_provider,
|
||||
"style_preset": t.style_preset,
|
||||
"quality": t.quality,
|
||||
"use_cases": t.use_cases or [],
|
||||
}
|
||||
for t in templates
|
||||
]
|
||||
return {"templates": templates_dict, "total": len(templates_dict), "use_case": use_case}
|
||||
except Exception as e:
|
||||
logger.error(f"[Recommend Templates] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/providers", summary="Get Providers")
|
||||
async def get_providers(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get available AI providers and their capabilities."""
|
||||
try:
|
||||
providers = studio_manager.get_providers()
|
||||
return {"providers": providers}
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Providers] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/estimate-cost", summary="Estimate Cost")
|
||||
async def estimate_cost(
|
||||
request: CostEstimationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Estimate cost for image generation operations."""
|
||||
try:
|
||||
resolution = None
|
||||
if request.width and request.height:
|
||||
resolution = (request.width, request.height)
|
||||
estimate = studio_manager.estimate_cost(
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
operation=request.operation,
|
||||
num_images=request.num_images,
|
||||
resolution=resolution
|
||||
)
|
||||
return estimate
|
||||
except Exception as e:
|
||||
logger.error(f"[Estimate Cost] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/platform-specs/{platform}", summary="Get Platform Specifications")
|
||||
async def get_platform_specs(
|
||||
platform: Platform,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager)
|
||||
):
|
||||
"""Get specifications and requirements for a specific platform."""
|
||||
try:
|
||||
specs = studio_manager.get_platform_specs(platform)
|
||||
if not specs:
|
||||
raise HTTPException(status_code=404, detail=f"Specifications not found for platform: {platform}")
|
||||
return specs
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Get Platform Specs] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,35 +0,0 @@
|
||||
"""Shared dependencies for Image Studio API endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import Depends, HTTPException, status
|
||||
|
||||
from services.image_studio import ImageStudioManager
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
|
||||
|
||||
def get_studio_manager() -> ImageStudioManager:
|
||||
"""Get Image Studio Manager instance."""
|
||||
return ImageStudioManager()
|
||||
|
||||
|
||||
def _require_user_id(current_user: Dict[str, Any], operation: str) -> str:
|
||||
"""Ensure user_id is available for protected operations."""
|
||||
user_id = (
|
||||
current_user.get("sub")
|
||||
or current_user.get("user_id")
|
||||
or current_user.get("id")
|
||||
or current_user.get("clerk_user_id")
|
||||
)
|
||||
if not user_id:
|
||||
logger.error(
|
||||
"[Image Studio] ❌ Missing user_id for %s operation - blocking request",
|
||||
operation,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authenticated user required for image operations.",
|
||||
)
|
||||
return user_id
|
||||
@@ -1,122 +0,0 @@
|
||||
"""Edit Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
EditImageRequest, EditImageResponse, EditOperationsResponse,
|
||||
EditModelsResponse, EditModelRecommendationRequest, EditModelRecommendationResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, EditStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/edit/process", response_model=EditImageResponse, summary="Process Edit Studio request")
|
||||
async def process_edit_image(
|
||||
request: EditImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Perform Edit Studio operations such as remove background, inpaint, or recolor."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image editing")
|
||||
logger.info(f"[Edit Image] Request from user {user_id}: operation={request.operation}")
|
||||
|
||||
edit_request = EditStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
operation=request.operation,
|
||||
prompt=request.prompt,
|
||||
negative_prompt=request.negative_prompt,
|
||||
mask_base64=request.mask_base64,
|
||||
search_prompt=request.search_prompt,
|
||||
select_prompt=request.select_prompt,
|
||||
background_image_base64=request.background_image_base64,
|
||||
lighting_image_base64=request.lighting_image_base64,
|
||||
expand_left=request.expand_left,
|
||||
expand_right=request.expand_right,
|
||||
expand_up=request.expand_up,
|
||||
expand_down=request.expand_down,
|
||||
provider=request.provider,
|
||||
model=request.model,
|
||||
style_preset=request.style_preset,
|
||||
guidance_scale=request.guidance_scale,
|
||||
steps=request.steps,
|
||||
seed=request.seed,
|
||||
output_format=request.output_format,
|
||||
options=request.options or {},
|
||||
)
|
||||
|
||||
result = await studio_manager.edit_image(edit_request, user_id=user_id)
|
||||
return EditImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Image] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image editing failed: {e}")
|
||||
|
||||
|
||||
@router.get("/edit/operations", response_model=EditOperationsResponse, summary="List Edit Studio operations")
|
||||
async def get_edit_operations(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Return metadata for supported Edit Studio operations."""
|
||||
try:
|
||||
operations = studio_manager.get_edit_operations()
|
||||
return EditOperationsResponse(operations=operations)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Operations] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load edit operations")
|
||||
|
||||
|
||||
@router.get("/edit/models", response_model=EditModelsResponse, summary="List available editing models")
|
||||
async def get_edit_models(
|
||||
operation: Optional[str] = None,
|
||||
tier: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available WaveSpeed editing models with metadata.
|
||||
|
||||
Query Parameters:
|
||||
- operation: Filter by operation type (e.g., "general_edit")
|
||||
- tier: Filter by tier ("budget", "mid", "premium")
|
||||
"""
|
||||
try:
|
||||
result = studio_manager.get_edit_models(operation=operation, tier=tier)
|
||||
return EditModelsResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Models] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load editing models")
|
||||
|
||||
|
||||
@router.post("/edit/recommend", response_model=EditModelRecommendationResponse, summary="Get model recommendation")
|
||||
async def recommend_edit_model(
|
||||
request: EditModelRecommendationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get recommended editing model based on operation, image resolution, and user preferences.
|
||||
|
||||
Auto-detects best model when user doesn't specify one.
|
||||
"""
|
||||
try:
|
||||
user_tier = request.user_tier
|
||||
if not user_tier and current_user:
|
||||
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
|
||||
|
||||
result = studio_manager.recommend_edit_model(
|
||||
operation=request.operation,
|
||||
image_resolution=request.image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=request.preferences,
|
||||
)
|
||||
return EditModelRecommendationResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Edit Recommend] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
|
||||
@@ -1,89 +0,0 @@
|
||||
"""Face Swap Studio endpoints."""
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import (
|
||||
FaceSwapRequest, FaceSwapResponse, FaceSwapModelsResponse,
|
||||
FaceSwapModelRecommendationRequest, FaceSwapModelRecommendationResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from services.image_studio.face_swap_service import FaceSwapStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/face-swap/process", response_model=FaceSwapResponse, summary="Process Face Swap")
|
||||
async def process_face_swap(
|
||||
request: FaceSwapRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Process face swap request with auto-detection and model selection."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "face swap")
|
||||
face_swap_request = FaceSwapStudioRequest(
|
||||
base_image_base64=request.base_image_base64,
|
||||
face_image_base64=request.face_image_base64,
|
||||
model=request.model,
|
||||
target_face_index=request.target_face_index,
|
||||
target_gender=request.target_gender,
|
||||
options=request.options,
|
||||
)
|
||||
result = await studio_manager.face_swap(face_swap_request, user_id=user_id)
|
||||
return FaceSwapResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap] ❌ Error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Face swap failed: {e}")
|
||||
|
||||
|
||||
@router.get("/face-swap/models", response_model=FaceSwapModelsResponse, summary="List available face swap models")
|
||||
async def get_face_swap_models(
|
||||
tier: Optional[str] = None,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available WaveSpeed face swap models with metadata.
|
||||
|
||||
Query Parameters:
|
||||
- tier: Filter by tier ("budget", "mid", "premium")
|
||||
"""
|
||||
try:
|
||||
result = studio_manager.get_face_swap_models(tier=tier)
|
||||
return FaceSwapModelsResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap Models] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail="Failed to load face swap models")
|
||||
|
||||
|
||||
@router.post("/face-swap/recommend", response_model=FaceSwapModelRecommendationResponse, summary="Get face swap model recommendation")
|
||||
async def recommend_face_swap_model(
|
||||
request: FaceSwapModelRecommendationRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get recommended face swap model based on image resolutions and user preferences.
|
||||
|
||||
Auto-detects best model when user doesn't specify one.
|
||||
"""
|
||||
try:
|
||||
user_tier = request.user_tier
|
||||
if not user_tier and current_user:
|
||||
user_tier = current_user.get("tier") or current_user.get("subscription_tier")
|
||||
|
||||
result = studio_manager.recommend_face_swap_model(
|
||||
base_image_resolution=request.base_image_resolution,
|
||||
face_image_resolution=request.face_image_resolution,
|
||||
user_tier=user_tier,
|
||||
preferences=request.preferences,
|
||||
)
|
||||
return FaceSwapModelRecommendationResponse(**result)
|
||||
except Exception as e:
|
||||
logger.error(f"[Face Swap Recommend] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to get recommendation: {e}")
|
||||
@@ -1,21 +0,0 @@
|
||||
"""Health check endpoint."""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.get("/health", summary="Health Check")
|
||||
async def health_check():
|
||||
"""Health check endpoint for Image Studio."""
|
||||
return {
|
||||
"status": "healthy",
|
||||
"service": "image_studio",
|
||||
"version": "1.0.0",
|
||||
"modules": {
|
||||
"create_studio": "available",
|
||||
"templates": "available",
|
||||
"providers": "available",
|
||||
"compression": "available",
|
||||
}
|
||||
}
|
||||
@@ -1,372 +0,0 @@
|
||||
"""Pydantic request/response models for Image Studio API."""
|
||||
|
||||
from typing import Optional, List, Dict, Any, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ==================== Create Studio ====================
|
||||
|
||||
class CreateImageRequest(BaseModel):
|
||||
prompt: str = Field(..., description="Image generation prompt")
|
||||
template_id: Optional[str] = Field(None, description="Template ID to use")
|
||||
provider: Optional[str] = Field("auto", description="Provider: auto, stability, wavespeed, huggingface, gemini")
|
||||
model: Optional[str] = Field(None, description="Specific model to use")
|
||||
width: Optional[int] = Field(None, description="Image width in pixels")
|
||||
height: Optional[int] = Field(None, description="Image height in pixels")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (e.g., '1:1', '16:9')")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
quality: str = Field("standard", description="Quality: draft, standard, premium")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale")
|
||||
steps: Optional[int] = Field(None, description="Number of inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
num_variations: int = Field(1, ge=1, le=10, description="Number of variations (1-10)")
|
||||
enhance_prompt: bool = Field(True, description="Enhance prompt with AI")
|
||||
use_persona: bool = Field(False, description="Use persona for brand consistency")
|
||||
persona_id: Optional[str] = Field(None, description="Persona ID")
|
||||
|
||||
|
||||
class CostEstimationRequest(BaseModel):
|
||||
provider: str = Field(..., description="Provider name")
|
||||
model: Optional[str] = Field(None, description="Model name")
|
||||
operation: str = Field("generate", description="Operation type")
|
||||
num_images: int = Field(1, ge=1, description="Number of images")
|
||||
width: Optional[int] = Field(None, description="Image width")
|
||||
height: Optional[int] = Field(None, description="Image height")
|
||||
|
||||
|
||||
# ==================== Edit Studio ====================
|
||||
|
||||
class EditImageRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Primary image payload (base64 or data URL)")
|
||||
operation: Literal[
|
||||
"remove_background",
|
||||
"inpaint",
|
||||
"outpaint",
|
||||
"search_replace",
|
||||
"search_recolor",
|
||||
"general_edit",
|
||||
] = Field(..., description="Edit operation to perform")
|
||||
prompt: Optional[str] = Field(None, description="Primary prompt/instruction")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt for providers that support it")
|
||||
mask_base64: Optional[str] = Field(None, description="Optional mask image in base64")
|
||||
search_prompt: Optional[str] = Field(None, description="Search prompt for replace operations")
|
||||
select_prompt: Optional[str] = Field(None, description="Select prompt for recolor operations")
|
||||
background_image_base64: Optional[str] = Field(None, description="Reference background image")
|
||||
lighting_image_base64: Optional[str] = Field(None, description="Reference lighting image")
|
||||
expand_left: Optional[int] = Field(0, description="Outpaint expansion in pixels (left)")
|
||||
expand_right: Optional[int] = Field(0, description="Outpaint expansion in pixels (right)")
|
||||
expand_up: Optional[int] = Field(0, description="Outpaint expansion in pixels (up)")
|
||||
expand_down: Optional[int] = Field(0, description="Outpaint expansion in pixels (down)")
|
||||
provider: Optional[str] = Field(None, description="Explicit provider override")
|
||||
model: Optional[str] = Field(None, description="Explicit model override")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset for Stability helpers")
|
||||
guidance_scale: Optional[float] = Field(None, description="Guidance scale for general edits")
|
||||
steps: Optional[int] = Field(None, description="Inference steps")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
output_format: str = Field("png", description="Output format for edited image")
|
||||
options: Optional[Dict[str, Any]] = Field(None, description="Advanced provider-specific options (e.g., grow_mask)")
|
||||
|
||||
|
||||
class EditImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class EditOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
class EditModelsResponse(BaseModel):
|
||||
models: List[Dict[str, Any]]
|
||||
total: int
|
||||
|
||||
|
||||
class EditModelRecommendationRequest(BaseModel):
|
||||
operation: str
|
||||
image_resolution: Optional[Dict[str, int]] = None
|
||||
user_tier: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class EditModelRecommendationResponse(BaseModel):
|
||||
recommended_model: str
|
||||
reason: str
|
||||
alternatives: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Face Swap Studio ====================
|
||||
|
||||
class FaceSwapRequest(BaseModel):
|
||||
base_image_base64: str
|
||||
face_image_base64: str
|
||||
model: Optional[str] = None
|
||||
target_face_index: Optional[int] = None
|
||||
target_gender: Optional[str] = None
|
||||
options: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FaceSwapResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
provider: str
|
||||
model: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class FaceSwapModelsResponse(BaseModel):
|
||||
models: List[Dict[str, Any]]
|
||||
total: int
|
||||
|
||||
|
||||
class FaceSwapModelRecommendationRequest(BaseModel):
|
||||
base_image_resolution: Optional[Dict[str, int]] = None
|
||||
face_image_resolution: Optional[Dict[str, int]] = None
|
||||
user_tier: Optional[str] = None
|
||||
preferences: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class FaceSwapModelRecommendationResponse(BaseModel):
|
||||
recommended_model: str
|
||||
reason: str
|
||||
alternatives: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Upscale Studio ====================
|
||||
|
||||
class UpscaleImageRequest(BaseModel):
|
||||
image_base64: str
|
||||
mode: Literal["fast", "conservative", "creative", "auto"] = "auto"
|
||||
target_width: Optional[int] = Field(None, description="Target width in pixels")
|
||||
target_height: Optional[int] = Field(None, description="Target height in pixels")
|
||||
preset: Optional[str] = Field(None, description="Named preset (web, print, social)")
|
||||
prompt: Optional[str] = Field(None, description="Prompt for conservative/creative modes")
|
||||
|
||||
|
||||
class UpscaleImageResponse(BaseModel):
|
||||
success: bool
|
||||
mode: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
# ==================== Control Studio ====================
|
||||
|
||||
class ControlImageRequest(BaseModel):
|
||||
control_image_base64: str = Field(..., description="Control image (sketch/structure/style) in base64")
|
||||
operation: Literal["sketch", "structure", "style", "style_transfer"] = Field(..., description="Control operation")
|
||||
prompt: str = Field(..., description="Text prompt for generation")
|
||||
style_image_base64: Optional[str] = Field(None, description="Style reference image (for style_transfer only)")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
control_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Control strength (sketch/structure)")
|
||||
fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style fidelity (style operation)")
|
||||
style_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Style strength (style_transfer)")
|
||||
composition_fidelity: Optional[float] = Field(None, ge=0.0, le=1.0, description="Composition fidelity (style_transfer)")
|
||||
change_strength: Optional[float] = Field(None, ge=0.0, le=1.0, description="Change strength (style_transfer)")
|
||||
aspect_ratio: Optional[str] = Field(None, description="Aspect ratio (style operation)")
|
||||
style_preset: Optional[str] = Field(None, description="Style preset")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
output_format: str = Field("png", description="Output format")
|
||||
|
||||
|
||||
class ControlImageResponse(BaseModel):
|
||||
success: bool
|
||||
operation: str
|
||||
provider: str
|
||||
image_base64: str
|
||||
width: int
|
||||
height: int
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class ControlOperationsResponse(BaseModel):
|
||||
operations: Dict[str, Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Social Optimizer ====================
|
||||
|
||||
class SocialOptimizeRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Source image in base64 or data URL")
|
||||
platforms: List[str] = Field(..., description="List of platforms to optimize for")
|
||||
format_names: Optional[Dict[str, str]] = Field(None, description="Specific format per platform")
|
||||
show_safe_zones: bool = Field(False, description="Include safe zone overlay in output")
|
||||
crop_mode: str = Field("smart", description="Crop mode: smart, center, or fit")
|
||||
focal_point: Optional[Dict[str, float]] = Field(None, description="Focal point for smart crop (x, y as 0-1)")
|
||||
output_format: str = Field("png", description="Output format (png or jpg)")
|
||||
|
||||
|
||||
class SocialOptimizeResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[Dict[str, Any]]
|
||||
total_optimized: int
|
||||
|
||||
|
||||
class PlatformFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Transform Studio ====================
|
||||
|
||||
class TransformImageToVideoRequestModel(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
prompt: str = Field(..., description="Text prompt describing the video")
|
||||
audio_base64: Optional[str] = Field(None, description="Optional audio file (wav/mp3, 3-30s, ≤15MB)")
|
||||
resolution: Literal["480p", "720p", "1080p"] = Field("720p", description="Output resolution")
|
||||
duration: Literal[5, 10] = Field(5, description="Video duration in seconds")
|
||||
negative_prompt: Optional[str] = Field(None, description="Negative prompt")
|
||||
seed: Optional[int] = Field(None, description="Random seed for reproducibility")
|
||||
enable_prompt_expansion: bool = Field(True, description="Enable prompt optimizer")
|
||||
|
||||
|
||||
class TalkingAvatarRequestModel(BaseModel):
|
||||
image_base64: str = Field(..., description="Person image in base64 or data URL")
|
||||
audio_base64: str = Field(..., description="Audio file in base64 or data URL (wav/mp3, max 10 minutes)")
|
||||
resolution: Literal["480p", "720p"] = Field("720p", description="Output resolution")
|
||||
prompt: Optional[str] = Field(None, description="Optional prompt for expression/style")
|
||||
mask_image_base64: Optional[str] = Field(None, description="Optional mask for animatable regions")
|
||||
seed: Optional[int] = Field(None, description="Random seed")
|
||||
|
||||
|
||||
class TransformVideoResponse(BaseModel):
|
||||
success: bool
|
||||
video_url: Optional[str] = None
|
||||
video_base64: Optional[str] = None
|
||||
duration: float
|
||||
resolution: str
|
||||
width: int
|
||||
height: int
|
||||
file_size: int
|
||||
cost: float
|
||||
provider: str
|
||||
model: str
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class TransformCostEstimateRequest(BaseModel):
|
||||
operation: Literal["image-to-video", "talking-avatar"] = Field(..., description="Operation type")
|
||||
resolution: str = Field(..., description="Output resolution")
|
||||
duration: Optional[int] = Field(None, description="Video duration in seconds (for image-to-video)")
|
||||
|
||||
|
||||
class TransformCostEstimateResponse(BaseModel):
|
||||
estimated_cost: float
|
||||
breakdown: Dict[str, Any]
|
||||
currency: str
|
||||
provider: str
|
||||
model: str
|
||||
|
||||
|
||||
# ==================== Compression ====================
|
||||
|
||||
class CompressImageRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
quality: int = Field(85, ge=1, le=100, description="Compression quality (1-100)")
|
||||
format: str = Field("jpeg", description="Output format: jpeg, png, webp")
|
||||
target_size_kb: Optional[int] = Field(None, ge=10, description="Target file size in KB")
|
||||
strip_metadata: bool = Field(True, description="Remove EXIF metadata")
|
||||
progressive: bool = Field(True, description="Progressive JPEG encoding")
|
||||
optimize: bool = Field(True, description="Optimize encoding")
|
||||
|
||||
|
||||
class CompressImageResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_size_kb: float
|
||||
compressed_size_kb: float
|
||||
compression_ratio: float
|
||||
format: str
|
||||
width: int
|
||||
height: int
|
||||
quality_used: int
|
||||
metadata_stripped: bool
|
||||
|
||||
|
||||
class CompressBatchRequest(BaseModel):
|
||||
images: List[CompressImageRequest] = Field(..., description="List of images to compress")
|
||||
|
||||
|
||||
class CompressBatchResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[CompressImageResponse]
|
||||
total_images: int
|
||||
successful: int
|
||||
failed: int
|
||||
|
||||
|
||||
class CompressionEstimateRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
format: str = Field("jpeg", description="Output format")
|
||||
quality: int = Field(85, ge=1, le=100, description="Quality level")
|
||||
|
||||
|
||||
class CompressionEstimateResponse(BaseModel):
|
||||
original_size_kb: float
|
||||
estimated_size_kb: float
|
||||
estimated_reduction_percent: float
|
||||
width: int
|
||||
height: int
|
||||
format: str
|
||||
|
||||
|
||||
class CompressionFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class CompressionPresetsResponse(BaseModel):
|
||||
presets: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# ==================== Format Converter ====================
|
||||
|
||||
class ConvertFormatRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Image in base64 or data URL format")
|
||||
target_format: str = Field(..., description="Target format: png, jpeg, jpg, webp, gif, bmp, tiff")
|
||||
preserve_transparency: bool = Field(True, description="Preserve transparency when possible")
|
||||
quality: Optional[int] = Field(None, ge=1, le=100, description="Quality for lossy formats (1-100)")
|
||||
color_space: Optional[str] = Field(None, description="Color space: sRGB, Adobe RGB")
|
||||
strip_metadata: bool = Field(False, description="Remove EXIF metadata")
|
||||
optimize: bool = Field(True, description="Optimize encoding")
|
||||
progressive: bool = Field(True, description="Progressive JPEG encoding")
|
||||
|
||||
|
||||
class ConvertFormatResponse(BaseModel):
|
||||
success: bool
|
||||
image_base64: str
|
||||
original_format: str
|
||||
target_format: str
|
||||
original_size_kb: float
|
||||
converted_size_kb: float
|
||||
width: int
|
||||
height: int
|
||||
transparency_preserved: bool
|
||||
metadata_preserved: bool
|
||||
color_space: Optional[str] = None
|
||||
|
||||
|
||||
class ConvertFormatBatchRequest(BaseModel):
|
||||
images: List[ConvertFormatRequest] = Field(..., description="List of images to convert")
|
||||
|
||||
|
||||
class ConvertFormatBatchResponse(BaseModel):
|
||||
success: bool
|
||||
results: List[ConvertFormatResponse]
|
||||
total_images: int
|
||||
successful: int
|
||||
failed: int
|
||||
|
||||
|
||||
class SupportedFormatsResponse(BaseModel):
|
||||
formats: List[Dict[str, Any]]
|
||||
|
||||
|
||||
class FormatRecommendationsResponse(BaseModel):
|
||||
recommendations: List[Dict[str, Any]]
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Save generated images to the unified asset library."""
|
||||
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Dict, Any, Optional
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from .deps import _require_user_id
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from utils.logger_utils import get_service_logger
|
||||
from utils.storage_paths import get_repo_root, sanitize_user_id
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
class SaveToLibraryRequest(BaseModel):
|
||||
image_base64: str = Field(..., description="Base64-encoded image (or data URL)")
|
||||
prompt: Optional[str] = None
|
||||
provider: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
cost: Optional[float] = None
|
||||
operation: str = Field("image-generation", description="Operation type for labelling")
|
||||
output_format: str = Field("png", description="Output image format")
|
||||
|
||||
|
||||
@router.post("/save-to-library")
|
||||
async def save_to_library(
|
||||
req: SaveToLibraryRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Save a generated image to the asset library.
|
||||
|
||||
Decodes base64 image data, saves to workspace disk storage,
|
||||
and creates a record in the ContentAsset database table.
|
||||
"""
|
||||
user_id = _require_user_id(current_user, "save-to-library")
|
||||
|
||||
# Decode base64 payload
|
||||
try:
|
||||
b64data = req.image_base64
|
||||
if "base64," in b64data:
|
||||
b64data = b64data.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(b64data)
|
||||
except Exception:
|
||||
raise HTTPException(status_code=400, detail="Invalid base64 image data")
|
||||
|
||||
# Generate file path under workspace
|
||||
safe_user = sanitize_user_id(user_id)
|
||||
repo_root = get_repo_root()
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
||||
filename = f"generated_{timestamp}.{req.output_format or 'png'}"
|
||||
|
||||
assets_dir = repo_root / "workspace" / f"workspace_{safe_user}" / "assets" / "images"
|
||||
assets_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
file_path = assets_dir / filename
|
||||
file_path.write_bytes(image_bytes)
|
||||
|
||||
# Build serving URL (assets_serving.py serves /{user_id}/avatars/{filename})
|
||||
file_url = f"/api/assets/{safe_user}/avatars/{filename}"
|
||||
|
||||
# Save to unified asset library via existing utility
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
|
||||
asset_id = save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
source_module="image_studio",
|
||||
filename=filename,
|
||||
file_url=file_url,
|
||||
file_path=str(file_path),
|
||||
file_size=len(image_bytes),
|
||||
mime_type=f"image/{req.output_format or 'png'}",
|
||||
title=f"Generated Image - {timestamp}",
|
||||
prompt=req.prompt,
|
||||
provider=req.provider,
|
||||
model=req.model,
|
||||
cost=req.cost,
|
||||
)
|
||||
|
||||
if not asset_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to save to asset library")
|
||||
|
||||
logger.info(f"[Save to Library] ✅ Image saved: asset_id={asset_id}, user={user_id}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"asset_id": asset_id,
|
||||
"file_url": file_url,
|
||||
"filename": filename,
|
||||
"file_size": len(image_bytes),
|
||||
}
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Social Optimizer endpoints."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import SocialOptimizeRequest, SocialOptimizeResponse, PlatformFormatsResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, SocialOptimizerRequest
|
||||
from services.image_studio.templates import Platform
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/social/optimize", response_model=SocialOptimizeResponse, summary="Optimize image for social platforms")
|
||||
async def optimize_for_social(
|
||||
request: SocialOptimizeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Optimize an image for multiple social media platforms with smart cropping and safe zones."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "social optimization")
|
||||
logger.info(f"[Social Optimizer] Request from user {user_id}: platforms={request.platforms}")
|
||||
|
||||
platforms = []
|
||||
for platform_str in request.platforms:
|
||||
try:
|
||||
platforms.append(Platform(platform_str.lower()))
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform: {platform_str}")
|
||||
continue
|
||||
|
||||
if not platforms:
|
||||
raise HTTPException(status_code=400, detail="No valid platforms provided")
|
||||
|
||||
format_names = None
|
||||
if request.format_names:
|
||||
format_names = {}
|
||||
for platform_str, format_name in request.format_names.items():
|
||||
try:
|
||||
platform = Platform(platform_str.lower())
|
||||
format_names[platform] = format_name
|
||||
except ValueError:
|
||||
logger.warning(f"[Social Optimizer] Invalid platform in format_names: {platform_str}")
|
||||
|
||||
social_request = SocialOptimizerRequest(
|
||||
image_base64=request.image_base64,
|
||||
platforms=platforms,
|
||||
format_names=format_names,
|
||||
show_safe_zones=request.show_safe_zones,
|
||||
crop_mode=request.crop_mode,
|
||||
focal_point=request.focal_point,
|
||||
output_format=request.output_format,
|
||||
options={},
|
||||
)
|
||||
|
||||
result = await studio_manager.optimize_for_social(social_request, user_id=user_id)
|
||||
return SocialOptimizeResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Social Optimizer] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Social optimization failed: {e}")
|
||||
|
||||
|
||||
@router.get("/social/platforms/{platform}/formats", response_model=PlatformFormatsResponse, summary="Get platform formats")
|
||||
async def get_platform_formats(
|
||||
platform: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Get available formats for a social media platform."""
|
||||
try:
|
||||
try:
|
||||
platform_enum = Platform(platform.lower())
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=400, detail=f"Invalid platform: {platform}")
|
||||
|
||||
formats = studio_manager.get_social_platform_formats(platform_enum)
|
||||
return PlatformFormatsResponse(formats=formats)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Platform Formats] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to load platform formats: {e}")
|
||||
@@ -1,158 +0,0 @@
|
||||
"""Transform Studio endpoints — image-to-video, talking avatar, and video serving."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import FileResponse
|
||||
|
||||
from .models import (
|
||||
TransformImageToVideoRequestModel, TalkingAvatarRequestModel,
|
||||
TransformVideoResponse, TransformCostEstimateRequest, TransformCostEstimateResponse,
|
||||
)
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager, TransformImageToVideoRequest, TalkingAvatarRequest
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/transform/image-to-video", response_model=TransformVideoResponse, summary="Transform Image to Video")
|
||||
async def transform_image_to_video(
|
||||
request: TransformImageToVideoRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Transform an image into a video using WAN 2.5."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image-to-video transformation")
|
||||
logger.info(f"[Transform Studio] Image-to-video request from user {user_id}: resolution={request.resolution}, duration={request.duration}s")
|
||||
|
||||
transform_request = TransformImageToVideoRequest(
|
||||
image_base64=request.image_base64,
|
||||
prompt=request.prompt,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
negative_prompt=request.negative_prompt,
|
||||
seed=request.seed,
|
||||
enable_prompt_expansion=request.enable_prompt_expansion,
|
||||
)
|
||||
|
||||
result = await studio_manager.transform_image_to_video(transform_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Image-to-video completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Video generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/talking-avatar", response_model=TransformVideoResponse, summary="Create Talking Avatar")
|
||||
async def create_talking_avatar(
|
||||
request: TalkingAvatarRequestModel,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Create a talking avatar video using InfiniteTalk."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "talking avatar generation")
|
||||
logger.info(f"[Transform Studio] Talking avatar request from user {user_id}: resolution={request.resolution}")
|
||||
|
||||
avatar_request = TalkingAvatarRequest(
|
||||
image_base64=request.image_base64,
|
||||
audio_base64=request.audio_base64,
|
||||
resolution=request.resolution,
|
||||
prompt=request.prompt,
|
||||
mask_image_base64=request.mask_image_base64,
|
||||
seed=request.seed,
|
||||
)
|
||||
|
||||
result = await studio_manager.create_talking_avatar(avatar_request, user_id=user_id)
|
||||
|
||||
logger.info(f"[Transform Studio] ✅ Talking avatar completed: cost=${result['cost']:.2f}")
|
||||
return TransformVideoResponse(**result)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Validation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Unexpected error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Talking avatar generation failed: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/transform/estimate-cost", response_model=TransformCostEstimateResponse, summary="Estimate Transform Cost")
|
||||
async def estimate_transform_cost(
|
||||
request: TransformCostEstimateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Estimate cost for transform operations."""
|
||||
try:
|
||||
estimate = studio_manager.estimate_transform_cost(
|
||||
operation=request.operation,
|
||||
resolution=request.resolution,
|
||||
duration=request.duration,
|
||||
)
|
||||
return TransformCostEstimateResponse(**estimate)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"[Transform Studio] ❌ Cost estimation error: {str(e)}")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] ❌ Error: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/videos/{user_id}/{video_filename:path}", summary="Serve Transform Studio Video")
|
||||
async def serve_transform_video(
|
||||
user_id: str,
|
||||
video_filename: str,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user_with_query_token),
|
||||
):
|
||||
"""Serve a generated Transform Studio video file."""
|
||||
try:
|
||||
authenticated_user_id = _require_user_id(current_user, "video access")
|
||||
if authenticated_user_id != user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Access denied: You can only access your own videos"
|
||||
)
|
||||
|
||||
base_dir = Path(__file__).parent.parent.parent
|
||||
transform_videos_dir = base_dir / "transform_videos"
|
||||
video_path = transform_videos_dir / user_id / video_filename
|
||||
|
||||
try:
|
||||
resolved_video_path = video_path.resolve()
|
||||
resolved_base = transform_videos_dir.resolve()
|
||||
resolved_video_path.relative_to(resolved_base)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Invalid video path: path traversal detected"
|
||||
)
|
||||
|
||||
if not video_path.exists():
|
||||
raise HTTPException(status_code=404, detail="Video not found")
|
||||
|
||||
return FileResponse(
|
||||
path=str(video_path),
|
||||
media_type="video/mp4",
|
||||
filename=video_filename
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Transform Studio] Failed to serve video: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Upscale Studio endpoint."""
|
||||
|
||||
from typing import Dict, Any
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
from .models import UpscaleImageRequest, UpscaleImageResponse
|
||||
from .deps import get_studio_manager, _require_user_id
|
||||
from services.image_studio import ImageStudioManager
|
||||
from services.image_studio.upscale_service import UpscaleStudioRequest
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("api.image_studio")
|
||||
router = APIRouter(tags=["image-studio"])
|
||||
|
||||
|
||||
@router.post("/upscale", response_model=UpscaleImageResponse, summary="Upscale Image")
|
||||
async def upscale_image(
|
||||
request: UpscaleImageRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
studio_manager: ImageStudioManager = Depends(get_studio_manager),
|
||||
):
|
||||
"""Upscale an image using Stability AI pipelines."""
|
||||
try:
|
||||
user_id = _require_user_id(current_user, "image upscaling")
|
||||
upscale_request = UpscaleStudioRequest(
|
||||
image_base64=request.image_base64,
|
||||
mode=request.mode,
|
||||
target_width=request.target_width,
|
||||
target_height=request.target_height,
|
||||
preset=request.preset,
|
||||
prompt=request.prompt,
|
||||
)
|
||||
result = await studio_manager.upscale_image(upscale_request, user_id=user_id)
|
||||
return UpscaleImageResponse(**result)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Upscale Image] ❌ Error: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"Image upscaling failed: {e}")
|
||||
@@ -9,7 +9,6 @@ import json
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.blog_models import (
|
||||
MediumBlogGenerateRequest,
|
||||
@@ -27,7 +26,7 @@ class MediumBlogGenerator:
|
||||
def __init__(self):
|
||||
self.cache = persistent_content_cache
|
||||
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str, db: Session = None) -> MediumBlogGenerateResult:
|
||||
async def generate_medium_blog_with_progress(self, req: MediumBlogGenerateRequest, task_id: str, user_id: str) -> MediumBlogGenerateResult:
|
||||
"""Use Gemini structured JSON to generate a medium-length blog in one call.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -499,7 +499,7 @@ class DatabaseTaskManager:
|
||||
)
|
||||
blog_writer_logger.log_error(e, "outline_generation_task", context={"task_id": task_id})
|
||||
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest, user_id: str):
|
||||
async def _run_medium_generation_task(self, task_id: str, request: MediumBlogGenerateRequest):
|
||||
"""Background task to generate a medium blog using a single structured JSON call."""
|
||||
try:
|
||||
await self.update_progress(task_id, "📦 Packaging outline and metadata...", 0)
|
||||
@@ -512,7 +512,7 @@ class DatabaseTaskManager:
|
||||
result: MediumBlogGenerateResult = await self.service.generate_medium_blog_with_progress(
|
||||
request,
|
||||
task_id,
|
||||
user_id,
|
||||
user_id=request.user_id if hasattr(request, 'user_id') else (await self.get_task_status(task_id))['user_id'],
|
||||
db=self.db
|
||||
)
|
||||
|
||||
|
||||
@@ -70,22 +70,22 @@ STRATEGIC REQUIREMENTS:
|
||||
- Ensure engaging, actionable content throughout
|
||||
|
||||
Return JSON format:
|
||||
{{
|
||||
{
|
||||
"title_options": [
|
||||
"Title option 1",
|
||||
"Title option 2",
|
||||
"Title option 3"
|
||||
],
|
||||
"outline": [
|
||||
{{
|
||||
{
|
||||
"heading": "Section heading with primary keyword",
|
||||
"subheadings": ["Subheading 1", "Subheading 2", "Subheading 3"],
|
||||
"key_points": ["Key point 1", "Key point 2", "Key point 3"],
|
||||
"target_words": 300,
|
||||
"keywords": ["primary keyword", "secondary keyword"]
|
||||
}}
|
||||
}
|
||||
]
|
||||
}}"""
|
||||
}"""
|
||||
|
||||
def get_outline_schema(self) -> Dict[str, Any]:
|
||||
"""Get the structured JSON schema for outline generation."""
|
||||
|
||||
@@ -5,8 +5,8 @@ Enhances individual outline sections for better engagement and value.
|
||||
"""
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from models.blog_models import BlogOutlineSection
|
||||
import json
|
||||
|
||||
|
||||
class SectionEnhancer:
|
||||
@@ -73,45 +73,14 @@ class SectionEnhancer:
|
||||
"required": ["heading", "subheadings", "key_points", "target_words", "keywords"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
enhanced_data = llm_text_gen(
|
||||
prompt=enhancement_prompt,
|
||||
json_struct=enhancement_schema,
|
||||
system_prompt=None,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
enhanced_data = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
enhanced_data = json.loads(json_match.group(0))
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Section enhancement returned invalid JSON: {e}")
|
||||
return section
|
||||
else:
|
||||
logger.warning(f"Section enhancement returned non-JSON string: {cleaned[:200]}")
|
||||
return section
|
||||
elif isinstance(raw, dict):
|
||||
enhanced_data = raw
|
||||
else:
|
||||
logger.warning(f"Unexpected LLM response type: {type(raw)}")
|
||||
return section
|
||||
|
||||
if 'error' in enhanced_data:
|
||||
logger.warning(f"AI section enhancement failed: {enhanced_data.get('error', 'Unknown error')}")
|
||||
else:
|
||||
if isinstance(enhanced_data, dict) and 'error' not in enhanced_data:
|
||||
return BlogOutlineSection(
|
||||
id=section.id,
|
||||
heading=enhanced_data.get('heading', section.heading),
|
||||
|
||||
@@ -6,7 +6,6 @@ Extracts competitor insights and market intelligence from research content.
|
||||
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class CompetitorAnalyzer:
|
||||
@@ -23,7 +22,7 @@ class CompetitorAnalyzer:
|
||||
Extract and analyze:
|
||||
1. Top competitors mentioned (companies, brands, platforms)
|
||||
2. Content gaps (what competitors are missing)
|
||||
3. Opportunities (untapped areas)
|
||||
3. Market opportunities (untapped areas)
|
||||
4. Competitive advantages (what makes content unique)
|
||||
5. Market positioning insights
|
||||
6. Industry leaders and their strategies
|
||||
@@ -56,38 +55,18 @@ class CompetitorAnalyzer:
|
||||
"required": ["top_competitors", "content_gaps", "opportunities", "competitive_advantages", "market_positioning", "industry_leaders", "analysis_notes"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
competitor_analysis = llm_text_gen(
|
||||
prompt=competitor_prompt,
|
||||
json_struct=competitor_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
competitor_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
competitor_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Competitor analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
competitor_analysis = raw
|
||||
if isinstance(competitor_analysis, dict) and 'error' not in competitor_analysis:
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
else:
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in competitor_analysis:
|
||||
raise ValueError(f"Competitor analysis failed: {competitor_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI competitor analysis completed successfully")
|
||||
return competitor_analysis
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = competitor_analysis.get('error', 'Unknown error') if isinstance(competitor_analysis, dict) else str(competitor_analysis)
|
||||
logger.error(f"AI competitor analysis failed: {error_msg}")
|
||||
raise ValueError(f"Competitor analysis failed: {error_msg}")
|
||||
|
||||
|
||||
@@ -63,41 +63,18 @@ class ContentAngleGenerator:
|
||||
"required": ["content_angles"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
angles_result = llm_text_gen(
|
||||
prompt=angles_prompt,
|
||||
json_struct=angles_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import json, re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
angles_result = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
angles_result = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Content angles returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
angles_result = raw
|
||||
if isinstance(angles_result, dict) and 'content_angles' in angles_result:
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
else:
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in angles_result:
|
||||
raise ValueError(f"Content angles generation failed: {angles_result.get('error', 'Unknown error')}")
|
||||
|
||||
if 'content_angles' not in angles_result:
|
||||
raise ValueError(f"Content angles missing from response")
|
||||
|
||||
logger.info("✅ AI content angles generation completed successfully")
|
||||
return angles_result['content_angles'][:7]
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = angles_result.get('error', 'Unknown error') if isinstance(angles_result, dict) else str(angles_result)
|
||||
logger.error(f"AI content angles generation failed: {error_msg}")
|
||||
raise ValueError(f"Content angles generation failed: {error_msg}")
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ Extracts and analyzes keywords from research content using structured AI respons
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
|
||||
class KeywordAnalyzer:
|
||||
@@ -63,38 +62,18 @@ class KeywordAnalyzer:
|
||||
"required": ["primary", "secondary", "long_tail", "search_intent", "difficulty", "content_gaps", "semantic_keywords", "trending_terms", "analysis_insights"]
|
||||
}
|
||||
|
||||
raw = llm_text_gen(
|
||||
keyword_analysis = llm_text_gen(
|
||||
prompt=keyword_prompt,
|
||||
json_struct=keyword_schema,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
# Parse JSON from LLM response (works with both string and dict return types)
|
||||
import re
|
||||
if isinstance(raw, str):
|
||||
cleaned = raw.strip()
|
||||
if cleaned.startswith('```json'):
|
||||
cleaned = cleaned[7:]
|
||||
if cleaned.startswith('```'):
|
||||
cleaned = cleaned[3:]
|
||||
if cleaned.endswith('```'):
|
||||
cleaned = cleaned[:-3]
|
||||
cleaned = cleaned.strip()
|
||||
try:
|
||||
keyword_analysis = json.loads(cleaned)
|
||||
except json.JSONDecodeError:
|
||||
json_match = re.search(r'\{.*\}', cleaned, re.DOTALL)
|
||||
if json_match:
|
||||
keyword_analysis = json.loads(json_match.group(0))
|
||||
else:
|
||||
raise ValueError(f"Keyword analysis returned non-JSON string: {cleaned[:200]}")
|
||||
elif isinstance(raw, dict):
|
||||
keyword_analysis = raw
|
||||
if isinstance(keyword_analysis, dict) and 'error' not in keyword_analysis:
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
else:
|
||||
raise ValueError(f"Unexpected LLM response type: {type(raw)}")
|
||||
|
||||
if 'error' in keyword_analysis:
|
||||
raise ValueError(f"Keyword analysis failed: {keyword_analysis.get('error', 'Unknown error')}")
|
||||
|
||||
logger.info("✅ AI keyword analysis completed successfully")
|
||||
return keyword_analysis
|
||||
# Fail gracefully - no fallback data
|
||||
error_msg = keyword_analysis.get('error', 'Unknown error') if isinstance(keyword_analysis, dict) else str(keyword_analysis)
|
||||
logger.error(f"AI keyword analysis failed: {error_msg}")
|
||||
raise ValueError(f"Keyword analysis failed: {error_msg}")
|
||||
|
||||
|
||||
@@ -111,22 +111,19 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
validate_exa_research_operations(pricing_service, user_id, gpt_provider)
|
||||
finally:
|
||||
if db_val:
|
||||
db_val.close()
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
api_start_time = time.time()
|
||||
@@ -165,15 +162,13 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
import time
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation (similar to Exa)
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -434,16 +429,14 @@ class ResearchService:
|
||||
# Exa research workflow
|
||||
from .exa_provider import ExaResearchProvider
|
||||
from services.subscription.preflight_validator import validate_exa_research_operations
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Exa neural search...")
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
gpt_provider = os.getenv("GPT_PROVIDER", "google")
|
||||
@@ -453,8 +446,7 @@ class ResearchService:
|
||||
await task_manager.update_progress(task_id, f"❌ Subscription limit exceeded: {http_error.detail.get('message', str(http_error.detail)) if isinstance(http_error.detail, dict) else str(http_error.detail)}")
|
||||
raise
|
||||
finally:
|
||||
if db_val:
|
||||
db_val.close()
|
||||
db_val.close()
|
||||
|
||||
# Execute Exa search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Exa neural search...")
|
||||
@@ -493,16 +485,14 @@ class ResearchService:
|
||||
elif config.provider == ResearchProvider.TAVILY:
|
||||
# Tavily research workflow
|
||||
from .tavily_provider import TavilyResearchProvider
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
import os
|
||||
|
||||
await task_manager.update_progress(task_id, "🌐 Connecting to Tavily AI search...")
|
||||
|
||||
# Pre-flight validation (use get_session_for_user since get_db is a FastAPI dependency)
|
||||
db_val = get_session_for_user(user_id)
|
||||
if not db_val:
|
||||
raise HTTPException(status_code=503, detail="Database temporarily unavailable. Please try again.")
|
||||
# Pre-flight validation
|
||||
db_val = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db_val)
|
||||
# Check Tavily usage limits
|
||||
@@ -539,8 +529,7 @@ class ResearchService:
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking Tavily limits: {e}")
|
||||
finally:
|
||||
if db_val:
|
||||
db_val.close()
|
||||
db_val.close()
|
||||
|
||||
# Execute Tavily search
|
||||
await task_manager.update_progress(task_id, "🤖 Executing Tavily AI search...")
|
||||
|
||||
@@ -135,14 +135,11 @@ class TavilyResearchProvider(BaseProvider):
|
||||
|
||||
def track_tavily_usage(self, user_id: str, cost: float, search_depth: str):
|
||||
"""Track Tavily API usage after successful call."""
|
||||
from services.database import get_session_for_user
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.warning(f"[Tavily] Could not get DB session for user {user_id}, skipping usage tracking")
|
||||
return
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
@@ -92,7 +92,6 @@ class BlogSEORecommendationApplier:
|
||||
None,
|
||||
schema,
|
||||
user_id, # Pass user_id for subscription checking
|
||||
max_tokens=8192,
|
||||
)
|
||||
|
||||
if not result or result.get("error"):
|
||||
|
||||
@@ -237,21 +237,6 @@ class ControlStudioService:
|
||||
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_bytes_to_metadata(image_bytes)
|
||||
|
||||
# Track usage
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"control-{operation}",
|
||||
operation_type="image-control",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/control/process",
|
||||
log_prefix="[Control Studio]"
|
||||
)
|
||||
|
||||
metadata.update(
|
||||
{
|
||||
"operation": operation,
|
||||
|
||||
@@ -514,19 +514,6 @@ class EditStudioService:
|
||||
background_bytes=background_bytes,
|
||||
lighting_bytes=lighting_bytes,
|
||||
)
|
||||
# Track usage for Stability operations
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"edit-{operation}",
|
||||
operation_type="image-edit",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/edit/process",
|
||||
log_prefix="[Edit Studio]"
|
||||
)
|
||||
else:
|
||||
image_bytes = await self._handle_general_edit(
|
||||
request=request,
|
||||
|
||||
@@ -88,20 +88,6 @@ class UpscaleStudioService:
|
||||
image_bytes = self._extract_image_bytes(result)
|
||||
metadata = self._image_metadata(image_bytes)
|
||||
|
||||
# Track usage
|
||||
if user_id:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="stability",
|
||||
model=f"upscale-{mode}",
|
||||
operation_type="image-upscale",
|
||||
result_bytes=image_bytes,
|
||||
cost=0.04,
|
||||
endpoint="/image-studio/upscale",
|
||||
log_prefix="[Upscale Studio]"
|
||||
)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"mode": mode,
|
||||
|
||||
@@ -233,7 +233,7 @@ def create_blog_post(
|
||||
|
||||
# BACK TO BASICS MODE: Try simplest possible structure FIRST
|
||||
# Since posting worked before Ricos/SEO, let's test with absolute minimum
|
||||
BACK_TO_BASICS_MODE = False # Disabled: full Ricos conversion now produces valid output
|
||||
BACK_TO_BASICS_MODE = True # Set to True to test with simplest structure
|
||||
|
||||
wix_logger.reset()
|
||||
wix_logger.log_operation_start("Blog Post Creation", title=title[:50] if title else None, member_id=member_id[:20] if member_id else None)
|
||||
@@ -257,7 +257,8 @@ def create_blog_post(
|
||||
'text': (content[:500] if content else "This is a post from ALwrity.").strip(),
|
||||
'decorations': []
|
||||
}
|
||||
}]
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}]
|
||||
}
|
||||
|
||||
|
||||
@@ -256,16 +256,17 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
quote_content = ' '.join(quote_lines)
|
||||
text_nodes = parse_markdown_inline(quote_content)
|
||||
# CRITICAL: TEXT nodes must be wrapped in PARAGRAPH nodes within BLOCKQUOTE
|
||||
# Wix API: omit empty data objects, don't include them as {}
|
||||
paragraph_node = {
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
blockquote_node = {
|
||||
'id': node_id,
|
||||
'type': 'BLOCKQUOTE',
|
||||
'nodes': [paragraph_node],
|
||||
'blockquoteData': {}
|
||||
}
|
||||
nodes.append(blockquote_node)
|
||||
|
||||
@@ -331,6 +332,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -343,6 +345,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'BULLETED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'bulletedListData': {}
|
||||
}
|
||||
nodes.append(bulleted_list_node)
|
||||
|
||||
@@ -370,6 +373,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': str(uuid.uuid4()),
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
list_item_node = {
|
||||
'id': item_node_id,
|
||||
@@ -382,6 +386,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'ORDERED_LIST',
|
||||
'nodes': list_node_items,
|
||||
'orderedListData': {}
|
||||
}
|
||||
nodes.append(ordered_list_node)
|
||||
|
||||
@@ -437,6 +442,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'id': node_id,
|
||||
'type': 'PARAGRAPH',
|
||||
'nodes': text_nodes,
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(paragraph_node)
|
||||
|
||||
@@ -455,6 +461,7 @@ def convert_content_to_ricos(content: str, images: List[str] = None) -> Dict[str
|
||||
'decorations': []
|
||||
}
|
||||
}],
|
||||
'paragraphData': {}
|
||||
}
|
||||
nodes.append(fallback_paragraph)
|
||||
|
||||
|
||||
@@ -20,14 +20,13 @@ class SemanticHarvesterService:
|
||||
"last_harvest_time": None
|
||||
}
|
||||
|
||||
async def harvest_website(self, website_url: str, limit: int = 100, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
async def harvest_website(self, website_url: str, limit: int = 100) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Deep crawl a website using Exa AI.
|
||||
|
||||
Args:
|
||||
website_url: The root URL to crawl.
|
||||
limit: Maximum number of pages to retrieve.
|
||||
user_id: Optional user ID for usage tracking and preflight checks.
|
||||
|
||||
Returns:
|
||||
List of pages with content and metadata.
|
||||
@@ -60,30 +59,6 @@ class SemanticHarvesterService:
|
||||
logger.warning("[SemanticHarvester] Exa service disabled. Returning placeholder data.")
|
||||
return self._get_placeholder_data(website_url)
|
||||
|
||||
# Preflight subscription check if user_id provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
logger.warning(f"[SemanticHarvester] Exa blocked for user {user_id}: {message}")
|
||||
return []
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[SemanticHarvester] Preflight check failed: {e}")
|
||||
|
||||
# Use Exa to search for all pages in this domain
|
||||
search_response = self.exa_service.exa.search_and_contents(
|
||||
query=f"site:{website_url}",
|
||||
@@ -107,38 +82,6 @@ class SemanticHarvesterService:
|
||||
})
|
||||
|
||||
logger.info(f"[SemanticHarvester] Successfully harvested {len(results)} pages from {website_url}")
|
||||
|
||||
# Track Exa usage if user_id provided
|
||||
if user_id and results:
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Exa search cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET exa_calls = COALESCE(exa_calls, 0) + 1,
|
||||
exa_cost = COALESCE(exa_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[SemanticHarvester] Tracked Exa usage: user={user_id}, cost=${cost}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[SemanticHarvester] Failed to track Exa usage: {track_err}")
|
||||
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,120 +0,0 @@
|
||||
"""Image editing operations — generate_image_edit and related helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import ImageEditOptions, ImageGenerationResult, ImageEditProvider
|
||||
from .wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.edit")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Generate edited image with pre-flight validation and usage tracking.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Image editing failed", "message": str(e)}
|
||||
)
|
||||
|
||||
# 6. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Image Edit] ✅ API call successful, tracking usage for user {user_id}")
|
||||
estimated_cost = 0.0
|
||||
if result.metadata and "estimated_cost" in result.metadata:
|
||||
estimated_cost = float(result.metadata["estimated_cost"])
|
||||
else:
|
||||
estimated_cost = 0.02 if provider_name == "wavespeed" else 0.05
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=result.model or model or "unknown",
|
||||
operation_type="image-edit",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-generation/edit",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Image Edit] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||
|
||||
# 7. Return result
|
||||
return result
|
||||
@@ -1,105 +0,0 @@
|
||||
"""Face swap operations — generate_face_swap and related helpers."""
|
||||
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
|
||||
from .base import FaceSwapOptions, FaceSwapProvider, ImageGenerationResult
|
||||
from .wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from .helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.face_swap")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""Generate face swap with pre-flight validation and usage tracking.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
num_operations=1,
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025)
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None,
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as api_error:
|
||||
logger.error(f"[Face Swap] Face swap API failed: {api_error}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={"error": "Face swap failed", "message": str(api_error)}
|
||||
)
|
||||
@@ -1,200 +0,0 @@
|
||||
"""Shared helpers for image generation operations — validation and usage tracking."""
|
||||
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
logger = get_service_logger("image_generation.helpers")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""Reusable pre-flight validation helper for all image operations."""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id}")
|
||||
except HTTPException:
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""Reusable usage tracking helper for all image operations."""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Map provider to DB column names
|
||||
provider_column_map = {
|
||||
"stability": ("stability_calls", "stability_cost"),
|
||||
"wavespeed": ("wavespeed_calls", "wavespeed_cost"),
|
||||
"gemini": ("gemini_calls", "gemini_cost"),
|
||||
"openai": ("openai_calls", "openai_cost"),
|
||||
"huggingface": ("total_calls", "total_cost"), # no dedicated columns
|
||||
}
|
||||
calls_col, cost_col = provider_column_map.get(provider, ("total_calls", "total_cost"))
|
||||
|
||||
current_calls_before = getattr(summary, calls_col, 0) or 0
|
||||
current_cost_before = getattr(summary, cost_col, 0.0) or 0.0
|
||||
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {calls_col} = :new_calls,
|
||||
{cost_col} = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Map provider to APIProvider enum
|
||||
provider_api_map = {
|
||||
"stability": APIProvider.STABILITY,
|
||||
"wavespeed": APIProvider.WAVESPEED,
|
||||
"gemini": APIProvider.GEMINI,
|
||||
"openai": APIProvider.OPENAI,
|
||||
"image_edit": APIProvider.IMAGE_EDIT,
|
||||
"video": APIProvider.VIDEO,
|
||||
"audio": APIProvider.AUDIO,
|
||||
}
|
||||
api_provider = provider_api_map.get(provider, APIProvider.STABILITY)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
provider_limit = limits['limits'].get(calls_col, 0) if limits else 0
|
||||
provider_limit_display = provider_limit if (provider_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {provider_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {"current_calls": new_calls, "cost": cost, "total_cost": new_cost}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
@@ -133,9 +133,9 @@ def edit_image(
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
# In feature-limited mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all"):
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in feature-limited mode - allowing operation to continue")
|
||||
# In podcast-only mode, allow the operation to continue on validation errors
|
||||
if os.getenv("ALWRITY_ENABLED_FEATURES") == "podcast":
|
||||
logger.warning(f"[Image Editing] ⚠️ Validation error in podcast mode - allowing operation to continue")
|
||||
else:
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
|
||||
@@ -18,9 +18,9 @@ from .image_generation import (
|
||||
StabilityImageProvider,
|
||||
WaveSpeedImageProvider,
|
||||
)
|
||||
from .image_generation.helpers import _validate_image_operation, _track_image_operation_usage
|
||||
from .image_generation.edit import generate_image_edit
|
||||
from .image_generation.face_swap import generate_face_swap
|
||||
from .image_generation.base import FaceSwapOptions, FaceSwapProvider
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
from .image_generation.wavespeed_face_swap_provider import WaveSpeedFaceSwapProvider
|
||||
from utils.logger_utils import get_service_logger
|
||||
from .tenant_provider_config import tenant_provider_config_resolver
|
||||
|
||||
@@ -53,6 +53,259 @@ def _get_provider(provider_name: str, user_id: Optional[str] = None):
|
||||
raise ValueError(f"Unknown image provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_face_swap_provider(provider_name: str) -> FaceSwapProvider:
|
||||
"""Get face swap provider by name."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedFaceSwapProvider()
|
||||
raise ValueError(f"Unknown face swap provider: {provider_name}")
|
||||
|
||||
|
||||
def _get_edit_provider(provider_name: str) -> ImageEditProvider:
|
||||
"""Get editing provider instance.
|
||||
|
||||
Args:
|
||||
provider_name: Provider name ("wavespeed", "stability", etc.)
|
||||
|
||||
Returns:
|
||||
ImageEditProvider instance
|
||||
|
||||
Raises:
|
||||
ValueError: If provider is not supported
|
||||
"""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider()
|
||||
# TODO: Add Stability edit provider if needed
|
||||
# elif provider_name == "stability":
|
||||
# return StabilityEditProvider()
|
||||
else:
|
||||
raise ValueError(f"Unknown edit provider: {provider_name}")
|
||||
|
||||
|
||||
def _validate_image_operation(
|
||||
user_id: Optional[str],
|
||||
operation_type: str = "image-generation",
|
||||
num_operations: int = 1,
|
||||
log_prefix: str = "[Image Generation]"
|
||||
) -> None:
|
||||
"""
|
||||
Reusable pre-flight validation helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for subscription checking
|
||||
operation_type: Type of operation (for logging)
|
||||
num_operations: Number of operations to validate (default: 1)
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails (subscription limits exceeded, etc.)
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning(f"{log_prefix} ⚠️ No user_id provided - skipping pre-flight validation (this should not happen in production)")
|
||||
return
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"{log_prefix} 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=num_operations
|
||||
)
|
||||
logger.info(f"{log_prefix} ✅ Pre-flight validation passed for user_id={user_id} - proceeding with operation")
|
||||
except HTTPException as http_ex:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"{log_prefix} ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def _track_image_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/image-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Image Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all image operations.
|
||||
|
||||
Extracted from generate_image() to be reused across all image operation functions.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name (e.g., "wavespeed", "stability")
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated/processed image bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text (for request size calculation)
|
||||
endpoint: API endpoint path (for logging)
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix for operation-specific logs
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information (current_calls, cost, etc.)
|
||||
"""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "stability_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
|
||||
# Update image calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET stability_calls = :new_calls,
|
||||
stability_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Determine API provider based on actual provider
|
||||
api_provider = APIProvider.STABILITY # Default for image generation
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Stability, HuggingFace, etc.)
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=api_provider,
|
||||
model_name=model,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=actual_provider, # Track actual provider (WaveSpeed, Stability, etc.)
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time, # Use actual response time
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
image_limit_display = image_limit if (image_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {image_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit if audio_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def generate_image(prompt: str, options: Optional[Dict[str, Any]] = None, user_id: Optional[str] = None) -> ImageGenerationResult:
|
||||
"""Generate image with pre-flight validation.
|
||||
|
||||
@@ -247,7 +500,165 @@ def generate_character_image(
|
||||
)
|
||||
|
||||
|
||||
def generate_image_edit(
|
||||
image_base64: str,
|
||||
prompt: str,
|
||||
operation: str = "general_edit",
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate edited image - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
image_base64: Base64-encoded input image (or data URI)
|
||||
prompt: Edit instruction prompt
|
||||
operation: Type of edit operation (e.g., "general_edit", "inpaint", "outpaint")
|
||||
model: Model ID to use (default: auto-select based on provider)
|
||||
options: Additional options (mask_base64, negative_prompt, width, height, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with edited image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or editing fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="image-edit",
|
||||
num_operations=1,
|
||||
log_prefix="[Image Edit]"
|
||||
)
|
||||
|
||||
# 2. Determine provider from model or default to wavespeed
|
||||
opts = options or {}
|
||||
provider_name = opts.get("provider", "wavespeed")
|
||||
|
||||
# If model is specified and starts with "wavespeed", use wavespeed provider
|
||||
if model and (model.startswith("wavespeed") or model.startswith("qwen") or model.startswith("flux") or model.startswith("nano-banana")):
|
||||
provider_name = "wavespeed"
|
||||
|
||||
# 3. Get provider (REUSES provider pattern)
|
||||
try:
|
||||
provider = _get_edit_provider(provider_name)
|
||||
except ValueError as e:
|
||||
logger.error(f"[Image Edit] ❌ Provider error: {str(e)}")
|
||||
raise ValueError(f"Unsupported edit provider: {provider_name}")
|
||||
|
||||
# 4. Prepare edit options
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation=operation,
|
||||
mask_base64=opts.get("mask_base64"),
|
||||
negative_prompt=opts.get("negative_prompt"),
|
||||
model=model,
|
||||
width=opts.get("width"),
|
||||
height=opts.get("height"),
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts.get("extra"),
|
||||
)
|
||||
|
||||
# 5. Edit image
|
||||
logger.info(f"[Image Edit] Starting edit: operation={operation}, model={model}, provider={provider_name}")
|
||||
try:
|
||||
result = provider.edit(edit_options)
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Edit] ❌ Edit failed: {str(e)}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "Image editing failed",
|
||||
"message": str(e)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def generate_face_swap(
|
||||
base_image_base64: str,
|
||||
face_image_base64: str,
|
||||
model: Optional[str] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
user_id: Optional[str] = None
|
||||
) -> ImageGenerationResult:
|
||||
"""
|
||||
Generate face swap - REUSES validation and tracking helpers.
|
||||
|
||||
Args:
|
||||
base_image_base64: Base64-encoded base image (or data URI)
|
||||
face_image_base64: Base64-encoded face image to swap (or data URI)
|
||||
model: Model ID to use (default: auto-select)
|
||||
options: Additional options (target_face_index, target_gender, etc.)
|
||||
user_id: User ID for validation and tracking
|
||||
|
||||
Returns:
|
||||
ImageGenerationResult with swapped face image
|
||||
|
||||
Raises:
|
||||
HTTPException: If validation fails or face swap fails
|
||||
ValueError: If options are invalid
|
||||
"""
|
||||
# 1. REUSE: Validation helper
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="face-swap",
|
||||
image_base64=base_image_base64, # Use base image for validation
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
|
||||
# 2. Get provider (default to wavespeed)
|
||||
provider_name = "wavespeed"
|
||||
provider = _get_face_swap_provider(provider_name)
|
||||
|
||||
# 3. Prepare options
|
||||
face_swap_options = FaceSwapOptions(
|
||||
base_image_base64=base_image_base64,
|
||||
face_image_base64=face_image_base64,
|
||||
model=model,
|
||||
target_face_index=options.get("target_face_index") if options else None,
|
||||
target_gender=options.get("target_gender") if options else None,
|
||||
extra=options,
|
||||
)
|
||||
|
||||
# 4. Swap face
|
||||
try:
|
||||
result = provider.swap_face(face_swap_options)
|
||||
|
||||
# 5. REUSE: Tracking helper
|
||||
if user_id and result and result.image_bytes:
|
||||
logger.info(f"[Face Swap] ✅ API call successful, tracking usage for user {user_id}")
|
||||
|
||||
# Get model cost
|
||||
model_id = model or (list(WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.keys())[0] if WaveSpeedFaceSwapProvider.SUPPORTED_MODELS else "unknown")
|
||||
model_info = WaveSpeedFaceSwapProvider.SUPPORTED_MODELS.get(model_id, {})
|
||||
estimated_cost = model_info.get("cost", 0.025) # Default to Pro cost
|
||||
|
||||
# Reuse tracking helper
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=provider_name,
|
||||
model=model_id,
|
||||
operation_type="face-swap",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=None, # Face swap doesn't use prompts
|
||||
endpoint="/image-studio/face-swap/process",
|
||||
metadata={
|
||||
"base_image_size": len(base_image_base64),
|
||||
"face_image_size": len(face_image_base64),
|
||||
},
|
||||
log_prefix="[Face Swap]"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"[Face Swap] ⚠️ Skipping usage tracking: user_id={user_id}, image_bytes={len(result.image_bytes) if result and result.image_bytes else 0} bytes")
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -45,7 +45,6 @@ def llm_text_gen(
|
||||
preferred_hf_models: Optional[List[str]] = None,
|
||||
preferred_provider: Optional[str] = None,
|
||||
flow_type: Optional[str] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Generate text using Language Model (LLM) based on the provided prompt.
|
||||
@@ -76,8 +75,7 @@ def llm_text_gen(
|
||||
gpt_provider = "google" # Default to Google Gemini
|
||||
model = "gemini-2.0-flash-001"
|
||||
temperature = 0.7
|
||||
if max_tokens is None:
|
||||
max_tokens = 4000
|
||||
max_tokens = 4000
|
||||
top_p = 0.9
|
||||
n = 1
|
||||
fp = 16
|
||||
@@ -373,27 +371,16 @@ def llm_text_gen(
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
elif gpt_provider == "wavespeed":
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
llm_start = time.time()
|
||||
if json_struct:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_structured_json_response
|
||||
response_text = wavespeed_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=json_struct,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
else:
|
||||
from services.llm_providers.wavespeed_provider import wavespeed_text_response
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
response_text = wavespeed_text_response(
|
||||
prompt=prompt,
|
||||
model=model or "openai/gpt-oss-120b",
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
top_p=top_p,
|
||||
system_prompt=system_instructions
|
||||
)
|
||||
llm_ms = (time.time() - llm_start) * 1000
|
||||
logger.warning(f"[llm_text_gen][{flow_tag}] LLM API call took {llm_ms:.0f}ms for user {user_id} (wavespeed)")
|
||||
else:
|
||||
|
||||
@@ -179,43 +179,6 @@ def get_wavespeed_api_key() -> str:
|
||||
|
||||
return api_key
|
||||
|
||||
|
||||
def _retry_with_increased_tokens(
|
||||
client: "OpenAI",
|
||||
messages: List[Dict[str, str]],
|
||||
model: str,
|
||||
fallback_models: Optional[List[str]],
|
||||
temperature: float,
|
||||
max_tokens: int,
|
||||
) -> Optional[str]:
|
||||
"""Retry the API call with increased max_tokens when JSON parsing fails due to truncation."""
|
||||
max_tokens = min(max_tokens, 16384)
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
text = response.choices[0].message.content
|
||||
text = text.strip() if text else ""
|
||||
if text.startswith("```json"):
|
||||
text = text[7:]
|
||||
if text.startswith("```"):
|
||||
text = text[3:]
|
||||
if text.endswith("```"):
|
||||
text = text[:-3]
|
||||
return text.strip()
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
continue
|
||||
if last_error:
|
||||
logger.error(f"All fallback models failed on retry with increased tokens: {last_error}")
|
||||
return None
|
||||
|
||||
|
||||
@retry(
|
||||
retry=retry_if_exception(_should_retry_wavespeed_error),
|
||||
wait=wait_random_exponential(min=1, max=60),
|
||||
@@ -483,69 +446,24 @@ def wavespeed_structured_json_response(
|
||||
raise last_error or Exception("WaveSpeed structured generation failed: all fallback models failed")
|
||||
|
||||
response_text = response.choices[0].message.content
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
|
||||
# If response_format returned empty content, retry without it
|
||||
if not response_text:
|
||||
logger.warning("WaveSpeed structured call returned empty content with response_format, retrying without it...")
|
||||
response = None
|
||||
last_error = None
|
||||
for candidate_model in _fallback_model_sequence(model, fallback_models):
|
||||
try:
|
||||
response = client.chat.completions.create(
|
||||
model=candidate_model,
|
||||
messages=messages,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
break
|
||||
except NotFoundError as nf_err:
|
||||
last_error = nf_err
|
||||
continue
|
||||
if response is not None:
|
||||
response_text = response.choices[0].message.content
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
|
||||
# Clean up response text if needed
|
||||
response_text = response_text.strip()
|
||||
if response_text.startswith("```json"):
|
||||
response_text = response_text[7:]
|
||||
if response_text.startswith("```"):
|
||||
response_text = response_text[3:]
|
||||
if response_text.endswith("```"):
|
||||
response_text = response_text[:-3]
|
||||
response_text = response_text.strip()
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(response_text) if response_text else None
|
||||
if parsed_json is not None:
|
||||
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
|
||||
return parsed_json
|
||||
parsed_json = json.loads(response_text)
|
||||
logger.info("✅ WaveSpeed structured JSON response parsed successfully")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError as json_err:
|
||||
logger.error(f"❌ JSON parsing failed: {json_err}")
|
||||
# Retry once with increased max_tokens — likely a truncation issue
|
||||
if max_tokens < 16384:
|
||||
logger.warning(f"Retrying with increased max_tokens ({max_tokens} → {max_tokens * 2}) due to JSON parse failure")
|
||||
response_text = _retry_with_increased_tokens(
|
||||
client=client,
|
||||
messages=messages,
|
||||
model=model,
|
||||
fallback_models=fallback_models,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens * 2,
|
||||
)
|
||||
if response_text:
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
if parsed_json is not None:
|
||||
logger.info("✅ WaveSpeed structured JSON parsed successfully after max_tokens increase")
|
||||
return parsed_json
|
||||
except json.JSONDecodeError:
|
||||
logger.error("❌ JSON parsing failed even after max_tokens increase")
|
||||
|
||||
logger.error(f"Raw response: {response_text}")
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
if response_text:
|
||||
|
||||
# Try to extract JSON from the response using regex
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
@@ -554,8 +472,8 @@ def wavespeed_structured_json_response(
|
||||
return extracted_json
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ WaveSpeed API call failed: {e}")
|
||||
@@ -583,24 +501,14 @@ def wavespeed_structured_json_response(
|
||||
if response is None:
|
||||
raise last_error or e
|
||||
response_text = response.choices[0].message.content
|
||||
response_text = response_text.strip() if response_text else ""
|
||||
# Parse JSON with robust cleaning
|
||||
if response_text.startswith("```json"):
|
||||
response_text = response_text[7:]
|
||||
if response_text.startswith("```"):
|
||||
response_text = response_text[3:]
|
||||
if response_text.endswith("```"):
|
||||
response_text = response_text[:-3]
|
||||
response_text = response_text.strip()
|
||||
# ... (same parsing logic would apply, simplified here for brevity)
|
||||
try:
|
||||
return json.loads(response_text) if response_text else {"error": "Empty response"}
|
||||
except json.JSONDecodeError:
|
||||
return json.loads(response_text)
|
||||
except:
|
||||
# Regex fallback
|
||||
json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
|
||||
if json_match:
|
||||
try:
|
||||
return json.loads(json_match.group())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return json.loads(json_match.group())
|
||||
return {"error": "Failed to parse JSON response", "raw_response": response_text}
|
||||
raise e
|
||||
|
||||
|
||||
@@ -19,11 +19,11 @@ from services.database import get_db_session
|
||||
from models.onboarding import OnboardingSession, WebsiteAnalysis, ResearchPreferences
|
||||
from models.persona_models import WritingPersona, PlatformPersona, PersonaAnalysisResult
|
||||
|
||||
def _is_feature_limited_mode():
|
||||
"""Check if running in feature-limited mode to skip heavy initialization."""
|
||||
def _get_podcast_mode():
|
||||
"""Check if running in podcast-only mode to skip heavy initialization."""
|
||||
import os
|
||||
env_val = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower()
|
||||
return env_val not in ("", "all")
|
||||
return env_val == "podcast"
|
||||
|
||||
class PersonaAnalysisService:
|
||||
"""Service for analyzing onboarding data and generating writing personas using Gemini AI."""
|
||||
@@ -40,9 +40,9 @@ class PersonaAnalysisService:
|
||||
def __init__(self):
|
||||
"""Initialize the persona analysis service (only once)."""
|
||||
if not self._initialized:
|
||||
# Skip heavy initialization in feature-limited mode
|
||||
if _is_feature_limited_mode():
|
||||
logger.debug(f"PersonaAnalysisService: Skipping heavy init in feature-limited mode")
|
||||
# Skip heavy initialization in podcast-only mode
|
||||
if _get_podcast_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
|
||||
self._initialized = True
|
||||
return
|
||||
|
||||
@@ -55,8 +55,8 @@ class PersonaAnalysisService:
|
||||
return
|
||||
|
||||
# Check again in case mode changed
|
||||
if _is_feature_limited_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in feature-limited mode")
|
||||
if _get_podcast_mode():
|
||||
logger.debug("PersonaAnalysisService: Skipping heavy init in podcast mode")
|
||||
self._heavy_init_done = True
|
||||
return
|
||||
|
||||
@@ -89,9 +89,9 @@ class PersonaAnalysisService:
|
||||
# Ensure heavy services are initialized
|
||||
self._ensure_heavy_init()
|
||||
|
||||
# Check if heavy init failed (feature-limited mode)
|
||||
# Check if heavy init failed (podcast mode)
|
||||
if not getattr(self, '_heavy_init_done', False):
|
||||
return {"error": "Persona service unavailable in feature-limited mode"}
|
||||
return {"error": "Persona service unavailable in podcast-only mode"}
|
||||
|
||||
try:
|
||||
logger.info(f"Generating persona for user {user_id}")
|
||||
|
||||
@@ -13,7 +13,7 @@ from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from services.database import SessionLocal
|
||||
from fastapi import HTTPException
|
||||
@@ -113,7 +113,12 @@ class ProductImageService:
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize Product Image Service."""
|
||||
logger.info("[Product Image Service] Initialized")
|
||||
try:
|
||||
self.wavespeed_client = WaveSpeedClient()
|
||||
logger.info("[Product Image Service] Initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"[Product Image Service] Failed to initialize WaveSpeed client: {str(e)}")
|
||||
raise ProductImageServiceError(f"Failed to initialize service: {str(e)}") from e
|
||||
|
||||
def validate_request(self, request: ProductImageRequest) -> None:
|
||||
"""
|
||||
@@ -255,7 +260,77 @@ class ProductImageService:
|
||||
|
||||
return full_prompt
|
||||
|
||||
def generate_product_image(
|
||||
def _generate_image_with_retry(
|
||||
self,
|
||||
model: str,
|
||||
prompt: str,
|
||||
width: int,
|
||||
height: int,
|
||||
max_retries: int = 3,
|
||||
retry_delay: float = 2.0
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate image with retry logic for transient failures.
|
||||
|
||||
Args:
|
||||
model: Model to use
|
||||
prompt: Generation prompt
|
||||
width: Image width
|
||||
height: Image height
|
||||
max_retries: Maximum number of retries
|
||||
retry_delay: Delay between retries in seconds
|
||||
|
||||
Returns:
|
||||
Generated image bytes
|
||||
|
||||
Raises:
|
||||
ImageGenerationError: If generation fails after retries
|
||||
"""
|
||||
last_error = None
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
logger.info(f"[Product Image Service] Image generation attempt {attempt + 1}/{max_retries}")
|
||||
|
||||
image_bytes = self.wavespeed_client.generate_image(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
width=width,
|
||||
height=height,
|
||||
enable_sync_mode=True,
|
||||
timeout=120,
|
||||
)
|
||||
|
||||
if not image_bytes:
|
||||
raise ValueError("Image generation returned empty result")
|
||||
|
||||
if len(image_bytes) < 100: # Sanity check: image should be at least 100 bytes
|
||||
raise ValueError(f"Generated image too small: {len(image_bytes)} bytes")
|
||||
|
||||
logger.info(f"[Product Image Service] ✅ Image generated successfully: {len(image_bytes)} bytes")
|
||||
return image_bytes
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
error_msg = str(e)
|
||||
logger.warning(f"[Product Image Service] Attempt {attempt + 1} failed: {error_msg}")
|
||||
|
||||
# Don't retry on validation errors or client errors (4xx)
|
||||
if "4" in error_msg or "validation" in error_msg.lower() or "invalid" in error_msg.lower():
|
||||
logger.error(f"[Product Image Service] Non-retryable error: {error_msg}")
|
||||
raise ImageGenerationError(f"Image generation failed: {error_msg}") from e
|
||||
|
||||
# Retry on transient errors
|
||||
if attempt < max_retries - 1:
|
||||
logger.info(f"[Product Image Service] Retrying in {retry_delay} seconds...")
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 1.5 # Exponential backoff
|
||||
else:
|
||||
logger.error(f"[Product Image Service] All retry attempts failed")
|
||||
|
||||
raise ImageGenerationError(f"Image generation failed after {max_retries} attempts: {str(last_error)}") from last_error
|
||||
|
||||
async def generate_product_image(
|
||||
self,
|
||||
request: ProductImageRequest,
|
||||
user_id: str,
|
||||
@@ -299,18 +374,15 @@ class ProductImageService:
|
||||
|
||||
# Generate image using WaveSpeed with retry logic
|
||||
try:
|
||||
result = generate_image(
|
||||
image_bytes = self._generate_image_with_retry(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
options={
|
||||
"provider": "wavespeed",
|
||||
"model": model,
|
||||
"width": width,
|
||||
"height": height,
|
||||
},
|
||||
user_id=user_id,
|
||||
width=width,
|
||||
height=height,
|
||||
max_retries=3,
|
||||
retry_delay=2.0
|
||||
)
|
||||
image_bytes = result.image_bytes
|
||||
except Exception as e:
|
||||
except ImageGenerationError as e:
|
||||
logger.error(f"[Product Image Service] Image generation failed: {str(e)}")
|
||||
generation_time = time.time() - start_time
|
||||
return ProductImageResult(
|
||||
|
||||
@@ -296,33 +296,6 @@ class ResearchEngine:
|
||||
target_audience = request.target_audience or "General"
|
||||
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Preflight subscription check
|
||||
try:
|
||||
db = self._db_session
|
||||
if not db:
|
||||
from services.database import get_db_session
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.EXA,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="exa",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'exa', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[ResearchEngine] Exa preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[ResearchEngine] Exa preflight check failed: {e}")
|
||||
|
||||
# Execute Exa search
|
||||
try:
|
||||
@@ -368,33 +341,6 @@ class ResearchEngine:
|
||||
target_audience = request.target_audience or "General"
|
||||
|
||||
research_prompt = strategy.build_research_prompt(topic, industry, target_audience, config)
|
||||
|
||||
# Preflight subscription check
|
||||
try:
|
||||
db = self._db_session
|
||||
if not db:
|
||||
from services.database import get_db_session
|
||||
db = get_db_session()
|
||||
if db:
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.TAVILY,
|
||||
tokens_requested=0,
|
||||
actual_provider_name="tavily",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(status_code=429, detail={
|
||||
'error': message, 'message': message,
|
||||
'provider': 'tavily', 'usage_info': usage_info or {}
|
||||
})
|
||||
logger.info(f"[ResearchEngine] Tavily preflight check passed for user {user_id}")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"[ResearchEngine] Tavily preflight check failed: {e}")
|
||||
|
||||
# Execute Tavily search
|
||||
try:
|
||||
|
||||
@@ -83,30 +83,6 @@ class DeepCrawlService:
|
||||
tavily_results.append(res)
|
||||
|
||||
logger.info(f"Found {len(tavily_urls)} URLs from Tavily")
|
||||
|
||||
# Track Tavily usage
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text
|
||||
pricing_service = PricingService(db)
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
cost = 0.005 # Tavily crawl cost estimate
|
||||
|
||||
update_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET tavily_calls = COALESCE(tavily_calls, 0) + 1,
|
||||
tavily_cost = COALESCE(tavily_cost, 0) + :cost,
|
||||
total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_query, {
|
||||
'cost': cost, 'user_id': user_id, 'period': current_period,
|
||||
})
|
||||
db.commit()
|
||||
logger.info(f"[DeepCrawl] Tracked Tavily crawl usage: user={user_id}, cost=${cost}")
|
||||
except Exception as track_err:
|
||||
logger.warning(f"[DeepCrawl] Failed to track Tavily usage: {track_err}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Tavily crawl failed: {e}")
|
||||
|
||||
|
||||
@@ -49,11 +49,9 @@ except Exception as _patch_err:
|
||||
# Now safe to import pytrends
|
||||
try:
|
||||
from pytrends.request import TrendReq as _TrendReq
|
||||
from pytrends.exceptions import TooManyRequestsError as _TooManyRequestsError
|
||||
PYTrends_AVAILABLE = True
|
||||
except ImportError:
|
||||
PYTrends_AVAILABLE = False
|
||||
_TooManyRequestsError = None
|
||||
logger.warning("pytrends not installed. Google Trends features will be unavailable.")
|
||||
|
||||
# Patch 2: pytrends related_topics() and related_queries() use keyword[0]
|
||||
@@ -141,8 +139,6 @@ class GoogleTrendsService:
|
||||
Uses TrendReq with no retries (fail-fast) to avoid hitting CAPTCHA on blocks.
|
||||
429 retry handling (1s, 2s, 4s backoff). Random user-agent is set
|
||||
per instance to reduce fingerprinting.
|
||||
|
||||
Rate limiter is shared across all instances to enforce global rate limiting.
|
||||
"""
|
||||
|
||||
USER_AGENTS = [
|
||||
@@ -154,28 +150,15 @@ class GoogleTrendsService:
|
||||
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36 Edg/124.0.0.0",
|
||||
]
|
||||
|
||||
# Class-level shared resources (shared across all instances)
|
||||
_shared_rate_limiter = None
|
||||
_shared_cache = None
|
||||
_cache_ttl = timedelta(hours=24)
|
||||
_last_429_time = 0 # Timestamp of last 429 error (Unix epoch)
|
||||
_429_cooldown_period = 1800 # 30 minutes cooldown after 429
|
||||
|
||||
def __init__(self):
|
||||
if not PYTrends_AVAILABLE:
|
||||
raise RuntimeError("pytrends library is required. Install with: pip install pytrends")
|
||||
|
||||
# Initialize shared rate limiter at class level (lazy init)
|
||||
if self.__class__._shared_rate_limiter is None:
|
||||
self.__class__._shared_rate_limiter = RateLimiter(max_calls=1, period=3.0) # 1 call per 3 seconds
|
||||
if self.__class__._shared_cache is None:
|
||||
self.__class__._shared_cache = {}
|
||||
self.rate_limiter = RateLimiter(max_calls=1, period=1.0)
|
||||
self.cache: Dict[str, Any] = {}
|
||||
self.cache_ttl = timedelta(hours=24)
|
||||
|
||||
self.rate_limiter = self.__class__._shared_rate_limiter
|
||||
self.cache = self.__class__._shared_cache
|
||||
self.cache_ttl = self._cache_ttl
|
||||
|
||||
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, shared rate limiter, 3s period, shared cache, 30min 429 cooldown)")
|
||||
logger.info("GoogleTrendsService initialized (pytrends 4.9.2, fail-fast, 2s delays)")
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Public API
|
||||
@@ -190,7 +173,7 @@ class GoogleTrendsService:
|
||||
user_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Comprehensive trends analysis with retry logic for 429 errors.
|
||||
Comprehensive trends analysis.
|
||||
|
||||
Args:
|
||||
keywords: List of keywords to analyze (1-5)
|
||||
@@ -210,97 +193,11 @@ class GoogleTrendsService:
|
||||
keywords = keywords[:5]
|
||||
|
||||
cache_key = self._build_cache_key(keywords, timeframe, geo)
|
||||
|
||||
# Check if we're in a 429 cooldown period
|
||||
now = time.time()
|
||||
if now - self.__class__._last_429_time < self.__class__._429_cooldown_period:
|
||||
remaining_cooldown = int(self.__class__._429_cooldown_period - (now - self.__class__._last_429_time))
|
||||
logger.warning(
|
||||
f"[Trends] In 429 cooldown period. {remaining_cooldown}s remaining. "
|
||||
f"Returning cached data if available."
|
||||
)
|
||||
cached_data = self._get_from_cache(cache_key, ignore_ttl=True) # Use stale cache
|
||||
if cached_data:
|
||||
logger.info(f"[Trends] Returning stale cached data for {keywords} during cooldown")
|
||||
return {**cached_data, "cached": True, "cooldown_active": True}
|
||||
return self._create_fallback_response(
|
||||
keywords, timeframe, geo, gprop,
|
||||
f"Rate limited by Google. Cooldown active for {remaining_cooldown}s. Try again later."
|
||||
)
|
||||
|
||||
# Check fresh cache
|
||||
cached_data = self._get_from_cache(cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Returning cached trends data for: {keywords}")
|
||||
return {**cached_data, "cached": True}
|
||||
|
||||
# Retry logic for 429 errors
|
||||
max_retries = 3
|
||||
retry_delays = [30, 60, 120] # Longer delays: 30s, 60s, 120s
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await self._do_analyze_trends(
|
||||
keywords, timeframe, geo, gprop, cache_key, attempt, max_retries
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if this is a 429 error (pytrends raises TooManyRequestsError)
|
||||
is_429 = False
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
is_429 = True
|
||||
else:
|
||||
error_str = str(e).lower()
|
||||
is_429 = "429" in error_str or "rate limit" in error_str or "too many requests" in error_str
|
||||
|
||||
if is_429:
|
||||
# Update the last 429 time for cooldown
|
||||
self.__class__._last_429_time = time.time()
|
||||
|
||||
if attempt < max_retries:
|
||||
delay = retry_delays[attempt]
|
||||
logger.warning(
|
||||
f"[Trends] 429 rate limit hit (attempt {attempt + 1}/{max_retries + 1}), "
|
||||
f"retrying in {delay}s..."
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
# Out of retries - enter cooldown
|
||||
logger.error(
|
||||
f"[Trends] 429 rate limit persisted after {max_retries + 1} attempts. "
|
||||
f"Entering {self.__class__._429_cooldown_period}s cooldown period."
|
||||
)
|
||||
# Try to return stale cache
|
||||
stale_cache = self._get_from_cache(cache_key, ignore_ttl=True)
|
||||
if stale_cache:
|
||||
logger.info(f"[Trends] Returning stale cache after 429 exhaustion for {keywords}")
|
||||
result = {**stale_cache}
|
||||
result["cached"] = True
|
||||
result["cooldown_active"] = True
|
||||
return result
|
||||
return self._create_fallback_response(
|
||||
keywords, timeframe, geo, gprop,
|
||||
f"Google is rate limiting requests. Cooldown active for {self.__class__._429_cooldown_period}s. Try again later."
|
||||
)
|
||||
else:
|
||||
# Non-429 error
|
||||
logger.error(f"Google Trends analysis failed after {attempt + 1} attempts: {e}")
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
|
||||
|
||||
# Should not reach here, but just in case
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, "Max retries exceeded")
|
||||
|
||||
async def _do_analyze_trends(
|
||||
self,
|
||||
keywords: List[str],
|
||||
timeframe: str,
|
||||
geo: str,
|
||||
gprop: str,
|
||||
cache_key: str,
|
||||
attempt: int,
|
||||
max_retries: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Internal method to perform the actual trends analysis."""
|
||||
await self.rate_limiter.acquire()
|
||||
|
||||
total_start = time.monotonic()
|
||||
@@ -310,63 +207,95 @@ class GoogleTrendsService:
|
||||
related_topics: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
related_queries: Dict[str, List[Dict[str, Any]]] = {"top": [], "rising": []}
|
||||
|
||||
logger.info(
|
||||
f"[Trends] ===== START analyze_trends (attempt {attempt + 1}/{max_retries + 1}) ===== "
|
||||
f"keywords={keywords} timeframe={timeframe} geo={geo}"
|
||||
)
|
||||
try:
|
||||
logger.info(f"[Trends] ===== START analyze_trends ===== keywords={keywords} timeframe={timeframe} geo={geo}")
|
||||
|
||||
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
|
||||
init_start = time.monotonic()
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._create_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo,
|
||||
gprop,
|
||||
)
|
||||
init_ms = int((time.monotonic() - init_start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
|
||||
# Initialize TrendReq with gprop (youtube for video/podcast relevance)
|
||||
init_start = time.monotonic()
|
||||
pytrends = await asyncio.to_thread(
|
||||
self._create_pytrends,
|
||||
keywords,
|
||||
timeframe,
|
||||
geo,
|
||||
gprop,
|
||||
)
|
||||
init_ms = int((time.monotonic() - init_start) * 1000)
|
||||
logger.info(f"[Trends] TrendReq init + build_payload took {init_ms}ms")
|
||||
|
||||
# --- Interest Over Time ONLY (skip others to avoid 429) ---
|
||||
await self.rate_limiter.acquire() # Rate limit check BEFORE each request
|
||||
iot_start = time.monotonic()
|
||||
interest_over_time = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_over_time(pytrends)
|
||||
)
|
||||
iot_ms = int((time.monotonic() - iot_start) * 1000)
|
||||
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
|
||||
# --- Interest Over Time ---
|
||||
iot_start = time.monotonic()
|
||||
interest_over_time = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_over_time(pytrends)
|
||||
)
|
||||
iot_ms = int((time.monotonic() - iot_start) * 1000)
|
||||
logger.info(f"[Trends] interest_over_time took {iot_ms}ms, returned {len(interest_over_time)} points")
|
||||
|
||||
# Skip other requests to avoid 429 - only fetch interest_over_time for now
|
||||
logger.info(f"[Trends] Skipping other requests to avoid 429 (interest_by_region, related_topics, related_queries)")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
total_ms = int((time.monotonic() - total_start) * 1000)
|
||||
logger.info(
|
||||
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
|
||||
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
|
||||
f"rt_top={rt_top} rq_top={rq_top}"
|
||||
)
|
||||
# --- Interest By Region ---
|
||||
ibr_start = time.monotonic()
|
||||
interest_by_region = await asyncio.to_thread(
|
||||
lambda: self._fetch_interest_by_region(pytrends)
|
||||
)
|
||||
ibr_ms = int((time.monotonic() - ibr_start) * 1000)
|
||||
logger.info(f"[Trends] interest_by_region took {ibr_ms}ms, returned {len(interest_by_region)} regions")
|
||||
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
"related_topics": related_topics,
|
||||
"related_queries": related_queries,
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
}
|
||||
await asyncio.sleep(2)
|
||||
|
||||
self._save_to_cache(cache_key, result)
|
||||
# --- Related Topics ---
|
||||
rt_start = time.monotonic()
|
||||
related_topics = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_topics(pytrends)
|
||||
)
|
||||
rt_ms = int((time.monotonic() - rt_start) * 1000)
|
||||
rt_top = len(related_topics.get("top", []))
|
||||
rt_rising = len(related_topics.get("rising", []))
|
||||
logger.info(f"[Trends] related_topics took {rt_ms}ms, top={rt_top} rising={rt_rising}")
|
||||
|
||||
logger.info(
|
||||
f"Google Trends data fetched successfully: "
|
||||
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
|
||||
)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
return result
|
||||
# --- Related Queries ---
|
||||
rq_start = time.monotonic()
|
||||
related_queries = await asyncio.to_thread(
|
||||
lambda: self._fetch_related_queries(pytrends)
|
||||
)
|
||||
rq_ms = int((time.monotonic() - rq_start) * 1000)
|
||||
rq_top = len(related_queries.get("top", []))
|
||||
rq_rising = len(related_queries.get("rising", []))
|
||||
logger.info(f"[Trends] related_queries took {rq_ms}ms, top={rq_top} rising={rq_rising}")
|
||||
|
||||
total_ms = int((time.monotonic() - total_start) * 1000)
|
||||
logger.info(
|
||||
f"[Trends] ===== DONE analyze_trends ===== total={total_ms}ms "
|
||||
f"iot={len(interest_over_time)} ibr={len(interest_by_region)} "
|
||||
f"rt_top={rt_top} rq_top={rq_top}"
|
||||
)
|
||||
|
||||
result = {
|
||||
"interest_over_time": interest_over_time,
|
||||
"interest_by_region": interest_by_region,
|
||||
"related_topics": related_topics,
|
||||
"related_queries": related_queries,
|
||||
"timeframe": timeframe,
|
||||
"geo": geo,
|
||||
"keywords": keywords,
|
||||
"source": "web" if gprop == "" else "podcast" if gprop == "youtube" else gprop,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"cached": False,
|
||||
}
|
||||
|
||||
self._save_to_cache(cache_key, result)
|
||||
|
||||
logger.info(
|
||||
f"Google Trends data fetched successfully: "
|
||||
f"{len(interest_over_time)} time points, {len(interest_by_region)} regions"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Google Trends analysis failed: {e}")
|
||||
return self._create_fallback_response(keywords, timeframe, geo, gprop, str(e))
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# TrendReq factory
|
||||
@@ -417,12 +346,6 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] interest_over_time failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
@@ -440,12 +363,6 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] interest_by_region failed in {elapsed}ms: {e}")
|
||||
return []
|
||||
|
||||
@@ -492,12 +409,6 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] related_topics failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
@@ -541,12 +452,6 @@ class GoogleTrendsService:
|
||||
return result
|
||||
except Exception as e:
|
||||
elapsed = int((time.monotonic() - start) * 1000)
|
||||
# Re-raise 429 errors so retry logic can handle them
|
||||
if _TooManyRequestsError and isinstance(e, _TooManyRequestsError):
|
||||
raise
|
||||
error_str = str(e).lower()
|
||||
if "429" in error_str or "rate limit" in error_str or "too many requests" in error_str:
|
||||
raise
|
||||
logger.error(f"[Trends] related_queries failed in {elapsed}ms: {e}")
|
||||
return result
|
||||
|
||||
@@ -598,18 +503,14 @@ class GoogleTrendsService:
|
||||
keywords_str = ":".join(sorted(keywords))
|
||||
return f"google_trends:{keywords_str}:{timeframe}:{geo}"
|
||||
|
||||
def _get_from_cache(self, cache_key: str, ignore_ttl: bool = False) -> Optional[Dict[str, Any]]:
|
||||
"""Get cached data. If ignore_ttl=True, return stale data too (for 429 cooldown)."""
|
||||
def _get_from_cache(self, cache_key: str) -> Optional[Dict[str, Any]]:
|
||||
if cache_key not in self.cache:
|
||||
return None
|
||||
cached_entry = self.cache[cache_key]
|
||||
|
||||
if not ignore_ttl:
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
|
||||
cached_time = datetime.fromisoformat(cached_entry.get("timestamp", ""))
|
||||
if datetime.utcnow() - cached_time > self.cache_ttl:
|
||||
del self.cache[cache_key]
|
||||
return None
|
||||
result = {**cached_entry}
|
||||
result.pop("cached", None)
|
||||
return result
|
||||
|
||||
@@ -1,271 +0,0 @@
|
||||
"""Self-healing executor for social post engagement recovery.
|
||||
|
||||
Implements:
|
||||
- Per-post evaluation windows and cooldown timers
|
||||
- Stagnation trigger evaluation with tiered action selection
|
||||
- Action idempotency keys for edit/comment/thread operations
|
||||
- Duplicate and over-frequency suppression within cooldown boundaries
|
||||
- Outcome persistence and safe retry policy for transient failures
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field, asdict
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from enum import Enum
|
||||
import hashlib
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
||||
class ActionType(str, Enum):
|
||||
EDIT = "edit"
|
||||
COMMENT = "comment"
|
||||
THREAD = "thread"
|
||||
|
||||
|
||||
class ActionTier(str, Enum):
|
||||
TIER_1 = "tier_1" # low-intensity nudge (comment)
|
||||
TIER_2 = "tier_2" # medium-intensity enhancement (edit)
|
||||
TIER_3 = "tier_3" # high-intensity amplification (thread)
|
||||
|
||||
|
||||
SAFE_TRANSIENT_ERROR_CODES = {
|
||||
"timeout",
|
||||
"rate_limit",
|
||||
"service_unavailable",
|
||||
"network_error",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvaluationConfig:
|
||||
per_post_window_minutes: int = 90
|
||||
min_samples_required: int = 3
|
||||
cooldown_by_action_seconds: Dict[ActionType, int] = field(
|
||||
default_factory=lambda: {
|
||||
ActionType.COMMENT: 30 * 60,
|
||||
ActionType.EDIT: 2 * 60 * 60,
|
||||
ActionType.THREAD: 3 * 60 * 60,
|
||||
}
|
||||
)
|
||||
max_actions_per_window: int = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostMetricsPoint:
|
||||
timestamp: datetime
|
||||
impressions: int
|
||||
engagements: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActionRecord:
|
||||
idempotency_key: str
|
||||
post_id: str
|
||||
action_type: ActionType
|
||||
tier: ActionTier
|
||||
initiated_at: datetime
|
||||
status: str
|
||||
attempts: int = 1
|
||||
outcome: Optional[Dict[str, Any]] = None
|
||||
error_code: Optional[str] = None
|
||||
|
||||
def to_json(self) -> Dict[str, Any]:
|
||||
payload = asdict(self)
|
||||
payload["action_type"] = self.action_type.value
|
||||
payload["tier"] = self.tier.value
|
||||
payload["initiated_at"] = self.initiated_at.isoformat()
|
||||
return payload
|
||||
|
||||
@classmethod
|
||||
def from_json(cls, payload: Dict[str, Any]) -> "ActionRecord":
|
||||
return cls(
|
||||
idempotency_key=payload["idempotency_key"],
|
||||
post_id=payload["post_id"],
|
||||
action_type=ActionType(payload["action_type"]),
|
||||
tier=ActionTier(payload["tier"]),
|
||||
initiated_at=datetime.fromisoformat(payload["initiated_at"]),
|
||||
status=payload["status"],
|
||||
attempts=payload.get("attempts", 1),
|
||||
outcome=payload.get("outcome"),
|
||||
error_code=payload.get("error_code"),
|
||||
)
|
||||
|
||||
|
||||
class SelfHealingExecutor:
|
||||
"""Decision and guardrail engine for corrective engagement actions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: Optional[EvaluationConfig] = None,
|
||||
persistence_path: str = "backend/data/self_healing_action_history.json",
|
||||
) -> None:
|
||||
self.config = config or EvaluationConfig()
|
||||
self.persistence_path = Path(persistence_path)
|
||||
self._history: List[ActionRecord] = self._load_history()
|
||||
|
||||
def evaluate_and_plan(
|
||||
self,
|
||||
post_id: str,
|
||||
metrics: List[PostMetricsPoint],
|
||||
now: Optional[datetime] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Evaluate stagnation for a post and plan a single best next action."""
|
||||
now = now or datetime.now(timezone.utc)
|
||||
window_metrics = self._filter_window(metrics, now)
|
||||
|
||||
if len(window_metrics) < self.config.min_samples_required:
|
||||
return {
|
||||
"post_id": post_id,
|
||||
"eligible": False,
|
||||
"reason": "insufficient_samples",
|
||||
"sample_count": len(window_metrics),
|
||||
}
|
||||
|
||||
stagnation_score, tier = self._evaluate_stagnation(window_metrics)
|
||||
action_type = self._choose_action_type(tier)
|
||||
idempotency_key = self.generate_idempotency_key(post_id, action_type, tier)
|
||||
|
||||
if self._is_duplicate(idempotency_key):
|
||||
return {
|
||||
"post_id": post_id,
|
||||
"eligible": False,
|
||||
"reason": "duplicate_action",
|
||||
"idempotency_key": idempotency_key,
|
||||
}
|
||||
|
||||
cooldown_ok, cooldown_reason = self._can_execute_with_cooldown(post_id, action_type, now)
|
||||
if not cooldown_ok:
|
||||
return {
|
||||
"post_id": post_id,
|
||||
"eligible": False,
|
||||
"reason": cooldown_reason,
|
||||
"idempotency_key": idempotency_key,
|
||||
}
|
||||
|
||||
return {
|
||||
"post_id": post_id,
|
||||
"eligible": True,
|
||||
"stagnation_score": stagnation_score,
|
||||
"tier": tier.value,
|
||||
"action_type": action_type.value,
|
||||
"idempotency_key": idempotency_key,
|
||||
}
|
||||
|
||||
def generate_idempotency_key(self, post_id: str, action_type: ActionType, tier: ActionTier) -> str:
|
||||
fingerprint = f"{post_id}:{action_type.value}:{tier.value}".encode("utf-8")
|
||||
digest = hashlib.sha256(fingerprint).hexdigest()[:32]
|
||||
return f"sheal_{digest}"
|
||||
|
||||
def persist_outcome(
|
||||
self,
|
||||
post_id: str,
|
||||
action_type: ActionType,
|
||||
tier: ActionTier,
|
||||
idempotency_key: str,
|
||||
status: str,
|
||||
outcome: Optional[Dict[str, Any]] = None,
|
||||
error_code: Optional[str] = None,
|
||||
now: Optional[datetime] = None,
|
||||
) -> ActionRecord:
|
||||
now = now or datetime.now(timezone.utc)
|
||||
|
||||
existing = next((h for h in self._history if h.idempotency_key == idempotency_key), None)
|
||||
if existing:
|
||||
existing.status = status
|
||||
existing.outcome = outcome
|
||||
existing.error_code = error_code
|
||||
existing.attempts += 1
|
||||
existing.initiated_at = now
|
||||
record = existing
|
||||
else:
|
||||
record = ActionRecord(
|
||||
idempotency_key=idempotency_key,
|
||||
post_id=post_id,
|
||||
action_type=action_type,
|
||||
tier=tier,
|
||||
initiated_at=now,
|
||||
status=status,
|
||||
outcome=outcome,
|
||||
error_code=error_code,
|
||||
)
|
||||
self._history.append(record)
|
||||
|
||||
self._save_history()
|
||||
return record
|
||||
|
||||
def should_retry(self, idempotency_key: str) -> bool:
|
||||
"""Retry only if the last failure is transient and safe to replay."""
|
||||
rec = next((h for h in self._history if h.idempotency_key == idempotency_key), None)
|
||||
if not rec or rec.status != "failed":
|
||||
return False
|
||||
|
||||
if rec.error_code not in SAFE_TRANSIENT_ERROR_CODES:
|
||||
return False
|
||||
|
||||
return rec.action_type in {ActionType.COMMENT, ActionType.EDIT, ActionType.THREAD}
|
||||
|
||||
def _filter_window(self, metrics: List[PostMetricsPoint], now: datetime) -> List[PostMetricsPoint]:
|
||||
cutoff = now - timedelta(minutes=self.config.per_post_window_minutes)
|
||||
return [m for m in metrics if m.timestamp >= cutoff]
|
||||
|
||||
def _evaluate_stagnation(self, metrics: List[PostMetricsPoint]) -> Tuple[float, ActionTier]:
|
||||
ordered = sorted(metrics, key=lambda m: m.timestamp)
|
||||
first, last = ordered[0], ordered[-1]
|
||||
|
||||
imp_delta = max(0, last.impressions - first.impressions)
|
||||
eng_delta = max(0, last.engagements - first.engagements)
|
||||
eng_rate = eng_delta / imp_delta if imp_delta > 0 else 0.0
|
||||
|
||||
stagnation_score = 1.0 - min(1.0, eng_rate * 20)
|
||||
if stagnation_score >= 0.8:
|
||||
return stagnation_score, ActionTier.TIER_3
|
||||
if stagnation_score >= 0.55:
|
||||
return stagnation_score, ActionTier.TIER_2
|
||||
return stagnation_score, ActionTier.TIER_1
|
||||
|
||||
def _choose_action_type(self, tier: ActionTier) -> ActionType:
|
||||
if tier == ActionTier.TIER_1:
|
||||
return ActionType.COMMENT
|
||||
if tier == ActionTier.TIER_2:
|
||||
return ActionType.EDIT
|
||||
return ActionType.THREAD
|
||||
|
||||
def _is_duplicate(self, idempotency_key: str) -> bool:
|
||||
return any(h.idempotency_key == idempotency_key and h.status in {"success", "running"} for h in self._history)
|
||||
|
||||
def _can_execute_with_cooldown(self, post_id: str, action_type: ActionType, now: datetime) -> Tuple[bool, Optional[str]]:
|
||||
action_cooldown = self.config.cooldown_by_action_seconds[action_type]
|
||||
|
||||
same_post = [h for h in self._history if h.post_id == post_id]
|
||||
recent_in_window = [
|
||||
h for h in same_post
|
||||
if h.initiated_at >= now - timedelta(minutes=self.config.per_post_window_minutes)
|
||||
]
|
||||
if len(recent_in_window) >= self.config.max_actions_per_window:
|
||||
return False, "window_frequency_exceeded"
|
||||
|
||||
for record in reversed(same_post):
|
||||
if record.action_type != action_type:
|
||||
continue
|
||||
if (now - record.initiated_at).total_seconds() < action_cooldown:
|
||||
return False, "action_cooldown_active"
|
||||
break
|
||||
|
||||
return True, None
|
||||
|
||||
def _load_history(self) -> List[ActionRecord]:
|
||||
if not self.persistence_path.exists():
|
||||
return []
|
||||
try:
|
||||
payload = json.loads(self.persistence_path.read_text(encoding="utf-8"))
|
||||
return [ActionRecord.from_json(item) for item in payload]
|
||||
except (json.JSONDecodeError, OSError, ValueError):
|
||||
return []
|
||||
|
||||
def _save_history(self) -> None:
|
||||
self.persistence_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = [item.to_json() for item in self._history]
|
||||
self.persistence_path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
||||
@@ -157,10 +157,10 @@ def _check_production_api_key_loading(
|
||||
_record_check(checks, "production_api_key_loading", True, "skipped in local deploy mode")
|
||||
return
|
||||
|
||||
# Skip when in feature-limited mode (no production API keys needed)
|
||||
# Also skip in podcast-only mode (no production API keys needed)
|
||||
enabled_features = os.getenv("ALWRITY_ENABLED_FEATURES", "all").strip().lower()
|
||||
if enabled_features and enabled_features not in ("", "all"):
|
||||
_record_check(checks, "production_api_key_loading", True, f"skipped in feature-limited mode: {enabled_features}")
|
||||
if enabled_features == "podcast":
|
||||
_record_check(checks, "production_api_key_loading", True, "skipped in podcast-only mode")
|
||||
return
|
||||
|
||||
test_tenant_id = os.getenv("ALWRITY_STARTUP_TEST_TENANT_ID", "").strip()
|
||||
|
||||
@@ -12,7 +12,7 @@ from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from models.subscription_models import APIProvider, UsageAlert, UserSubscription
|
||||
from models.subscription_models import APIProvider, UsageAlert
|
||||
|
||||
class SubscriptionErrorType(Enum):
|
||||
USAGE_LIMIT_EXCEEDED = "usage_limit_exceeded"
|
||||
@@ -248,18 +248,6 @@ class SubscriptionExceptionHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
# Get billing period from subscription, fallback to calendar month
|
||||
billing_period = datetime.now().strftime("%Y-%m") # default
|
||||
try:
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == error.user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
if subscription and subscription.current_period_start:
|
||||
billing_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
except:
|
||||
pass # Use default calendar period
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=error.user_id,
|
||||
alert_type="system_error",
|
||||
@@ -268,7 +256,7 @@ class SubscriptionExceptionHandler:
|
||||
title=f"System Error: {error.error_type.value}",
|
||||
message=error.message,
|
||||
severity=error.severity.value,
|
||||
billing_period=billing_period
|
||||
billing_period=datetime.now().strftime("%Y-%m")
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
|
||||
@@ -157,38 +157,39 @@ class LimitValidator:
|
||||
user_tier = limits.get('tier', 'free') if limits else 'free'
|
||||
|
||||
# Get current usage for this billing period with error handling
|
||||
# Use subscription period, not calendar month
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id)
|
||||
|
||||
# Only expire specific objects that might have changed after renewal
|
||||
# (subscription was already checked above; plan was expired above)
|
||||
# The usage record is the main object we need fresh, and we query it directly below
|
||||
if subscription:
|
||||
self.db.expire(subscription)
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
# Use targeted expiry instead of expire_all() to avoid nuking the entire session cache
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Only expire specific objects that might have changed after renewal
|
||||
# (subscription was already checked above; plan was expired above)
|
||||
# The usage record is the main object we need fresh, and we query it directly below
|
||||
if subscription:
|
||||
self.db.expire(subscription)
|
||||
|
||||
# Use raw SQL query first to bypass ORM cache, fallback to ORM if SQL fails
|
||||
usage = None
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
sql_query = text("SELECT * FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
result = self.db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
# Map result to UsageSummary object
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
except Exception as sql_error:
|
||||
logger.debug(f"[Subscription Check] Raw SQL query failed, using ORM: {sql_error}")
|
||||
# Fallback to ORM query
|
||||
usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if usage:
|
||||
self.db.refresh(usage) # Ensure fresh data
|
||||
|
||||
if not usage:
|
||||
# First usage this period, create summary
|
||||
@@ -447,7 +448,7 @@ class LimitValidator:
|
||||
logger.info(f"[Pre-flight Check] 📋 Validating {len(operations)} operation(s) before making any API calls")
|
||||
|
||||
# Get current usage and limits once
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id)
|
||||
current_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.info(f"[Pre-flight Check] 📅 Billing Period: {current_period} (for user {user_id})")
|
||||
|
||||
|
||||
@@ -67,56 +67,15 @@ class PricingService:
|
||||
self.db.rollback()
|
||||
return True
|
||||
|
||||
def get_current_billing_period(self, user_id: str) -> str:
|
||||
"""Return current billing period key (YYYY-MM) based on subscription, not calendar.
|
||||
Maintains backward compatibility with existing calendar-month data."""
|
||||
def get_current_billing_period(self, user_id: str) -> Optional[str]:
|
||||
"""Return current billing period key (YYYY-MM) after ensuring subscription is current."""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id,
|
||||
UserSubscription.is_active == True
|
||||
).first()
|
||||
|
||||
# Ensure subscription is current (advance if auto_renew)
|
||||
self._ensure_subscription_current(subscription)
|
||||
|
||||
# Use subscription's billing period, NOT calendar month
|
||||
if subscription and subscription.current_period_start:
|
||||
sub_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
|
||||
# Check if usage data exists for this subscription period
|
||||
from models.subscription_models import UsageSummary
|
||||
usage_exists = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == sub_period
|
||||
).first()
|
||||
|
||||
if usage_exists:
|
||||
return sub_period
|
||||
|
||||
# If no data for subscription period, check for calendar month data
|
||||
# This handles backward compatibility for existing users
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
if calendar_period != sub_period:
|
||||
calendar_usage = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
if calendar_usage:
|
||||
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {sub_period} has no data)")
|
||||
return calendar_period
|
||||
|
||||
return sub_period
|
||||
|
||||
# Fallback: Check if user has any usage summary and use that period
|
||||
from models.subscription_models import UsageSummary
|
||||
latest_summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).first()
|
||||
|
||||
if latest_summary:
|
||||
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period}")
|
||||
return latest_summary.billing_period
|
||||
|
||||
# Last fallback to calendar month for free tier / no data
|
||||
# Continue to use YYYY-MM for summaries
|
||||
return datetime.now().strftime("%Y-%m")
|
||||
|
||||
@classmethod
|
||||
@@ -871,7 +830,6 @@ class PricingService:
|
||||
'serper_calls': plan.serper_calls_limit,
|
||||
'metaphor_calls': plan.metaphor_calls_limit,
|
||||
'firecrawl_calls': plan.firecrawl_calls_limit,
|
||||
'exa_calls': getattr(plan, 'exa_calls_limit', 0), # Exa research API
|
||||
'stability_calls': plan.stability_calls_limit,
|
||||
'video_calls': getattr(plan, 'video_calls_limit', 0), # Support missing column
|
||||
'image_edit_calls': getattr(plan, 'image_edit_calls_limit', 0), # Support missing column
|
||||
|
||||
@@ -8,7 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning, ProcessedStripeEvent
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
|
||||
REQUIRED_STRIPE_PLAN_KEYS = {
|
||||
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value),
|
||||
@@ -421,6 +421,10 @@ class StripeService:
|
||||
try:
|
||||
sub = stripe.Subscription.retrieve(subscription_id)
|
||||
price_id = sub['items']['data'][0]['price']['id']
|
||||
# Map price_id to internal plan_id
|
||||
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
|
||||
# For now, we'll assume the metadata or a lookup.
|
||||
# Ideally, store price_id in SubscriptionPlan table or config.
|
||||
|
||||
# Update DB
|
||||
self._update_user_subscription(
|
||||
@@ -430,24 +434,6 @@ class StripeService:
|
||||
status="active",
|
||||
price_id=price_id
|
||||
)
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after checkout for user {user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(user_id)
|
||||
logger.info(f"Cleared dashboard cache for user {user_id} after checkout")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear cache after checkout for user {user_id}: {cache_err}")
|
||||
|
||||
# Expire all SQLAlchemy objects to force fresh reads
|
||||
self.db.expire_all()
|
||||
logger.info(f"Expired all SQLAlchemy objects for user {user_id} after checkout")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkout subscription: {e}")
|
||||
|
||||
@@ -471,28 +457,11 @@ class StripeService:
|
||||
logger.info(f"Payment succeeded for user {subscription.user_id}")
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
subscription.auto_renew = True
|
||||
# Update period start/end based on invoice lines period
|
||||
# Update period end based on invoice lines period
|
||||
if invoice.get('lines'):
|
||||
period_start = invoice['lines']['data'][0]['period']['start']
|
||||
period_end = invoice['lines']['data'][0]['period']['end']
|
||||
subscription.current_period_start = datetime.fromtimestamp(period_start)
|
||||
subscription.current_period_end = datetime.fromtimestamp(period_end)
|
||||
self.db.commit()
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(subscription.user_id)
|
||||
logger.info(f"Cleared subscription cache for user {subscription.user_id} after payment success")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after payment success for user {subscription.user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(subscription.user_id)
|
||||
except Exception as dash_cache_err:
|
||||
logger.warning(f"Failed to clear dashboard cache after payment success for user {subscription.user_id}: {dash_cache_err}")
|
||||
self.db.expire_all()
|
||||
|
||||
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
|
||||
subscription_id = invoice.get("subscription")
|
||||
@@ -528,12 +497,6 @@ class StripeService:
|
||||
if status in ["active", "trialing"]:
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
subscription.auto_renew = True
|
||||
# Update period boundaries from Stripe event
|
||||
current_period = subscription_obj.get("current_period", {})
|
||||
if current_period:
|
||||
subscription.current_period_start = datetime.fromtimestamp(current_period.get("start", 0))
|
||||
subscription.current_period_end = datetime.fromtimestamp(current_period.get("end", 0))
|
||||
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
|
||||
subscription.status = UsageStatus.PAST_DUE
|
||||
subscription.is_active = False
|
||||
@@ -543,20 +506,6 @@ class StripeService:
|
||||
subscription.auto_renew = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
# Clear PricingService cache so next status check returns updated limits
|
||||
try:
|
||||
from services.subscription import PricingService
|
||||
PricingService.clear_user_cache(subscription.user_id)
|
||||
logger.info(f"Cleared subscription cache for user {subscription.user_id} after subscription update")
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear user cache after subscription update for user {subscription.user_id}: {cache_err}")
|
||||
try:
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
clear_dashboard_cache(subscription.user_id)
|
||||
except Exception as dash_cache_err:
|
||||
logger.warning(f"Failed to clear dashboard cache after subscription update for user {subscription.user_id}: {dash_cache_err}")
|
||||
self.db.expire_all()
|
||||
|
||||
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
|
||||
"""
|
||||
@@ -661,11 +610,6 @@ class StripeService:
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
# Calculate billing period end based on cycle
|
||||
if billing_cycle == BillingCycle.YEARLY:
|
||||
period_end = now + timedelta(days=365)
|
||||
else:
|
||||
period_end = now + timedelta(days=30)
|
||||
|
||||
if not subscription:
|
||||
subscription = UserSubscription(
|
||||
@@ -673,7 +617,7 @@ class StripeService:
|
||||
plan_id=plan.id,
|
||||
billing_cycle=billing_cycle,
|
||||
current_period_start=now,
|
||||
current_period_end=period_end,
|
||||
current_period_end=now,
|
||||
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
|
||||
is_active=status == "active",
|
||||
auto_renew=True,
|
||||
@@ -683,11 +627,6 @@ class StripeService:
|
||||
subscription.plan_id = plan.id
|
||||
subscription.billing_cycle = billing_cycle
|
||||
subscription.is_active = status == "active"
|
||||
subscription.status = UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED
|
||||
# Reset billing period on upgrade/plan change
|
||||
subscription.current_period_start = now
|
||||
subscription.current_period_end = period_end
|
||||
subscription.auto_renew = True
|
||||
|
||||
subscription.stripe_customer_id = stripe_customer_id
|
||||
subscription.stripe_subscription_id = stripe_subscription_id
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
"""
|
||||
Usage tracking modules package.
|
||||
Split from the monolithic usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from .historical_usage import get_all_historical_usage, get_current_period_usage, get_usage_for_period
|
||||
from .usage_stats import get_user_usage_stats
|
||||
from .usage_trends import get_usage_trends
|
||||
from .limits_enforcement import enforce_usage_limits
|
||||
from .alerts import check_usage_alerts, create_usage_alert
|
||||
|
||||
__all__ = [
|
||||
'get_all_historical_usage',
|
||||
'get_current_period_usage',
|
||||
'get_usage_for_period',
|
||||
'get_user_usage_stats',
|
||||
'get_usage_trends',
|
||||
'enforce_usage_limits',
|
||||
'check_usage_alerts',
|
||||
'create_usage_alert',
|
||||
]
|
||||
@@ -1,101 +0,0 @@
|
||||
"""
|
||||
Usage alert functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import UsageAlert, UsageSummary, APIProvider, UsageStatus
|
||||
|
||||
|
||||
def check_usage_alerts(user_id: str, provider: APIProvider,
|
||||
billing_period: str, db: Session, pricing_service):
|
||||
"""Check if usage alerts should be sent."""
|
||||
# Get current usage
|
||||
period_keys = {'billing_period': billing_period, 'lookup_periods': [billing_period]}
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period,
|
||||
db=db
|
||||
)
|
||||
|
||||
|
||||
def create_usage_alert(user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str, db: Session):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
@@ -1,250 +0,0 @@
|
||||
"""
|
||||
Historical usage aggregation functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import UsageSummary, UsageStatus
|
||||
|
||||
|
||||
# Shared provider mapping: DB column → frontend key
|
||||
PROVIDER_MAPPING = {
|
||||
'gemini_calls': 'gemini',
|
||||
'openai_calls': 'openai',
|
||||
'anthropic_calls': 'anthropic',
|
||||
'mistral_calls': 'huggingface', # HuggingFace stored as mistral
|
||||
'wavespeed_calls': 'wavespeed',
|
||||
'exa_calls': 'exa',
|
||||
'tavily_calls': 'tavily',
|
||||
'serper_calls': 'serper',
|
||||
'firecrawl_calls': 'firecrawl',
|
||||
'metaphor_calls': 'metaphor',
|
||||
'stability_calls': 'stability',
|
||||
'video_calls': 'video',
|
||||
'image_edit_calls': 'image_edit',
|
||||
'audio_calls': 'audio',
|
||||
}
|
||||
|
||||
|
||||
def _build_provider_breakdown(summaries: list, mapping: dict) -> dict:
|
||||
"""Build provider_breakdown dict from a list of UsageSummary records."""
|
||||
breakdown = {}
|
||||
for db_col, frontend_key in mapping.items():
|
||||
total = sum(getattr(s, db_col, 0) or 0 for s in summaries)
|
||||
breakdown[frontend_key] = {'calls': total, 'cost': 0, 'tokens': 0}
|
||||
return breakdown
|
||||
|
||||
|
||||
def _build_usage_percentages(provider_breakdown: dict, limits: dict) -> dict:
|
||||
"""Build usage_percentages dict from provider_breakdown and per-period limits."""
|
||||
pcts = {}
|
||||
if not limits or not limits.get('limits'):
|
||||
return pcts
|
||||
|
||||
limit_map = {
|
||||
'gemini_calls': ('gemini', 'gemini_calls'),
|
||||
'huggingface_calls': ('huggingface', 'mistral_calls'),
|
||||
'stability_calls': ('stability', 'stability_calls'),
|
||||
'video_calls': ('video', 'video_calls'),
|
||||
'audio_calls': ('audio', 'audio_calls'),
|
||||
'image_edit_calls': ('image_edit', 'image_edit_calls'),
|
||||
'wavespeed_calls': ('wavespeed', 'wavespeed_calls'),
|
||||
'tavily_calls': ('tavily', 'tavily_calls'),
|
||||
'serper_calls': ('serper', 'serper_calls'),
|
||||
'firecrawl_calls': ('firecrawl', 'firecrawl_calls'),
|
||||
'metaphor_calls': ('metaphor', 'metaphor_calls'),
|
||||
'exa_calls': ('exa', 'exa_calls'),
|
||||
}
|
||||
|
||||
for pct_key, (bk_key, limit_key) in limit_map.items():
|
||||
used = provider_breakdown.get(bk_key, {}).get('calls', 0)
|
||||
limit_val = limits.get('limits', {}).get(limit_key, 0) or 0
|
||||
if limit_val > 0:
|
||||
pcts[pct_key] = (used / limit_val) * 100
|
||||
|
||||
# Cost percentage
|
||||
total_cost = provider_breakdown.get('total_cost', 0)
|
||||
cost_limit = limits.get('limits', {}).get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
pcts['cost'] = (total_cost / cost_limit) * 100
|
||||
|
||||
return pcts
|
||||
|
||||
|
||||
def _summaries_usage_status(summaries: list) -> str:
|
||||
"""Derive overall usage_status from a list of summaries."""
|
||||
status = 'active'
|
||||
for s in summaries:
|
||||
try:
|
||||
st = s.usage_status.value
|
||||
except Exception:
|
||||
st = str(s.usage_status)
|
||||
if st == 'limit_reached':
|
||||
return 'limit_reached'
|
||||
if st == 'warning' and status != 'limit_reached':
|
||||
status = 'warning'
|
||||
return status
|
||||
|
||||
|
||||
def _empty_usage_response(billing_period: str, limits: dict) -> Dict[str, Any]:
|
||||
"""Return a zeroed UsageStats-shaped response."""
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': 'active',
|
||||
'total_calls': 0,
|
||||
'total_tokens': 0,
|
||||
'total_cost': 0.0,
|
||||
'avg_response_time': 0.0,
|
||||
'error_rate': 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': {},
|
||||
'usage_percentages': {},
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_all_historical_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
all_summaries = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
if not all_summaries:
|
||||
return _empty_usage_response('all', limits)
|
||||
|
||||
# Aggregate
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
provider_breakdown = _build_provider_breakdown(all_summaries, PROVIDER_MAPPING)
|
||||
|
||||
# Historical breakdown per period
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
'billing_period': s.billing_period,
|
||||
'total_calls': s.total_calls or 0,
|
||||
'total_tokens': s.total_tokens or 0,
|
||||
'total_cost': float(s.total_cost or 0),
|
||||
'usage_status': status_val,
|
||||
'updated_at': s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
return {
|
||||
'billing_period': 'all',
|
||||
'usage_status': _summaries_usage_status(all_summaries),
|
||||
'total_calls': total_calls,
|
||||
'total_tokens': total_tokens,
|
||||
'total_cost': round(total_cost, 2),
|
||||
'avg_response_time': round(avg_response_time, 2),
|
||||
'error_rate': round(error_rate, 2),
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': {}, # misleading for all-time vs per-period limits
|
||||
'historical_breakdown': historical_breakdown,
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_current_period_usage(user_id: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get current billing period usage data with correct per-period limit percentages.
|
||||
|
||||
Returns a UsageStats-shaped dict with provider_breakdown and usage_percentages
|
||||
computed against the plan's per-period limits.
|
||||
"""
|
||||
current_period = pricing_service.get_current_billing_period(user_id)
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
result = _empty_usage_response(current_period, limits)
|
||||
result['usage_percentages'] = _build_usage_percentages({}, limits)
|
||||
return result
|
||||
|
||||
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
|
||||
|
||||
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
|
||||
|
||||
try:
|
||||
status_val = summary.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(summary.usage_status)
|
||||
|
||||
return {
|
||||
'billing_period': current_period,
|
||||
'usage_status': status_val,
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': round(float(summary.total_cost or 0), 2),
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
|
||||
|
||||
def get_usage_for_period(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get usage data for a specific billing period.
|
||||
|
||||
Returns a UsageStats-shaped dict with that period's provider_breakdown
|
||||
and usage_percentages computed against plan limits.
|
||||
"""
|
||||
limits = pricing_service.get_user_limits(user_id)
|
||||
|
||||
summary = db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
result = _empty_usage_response(billing_period, limits)
|
||||
result['usage_percentages'] = _build_usage_percentages({}, limits)
|
||||
return result
|
||||
|
||||
provider_breakdown = _build_provider_breakdown([summary], PROVIDER_MAPPING)
|
||||
usage_percentages = _build_usage_percentages(provider_breakdown, limits)
|
||||
|
||||
try:
|
||||
status_val = summary.usage_status.value
|
||||
except Exception:
|
||||
status_val = str(summary.usage_status)
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': status_val,
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': round(float(summary.total_cost or 0), 2),
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'usage_percentages': usage_percentages,
|
||||
'historical_breakdown': [],
|
||||
'last_updated': datetime.now().isoformat()
|
||||
}
|
||||
@@ -1,38 +0,0 @@
|
||||
"""
|
||||
Usage limit enforcement functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Tuple, Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
from models.subscription_models import APIProvider
|
||||
from services.subscription.pricing_service import PricingService
|
||||
|
||||
|
||||
def enforce_usage_limits(user_id: str, provider: APIProvider,
|
||||
tokens_requested: int, db: Session,
|
||||
pricing_service: PricingService) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
# Check short-lived cache first (30s)
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
|
||||
# This would need access to self._enforce_cache
|
||||
# For now, keeping the structure
|
||||
|
||||
result = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
|
||||
# Cache the result
|
||||
# self._enforce_cache[cache_key] = {
|
||||
# 'result': result,
|
||||
# 'expires_at': now + timedelta(seconds=30)
|
||||
# }
|
||||
|
||||
return tuple(result)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
Usage statistics functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import UsageSummary, UsageStatus, APIProvider
|
||||
from services.subscription.usage_tracking_modules.historical_usage import get_all_historical_usage, get_usage_for_period
|
||||
|
||||
|
||||
def get_user_usage_stats(user_id: str, billing_period: str, db: Session, pricing_service) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user.
|
||||
When no billing_period is specified, returns ALL historical usage data.
|
||||
When a specific period is given, returns only that period's data."""
|
||||
|
||||
if not user_id:
|
||||
logger.error("get_user_usage_stats called without user_id")
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
# If no billing_period requested, return ALL historical data
|
||||
if not billing_period:
|
||||
return get_all_historical_usage(user_id, db, pricing_service)
|
||||
|
||||
# Return data for the specific billing period
|
||||
return get_usage_for_period(user_id, billing_period, db, pricing_service)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""
|
||||
Usage trends functions.
|
||||
Extracted from usage_tracking_service.py for better maintainability.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def get_usage_trends(user_id: str, months: int, db: Session) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
from services.subscription.usage_tracking_helpers import build_billing_periods, query_usage_summaries, self_heal_summaries_from_logs, build_usage_trends_response
|
||||
|
||||
periods = build_billing_periods(months)
|
||||
summary_dict = query_usage_summaries(db, user_id, periods)
|
||||
self_heal_summaries_from_logs(db, user_id, periods, summary_dict)
|
||||
return build_usage_trends_response(periods, summary_dict)
|
||||
@@ -1,60 +1,41 @@
|
||||
"""
|
||||
Usage Tracking Service - Refactored into modular components.
|
||||
|
||||
This file now serves as a facade that delegates to specialized modules
|
||||
in the usage_tracking_modules package.
|
||||
|
||||
Modules:
|
||||
- historical_usage: Functions for aggregating historical usage data
|
||||
- usage_stats: Functions for getting user usage statistics
|
||||
- usage_trends: Functions for usage trend analysis
|
||||
- limit_enforcement: Functions for enforcing usage limits
|
||||
- alerts: Functions for usage alerts
|
||||
Usage Tracking Service
|
||||
Comprehensive tracking of API usage, costs, and subscription limits.
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, Tuple, Optional
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
from loguru import logger
|
||||
# Ensure Optional is available in global scope for dynamic imports
|
||||
from typing import Optional
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Any, List, Tuple
|
||||
from datetime import datetime, timedelta
|
||||
import time
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
from loguru import logger
|
||||
import json
|
||||
from api.subscription.cache import clear_dashboard_cache
|
||||
|
||||
from models.subscription_models import (
|
||||
APIProvider, UsageStatus, UserSubscription,
|
||||
UsageSummary, APIUsageLog, UsageAlert
|
||||
APIUsageLog, UsageSummary, APIProvider, UsageAlert,
|
||||
UserSubscription, UsageStatus
|
||||
)
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
from services.subscription.usage_tracking_helpers import (
|
||||
build_provider_breakdown,
|
||||
from .pricing_service import PricingService
|
||||
from .provider_detection import detect_actual_provider
|
||||
from .usage_tracking_helpers import (
|
||||
build_billing_periods,
|
||||
build_default_usage_percentages,
|
||||
build_empty_usage_response,
|
||||
build_provider_breakdown,
|
||||
build_usage_trends_response,
|
||||
calculate_final_total_cost,
|
||||
maybe_persist_reconciled_costs,
|
||||
build_usage_trends_response,
|
||||
build_billing_periods,
|
||||
query_usage_summaries,
|
||||
self_heal_summaries_from_logs,
|
||||
reset_usage_summary_counters,
|
||||
self_heal_summaries_from_logs,
|
||||
)
|
||||
# Import clear_dashboard_cache lazily to avoid circular import
|
||||
def _clear_dashboard_cache_for_user(user_id: str):
|
||||
from api.subscription.cache import clear_dashboard_cache as _clear
|
||||
return _clear(user_id)
|
||||
|
||||
from .usage_tracking_modules import (
|
||||
get_all_historical_usage,
|
||||
get_current_period_usage,
|
||||
get_usage_for_period,
|
||||
get_user_usage_stats,
|
||||
get_usage_trends,
|
||||
enforce_usage_limits,
|
||||
check_usage_alerts,
|
||||
create_usage_alert,
|
||||
)
|
||||
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking API usage and managing billing information."""
|
||||
"""Service for tracking API usage and managing subscription limits."""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
@@ -62,14 +43,13 @@ class UsageTrackingService:
|
||||
# TTL cache (30s) for enforcement results to cut DB chatter
|
||||
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
|
||||
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
|
||||
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""Return authoritative billing period lookup keys. Always uses subscription period for consistency.
|
||||
Maintains backward compatibility with existing calendar-month data."""
|
||||
"""Return authoritative billing period lookup keys. Always uses calendar month for consistency."""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
|
||||
# If caller explicitly requested a billing period, use it
|
||||
if billing_period:
|
||||
return {
|
||||
@@ -78,125 +58,26 @@ class UsageTrackingService:
|
||||
"period_start": subscription.current_period_start if subscription else None,
|
||||
"period_end": subscription.current_period_end if subscription else None,
|
||||
}
|
||||
|
||||
# ALWAYS use current calendar month for billing period to ensure consistency
|
||||
# This prevents data loss when subscription spans month boundaries
|
||||
current_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get subscription period if available
|
||||
subscription_period = None
|
||||
if subscription and subscription.current_period_start:
|
||||
subscription_period = subscription.current_period_start.strftime("%Y-%m")
|
||||
|
||||
# Get calendar period
|
||||
calendar_period = datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Check which period has usage data
|
||||
from models.subscription_models import UsageSummary
|
||||
|
||||
if subscription_period:
|
||||
# Check if data exists for subscription period
|
||||
sub_data = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == subscription_period
|
||||
).first()
|
||||
|
||||
if sub_data:
|
||||
# Use subscription period (has data)
|
||||
return {
|
||||
"billing_period": subscription_period,
|
||||
"lookup_periods": [subscription_period],
|
||||
"period_start": subscription.current_period_start,
|
||||
"period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
# No data for subscription period, check calendar period (backward compatibility)
|
||||
if calendar_period != subscription_period:
|
||||
cal_data = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == calendar_period
|
||||
).first()
|
||||
|
||||
if cal_data:
|
||||
logger.info(f"Using calendar period {calendar_period} for backward compatibility (subscription period {subscription_period} has no data)")
|
||||
return {
|
||||
"billing_period": calendar_period,
|
||||
"lookup_periods": [calendar_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# No data in either period, use subscription period
|
||||
return {
|
||||
"billing_period": subscription_period,
|
||||
"lookup_periods": [subscription_period],
|
||||
"period_start": subscription.current_period_start,
|
||||
"period_end": subscription.current_period_end,
|
||||
}
|
||||
|
||||
# No subscription, check for any existing data
|
||||
latest_summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).first()
|
||||
|
||||
if latest_summary:
|
||||
logger.info(f"Using latest billing period from UsageSummary: {latest_summary.billing_period} for user {user_id}")
|
||||
return {
|
||||
"billing_period": latest_summary.billing_period,
|
||||
"lookup_periods": [latest_summary.billing_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
}
|
||||
|
||||
# Last fallback to calendar month for free tier / no subscription
|
||||
return {
|
||||
"billing_period": calendar_period,
|
||||
"lookup_periods": [calendar_period],
|
||||
"period_start": None,
|
||||
"period_end": None,
|
||||
"billing_period": current_period,
|
||||
"lookup_periods": [current_period],
|
||||
"period_start": subscription.current_period_start if subscription else None,
|
||||
"period_end": subscription.current_period_end if subscription else None,
|
||||
}
|
||||
|
||||
# Delegate to modular functions
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
return get_user_usage_stats(user_id, billing_period, self.db, self.pricing_service)
|
||||
|
||||
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get ALL historical usage data aggregated across all billing periods."""
|
||||
return get_all_historical_usage(user_id, self.db, self.pricing_service)
|
||||
|
||||
def get_current_period_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Get current billing period usage with correct per-period limit percentages."""
|
||||
return get_current_period_usage(user_id, self.db, self.pricing_service)
|
||||
|
||||
def get_usage_for_period(self, user_id: str, billing_period: str) -> Dict[str, Any]:
|
||||
"""Get usage for a specific billing period."""
|
||||
return get_usage_for_period(user_id, billing_period, self.db, self.pricing_service)
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
return get_usage_trends(user_id, months, self.db)
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
return enforce_usage_limits(user_id, provider, tokens_requested, self.db, self.pricing_service)
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
check_usage_alerts(user_id, provider, billing_period, self.db, self.pricing_service)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
create_usage_alert(user_id, provider, threshold, current_usage, limit, billing_period, self.db)
|
||||
|
||||
# Keep the track_api_usage method here as it's the core functionality
|
||||
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
endpoint: str, method: str, model_used: str = None,
|
||||
tokens_input: int = 0, tokens_output: int = 0,
|
||||
response_time: float = 0.0, status_code: int = 200,
|
||||
request_size: int = None, response_size: int = None,
|
||||
user_agent: str = None, ip_address: str = None,
|
||||
error_message: str = None, retry_count: int = 0,
|
||||
**kwargs) -> Dict[str, Any]:
|
||||
"""Track an API usage event and update billing information."""
|
||||
|
||||
try:
|
||||
@@ -284,81 +165,394 @@ class UsageTrackingService:
|
||||
|
||||
# Invalidate dashboard cache so header stats update immediately
|
||||
try:
|
||||
_clear_dashboard_cache_for_user(user_id)
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.warning(f"Failed to clear dashboard cache: {cache_err}")
|
||||
logger.debug(f"Could not clear dashboard cache: {cache_err}")
|
||||
|
||||
logger.info(f"Tracked API usage: {user_id} -> {provider.value} -> ${cost_data['cost_total']:.6f}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"cost": cost_data['cost_total'],
|
||||
"tokens": (tokens_input or 0) + (tokens_output or 0),
|
||||
"billing_period": billing_period
|
||||
'usage_logged': True,
|
||||
'cost': cost_data['cost_total'],
|
||||
'tokens_used': (tokens_input or 0) + (tokens_output or 0),
|
||||
'billing_period': billing_period
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track API usage: {e}")
|
||||
logger.error(f"Error tracking API usage: {str(e)}")
|
||||
self.db.rollback()
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
'usage_logged': False,
|
||||
'error': str(e)
|
||||
}
|
||||
|
||||
async def _update_usage_summary(self, user_id: str, provider: APIProvider,
|
||||
tokens_used: int, cost: float,
|
||||
billing_period: str,
|
||||
response_time: float = 0.0,
|
||||
is_error: bool = False):
|
||||
"""Update or create usage summary for the billing period."""
|
||||
tokens_used: int, cost: float, billing_period: str,
|
||||
response_time: float, is_error: bool):
|
||||
"""Update the usage summary for a user."""
|
||||
|
||||
# Get or create summary
|
||||
# Get or create usage summary
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == billing_period
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.info(f"[UsageTracking] Creating new UsageSummary for user={user_id}, period={period_keys['billing_period']}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=billing_period,
|
||||
usage_status=UsageStatus.ACTIVE,
|
||||
total_calls=0,
|
||||
total_tokens=0,
|
||||
total_cost=0.0
|
||||
billing_period=period_keys["billing_period"]
|
||||
)
|
||||
self.db.add(summary)
|
||||
else:
|
||||
logger.debug(f"[UsageTracking] Found existing UsageSummary for user={user_id}, period={summary.billing_period}, calls={summary.total_calls}")
|
||||
|
||||
# Update counts
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_used
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
|
||||
# Update provider-specific counts
|
||||
# Update provider-specific counters
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
setattr(summary, f"{provider_name}_calls", current_calls + 1)
|
||||
|
||||
# Update provider-specific tokens
|
||||
tokens_attr = f"{provider_name}_tokens"
|
||||
if hasattr(summary, tokens_attr):
|
||||
current_tokens = getattr(summary, tokens_attr, 0) or 0
|
||||
setattr(summary, tokens_attr, current_tokens + tokens_used)
|
||||
# Update token usage for LLM providers
|
||||
if provider in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL, APIProvider.WAVESPEED]:
|
||||
current_tokens = getattr(summary, f"{provider_name}_tokens", 0)
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens + tokens_used)
|
||||
|
||||
# Update provider-specific cost
|
||||
cost_attr = f"{provider_name}_cost"
|
||||
if hasattr(summary, cost_attr):
|
||||
current_cost = getattr(summary, cost_attr, 0.0) or 0.0
|
||||
setattr(summary, cost_attr, current_cost + cost)
|
||||
# Update cost
|
||||
current_cost = getattr(summary, f"{provider_name}_cost", 0.0)
|
||||
setattr(summary, f"{provider_name}_cost", current_cost + cost)
|
||||
|
||||
# Update response time (rolling average)
|
||||
if response_time > 0:
|
||||
current_avg = summary.avg_response_time or 0.0
|
||||
current_calls = summary.total_calls or 1
|
||||
summary.avg_response_time = ((current_avg * (current_calls - 1)) + response_time) / current_calls
|
||||
# Update totals
|
||||
summary.total_calls += 1
|
||||
summary.total_tokens += tokens_used
|
||||
summary.total_cost += cost
|
||||
|
||||
# Update error rate
|
||||
if is_error:
|
||||
summary.error_count = (summary.error_count or 0) + 1
|
||||
total_calls = summary.total_calls or 1
|
||||
summary.error_rate = (summary.error_count / total_calls) * 100
|
||||
# Update performance metrics
|
||||
if summary.total_calls > 0:
|
||||
# Update average response time
|
||||
total_response_time = summary.avg_response_time * (summary.total_calls - 1) + response_time
|
||||
summary.avg_response_time = total_response_time / summary.total_calls
|
||||
|
||||
# Update error rate
|
||||
if is_error:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100) + 1
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
else:
|
||||
error_count = int(summary.error_rate * (summary.total_calls - 1) / 100)
|
||||
summary.error_rate = (error_count / summary.total_calls) * 100
|
||||
|
||||
# Update usage status based on limits
|
||||
await self._update_usage_status(summary)
|
||||
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
async def _update_usage_status(self, summary: UsageSummary):
|
||||
"""Update usage status based on subscription limits."""
|
||||
|
||||
limits = self.pricing_service.get_user_limits(summary.user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check various limits and determine status
|
||||
max_usage_percentage = 0.0
|
||||
|
||||
# Check cost limit
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0)
|
||||
if cost_limit > 0:
|
||||
cost_usage_pct = (summary.total_cost / cost_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, cost_usage_pct)
|
||||
|
||||
# Check call limits for each provider
|
||||
for provider in APIProvider:
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
call_usage_pct = (current_calls / call_limit) * 100
|
||||
max_usage_percentage = max(max_usage_percentage, call_usage_pct)
|
||||
|
||||
# Update status based on highest usage percentage
|
||||
if max_usage_percentage >= 100:
|
||||
summary.usage_status = UsageStatus.LIMIT_REACHED
|
||||
elif max_usage_percentage >= 80:
|
||||
summary.usage_status = UsageStatus.WARNING
|
||||
else:
|
||||
summary.usage_status = UsageStatus.ACTIVE
|
||||
|
||||
async def _check_usage_alerts(self, user_id: str, provider: APIProvider, billing_period: str):
|
||||
"""Check if usage alerts should be sent."""
|
||||
|
||||
# Get current usage
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
if not limits:
|
||||
return
|
||||
|
||||
# Check for alert thresholds (80%, 90%, 100%)
|
||||
thresholds = [80, 90, 100]
|
||||
|
||||
for threshold in thresholds:
|
||||
# Check if alert already sent for this threshold
|
||||
existing_alert = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.threshold_percentage == threshold,
|
||||
UsageAlert.provider == provider,
|
||||
UsageAlert.is_sent == True
|
||||
).first()
|
||||
|
||||
if existing_alert:
|
||||
continue
|
||||
|
||||
# Check if threshold is reached
|
||||
provider_name = provider.value
|
||||
current_calls = getattr(summary, f"{provider_name}_calls", 0)
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0)
|
||||
|
||||
if call_limit > 0:
|
||||
usage_percentage = (current_calls / call_limit) * 100
|
||||
|
||||
if usage_percentage >= threshold:
|
||||
await self._create_usage_alert(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
threshold=threshold,
|
||||
current_usage=current_calls,
|
||||
limit=call_limit,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
async def _create_usage_alert(self, user_id: str, provider: APIProvider,
|
||||
threshold: int, current_usage: int, limit: int,
|
||||
billing_period: str):
|
||||
"""Create a usage alert."""
|
||||
|
||||
# Determine alert type and severity
|
||||
if threshold >= 100:
|
||||
alert_type = "limit_reached"
|
||||
severity = "error"
|
||||
title = f"API Limit Reached - {provider.value.title()}"
|
||||
message = f"You have reached your {provider.value} API limit of {limit:,} calls for this billing period."
|
||||
elif threshold >= 90:
|
||||
alert_type = "usage_warning"
|
||||
severity = "warning"
|
||||
title = f"API Usage Warning - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
else:
|
||||
alert_type = "usage_warning"
|
||||
severity = "info"
|
||||
title = f"API Usage Notice - {provider.value.title()}"
|
||||
message = f"You have used {current_usage:,} of {limit:,} {provider.value} API calls ({threshold}% of your limit)."
|
||||
|
||||
alert = UsageAlert(
|
||||
user_id=user_id,
|
||||
alert_type=alert_type,
|
||||
threshold_percentage=threshold,
|
||||
provider=provider,
|
||||
title=title,
|
||||
message=message,
|
||||
severity=severity,
|
||||
billing_period=billing_period
|
||||
)
|
||||
|
||||
self.db.add(alert)
|
||||
logger.info(f"Created usage alert for {user_id}: {title}")
|
||||
|
||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||
"""Get comprehensive usage statistics for a user."""
|
||||
|
||||
if not user_id:
|
||||
logger.error("get_user_usage_stats called without user_id")
|
||||
raise ValueError("user_id is required")
|
||||
|
||||
requested_billing_period = billing_period
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period)
|
||||
billing_period = period_keys["billing_period"]
|
||||
|
||||
logger.debug(f"[get_user_usage_stats] user={user_id}, billing_period={billing_period}, lookup_periods={period_keys['lookup_periods']}")
|
||||
|
||||
# Get usage summary
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if summary:
|
||||
logger.debug(f"[get_user_usage_stats] Found summary: period={summary.billing_period}, calls={summary.total_calls}, cost={summary.total_cost}")
|
||||
else:
|
||||
logger.debug(f"[get_user_usage_stats] No summary found for user={user_id}, period={billing_period}")
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Get recent alerts
|
||||
alerts = self.db.query(UsageAlert).filter(
|
||||
UsageAlert.user_id == user_id,
|
||||
UsageAlert.billing_period == billing_period,
|
||||
UsageAlert.is_read == False
|
||||
).order_by(UsageAlert.created_at.desc()).limit(10).all()
|
||||
|
||||
if not summary:
|
||||
# If no summary exists for current period, we should initialize it
|
||||
# This handles the "start of month" case where a user logs in but hasn't made calls yet
|
||||
if not requested_billing_period:
|
||||
logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=billing_period,
|
||||
usage_status=UsageStatus.ACTIVE,
|
||||
total_calls=0,
|
||||
total_tokens=0,
|
||||
total_cost=0.0
|
||||
)
|
||||
try:
|
||||
self.db.add(summary)
|
||||
self.db.commit()
|
||||
self.db.refresh(summary)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize summary: {e}")
|
||||
self.db.rollback()
|
||||
# Fallback to zero-struct return if DB write fails
|
||||
pass
|
||||
|
||||
if not summary: # Still no summary after attempt
|
||||
return build_empty_usage_response(
|
||||
billing_period=billing_period,
|
||||
limits=limits,
|
||||
providers=APIProvider,
|
||||
)
|
||||
|
||||
# Provider breakdown - calculate costs first, then use for percentages
|
||||
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
|
||||
provider_breakdown, resolved_costs, core_counts = build_provider_breakdown(
|
||||
db=self.db,
|
||||
user_id=user_id,
|
||||
billing_period=billing_period,
|
||||
summary=summary,
|
||||
)
|
||||
|
||||
summary_total_cost = summary.total_cost or 0.0
|
||||
calculated_total_cost, final_total_cost = calculate_final_total_cost(
|
||||
summary_total_cost=summary_total_cost,
|
||||
resolved_costs=resolved_costs,
|
||||
)
|
||||
|
||||
maybe_persist_reconciled_costs(
|
||||
db=self.db,
|
||||
summary=summary,
|
||||
summary_total_cost=summary_total_cost,
|
||||
calculated_total_cost=calculated_total_cost,
|
||||
final_total_cost=final_total_cost,
|
||||
resolved_costs=resolved_costs,
|
||||
)
|
||||
|
||||
# Calculate usage percentages - only for Gemini and HuggingFace
|
||||
# Use the calculated costs for accurate percentages
|
||||
usage_percentages = build_default_usage_percentages(APIProvider)
|
||||
if limits:
|
||||
# Gemini
|
||||
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
|
||||
if gemini_call_limit > 0:
|
||||
usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100
|
||||
|
||||
# HuggingFace (stored as mistral in database)
|
||||
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
|
||||
if mistral_call_limit > 0:
|
||||
usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100
|
||||
|
||||
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
|
||||
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
|
||||
if cost_limit > 0:
|
||||
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
|
||||
|
||||
return {
|
||||
'billing_period': billing_period,
|
||||
'usage_status': summary.usage_status.value if hasattr(summary.usage_status, 'value') else str(summary.usage_status),
|
||||
'total_calls': summary.total_calls or 0,
|
||||
'total_tokens': summary.total_tokens or 0,
|
||||
'total_cost': final_total_cost,
|
||||
'avg_response_time': summary.avg_response_time or 0.0,
|
||||
'error_rate': summary.error_rate or 0.0,
|
||||
'limits': limits,
|
||||
'provider_breakdown': provider_breakdown,
|
||||
'alerts': [
|
||||
{
|
||||
'id': alert.id,
|
||||
'type': alert.alert_type,
|
||||
'title': alert.title,
|
||||
'message': alert.message,
|
||||
'severity': alert.severity,
|
||||
'created_at': alert.created_at.isoformat()
|
||||
}
|
||||
for alert in alerts
|
||||
],
|
||||
'usage_percentages': usage_percentages,
|
||||
'last_updated': summary.updated_at.isoformat()
|
||||
}
|
||||
|
||||
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
|
||||
"""Get usage trends over time with self-healing from logs."""
|
||||
periods = build_billing_periods(months)
|
||||
summary_dict = query_usage_summaries(self.db, user_id, periods)
|
||||
self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict)
|
||||
return build_usage_trends_response(periods, summary_dict)
|
||||
|
||||
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
|
||||
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
|
||||
"""Enforce usage limits before making an API call."""
|
||||
# Check short-lived cache first (30s)
|
||||
cache_key = f"{user_id}:{provider.value}"
|
||||
now = datetime.utcnow()
|
||||
cached = self._enforce_cache.get(cache_key)
|
||||
if cached and cached.get('expires_at') and cached['expires_at'] > now:
|
||||
return tuple(cached['result']) # type: ignore
|
||||
|
||||
result = self.pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
tokens_requested=tokens_requested
|
||||
)
|
||||
self._enforce_cache[cache_key] = {
|
||||
'result': result,
|
||||
'expires_at': now + timedelta(seconds=30)
|
||||
}
|
||||
return result
|
||||
|
||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
||||
period_keys = self._get_authoritative_billing_period_keys(user_id)
|
||||
billing_period = period_keys["billing_period"]
|
||||
summary = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
return {"reset": False, "reason": "no_summary"}
|
||||
|
||||
try:
|
||||
reset_usage_summary_counters(summary)
|
||||
self.db.commit()
|
||||
|
||||
# Invalidate dashboard cache so header stats update after reset
|
||||
try:
|
||||
clear_dashboard_cache(user_id)
|
||||
except Exception as cache_err:
|
||||
logger.debug(f"Could not clear dashboard cache: {cache_err}")
|
||||
|
||||
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
|
||||
return {"reset": True, "counters_reset": True}
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"Error resetting usage status: {e}")
|
||||
return {"reset": False, "error": str(e)}
|
||||
|
||||
@@ -2,8 +2,9 @@ import os
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
from dataclasses import dataclass
|
||||
import httpx
|
||||
import requests
|
||||
from loguru import logger
|
||||
import time
|
||||
import random
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
@@ -60,26 +61,30 @@ class WritingAssistantService:
|
||||
logger.info(f"Writing assistant API call #{self.daily_api_calls}/{self.daily_limit} today")
|
||||
return True
|
||||
|
||||
async def suggest(self, text: str, user_id: str | None = None) -> List[WritingSuggestion]:
|
||||
async def suggest(self, text: str, max_results: int = 1) -> List[WritingSuggestion]:
|
||||
if not text or len(text.strip()) < 6:
|
||||
return []
|
||||
|
||||
# COST OPTIMIZATION: Use cached/static suggestions for common patterns
|
||||
# This reduces API calls by 90%+ while maintaining usefulness
|
||||
cached_suggestion = self._get_cached_suggestion(text)
|
||||
if cached_suggestion:
|
||||
return [cached_suggestion]
|
||||
|
||||
# COST CONTROL: Check daily usage limits
|
||||
if not self._check_daily_limit():
|
||||
logger.warning("Daily API limit reached for writing assistant")
|
||||
return []
|
||||
|
||||
if len(text.strip()) < 50:
|
||||
# Only make expensive API calls for unique, substantial content
|
||||
if len(text.strip()) < 50: # Skip API calls for very short text
|
||||
return []
|
||||
|
||||
# 1) Find relevant sources via Exa
|
||||
# 1) Find relevant sources via Exa (reduced results for cost)
|
||||
sources = await self._search_sources(text)
|
||||
|
||||
# 2) Generate continuation suggestion via LLM grounded in sources
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources, user_id=user_id)
|
||||
# 2) Generate continuation suggestion via Gemini
|
||||
suggestion_text, confidence = await self._generate_continuation(text, sources)
|
||||
|
||||
if not suggestion_text:
|
||||
return []
|
||||
@@ -105,12 +110,12 @@ class WritingAssistantService:
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.http_timeout_seconds) as client:
|
||||
resp = await client.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
resp = requests.post(
|
||||
"https://api.exa.ai/search",
|
||||
headers={"x-api-key": self.exa_api_key, "Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=self.http_timeout_seconds,
|
||||
)
|
||||
if resp.status_code != 200:
|
||||
raise Exception(f"Exa error {resp.status_code}: {resp.text}")
|
||||
data = resp.json()
|
||||
@@ -135,7 +140,8 @@ class WritingAssistantService:
|
||||
logger.error(f"WritingAssistant _search_sources error: {e}")
|
||||
raise
|
||||
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]], user_id: str | None = None) -> tuple[str, float]:
|
||||
async def _generate_continuation(self, text: str, sources: List[Dict[str, Any]]) -> tuple[str, float]:
|
||||
# Build compact sources context block
|
||||
source_blocks: List[str] = []
|
||||
for i, s in enumerate(sources[:5]):
|
||||
excerpt = (s.get("text", "") or "")
|
||||
@@ -143,14 +149,16 @@ class WritingAssistantService:
|
||||
source_blocks.append(
|
||||
f"Source {i+1}: {s.get('title','') or 'Source'}\nURL: {s.get('url','')}\nExcerpt: {excerpt}"
|
||||
)
|
||||
sources_text = "\n\n".join(source_blocks)
|
||||
sources_text = "\n\n".join(source_blocks) if source_blocks else "(No sources)"
|
||||
|
||||
# Provider-agnostic behavior: short continuation with one inline citation hint
|
||||
system_prompt = (
|
||||
"You are an assistive writing continuation bot. "
|
||||
"Only produce 1-2 SHORT sentences. Do not repeat or paraphrase the user's stub. "
|
||||
"Match tone and topic. Prefer concrete, current facts from the provided sources. "
|
||||
"Include exactly one brief citation hint in parentheses with an author (or 'Source') and URL in square brackets, e.g., ((Doe, 2021)[https://example.com])."
|
||||
)
|
||||
|
||||
user_prompt = (
|
||||
f"User text to continue (do not repeat):\n{text}\n\n"
|
||||
f"Relevant sources to inform your continuation:\n{sources_text}\n\n"
|
||||
@@ -158,13 +166,13 @@ class WritingAssistantService:
|
||||
)
|
||||
|
||||
try:
|
||||
await asyncio.sleep(random.uniform(0.05, 0.15))
|
||||
# Inter-call jitter to reduce burst rate limits
|
||||
time.sleep(random.uniform(0.05, 0.15))
|
||||
|
||||
ai_resp = llm_text_gen(
|
||||
prompt=user_prompt,
|
||||
json_struct=None,
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id,
|
||||
)
|
||||
if isinstance(ai_resp, dict) and ai_resp.get("text"):
|
||||
suggestion = (ai_resp.get("text", "") or "").strip()
|
||||
@@ -172,10 +180,12 @@ class WritingAssistantService:
|
||||
suggestion = (str(ai_resp or "")).strip()
|
||||
if not suggestion:
|
||||
raise Exception("Assistive writer returned empty suggestion")
|
||||
confidence = 0.7
|
||||
# naive confidence from number of sources present
|
||||
confidence = 0.7 if sources else 0.5
|
||||
return suggestion, confidence
|
||||
except Exception as e:
|
||||
logger.error(f"WritingAssistant _generate_continuation error: {e}")
|
||||
# Propagate to ensure frontend does not show stale/generic content
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ def should_bootstrap_linguistic_models() -> bool:
|
||||
}
|
||||
|
||||
# Check if any linguistic-required feature is enabled
|
||||
linguistic_features = {"content_planning", "facebook", "linkedin", "blog_writer", "persona"}
|
||||
linguistic_features = {"content_planning", "facebook", "linkedin", "blog-writer", "persona"}
|
||||
return bool(enabled_features & linguistic_features)
|
||||
|
||||
|
||||
@@ -287,16 +287,12 @@ from alwrity_utils import (
|
||||
def start_backend(enable_reload=False, production_mode=False):
|
||||
"""Start the backend server."""
|
||||
print("==> Starting ALwrity Backend...")
|
||||
# Check for legacy podcast-only demo mode env vars (backward compat)
|
||||
is_legacy_podcast_mode = os.getenv("ALWRITY_PODCAST_ONLY_DEMO_MODE", os.getenv("PODCAST_ONLY_DEMO_MODE", "false")).lower() in {"1", "true", "yes", "on"}
|
||||
enabled = get_enabled_features()
|
||||
is_feature_limited = "all" not in enabled
|
||||
podcast_only_demo_mode = os.getenv("ALWRITY_PODCAST_ONLY_DEMO_MODE", os.getenv("PODCAST_ONLY_DEMO_MODE", "false")).lower() in {"1", "true", "yes", "on"}
|
||||
|
||||
if is_legacy_podcast_mode or is_feature_limited:
|
||||
mode_label = "legacy podcast-only" if is_legacy_podcast_mode else f"feature-limited ({', '.join(sorted(enabled))})"
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"==> {mode_label.upper()} MODE ACTIVE")
|
||||
print(" Non-matching router groups are intentionally skipped.")
|
||||
if podcast_only_demo_mode:
|
||||
print("\n" + "=" * 60)
|
||||
print("==> PODCAST-ONLY DEMO MODE ACTIVE")
|
||||
print(" Non-podcast router groups are intentionally skipped.")
|
||||
print("=" * 60)
|
||||
|
||||
# Set host based on environment and mode
|
||||
@@ -389,12 +385,12 @@ def start_backend(enable_reload=False, production_mode=False):
|
||||
print(f"[DEBUG] Starting uvicorn with host={host} port={port}", flush=True)
|
||||
print("[DEBUG] >>> ABOUT TO CALL UVICORN.RUN() <<<", flush=True)
|
||||
|
||||
# Skip video preflight in feature-limited mode to save memory/time
|
||||
is_feature_limited = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() not in ("", "all")
|
||||
print(f"[DEBUG] Feature-limited mode check: {is_feature_limited}", flush=True)
|
||||
# Skip video preflight in podcast-only mode to save memory/time
|
||||
is_podcast = os.getenv("ALWRITY_ENABLED_FEATURES", "").strip().lower() == "podcast"
|
||||
print(f"[DEBUG] Podcast mode check: {is_podcast}", flush=True)
|
||||
|
||||
if is_feature_limited:
|
||||
print("[DEBUG] Feature-limited mode - skipping video preflight", flush=True)
|
||||
if is_podcast:
|
||||
print("[DEBUG] Podcast mode - skipping video preflight", flush=True)
|
||||
else:
|
||||
# Log diagnostics and assert versions (fail fast if misconfigured)
|
||||
try:
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
def _get_all_historical_usage(self, user_id: str) -> Dict[str, Any]:
|
||||
\ \\Get ALL historical usage data aggregated across all billing periods.\\\
|
||||
|
||||
# Get all usage summaries for the user
|
||||
all_summaries = self.db.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id
|
||||
).order_by(UsageSummary.billing_period.desc()).all()
|
||||
|
||||
if not all_summaries:
|
||||
return {
|
||||
\billing_period\: \all\,
|
||||
\usage_status\: \active\,
|
||||
\total_calls\: 0,
|
||||
\total_tokens\: 0,
|
||||
\total_cost\: 0.0,
|
||||
\avg_response_time\: 0.0,
|
||||
\error_rate\: 0.0,
|
||||
\limits\: self.pricing_service.get_user_limits(user_id),
|
||||
\provider_breakdown\: {},
|
||||
\usage_percentages\: {},
|
||||
\historical_breakdown\: [],
|
||||
\last_updated\: datetime.now().isoformat()
|
||||
}
|
||||
|
||||
# Aggregate all data
|
||||
total_calls = sum(s.total_calls or 0 for s in all_summaries)
|
||||
total_tokens = sum(s.total_tokens or 0 for s in all_summaries)
|
||||
total_cost = sum(float(s.total_cost or 0) for s in all_summaries)
|
||||
|
||||
# Calculate weighted average response time
|
||||
total_weighted_time = sum((s.avg_response_time or 0) * (s.total_calls or 0) for s in all_summaries)
|
||||
avg_response_time = total_weighted_time / total_calls if total_calls > 0 else 0.0
|
||||
|
||||
# Calculate overall error rate
|
||||
total_errors = sum((s.total_calls or 0) * (s.error_rate or 0) / 100 for s in all_summaries)
|
||||
error_rate = (total_errors / total_calls * 100) if total_calls > 0 else 0.0
|
||||
|
||||
# Get user limits
|
||||
limits = self.pricing_service.get_user_limits(user_id)
|
||||
|
||||
# Build historical breakdown
|
||||
historical_breakdown = []
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status_val = s.usage_status.value
|
||||
except:
|
||||
status_val = str(s.usage_status)
|
||||
historical_breakdown.append({
|
||||
\billing_period\: s.billing_period,
|
||||
\total_calls\: s.total_calls or 0,
|
||||
\total_tokens\: s.total_tokens or 0,
|
||||
\total_cost\: float(s.total_cost or 0),
|
||||
\usage_status\: status_val,
|
||||
\updated_at\: s.updated_at.isoformat() if s.updated_at else None
|
||||
})
|
||||
|
||||
# Determine overall status
|
||||
usage_status = \active\
|
||||
for s in all_summaries:
|
||||
try:
|
||||
status = s.usage_status.value
|
||||
except:
|
||||
status = str(s.usage_status)
|
||||
if status == \limit_reached\:
|
||||
usage_status = \limit_reached\
|
||||
break
|
||||
elif status == \warning\ and usage_status != \limit_reached\:
|
||||
usage_status = \warning\
|
||||
|
||||
return {
|
||||
\billing_period\: \all\,
|
||||
\usage_status\: usage_status,
|
||||
\total_calls\: total_calls,
|
||||
\total_tokens\: total_tokens,
|
||||
\total_cost\: round(total_cost, 2),
|
||||
\avg_response_time\: round(avg_response_time, 2),
|
||||
\error_rate\: round(error_rate, 2),
|
||||
\limits\: limits,
|
||||
\provider_breakdown\: {},
|
||||
\usage_percentages\: {},
|
||||
\historical_breakdown\: historical_breakdown,
|
||||
\last_updated\: datetime.now().isoformat()
|
||||
}
|
||||
@@ -1,82 +1,72 @@
|
||||
import React, { useState, useEffect, Suspense } from 'react';
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { BrowserRouter as Router, Routes, Route, Navigate } from 'react-router-dom';
|
||||
import { Box, CircularProgress, Typography } from '@mui/material';
|
||||
import { ClerkProvider, useAuth } from '@clerk/clerk-react';
|
||||
import Wizard from './components/OnboardingWizard/Wizard';
|
||||
import MainDashboard from './components/MainDashboard/MainDashboard';
|
||||
import SEODashboard from './components/SEODashboard/SEODashboard';
|
||||
import ContentPlanningDashboard from './components/ContentPlanningDashboard/ContentPlanningDashboard';
|
||||
import FacebookWriter from './components/FacebookWriter/FacebookWriter';
|
||||
import LinkedInWriter from './components/LinkedInWriter/LinkedInWriter';
|
||||
import BlogWriter from './components/BlogWriter/BlogWriter';
|
||||
import StoryWriter from './components/StoryWriter/StoryWriter';
|
||||
import { StoryProjectList } from './components/StoryWriter/StoryProjectList';
|
||||
import YouTubeCreator from './components/YouTubeCreator/YouTubeCreator';
|
||||
import { CreateStudio, EditStudio, UpscaleStudio, ControlStudio, SocialOptimizer, AssetLibrary, ImageStudioDashboard, FaceSwapStudio, CompressionStudio, ImageProcessingStudio } from './components/ImageStudio';
|
||||
import {
|
||||
VideoStudioDashboard,
|
||||
CreateVideo,
|
||||
AvatarVideo,
|
||||
EnhanceVideo,
|
||||
ExtendVideo,
|
||||
EditVideo,
|
||||
TransformVideo,
|
||||
SocialVideo,
|
||||
FaceSwap,
|
||||
VideoTranslate,
|
||||
VideoBackgroundRemover,
|
||||
AddAudioToVideo,
|
||||
LibraryVideo,
|
||||
} from './components/VideoStudio';
|
||||
import {
|
||||
ProductMarketingDashboard,
|
||||
ProductPhotoshootStudio,
|
||||
ProductAnimationStudio,
|
||||
ProductVideoStudio,
|
||||
ProductAvatarStudio,
|
||||
} from './components/ProductMarketing';
|
||||
import PodcastDashboard from './components/PodcastMaker/PodcastDashboard';
|
||||
import PricingPage from './components/Pricing/PricingPage';
|
||||
import WixTestPage from './components/WixTestPage/WixTestPage';
|
||||
import WixCallbackPage from './components/WixCallbackPage/WixCallbackPage';
|
||||
import WordPressCallbackPage from './components/WordPressCallbackPage/WordPressCallbackPage';
|
||||
import BingCallbackPage from './components/BingCallbackPage/BingCallbackPage';
|
||||
import BingAnalyticsStorage from './components/BingAnalyticsStorage/BingAnalyticsStorage';
|
||||
import ResearchDashboard from './pages/ResearchDashboard';
|
||||
import IntentResearchTest from './pages/IntentResearchTest';
|
||||
import SchedulerDashboard from './pages/SchedulerDashboard';
|
||||
import BillingPage from './pages/BillingPage';
|
||||
import ApprovalsPage from './pages/ApprovalsPage';
|
||||
import TeamActivityPage from './pages/TeamActivityPage';
|
||||
import StripeDisputesDashboard from './pages/StripeDisputesDashboard';
|
||||
import ProtectedRoute from './components/shared/ProtectedRoute';
|
||||
import GSCAuthCallback from './components/SEODashboard/components/GSCAuthCallback';
|
||||
import Landing from './components/Landing/Landing';
|
||||
import ErrorBoundary from './components/shared/ErrorBoundary';
|
||||
import ErrorBoundaryTest from './components/shared/ErrorBoundaryTest';
|
||||
import { OnboardingProvider } from './contexts/OnboardingContext';
|
||||
import { SubscriptionProvider } from './contexts/SubscriptionContext';
|
||||
import InitialRouteHandler from './components/App/InitialRouteHandler';
|
||||
import TokenInstaller from './components/App/TokenInstaller';
|
||||
import { ConditionalCopilotKit, AuthenticatedCopilotWrapper } from './components/App/CopilotWrappers';
|
||||
import Landing from './components/Landing/Landing';
|
||||
import LazyLoadingFallback from './components/shared/LazyLoadingFallback';
|
||||
import FeatureRoute from './components/shared/FeatureRoute';
|
||||
|
||||
// ─── Lazy loaded route components ───────────────────────────────────────────
|
||||
// Default exports
|
||||
const Wizard = React.lazy(() => import('./components/OnboardingWizard/Wizard'));
|
||||
const MainDashboard = React.lazy(() => import('./components/MainDashboard/MainDashboard'));
|
||||
const SEODashboard = React.lazy(() => import('./components/SEODashboard/SEODashboard'));
|
||||
const ContentPlanningDashboard = React.lazy(() => import('./components/ContentPlanningDashboard/ContentPlanningDashboard'));
|
||||
const FacebookWriter = React.lazy(() => import('./components/FacebookWriter/FacebookWriter'));
|
||||
const LinkedInWriter = React.lazy(() => import('./components/LinkedInWriter/LinkedInWriter'));
|
||||
const BlogWriter = React.lazy(() => import('./components/BlogWriter/BlogWriter'));
|
||||
const StoryWriter = React.lazy(() => import('./components/StoryWriter/StoryWriter'));
|
||||
const YouTubeCreator = React.lazy(() => import('./components/YouTubeCreator/YouTubeCreator'));
|
||||
const PodcastDashboard = React.lazy(() => import('./components/PodcastMaker/PodcastDashboard'));
|
||||
const PricingPage = React.lazy(() => import('./components/Pricing/PricingPage'));
|
||||
const WixTestPage = React.lazy(() => import('./components/WixTestPage/WixTestPage'));
|
||||
const WixCallbackPage = React.lazy(() => import('./components/WixCallbackPage/WixCallbackPage'));
|
||||
const WordPressCallbackPage = React.lazy(() => import('./components/WordPressCallbackPage/WordPressCallbackPage'));
|
||||
const BingCallbackPage = React.lazy(() => import('./components/BingCallbackPage/BingCallbackPage'));
|
||||
const BingAnalyticsStorage = React.lazy(() => import('./components/BingAnalyticsStorage/BingAnalyticsStorage'));
|
||||
const ResearchDashboard = React.lazy(() => import('./pages/ResearchDashboard'));
|
||||
const IntentResearchTest = React.lazy(() => import('./pages/IntentResearchTest'));
|
||||
const SchedulerDashboard = React.lazy(() => import('./pages/SchedulerDashboard'));
|
||||
const BillingPage = React.lazy(() => import('./pages/BillingPage'));
|
||||
const ApprovalsPage = React.lazy(() => import('./pages/ApprovalsPage'));
|
||||
const TeamActivityPage = React.lazy(() => import('./pages/TeamActivityPage'));
|
||||
const StripeDisputesDashboard = React.lazy(() => import('./pages/StripeDisputesDashboard'));
|
||||
const GSCAuthCallback = React.lazy(() => import('./components/SEODashboard/components/GSCAuthCallback'));
|
||||
const ErrorBoundaryTest = React.lazy(() => import('./components/shared/ErrorBoundaryTest'));
|
||||
|
||||
// Named exports — need .then() wrapper to resolve default
|
||||
const StoryProjectList = React.lazy(() => import('./components/StoryWriter/StoryProjectList').then(m => ({ default: m.StoryProjectList })));
|
||||
|
||||
// ImageStudio barrel (10 named exports)
|
||||
const CreateStudio = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.CreateStudio })));
|
||||
const EditStudio = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.EditStudio })));
|
||||
const UpscaleStudio = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.UpscaleStudio })));
|
||||
const ControlStudio = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.ControlStudio })));
|
||||
const SocialOptimizer = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.SocialOptimizer })));
|
||||
const AssetLibrary = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.AssetLibrary })));
|
||||
const ImageStudioDashboard = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.ImageStudioDashboard })));
|
||||
const FaceSwapStudio = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.FaceSwapStudio })));
|
||||
const CompressionStudio = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.CompressionStudio })));
|
||||
const ImageProcessingStudio = React.lazy(() => import('./components/ImageStudio').then(m => ({ default: m.ImageProcessingStudio })));
|
||||
|
||||
// VideoStudio barrel (13 named exports)
|
||||
const VideoStudioDashboard = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.VideoStudioDashboard })));
|
||||
const CreateVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.CreateVideo })));
|
||||
const AvatarVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.AvatarVideo })));
|
||||
const EnhanceVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.EnhanceVideo })));
|
||||
const ExtendVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.ExtendVideo })));
|
||||
const EditVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.EditVideo })));
|
||||
const TransformVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.TransformVideo })));
|
||||
const SocialVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.SocialVideo })));
|
||||
const FaceSwap = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.FaceSwap })));
|
||||
const VideoTranslate = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.VideoTranslate })));
|
||||
const VideoBackgroundRemover = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.VideoBackgroundRemover })));
|
||||
const AddAudioToVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.AddAudioToVideo })));
|
||||
const LibraryVideo = React.lazy(() => import('./components/VideoStudio').then(m => ({ default: m.LibraryVideo })));
|
||||
|
||||
// ProductMarketing barrel (5 named exports)
|
||||
const ProductMarketingDashboard = React.lazy(() => import('./components/ProductMarketing').then(m => ({ default: m.ProductMarketingDashboard })));
|
||||
const ProductPhotoshootStudio = React.lazy(() => import('./components/ProductMarketing').then(m => ({ default: m.ProductPhotoshootStudio })));
|
||||
const ProductAnimationStudio = React.lazy(() => import('./components/ProductMarketing').then(m => ({ default: m.ProductAnimationStudio })));
|
||||
const ProductVideoStudio = React.lazy(() => import('./components/ProductMarketing').then(m => ({ default: m.ProductVideoStudio })));
|
||||
const ProductAvatarStudio = React.lazy(() => import('./components/ProductMarketing').then(m => ({ default: m.ProductAvatarStudio })));
|
||||
// interface OnboardingStatus {
|
||||
// onboarding_required: boolean;
|
||||
// onboarding_complete: boolean;
|
||||
// current_step?: number;
|
||||
// total_steps?: number;
|
||||
// completion_percentage?: number;
|
||||
// }
|
||||
|
||||
// Root route that chooses Landing (signed out) or InitialRouteHandler (signed in)
|
||||
const RootRoute: React.FC = () => {
|
||||
@@ -172,81 +162,78 @@ const App: React.FC = () => {
|
||||
<AuthenticatedCopilotWrapper apiKey={copilotApiKey}>
|
||||
<ConditionalCopilotKit>
|
||||
<TokenInstaller />
|
||||
<Suspense fallback={<LazyLoadingFallback />}>
|
||||
<Routes>
|
||||
<Route path="/" element={<RootRoute />} />
|
||||
<Route
|
||||
path="/onboarding"
|
||||
element={
|
||||
<ErrorBoundary context="Onboarding Wizard" showDetails>
|
||||
<Wizard />
|
||||
</ErrorBoundary>
|
||||
}
|
||||
/>
|
||||
{/* Error Boundary Testing - Development Only */}
|
||||
{process.env.NODE_ENV === 'development' && (
|
||||
<Route path="/error-test" element={<ErrorBoundaryTest />} />
|
||||
)}
|
||||
<Route path="/dashboard" element={<ProtectedRoute><MainDashboard /></ProtectedRoute>} />
|
||||
<Route path="/seo" element={<ProtectedRoute><FeatureRoute feature="seo"><SEODashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/seo-dashboard" element={<ProtectedRoute><FeatureRoute feature="seo"><SEODashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/content-planning" element={<ProtectedRoute><FeatureRoute feature="content-planning"><ContentPlanningDashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/facebook-writer" element={<ProtectedRoute><FeatureRoute feature="social"><FacebookWriter /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/linkedin-writer" element={<ProtectedRoute><FeatureRoute feature="social"><LinkedInWriter /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/blog-writer" element={<ProtectedRoute><FeatureRoute feature="blog_writer"><BlogWriter /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/story-writer" element={<ProtectedRoute><FeatureRoute feature="story"><StoryWriter /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/story-projects" element={<ProtectedRoute><FeatureRoute feature="story"><StoryProjectList /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/youtube-creator" element={<ProtectedRoute><FeatureRoute feature="youtube"><YouTubeCreator /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/podcast-maker" element={<ProtectedRoute><FeatureRoute feature="podcast"><PodcastDashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-studio" element={<ProtectedRoute><FeatureRoute feature="image"><ImageStudioDashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio" element={<ProtectedRoute><FeatureRoute feature="video"><VideoStudioDashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/create" element={<ProtectedRoute><FeatureRoute feature="video"><CreateVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/avatar" element={<ProtectedRoute><FeatureRoute feature="video"><AvatarVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/enhance" element={<ProtectedRoute><FeatureRoute feature="video"><EnhanceVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/extend" element={<ProtectedRoute><FeatureRoute feature="video"><ExtendVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/edit" element={<ProtectedRoute><FeatureRoute feature="video"><EditVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/transform" element={<ProtectedRoute><FeatureRoute feature="video"><TransformVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/social" element={<ProtectedRoute><FeatureRoute feature="video"><SocialVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/face-swap" element={<ProtectedRoute><FeatureRoute feature="video"><FaceSwap /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-translate" element={<ProtectedRoute><FeatureRoute feature="video"><VideoTranslate /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-background-remover" element={<ProtectedRoute><FeatureRoute feature="video"><VideoBackgroundRemover /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/add-audio-to-video" element={<ProtectedRoute><FeatureRoute feature="video"><AddAudioToVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/video-studio/library" element={<ProtectedRoute><FeatureRoute feature="video"><LibraryVideo /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-generator" element={<ProtectedRoute><FeatureRoute feature="image"><CreateStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-editor" element={<ProtectedRoute><FeatureRoute feature="image"><EditStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-upscale" element={<ProtectedRoute><FeatureRoute feature="image"><UpscaleStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-control" element={<ProtectedRoute><FeatureRoute feature="image"><ControlStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-studio/face-swap" element={<ProtectedRoute><FeatureRoute feature="image"><FaceSwapStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-studio/compress" element={<ProtectedRoute><FeatureRoute feature="image"><CompressionStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-studio/processing" element={<ProtectedRoute><FeatureRoute feature="image"><ImageProcessingStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/image-studio/social-optimizer" element={<ProtectedRoute><FeatureRoute feature="image"><SocialOptimizer /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/asset-library" element={<ProtectedRoute><FeatureRoute feature="asset-library"><AssetLibrary /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator" element={<ProtectedRoute><FeatureRoute feature="campaign"><ProductMarketingDashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/photoshoot" element={<ProtectedRoute><FeatureRoute feature="campaign"><ProductPhotoshootStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/animation" element={<ProtectedRoute><FeatureRoute feature="campaign"><ProductAnimationStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/video" element={<ProtectedRoute><FeatureRoute feature="campaign"><ProductVideoStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/avatar" element={<ProtectedRoute><FeatureRoute feature="campaign"><ProductAvatarStudio /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/product-marketing" element={<Navigate to="/campaign-creator" replace />} />
|
||||
<Route path="/scheduler-dashboard" element={<ProtectedRoute><FeatureRoute feature="scheduler"><SchedulerDashboard /></FeatureRoute></ProtectedRoute>} />
|
||||
<Route path="/billing" element={<ProtectedRoute><BillingPage /></ProtectedRoute>} />
|
||||
<Route path="/approvals" element={<ProtectedRoute><ApprovalsPage /></ProtectedRoute>} />
|
||||
<Route path="/team-activity" element={<ProtectedRoute><TeamActivityPage /></ProtectedRoute>} />
|
||||
<Route path="/stripe-disputes" element={<ProtectedRoute><StripeDisputesDashboard /></ProtectedRoute>} />
|
||||
<Route path="/pricing" element={<PricingPage />} />
|
||||
<Route path="/research-test" element={<FeatureRoute feature="research"><ResearchDashboard /></FeatureRoute>} />
|
||||
<Route path="/research-dashboard" element={<FeatureRoute feature="research"><ResearchDashboard /></FeatureRoute>} />
|
||||
<Route path="/alwrity-researcher" element={<FeatureRoute feature="research"><ResearchDashboard /></FeatureRoute>} />
|
||||
<Route path="/intent-research" element={<FeatureRoute feature="research"><IntentResearchTest /></FeatureRoute>} />
|
||||
<Route path="/wix-test" element={<FeatureRoute feature="wix"><WixTestPage /></FeatureRoute>} />
|
||||
<Route path="/wix-test-direct" element={<FeatureRoute feature="wix"><WixTestPage /></FeatureRoute>} />
|
||||
{/* Auth callbacks — always accessible (needed for OAuth flow) */}
|
||||
<Route path="/wix/callback" element={<WixCallbackPage />} />
|
||||
<Route path="/wp/callback" element={<WordPressCallbackPage />} />
|
||||
<Route path="/gsc/callback" element={<GSCAuthCallback />} />
|
||||
<Route path="/bing/callback" element={<BingCallbackPage />} />
|
||||
<Route path="/bing-analytics-storage" element={<ProtectedRoute><FeatureRoute feature="bing"><BingAnalyticsStorage /></FeatureRoute></ProtectedRoute>} />
|
||||
</Routes>
|
||||
</Suspense>
|
||||
<Routes>
|
||||
<Route path="/" element={<RootRoute />} />
|
||||
<Route
|
||||
path="/onboarding"
|
||||
element={
|
||||
<ErrorBoundary context="Onboarding Wizard" showDetails>
|
||||
<Wizard />
|
||||
</ErrorBoundary>
|
||||
}
|
||||
/>
|
||||
{/* Error Boundary Testing - Development Only */}
|
||||
{process.env.NODE_ENV === 'development' && (
|
||||
<Route path="/error-test" element={<ErrorBoundaryTest />} />
|
||||
)}
|
||||
<Route path="/dashboard" element={<ProtectedRoute><MainDashboard /></ProtectedRoute>} />
|
||||
<Route path="/seo" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
|
||||
<Route path="/seo-dashboard" element={<ProtectedRoute><SEODashboard /></ProtectedRoute>} />
|
||||
<Route path="/content-planning" element={<ProtectedRoute><ContentPlanningDashboard /></ProtectedRoute>} />
|
||||
<Route path="/facebook-writer" element={<ProtectedRoute><FacebookWriter /></ProtectedRoute>} />
|
||||
<Route path="/linkedin-writer" element={<ProtectedRoute><LinkedInWriter /></ProtectedRoute>} />
|
||||
<Route path="/blog-writer" element={<ProtectedRoute><BlogWriter /></ProtectedRoute>} />
|
||||
<Route path="/story-writer" element={<ProtectedRoute><StoryWriter /></ProtectedRoute>} />
|
||||
<Route path="/story-projects" element={<ProtectedRoute><StoryProjectList /></ProtectedRoute>} />
|
||||
<Route path="/youtube-creator" element={<ProtectedRoute><YouTubeCreator /></ProtectedRoute>} />
|
||||
<Route path="/podcast-maker" element={<ProtectedRoute><PodcastDashboard /></ProtectedRoute>} />
|
||||
<Route path="/image-studio" element={<ProtectedRoute><ImageStudioDashboard /></ProtectedRoute>} />
|
||||
<Route path="/video-studio" element={<ProtectedRoute><VideoStudioDashboard /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/create" element={<ProtectedRoute><CreateVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/avatar" element={<ProtectedRoute><AvatarVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/enhance" element={<ProtectedRoute><EnhanceVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/extend" element={<ProtectedRoute><ExtendVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/edit" element={<ProtectedRoute><EditVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/transform" element={<ProtectedRoute><TransformVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/social" element={<ProtectedRoute><SocialVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/face-swap" element={<ProtectedRoute><FaceSwap /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-translate" element={<ProtectedRoute><VideoTranslate /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/video-background-remover" element={<ProtectedRoute><VideoBackgroundRemover /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/add-audio-to-video" element={<ProtectedRoute><AddAudioToVideo /></ProtectedRoute>} />
|
||||
<Route path="/video-studio/library" element={<ProtectedRoute><LibraryVideo /></ProtectedRoute>} />
|
||||
<Route path="/image-generator" element={<ProtectedRoute><CreateStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-editor" element={<ProtectedRoute><EditStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-upscale" element={<ProtectedRoute><UpscaleStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-control" element={<ProtectedRoute><ControlStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/face-swap" element={<ProtectedRoute><FaceSwapStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/compress" element={<ProtectedRoute><CompressionStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/processing" element={<ProtectedRoute><ImageProcessingStudio /></ProtectedRoute>} />
|
||||
<Route path="/image-studio/social-optimizer" element={<ProtectedRoute><SocialOptimizer /></ProtectedRoute>} />
|
||||
<Route path="/asset-library" element={<ProtectedRoute><AssetLibrary /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator" element={<ProtectedRoute><ProductMarketingDashboard /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/photoshoot" element={<ProtectedRoute><ProductPhotoshootStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/animation" element={<ProtectedRoute><ProductAnimationStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/video" element={<ProtectedRoute><ProductVideoStudio /></ProtectedRoute>} />
|
||||
<Route path="/campaign-creator/avatar" element={<ProtectedRoute><ProductAvatarStudio /></ProtectedRoute>} />
|
||||
<Route path="/product-marketing" element={<Navigate to="/campaign-creator" replace />} />
|
||||
<Route path="/scheduler-dashboard" element={<ProtectedRoute><SchedulerDashboard /></ProtectedRoute>} />
|
||||
<Route path="/billing" element={<ProtectedRoute><BillingPage /></ProtectedRoute>} />
|
||||
<Route path="/approvals" element={<ProtectedRoute><ApprovalsPage /></ProtectedRoute>} />
|
||||
<Route path="/team-activity" element={<ProtectedRoute><TeamActivityPage /></ProtectedRoute>} />
|
||||
<Route path="/stripe-disputes" element={<ProtectedRoute><StripeDisputesDashboard /></ProtectedRoute>} />
|
||||
<Route path="/pricing" element={<PricingPage />} />
|
||||
<Route path="/research-test" element={<ResearchDashboard />} />
|
||||
<Route path="/research-dashboard" element={<ResearchDashboard />} />
|
||||
<Route path="/alwrity-researcher" element={<ResearchDashboard />} />
|
||||
<Route path="/intent-research" element={<IntentResearchTest />} />
|
||||
<Route path="/wix-test" element={<WixTestPage />} />
|
||||
<Route path="/wix-test-direct" element={<WixTestPage />} />
|
||||
<Route path="/wix/callback" element={<WixCallbackPage />} />
|
||||
<Route path="/wp/callback" element={<WordPressCallbackPage />} />
|
||||
<Route path="/gsc/callback" element={<GSCAuthCallback />} />
|
||||
<Route path="/bing/callback" element={<BingCallbackPage />} />
|
||||
<Route path="/bing-analytics-storage" element={<ProtectedRoute><BingAnalyticsStorage /></ProtectedRoute>} />
|
||||
</Routes>
|
||||
</ConditionalCopilotKit>
|
||||
</AuthenticatedCopilotWrapper>
|
||||
</Router>
|
||||
@@ -274,4 +261,4 @@ const App: React.FC = () => {
|
||||
);
|
||||
};
|
||||
|
||||
export default App;
|
||||
export default App;
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
import { ResearchMode, ResearchProvider } from '../services/blogWriterApi';
|
||||
import { apiClient } from './client';
|
||||
import { isFeatureOnlyMode } from '../utils/demoMode';
|
||||
import { isPodcastOnlyDemoMode } from '../utils/demoMode';
|
||||
|
||||
export interface ProviderAvailability {
|
||||
google_available: boolean;
|
||||
@@ -130,9 +130,9 @@ let pendingConfigRequest: Promise<ResearchConfigResponse> | null = null;
|
||||
* and research persona from the unified /api/research/config endpoint.
|
||||
*/
|
||||
export const getResearchConfig = async (): Promise<ResearchConfigResponse> => {
|
||||
// Skip in feature-limited mode — backend always provides AI-generated research_queries
|
||||
if (isFeatureOnlyMode()) {
|
||||
throw new Error('Research config not available in feature-limited mode');
|
||||
// Skip in podcast-only mode — backend always provides AI-generated research_queries
|
||||
if (isPodcastOnlyDemoMode()) {
|
||||
throw new Error('Research config not available in podcast-only mode');
|
||||
}
|
||||
|
||||
// If a request is already in flight, return the same promise
|
||||
|
||||
@@ -5,6 +5,7 @@ import { CopilotKit } from "@copilotkit/react-core";
|
||||
import { CopilotKitHealthProvider } from '../../contexts/CopilotKitHealthContext';
|
||||
import CopilotKitDegradedBanner from '../shared/CopilotKitDegradedBanner';
|
||||
import ErrorBoundary from '../shared/ErrorBoundary';
|
||||
import { isPodcastOnlyDemoMode } from '../../utils/demoMode';
|
||||
|
||||
interface ConditionalCopilotKitProps {
|
||||
children: React.ReactNode;
|
||||
@@ -23,12 +24,10 @@ export const AuthenticatedCopilotWrapper: React.FC<AuthenticatedCopilotWrapperPr
|
||||
const { isSignedIn } = useAuth();
|
||||
const location = useLocation();
|
||||
|
||||
// Only fully exclude CopilotKit when user is not signed in or on onboarding
|
||||
// Feature-limited mode (blog_writer, etc.) still needs CopilotKit providers
|
||||
// because BlogWriter uses useCopilotAction and useCopilotKitHealth hooks
|
||||
const shouldExcludeCopilotKit = !isSignedIn || location.pathname.startsWith('/onboarding');
|
||||
const isPodcastOnly = isPodcastOnlyDemoMode();
|
||||
const shouldExcludeCopilot = !isSignedIn || location.pathname.startsWith('/onboarding') || isPodcastOnly;
|
||||
|
||||
if (shouldExcludeCopilotKit) {
|
||||
if (shouldExcludeCopilot) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import { Navigate, useLocation } from 'react-router-dom';
|
||||
import { Box, CircularProgress, Typography } from '@mui/material';
|
||||
import { useOnboarding } from '../../contexts/OnboardingContext';
|
||||
import { useSubscription } from '../../contexts/SubscriptionContext';
|
||||
import { useOAuthTokenAlerts } from '../../hooks/useOAuthTokenAlerts';
|
||||
import { shouldSkipOnboarding, getDefaultLandingRoute, isFeatureOnlyMode, getSingleFeature } from '../../utils/demoMode';
|
||||
import { restoreNavigationState } from '../../utils/navigationState';
|
||||
import { shouldSkipOnboarding } from '../../utils/demoMode';
|
||||
import ConnectionErrorPage from '../shared/ConnectionErrorPage';
|
||||
|
||||
const CHECKOUT_POLL_INTERVAL_MS = 2000;
|
||||
const CHECKOUT_POLL_MAX_ATTEMPTS = 10;
|
||||
|
||||
const InitialRouteHandler: React.FC = () => {
|
||||
// Helper to log and navigate in a single place
|
||||
const navigateAndLog = (to: string) => {
|
||||
console.log(`InitialRouteHandler: Redirecting to ${to}`);
|
||||
return <Navigate to={to} replace />;
|
||||
@@ -26,24 +23,12 @@ const InitialRouteHandler: React.FC = () => {
|
||||
hasError: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
// Post-checkout polling state
|
||||
const [checkoutPolling, setCheckoutPolling] = useState(false);
|
||||
const checkoutPollAttempts = useRef(0);
|
||||
// Track whether the initial subscription check has completed
|
||||
// Prevents premature routing decisions before we know the user's plan
|
||||
const [initialCheckDone, setInitialCheckDone] = useState(false);
|
||||
|
||||
const urlParams = new URLSearchParams(location.search);
|
||||
const isCheckoutSuccess = urlParams.get('subscription') === 'success';
|
||||
const returnTo = urlParams.get('return_to');
|
||||
|
||||
|
||||
useOAuthTokenAlerts({
|
||||
enabled: subscription?.active === true,
|
||||
interval: 60000,
|
||||
});
|
||||
|
||||
// Initial subscription check with retries
|
||||
useEffect(() => {
|
||||
const timeoutId = setTimeout(async () => {
|
||||
const maxRetries = 3;
|
||||
@@ -53,91 +38,47 @@ const InitialRouteHandler: React.FC = () => {
|
||||
break;
|
||||
} catch (err) {
|
||||
console.error(`App: Subscription check attempt ${attempt + 1} failed:`, err);
|
||||
|
||||
|
||||
const isConnectionError = err instanceof Error && (err.name === 'NetworkError' || err.name === 'ConnectionError');
|
||||
|
||||
|
||||
if (isConnectionError && attempt < maxRetries - 1) {
|
||||
const delay = 1000 * Math.pow(2, attempt);
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (attempt === maxRetries - 1 || !isConnectionError) {
|
||||
if (isConnectionError) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err as Error,
|
||||
});
|
||||
}
|
||||
}
|
||||
const delay = 1000 * Math.pow(2, attempt);
|
||||
await new Promise(resolve => setTimeout(resolve, delay));
|
||||
continue;
|
||||
}
|
||||
|
||||
if (attempt === maxRetries - 1 || !isConnectionError) {
|
||||
if (isConnectionError) {
|
||||
setConnectionError({
|
||||
hasError: true,
|
||||
error: err as Error,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Mark initial check as done regardless of success/failure
|
||||
setInitialCheckDone(true);
|
||||
}, 100);
|
||||
|
||||
|
||||
return () => clearTimeout(timeoutId);
|
||||
}, []);
|
||||
|
||||
// Handle post-checkout: when Stripe redirects back with ?subscription=success,
|
||||
// the webhook may not have processed yet. Poll until subscription becomes active.
|
||||
useEffect(() => {
|
||||
if (!isCheckoutSuccess) return;
|
||||
if (subscription?.active && subscription.plan !== 'none' && subscription.plan !== 'free') {
|
||||
// Webhook has processed — subscription is active, stop polling
|
||||
if (checkoutPolling) {
|
||||
console.log('InitialRouteHandler: Checkout success — subscription confirmed active, stopping poll');
|
||||
setCheckoutPolling(false);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// Start polling if webhook hasn't processed yet
|
||||
if (!checkoutPolling && checkoutPollAttempts.current === 0) {
|
||||
console.log('InitialRouteHandler: Checkout success — subscription not yet active, starting poll');
|
||||
setCheckoutPolling(true);
|
||||
}
|
||||
}, [isCheckoutSuccess, subscription, checkoutPolling]);
|
||||
|
||||
// Polling effect for post-checkout
|
||||
useEffect(() => {
|
||||
if (!checkoutPolling) return;
|
||||
|
||||
if (checkoutPollAttempts.current >= CHECKOUT_POLL_MAX_ATTEMPTS) {
|
||||
console.log('InitialRouteHandler: Checkout polling exhausted — proceeding with current state');
|
||||
setCheckoutPolling(false);
|
||||
return;
|
||||
}
|
||||
|
||||
const timer = setTimeout(async () => {
|
||||
checkoutPollAttempts.current += 1;
|
||||
console.log(`InitialRouteHandler: Checkout poll attempt ${checkoutPollAttempts.current}/${CHECKOUT_POLL_MAX_ATTEMPTS}`);
|
||||
try {
|
||||
await checkSubscription();
|
||||
} catch (err) {
|
||||
console.error('InitialRouteHandler: Checkout poll check failed:', err);
|
||||
}
|
||||
}, CHECKOUT_POLL_INTERVAL_MS);
|
||||
|
||||
return () => clearTimeout(timer);
|
||||
}, [checkoutPolling, checkSubscription]);
|
||||
|
||||
// Initialize onboarding when subscription is confirmed (but not on checkout success — let redirect happen)
|
||||
const urlParams = new URLSearchParams(location.search);
|
||||
const isCheckoutSuccess = urlParams.get('subscription') === 'success';
|
||||
|
||||
useEffect(() => {
|
||||
if (subscription && !subscriptionLoading) {
|
||||
const isNewUser = !subscription || subscription.plan === 'none';
|
||||
|
||||
|
||||
console.log('InitialRouteHandler: Subscription data received:', {
|
||||
plan: subscription.plan,
|
||||
active: subscription.active,
|
||||
isNewUser,
|
||||
subscriptionLoading,
|
||||
isCheckoutSuccess,
|
||||
subscriptionLoading
|
||||
});
|
||||
|
||||
|
||||
if (subscription.active && !isNewUser) {
|
||||
console.log('InitialRouteHandler: Subscription confirmed, initializing onboarding...');
|
||||
|
||||
|
||||
if (!isCheckoutSuccess) {
|
||||
initializeOnboarding();
|
||||
}
|
||||
@@ -145,85 +86,9 @@ const InitialRouteHandler: React.FC = () => {
|
||||
}
|
||||
}, [subscription, subscriptionLoading, initializeOnboarding, isCheckoutSuccess]);
|
||||
|
||||
// --- Render decisions ---
|
||||
|
||||
// Wait for initial subscription check before making routing decisions.
|
||||
// Without this, a null subscription (before API response) can trigger
|
||||
// incorrect redirects (e.g., to feature routes instead of /pricing).
|
||||
if (!initialCheckDone && !connectionError.hasError) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Show polling spinner during post-checkout webhook wait
|
||||
if (checkoutPolling) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Activating your subscription...
|
||||
</Typography>
|
||||
<Typography variant="body2" color="textSecondary">
|
||||
This may take a few seconds.
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
// Post-checkout: subscription is now active (or poll exhausted)
|
||||
if (isCheckoutSuccess && subscription?.active && subscription.plan !== 'none' && subscription.plan !== 'free') {
|
||||
// Restore navigation state (saved before Stripe redirect)
|
||||
const navState = restoreNavigationState();
|
||||
const redirectTo = returnTo || navState?.path;
|
||||
|
||||
if (redirectTo && redirectTo !== '/pricing' && redirectTo !== '/onboarding') {
|
||||
console.log(`InitialRouteHandler: Checkout success — redirecting to saved page: ${redirectTo}`);
|
||||
return navigateAndLog(redirectTo);
|
||||
}
|
||||
|
||||
if (shouldSkipOnboarding()) {
|
||||
const route = getDefaultLandingRoute();
|
||||
console.log(`InitialRouteHandler: Checkout success in demo mode → ${route}`);
|
||||
return navigateAndLog(route);
|
||||
}
|
||||
|
||||
if (!isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Checkout success — onboarding incomplete → Onboarding');
|
||||
return navigateAndLog('/onboarding');
|
||||
}
|
||||
|
||||
console.log('InitialRouteHandler: Checkout success → Dashboard');
|
||||
return navigateAndLog('/dashboard');
|
||||
}
|
||||
|
||||
// Checkout success but subscription still not active after polling — treat as inactive
|
||||
// SubscriptionContext will show the expired modal
|
||||
if (isCheckoutSuccess && (!subscription?.active || subscription.plan === 'none' || subscription.plan === 'free')) {
|
||||
console.log('InitialRouteHandler: Checkout success but subscription not yet active — showing pricing');
|
||||
if (shouldSkipOnboarding()) {
|
||||
return navigateAndLog(getDefaultLandingRoute());
|
||||
}
|
||||
return <Navigate to="/pricing" replace />;
|
||||
if (isCheckoutSuccess && subscription?.active && shouldSkipOnboarding()) {
|
||||
console.log('InitialRouteHandler: Early redirect - Stripe checkout success in demo mode → Podcast Maker');
|
||||
return navigateAndLog("/podcast-maker");
|
||||
}
|
||||
|
||||
if (connectionError.hasError) {
|
||||
@@ -263,9 +128,9 @@ const InitialRouteHandler: React.FC = () => {
|
||||
subscription: subscription ? { plan: subscription.plan, active: subscription.active } : null,
|
||||
subscriptionLoading,
|
||||
loading,
|
||||
data: !!data,
|
||||
data: !!data
|
||||
});
|
||||
const isActiveSubscriber = Boolean(subscription && subscription.active && subscription.plan !== 'none' && subscription.plan !== 'free');
|
||||
const isActiveSubscriber = Boolean(subscription && subscription.active && subscription.plan !== 'none');
|
||||
console.log('InitialRouteHandler: isActiveSubscriber =', isActiveSubscriber);
|
||||
const waitingForOnboardingInit = !isDemoMode && isActiveSubscriber && (loading || !data);
|
||||
if (waitingForOnboardingInit) {
|
||||
@@ -327,15 +192,10 @@ const InitialRouteHandler: React.FC = () => {
|
||||
|
||||
if (!subscription) {
|
||||
if (isOnboardingComplete) {
|
||||
if (isDemoMode) {
|
||||
const route = getDefaultLandingRoute();
|
||||
console.log(`InitialRouteHandler: Onboarding complete, no sub, demo mode → ${route}`);
|
||||
return navigateAndLog(route);
|
||||
}
|
||||
console.log('InitialRouteHandler: Onboarding complete but no subscription data → Dashboard (allow access)');
|
||||
return navigateAndLog("/dashboard");
|
||||
}
|
||||
|
||||
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
@@ -353,19 +213,43 @@ const InitialRouteHandler: React.FC = () => {
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (shouldSkipOnboarding()) {
|
||||
const route = getDefaultLandingRoute();
|
||||
console.log(`InitialRouteHandler: Demo mode - no subscription but allowing access to ${route}`);
|
||||
return navigateAndLog(route);
|
||||
|
||||
if (!subscription) {
|
||||
if (isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: Onboarding complete but no subscription data → Dashboard (allow access)');
|
||||
return navigateAndLog("/dashboard");
|
||||
}
|
||||
|
||||
if (subscriptionLoading) {
|
||||
return (
|
||||
<Box
|
||||
display="flex"
|
||||
flexDirection="column"
|
||||
alignItems="center"
|
||||
justifyContent="center"
|
||||
minHeight="100vh"
|
||||
gap={2}
|
||||
>
|
||||
<CircularProgress size={60} />
|
||||
<Typography variant="h6" color="textSecondary">
|
||||
Checking subscription...
|
||||
</Typography>
|
||||
</Box>
|
||||
);
|
||||
}
|
||||
|
||||
if (shouldSkipOnboarding()) {
|
||||
console.log('InitialRouteHandler: Demo mode - no subscription but allowing access to podcast-maker');
|
||||
return navigateAndLog("/podcast-maker");
|
||||
}
|
||||
|
||||
console.log('InitialRouteHandler: No subscription data after check → Pricing page');
|
||||
return navigateAndLog("/pricing");
|
||||
}
|
||||
|
||||
console.log('InitialRouteHandler: No subscription data after check → Pricing page');
|
||||
return navigateAndLog("/pricing");
|
||||
}
|
||||
|
||||
const isNewUser = !subscription || subscription.plan === 'none' || subscription.plan === 'free';
|
||||
|
||||
const isNewUser = !subscription || subscription.plan === 'none';
|
||||
|
||||
if (isNewUser || !subscription.active) {
|
||||
console.log('InitialRouteHandler: No active subscription - modal will be shown by SubscriptionContext');
|
||||
if (isNewUser) {
|
||||
@@ -378,21 +262,15 @@ const InitialRouteHandler: React.FC = () => {
|
||||
if (!isOnboardingComplete) {
|
||||
console.log('InitialRouteHandler: isOnboardingComplete = false, shouldSkipOnboarding() =', shouldSkipOnboarding());
|
||||
if (shouldSkipOnboarding()) {
|
||||
const route = getDefaultLandingRoute();
|
||||
console.log(`InitialRouteHandler: Demo mode - skipping onboarding → ${route}`);
|
||||
return navigateAndLog(route);
|
||||
console.log('InitialRouteHandler: Demo mode - skipping onboarding → Podcast Maker');
|
||||
return navigateAndLog("/podcast-maker");
|
||||
}
|
||||
console.log('InitialRouteHandler: Subscription active but onboarding incomplete → Onboarding');
|
||||
return navigateAndLog("/onboarding");
|
||||
}
|
||||
|
||||
if (isDemoMode) {
|
||||
const route = getDefaultLandingRoute();
|
||||
console.log(`InitialRouteHandler: All set in demo mode → ${route}`);
|
||||
return navigateAndLog(route);
|
||||
}
|
||||
console.log('InitialRouteHandler: All set (subscription + onboarding) → Dashboard');
|
||||
return navigateAndLog("/dashboard");
|
||||
};
|
||||
|
||||
export default InitialRouteHandler;
|
||||
export default InitialRouteHandler;
|
||||
|
||||
@@ -1,11 +1,4 @@
|
||||
import React, { useRef, useCallback, useState } from 'react';
|
||||
import { useNavigate } from 'react-router-dom';
|
||||
import Dialog from '@mui/material/Dialog';
|
||||
import DialogTitle from '@mui/material/DialogTitle';
|
||||
import DialogContent from '@mui/material/DialogContent';
|
||||
import DialogContentText from '@mui/material/DialogContentText';
|
||||
import DialogActions from '@mui/material/DialogActions';
|
||||
import Button from '@mui/material/Button';
|
||||
import React, { useRef, useCallback } from 'react';
|
||||
import { debug } from '../../utils/debug';
|
||||
import WriterCopilotSidebar from './BlogWriterUtils/WriterCopilotSidebar';
|
||||
import { blogWriterApi } from '../../services/blogWriterApi';
|
||||
@@ -35,7 +28,7 @@ import { useBlogWriterRefs } from './BlogWriterUtils/useBlogWriterRefs';
|
||||
import { BlogWriterLandingSection } from './BlogWriterUtils/BlogWriterLandingSection';
|
||||
import { CopilotKitComponents } from './BlogWriterUtils/CopilotKitComponents';
|
||||
|
||||
const BlogWriter: React.FC = () => {
|
||||
export const BlogWriter: React.FC = () => {
|
||||
// Add light theme class to body/html on mount, remove on unmount
|
||||
React.useEffect(() => {
|
||||
document.body.classList.add('blog-writer-page');
|
||||
@@ -51,8 +44,6 @@ const BlogWriter: React.FC = () => {
|
||||
enabled: true, // Enable health checking
|
||||
});
|
||||
|
||||
const navigate = useNavigate();
|
||||
|
||||
// Use custom hook for all state management
|
||||
const {
|
||||
research,
|
||||
@@ -76,7 +67,6 @@ const BlogWriter: React.FC = () => {
|
||||
flowAnalysisCompleted,
|
||||
flowAnalysisResults,
|
||||
sectionImages,
|
||||
setResearch,
|
||||
setOutline,
|
||||
setTitleOptions,
|
||||
setSelectedTitle,
|
||||
@@ -87,7 +77,6 @@ const BlogWriter: React.FC = () => {
|
||||
setContinuityRefresh,
|
||||
setOutlineTaskId,
|
||||
setContentConfirmed,
|
||||
setOutlineConfirmed,
|
||||
setFlowAnalysisCompleted,
|
||||
setFlowAnalysisResults,
|
||||
setSectionImages,
|
||||
@@ -302,48 +291,6 @@ const BlogWriter: React.FC = () => {
|
||||
}
|
||||
}, [navigateToPhase, seoAnalysis, research, runSEOAnalysisDirect, setIsSEOAnalysisModalOpen]);
|
||||
|
||||
const handleNewBlog = useCallback(() => {
|
||||
setResearch(null);
|
||||
setOutline([]);
|
||||
setSections({});
|
||||
setSeoAnalysis(null);
|
||||
setSeoMetadata(null);
|
||||
setContentConfirmed(false);
|
||||
setOutlineConfirmed(false);
|
||||
setSelectedTitle('');
|
||||
setTitleOptions([]);
|
||||
setCurrentPhase('');
|
||||
try {
|
||||
localStorage.removeItem('blog_outline');
|
||||
localStorage.removeItem('blog_title_options');
|
||||
localStorage.removeItem('blog_selected_title');
|
||||
localStorage.removeItem('blogwriter_current_phase');
|
||||
localStorage.removeItem('blogwriter_user_selected_phase');
|
||||
localStorage.removeItem('blog_content_confirmed');
|
||||
localStorage.removeItem('blog_seo_recommendations_applied');
|
||||
} catch {
|
||||
// ignore localStorage errors
|
||||
}
|
||||
}, [setResearch, setOutline, setSections, setSeoAnalysis, setSeoMetadata,
|
||||
setContentConfirmed, setOutlineConfirmed, setSelectedTitle, setTitleOptions,
|
||||
setCurrentPhase]);
|
||||
|
||||
const handleMyBlogs = useCallback(() => {
|
||||
navigate('/asset-library?source_module=blog_writer&asset_type=text');
|
||||
}, [navigate]);
|
||||
|
||||
const [newBlogDialogOpen, setNewBlogDialogOpen] = useState(false);
|
||||
|
||||
const hasExistingWork = !!(research || outline.length > 0 || Object.keys(sections).length > 0);
|
||||
|
||||
const confirmNewBlog = useCallback(() => {
|
||||
if (hasExistingWork) {
|
||||
setNewBlogDialogOpen(true);
|
||||
} else {
|
||||
handleNewBlog();
|
||||
}
|
||||
}, [hasExistingWork, handleNewBlog]);
|
||||
|
||||
const outlineGenRef = useRef<any>(null);
|
||||
|
||||
// Callback to handle cached outline completion
|
||||
@@ -385,7 +332,6 @@ const BlogWriter: React.FC = () => {
|
||||
setIsSEOAnalysisModalOpen,
|
||||
setIsSEOMetadataModalOpen,
|
||||
runSEOAnalysisDirect,
|
||||
onResearchComplete: handleResearchComplete,
|
||||
onOutlineComplete: handleCachedOutlineComplete,
|
||||
onContentComplete: handleCachedContentComplete,
|
||||
});
|
||||
@@ -497,7 +443,6 @@ const BlogWriter: React.FC = () => {
|
||||
/>
|
||||
|
||||
{/* Phase navigation header - always visible as default interface */}
|
||||
<div style={{ flexShrink: 0 }}>
|
||||
<HeaderBar
|
||||
phases={phases}
|
||||
currentPhase={currentPhase}
|
||||
@@ -519,11 +464,7 @@ const BlogWriter: React.FC = () => {
|
||||
hasSEOAnalysis={!!seoAnalysis}
|
||||
seoRecommendationsApplied={seoRecommendationsApplied}
|
||||
hasSEOMetadata={!!seoMetadata}
|
||||
onNewBlog={confirmNewBlog}
|
||||
onMyBlogs={handleMyBlogs}
|
||||
onHelp={() => window.open('/docs', '_blank')}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Landing section - extracted to BlogWriterLandingSection */}
|
||||
<BlogWriterLandingSection
|
||||
@@ -619,26 +560,6 @@ const BlogWriter: React.FC = () => {
|
||||
// Publisher component will use this metadata when calling publish API
|
||||
}}
|
||||
/>
|
||||
|
||||
{/* New Blog confirmation dialog */}
|
||||
<Dialog
|
||||
open={newBlogDialogOpen}
|
||||
onClose={() => setNewBlogDialogOpen(false)}
|
||||
aria-labelledby="new-blog-dialog-title"
|
||||
>
|
||||
<DialogTitle id="new-blog-dialog-title">Start New Blog?</DialogTitle>
|
||||
<DialogContent>
|
||||
<DialogContentText>
|
||||
This will clear all your current work and start a new blog. This action cannot be undone.
|
||||
</DialogContentText>
|
||||
</DialogContent>
|
||||
<DialogActions>
|
||||
<Button onClick={() => setNewBlogDialogOpen(false)}>Cancel</Button>
|
||||
<Button onClick={() => { handleNewBlog(); setNewBlogDialogOpen(false); }} color="primary" variant="contained">
|
||||
Start New
|
||||
</Button>
|
||||
</DialogActions>
|
||||
</Dialog>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import React, { useState } from 'react';
|
||||
import { Container, Grid, Card, CardContent, Typography, Box, Stack, Chip } from '@mui/material';
|
||||
import CheckCircle from '@mui/icons-material/CheckCircle';
|
||||
import AutoAwesome from '@mui/icons-material/AutoAwesome';
|
||||
import { CheckCircle, AutoAwesome } from '@mui/icons-material';
|
||||
|
||||
interface PhaseFeature {
|
||||
title: string;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user