Release Candidate: Production Release with Multi-Tenant & Onboarding Enhancements

This commit is contained in:
ajaysi
2026-02-28 20:06:26 +05:30
parent 08a1f4a1d8
commit 4828274cbf
162 changed files with 19489 additions and 4300 deletions

View File

@@ -641,25 +641,18 @@ async def complete_style_detection(
return await loop.run_in_executor(None, partial(style_logic.perform_seo_audit, request.url, crawl_result['content']))
async def run_sitemap_analysis():
"""Run AI sitemap analysis for home page"""
if not request.url:
return None
try:
# Discover sitemap URL
sitemap_url = await sitemap_service.discover_sitemap_url(request.url)
if sitemap_url:
# Analyze sitemap with AI insights
return await sitemap_service.analyze_sitemap(
sitemap_url=sitemap_url,
analyze_content_trends=True,
analyze_publishing_patterns=True,
include_ai_insights=True,
user_id=user_id
)
return None
except Exception as e:
logger.error(f"Sitemap analysis failed: {e}")
return None
sitemap_url = await sitemap_service.discover_sitemap_url(request.url)
if sitemap_url:
return await sitemap_service.analyze_sitemap(
sitemap_url=sitemap_url,
analyze_content_trends=True,
analyze_publishing_patterns=True,
include_ai_insights=True,
user_id=user_id
)
return None
# Execute style, patterns, SEO analysis and sitemap analysis in parallel
style_analysis, patterns_result, seo_audit_result, sitemap_result = await asyncio.gather(
@@ -710,15 +703,17 @@ async def complete_style_detection(
elif isinstance(seo_audit_result, Exception):
logger.warning(f"SEO audit failed: {seo_audit_result}")
# Process sitemap analysis result
sitemap_analysis = None
sitemap_warning = None
if sitemap_result and not isinstance(sitemap_result, Exception):
sitemap_analysis = sitemap_result
elif isinstance(sitemap_result, Exception):
logger.warning(f"Sitemap analysis failed: {sitemap_result}")
sitemap_warning = f"Sitemap analysis failed: {sitemap_result}"
# Step 4: Generate guidelines (depends on style_analysis, must run after)
style_guidelines = None
guidelines_result = None
if request.include_guidelines:
loop = asyncio.get_event_loop()
guidelines_result = await loop.run_in_executor(
@@ -728,10 +723,14 @@ async def complete_style_detection(
if guidelines_result and guidelines_result.get('success'):
style_guidelines = guidelines_result.get('guidelines')
# Check if there's a warning about fallback data
warning = None
warning_parts = []
if style_analysis and 'warning' in style_analysis:
warning = style_analysis['warning']
warning_parts.append(style_analysis['warning'])
if request.include_guidelines and guidelines_result and not guidelines_result.get('success') and guidelines_result.get('error'):
warning_parts.append(f"Guidelines generation failed: {guidelines_result.get('error')}")
if sitemap_warning:
warning_parts.append(sitemap_warning)
warning = " | ".join(warning_parts) if warning_parts else None
# Prepare response data
response_data = {
@@ -1000,4 +999,4 @@ async def get_style_detection_configuration():
}
except Exception as e:
logger.error(f"[get_style_detection_configuration] Error: {str(e)}")
return {"error": f"Configuration error: {str(e)}"}
return {"error": f"Configuration error: {str(e)}"}

View File

@@ -5,6 +5,7 @@ Handles the complex logic for completing the onboarding process.
from typing import Dict, Any, List
from datetime import datetime, timedelta
import os
from fastapi import HTTPException
from loguru import logger
@@ -306,18 +307,20 @@ class OnboardingCompletionService:
db = get_session_for_user(user_id)
integration_service = OnboardingDataIntegrationService()
# Debug logging
logger.info(f"Validating steps for user {user_id}")
# Get integrated data
integrated_data = await integration_service.process_onboarding_data(user_id, db)
db.close()
from services.onboarding.progress_service import OnboardingProgressService
progress_service = OnboardingProgressService()
status = progress_service.get_onboarding_status(user_id)
current_step = status.get("current_step", 1)
# Check each required step
for step_num in self.required_steps:
step_completed = False
if step_num == 1: # API Keys
if step_num == 1:
api_keys_data = integrated_data.get('api_keys_data', {})
logger.info(f"Step 1 - API Keys: {api_keys_data}")
step_completed = bool(
@@ -325,26 +328,49 @@ class OnboardingCompletionService:
api_keys_data.get('anthropic_api_key') or
api_keys_data.get('google_api_key')
)
if not step_completed:
has_global_providers = bool(
os.getenv("EXA_API_KEY") or
os.getenv("GEMINI_API_KEY") or
os.getenv("OPENAI_API_KEY") or
os.getenv("ANTHROPIC_API_KEY") or
os.getenv("GOOGLE_API_KEY")
)
if has_global_providers:
step_completed = True
logger.info(f"Step 1 completed: {step_completed}")
elif step_num == 2: # Website Analysis
elif step_num == 2:
website = integrated_data.get('website_analysis', {})
logger.info(f"Step 2 - Website Analysis: {website}")
step_completed = bool(website and (website.get('website_url') or website.get('writing_style')))
logger.info(f"Step 2 completed: {step_completed}")
elif step_num == 3: # Research Preferences
elif step_num == 3:
research = integrated_data.get('research_preferences', {})
logger.info(f"Step 3 - Research Preferences: {research}")
step_completed = bool(research and (research.get('research_depth') or research.get('content_types')))
logger.info(f"Step 3 completed: {step_completed}")
elif step_num == 4: # Persona Generation
elif step_num == 4:
persona = integrated_data.get('persona_data', {})
logger.info(f"Step 4 - Persona Data: {persona}")
step_completed = bool(persona and (persona.get('corePersona') or persona.get('platformPersonas')))
if not step_completed:
website = integrated_data.get('website_analysis', {})
research = integrated_data.get('research_preferences', {})
basic_ready = bool(
website and (website.get('website_url') or website.get('writing_style'))
) and bool(research)
if basic_ready:
step_completed = True
logger.info(f"Step 4 completed: {step_completed}")
elif step_num == 5: # Integrations
# For now, consider this always completed if we reach this point
elif step_num == 5:
step_completed = True
logger.info(f"Step 5 completed: {step_completed}")
if not step_completed and current_step >= step_num:
step_completed = True
logger.info(
f"Step {step_num} marked completed based on progress service (current_step={current_step})"
)
if not step_completed:
missing_steps.append(f"Step {step_num}")
@@ -357,20 +383,34 @@ class OnboardingCompletionService:
return ["Validation error"]
async def _validate_api_keys(self, user_id: str):
"""Validate that API keys are configured for the current user (SSOT)."""
"""Validate that API keys are configured for the current user (SSOT or environment)."""
try:
db = get_session_for_user(user_id)
integration_service = OnboardingDataIntegrationService()
integrated_data = await integration_service.process_onboarding_data(user_id, db)
db.close()
try:
integration_service = OnboardingDataIntegrationService()
integrated_data = await integration_service.process_onboarding_data(user_id, db)
finally:
db.close()
api_keys_data = integrated_data.get('api_keys_data', {})
api_keys_data = integrated_data.get('api_keys_data', {}) if integrated_data else {}
has_keys = bool(
has_user_keys = bool(
api_keys_data.get('openai_api_key') or
api_keys_data.get('anthropic_api_key') or
api_keys_data.get('google_api_key')
api_keys_data.get('google_api_key') or
api_keys_data.get('exa_api_key') or
api_keys_data.get('gemini_api_key')
)
has_env_keys = bool(
os.getenv("OPENAI_API_KEY") or
os.getenv("ANTHROPIC_API_KEY") or
os.getenv("GOOGLE_API_KEY") or
os.getenv("EXA_API_KEY") or
os.getenv("GEMINI_API_KEY")
)
has_keys = has_user_keys or has_env_keys
if not has_keys:
raise HTTPException(

View File

@@ -8,7 +8,7 @@ from fastapi import HTTPException
from loguru import logger
from services.onboarding.api_key_manager import get_api_key_manager
from services.database import get_db
from services.database import get_session_for_user
from services.website_analysis_service import WebsiteAnalysisService
from services.research_preferences_service import ResearchPreferencesService
from services.persona_analysis_service import PersonaAnalysisService
@@ -32,10 +32,13 @@ class OnboardingSummaryService:
async def get_onboarding_summary(self) -> Dict[str, Any]:
"""Get comprehensive onboarding summary for FinalStep."""
try:
# Get integrated data via SSOT
db = next(get_db())
integrated_data = await self.integration_service.process_onboarding_data(self.user_id, db)
db.close()
db = get_session_for_user(self.user_id)
if not db:
raise HTTPException(status_code=500, detail="Database session could not be created")
try:
integrated_data = await self.integration_service.process_onboarding_data(self.user_id, db)
finally:
db.close()
# Extract components from integrated data
website_analysis = integrated_data.get('website_analysis', {})
@@ -152,15 +155,34 @@ class OnboardingSummaryService:
return capabilities
async def get_website_analysis_data(self) -> Dict[str, Any]:
"""Get website analysis data for the user (Step 2 output)."""
try:
db = get_session_for_user(self.user_id)
if not db:
raise HTTPException(status_code=500, detail="Database session could not be created")
try:
integrated_data = await self.integration_service.process_onboarding_data(self.user_id, db)
website_analysis = integrated_data.get("website_analysis") or {}
return website_analysis
finally:
db.close()
except Exception as e:
logger.error(f"Error getting website analysis data: {e}")
raise
async def get_research_preferences_data(self) -> Dict[str, Any]:
"""Get research preferences data for the user."""
try:
db = next(get_db())
research_prefs_service = ResearchPreferencesService(db)
# Use the new method that accepts user_id directly
result = research_prefs_service.get_research_preferences_by_user_id(self.user_id)
db.close()
return result
db = get_session_for_user(self.user_id)
if not db:
raise HTTPException(status_code=500, detail="Database session could not be created")
try:
research_prefs_service = ResearchPreferencesService(db)
result = research_prefs_service.get_research_preferences_by_user_id(self.user_id)
return result
finally:
db.close()
except Exception as e:
logger.error(f"Error getting research preferences data: {e}")
raise

View File

@@ -7,54 +7,238 @@ Analysis endpoint for podcast ideas.
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any
import json
import uuid
from sqlalchemy.orm import Session
from services.database import get_db
from middleware.auth_middleware import get_current_user
from api.story_writer.utils.auth import require_authenticated_user
from services.llm_providers.main_text_generation import llm_text_gen
from services.llm_providers.main_image_generation import generate_image
from services.podcast_bible_service import PodcastBibleService
from utils.asset_tracker import save_asset_to_library
from loguru import logger
from ..models import PodcastAnalyzeRequest, PodcastAnalyzeResponse
from ..constants import PODCAST_IMAGES_DIR
from ..models import (
PodcastAnalyzeRequest,
PodcastAnalyzeResponse,
PodcastEnhanceIdeaRequest,
PodcastEnhanceIdeaResponse
)
router = APIRouter()
@router.post("/idea/enhance", response_model=PodcastEnhanceIdeaResponse)
async def enhance_podcast_idea(
request: PodcastEnhanceIdeaRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""
Take raw keywords/topic and use AI to craft a presentable, detailed podcast idea.
Uses the user's Podcast Bible for hyper-personalization if available.
"""
user_id = require_authenticated_user(current_user)
# Serialize Bible context if provided or generate from onboarding
bible_context = ""
try:
bible_service = PodcastBibleService()
if request.bible:
from models.podcast_bible_models import PodcastBible
bible_data = PodcastBible(**request.bible)
bible_context = bible_service.serialize_bible(bible_data)
else:
# Generate from onboarding data directly
bible_obj = bible_service.generate_bible(user_id, "temp_enhance")
bible_context = bible_service.serialize_bible(bible_obj)
except Exception as exc:
logger.warning(f"[Podcast Enhance] Failed to parse or generate bible context: {exc}")
prompt = f"""
You are a creative podcast producer. Your goal is to take a simple podcast idea or keywords
and transform it into a compelling, professional, and detailed episode concept.
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
RAW IDEA/KEYWORDS: "{request.idea}"
TASK:
1. Rewrite the idea into a professional, presentable 2-3 sentence episode pitch.
2. Focus on making it sound expert-led and audience-focused.
3. Ensure it aligns with the host's persona and target audience interests if context was provided.
4. Keep it concise but information-rich.
Return JSON with:
- enhanced_idea: the rewritten, professional episode pitch
- rationale: 1 sentence explaining why this version works better for the target audience
"""
try:
raw = llm_text_gen(prompt=prompt, user_id=user_id, json_struct=None)
# Normalize response
if isinstance(raw, str):
data = json.loads(raw)
else:
data = raw
return PodcastEnhanceIdeaResponse(
enhanced_idea=data.get("enhanced_idea", request.idea),
rationale=data.get("rationale", "Made it more professional and listener-focused.")
)
except Exception as exc:
logger.error(f"[Podcast Enhance] Failed for user {user_id}: {exc}")
return PodcastEnhanceIdeaResponse(
enhanced_idea=request.idea,
rationale="Failed to enhance idea with AI, using original."
)
@router.post("/analyze", response_model=PodcastAnalyzeResponse)
async def analyze_podcast_idea(
request: PodcastAnalyzeRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
db: Session = Depends(get_db),
):
"""
Analyze a podcast idea and return podcast-oriented outlines, keywords, and titles.
This uses the shared LLM provider but with a podcast-specific prompt (not story format).
If no avatar_url is provided, it generates one automatically based on the host's look.
"""
user_id = require_authenticated_user(current_user)
# Serialize Bible context if provided or generate from onboarding
bible_context = ""
bible_obj = None
try:
bible_service = PodcastBibleService()
if request.bible:
from models.podcast_bible_models import PodcastBible
bible_data = PodcastBible(**request.bible)
bible_context = bible_service.serialize_bible(bible_data)
bible_obj = bible_data
else:
# Generate from onboarding data directly
bible_obj = bible_service.generate_bible(user_id, "temp_analyze")
bible_context = bible_service.serialize_bible(bible_obj)
bible_obj = bible_obj
except Exception as exc:
logger.warning(f"[Podcast Analyze] Failed to parse or generate bible context: {exc}")
# --- NEW: Generate Presenter Avatar if missing ---
final_avatar_url = request.avatar_url
final_avatar_prompt = None
if not final_avatar_url:
logger.info(f"[Podcast Analyze] No avatar_url provided, generating one for user {user_id}")
try:
# 1. PRE-FLIGHT VALIDATION: Check subscription limits for image generation
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_image_generation_operations
pricing_service = PricingService(db)
validate_image_generation_operations(
pricing_service=pricing_service,
user_id=user_id,
num_images=1
)
# 2. Build avatar prompt from Bible host look or fallback
host_look = bible_obj.host.look if bible_obj and bible_obj.host.look else "A professional podcast host"
visual_style = bible_obj.visual_style.style_preset if bible_obj else "Realistic Photography"
final_avatar_prompt = f"Professional headshot of a podcast host, {host_look}, {visual_style} style, clean background, soft studio lighting, center-focused, high resolution, sharp focus, professional photography quality, 16:9 aspect ratio."
# 3. Generate the image
logger.info(f"[Podcast Analyze] Generating avatar with prompt: {final_avatar_prompt}")
image_result = generate_image(
prompt=final_avatar_prompt,
user_id=user_id,
width=1024,
height=1024
)
# 4. Save to disk and library
if image_result and image_result.image_bytes:
img_id = str(uuid.uuid4())[:8]
filename = f"presenter_podcast_{user_id}_{img_id}.png"
output_path = PODCAST_IMAGES_DIR / filename
PODCAST_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
with open(output_path, "wb") as f:
f.write(image_result.image_bytes)
final_avatar_url = f"/api/podcast/images/avatars/{filename}"
# Save to asset library for reuse
save_asset_to_library(
db=db,
user_id=user_id,
asset_type="image",
file_url=final_avatar_url,
filename=filename,
title=f"Presenter Avatar - {request.idea[:40]}",
description=f"AI-generated podcast presenter for: {request.idea}",
provider=image_result.provider,
model=image_result.model,
cost=image_result.cost
)
logger.info(f"[Podcast Analyze] ✅ Generated and saved avatar to {final_avatar_url}")
except Exception as e:
logger.error(f"[Podcast Analyze] ❌ Failed to generate avatar: {e}")
# Non-fatal: continue analysis even if avatar generation fails
# --- END: Avatar Generation ---
# Incorporate user feedback if provided
feedback_context = ""
if request.feedback:
feedback_context = f"""
USER REGENERATION FEEDBACK:
The user was not satisfied with the previous analysis. They provided the following instructions for improvement:
"{request.feedback}"
Please prioritize this feedback and adjust the analysis accordingly.
"""
prompt = f"""
You are an expert podcast producer. Given a podcast idea, craft concise podcast-ready assets
You are an expert podcast producer and research strategist. Given a podcast idea, craft concise podcast-ready assets
that sound like episode plans (not fiction stories).
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
{feedback_context}
Podcast Idea: "{request.idea}"
Duration: ~{request.duration} minutes
Speakers: {request.speakers} (host + optional guest)
TASK:
1. Define the target audience and content type aligned with the Bible's "Audience DNA" and "Brand DNA".
2. Identify 5 high-impact keywords.
3. Propose 2 episode outlines with factual segments.
4. Suggest 3 titles.
5. IMPORTANT: Generate 4-6 specific research queries for Exa. These queries MUST be highly targeted to the episode's topic, the host's expertise level, and the audience's interests as defined in the Bible.
* Do NOT use generic queries like "latest trends in X".
* DO use queries that look for case studies, specific data points, expert opinions, or contrasting viewpoints that would make for a deep, insightful podcast conversation.
Return JSON with:
- audience: short target audience description
- content_type: podcast style/format
- top_keywords: 5 podcast-relevant keywords/phrases
- suggested_outlines: 2 items, each with title (<=60 chars) and 4-6 short segments (bullet-friendly, factual)
- title_suggestions: 3 concise episode titles (no cliffhanger storytelling)
- exa_suggested_config: suggested Exa search options to power research (keep conservative defaults to control cost), with:
- exa_search_type: "auto" | "neural" | "keyword" (prefer "auto" unless clearly news-heavy)
- title_suggestions: 3 concise episode titles
- research_queries: array of {{"query": "string", "rationale": "string"}}
- exa_suggested_config: suggested Exa search options with:
- exa_search_type: "auto" | "neural" | "keyword"
- exa_category: one of ["research paper","news","company","github","tweet","personal site","pdf","financial report","linkedin profile"]
- exa_include_domains: up to 3 reputable domains to prioritize (optional)
- exa_exclude_domains: up to 3 domains to avoid (optional)
- exa_include_domains: up to 3 reputable domains
- exa_exclude_domains: up to 3 domains
- max_sources: 6-10
- include_statistics: boolean (true if topic needs fresh stats)
- date_range: one of ["last_month","last_3_months","last_year","all_time"] (pick recent if time-sensitive)
- include_statistics: boolean
- date_range: one of ["last_month","last_3_months","last_year","all_time"]
Requirements:
- Keep language factual, actionable, and suited for spoken audio.
- Avoid narrative fiction tone; focus on insights, hooks, objections, and takeaways.
- Prefer 2024-2025 context when relevant.
- Avoid narrative fiction tone.
- Prefer 2024-2025 context.
"""
try:
@@ -82,7 +266,7 @@ Requirements:
top_keywords = data.get("top_keywords") or []
suggested_outlines = data.get("suggested_outlines") or []
title_suggestions = data.get("title_suggestions") or []
research_queries = data.get("research_queries") or []
exa_suggested_config = data.get("exa_suggested_config") or None
return PodcastAnalyzeResponse(
@@ -91,6 +275,10 @@ Requirements:
top_keywords=top_keywords,
suggested_outlines=suggested_outlines,
title_suggestions=title_suggestions,
research_queries=research_queries,
exa_suggested_config=exa_suggested_config,
bible=bible_obj.model_dump() if bible_obj else None,
avatar_url=final_avatar_url,
avatar_prompt=final_avatar_prompt,
)

View File

@@ -86,6 +86,19 @@ async def generate_podcast_scene_image(
logger.info(f"[Podcast] No base avatar URL provided, will generate from scratch")
base_avatar_bytes = None
# Extract Podcast Bible context for hyper-personalization
bible_context = ""
bible_obj = None
if request.bible:
try:
from services.podcast_bible_service import PodcastBibleService
from models.podcast_bible_models import PodcastBible
bible_service = PodcastBibleService()
bible_obj = PodcastBible(**request.bible)
bible_context = bible_service.serialize_bible(bible_obj)
except Exception as exc:
logger.warning(f"[Podcast Image] Failed to serialize podcast bible: {exc}")
# Build optimized prompt for scene image generation
# When base avatar is provided, use Ideogram Character to maintain consistency
# Otherwise, generate from scratch with podcast-optimized prompt
@@ -106,6 +119,14 @@ async def generate_podcast_scene_image(
if request.scene_title:
prompt_parts.append(f"Scene: {request.scene_title}")
# Use Bible visual style if available
if bible_obj:
prompt_parts.append(f"Style: {bible_obj.visual_style.style_preset}")
prompt_parts.append(f"Environment: {bible_obj.visual_style.environment}")
prompt_parts.append(f"Lighting: {bible_obj.visual_style.lighting}")
if bible_obj.host.look:
prompt_parts.append(f"Host Look: {bible_obj.host.look}")
# Scene content insights for visual context
if request.scene_content:
content_preview = request.scene_content[:200].replace("\n", " ").strip()
@@ -127,12 +148,14 @@ async def generate_podcast_scene_image(
prompt_parts.append(f"Topic: {idea_preview}")
# Studio setting (maintains podcast aesthetic)
prompt_parts.extend([
"Professional podcast recording studio",
"Modern microphone setup",
"Clean background, professional lighting",
"16:9 aspect ratio, video-optimized composition"
])
if not bible_obj:
prompt_parts.extend([
"Professional podcast recording studio",
"Modern microphone setup",
"Clean background, professional lighting"
])
prompt_parts.append("16:9 aspect ratio, video-optimized composition")
image_prompt = ", ".join(prompt_parts)
@@ -221,14 +244,22 @@ async def generate_podcast_scene_image(
# Standard generation from scratch (no base avatar provided)
prompt_parts = []
# Core podcast studio elements
prompt_parts.extend([
"Professional podcast recording studio",
"Modern podcast setup with high-quality microphone",
"Clean, minimalist background suitable for video",
"Professional studio lighting with soft, even illumination",
"Podcast host environment, professional and inviting"
])
# Use Bible visual style if available
if bible_obj:
prompt_parts.append(f"Style: {bible_obj.visual_style.style_preset}")
prompt_parts.append(f"Environment: {bible_obj.visual_style.environment}")
prompt_parts.append(f"Lighting: {bible_obj.visual_style.lighting}")
if bible_obj.host.look:
prompt_parts.append(f"Host Look: {bible_obj.host.look}")
else:
# Core podcast studio elements
prompt_parts.extend([
"Professional podcast recording studio",
"Modern podcast setup with high-quality microphone",
"Clean, minimalist background suitable for video",
"Professional studio lighting with soft, even illumination",
"Podcast host environment, professional and inviting"
])
# Scene-specific context
if request.scene_title:
@@ -264,12 +295,13 @@ async def generate_podcast_scene_image(
])
# Style constraints
prompt_parts.extend([
"Realistic photography style, not illustration or cartoon",
"Professional broadcast quality",
"Warm, inviting atmosphere",
"Clean composition with breathing room for avatar placement"
])
if not bible_obj:
prompt_parts.extend([
"Realistic photography style, not illustration or cartoon",
"Professional broadcast quality",
"Warm, inviting atmosphere",
"Clean composition with breathing room for avatar placement"
])
image_prompt = ", ".join(prompt_parts)

View File

@@ -47,6 +47,7 @@ async def create_project(
duration=request.duration,
speakers=request.speakers,
budget_cap=request.budget_cap,
avatar_url=request.avatar_url,
)
return PodcastProjectResponse.model_validate(project)

View File

@@ -1,22 +1,26 @@
"""
Podcast Research Handlers
Research endpoints using Exa provider.
Research endpoints using Exa provider and LLM summarization.
"""
from fastapi import APIRouter, Depends, HTTPException
from typing import Dict, Any
from typing import Dict, Any, List
from types import SimpleNamespace
import json
from middleware.auth_middleware import get_current_user
from api.story_writer.utils.auth import require_authenticated_user
from services.blog_writer.research.exa_provider import ExaResearchProvider
from services.llm_providers.main_text_generation import llm_text_gen
from services.podcast_bible_service import PodcastBibleService
from loguru import logger
from ..models import (
PodcastExaResearchRequest,
PodcastExaResearchResponse,
PodcastExaSource,
PodcastExaConfig,
PodcastResearchInsight,
)
router = APIRouter()
@@ -28,7 +32,8 @@ async def podcast_research_exa(
current_user: Dict[str, Any] = Depends(get_current_user),
):
"""
Run podcast research directly via Exa (no blog writer pipeline).
Run podcast research via Exa and then use LLM to extract deep insights.
Uses Podcast Bible and Analysis context for hyper-personalization.
"""
user_id = require_authenticated_user(current_user)
@@ -47,22 +52,121 @@ async def podcast_research_exa(
)
provider = ExaResearchProvider()
prompt = request.topic
# --- Context Building ---
bible_service = PodcastBibleService()
bible_context = ""
if request.bible:
try:
from models.podcast_bible_models import PodcastBible
bible_data = PodcastBible(**request.bible)
bible_context = bible_service.serialize_bible(bible_data)
except Exception as exc:
logger.warning(f"[Podcast Research] Failed to serialize bible: {exc}")
analysis_context = ""
if request.analysis:
analysis_context = f"""
PODCAST ANALYSIS CONTEXT:
Audience: {request.analysis.get('audience', 'General')}
Content Type: {request.analysis.get('content_type', 'Informative')}
Top Keywords: {', '.join(request.analysis.get('top_keywords', []))}
"""
# Exa search params
industry = request.bible.get("brand", {}).get("industry", "") if request.bible else ""
target_audience = ""
if request.bible:
audience_dna = request.bible.get("audience", {})
if audience_dna:
interests = ", ".join(audience_dna.get("interests", []))
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
try:
# 1. RUN EXA SEARCH
result = await provider.search(
prompt=prompt,
prompt=request.topic,
topic=request.topic,
industry="",
target_audience="",
industry=industry,
target_audience=target_audience,
config=cfg,
user_id=user_id,
)
except Exception as exc:
logger.error(f"[Podcast Exa Research] Failed for user {user_id}: {exc}")
logger.error(f"[Podcast Exa Research] Search failed for user {user_id}: {exc}")
raise HTTPException(status_code=500, detail=f"Exa research failed: {exc}")
# Track usage if available
# 2. EXTRACT INSIGHTS VIA LLM
raw_content = result.get("content", "")
sources = result.get("sources", [])
summary = ""
key_insights = []
if raw_content and sources:
logger.info(f"[Podcast Research] Extracting insights from {len(sources)} sources for user {user_id}")
prompt = f"""
You are an expert research analyst for a high-end podcast production team.
Your task is to analyze the following research data and extract deep, actionable insights for a podcast episode.
PODCAST CONTEXT:
Topic: {request.topic}
{bible_context}
{analysis_context}
RESEARCH DATA (from {len(sources)} sources):
{raw_content}
TASK:
1. Provide a comprehensive summary (2-3 paragraphs) of the most important findings. Use Markdown for formatting (bolding, lists).
2. Extract 3-5 "Key Insights". Each insight should have a title and a detailed explanation.
3. For each insight, identify which source indices (e.g. 1, 2) it was derived from.
NOTE: The research data includes "Key Highlights", "Summaries", and "Excerpts" from various sources.
Pay special attention to the "Key Highlights" sections as they contain the most relevant information extracted by the neural search engine.
Return JSON structure:
{{
"summary": "Detailed markdown summary...",
"key_insights": [
{{
"title": "Insight Title",
"content": "Detailed markdown content...",
"source_indices": [1, 2]
}}
]
}}
Requirements:
- Ensure insights are deep, not just superficial facts. Look for trends, expert opinions, and specific data points.
- Tone should be professional, insightful, and ready for a podcast host to discuss.
- Avoid generic filler.
"""
try:
llm_response = llm_text_gen(prompt=prompt, user_id=user_id, json_struct=None)
# Normalize response
if isinstance(llm_response, str):
data = json.loads(llm_response)
else:
data = llm_response
summary = data.get("summary", "")
key_insights = [PodcastResearchInsight(**insight) for insight in data.get("key_insights", [])]
except Exception as exc:
logger.error(f"[Podcast Research] LLM Insight extraction failed: {exc}")
# Fallback to a basic summary if LLM fails
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
# Fallback: if summary is still empty (e.g. LLM returned empty string), use raw content first paragraph or basic text
if not summary:
if raw_content:
summary = raw_content[:2000] # Use first 2000 chars of raw content as summary
else:
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
# 3. TRACK USAGE
try:
cost_total = 0.0
if isinstance(result, dict):
@@ -72,28 +176,31 @@ async def podcast_research_exa(
logger.warning(f"[Podcast Exa Research] Failed to track usage: {track_err}")
sources_payload = []
if isinstance(result, dict):
for src in result.get("sources", []) or []:
try:
sources_payload.append(PodcastExaSource(**src))
except Exception:
sources_payload.append(PodcastExaSource(**{
"title": src.get("title", ""),
"url": src.get("url", ""),
"excerpt": src.get("excerpt", ""),
"published_at": src.get("published_at"),
"highlights": src.get("highlights"),
"summary": src.get("summary"),
"source_type": src.get("source_type"),
"index": src.get("index"),
}))
for src in sources:
try:
sources_payload.append(PodcastExaSource(**src))
except Exception:
sources_payload.append(PodcastExaSource(**{
"title": src.get("title", ""),
"url": src.get("url", ""),
"excerpt": src.get("excerpt", ""),
"published_at": src.get("published_at"),
"highlights": src.get("highlights"),
"summary": src.get("summary"),
"source_type": src.get("source_type"),
"index": src.get("index"),
"image": src.get("image"),
"author": src.get("author"),
}))
return PodcastExaResearchResponse(
sources=sources_payload,
search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries,
summary=summary,
key_insights=key_insights,
cost=result.get("cost") if isinstance(result, dict) else None,
search_type=result.get("search_type") if isinstance(result, dict) else None,
provider=result.get("provider", "exa") if isinstance(result, dict) else "exa",
content=result.get("content") if isinstance(result, dict) else None,
content=raw_content,
)

View File

@@ -11,6 +11,8 @@ import json
from middleware.auth_middleware import get_current_user
from api.story_writer.utils.auth import require_authenticated_user
from services.llm_providers.main_text_generation import llm_text_gen
from services.podcast_bible_service import PodcastBibleService
from models.podcast_bible_models import PodcastBible
from loguru import logger
from ..models import (
PodcastScriptRequest,
@@ -62,8 +64,39 @@ async def generate_podcast_script(
logger.warning(f"Failed to parse research context: {exc}")
research_context = ""
# Extract Podcast Bible context for hyper-personalization
bible_context = ""
if request.bible:
try:
bible_service = PodcastBibleService()
bible_obj = PodcastBible(**request.bible)
bible_context = bible_service.serialize_bible(bible_obj)
except Exception as exc:
logger.warning(f"Failed to serialize podcast bible: {exc}")
# Extract Analysis and Outline context for grounding
analysis_context = ""
if request.analysis:
analysis_context = f"""
TARGET AUDIENCE: {request.analysis.get('audience', 'General')}
CONTENT TYPE: {request.analysis.get('contentType', 'Conversational')}
TOP KEYWORDS: {', '.join(request.analysis.get('topKeywords', []))}
"""
outline_context = ""
if request.outline:
outline_context = f"""
REFINED EPISODE OUTLINE (Follow this structure closely):
Title: {request.outline.get('title', 'N/A')}
Segments: {' | '.join(request.outline.get('segments', []))}
"""
prompt = f"""You are an expert podcast script planner. Create natural, conversational podcast scenes.
{f"PODCAST BIBLE (Hyper-Personalization Context):\n{bible_context}\n" if bible_context else ""}
{f"ANALYSIS CONTEXT:\n{analysis_context}\n" if analysis_context else ""}
{f"REFINED OUTLINE:\n{outline_context}\n" if outline_context else ""}
Podcast Idea: "{request.idea}"
Duration: ~{request.duration_minutes} minutes
Speakers: {request.speakers} (Host + optional Guest)
@@ -83,11 +116,13 @@ Return JSON with:
* Mark "emphasis": true for key statistics or important points
Guidelines:
- Write for spoken delivery: conversational, natural, with contractions
- Use research insights naturally - weave statistics into dialogue, don't just list them
- Vary emotion per scene based on content
- Ensure scenes match target duration: aim for ~2.5 words per second of audio
- Keep it engaging and informative, like a real podcast conversation
- Write for spoken delivery: conversational, natural, with contractions.
- Follow the interaction tone specified in the Bible.
- Ensure the Host persona matches the background and personality traits from the Bible.
- Structure the intro and outro scenes according to the Bible's "Intro Format" and "Outro Format".
- Adhere to any constraints mentioned in the Bible.
- Use insights from the Research Context to ground the conversation in facts.
- IMPORTANT: Follow the REFINED OUTLINE segments as the primary structure for the episode.
"""
try:

View File

@@ -14,7 +14,7 @@ import re
import json
from concurrent.futures import ThreadPoolExecutor
from services.database import get_db
from services.database import get_session_for_user
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
from api.story_writer.utils.auth import require_authenticated_user
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
@@ -105,6 +105,34 @@ def _execute_podcast_video_task(
scene_number_match = re.search(r'\d+', request.scene_id)
scene_number = int(scene_number_match.group()) if scene_number_match else 0
# Fetch project context (Bible & Analysis) from DB if not provided in request
from services.database import get_session_for_user
from services.podcast_service import PodcastService
project_bible = request.bible
project_analysis = None
try:
# Create a dedicated session for this background task
db = get_session_for_user(user_id)
try:
podcast_service = PodcastService(db)
# Fetch project directly from DB to get latest analysis/bible
project = podcast_service.get_project(user_id, request.project_id)
if project:
# Use project bible if request didn't provide one
if not project_bible and project.bible:
project_bible = project.bible
# Get analysis for better context
if project.analysis:
project_analysis = project.analysis
logger.info(f"[Podcast] Loaded analysis for video context: {list(project_analysis.keys())}")
finally:
db.close()
except Exception as e:
logger.warning(f"[Podcast] Failed to fetch project context for video generation: {e}")
# Prepare scene data for animation
scene_data = {
"scene_number": scene_number,
@@ -114,6 +142,8 @@ def _execute_podcast_video_task(
story_context = {
"project_id": request.project_id,
"type": "podcast",
"bible": project_bible,
"analysis": project_analysis,
}
animation_result = animate_scene_with_voiceover(
@@ -207,8 +237,8 @@ def _execute_podcast_video_task(
@router.post("/render/video", response_model=PodcastVideoGenerationResponse)
async def generate_podcast_video(
request_obj: Request,
request: PodcastVideoGenerationRequest,
request: Request,
body: PodcastVideoGenerationRequest,
background_tasks: BackgroundTasks,
current_user: Dict[str, Any] = Depends(get_current_user),
):
@@ -216,22 +246,46 @@ async def generate_podcast_video(
Generate video for a podcast scene using WaveSpeed InfiniteTalk (avatar image + audio).
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
"""
# Debug logging to identify "Depends object has no attribute get" error source
logger.info(f"[Podcast] generate_podcast_video called. current_user type: {type(current_user)}")
# Check if current_user is a Depends object (FastAPI injection failure)
if hasattr(current_user, "dependency"):
logger.error(f"[Podcast] CRITICAL: current_user is a Depends object! Dependency injection failed.")
# Attempt to manually resolve or fail gracefully
auth_header = None
try:
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
auth_header = request.headers.get("Authorization")
except:
pass
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.replace("Bearer ", "").strip()
# Manually verify token if dependency injection failed
from middleware.auth_middleware import clerk_auth
current_user = await clerk_auth.verify_token(token)
if not current_user:
raise HTTPException(status_code=401, detail="Authentication failed (manual recovery)")
else:
raise HTTPException(status_code=401, detail="Authentication failed (injection error)")
user_id = require_authenticated_user(current_user)
logger.info(
f"[Podcast] Starting video generation for project {request.project_id}, scene {request.scene_id}"
f"[Podcast] Starting video generation for project {body.project_id}, scene {body.scene_id}"
)
# Load audio bytes
audio_bytes = load_podcast_audio_bytes(request.audio_url)
audio_bytes = load_podcast_audio_bytes(body.audio_url)
# Validate resolution
if request.resolution not in {"480p", "720p"}:
if body.resolution not in {"480p", "720p"}:
raise HTTPException(status_code=400, detail="Resolution must be '480p' or '720p'.")
# Load image bytes (scene image is required for video generation)
if request.avatar_image_url:
image_bytes = load_podcast_image_bytes(request.avatar_image_url)
if body.avatar_image_url:
image_bytes = load_podcast_image_bytes(body.avatar_image_url)
else:
# Scene-specific image should be generated before video generation
raise HTTPException(
@@ -240,9 +294,9 @@ async def generate_podcast_video(
)
mask_image_bytes = None
if request.mask_image_url:
if body.mask_image_url:
try:
mask_image_bytes = load_podcast_image_bytes(request.mask_image_url)
mask_image_bytes = load_podcast_image_bytes(body.mask_image_url)
except Exception as e:
logger.error(f"[Podcast] Failed to load mask image: {e}")
raise HTTPException(
@@ -251,7 +305,9 @@ async def generate_podcast_video(
)
# Validate subscription limits
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise HTTPException(status_code=500, detail="Database session unavailable for user.")
try:
pricing_service = PricingService(db)
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
@@ -260,16 +316,20 @@ async def generate_podcast_video(
# Extract token for authenticated URL building
auth_token = None
auth_header = request_obj.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.replace("Bearer ", "").strip()
try:
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.replace("Bearer ", "").strip()
except Exception as e:
logger.warning(f"[Podcast] Failed to extract auth token from headers: {e}")
# Create async task
task_id = task_manager.create_task("podcast_video_generation")
background_tasks.add_task(
_execute_podcast_video_task,
task_id=task_id,
request=request,
request=body,
user_id=user_id,
image_bytes=image_bytes,
audio_bytes=audio_bytes,

View File

@@ -25,6 +25,7 @@ class PodcastProjectResponse(BaseModel):
raw_research: Optional[Dict[str, Any]] = None
estimate: Optional[Dict[str, Any]] = None
script_data: Optional[Dict[str, Any]] = None
bible: Optional[Dict[str, Any]] = None
render_jobs: Optional[List[Dict[str, Any]]] = None
knobs: Optional[Dict[str, Any]] = None
research_provider: Optional[str] = None
@@ -34,6 +35,9 @@ class PodcastProjectResponse(BaseModel):
status: str = "draft"
is_favorite: bool = False
final_video_url: Optional[str] = None
avatar_url: Optional[str] = None
avatar_prompt: Optional[str] = None
avatar_persona_id: Optional[str] = None
created_at: datetime
updated_at: datetime
@@ -46,6 +50,9 @@ class PodcastAnalyzeRequest(BaseModel):
idea: str = Field(..., description="Podcast topic or idea")
duration: int = Field(default=10, description="Target duration in minutes")
speakers: int = Field(default=1, description="Number of speakers")
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
avatar_url: Optional[str] = Field(None, description="Current avatar URL if selected")
feedback: Optional[str] = Field(None, description="User feedback for regeneration")
class PodcastAnalyzeResponse(BaseModel):
@@ -55,7 +62,23 @@ class PodcastAnalyzeResponse(BaseModel):
top_keywords: list[str]
suggested_outlines: list[Dict[str, Any]]
title_suggestions: list[str]
research_queries: Optional[List[Dict[str, str]]] = None
exa_suggested_config: Optional[Dict[str, Any]] = None
bible: Optional[Dict[str, Any]] = None
avatar_url: Optional[str] = None
avatar_prompt: Optional[str] = None
class PodcastEnhanceIdeaRequest(BaseModel):
"""Request model for enhancing a podcast idea with AI."""
idea: str = Field(..., description="The raw podcast idea or keywords")
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
class PodcastEnhanceIdeaResponse(BaseModel):
"""Response model for enhanced podcast idea."""
enhanced_idea: str
rationale: str
class PodcastScriptRequest(BaseModel):
@@ -64,6 +87,9 @@ class PodcastScriptRequest(BaseModel):
duration_minutes: int = Field(default=10, description="Target duration in minutes")
speakers: int = Field(default=1, description="Number of speakers")
research: Optional[Dict[str, Any]] = Field(None, description="Optional research payload to ground the script")
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
outline: Optional[Dict[str, Any]] = Field(None, description="The refined episode outline to follow")
analysis: Optional[Dict[str, Any]] = Field(None, description="The full analysis context (audience, keywords, etc.)")
class PodcastSceneLine(BaseModel):
@@ -106,6 +132,8 @@ class PodcastExaResearchRequest(BaseModel):
topic: str
queries: List[str]
exa_config: Optional[PodcastExaConfig] = None
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
analysis: Optional[Dict[str, Any]] = Field(None, description="Podcast analysis context (audience, content type, etc.)")
class PodcastExaSource(BaseModel):
@@ -117,15 +145,26 @@ class PodcastExaSource(BaseModel):
summary: Optional[str] = None
source_type: Optional[str] = None
index: Optional[int] = None
image: Optional[str] = None
author: Optional[str] = None
class PodcastResearchInsight(BaseModel):
"""Deep insight extracted from research."""
title: str
content: str
source_indices: List[int] = []
class PodcastExaResearchResponse(BaseModel):
sources: List[PodcastExaSource]
search_queries: List[str] = []
summary: str = ""
key_insights: List[PodcastResearchInsight] = []
cost: Optional[Dict[str, Any]] = None
search_type: Optional[str] = None
provider: str = "exa"
content: Optional[str] = None
content: Optional[str] = None # Raw aggregated content (deprecated)
class PodcastScriptResponse(BaseModel):
@@ -191,6 +230,7 @@ class UpdateProjectRequest(BaseModel):
raw_research: Optional[Dict[str, Any]] = None
estimate: Optional[Dict[str, Any]] = None
script_data: Optional[Dict[str, Any]] = None
bible: Optional[Dict[str, Any]] = None
render_jobs: Optional[List[Dict[str, Any]]] = None
knobs: Optional[Dict[str, Any]] = None
research_provider: Optional[str] = None
@@ -224,6 +264,7 @@ class PodcastImageRequest(BaseModel):
scene_content: Optional[str] = None # Optional: scene lines text for context
idea: Optional[str] = None # Optional: podcast idea for context
base_avatar_url: Optional[str] = None # Base avatar image URL for scene variations
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
width: int = 1024
height: int = 1024
custom_prompt: Optional[str] = None # Custom prompt from user (overrides auto-generated prompt)
@@ -252,6 +293,7 @@ class PodcastVideoGenerationRequest(BaseModel):
scene_title: str = Field(..., description="Scene title")
audio_url: str = Field(..., description="URL to the generated audio file")
avatar_image_url: Optional[str] = Field(None, description="URL to scene image (required for video generation)")
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
resolution: str = Field("720p", description="Video resolution (480p or 720p)")
prompt: Optional[str] = Field(None, description="Optional animation prompt override")
seed: Optional[int] = Field(-1, description="Random seed; -1 for random")

View File

@@ -524,6 +524,80 @@ async def get_semantic_cache_stats(current_user: dict = Depends(get_current_user
"memory_usage_mb": 0.0
}
async def get_sif_indexing_health(current_user: dict = Depends(get_current_user)) -> Dict[str, Any]:
try:
from models.website_analysis_monitoring_models import SIFIndexingTask, SIFIndexingExecutionLog
user_id = str(current_user.get("id"))
db = get_session_for_user(user_id)
if not db:
raise HTTPException(status_code=500, detail="Database connection unavailable")
try:
tasks = (
db.query(SIFIndexingTask)
.filter(SIFIndexingTask.user_id == user_id)
.order_by(SIFIndexingTask.created_at.desc())
.all()
)
if not tasks:
return {
"has_task": False,
"status": "not_scheduled",
"message": "SIF indexing task not yet scheduled for this website.",
}
latest = tasks[0]
latest_log = (
db.query(SIFIndexingExecutionLog)
.filter(SIFIndexingExecutionLog.task_id == latest.id)
.order_by(SIFIndexingExecutionLog.execution_date.desc())
.first()
)
last_run_status = latest_log.status if latest_log else None
last_run_time = (
latest_log.execution_date.isoformat() if latest_log and latest_log.execution_date else None
)
last_error = (
(latest_log.error_message or "")[:500] if latest_log and latest_log.error_message else None
)
overall_status = "healthy"
if latest.consecutive_failures and latest.consecutive_failures > 0:
overall_status = "warning"
if latest.status in {"needs_intervention"}:
overall_status = "critical"
return {
"has_task": True,
"status": overall_status,
"task": {
"id": latest.id,
"website_url": latest.website_url,
"raw_status": latest.status,
"next_execution": latest.next_execution.isoformat() if latest.next_execution else None,
"last_success": latest.last_success.isoformat() if latest.last_success else None,
"last_failure": latest.last_failure.isoformat() if latest.last_failure else None,
"consecutive_failures": latest.consecutive_failures or 0,
"failure_pattern": latest.failure_pattern,
},
"last_run": {
"status": last_run_status,
"time": last_run_time,
"error_message": last_error,
},
}
finally:
db.close()
except HTTPException:
raise
except Exception as e:
logger.error(f"Failed to get SIF indexing health: {e}")
raise HTTPException(status_code=500, detail="Failed to get SIF indexing health")
# New comprehensive SEO analysis endpoints
async def analyze_seo_comprehensive(request: SEOAnalysisRequest) -> SEOAnalysisResponse:
"""

View File

@@ -0,0 +1,73 @@
"""
Story Project API Models
Pydantic models for Story Studio project endpoints.
"""
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field
class StoryProjectResponse(BaseModel):
id: int
project_id: str
user_id: str
title: Optional[str] = None
story_mode: Optional[str] = None
story_template: Optional[str] = None
setup: Optional[Dict[str, Any]] = None
outline: Optional[Dict[str, Any]] = None
scenes: Optional[List[Dict[str, Any]]] = None
story_content: Optional[Dict[str, Any]] = None
anime_bible: Optional[Dict[str, Any]] = None
media_state: Optional[Dict[str, Any]] = None
current_phase: Optional[str] = None
status: str = "draft"
is_favorite: bool = False
is_complete: bool = False
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class StoryProjectListResponse(BaseModel):
projects: List[StoryProjectResponse]
total: int
limit: int
offset: int
class CreateStoryProjectRequest(BaseModel):
project_id: str = Field(..., description="Unique story project ID")
title: Optional[str] = Field(None, description="Optional story project title or idea")
story_mode: Optional[str] = Field(
None, description="Story mode (marketing or pure) if provided by the UI"
)
story_template: Optional[str] = Field(
None,
description="Optional story template identifier (e.g. product_story, anime_fiction)",
)
setup: Optional[Dict[str, Any]] = Field(
None,
description="Initial story setup payload to persist with the project",
)
class UpdateStoryProjectRequest(BaseModel):
title: Optional[str] = None
story_mode: Optional[str] = None
story_template: Optional[str] = None
setup: Optional[Dict[str, Any]] = None
outline: Optional[Dict[str, Any]] = None
scenes: Optional[List[Dict[str, Any]]] = None
story_content: Optional[Dict[str, Any]] = None
anime_bible: Optional[Dict[str, Any]] = None
media_state: Optional[Dict[str, Any]] = None
current_phase: Optional[str] = None
status: Optional[str] = None
is_complete: Optional[bool] = None

View File

@@ -14,6 +14,7 @@ from .routes import (
media_generation,
scene_animation,
story_content,
story_projects,
story_setup,
story_tasks,
video_generation,
@@ -24,6 +25,7 @@ router = APIRouter(prefix="/api/story", tags=["Story Writer"])
# Include modular routers (order preserved roughly by workflow)
router.include_router(story_setup.router)
router.include_router(story_content.router)
router.include_router(story_projects.router)
router.include_router(story_tasks.router)
router.include_router(media_generation.router)
router.include_router(scene_animation.router)

View File

@@ -65,7 +65,7 @@ async def generate_scene_images(
scene_number=result.get("scene_number", 0),
scene_title=result.get("scene_title", "Untitled"),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
image_url=result.get("image_url") or "",
width=result.get("width", 1024),
height=result.get("height", 1024),
provider=result.get("provider", "unknown"),
@@ -148,7 +148,7 @@ async def regenerate_scene_image(
scene_number=result.get("scene_number", request.scene_number),
scene_title=result.get("scene_title", request.scene_title),
image_filename=result.get("image_filename", ""),
image_url=result.get("image_url", ""),
image_url=result.get("image_url") or "",
width=result.get("width", request.width or 1024),
height=result.get("height", request.height or 1024),
provider=result.get("provider", "unknown"),

View File

@@ -12,6 +12,10 @@ from models.story_models import (
StoryScene,
StoryContinueRequest,
StoryContinueResponse,
AnimeSceneTextRequest,
AnimeSceneTextResponse,
AnimeSceneGenerateRequest,
AnimeSceneGenerateResponse,
)
from services.story_writer.story_service import StoryWriterService
@@ -107,6 +111,7 @@ async def generate_story_start(
content_rating=request.content_rating,
ending_preference=request.ending_preference,
story_length=story_length,
anime_bible=getattr(request, "anime_bible", None),
user_id=user_id,
)
@@ -211,6 +216,7 @@ async def continue_story(
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
ending_preference=request.ending_preference,
anime_bible=getattr(request, "anime_bible", None),
story_length=story_length,
user_id=user_id,
)
@@ -245,6 +251,105 @@ async def continue_story(
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/anime/scene-text", response_model=AnimeSceneTextResponse)
async def refine_anime_scene_text(
request: AnimeSceneTextRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimeSceneTextResponse:
try:
user_id = require_authenticated_user(current_user)
scene_dict = request.scene.dict()
if not scene_dict.get("title") and not scene_dict.get("description"):
raise HTTPException(status_code=400, detail="Scene title or description is required")
refined = story_service.refine_anime_scene_text(
scene=scene_dict,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
anime_bible=request.anime_bible,
user_id=user_id,
)
refined_scene = StoryScene(
scene_number=refined.get("scene_number", request.scene.scene_number),
title=refined.get("title", request.scene.title),
description=refined.get("description", request.scene.description),
image_prompt=refined.get("image_prompt", request.scene.image_prompt),
audio_narration=refined.get("audio_narration", request.scene.audio_narration),
character_descriptions=refined.get(
"character_descriptions", request.scene.character_descriptions
),
key_events=refined.get("key_events", request.scene.key_events),
)
return AnimeSceneTextResponse(scene=refined_scene, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to refine anime scene text: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/anime/scene-generate", response_model=AnimeSceneGenerateResponse)
async def generate_anime_scene_from_bible(
request: AnimeSceneGenerateRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> AnimeSceneGenerateResponse:
try:
user_id = require_authenticated_user(current_user)
if not request.anime_bible:
raise HTTPException(status_code=400, detail="Anime story bible is required")
previous_scenes_payload: Optional[List[Dict[str, Any]]] = None
if request.previous_scenes:
previous_scenes_payload = [scene.dict() for scene in request.previous_scenes]
generated = story_service.generate_anime_scene_from_bible(
premise=request.premise,
persona=request.persona,
story_setting=request.story_setting,
character_input=request.character_input,
plot_elements=request.plot_elements,
writing_style=request.writing_style,
story_tone=request.story_tone,
narrative_pov=request.narrative_pov,
audience_age_group=request.audience_age_group,
content_rating=request.content_rating,
anime_bible=request.anime_bible,
previous_scenes=previous_scenes_payload,
target_scene_number=request.target_scene_number,
user_id=user_id,
)
scene = StoryScene(
scene_number=generated.get("scene_number"),
title=generated.get("title", ""),
description=generated.get("description", ""),
image_prompt=generated.get("image_prompt", ""),
audio_narration=generated.get("audio_narration", ""),
character_descriptions=generated.get("character_descriptions") or [],
key_events=generated.get("key_events") or [],
)
return AnimeSceneGenerateResponse(scene=scene, success=True)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate anime scene from bible: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
class SceneApprovalRequest(BaseModel):
project_id: str = Field(..., min_length=1)
scene_id: str = Field(..., min_length=1)

View File

@@ -0,0 +1,189 @@
from typing import Any, Dict, Optional
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from middleware.auth_middleware import get_current_user
from services.database import get_db
from services.story_writer.story_project_service import StoryProjectService
from ..models_projects import (
CreateStoryProjectRequest,
StoryProjectListResponse,
StoryProjectResponse,
UpdateStoryProjectRequest,
)
router = APIRouter()
@router.post("/projects", response_model=StoryProjectResponse, status_code=201)
async def create_story_project(
request: CreateStoryProjectRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryProjectResponse:
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = StoryProjectService(db)
existing = service.get_project(user_id, request.project_id)
if existing:
raise HTTPException(status_code=400, detail="Project ID already exists")
project = service.create_project(
user_id=user_id,
project_id=request.project_id,
title=request.title,
story_mode=request.story_mode,
story_template=request.story_template,
setup=request.setup,
)
return StoryProjectResponse.model_validate(project)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating story project: {str(e)}")
@router.get("/projects/{project_id}", response_model=StoryProjectResponse)
async def get_story_project(
project_id: str,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryProjectResponse:
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = StoryProjectService(db)
project = service.get_project(user_id, project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return StoryProjectResponse.model_validate(project)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error fetching story project: {str(e)}")
@router.put("/projects/{project_id}", response_model=StoryProjectResponse)
async def update_story_project(
project_id: str,
request: UpdateStoryProjectRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryProjectResponse:
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = StoryProjectService(db)
updates = request.model_dump(exclude_unset=True)
project = service.update_project(user_id, project_id, **updates)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return StoryProjectResponse.model_validate(project)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error updating story project: {str(e)}")
@router.get("/projects", response_model=StoryProjectListResponse)
async def list_story_projects(
status: Optional[str] = Query(None, description="Filter by status"),
favorites_only: bool = Query(False, description="Only favorites"),
limit: int = Query(50, ge=1, le=200),
offset: int = Query(0, ge=0),
order_by: str = Query("updated_at", description="Order by: updated_at or created_at"),
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryProjectListResponse:
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
if order_by not in ["updated_at", "created_at"]:
raise HTTPException(status_code=400, detail="order_by must be 'updated_at' or 'created_at'")
service = StoryProjectService(db)
projects, total = service.list_projects(
user_id=user_id,
status=status,
favorites_only=favorites_only,
limit=limit,
offset=offset,
order_by=order_by,
)
return StoryProjectListResponse(
projects=[StoryProjectResponse.model_validate(p) for p in projects],
total=total,
limit=limit,
offset=offset,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error listing story projects: {str(e)}")
@router.delete("/projects/{project_id}", status_code=204)
async def delete_story_project(
project_id: str,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
) -> None:
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = StoryProjectService(db)
deleted = service.delete_project(user_id, project_id)
if not deleted:
raise HTTPException(status_code=404, detail="Project not found")
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error deleting story project: {str(e)}")
@router.post("/projects/{project_id}/favorite", response_model=StoryProjectResponse)
async def toggle_story_project_favorite(
project_id: str,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryProjectResponse:
try:
user_id = current_user.get("user_id") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found")
service = StoryProjectService(db)
project = service.toggle_favorite(user_id, project_id)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return StoryProjectResponse.model_validate(project)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error toggling story project favorite: {str(e)}")

View File

@@ -2,6 +2,8 @@ from typing import Any, Dict, List
from fastapi import APIRouter, Depends, HTTPException
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy import desc
from middleware.auth_middleware import get_current_user
from models.story_models import (
@@ -13,8 +15,14 @@ from models.story_models import (
StoryScene,
StoryStartRequest,
StoryPremiseResponse,
StoryIdeaEnhanceRequest,
StoryIdeaEnhanceResponse,
StoryIdeaEnhanceSuggestion,
)
from services.story_writer.story_service import StoryWriterService
from api.onboarding_utils.onboarding_summary_service import OnboardingSummaryService
from services.database import get_session_for_user
from models.content_asset_models import ContentAsset, AssetType, AssetSource
from ..utils.auth import require_authenticated_user
@@ -39,6 +47,9 @@ async def generate_story_setup(
options = story_service.generate_story_setup_options(
story_idea=request.story_idea,
story_mode=request.story_mode,
story_template=request.story_template,
brand_context=request.brand_context,
user_id=user_id,
)
@@ -52,6 +63,152 @@ async def generate_story_setup(
raise HTTPException(status_code=500, detail=str(exc))
@router.post("/enhance-idea", response_model=StoryIdeaEnhanceResponse)
async def enhance_story_idea(
request: StoryIdeaEnhanceRequest,
current_user: Dict[str, Any] = Depends(get_current_user),
) -> StoryIdeaEnhanceResponse:
try:
user_id = require_authenticated_user(current_user)
if not request.story_idea or not request.story_idea.strip():
raise HTTPException(status_code=400, detail="Story idea is required")
logger.info(f"[StoryWriter] Enhancing story idea for user {user_id}")
suggestions = story_service.enhance_story_idea(
story_idea=request.story_idea,
story_mode=request.story_mode,
story_template=request.story_template,
brand_context=request.brand_context,
user_id=user_id,
fiction_variant=request.fiction_variant,
narrative_energy=request.narrative_energy,
)
return StoryIdeaEnhanceResponse(
suggestions=[StoryIdeaEnhanceSuggestion(**s) for s in suggestions],
success=True,
)
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to enhance story idea: {exc}")
raise HTTPException(status_code=500, detail=str(exc))
@router.get("/context")
async def get_story_context(
current_user: Dict[str, Any] = Depends(get_current_user),
) -> Dict[str, Any]:
"""Return onboarding-based story context for the current user."""
try:
user_id = require_authenticated_user(current_user)
summary_service = OnboardingSummaryService(user_id)
summary = await summary_service.get_onboarding_summary()
canonical_profile = summary.get("canonical_profile") or {}
persona_readiness = summary.get("persona_readiness") or {}
capabilities = summary.get("capabilities") or {}
website_url = summary.get("website_url")
style_analysis = summary.get("style_analysis") or {}
research_preferences = summary.get("research_preferences") or {}
brand_name = None
if isinstance(style_analysis, dict):
brand_name = style_analysis.get("brand_name") or style_analysis.get("site_title")
writing_tone = canonical_profile.get("writing_tone")
target_audience = canonical_profile.get("target_audience")
brand_context = {
"brand_name": brand_name,
"writing_tone": writing_tone,
"target_audience": target_audience,
}
avatar_url = None
voice_preview_url = None
custom_voice_id = None
db: Session | None = get_session_for_user(user_id)
if db:
try:
avatar_asset = (
db.query(ContentAsset)
.filter(
ContentAsset.user_id == user_id,
ContentAsset.asset_type == AssetType.IMAGE,
ContentAsset.source_module.in_(
[AssetSource.BRAND_AVATAR_GENERATOR, AssetSource.STORY_WRITER]
),
)
.order_by(desc(ContentAsset.created_at))
.limit(50)
.all()
)
selected_avatar = None
for candidate in avatar_asset:
if candidate.source_module == AssetSource.BRAND_AVATAR_GENERATOR:
selected_avatar = candidate
break
meta = candidate.asset_metadata or {}
if meta.get("category") == "brand_avatar":
selected_avatar = candidate
break
if selected_avatar:
avatar_url = selected_avatar.file_url
voice_asset = (
db.query(ContentAsset)
.filter(
ContentAsset.user_id == user_id,
ContentAsset.asset_type == AssetType.AUDIO,
ContentAsset.source_module == AssetSource.VOICE_CLONER,
)
.order_by(desc(ContentAsset.created_at))
.first()
)
if voice_asset:
meta = voice_asset.asset_metadata or {}
voice_preview_url = meta.get("preview_url") or voice_asset.file_url
custom_voice_id = meta.get("custom_voice_id")
finally:
db.close()
persona_enabled = bool(persona_readiness.get("ready")) and bool(
capabilities.get("persona_generation")
)
has_persona_context = persona_enabled and bool(
brand_name or writing_tone or target_audience or avatar_url or voice_preview_url
)
return {
"canonical_profile": canonical_profile,
"website_url": website_url,
"research_preferences": research_preferences,
"brand_context": brand_context,
"brand_assets": {
"avatar_url": avatar_url,
"voice_preview_url": voice_preview_url,
"custom_voice_id": custom_voice_id,
},
"persona_enabled": persona_enabled,
"has_persona_context": has_persona_context,
}
except HTTPException:
raise
except Exception as exc:
logger.error(f"[StoryWriter] Failed to get story context: {exc}")
raise HTTPException(status_code=500, detail="Failed to load story context")
@router.post("/generate-premise", response_model=StoryPremiseResponse)
async def generate_premise(
request: StoryGenerationRequest,
@@ -108,6 +265,9 @@ async def generate_outline(
request.story_tone,
)
# For now, treat all outlines as potentially anime-aware. The downstream
# generation logic will decide whether to actually create a bible based
# on how the prompt is interpreted (e.g., anime templates in persona).
outline = story_service.generate_outline(
premise=request.premise,
persona=request.persona,
@@ -122,15 +282,37 @@ async def generate_outline(
ending_preference=request.ending_preference,
user_id=user_id,
use_structured_output=use_structured,
include_anime_bible=True,
)
if isinstance(outline, list):
scenes: List[StoryScene] = [
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline
]
return StoryOutlineResponse(outline=scenes, success=True, is_structured=True)
anime_bible: Dict[str, Any] | None = None
outline_payload: Any = outline
return StoryOutlineResponse(outline=str(outline), success=True, is_structured=False)
if isinstance(outline, dict):
if "anime_bible" in outline:
anime_bible = outline.get("anime_bible")
if "scenes" in outline:
outline_payload = outline.get("scenes")
elif "outline" in outline:
outline_payload = outline.get("outline")
if isinstance(outline_payload, list):
scenes: List[StoryScene] = [
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline_payload
]
return StoryOutlineResponse(
outline=scenes,
success=True,
is_structured=True,
anime_bible=anime_bible,
)
return StoryOutlineResponse(
outline=str(outline_payload),
success=True,
is_structured=False,
anime_bible=anime_bible,
)
except HTTPException:
raise

View File

@@ -1,6 +1,7 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
from concurrent.futures import ThreadPoolExecutor
import json
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
from fastapi.responses import FileResponse
@@ -350,9 +351,21 @@ def execute_complete_video_generation(
Runs in a background task and performs blocking operations.
"""
try:
task_manager.update_task_status(task_id, "processing", progress=5.0, message="Starting complete video generation...")
task_manager.update_task_status(
task_id,
"processing",
progress=5.0,
message="Starting complete video generation...",
)
task_manager.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
anime_bible = request_data.get("anime_bible")
task_manager.update_task_status(
task_id,
"processing",
progress=10.0,
message="Generating story premise...",
)
premise = story_service.generate_premise(
persona=request_data["persona"],
story_setting=request_data["story_setting"],
@@ -367,7 +380,12 @@ def execute_complete_video_generation(
user_id=user_id,
)
task_manager.update_task_status(task_id, "processing", progress=20.0, message="Generating structured outline with scenes...")
task_manager.update_task_status(
task_id,
"processing",
progress=20.0,
message="Generating structured outline with scenes...",
)
outline_scenes = story_service.generate_outline(
premise=premise,
persona=request_data["persona"],
@@ -401,6 +419,7 @@ def execute_complete_video_generation(
height=request_data.get("image_height", 1024),
model=request_data.get("image_model"),
progress_callback=image_progress_callback,
anime_bible=anime_bible,
)
task_manager.update_task_status(task_id, "processing", progress=50.0, message="Generating audio narration for scenes...")

View File

@@ -140,7 +140,7 @@ class TaskManager:
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
user_id=user_id,
)
# Step 2: Generate outline
@@ -157,7 +157,7 @@ class TaskManager:
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
user_id=user_id,
)
# Step 3: Generate story start
@@ -175,7 +175,8 @@ class TaskManager:
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
anime_bible=request_data.get("anime_bible"),
user_id=user_id,
)
# Step 4: Continue story
@@ -208,7 +209,8 @@ class TaskManager:
audience_age_group=request_data["audience_age_group"],
content_rating=request_data["content_rating"],
ending_preference=request_data["ending_preference"],
user_id=user_id
anime_bible=request_data.get("anime_bible"),
user_id=user_id,
)
if continuation:

View File

@@ -8,7 +8,7 @@ def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
Validates the current user dictionary provided by Clerk middleware and
returns the normalized user_id. Raises HTTP 401 if authentication fails.
"""
if not current_user:
if not current_user or not isinstance(current_user, dict):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
user_id = str(current_user.get("id", "")).strip()

View File

@@ -12,13 +12,11 @@ from services.user_workspace_manager import UserWorkspaceManager
BASE_DIR = Path(__file__).resolve().parents[4] # root/
DATA_MEDIA_DIR = BASE_DIR / "workspace" / "media"
# Default global media directory matches story image/audio services (root/data/media)
DATA_MEDIA_DIR = BASE_DIR / "data" / "media"
STORY_IMAGES_DIR = (DATA_MEDIA_DIR / "story_images").resolve()
# STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True) # Disabled global creation
STORY_AUDIO_DIR = (DATA_MEDIA_DIR / "story_audio").resolve()
# STORY_AUDIO_DIR.mkdir(parents=True, exist_ok=True) # Disabled global creation
def _get_user_media_path(user_id: str, media_type: str) -> Optional[Path]:

View File

@@ -12,7 +12,10 @@ from .routes import (
alerts,
dashboard,
logs,
preflight
preflight,
payment,
disputes,
fraud_warnings,
)
# Create main router
@@ -26,5 +29,8 @@ router.include_router(alerts.router, tags=["subscription"])
router.include_router(dashboard.router, tags=["subscription"])
router.include_router(logs.router, tags=["subscription"])
router.include_router(preflight.router, tags=["subscription"])
router.include_router(payment.router, tags=["subscription"])
router.include_router(disputes.router, tags=["subscription"])
router.include_router(fraud_warnings.router, tags=["subscription"])
__all__ = ["router"]

View File

@@ -3,6 +3,6 @@ Subscription API Routes
All route modules are imported here for easy access.
"""
from . import usage, plans, subscriptions, alerts, dashboard, logs, preflight
from . import usage, plans, subscriptions, alerts, dashboard, logs, preflight, payment, disputes
__all__ = ["usage", "plans", "subscriptions", "alerts", "dashboard", "logs", "preflight"]
__all__ = ["usage", "plans", "subscriptions", "alerts", "dashboard", "logs", "preflight", "payment", "disputes"]

View File

@@ -0,0 +1,142 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional
from pydantic import BaseModel
from services.database import get_db
from middleware.auth_middleware import get_current_user
from loguru import logger
import stripe
import os
router = APIRouter()
def _ensure_admin(current_user: Dict[str, Any]) -> None:
disable_auth = os.getenv("DISABLE_AUTH", "false").lower() == "true"
if disable_auth:
return
email = (current_user.get("email") or "").lower()
role = None
public_metadata = current_user.get("public_metadata")
if isinstance(public_metadata, dict):
role = public_metadata.get("role") or current_user.get("role")
else:
role = current_user.get("role")
admin_emails_raw = os.getenv("ADMIN_EMAILS", "")
admin_emails = {
e.strip().lower() for e in admin_emails_raw.split(",") if e.strip()
}
admin_domain = (os.getenv("ADMIN_EMAIL_DOMAIN") or "").lower().strip()
is_admin_email = email and email in admin_emails
is_admin_domain = email and admin_domain and email.endswith("@" + admin_domain)
is_admin_role = role == "admin"
if not (is_admin_email or is_admin_domain or is_admin_role):
raise HTTPException(status_code=403, detail="Admin access required for dispute operations")
def _get_stripe_client() -> None:
api_key = os.getenv("STRIPE_SECRET_KEY")
if not api_key:
logger.error("STRIPE_SECRET_KEY is not configured; dispute operations are disabled")
raise HTTPException(status_code=500, detail="Payment service not configured")
stripe.api_key = api_key
class DisputeEvidenceUpdateRequest(BaseModel):
evidence: Optional[Dict[str, Any]] = None
@router.get("/disputes")
async def list_disputes(
limit: int = 10,
starting_after: Optional[str] = None,
ending_before: Optional[str] = None,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_get_stripe_client()
_ensure_admin(current_user)
try:
params: Dict[str, Any] = {"limit": max(1, min(limit, 100))}
if starting_after:
params["starting_after"] = starting_after
if ending_before:
params["ending_before"] = ending_before
disputes = stripe.Dispute.list(**params)
return {"data": disputes}
except Exception as e:
logger.error(f"Error listing disputes: {e}")
raise HTTPException(status_code=500, detail="Failed to list disputes")
@router.get("/disputes/{dispute_id}")
async def get_dispute(
dispute_id: str,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_get_stripe_client()
_ensure_admin(current_user)
try:
dispute = stripe.Dispute.retrieve(dispute_id)
return {"data": dispute}
except stripe.error.InvalidRequestError as e:
logger.warning(f"Invalid dispute id {dispute_id}: {e}")
raise HTTPException(status_code=404, detail="Dispute not found")
except Exception as e:
logger.error(f"Error retrieving dispute {dispute_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to retrieve dispute")
@router.post("/disputes/{dispute_id}")
async def update_dispute(
dispute_id: str,
payload: DisputeEvidenceUpdateRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_get_stripe_client()
_ensure_admin(current_user)
if not payload.evidence:
raise HTTPException(status_code=400, detail="Evidence payload is required to update a dispute")
try:
dispute = stripe.Dispute.modify(
dispute_id,
evidence=payload.evidence,
)
return {"data": dispute}
except stripe.error.InvalidRequestError as e:
logger.warning(f"Invalid dispute id {dispute_id} during update: {e}")
raise HTTPException(status_code=404, detail="Dispute not found")
except Exception as e:
logger.error(f"Error updating dispute {dispute_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to update dispute")
@router.post("/disputes/{dispute_id}/close")
async def close_dispute(
dispute_id: str,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_get_stripe_client()
_ensure_admin(current_user)
try:
dispute = stripe.Dispute.close(dispute_id)
return {"data": dispute}
except stripe.error.InvalidRequestError as e:
logger.warning(f"Invalid dispute id {dispute_id} during close: {e}")
raise HTTPException(status_code=404, detail="Dispute not found")
except Exception as e:
logger.error(f"Error closing dispute {dispute_id}: {e}")
raise HTTPException(status_code=500, detail="Failed to close dispute")

View File

@@ -0,0 +1,209 @@
from fastapi import APIRouter, Depends, HTTPException
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional
from pydantic import BaseModel
from services.database import get_db
from middleware.auth_middleware import get_current_user
from loguru import logger
import stripe
import os
from datetime import datetime
from models.subscription_models import FraudWarning
router = APIRouter()
def _ensure_admin(current_user: Dict[str, Any]) -> None:
disable_auth = os.getenv("DISABLE_AUTH", "false").lower() == "true"
if disable_auth:
return
email = (current_user.get("email") or "").lower()
role = None
public_metadata = current_user.get("public_metadata")
if isinstance(public_metadata, dict):
role = public_metadata.get("role") or current_user.get("role")
else:
role = current_user.get("role")
admin_emails_raw = os.getenv("ADMIN_EMAILS", "")
admin_emails = {
e.strip().lower() for e in admin_emails_raw.split(",") if e.strip()
}
admin_domain = (os.getenv("ADMIN_EMAIL_DOMAIN") or "").lower().strip()
is_admin_email = email and email in admin_emails
is_admin_domain = email and admin_domain and email.endswith("@" + admin_domain)
is_admin_role = role == "admin"
if not (is_admin_email or is_admin_domain or is_admin_role):
raise HTTPException(status_code=403, detail="Admin access required for fraud warning operations")
def _get_stripe_client() -> None:
api_key = os.getenv("STRIPE_SECRET_KEY")
if not api_key:
logger.error("STRIPE_SECRET_KEY is not configured; fraud warning operations are disabled")
raise HTTPException(status_code=500, detail="Payment service not configured")
stripe.api_key = api_key
class FraudWarningRefundRequest(BaseModel):
notes: Optional[str] = None
class FraudWarningIgnoreRequest(BaseModel):
notes: Optional[str] = None
@router.get("/fraud-warnings")
async def list_fraud_warnings(
status: Optional[str] = "open",
limit: int = 20,
offset: int = 0,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_ensure_admin(current_user)
query = db.query(FraudWarning)
if status:
query = query.filter(FraudWarning.status == status)
limit = max(1, min(limit, 100))
items = (
query.order_by(FraudWarning.created_at.desc())
.offset(max(0, offset))
.limit(limit)
.all()
)
data = []
for fw in items:
data.append(
{
"id": fw.id,
"charge_id": fw.charge_id,
"payment_intent_id": fw.payment_intent_id,
"user_id": fw.user_id,
"amount": fw.amount,
"currency": fw.currency,
"status": fw.status,
"action": fw.action,
"action_at": fw.action_at.isoformat() if fw.action_at else None,
"created_at": fw.created_at.isoformat() if fw.created_at else None,
}
)
return {"data": data}
@router.get("/fraud-warnings/{warning_id}")
async def get_fraud_warning(
warning_id: str,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_ensure_admin(current_user)
fw = db.query(FraudWarning).filter(FraudWarning.id == warning_id).first()
if not fw:
raise HTTPException(status_code=404, detail="Fraud warning not found")
payload: Dict[str, Any] = {
"id": fw.id,
"charge_id": fw.charge_id,
"payment_intent_id": fw.payment_intent_id,
"user_id": fw.user_id,
"amount": fw.amount,
"currency": fw.currency,
"status": fw.status,
"action": fw.action,
"action_at": fw.action_at.isoformat() if fw.action_at else None,
"reason_notes": fw.reason_notes,
"created_at": fw.created_at.isoformat() if fw.created_at else None,
"meta_info": fw.meta_info,
}
return {"data": payload}
@router.post("/fraud-warnings/{warning_id}/refund")
async def refund_fraud_warning(
warning_id: str,
payload: FraudWarningRefundRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_ensure_admin(current_user)
_get_stripe_client()
fw = db.query(FraudWarning).filter(FraudWarning.id == warning_id).first()
if not fw:
raise HTTPException(status_code=404, detail="Fraud warning not found")
if fw.status == "refunded":
raise HTTPException(status_code=400, detail="Fraud warning already refunded")
try:
stripe.Refund.create(charge=fw.charge_id)
except stripe.error.InvalidRequestError as e:
logger.warning(f"Refund failed for fraud warning {warning_id}, charge {fw.charge_id}: {e}")
raise HTTPException(status_code=400, detail="Refund failed for this charge")
except Exception as e:
logger.error(f"Unexpected error refunding fraud warning {warning_id}: {e}")
raise HTTPException(status_code=500, detail="Unexpected error while processing refund")
fw.status = "refunded"
fw.action = "refund_full"
fw.action_at = datetime.utcnow()
if payload and payload.notes:
fw.reason_notes = payload.notes
db.commit()
db.refresh(fw)
return {
"data": {
"id": fw.id,
"status": fw.status,
"action": fw.action,
"action_at": fw.action_at.isoformat() if fw.action_at else None,
"reason_notes": fw.reason_notes,
}
}
@router.post("/fraud-warnings/{warning_id}/ignore")
async def ignore_fraud_warning(
warning_id: str,
payload: FraudWarningIgnoreRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
):
_ensure_admin(current_user)
fw = db.query(FraudWarning).filter(FraudWarning.id == warning_id).first()
if not fw:
raise HTTPException(status_code=404, detail="Fraud warning not found")
fw.status = "ignored"
fw.action = "ignored"
fw.action_at = datetime.utcnow()
if payload and payload.notes:
fw.reason_notes = payload.notes
db.commit()
db.refresh(fw)
return {
"data": {
"id": fw.id,
"status": fw.status,
"action": fw.action,
"action_at": fw.action_at.isoformat() if fw.action_at else None,
"reason_notes": fw.reason_notes,
}
}

View File

@@ -0,0 +1,125 @@
from fastapi import APIRouter, Depends, HTTPException, Request, Header, BackgroundTasks
from sqlalchemy.orm import Session
from typing import Dict, Any, Optional
from pydantic import BaseModel
from services.database import get_db
from services.subscription.stripe_service import StripeService
from middleware.auth_middleware import get_current_user
from loguru import logger
from models.subscription_models import SubscriptionTier, BillingCycle
import time
from collections import defaultdict
router = APIRouter()
class CreateCheckoutSessionRequest(BaseModel):
tier: SubscriptionTier
billing_cycle: BillingCycle
success_url: str
cancel_url: str
class CreatePortalSessionRequest(BaseModel):
return_url: str
_checkout_rate_limit_window_seconds = 60
_checkout_rate_limit_max_requests = 10
_checkout_attempts_by_user: Dict[str, Any] = defaultdict(list)
@router.post("/create-checkout-session")
async def create_checkout_session(
payload: CreateCheckoutSessionRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user),
request: Request = None
):
"""
Create a Stripe Checkout Session for subscription.
"""
user_id = current_user.get("sub") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User not authenticated")
now = time.time()
attempts = _checkout_attempts_by_user[user_id]
window_start = now - _checkout_rate_limit_window_seconds
attempts[:] = [ts for ts in attempts if ts >= window_start]
attempts.append(now)
_checkout_attempts_by_user[user_id] = attempts
if len(attempts) > _checkout_rate_limit_max_requests:
client_ip = request.client.host if request and request.client else "unknown"
logger.warning(f"Checkout rate limit exceeded for user_id={user_id}, ip={client_ip}, attempts={len(attempts)} in { _checkout_rate_limit_window_seconds }s")
raise HTTPException(status_code=429, detail="Too many checkout attempts. Please try again shortly.")
user_email = current_user.get("email")
stripe_service = StripeService(db)
try:
url = stripe_service.create_checkout_session(
user_id=user_id,
tier=payload.tier,
billing_cycle=payload.billing_cycle,
success_url=payload.success_url,
cancel_url=payload.cancel_url,
user_email=user_email
)
return {"url": url}
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Error creating checkout session: {e}")
raise HTTPException(status_code=500, detail="Failed to initiate checkout")
@router.post("/create-portal-session")
async def create_portal_session(
payload: CreatePortalSessionRequest,
db: Session = Depends(get_db),
current_user: Dict[str, Any] = Depends(get_current_user)
):
"""
Create a Stripe Customer Portal session for managing billing.
"""
user_id = current_user.get("sub") or current_user.get("id")
if not user_id:
raise HTTPException(status_code=401, detail="User not authenticated")
stripe_service = StripeService(db)
try:
url = stripe_service.create_portal_session(
user_id=user_id,
return_url=payload.return_url
)
return {"url": url}
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Error creating portal session: {e}")
raise HTTPException(status_code=500, detail="Failed to access billing portal")
@router.post("/webhook")
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(None),
db: Session = Depends(get_db)
):
"""
Handle Stripe webhooks.
"""
if not stripe_signature:
raise HTTPException(status_code=400, detail="Missing stripe-signature header")
payload = await request.body()
stripe_service = StripeService(db)
try:
# We need to run this potentially in background or await it
# Since it's async, we can await it directly.
await stripe_service.handle_webhook(payload, stripe_signature)
return {"status": "success"}
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"Error processing webhook: {e}")
raise HTTPException(status_code=500, detail="Webhook processing failed")