Release Candidate: Production Release with Multi-Tenant & Onboarding Enhancements
This commit is contained in:
@@ -641,25 +641,18 @@ async def complete_style_detection(
|
||||
return await loop.run_in_executor(None, partial(style_logic.perform_seo_audit, request.url, crawl_result['content']))
|
||||
|
||||
async def run_sitemap_analysis():
|
||||
"""Run AI sitemap analysis for home page"""
|
||||
if not request.url:
|
||||
return None
|
||||
try:
|
||||
# Discover sitemap URL
|
||||
sitemap_url = await sitemap_service.discover_sitemap_url(request.url)
|
||||
if sitemap_url:
|
||||
# Analyze sitemap with AI insights
|
||||
return await sitemap_service.analyze_sitemap(
|
||||
sitemap_url=sitemap_url,
|
||||
analyze_content_trends=True,
|
||||
analyze_publishing_patterns=True,
|
||||
include_ai_insights=True,
|
||||
user_id=user_id
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Sitemap analysis failed: {e}")
|
||||
return None
|
||||
sitemap_url = await sitemap_service.discover_sitemap_url(request.url)
|
||||
if sitemap_url:
|
||||
return await sitemap_service.analyze_sitemap(
|
||||
sitemap_url=sitemap_url,
|
||||
analyze_content_trends=True,
|
||||
analyze_publishing_patterns=True,
|
||||
include_ai_insights=True,
|
||||
user_id=user_id
|
||||
)
|
||||
return None
|
||||
|
||||
# Execute style, patterns, SEO analysis and sitemap analysis in parallel
|
||||
style_analysis, patterns_result, seo_audit_result, sitemap_result = await asyncio.gather(
|
||||
@@ -710,15 +703,17 @@ async def complete_style_detection(
|
||||
elif isinstance(seo_audit_result, Exception):
|
||||
logger.warning(f"SEO audit failed: {seo_audit_result}")
|
||||
|
||||
# Process sitemap analysis result
|
||||
sitemap_analysis = None
|
||||
sitemap_warning = None
|
||||
if sitemap_result and not isinstance(sitemap_result, Exception):
|
||||
sitemap_analysis = sitemap_result
|
||||
elif isinstance(sitemap_result, Exception):
|
||||
logger.warning(f"Sitemap analysis failed: {sitemap_result}")
|
||||
sitemap_warning = f"Sitemap analysis failed: {sitemap_result}"
|
||||
|
||||
# Step 4: Generate guidelines (depends on style_analysis, must run after)
|
||||
style_guidelines = None
|
||||
guidelines_result = None
|
||||
if request.include_guidelines:
|
||||
loop = asyncio.get_event_loop()
|
||||
guidelines_result = await loop.run_in_executor(
|
||||
@@ -728,10 +723,14 @@ async def complete_style_detection(
|
||||
if guidelines_result and guidelines_result.get('success'):
|
||||
style_guidelines = guidelines_result.get('guidelines')
|
||||
|
||||
# Check if there's a warning about fallback data
|
||||
warning = None
|
||||
warning_parts = []
|
||||
if style_analysis and 'warning' in style_analysis:
|
||||
warning = style_analysis['warning']
|
||||
warning_parts.append(style_analysis['warning'])
|
||||
if request.include_guidelines and guidelines_result and not guidelines_result.get('success') and guidelines_result.get('error'):
|
||||
warning_parts.append(f"Guidelines generation failed: {guidelines_result.get('error')}")
|
||||
if sitemap_warning:
|
||||
warning_parts.append(sitemap_warning)
|
||||
warning = " | ".join(warning_parts) if warning_parts else None
|
||||
|
||||
# Prepare response data
|
||||
response_data = {
|
||||
@@ -1000,4 +999,4 @@ async def get_style_detection_configuration():
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[get_style_detection_configuration] Error: {str(e)}")
|
||||
return {"error": f"Configuration error: {str(e)}"}
|
||||
return {"error": f"Configuration error: {str(e)}"}
|
||||
|
||||
@@ -5,6 +5,7 @@ Handles the complex logic for completing the onboarding process.
|
||||
|
||||
from typing import Dict, Any, List
|
||||
from datetime import datetime, timedelta
|
||||
import os
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
@@ -306,18 +307,20 @@ class OnboardingCompletionService:
|
||||
db = get_session_for_user(user_id)
|
||||
integration_service = OnboardingDataIntegrationService()
|
||||
|
||||
# Debug logging
|
||||
logger.info(f"Validating steps for user {user_id}")
|
||||
|
||||
# Get integrated data
|
||||
integrated_data = await integration_service.process_onboarding_data(user_id, db)
|
||||
db.close()
|
||||
|
||||
from services.onboarding.progress_service import OnboardingProgressService
|
||||
progress_service = OnboardingProgressService()
|
||||
status = progress_service.get_onboarding_status(user_id)
|
||||
current_step = status.get("current_step", 1)
|
||||
|
||||
# Check each required step
|
||||
for step_num in self.required_steps:
|
||||
step_completed = False
|
||||
|
||||
if step_num == 1: # API Keys
|
||||
if step_num == 1:
|
||||
api_keys_data = integrated_data.get('api_keys_data', {})
|
||||
logger.info(f"Step 1 - API Keys: {api_keys_data}")
|
||||
step_completed = bool(
|
||||
@@ -325,26 +328,49 @@ class OnboardingCompletionService:
|
||||
api_keys_data.get('anthropic_api_key') or
|
||||
api_keys_data.get('google_api_key')
|
||||
)
|
||||
if not step_completed:
|
||||
has_global_providers = bool(
|
||||
os.getenv("EXA_API_KEY") or
|
||||
os.getenv("GEMINI_API_KEY") or
|
||||
os.getenv("OPENAI_API_KEY") or
|
||||
os.getenv("ANTHROPIC_API_KEY") or
|
||||
os.getenv("GOOGLE_API_KEY")
|
||||
)
|
||||
if has_global_providers:
|
||||
step_completed = True
|
||||
logger.info(f"Step 1 completed: {step_completed}")
|
||||
elif step_num == 2: # Website Analysis
|
||||
elif step_num == 2:
|
||||
website = integrated_data.get('website_analysis', {})
|
||||
logger.info(f"Step 2 - Website Analysis: {website}")
|
||||
step_completed = bool(website and (website.get('website_url') or website.get('writing_style')))
|
||||
logger.info(f"Step 2 completed: {step_completed}")
|
||||
elif step_num == 3: # Research Preferences
|
||||
elif step_num == 3:
|
||||
research = integrated_data.get('research_preferences', {})
|
||||
logger.info(f"Step 3 - Research Preferences: {research}")
|
||||
step_completed = bool(research and (research.get('research_depth') or research.get('content_types')))
|
||||
logger.info(f"Step 3 completed: {step_completed}")
|
||||
elif step_num == 4: # Persona Generation
|
||||
elif step_num == 4:
|
||||
persona = integrated_data.get('persona_data', {})
|
||||
logger.info(f"Step 4 - Persona Data: {persona}")
|
||||
step_completed = bool(persona and (persona.get('corePersona') or persona.get('platformPersonas')))
|
||||
if not step_completed:
|
||||
website = integrated_data.get('website_analysis', {})
|
||||
research = integrated_data.get('research_preferences', {})
|
||||
basic_ready = bool(
|
||||
website and (website.get('website_url') or website.get('writing_style'))
|
||||
) and bool(research)
|
||||
if basic_ready:
|
||||
step_completed = True
|
||||
logger.info(f"Step 4 completed: {step_completed}")
|
||||
elif step_num == 5: # Integrations
|
||||
# For now, consider this always completed if we reach this point
|
||||
elif step_num == 5:
|
||||
step_completed = True
|
||||
logger.info(f"Step 5 completed: {step_completed}")
|
||||
|
||||
if not step_completed and current_step >= step_num:
|
||||
step_completed = True
|
||||
logger.info(
|
||||
f"Step {step_num} marked completed based on progress service (current_step={current_step})"
|
||||
)
|
||||
|
||||
if not step_completed:
|
||||
missing_steps.append(f"Step {step_num}")
|
||||
@@ -357,20 +383,34 @@ class OnboardingCompletionService:
|
||||
return ["Validation error"]
|
||||
|
||||
async def _validate_api_keys(self, user_id: str):
|
||||
"""Validate that API keys are configured for the current user (SSOT)."""
|
||||
"""Validate that API keys are configured for the current user (SSOT or environment)."""
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
integration_service = OnboardingDataIntegrationService()
|
||||
integrated_data = await integration_service.process_onboarding_data(user_id, db)
|
||||
db.close()
|
||||
try:
|
||||
integration_service = OnboardingDataIntegrationService()
|
||||
integrated_data = await integration_service.process_onboarding_data(user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
api_keys_data = integrated_data.get('api_keys_data', {})
|
||||
api_keys_data = integrated_data.get('api_keys_data', {}) if integrated_data else {}
|
||||
|
||||
has_keys = bool(
|
||||
has_user_keys = bool(
|
||||
api_keys_data.get('openai_api_key') or
|
||||
api_keys_data.get('anthropic_api_key') or
|
||||
api_keys_data.get('google_api_key')
|
||||
api_keys_data.get('google_api_key') or
|
||||
api_keys_data.get('exa_api_key') or
|
||||
api_keys_data.get('gemini_api_key')
|
||||
)
|
||||
|
||||
has_env_keys = bool(
|
||||
os.getenv("OPENAI_API_KEY") or
|
||||
os.getenv("ANTHROPIC_API_KEY") or
|
||||
os.getenv("GOOGLE_API_KEY") or
|
||||
os.getenv("EXA_API_KEY") or
|
||||
os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
|
||||
has_keys = has_user_keys or has_env_keys
|
||||
|
||||
if not has_keys:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -8,7 +8,7 @@ from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.onboarding.api_key_manager import get_api_key_manager
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.website_analysis_service import WebsiteAnalysisService
|
||||
from services.research_preferences_service import ResearchPreferencesService
|
||||
from services.persona_analysis_service import PersonaAnalysisService
|
||||
@@ -32,10 +32,13 @@ class OnboardingSummaryService:
|
||||
async def get_onboarding_summary(self) -> Dict[str, Any]:
|
||||
"""Get comprehensive onboarding summary for FinalStep."""
|
||||
try:
|
||||
# Get integrated data via SSOT
|
||||
db = next(get_db())
|
||||
integrated_data = await self.integration_service.process_onboarding_data(self.user_id, db)
|
||||
db.close()
|
||||
db = get_session_for_user(self.user_id)
|
||||
if not db:
|
||||
raise HTTPException(status_code=500, detail="Database session could not be created")
|
||||
try:
|
||||
integrated_data = await self.integration_service.process_onboarding_data(self.user_id, db)
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Extract components from integrated data
|
||||
website_analysis = integrated_data.get('website_analysis', {})
|
||||
@@ -152,15 +155,34 @@ class OnboardingSummaryService:
|
||||
|
||||
return capabilities
|
||||
|
||||
async def get_website_analysis_data(self) -> Dict[str, Any]:
|
||||
"""Get website analysis data for the user (Step 2 output)."""
|
||||
try:
|
||||
db = get_session_for_user(self.user_id)
|
||||
if not db:
|
||||
raise HTTPException(status_code=500, detail="Database session could not be created")
|
||||
try:
|
||||
integrated_data = await self.integration_service.process_onboarding_data(self.user_id, db)
|
||||
website_analysis = integrated_data.get("website_analysis") or {}
|
||||
return website_analysis
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting website analysis data: {e}")
|
||||
raise
|
||||
|
||||
async def get_research_preferences_data(self) -> Dict[str, Any]:
|
||||
"""Get research preferences data for the user."""
|
||||
try:
|
||||
db = next(get_db())
|
||||
research_prefs_service = ResearchPreferencesService(db)
|
||||
# Use the new method that accepts user_id directly
|
||||
result = research_prefs_service.get_research_preferences_by_user_id(self.user_id)
|
||||
db.close()
|
||||
return result
|
||||
db = get_session_for_user(self.user_id)
|
||||
if not db:
|
||||
raise HTTPException(status_code=500, detail="Database session could not be created")
|
||||
try:
|
||||
research_prefs_service = ResearchPreferencesService(db)
|
||||
result = research_prefs_service.get_research_preferences_by_user_id(self.user_id)
|
||||
return result
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting research preferences data: {e}")
|
||||
raise
|
||||
|
||||
@@ -7,54 +7,238 @@ Analysis endpoint for podcast ideas.
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any
|
||||
import json
|
||||
import uuid
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.llm_providers.main_image_generation import generate_image
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from utils.asset_tracker import save_asset_to_library
|
||||
from loguru import logger
|
||||
from ..models import PodcastAnalyzeRequest, PodcastAnalyzeResponse
|
||||
from ..constants import PODCAST_IMAGES_DIR
|
||||
from ..models import (
|
||||
PodcastAnalyzeRequest,
|
||||
PodcastAnalyzeResponse,
|
||||
PodcastEnhanceIdeaRequest,
|
||||
PodcastEnhanceIdeaResponse
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/idea/enhance", response_model=PodcastEnhanceIdeaResponse)
|
||||
async def enhance_podcast_idea(
|
||||
request: PodcastEnhanceIdeaRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Take raw keywords/topic and use AI to craft a presentable, detailed podcast idea.
|
||||
Uses the user's Podcast Bible for hyper-personalization if available.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Serialize Bible context if provided or generate from onboarding
|
||||
bible_context = ""
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_enhance")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Enhance] Failed to parse or generate bible context: {exc}")
|
||||
|
||||
prompt = f"""
|
||||
You are a creative podcast producer. Your goal is to take a simple podcast idea or keywords
|
||||
and transform it into a compelling, professional, and detailed episode concept.
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
|
||||
RAW IDEA/KEYWORDS: "{request.idea}"
|
||||
|
||||
TASK:
|
||||
1. Rewrite the idea into a professional, presentable 2-3 sentence episode pitch.
|
||||
2. Focus on making it sound expert-led and audience-focused.
|
||||
3. Ensure it aligns with the host's persona and target audience interests if context was provided.
|
||||
4. Keep it concise but information-rich.
|
||||
|
||||
Return JSON with:
|
||||
- enhanced_idea: the rewritten, professional episode pitch
|
||||
- rationale: 1 sentence explaining why this version works better for the target audience
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(prompt=prompt, user_id=user_id, json_struct=None)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(raw, str):
|
||||
data = json.loads(raw)
|
||||
else:
|
||||
data = raw
|
||||
|
||||
return PodcastEnhanceIdeaResponse(
|
||||
enhanced_idea=data.get("enhanced_idea", request.idea),
|
||||
rationale=data.get("rationale", "Made it more professional and listener-focused.")
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Enhance] Failed for user {user_id}: {exc}")
|
||||
return PodcastEnhanceIdeaResponse(
|
||||
enhanced_idea=request.idea,
|
||||
rationale="Failed to enhance idea with AI, using original."
|
||||
)
|
||||
|
||||
|
||||
@router.post("/analyze", response_model=PodcastAnalyzeResponse)
|
||||
async def analyze_podcast_idea(
|
||||
request: PodcastAnalyzeRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Analyze a podcast idea and return podcast-oriented outlines, keywords, and titles.
|
||||
This uses the shared LLM provider but with a podcast-specific prompt (not story format).
|
||||
If no avatar_url is provided, it generates one automatically based on the host's look.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
# Serialize Bible context if provided or generate from onboarding
|
||||
bible_context = ""
|
||||
bible_obj = None
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
if request.bible:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
bible_obj = bible_data
|
||||
else:
|
||||
# Generate from onboarding data directly
|
||||
bible_obj = bible_service.generate_bible(user_id, "temp_analyze")
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
bible_obj = bible_obj
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Analyze] Failed to parse or generate bible context: {exc}")
|
||||
|
||||
# --- NEW: Generate Presenter Avatar if missing ---
|
||||
final_avatar_url = request.avatar_url
|
||||
final_avatar_prompt = None
|
||||
|
||||
if not final_avatar_url:
|
||||
logger.info(f"[Podcast Analyze] No avatar_url provided, generating one for user {user_id}")
|
||||
try:
|
||||
# 1. PRE-FLIGHT VALIDATION: Check subscription limits for image generation
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_image_generation_operations
|
||||
pricing_service = PricingService(db)
|
||||
validate_image_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id,
|
||||
num_images=1
|
||||
)
|
||||
|
||||
# 2. Build avatar prompt from Bible host look or fallback
|
||||
host_look = bible_obj.host.look if bible_obj and bible_obj.host.look else "A professional podcast host"
|
||||
visual_style = bible_obj.visual_style.style_preset if bible_obj else "Realistic Photography"
|
||||
|
||||
final_avatar_prompt = f"Professional headshot of a podcast host, {host_look}, {visual_style} style, clean background, soft studio lighting, center-focused, high resolution, sharp focus, professional photography quality, 16:9 aspect ratio."
|
||||
|
||||
# 3. Generate the image
|
||||
logger.info(f"[Podcast Analyze] Generating avatar with prompt: {final_avatar_prompt}")
|
||||
image_result = generate_image(
|
||||
prompt=final_avatar_prompt,
|
||||
user_id=user_id,
|
||||
width=1024,
|
||||
height=1024
|
||||
)
|
||||
|
||||
# 4. Save to disk and library
|
||||
if image_result and image_result.image_bytes:
|
||||
img_id = str(uuid.uuid4())[:8]
|
||||
filename = f"presenter_podcast_{user_id}_{img_id}.png"
|
||||
output_path = PODCAST_IMAGES_DIR / filename
|
||||
PODCAST_IMAGES_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(image_result.image_bytes)
|
||||
|
||||
final_avatar_url = f"/api/podcast/images/avatars/{filename}"
|
||||
|
||||
# Save to asset library for reuse
|
||||
save_asset_to_library(
|
||||
db=db,
|
||||
user_id=user_id,
|
||||
asset_type="image",
|
||||
file_url=final_avatar_url,
|
||||
filename=filename,
|
||||
title=f"Presenter Avatar - {request.idea[:40]}",
|
||||
description=f"AI-generated podcast presenter for: {request.idea}",
|
||||
provider=image_result.provider,
|
||||
model=image_result.model,
|
||||
cost=image_result.cost
|
||||
)
|
||||
logger.info(f"[Podcast Analyze] ✅ Generated and saved avatar to {final_avatar_url}")
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast Analyze] ❌ Failed to generate avatar: {e}")
|
||||
# Non-fatal: continue analysis even if avatar generation fails
|
||||
|
||||
# --- END: Avatar Generation ---
|
||||
|
||||
# Incorporate user feedback if provided
|
||||
feedback_context = ""
|
||||
if request.feedback:
|
||||
feedback_context = f"""
|
||||
USER REGENERATION FEEDBACK:
|
||||
The user was not satisfied with the previous analysis. They provided the following instructions for improvement:
|
||||
"{request.feedback}"
|
||||
Please prioritize this feedback and adjust the analysis accordingly.
|
||||
"""
|
||||
|
||||
prompt = f"""
|
||||
You are an expert podcast producer. Given a podcast idea, craft concise podcast-ready assets
|
||||
You are an expert podcast producer and research strategist. Given a podcast idea, craft concise podcast-ready assets
|
||||
that sound like episode plans (not fiction stories).
|
||||
|
||||
{f"USER PERSONALIZATION CONTEXT (Podcast Bible):\n{bible_context}\n" if bible_context else ""}
|
||||
{feedback_context}
|
||||
|
||||
Podcast Idea: "{request.idea}"
|
||||
Duration: ~{request.duration} minutes
|
||||
Speakers: {request.speakers} (host + optional guest)
|
||||
|
||||
TASK:
|
||||
1. Define the target audience and content type aligned with the Bible's "Audience DNA" and "Brand DNA".
|
||||
2. Identify 5 high-impact keywords.
|
||||
3. Propose 2 episode outlines with factual segments.
|
||||
4. Suggest 3 titles.
|
||||
5. IMPORTANT: Generate 4-6 specific research queries for Exa. These queries MUST be highly targeted to the episode's topic, the host's expertise level, and the audience's interests as defined in the Bible.
|
||||
* Do NOT use generic queries like "latest trends in X".
|
||||
* DO use queries that look for case studies, specific data points, expert opinions, or contrasting viewpoints that would make for a deep, insightful podcast conversation.
|
||||
|
||||
Return JSON with:
|
||||
- audience: short target audience description
|
||||
- content_type: podcast style/format
|
||||
- top_keywords: 5 podcast-relevant keywords/phrases
|
||||
- suggested_outlines: 2 items, each with title (<=60 chars) and 4-6 short segments (bullet-friendly, factual)
|
||||
- title_suggestions: 3 concise episode titles (no cliffhanger storytelling)
|
||||
- exa_suggested_config: suggested Exa search options to power research (keep conservative defaults to control cost), with:
|
||||
- exa_search_type: "auto" | "neural" | "keyword" (prefer "auto" unless clearly news-heavy)
|
||||
- title_suggestions: 3 concise episode titles
|
||||
- research_queries: array of {{"query": "string", "rationale": "string"}}
|
||||
- exa_suggested_config: suggested Exa search options with:
|
||||
- exa_search_type: "auto" | "neural" | "keyword"
|
||||
- exa_category: one of ["research paper","news","company","github","tweet","personal site","pdf","financial report","linkedin profile"]
|
||||
- exa_include_domains: up to 3 reputable domains to prioritize (optional)
|
||||
- exa_exclude_domains: up to 3 domains to avoid (optional)
|
||||
- exa_include_domains: up to 3 reputable domains
|
||||
- exa_exclude_domains: up to 3 domains
|
||||
- max_sources: 6-10
|
||||
- include_statistics: boolean (true if topic needs fresh stats)
|
||||
- date_range: one of ["last_month","last_3_months","last_year","all_time"] (pick recent if time-sensitive)
|
||||
- include_statistics: boolean
|
||||
- date_range: one of ["last_month","last_3_months","last_year","all_time"]
|
||||
|
||||
Requirements:
|
||||
- Keep language factual, actionable, and suited for spoken audio.
|
||||
- Avoid narrative fiction tone; focus on insights, hooks, objections, and takeaways.
|
||||
- Prefer 2024-2025 context when relevant.
|
||||
- Avoid narrative fiction tone.
|
||||
- Prefer 2024-2025 context.
|
||||
"""
|
||||
|
||||
try:
|
||||
@@ -82,7 +266,7 @@ Requirements:
|
||||
top_keywords = data.get("top_keywords") or []
|
||||
suggested_outlines = data.get("suggested_outlines") or []
|
||||
title_suggestions = data.get("title_suggestions") or []
|
||||
|
||||
research_queries = data.get("research_queries") or []
|
||||
exa_suggested_config = data.get("exa_suggested_config") or None
|
||||
|
||||
return PodcastAnalyzeResponse(
|
||||
@@ -91,6 +275,10 @@ Requirements:
|
||||
top_keywords=top_keywords,
|
||||
suggested_outlines=suggested_outlines,
|
||||
title_suggestions=title_suggestions,
|
||||
research_queries=research_queries,
|
||||
exa_suggested_config=exa_suggested_config,
|
||||
bible=bible_obj.model_dump() if bible_obj else None,
|
||||
avatar_url=final_avatar_url,
|
||||
avatar_prompt=final_avatar_prompt,
|
||||
)
|
||||
|
||||
|
||||
@@ -86,6 +86,19 @@ async def generate_podcast_scene_image(
|
||||
logger.info(f"[Podcast] No base avatar URL provided, will generate from scratch")
|
||||
base_avatar_bytes = None
|
||||
|
||||
# Extract Podcast Bible context for hyper-personalization
|
||||
bible_context = ""
|
||||
bible_obj = None
|
||||
if request.bible:
|
||||
try:
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_service = PodcastBibleService()
|
||||
bible_obj = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Image] Failed to serialize podcast bible: {exc}")
|
||||
|
||||
# Build optimized prompt for scene image generation
|
||||
# When base avatar is provided, use Ideogram Character to maintain consistency
|
||||
# Otherwise, generate from scratch with podcast-optimized prompt
|
||||
@@ -106,6 +119,14 @@ async def generate_podcast_scene_image(
|
||||
if request.scene_title:
|
||||
prompt_parts.append(f"Scene: {request.scene_title}")
|
||||
|
||||
# Use Bible visual style if available
|
||||
if bible_obj:
|
||||
prompt_parts.append(f"Style: {bible_obj.visual_style.style_preset}")
|
||||
prompt_parts.append(f"Environment: {bible_obj.visual_style.environment}")
|
||||
prompt_parts.append(f"Lighting: {bible_obj.visual_style.lighting}")
|
||||
if bible_obj.host.look:
|
||||
prompt_parts.append(f"Host Look: {bible_obj.host.look}")
|
||||
|
||||
# Scene content insights for visual context
|
||||
if request.scene_content:
|
||||
content_preview = request.scene_content[:200].replace("\n", " ").strip()
|
||||
@@ -127,12 +148,14 @@ async def generate_podcast_scene_image(
|
||||
prompt_parts.append(f"Topic: {idea_preview}")
|
||||
|
||||
# Studio setting (maintains podcast aesthetic)
|
||||
prompt_parts.extend([
|
||||
"Professional podcast recording studio",
|
||||
"Modern microphone setup",
|
||||
"Clean background, professional lighting",
|
||||
"16:9 aspect ratio, video-optimized composition"
|
||||
])
|
||||
if not bible_obj:
|
||||
prompt_parts.extend([
|
||||
"Professional podcast recording studio",
|
||||
"Modern microphone setup",
|
||||
"Clean background, professional lighting"
|
||||
])
|
||||
|
||||
prompt_parts.append("16:9 aspect ratio, video-optimized composition")
|
||||
|
||||
image_prompt = ", ".join(prompt_parts)
|
||||
|
||||
@@ -221,14 +244,22 @@ async def generate_podcast_scene_image(
|
||||
# Standard generation from scratch (no base avatar provided)
|
||||
prompt_parts = []
|
||||
|
||||
# Core podcast studio elements
|
||||
prompt_parts.extend([
|
||||
"Professional podcast recording studio",
|
||||
"Modern podcast setup with high-quality microphone",
|
||||
"Clean, minimalist background suitable for video",
|
||||
"Professional studio lighting with soft, even illumination",
|
||||
"Podcast host environment, professional and inviting"
|
||||
])
|
||||
# Use Bible visual style if available
|
||||
if bible_obj:
|
||||
prompt_parts.append(f"Style: {bible_obj.visual_style.style_preset}")
|
||||
prompt_parts.append(f"Environment: {bible_obj.visual_style.environment}")
|
||||
prompt_parts.append(f"Lighting: {bible_obj.visual_style.lighting}")
|
||||
if bible_obj.host.look:
|
||||
prompt_parts.append(f"Host Look: {bible_obj.host.look}")
|
||||
else:
|
||||
# Core podcast studio elements
|
||||
prompt_parts.extend([
|
||||
"Professional podcast recording studio",
|
||||
"Modern podcast setup with high-quality microphone",
|
||||
"Clean, minimalist background suitable for video",
|
||||
"Professional studio lighting with soft, even illumination",
|
||||
"Podcast host environment, professional and inviting"
|
||||
])
|
||||
|
||||
# Scene-specific context
|
||||
if request.scene_title:
|
||||
@@ -264,12 +295,13 @@ async def generate_podcast_scene_image(
|
||||
])
|
||||
|
||||
# Style constraints
|
||||
prompt_parts.extend([
|
||||
"Realistic photography style, not illustration or cartoon",
|
||||
"Professional broadcast quality",
|
||||
"Warm, inviting atmosphere",
|
||||
"Clean composition with breathing room for avatar placement"
|
||||
])
|
||||
if not bible_obj:
|
||||
prompt_parts.extend([
|
||||
"Realistic photography style, not illustration or cartoon",
|
||||
"Professional broadcast quality",
|
||||
"Warm, inviting atmosphere",
|
||||
"Clean composition with breathing room for avatar placement"
|
||||
])
|
||||
|
||||
image_prompt = ", ".join(prompt_parts)
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ async def create_project(
|
||||
duration=request.duration,
|
||||
speakers=request.speakers,
|
||||
budget_cap=request.budget_cap,
|
||||
avatar_url=request.avatar_url,
|
||||
)
|
||||
|
||||
return PodcastProjectResponse.model_validate(project)
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
"""
|
||||
Podcast Research Handlers
|
||||
|
||||
Research endpoints using Exa provider.
|
||||
Research endpoints using Exa provider and LLM summarization.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from typing import Dict, Any
|
||||
from typing import Dict, Any, List
|
||||
from types import SimpleNamespace
|
||||
import json
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.blog_writer.research.exa_provider import ExaResearchProvider
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastExaResearchRequest,
|
||||
PodcastExaResearchResponse,
|
||||
PodcastExaSource,
|
||||
PodcastExaConfig,
|
||||
PodcastResearchInsight,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
@@ -28,7 +32,8 @@ async def podcast_research_exa(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
"""
|
||||
Run podcast research directly via Exa (no blog writer pipeline).
|
||||
Run podcast research via Exa and then use LLM to extract deep insights.
|
||||
Uses Podcast Bible and Analysis context for hyper-personalization.
|
||||
"""
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
@@ -47,22 +52,121 @@ async def podcast_research_exa(
|
||||
)
|
||||
|
||||
provider = ExaResearchProvider()
|
||||
prompt = request.topic
|
||||
|
||||
# --- Context Building ---
|
||||
bible_service = PodcastBibleService()
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
try:
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
bible_data = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_data)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[Podcast Research] Failed to serialize bible: {exc}")
|
||||
|
||||
analysis_context = ""
|
||||
if request.analysis:
|
||||
analysis_context = f"""
|
||||
PODCAST ANALYSIS CONTEXT:
|
||||
Audience: {request.analysis.get('audience', 'General')}
|
||||
Content Type: {request.analysis.get('content_type', 'Informative')}
|
||||
Top Keywords: {', '.join(request.analysis.get('top_keywords', []))}
|
||||
"""
|
||||
|
||||
# Exa search params
|
||||
industry = request.bible.get("brand", {}).get("industry", "") if request.bible else ""
|
||||
target_audience = ""
|
||||
if request.bible:
|
||||
audience_dna = request.bible.get("audience", {})
|
||||
if audience_dna:
|
||||
interests = ", ".join(audience_dna.get("interests", []))
|
||||
target_audience = f"Expertise: {audience_dna.get('expertise_level', '')}. Interests: {interests}."
|
||||
|
||||
try:
|
||||
# 1. RUN EXA SEARCH
|
||||
result = await provider.search(
|
||||
prompt=prompt,
|
||||
prompt=request.topic,
|
||||
topic=request.topic,
|
||||
industry="",
|
||||
target_audience="",
|
||||
industry=industry,
|
||||
target_audience=target_audience,
|
||||
config=cfg,
|
||||
user_id=user_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Exa Research] Failed for user {user_id}: {exc}")
|
||||
logger.error(f"[Podcast Exa Research] Search failed for user {user_id}: {exc}")
|
||||
raise HTTPException(status_code=500, detail=f"Exa research failed: {exc}")
|
||||
|
||||
# Track usage if available
|
||||
# 2. EXTRACT INSIGHTS VIA LLM
|
||||
raw_content = result.get("content", "")
|
||||
sources = result.get("sources", [])
|
||||
|
||||
summary = ""
|
||||
key_insights = []
|
||||
|
||||
if raw_content and sources:
|
||||
logger.info(f"[Podcast Research] Extracting insights from {len(sources)} sources for user {user_id}")
|
||||
|
||||
prompt = f"""
|
||||
You are an expert research analyst for a high-end podcast production team.
|
||||
Your task is to analyze the following research data and extract deep, actionable insights for a podcast episode.
|
||||
|
||||
PODCAST CONTEXT:
|
||||
Topic: {request.topic}
|
||||
{bible_context}
|
||||
{analysis_context}
|
||||
|
||||
RESEARCH DATA (from {len(sources)} sources):
|
||||
{raw_content}
|
||||
|
||||
TASK:
|
||||
1. Provide a comprehensive summary (2-3 paragraphs) of the most important findings. Use Markdown for formatting (bolding, lists).
|
||||
2. Extract 3-5 "Key Insights". Each insight should have a title and a detailed explanation.
|
||||
3. For each insight, identify which source indices (e.g. 1, 2) it was derived from.
|
||||
|
||||
NOTE: The research data includes "Key Highlights", "Summaries", and "Excerpts" from various sources.
|
||||
Pay special attention to the "Key Highlights" sections as they contain the most relevant information extracted by the neural search engine.
|
||||
|
||||
Return JSON structure:
|
||||
{{
|
||||
"summary": "Detailed markdown summary...",
|
||||
"key_insights": [
|
||||
{{
|
||||
"title": "Insight Title",
|
||||
"content": "Detailed markdown content...",
|
||||
"source_indices": [1, 2]
|
||||
}}
|
||||
]
|
||||
}}
|
||||
|
||||
Requirements:
|
||||
- Ensure insights are deep, not just superficial facts. Look for trends, expert opinions, and specific data points.
|
||||
- Tone should be professional, insightful, and ready for a podcast host to discuss.
|
||||
- Avoid generic filler.
|
||||
"""
|
||||
try:
|
||||
llm_response = llm_text_gen(prompt=prompt, user_id=user_id, json_struct=None)
|
||||
|
||||
# Normalize response
|
||||
if isinstance(llm_response, str):
|
||||
data = json.loads(llm_response)
|
||||
else:
|
||||
data = llm_response
|
||||
|
||||
summary = data.get("summary", "")
|
||||
key_insights = [PodcastResearchInsight(**insight) for insight in data.get("key_insights", [])]
|
||||
except Exception as exc:
|
||||
logger.error(f"[Podcast Research] LLM Insight extraction failed: {exc}")
|
||||
# Fallback to a basic summary if LLM fails
|
||||
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
|
||||
|
||||
# Fallback: if summary is still empty (e.g. LLM returned empty string), use raw content first paragraph or basic text
|
||||
if not summary:
|
||||
if raw_content:
|
||||
summary = raw_content[:2000] # Use first 2000 chars of raw content as summary
|
||||
else:
|
||||
summary = f"Research completed for '{request.topic}'. Found {len(sources)} sources."
|
||||
|
||||
# 3. TRACK USAGE
|
||||
try:
|
||||
cost_total = 0.0
|
||||
if isinstance(result, dict):
|
||||
@@ -72,28 +176,31 @@ async def podcast_research_exa(
|
||||
logger.warning(f"[Podcast Exa Research] Failed to track usage: {track_err}")
|
||||
|
||||
sources_payload = []
|
||||
if isinstance(result, dict):
|
||||
for src in result.get("sources", []) or []:
|
||||
try:
|
||||
sources_payload.append(PodcastExaSource(**src))
|
||||
except Exception:
|
||||
sources_payload.append(PodcastExaSource(**{
|
||||
"title": src.get("title", ""),
|
||||
"url": src.get("url", ""),
|
||||
"excerpt": src.get("excerpt", ""),
|
||||
"published_at": src.get("published_at"),
|
||||
"highlights": src.get("highlights"),
|
||||
"summary": src.get("summary"),
|
||||
"source_type": src.get("source_type"),
|
||||
"index": src.get("index"),
|
||||
}))
|
||||
for src in sources:
|
||||
try:
|
||||
sources_payload.append(PodcastExaSource(**src))
|
||||
except Exception:
|
||||
sources_payload.append(PodcastExaSource(**{
|
||||
"title": src.get("title", ""),
|
||||
"url": src.get("url", ""),
|
||||
"excerpt": src.get("excerpt", ""),
|
||||
"published_at": src.get("published_at"),
|
||||
"highlights": src.get("highlights"),
|
||||
"summary": src.get("summary"),
|
||||
"source_type": src.get("source_type"),
|
||||
"index": src.get("index"),
|
||||
"image": src.get("image"),
|
||||
"author": src.get("author"),
|
||||
}))
|
||||
|
||||
return PodcastExaResearchResponse(
|
||||
sources=sources_payload,
|
||||
search_queries=result.get("search_queries", queries) if isinstance(result, dict) else queries,
|
||||
summary=summary,
|
||||
key_insights=key_insights,
|
||||
cost=result.get("cost") if isinstance(result, dict) else None,
|
||||
search_type=result.get("search_type") if isinstance(result, dict) else None,
|
||||
provider=result.get("provider", "exa") if isinstance(result, dict) else "exa",
|
||||
content=result.get("content") if isinstance(result, dict) else None,
|
||||
content=raw_content,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,8 @@ import json
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
from models.podcast_bible_models import PodcastBible
|
||||
from loguru import logger
|
||||
from ..models import (
|
||||
PodcastScriptRequest,
|
||||
@@ -62,8 +64,39 @@ async def generate_podcast_script(
|
||||
logger.warning(f"Failed to parse research context: {exc}")
|
||||
research_context = ""
|
||||
|
||||
# Extract Podcast Bible context for hyper-personalization
|
||||
bible_context = ""
|
||||
if request.bible:
|
||||
try:
|
||||
bible_service = PodcastBibleService()
|
||||
bible_obj = PodcastBible(**request.bible)
|
||||
bible_context = bible_service.serialize_bible(bible_obj)
|
||||
except Exception as exc:
|
||||
logger.warning(f"Failed to serialize podcast bible: {exc}")
|
||||
|
||||
# Extract Analysis and Outline context for grounding
|
||||
analysis_context = ""
|
||||
if request.analysis:
|
||||
analysis_context = f"""
|
||||
TARGET AUDIENCE: {request.analysis.get('audience', 'General')}
|
||||
CONTENT TYPE: {request.analysis.get('contentType', 'Conversational')}
|
||||
TOP KEYWORDS: {', '.join(request.analysis.get('topKeywords', []))}
|
||||
"""
|
||||
|
||||
outline_context = ""
|
||||
if request.outline:
|
||||
outline_context = f"""
|
||||
REFINED EPISODE OUTLINE (Follow this structure closely):
|
||||
Title: {request.outline.get('title', 'N/A')}
|
||||
Segments: {' | '.join(request.outline.get('segments', []))}
|
||||
"""
|
||||
|
||||
prompt = f"""You are an expert podcast script planner. Create natural, conversational podcast scenes.
|
||||
|
||||
{f"PODCAST BIBLE (Hyper-Personalization Context):\n{bible_context}\n" if bible_context else ""}
|
||||
{f"ANALYSIS CONTEXT:\n{analysis_context}\n" if analysis_context else ""}
|
||||
{f"REFINED OUTLINE:\n{outline_context}\n" if outline_context else ""}
|
||||
|
||||
Podcast Idea: "{request.idea}"
|
||||
Duration: ~{request.duration_minutes} minutes
|
||||
Speakers: {request.speakers} (Host + optional Guest)
|
||||
@@ -83,11 +116,13 @@ Return JSON with:
|
||||
* Mark "emphasis": true for key statistics or important points
|
||||
|
||||
Guidelines:
|
||||
- Write for spoken delivery: conversational, natural, with contractions
|
||||
- Use research insights naturally - weave statistics into dialogue, don't just list them
|
||||
- Vary emotion per scene based on content
|
||||
- Ensure scenes match target duration: aim for ~2.5 words per second of audio
|
||||
- Keep it engaging and informative, like a real podcast conversation
|
||||
- Write for spoken delivery: conversational, natural, with contractions.
|
||||
- Follow the interaction tone specified in the Bible.
|
||||
- Ensure the Host persona matches the background and personality traits from the Bible.
|
||||
- Structure the intro and outro scenes according to the Bible's "Intro Format" and "Outro Format".
|
||||
- Adhere to any constraints mentioned in the Bible.
|
||||
- Use insights from the Research Context to ground the conversation in facts.
|
||||
- IMPORTANT: Follow the REFINED OUTLINE segments as the primary structure for the episode.
|
||||
"""
|
||||
|
||||
try:
|
||||
|
||||
@@ -14,7 +14,7 @@ import re
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from middleware.auth_middleware import get_current_user, get_current_user_with_query_token
|
||||
from api.story_writer.utils.auth import require_authenticated_user
|
||||
from services.wavespeed.infinitetalk import animate_scene_with_voiceover
|
||||
@@ -105,6 +105,34 @@ def _execute_podcast_video_task(
|
||||
scene_number_match = re.search(r'\d+', request.scene_id)
|
||||
scene_number = int(scene_number_match.group()) if scene_number_match else 0
|
||||
|
||||
# Fetch project context (Bible & Analysis) from DB if not provided in request
|
||||
from services.database import get_session_for_user
|
||||
from services.podcast_service import PodcastService
|
||||
|
||||
project_bible = request.bible
|
||||
project_analysis = None
|
||||
|
||||
try:
|
||||
# Create a dedicated session for this background task
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
podcast_service = PodcastService(db)
|
||||
# Fetch project directly from DB to get latest analysis/bible
|
||||
project = podcast_service.get_project(user_id, request.project_id)
|
||||
if project:
|
||||
# Use project bible if request didn't provide one
|
||||
if not project_bible and project.bible:
|
||||
project_bible = project.bible
|
||||
|
||||
# Get analysis for better context
|
||||
if project.analysis:
|
||||
project_analysis = project.analysis
|
||||
logger.info(f"[Podcast] Loaded analysis for video context: {list(project_analysis.keys())}")
|
||||
finally:
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] Failed to fetch project context for video generation: {e}")
|
||||
|
||||
# Prepare scene data for animation
|
||||
scene_data = {
|
||||
"scene_number": scene_number,
|
||||
@@ -114,6 +142,8 @@ def _execute_podcast_video_task(
|
||||
story_context = {
|
||||
"project_id": request.project_id,
|
||||
"type": "podcast",
|
||||
"bible": project_bible,
|
||||
"analysis": project_analysis,
|
||||
}
|
||||
|
||||
animation_result = animate_scene_with_voiceover(
|
||||
@@ -207,8 +237,8 @@ def _execute_podcast_video_task(
|
||||
|
||||
@router.post("/render/video", response_model=PodcastVideoGenerationResponse)
|
||||
async def generate_podcast_video(
|
||||
request_obj: Request,
|
||||
request: PodcastVideoGenerationRequest,
|
||||
request: Request,
|
||||
body: PodcastVideoGenerationRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
@@ -216,22 +246,46 @@ async def generate_podcast_video(
|
||||
Generate video for a podcast scene using WaveSpeed InfiniteTalk (avatar image + audio).
|
||||
Returns task_id for polling since InfiniteTalk can take up to 10 minutes.
|
||||
"""
|
||||
# Debug logging to identify "Depends object has no attribute get" error source
|
||||
logger.info(f"[Podcast] generate_podcast_video called. current_user type: {type(current_user)}")
|
||||
|
||||
# Check if current_user is a Depends object (FastAPI injection failure)
|
||||
if hasattr(current_user, "dependency"):
|
||||
logger.error(f"[Podcast] CRITICAL: current_user is a Depends object! Dependency injection failed.")
|
||||
# Attempt to manually resolve or fail gracefully
|
||||
auth_header = None
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
auth_header = request.headers.get("Authorization")
|
||||
except:
|
||||
pass
|
||||
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
token = auth_header.replace("Bearer ", "").strip()
|
||||
# Manually verify token if dependency injection failed
|
||||
from middleware.auth_middleware import clerk_auth
|
||||
current_user = await clerk_auth.verify_token(token)
|
||||
if not current_user:
|
||||
raise HTTPException(status_code=401, detail="Authentication failed (manual recovery)")
|
||||
else:
|
||||
raise HTTPException(status_code=401, detail="Authentication failed (injection error)")
|
||||
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
logger.info(
|
||||
f"[Podcast] Starting video generation for project {request.project_id}, scene {request.scene_id}"
|
||||
f"[Podcast] Starting video generation for project {body.project_id}, scene {body.scene_id}"
|
||||
)
|
||||
|
||||
# Load audio bytes
|
||||
audio_bytes = load_podcast_audio_bytes(request.audio_url)
|
||||
audio_bytes = load_podcast_audio_bytes(body.audio_url)
|
||||
|
||||
# Validate resolution
|
||||
if request.resolution not in {"480p", "720p"}:
|
||||
if body.resolution not in {"480p", "720p"}:
|
||||
raise HTTPException(status_code=400, detail="Resolution must be '480p' or '720p'.")
|
||||
|
||||
# Load image bytes (scene image is required for video generation)
|
||||
if request.avatar_image_url:
|
||||
image_bytes = load_podcast_image_bytes(request.avatar_image_url)
|
||||
if body.avatar_image_url:
|
||||
image_bytes = load_podcast_image_bytes(body.avatar_image_url)
|
||||
else:
|
||||
# Scene-specific image should be generated before video generation
|
||||
raise HTTPException(
|
||||
@@ -240,9 +294,9 @@ async def generate_podcast_video(
|
||||
)
|
||||
|
||||
mask_image_bytes = None
|
||||
if request.mask_image_url:
|
||||
if body.mask_image_url:
|
||||
try:
|
||||
mask_image_bytes = load_podcast_image_bytes(request.mask_image_url)
|
||||
mask_image_bytes = load_podcast_image_bytes(body.mask_image_url)
|
||||
except Exception as e:
|
||||
logger.error(f"[Podcast] Failed to load mask image: {e}")
|
||||
raise HTTPException(
|
||||
@@ -251,7 +305,9 @@ async def generate_podcast_video(
|
||||
)
|
||||
|
||||
# Validate subscription limits
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise HTTPException(status_code=500, detail="Database session unavailable for user.")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
validate_scene_animation_operation(pricing_service=pricing_service, user_id=user_id)
|
||||
@@ -260,16 +316,20 @@ async def generate_podcast_video(
|
||||
|
||||
# Extract token for authenticated URL building
|
||||
auth_token = None
|
||||
auth_header = request_obj.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.replace("Bearer ", "").strip()
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header and auth_header.startswith("Bearer "):
|
||||
auth_token = auth_header.replace("Bearer ", "").strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[Podcast] Failed to extract auth token from headers: {e}")
|
||||
|
||||
# Create async task
|
||||
task_id = task_manager.create_task("podcast_video_generation")
|
||||
background_tasks.add_task(
|
||||
_execute_podcast_video_task,
|
||||
task_id=task_id,
|
||||
request=request,
|
||||
request=body,
|
||||
user_id=user_id,
|
||||
image_bytes=image_bytes,
|
||||
audio_bytes=audio_bytes,
|
||||
|
||||
@@ -25,6 +25,7 @@ class PodcastProjectResponse(BaseModel):
|
||||
raw_research: Optional[Dict[str, Any]] = None
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
script_data: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
render_jobs: Optional[List[Dict[str, Any]]] = None
|
||||
knobs: Optional[Dict[str, Any]] = None
|
||||
research_provider: Optional[str] = None
|
||||
@@ -34,6 +35,9 @@ class PodcastProjectResponse(BaseModel):
|
||||
status: str = "draft"
|
||||
is_favorite: bool = False
|
||||
final_video_url: Optional[str] = None
|
||||
avatar_url: Optional[str] = None
|
||||
avatar_prompt: Optional[str] = None
|
||||
avatar_persona_id: Optional[str] = None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
@@ -46,6 +50,9 @@ class PodcastAnalyzeRequest(BaseModel):
|
||||
idea: str = Field(..., description="Podcast topic or idea")
|
||||
duration: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
avatar_url: Optional[str] = Field(None, description="Current avatar URL if selected")
|
||||
feedback: Optional[str] = Field(None, description="User feedback for regeneration")
|
||||
|
||||
|
||||
class PodcastAnalyzeResponse(BaseModel):
|
||||
@@ -55,7 +62,23 @@ class PodcastAnalyzeResponse(BaseModel):
|
||||
top_keywords: list[str]
|
||||
suggested_outlines: list[Dict[str, Any]]
|
||||
title_suggestions: list[str]
|
||||
research_queries: Optional[List[Dict[str, str]]] = None
|
||||
exa_suggested_config: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
avatar_url: Optional[str] = None
|
||||
avatar_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaRequest(BaseModel):
|
||||
"""Request model for enhancing a podcast idea with AI."""
|
||||
idea: str = Field(..., description="The raw podcast idea or keywords")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Optional Podcast Bible for context")
|
||||
|
||||
|
||||
class PodcastEnhanceIdeaResponse(BaseModel):
|
||||
"""Response model for enhanced podcast idea."""
|
||||
enhanced_idea: str
|
||||
rationale: str
|
||||
|
||||
|
||||
class PodcastScriptRequest(BaseModel):
|
||||
@@ -64,6 +87,9 @@ class PodcastScriptRequest(BaseModel):
|
||||
duration_minutes: int = Field(default=10, description="Target duration in minutes")
|
||||
speakers: int = Field(default=1, description="Number of speakers")
|
||||
research: Optional[Dict[str, Any]] = Field(None, description="Optional research payload to ground the script")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
outline: Optional[Dict[str, Any]] = Field(None, description="The refined episode outline to follow")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="The full analysis context (audience, keywords, etc.)")
|
||||
|
||||
|
||||
class PodcastSceneLine(BaseModel):
|
||||
@@ -106,6 +132,8 @@ class PodcastExaResearchRequest(BaseModel):
|
||||
topic: str
|
||||
queries: List[str]
|
||||
exa_config: Optional[PodcastExaConfig] = None
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
analysis: Optional[Dict[str, Any]] = Field(None, description="Podcast analysis context (audience, content type, etc.)")
|
||||
|
||||
|
||||
class PodcastExaSource(BaseModel):
|
||||
@@ -117,15 +145,26 @@ class PodcastExaSource(BaseModel):
|
||||
summary: Optional[str] = None
|
||||
source_type: Optional[str] = None
|
||||
index: Optional[int] = None
|
||||
image: Optional[str] = None
|
||||
author: Optional[str] = None
|
||||
|
||||
|
||||
class PodcastResearchInsight(BaseModel):
|
||||
"""Deep insight extracted from research."""
|
||||
title: str
|
||||
content: str
|
||||
source_indices: List[int] = []
|
||||
|
||||
|
||||
class PodcastExaResearchResponse(BaseModel):
|
||||
sources: List[PodcastExaSource]
|
||||
search_queries: List[str] = []
|
||||
summary: str = ""
|
||||
key_insights: List[PodcastResearchInsight] = []
|
||||
cost: Optional[Dict[str, Any]] = None
|
||||
search_type: Optional[str] = None
|
||||
provider: str = "exa"
|
||||
content: Optional[str] = None
|
||||
content: Optional[str] = None # Raw aggregated content (deprecated)
|
||||
|
||||
|
||||
class PodcastScriptResponse(BaseModel):
|
||||
@@ -191,6 +230,7 @@ class UpdateProjectRequest(BaseModel):
|
||||
raw_research: Optional[Dict[str, Any]] = None
|
||||
estimate: Optional[Dict[str, Any]] = None
|
||||
script_data: Optional[Dict[str, Any]] = None
|
||||
bible: Optional[Dict[str, Any]] = None
|
||||
render_jobs: Optional[List[Dict[str, Any]]] = None
|
||||
knobs: Optional[Dict[str, Any]] = None
|
||||
research_provider: Optional[str] = None
|
||||
@@ -224,6 +264,7 @@ class PodcastImageRequest(BaseModel):
|
||||
scene_content: Optional[str] = None # Optional: scene lines text for context
|
||||
idea: Optional[str] = None # Optional: podcast idea for context
|
||||
base_avatar_url: Optional[str] = None # Base avatar image URL for scene variations
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
custom_prompt: Optional[str] = None # Custom prompt from user (overrides auto-generated prompt)
|
||||
@@ -252,6 +293,7 @@ class PodcastVideoGenerationRequest(BaseModel):
|
||||
scene_title: str = Field(..., description="Scene title")
|
||||
audio_url: str = Field(..., description="URL to the generated audio file")
|
||||
avatar_image_url: Optional[str] = Field(None, description="URL to scene image (required for video generation)")
|
||||
bible: Optional[Dict[str, Any]] = Field(None, description="Podcast Bible for hyper-personalization")
|
||||
resolution: str = Field("720p", description="Video resolution (480p or 720p)")
|
||||
prompt: Optional[str] = Field(None, description="Optional animation prompt override")
|
||||
seed: Optional[int] = Field(-1, description="Random seed; -1 for random")
|
||||
|
||||
@@ -524,6 +524,80 @@ async def get_semantic_cache_stats(current_user: dict = Depends(get_current_user
|
||||
"memory_usage_mb": 0.0
|
||||
}
|
||||
|
||||
|
||||
async def get_sif_indexing_health(current_user: dict = Depends(get_current_user)) -> Dict[str, Any]:
|
||||
try:
|
||||
from models.website_analysis_monitoring_models import SIFIndexingTask, SIFIndexingExecutionLog
|
||||
|
||||
user_id = str(current_user.get("id"))
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise HTTPException(status_code=500, detail="Database connection unavailable")
|
||||
|
||||
try:
|
||||
tasks = (
|
||||
db.query(SIFIndexingTask)
|
||||
.filter(SIFIndexingTask.user_id == user_id)
|
||||
.order_by(SIFIndexingTask.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
|
||||
if not tasks:
|
||||
return {
|
||||
"has_task": False,
|
||||
"status": "not_scheduled",
|
||||
"message": "SIF indexing task not yet scheduled for this website.",
|
||||
}
|
||||
|
||||
latest = tasks[0]
|
||||
latest_log = (
|
||||
db.query(SIFIndexingExecutionLog)
|
||||
.filter(SIFIndexingExecutionLog.task_id == latest.id)
|
||||
.order_by(SIFIndexingExecutionLog.execution_date.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
last_run_status = latest_log.status if latest_log else None
|
||||
last_run_time = (
|
||||
latest_log.execution_date.isoformat() if latest_log and latest_log.execution_date else None
|
||||
)
|
||||
last_error = (
|
||||
(latest_log.error_message or "")[:500] if latest_log and latest_log.error_message else None
|
||||
)
|
||||
|
||||
overall_status = "healthy"
|
||||
if latest.consecutive_failures and latest.consecutive_failures > 0:
|
||||
overall_status = "warning"
|
||||
if latest.status in {"needs_intervention"}:
|
||||
overall_status = "critical"
|
||||
|
||||
return {
|
||||
"has_task": True,
|
||||
"status": overall_status,
|
||||
"task": {
|
||||
"id": latest.id,
|
||||
"website_url": latest.website_url,
|
||||
"raw_status": latest.status,
|
||||
"next_execution": latest.next_execution.isoformat() if latest.next_execution else None,
|
||||
"last_success": latest.last_success.isoformat() if latest.last_success else None,
|
||||
"last_failure": latest.last_failure.isoformat() if latest.last_failure else None,
|
||||
"consecutive_failures": latest.consecutive_failures or 0,
|
||||
"failure_pattern": latest.failure_pattern,
|
||||
},
|
||||
"last_run": {
|
||||
"status": last_run_status,
|
||||
"time": last_run_time,
|
||||
"error_message": last_error,
|
||||
},
|
||||
}
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get SIF indexing health: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to get SIF indexing health")
|
||||
|
||||
# New comprehensive SEO analysis endpoints
|
||||
async def analyze_seo_comprehensive(request: SEOAnalysisRequest) -> SEOAnalysisResponse:
|
||||
"""
|
||||
|
||||
73
backend/api/story_writer/models_projects.py
Normal file
73
backend/api/story_writer/models_projects.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""
|
||||
Story Project API Models
|
||||
|
||||
Pydantic models for Story Studio project endpoints.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StoryProjectResponse(BaseModel):
|
||||
id: int
|
||||
project_id: str
|
||||
user_id: str
|
||||
title: Optional[str] = None
|
||||
story_mode: Optional[str] = None
|
||||
story_template: Optional[str] = None
|
||||
setup: Optional[Dict[str, Any]] = None
|
||||
outline: Optional[Dict[str, Any]] = None
|
||||
scenes: Optional[List[Dict[str, Any]]] = None
|
||||
story_content: Optional[Dict[str, Any]] = None
|
||||
anime_bible: Optional[Dict[str, Any]] = None
|
||||
media_state: Optional[Dict[str, Any]] = None
|
||||
current_phase: Optional[str] = None
|
||||
status: str = "draft"
|
||||
is_favorite: bool = False
|
||||
is_complete: bool = False
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class StoryProjectListResponse(BaseModel):
|
||||
projects: List[StoryProjectResponse]
|
||||
total: int
|
||||
limit: int
|
||||
offset: int
|
||||
|
||||
|
||||
class CreateStoryProjectRequest(BaseModel):
|
||||
project_id: str = Field(..., description="Unique story project ID")
|
||||
title: Optional[str] = Field(None, description="Optional story project title or idea")
|
||||
story_mode: Optional[str] = Field(
|
||||
None, description="Story mode (marketing or pure) if provided by the UI"
|
||||
)
|
||||
story_template: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional story template identifier (e.g. product_story, anime_fiction)",
|
||||
)
|
||||
setup: Optional[Dict[str, Any]] = Field(
|
||||
None,
|
||||
description="Initial story setup payload to persist with the project",
|
||||
)
|
||||
|
||||
|
||||
class UpdateStoryProjectRequest(BaseModel):
|
||||
title: Optional[str] = None
|
||||
story_mode: Optional[str] = None
|
||||
story_template: Optional[str] = None
|
||||
setup: Optional[Dict[str, Any]] = None
|
||||
outline: Optional[Dict[str, Any]] = None
|
||||
scenes: Optional[List[Dict[str, Any]]] = None
|
||||
story_content: Optional[Dict[str, Any]] = None
|
||||
anime_bible: Optional[Dict[str, Any]] = None
|
||||
media_state: Optional[Dict[str, Any]] = None
|
||||
current_phase: Optional[str] = None
|
||||
status: Optional[str] = None
|
||||
is_complete: Optional[bool] = None
|
||||
|
||||
@@ -14,6 +14,7 @@ from .routes import (
|
||||
media_generation,
|
||||
scene_animation,
|
||||
story_content,
|
||||
story_projects,
|
||||
story_setup,
|
||||
story_tasks,
|
||||
video_generation,
|
||||
@@ -24,6 +25,7 @@ router = APIRouter(prefix="/api/story", tags=["Story Writer"])
|
||||
# Include modular routers (order preserved roughly by workflow)
|
||||
router.include_router(story_setup.router)
|
||||
router.include_router(story_content.router)
|
||||
router.include_router(story_projects.router)
|
||||
router.include_router(story_tasks.router)
|
||||
router.include_router(media_generation.router)
|
||||
router.include_router(scene_animation.router)
|
||||
|
||||
@@ -65,7 +65,7 @@ async def generate_scene_images(
|
||||
scene_number=result.get("scene_number", 0),
|
||||
scene_title=result.get("scene_title", "Untitled"),
|
||||
image_filename=result.get("image_filename", ""),
|
||||
image_url=result.get("image_url", ""),
|
||||
image_url=result.get("image_url") or "",
|
||||
width=result.get("width", 1024),
|
||||
height=result.get("height", 1024),
|
||||
provider=result.get("provider", "unknown"),
|
||||
@@ -148,7 +148,7 @@ async def regenerate_scene_image(
|
||||
scene_number=result.get("scene_number", request.scene_number),
|
||||
scene_title=result.get("scene_title", request.scene_title),
|
||||
image_filename=result.get("image_filename", ""),
|
||||
image_url=result.get("image_url", ""),
|
||||
image_url=result.get("image_url") or "",
|
||||
width=result.get("width", request.width or 1024),
|
||||
height=result.get("height", request.height or 1024),
|
||||
provider=result.get("provider", "unknown"),
|
||||
|
||||
@@ -12,6 +12,10 @@ from models.story_models import (
|
||||
StoryScene,
|
||||
StoryContinueRequest,
|
||||
StoryContinueResponse,
|
||||
AnimeSceneTextRequest,
|
||||
AnimeSceneTextResponse,
|
||||
AnimeSceneGenerateRequest,
|
||||
AnimeSceneGenerateResponse,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
|
||||
@@ -107,6 +111,7 @@ async def generate_story_start(
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
story_length=story_length,
|
||||
anime_bible=getattr(request, "anime_bible", None),
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -211,6 +216,7 @@ async def continue_story(
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
ending_preference=request.ending_preference,
|
||||
anime_bible=getattr(request, "anime_bible", None),
|
||||
story_length=story_length,
|
||||
user_id=user_id,
|
||||
)
|
||||
@@ -245,6 +251,105 @@ async def continue_story(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/anime/scene-text", response_model=AnimeSceneTextResponse)
|
||||
async def refine_anime_scene_text(
|
||||
request: AnimeSceneTextRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimeSceneTextResponse:
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
scene_dict = request.scene.dict()
|
||||
if not scene_dict.get("title") and not scene_dict.get("description"):
|
||||
raise HTTPException(status_code=400, detail="Scene title or description is required")
|
||||
|
||||
refined = story_service.refine_anime_scene_text(
|
||||
scene=scene_dict,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
anime_bible=request.anime_bible,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
refined_scene = StoryScene(
|
||||
scene_number=refined.get("scene_number", request.scene.scene_number),
|
||||
title=refined.get("title", request.scene.title),
|
||||
description=refined.get("description", request.scene.description),
|
||||
image_prompt=refined.get("image_prompt", request.scene.image_prompt),
|
||||
audio_narration=refined.get("audio_narration", request.scene.audio_narration),
|
||||
character_descriptions=refined.get(
|
||||
"character_descriptions", request.scene.character_descriptions
|
||||
),
|
||||
key_events=refined.get("key_events", request.scene.key_events),
|
||||
)
|
||||
|
||||
return AnimeSceneTextResponse(scene=refined_scene, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to refine anime scene text: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/anime/scene-generate", response_model=AnimeSceneGenerateResponse)
|
||||
async def generate_anime_scene_from_bible(
|
||||
request: AnimeSceneGenerateRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> AnimeSceneGenerateResponse:
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.anime_bible:
|
||||
raise HTTPException(status_code=400, detail="Anime story bible is required")
|
||||
|
||||
previous_scenes_payload: Optional[List[Dict[str, Any]]] = None
|
||||
if request.previous_scenes:
|
||||
previous_scenes_payload = [scene.dict() for scene in request.previous_scenes]
|
||||
|
||||
generated = story_service.generate_anime_scene_from_bible(
|
||||
premise=request.premise,
|
||||
persona=request.persona,
|
||||
story_setting=request.story_setting,
|
||||
character_input=request.character_input,
|
||||
plot_elements=request.plot_elements,
|
||||
writing_style=request.writing_style,
|
||||
story_tone=request.story_tone,
|
||||
narrative_pov=request.narrative_pov,
|
||||
audience_age_group=request.audience_age_group,
|
||||
content_rating=request.content_rating,
|
||||
anime_bible=request.anime_bible,
|
||||
previous_scenes=previous_scenes_payload,
|
||||
target_scene_number=request.target_scene_number,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
scene = StoryScene(
|
||||
scene_number=generated.get("scene_number"),
|
||||
title=generated.get("title", ""),
|
||||
description=generated.get("description", ""),
|
||||
image_prompt=generated.get("image_prompt", ""),
|
||||
audio_narration=generated.get("audio_narration", ""),
|
||||
character_descriptions=generated.get("character_descriptions") or [],
|
||||
key_events=generated.get("key_events") or [],
|
||||
)
|
||||
|
||||
return AnimeSceneGenerateResponse(scene=scene, success=True)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate anime scene from bible: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
class SceneApprovalRequest(BaseModel):
|
||||
project_id: str = Field(..., min_length=1)
|
||||
scene_id: str = Field(..., min_length=1)
|
||||
|
||||
189
backend/api/story_writer/routes/story_projects.py
Normal file
189
backend/api/story_writer/routes/story_projects.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.database import get_db
|
||||
from services.story_writer.story_project_service import StoryProjectService
|
||||
from ..models_projects import (
|
||||
CreateStoryProjectRequest,
|
||||
StoryProjectListResponse,
|
||||
StoryProjectResponse,
|
||||
UpdateStoryProjectRequest,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/projects", response_model=StoryProjectResponse, status_code=201)
|
||||
async def create_story_project(
|
||||
request: CreateStoryProjectRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryProjectResponse:
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = StoryProjectService(db)
|
||||
|
||||
existing = service.get_project(user_id, request.project_id)
|
||||
if existing:
|
||||
raise HTTPException(status_code=400, detail="Project ID already exists")
|
||||
|
||||
project = service.create_project(
|
||||
user_id=user_id,
|
||||
project_id=request.project_id,
|
||||
title=request.title,
|
||||
story_mode=request.story_mode,
|
||||
story_template=request.story_template,
|
||||
setup=request.setup,
|
||||
)
|
||||
|
||||
return StoryProjectResponse.model_validate(project)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error creating story project: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=StoryProjectResponse)
|
||||
async def get_story_project(
|
||||
project_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryProjectResponse:
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = StoryProjectService(db)
|
||||
project = service.get_project(user_id, project_id)
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
return StoryProjectResponse.model_validate(project)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error fetching story project: {str(e)}")
|
||||
|
||||
|
||||
@router.put("/projects/{project_id}", response_model=StoryProjectResponse)
|
||||
async def update_story_project(
|
||||
project_id: str,
|
||||
request: UpdateStoryProjectRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryProjectResponse:
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = StoryProjectService(db)
|
||||
|
||||
updates = request.model_dump(exclude_unset=True)
|
||||
|
||||
project = service.update_project(user_id, project_id, **updates)
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
return StoryProjectResponse.model_validate(project)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error updating story project: {str(e)}")
|
||||
|
||||
|
||||
@router.get("/projects", response_model=StoryProjectListResponse)
|
||||
async def list_story_projects(
|
||||
status: Optional[str] = Query(None, description="Filter by status"),
|
||||
favorites_only: bool = Query(False, description="Only favorites"),
|
||||
limit: int = Query(50, ge=1, le=200),
|
||||
offset: int = Query(0, ge=0),
|
||||
order_by: str = Query("updated_at", description="Order by: updated_at or created_at"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryProjectListResponse:
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
if order_by not in ["updated_at", "created_at"]:
|
||||
raise HTTPException(status_code=400, detail="order_by must be 'updated_at' or 'created_at'")
|
||||
|
||||
service = StoryProjectService(db)
|
||||
projects, total = service.list_projects(
|
||||
user_id=user_id,
|
||||
status=status,
|
||||
favorites_only=favorites_only,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
order_by=order_by,
|
||||
)
|
||||
|
||||
return StoryProjectListResponse(
|
||||
projects=[StoryProjectResponse.model_validate(p) for p in projects],
|
||||
total=total,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error listing story projects: {str(e)}")
|
||||
|
||||
|
||||
@router.delete("/projects/{project_id}", status_code=204)
|
||||
async def delete_story_project(
|
||||
project_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> None:
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = StoryProjectService(db)
|
||||
deleted = service.delete_project(user_id, project_id)
|
||||
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error deleting story project: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/favorite", response_model=StoryProjectResponse)
|
||||
async def toggle_story_project_favorite(
|
||||
project_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryProjectResponse:
|
||||
try:
|
||||
user_id = current_user.get("user_id") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found")
|
||||
|
||||
service = StoryProjectService(db)
|
||||
project = service.toggle_favorite(user_id, project_id)
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
return StoryProjectResponse.model_validate(project)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error toggling story project favorite: {str(e)}")
|
||||
|
||||
@@ -2,6 +2,8 @@ from typing import Any, Dict, List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import desc
|
||||
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from models.story_models import (
|
||||
@@ -13,8 +15,14 @@ from models.story_models import (
|
||||
StoryScene,
|
||||
StoryStartRequest,
|
||||
StoryPremiseResponse,
|
||||
StoryIdeaEnhanceRequest,
|
||||
StoryIdeaEnhanceResponse,
|
||||
StoryIdeaEnhanceSuggestion,
|
||||
)
|
||||
from services.story_writer.story_service import StoryWriterService
|
||||
from api.onboarding_utils.onboarding_summary_service import OnboardingSummaryService
|
||||
from services.database import get_session_for_user
|
||||
from models.content_asset_models import ContentAsset, AssetType, AssetSource
|
||||
|
||||
from ..utils.auth import require_authenticated_user
|
||||
|
||||
@@ -39,6 +47,9 @@ async def generate_story_setup(
|
||||
|
||||
options = story_service.generate_story_setup_options(
|
||||
story_idea=request.story_idea,
|
||||
story_mode=request.story_mode,
|
||||
story_template=request.story_template,
|
||||
brand_context=request.brand_context,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
@@ -52,6 +63,152 @@ async def generate_story_setup(
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.post("/enhance-idea", response_model=StoryIdeaEnhanceResponse)
|
||||
async def enhance_story_idea(
|
||||
request: StoryIdeaEnhanceRequest,
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> StoryIdeaEnhanceResponse:
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
|
||||
if not request.story_idea or not request.story_idea.strip():
|
||||
raise HTTPException(status_code=400, detail="Story idea is required")
|
||||
|
||||
logger.info(f"[StoryWriter] Enhancing story idea for user {user_id}")
|
||||
|
||||
suggestions = story_service.enhance_story_idea(
|
||||
story_idea=request.story_idea,
|
||||
story_mode=request.story_mode,
|
||||
story_template=request.story_template,
|
||||
brand_context=request.brand_context,
|
||||
user_id=user_id,
|
||||
fiction_variant=request.fiction_variant,
|
||||
narrative_energy=request.narrative_energy,
|
||||
)
|
||||
|
||||
return StoryIdeaEnhanceResponse(
|
||||
suggestions=[StoryIdeaEnhanceSuggestion(**s) for s in suggestions],
|
||||
success=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to enhance story idea: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
@router.get("/context")
|
||||
async def get_story_context(
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
) -> Dict[str, Any]:
|
||||
"""Return onboarding-based story context for the current user."""
|
||||
try:
|
||||
user_id = require_authenticated_user(current_user)
|
||||
summary_service = OnboardingSummaryService(user_id)
|
||||
summary = await summary_service.get_onboarding_summary()
|
||||
|
||||
canonical_profile = summary.get("canonical_profile") or {}
|
||||
persona_readiness = summary.get("persona_readiness") or {}
|
||||
capabilities = summary.get("capabilities") or {}
|
||||
|
||||
website_url = summary.get("website_url")
|
||||
style_analysis = summary.get("style_analysis") or {}
|
||||
research_preferences = summary.get("research_preferences") or {}
|
||||
|
||||
brand_name = None
|
||||
if isinstance(style_analysis, dict):
|
||||
brand_name = style_analysis.get("brand_name") or style_analysis.get("site_title")
|
||||
|
||||
writing_tone = canonical_profile.get("writing_tone")
|
||||
target_audience = canonical_profile.get("target_audience")
|
||||
|
||||
brand_context = {
|
||||
"brand_name": brand_name,
|
||||
"writing_tone": writing_tone,
|
||||
"target_audience": target_audience,
|
||||
}
|
||||
|
||||
avatar_url = None
|
||||
voice_preview_url = None
|
||||
custom_voice_id = None
|
||||
|
||||
db: Session | None = get_session_for_user(user_id)
|
||||
if db:
|
||||
try:
|
||||
avatar_asset = (
|
||||
db.query(ContentAsset)
|
||||
.filter(
|
||||
ContentAsset.user_id == user_id,
|
||||
ContentAsset.asset_type == AssetType.IMAGE,
|
||||
ContentAsset.source_module.in_(
|
||||
[AssetSource.BRAND_AVATAR_GENERATOR, AssetSource.STORY_WRITER]
|
||||
),
|
||||
)
|
||||
.order_by(desc(ContentAsset.created_at))
|
||||
.limit(50)
|
||||
.all()
|
||||
)
|
||||
|
||||
selected_avatar = None
|
||||
for candidate in avatar_asset:
|
||||
if candidate.source_module == AssetSource.BRAND_AVATAR_GENERATOR:
|
||||
selected_avatar = candidate
|
||||
break
|
||||
meta = candidate.asset_metadata or {}
|
||||
if meta.get("category") == "brand_avatar":
|
||||
selected_avatar = candidate
|
||||
break
|
||||
|
||||
if selected_avatar:
|
||||
avatar_url = selected_avatar.file_url
|
||||
|
||||
voice_asset = (
|
||||
db.query(ContentAsset)
|
||||
.filter(
|
||||
ContentAsset.user_id == user_id,
|
||||
ContentAsset.asset_type == AssetType.AUDIO,
|
||||
ContentAsset.source_module == AssetSource.VOICE_CLONER,
|
||||
)
|
||||
.order_by(desc(ContentAsset.created_at))
|
||||
.first()
|
||||
)
|
||||
|
||||
if voice_asset:
|
||||
meta = voice_asset.asset_metadata or {}
|
||||
voice_preview_url = meta.get("preview_url") or voice_asset.file_url
|
||||
custom_voice_id = meta.get("custom_voice_id")
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
persona_enabled = bool(persona_readiness.get("ready")) and bool(
|
||||
capabilities.get("persona_generation")
|
||||
)
|
||||
has_persona_context = persona_enabled and bool(
|
||||
brand_name or writing_tone or target_audience or avatar_url or voice_preview_url
|
||||
)
|
||||
|
||||
return {
|
||||
"canonical_profile": canonical_profile,
|
||||
"website_url": website_url,
|
||||
"research_preferences": research_preferences,
|
||||
"brand_context": brand_context,
|
||||
"brand_assets": {
|
||||
"avatar_url": avatar_url,
|
||||
"voice_preview_url": voice_preview_url,
|
||||
"custom_voice_id": custom_voice_id,
|
||||
},
|
||||
"persona_enabled": persona_enabled,
|
||||
"has_persona_context": has_persona_context,
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to get story context: {exc}")
|
||||
raise HTTPException(status_code=500, detail="Failed to load story context")
|
||||
|
||||
|
||||
@router.post("/generate-premise", response_model=StoryPremiseResponse)
|
||||
async def generate_premise(
|
||||
request: StoryGenerationRequest,
|
||||
@@ -108,6 +265,9 @@ async def generate_outline(
|
||||
request.story_tone,
|
||||
)
|
||||
|
||||
# For now, treat all outlines as potentially anime-aware. The downstream
|
||||
# generation logic will decide whether to actually create a bible based
|
||||
# on how the prompt is interpreted (e.g., anime templates in persona).
|
||||
outline = story_service.generate_outline(
|
||||
premise=request.premise,
|
||||
persona=request.persona,
|
||||
@@ -122,15 +282,37 @@ async def generate_outline(
|
||||
ending_preference=request.ending_preference,
|
||||
user_id=user_id,
|
||||
use_structured_output=use_structured,
|
||||
include_anime_bible=True,
|
||||
)
|
||||
|
||||
if isinstance(outline, list):
|
||||
scenes: List[StoryScene] = [
|
||||
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline
|
||||
]
|
||||
return StoryOutlineResponse(outline=scenes, success=True, is_structured=True)
|
||||
anime_bible: Dict[str, Any] | None = None
|
||||
outline_payload: Any = outline
|
||||
|
||||
return StoryOutlineResponse(outline=str(outline), success=True, is_structured=False)
|
||||
if isinstance(outline, dict):
|
||||
if "anime_bible" in outline:
|
||||
anime_bible = outline.get("anime_bible")
|
||||
if "scenes" in outline:
|
||||
outline_payload = outline.get("scenes")
|
||||
elif "outline" in outline:
|
||||
outline_payload = outline.get("outline")
|
||||
|
||||
if isinstance(outline_payload, list):
|
||||
scenes: List[StoryScene] = [
|
||||
StoryScene(**scene) if isinstance(scene, dict) else scene for scene in outline_payload
|
||||
]
|
||||
return StoryOutlineResponse(
|
||||
outline=scenes,
|
||||
success=True,
|
||||
is_structured=True,
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
|
||||
return StoryOutlineResponse(
|
||||
outline=str(outline_payload),
|
||||
success=True,
|
||||
is_structured=False,
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
@@ -350,9 +351,21 @@ def execute_complete_video_generation(
|
||||
Runs in a background task and performs blocking operations.
|
||||
"""
|
||||
try:
|
||||
task_manager.update_task_status(task_id, "processing", progress=5.0, message="Starting complete video generation...")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=5.0,
|
||||
message="Starting complete video generation...",
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=10.0, message="Generating story premise...")
|
||||
anime_bible = request_data.get("anime_bible")
|
||||
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=10.0,
|
||||
message="Generating story premise...",
|
||||
)
|
||||
premise = story_service.generate_premise(
|
||||
persona=request_data["persona"],
|
||||
story_setting=request_data["story_setting"],
|
||||
@@ -367,7 +380,12 @@ def execute_complete_video_generation(
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=20.0, message="Generating structured outline with scenes...")
|
||||
task_manager.update_task_status(
|
||||
task_id,
|
||||
"processing",
|
||||
progress=20.0,
|
||||
message="Generating structured outline with scenes...",
|
||||
)
|
||||
outline_scenes = story_service.generate_outline(
|
||||
premise=premise,
|
||||
persona=request_data["persona"],
|
||||
@@ -401,6 +419,7 @@ def execute_complete_video_generation(
|
||||
height=request_data.get("image_height", 1024),
|
||||
model=request_data.get("image_model"),
|
||||
progress_callback=image_progress_callback,
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
|
||||
task_manager.update_task_status(task_id, "processing", progress=50.0, message="Generating audio narration for scenes...")
|
||||
|
||||
@@ -140,7 +140,7 @@ class TaskManager:
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Step 2: Generate outline
|
||||
@@ -157,7 +157,7 @@ class TaskManager:
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Step 3: Generate story start
|
||||
@@ -175,7 +175,8 @@ class TaskManager:
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
anime_bible=request_data.get("anime_bible"),
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
# Step 4: Continue story
|
||||
@@ -208,7 +209,8 @@ class TaskManager:
|
||||
audience_age_group=request_data["audience_age_group"],
|
||||
content_rating=request_data["content_rating"],
|
||||
ending_preference=request_data["ending_preference"],
|
||||
user_id=user_id
|
||||
anime_bible=request_data.get("anime_bible"),
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
if continuation:
|
||||
|
||||
@@ -8,7 +8,7 @@ def require_authenticated_user(current_user: Dict[str, Any] | None) -> str:
|
||||
Validates the current user dictionary provided by Clerk middleware and
|
||||
returns the normalized user_id. Raises HTTP 401 if authentication fails.
|
||||
"""
|
||||
if not current_user:
|
||||
if not current_user or not isinstance(current_user, dict):
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required")
|
||||
|
||||
user_id = str(current_user.get("id", "")).strip()
|
||||
|
||||
@@ -12,13 +12,11 @@ from services.user_workspace_manager import UserWorkspaceManager
|
||||
|
||||
|
||||
BASE_DIR = Path(__file__).resolve().parents[4] # root/
|
||||
DATA_MEDIA_DIR = BASE_DIR / "workspace" / "media"
|
||||
# Default global media directory matches story image/audio services (root/data/media)
|
||||
DATA_MEDIA_DIR = BASE_DIR / "data" / "media"
|
||||
|
||||
STORY_IMAGES_DIR = (DATA_MEDIA_DIR / "story_images").resolve()
|
||||
# STORY_IMAGES_DIR.mkdir(parents=True, exist_ok=True) # Disabled global creation
|
||||
|
||||
STORY_AUDIO_DIR = (DATA_MEDIA_DIR / "story_audio").resolve()
|
||||
# STORY_AUDIO_DIR.mkdir(parents=True, exist_ok=True) # Disabled global creation
|
||||
|
||||
|
||||
def _get_user_media_path(user_id: str, media_type: str) -> Optional[Path]:
|
||||
|
||||
@@ -12,7 +12,10 @@ from .routes import (
|
||||
alerts,
|
||||
dashboard,
|
||||
logs,
|
||||
preflight
|
||||
preflight,
|
||||
payment,
|
||||
disputes,
|
||||
fraud_warnings,
|
||||
)
|
||||
|
||||
# Create main router
|
||||
@@ -26,5 +29,8 @@ router.include_router(alerts.router, tags=["subscription"])
|
||||
router.include_router(dashboard.router, tags=["subscription"])
|
||||
router.include_router(logs.router, tags=["subscription"])
|
||||
router.include_router(preflight.router, tags=["subscription"])
|
||||
router.include_router(payment.router, tags=["subscription"])
|
||||
router.include_router(disputes.router, tags=["subscription"])
|
||||
router.include_router(fraud_warnings.router, tags=["subscription"])
|
||||
|
||||
__all__ = ["router"]
|
||||
|
||||
@@ -3,6 +3,6 @@ Subscription API Routes
|
||||
All route modules are imported here for easy access.
|
||||
"""
|
||||
|
||||
from . import usage, plans, subscriptions, alerts, dashboard, logs, preflight
|
||||
from . import usage, plans, subscriptions, alerts, dashboard, logs, preflight, payment, disputes
|
||||
|
||||
__all__ = ["usage", "plans", "subscriptions", "alerts", "dashboard", "logs", "preflight"]
|
||||
__all__ = ["usage", "plans", "subscriptions", "alerts", "dashboard", "logs", "preflight", "payment", "disputes"]
|
||||
|
||||
142
backend/api/subscription/routes/disputes.py
Normal file
142
backend/api/subscription/routes/disputes.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from loguru import logger
|
||||
import stripe
|
||||
import os
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _ensure_admin(current_user: Dict[str, Any]) -> None:
|
||||
disable_auth = os.getenv("DISABLE_AUTH", "false").lower() == "true"
|
||||
if disable_auth:
|
||||
return
|
||||
|
||||
email = (current_user.get("email") or "").lower()
|
||||
role = None
|
||||
public_metadata = current_user.get("public_metadata")
|
||||
if isinstance(public_metadata, dict):
|
||||
role = public_metadata.get("role") or current_user.get("role")
|
||||
else:
|
||||
role = current_user.get("role")
|
||||
|
||||
admin_emails_raw = os.getenv("ADMIN_EMAILS", "")
|
||||
admin_emails = {
|
||||
e.strip().lower() for e in admin_emails_raw.split(",") if e.strip()
|
||||
}
|
||||
admin_domain = (os.getenv("ADMIN_EMAIL_DOMAIN") or "").lower().strip()
|
||||
|
||||
is_admin_email = email and email in admin_emails
|
||||
is_admin_domain = email and admin_domain and email.endswith("@" + admin_domain)
|
||||
is_admin_role = role == "admin"
|
||||
|
||||
if not (is_admin_email or is_admin_domain or is_admin_role):
|
||||
raise HTTPException(status_code=403, detail="Admin access required for dispute operations")
|
||||
|
||||
|
||||
def _get_stripe_client() -> None:
|
||||
api_key = os.getenv("STRIPE_SECRET_KEY")
|
||||
if not api_key:
|
||||
logger.error("STRIPE_SECRET_KEY is not configured; dispute operations are disabled")
|
||||
raise HTTPException(status_code=500, detail="Payment service not configured")
|
||||
stripe.api_key = api_key
|
||||
|
||||
|
||||
class DisputeEvidenceUpdateRequest(BaseModel):
|
||||
evidence: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
@router.get("/disputes")
|
||||
async def list_disputes(
|
||||
limit: int = 10,
|
||||
starting_after: Optional[str] = None,
|
||||
ending_before: Optional[str] = None,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_get_stripe_client()
|
||||
_ensure_admin(current_user)
|
||||
|
||||
try:
|
||||
params: Dict[str, Any] = {"limit": max(1, min(limit, 100))}
|
||||
if starting_after:
|
||||
params["starting_after"] = starting_after
|
||||
if ending_before:
|
||||
params["ending_before"] = ending_before
|
||||
|
||||
disputes = stripe.Dispute.list(**params)
|
||||
return {"data": disputes}
|
||||
except Exception as e:
|
||||
logger.error(f"Error listing disputes: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to list disputes")
|
||||
|
||||
|
||||
@router.get("/disputes/{dispute_id}")
|
||||
async def get_dispute(
|
||||
dispute_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_get_stripe_client()
|
||||
_ensure_admin(current_user)
|
||||
|
||||
try:
|
||||
dispute = stripe.Dispute.retrieve(dispute_id)
|
||||
return {"data": dispute}
|
||||
except stripe.error.InvalidRequestError as e:
|
||||
logger.warning(f"Invalid dispute id {dispute_id}: {e}")
|
||||
raise HTTPException(status_code=404, detail="Dispute not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving dispute {dispute_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve dispute")
|
||||
|
||||
|
||||
@router.post("/disputes/{dispute_id}")
|
||||
async def update_dispute(
|
||||
dispute_id: str,
|
||||
payload: DisputeEvidenceUpdateRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_get_stripe_client()
|
||||
_ensure_admin(current_user)
|
||||
|
||||
if not payload.evidence:
|
||||
raise HTTPException(status_code=400, detail="Evidence payload is required to update a dispute")
|
||||
|
||||
try:
|
||||
dispute = stripe.Dispute.modify(
|
||||
dispute_id,
|
||||
evidence=payload.evidence,
|
||||
)
|
||||
return {"data": dispute}
|
||||
except stripe.error.InvalidRequestError as e:
|
||||
logger.warning(f"Invalid dispute id {dispute_id} during update: {e}")
|
||||
raise HTTPException(status_code=404, detail="Dispute not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating dispute {dispute_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to update dispute")
|
||||
|
||||
|
||||
@router.post("/disputes/{dispute_id}/close")
|
||||
async def close_dispute(
|
||||
dispute_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_get_stripe_client()
|
||||
_ensure_admin(current_user)
|
||||
|
||||
try:
|
||||
dispute = stripe.Dispute.close(dispute_id)
|
||||
return {"data": dispute}
|
||||
except stripe.error.InvalidRequestError as e:
|
||||
logger.warning(f"Invalid dispute id {dispute_id} during close: {e}")
|
||||
raise HTTPException(status_code=404, detail="Dispute not found")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing dispute {dispute_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to close dispute")
|
||||
209
backend/api/subscription/routes/fraud_warnings.py
Normal file
209
backend/api/subscription/routes/fraud_warnings.py
Normal file
@@ -0,0 +1,209 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
from services.database import get_db
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from loguru import logger
|
||||
import stripe
|
||||
import os
|
||||
from datetime import datetime
|
||||
from models.subscription_models import FraudWarning
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _ensure_admin(current_user: Dict[str, Any]) -> None:
|
||||
disable_auth = os.getenv("DISABLE_AUTH", "false").lower() == "true"
|
||||
if disable_auth:
|
||||
return
|
||||
|
||||
email = (current_user.get("email") or "").lower()
|
||||
role = None
|
||||
public_metadata = current_user.get("public_metadata")
|
||||
if isinstance(public_metadata, dict):
|
||||
role = public_metadata.get("role") or current_user.get("role")
|
||||
else:
|
||||
role = current_user.get("role")
|
||||
|
||||
admin_emails_raw = os.getenv("ADMIN_EMAILS", "")
|
||||
admin_emails = {
|
||||
e.strip().lower() for e in admin_emails_raw.split(",") if e.strip()
|
||||
}
|
||||
admin_domain = (os.getenv("ADMIN_EMAIL_DOMAIN") or "").lower().strip()
|
||||
|
||||
is_admin_email = email and email in admin_emails
|
||||
is_admin_domain = email and admin_domain and email.endswith("@" + admin_domain)
|
||||
is_admin_role = role == "admin"
|
||||
|
||||
if not (is_admin_email or is_admin_domain or is_admin_role):
|
||||
raise HTTPException(status_code=403, detail="Admin access required for fraud warning operations")
|
||||
|
||||
|
||||
def _get_stripe_client() -> None:
|
||||
api_key = os.getenv("STRIPE_SECRET_KEY")
|
||||
if not api_key:
|
||||
logger.error("STRIPE_SECRET_KEY is not configured; fraud warning operations are disabled")
|
||||
raise HTTPException(status_code=500, detail="Payment service not configured")
|
||||
stripe.api_key = api_key
|
||||
|
||||
|
||||
class FraudWarningRefundRequest(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
class FraudWarningIgnoreRequest(BaseModel):
|
||||
notes: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/fraud-warnings")
|
||||
async def list_fraud_warnings(
|
||||
status: Optional[str] = "open",
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_ensure_admin(current_user)
|
||||
|
||||
query = db.query(FraudWarning)
|
||||
if status:
|
||||
query = query.filter(FraudWarning.status == status)
|
||||
|
||||
limit = max(1, min(limit, 100))
|
||||
items = (
|
||||
query.order_by(FraudWarning.created_at.desc())
|
||||
.offset(max(0, offset))
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
data = []
|
||||
for fw in items:
|
||||
data.append(
|
||||
{
|
||||
"id": fw.id,
|
||||
"charge_id": fw.charge_id,
|
||||
"payment_intent_id": fw.payment_intent_id,
|
||||
"user_id": fw.user_id,
|
||||
"amount": fw.amount,
|
||||
"currency": fw.currency,
|
||||
"status": fw.status,
|
||||
"action": fw.action,
|
||||
"action_at": fw.action_at.isoformat() if fw.action_at else None,
|
||||
"created_at": fw.created_at.isoformat() if fw.created_at else None,
|
||||
}
|
||||
)
|
||||
|
||||
return {"data": data}
|
||||
|
||||
|
||||
@router.get("/fraud-warnings/{warning_id}")
|
||||
async def get_fraud_warning(
|
||||
warning_id: str,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_ensure_admin(current_user)
|
||||
|
||||
fw = db.query(FraudWarning).filter(FraudWarning.id == warning_id).first()
|
||||
if not fw:
|
||||
raise HTTPException(status_code=404, detail="Fraud warning not found")
|
||||
|
||||
payload: Dict[str, Any] = {
|
||||
"id": fw.id,
|
||||
"charge_id": fw.charge_id,
|
||||
"payment_intent_id": fw.payment_intent_id,
|
||||
"user_id": fw.user_id,
|
||||
"amount": fw.amount,
|
||||
"currency": fw.currency,
|
||||
"status": fw.status,
|
||||
"action": fw.action,
|
||||
"action_at": fw.action_at.isoformat() if fw.action_at else None,
|
||||
"reason_notes": fw.reason_notes,
|
||||
"created_at": fw.created_at.isoformat() if fw.created_at else None,
|
||||
"meta_info": fw.meta_info,
|
||||
}
|
||||
|
||||
return {"data": payload}
|
||||
|
||||
|
||||
@router.post("/fraud-warnings/{warning_id}/refund")
|
||||
async def refund_fraud_warning(
|
||||
warning_id: str,
|
||||
payload: FraudWarningRefundRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_ensure_admin(current_user)
|
||||
_get_stripe_client()
|
||||
|
||||
fw = db.query(FraudWarning).filter(FraudWarning.id == warning_id).first()
|
||||
if not fw:
|
||||
raise HTTPException(status_code=404, detail="Fraud warning not found")
|
||||
|
||||
if fw.status == "refunded":
|
||||
raise HTTPException(status_code=400, detail="Fraud warning already refunded")
|
||||
|
||||
try:
|
||||
stripe.Refund.create(charge=fw.charge_id)
|
||||
except stripe.error.InvalidRequestError as e:
|
||||
logger.warning(f"Refund failed for fraud warning {warning_id}, charge {fw.charge_id}: {e}")
|
||||
raise HTTPException(status_code=400, detail="Refund failed for this charge")
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error refunding fraud warning {warning_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail="Unexpected error while processing refund")
|
||||
|
||||
fw.status = "refunded"
|
||||
fw.action = "refund_full"
|
||||
fw.action_at = datetime.utcnow()
|
||||
if payload and payload.notes:
|
||||
fw.reason_notes = payload.notes
|
||||
|
||||
db.commit()
|
||||
db.refresh(fw)
|
||||
|
||||
return {
|
||||
"data": {
|
||||
"id": fw.id,
|
||||
"status": fw.status,
|
||||
"action": fw.action,
|
||||
"action_at": fw.action_at.isoformat() if fw.action_at else None,
|
||||
"reason_notes": fw.reason_notes,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@router.post("/fraud-warnings/{warning_id}/ignore")
|
||||
async def ignore_fraud_warning(
|
||||
warning_id: str,
|
||||
payload: FraudWarningIgnoreRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
):
|
||||
_ensure_admin(current_user)
|
||||
|
||||
fw = db.query(FraudWarning).filter(FraudWarning.id == warning_id).first()
|
||||
if not fw:
|
||||
raise HTTPException(status_code=404, detail="Fraud warning not found")
|
||||
|
||||
fw.status = "ignored"
|
||||
fw.action = "ignored"
|
||||
fw.action_at = datetime.utcnow()
|
||||
if payload and payload.notes:
|
||||
fw.reason_notes = payload.notes
|
||||
|
||||
db.commit()
|
||||
db.refresh(fw)
|
||||
|
||||
return {
|
||||
"data": {
|
||||
"id": fw.id,
|
||||
"status": fw.status,
|
||||
"action": fw.action,
|
||||
"action_at": fw.action_at.isoformat() if fw.action_at else None,
|
||||
"reason_notes": fw.reason_notes,
|
||||
}
|
||||
}
|
||||
|
||||
125
backend/api/subscription/routes/payment.py
Normal file
125
backend/api/subscription/routes/payment.py
Normal file
@@ -0,0 +1,125 @@
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Header, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Dict, Any, Optional
|
||||
from pydantic import BaseModel
|
||||
from services.database import get_db
|
||||
from services.subscription.stripe_service import StripeService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from loguru import logger
|
||||
from models.subscription_models import SubscriptionTier, BillingCycle
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
class CreateCheckoutSessionRequest(BaseModel):
|
||||
tier: SubscriptionTier
|
||||
billing_cycle: BillingCycle
|
||||
success_url: str
|
||||
cancel_url: str
|
||||
|
||||
class CreatePortalSessionRequest(BaseModel):
|
||||
return_url: str
|
||||
|
||||
|
||||
_checkout_rate_limit_window_seconds = 60
|
||||
_checkout_rate_limit_max_requests = 10
|
||||
_checkout_attempts_by_user: Dict[str, Any] = defaultdict(list)
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
payload: CreateCheckoutSessionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user),
|
||||
request: Request = None
|
||||
):
|
||||
"""
|
||||
Create a Stripe Checkout Session for subscription.
|
||||
"""
|
||||
user_id = current_user.get("sub") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
|
||||
now = time.time()
|
||||
attempts = _checkout_attempts_by_user[user_id]
|
||||
window_start = now - _checkout_rate_limit_window_seconds
|
||||
attempts[:] = [ts for ts in attempts if ts >= window_start]
|
||||
attempts.append(now)
|
||||
_checkout_attempts_by_user[user_id] = attempts
|
||||
if len(attempts) > _checkout_rate_limit_max_requests:
|
||||
client_ip = request.client.host if request and request.client else "unknown"
|
||||
logger.warning(f"Checkout rate limit exceeded for user_id={user_id}, ip={client_ip}, attempts={len(attempts)} in { _checkout_rate_limit_window_seconds }s")
|
||||
raise HTTPException(status_code=429, detail="Too many checkout attempts. Please try again shortly.")
|
||||
|
||||
user_email = current_user.get("email")
|
||||
|
||||
stripe_service = StripeService(db)
|
||||
|
||||
try:
|
||||
url = stripe_service.create_checkout_session(
|
||||
user_id=user_id,
|
||||
tier=payload.tier,
|
||||
billing_cycle=payload.billing_cycle,
|
||||
success_url=payload.success_url,
|
||||
cancel_url=payload.cancel_url,
|
||||
user_email=user_email
|
||||
)
|
||||
return {"url": url}
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating checkout session: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to initiate checkout")
|
||||
|
||||
@router.post("/create-portal-session")
|
||||
async def create_portal_session(
|
||||
payload: CreatePortalSessionRequest,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: Dict[str, Any] = Depends(get_current_user)
|
||||
):
|
||||
"""
|
||||
Create a Stripe Customer Portal session for managing billing.
|
||||
"""
|
||||
user_id = current_user.get("sub") or current_user.get("id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
|
||||
stripe_service = StripeService(db)
|
||||
|
||||
try:
|
||||
url = stripe_service.create_portal_session(
|
||||
user_id=user_id,
|
||||
return_url=payload.return_url
|
||||
)
|
||||
return {"url": url}
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating portal session: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to access billing portal")
|
||||
|
||||
@router.post("/webhook")
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
stripe_signature: str = Header(None),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Handle Stripe webhooks.
|
||||
"""
|
||||
if not stripe_signature:
|
||||
raise HTTPException(status_code=400, detail="Missing stripe-signature header")
|
||||
|
||||
payload = await request.body()
|
||||
stripe_service = StripeService(db)
|
||||
|
||||
try:
|
||||
# We need to run this potentially in background or await it
|
||||
# Since it's async, we can await it directly.
|
||||
await stripe_service.handle_webhook(payload, stripe_signature)
|
||||
return {"status": "success"}
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing webhook: {e}")
|
||||
raise HTTPException(status_code=500, detail="Webhook processing failed")
|
||||
@@ -376,6 +376,9 @@ async def analyze_urls_ai_endpoint(request: AnalyzeURLsRequest, current_user: di
|
||||
# Include platform analytics router
|
||||
from routers.platform_analytics import router as platform_analytics_router
|
||||
app.include_router(platform_analytics_router)
|
||||
# Include Bing Analytics Storage router to expose storage-backed endpoints
|
||||
from routers.bing_analytics_storage import router as bing_analytics_storage_router
|
||||
app.include_router(bing_analytics_storage_router)
|
||||
app.include_router(images_router)
|
||||
app.include_router(image_studio_router)
|
||||
app.include_router(product_marketing_router)
|
||||
|
||||
56
backend/debug_analytics_api.py
Normal file
56
backend/debug_analytics_api.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import date, timedelta
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
base_url = os.environ.get("ALWRITY_API_BASE_URL", "http://localhost:8000")
|
||||
token = os.environ.get("ALWRITY_API_TOKEN")
|
||||
|
||||
today = date.today()
|
||||
start = today - timedelta(days=29)
|
||||
|
||||
params = {
|
||||
"platforms": "gsc",
|
||||
"start_date": start.isoformat(),
|
||||
"end_date": today.isoformat(),
|
||||
}
|
||||
|
||||
headers = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
async with httpx.AsyncClient(base_url=base_url, headers=headers, timeout=60.0) as client:
|
||||
resp = await client.get("/api/analytics/data", params=params)
|
||||
print(f"Status: {resp.status_code}")
|
||||
try:
|
||||
data = resp.json()
|
||||
except Exception:
|
||||
print("Non‑JSON response body:")
|
||||
print(resp.text)
|
||||
return
|
||||
|
||||
print("Raw JSON response:")
|
||||
print(data)
|
||||
|
||||
summary = data.get("summary") or {}
|
||||
platforms = data.get("data") or {}
|
||||
gsc = platforms.get("gsc") or {}
|
||||
gsc_metrics = gsc.get("metrics") or {}
|
||||
|
||||
print("\nSummary snapshot:")
|
||||
print(f" total_clicks: {summary.get('total_clicks')}")
|
||||
print(f" total_impressions: {summary.get('total_impressions')}")
|
||||
print(f" overall_ctr: {summary.get('overall_ctr')}")
|
||||
|
||||
print("\nGSC metrics snapshot:")
|
||||
print(f" total_clicks: {gsc_metrics.get('total_clicks')}")
|
||||
print(f" total_impressions: {gsc_metrics.get('total_impressions')}")
|
||||
print(f" avg_ctr: {gsc_metrics.get('avg_ctr')}")
|
||||
print(f" avg_position: {gsc_metrics.get('avg_position')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -129,7 +129,8 @@ from api.seo_dashboard import (
|
||||
analyze_urls_ai,
|
||||
AnalyzeURLsRequest,
|
||||
get_analyzed_pages,
|
||||
get_semantic_health # Phase 2B: Semantic health monitoring
|
||||
get_semantic_health,
|
||||
get_sif_indexing_health
|
||||
)
|
||||
|
||||
# Initialize FastAPI app
|
||||
@@ -337,6 +338,15 @@ async def semantic_cache_stats_endpoint(current_user: dict = Depends(get_current
|
||||
"""
|
||||
return await get_semantic_cache_stats(current_user)
|
||||
|
||||
|
||||
@app.get("/api/seo-dashboard/sif-health")
|
||||
async def sif_indexing_health_endpoint(current_user: dict = Depends(get_current_user)):
|
||||
"""
|
||||
Get SIF indexing health summary for the current user.
|
||||
Used by the Semantic Indexing Status widget on the dashboard.
|
||||
"""
|
||||
return await get_sif_indexing_health(current_user)
|
||||
|
||||
# Comprehensive SEO Analysis endpoints
|
||||
@app.post("/api/seo-dashboard/analyze-comprehensive")
|
||||
async def analyze_seo_comprehensive_endpoint(request: SEOAnalysisRequest):
|
||||
|
||||
@@ -227,7 +227,10 @@ class ClerkAuthMiddleware:
|
||||
'last_name': last_name,
|
||||
'clerk_user_id': user_id
|
||||
}
|
||||
logger.error("Fallback decoding is disabled in production.")
|
||||
# In production mode, treat fallback as a soft failure:
|
||||
# log at warning level (once per process) and let the caller
|
||||
# handle this as an authentication failure without spamming logs.
|
||||
logger.warning("Fallback decoding is disabled in production.")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
@@ -247,21 +250,33 @@ async def get_current_user(
|
||||
) -> Dict[str, Any]:
|
||||
"""Get current authenticated user."""
|
||||
try:
|
||||
# Safe header access
|
||||
auth_header = None
|
||||
user_agent = "unknown"
|
||||
all_headers = {}
|
||||
|
||||
try:
|
||||
if hasattr(request, 'headers'):
|
||||
if hasattr(request.headers, 'get'):
|
||||
auth_header = request.headers.get('authorization') or request.headers.get('Authorization')
|
||||
user_agent = request.headers.get('user-agent', 'unknown')
|
||||
|
||||
if hasattr(request.headers, 'items'):
|
||||
all_headers = {k: v[:50] if len(v) > 50 else v for k, v in request.headers.items()}
|
||||
except:
|
||||
pass
|
||||
|
||||
if not credentials:
|
||||
# CRITICAL: Log as ERROR since this is a security issue - authenticated endpoint accessed without credentials
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
|
||||
# DEBUG: Log all headers to see what's actually being received
|
||||
auth_header = request.headers.get('authorization') or request.headers.get('Authorization')
|
||||
all_headers = {k: v[:50] if len(v) > 50 else v for k, v in request.headers.items()}
|
||||
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: No credentials provided for authenticated endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'}, "
|
||||
f"auth_header_received={'YES' if auth_header else 'NO'}, "
|
||||
f"auth_header_value={auth_header[:50] + '...' if auth_header and len(auth_header) > 50 else (auth_header or 'None')}, "
|
||||
f"all_headers={list(all_headers.keys())}, "
|
||||
f"user_agent={request.headers.get('user-agent', 'unknown')})"
|
||||
f"user_agent={user_agent})"
|
||||
)
|
||||
|
||||
# Get caller information for better debugging
|
||||
@@ -328,11 +343,19 @@ async def get_current_user(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Safe header access for logging
|
||||
safe_user_agent = "unknown"
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
safe_user_agent = request.headers.get('user-agent', 'unknown')
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: Token verification failed for endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'}, "
|
||||
f"caller={caller_info}, "
|
||||
f"user_agent={request.headers.get('user-agent', 'unknown')})"
|
||||
f"user_agent={safe_user_agent})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -369,7 +392,7 @@ async def get_current_user(
|
||||
f"🔒 AUTHENTICATION ERROR: Unexpected error during authentication for endpoint: {endpoint_path}: {e} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'}, "
|
||||
f"caller={caller_info}, "
|
||||
f"user_agent={request.headers.get('user-agent', 'unknown')})",
|
||||
f"user_agent={user_agent})",
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
@@ -420,7 +443,12 @@ async def get_current_user_with_query_token(
|
||||
token_to_verify = credentials.credentials
|
||||
else:
|
||||
# Fall back to query parameter if no header
|
||||
query_token = request.query_params.get("token")
|
||||
query_token = None
|
||||
try:
|
||||
if hasattr(request, 'query_params') and hasattr(request.query_params, 'get'):
|
||||
query_token = request.query_params.get("token")
|
||||
except:
|
||||
pass
|
||||
if query_token:
|
||||
token_to_verify = query_token
|
||||
|
||||
@@ -428,6 +456,14 @@ async def get_current_user_with_query_token(
|
||||
# CRITICAL: Log as ERROR since this is a security issue
|
||||
endpoint_path = f"{request.method} {request.url.path}"
|
||||
|
||||
# Safe user agent access
|
||||
user_agent = "unknown"
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
user_agent = request.headers.get('user-agent', 'unknown')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Get caller information
|
||||
caller_frame = inspect.currentframe()
|
||||
caller_info = "unknown"
|
||||
@@ -446,12 +482,20 @@ async def get_current_user_with_query_token(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Safe header access for logging
|
||||
safe_user_agent = "unknown"
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
safe_user_agent = request.headers.get('user-agent', 'unknown')
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: No credentials provided (neither header nor query parameter) "
|
||||
f"for authenticated endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'}, "
|
||||
f"caller={caller_info}, "
|
||||
f"user_agent={request.headers.get('user-agent', 'unknown')})"
|
||||
f"user_agent={safe_user_agent})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -482,11 +526,19 @@ async def get_current_user_with_query_token(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Safe header access for logging
|
||||
safe_user_agent = "unknown"
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
safe_user_agent = request.headers.get('user-agent', 'unknown')
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: Token verification failed for endpoint: {endpoint_path} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'}, "
|
||||
f"caller={caller_info}, "
|
||||
f"user_agent={request.headers.get('user-agent', 'unknown')})"
|
||||
f"user_agent={safe_user_agent})"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@@ -519,11 +571,19 @@ async def get_current_user_with_query_token(
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Safe header access for logging
|
||||
safe_user_agent = "unknown"
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
safe_user_agent = request.headers.get('user-agent', 'unknown')
|
||||
except:
|
||||
pass
|
||||
|
||||
logger.error(
|
||||
f"🔒 AUTHENTICATION ERROR: Unexpected error during authentication for endpoint: {endpoint_path}: {e} "
|
||||
f"(client_ip={request.client.host if request.client else 'unknown'}, "
|
||||
f"caller={caller_info}, "
|
||||
f"user_agent={request.headers.get('user-agent', 'unknown')})",
|
||||
f"user_agent={safe_user_agent})",
|
||||
exc_info=True
|
||||
)
|
||||
raise HTTPException(
|
||||
|
||||
65
backend/models/podcast_bible_models.py
Normal file
65
backend/models/podcast_bible_models.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""
|
||||
Podcast Bible Models
|
||||
|
||||
Pydantic models for the structured Podcast Bible, used for hyper-personalization.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional, Dict, Any
|
||||
|
||||
class HostPersona(BaseModel):
|
||||
"""Details about the podcast host persona."""
|
||||
name: str = Field(..., description="Name of the podcast host")
|
||||
background: str = Field(..., description="Professional background and expertise")
|
||||
expertise_level: str = Field(..., description="Level of expertise (e.g., Expert, Practitioner, Enthusiast)")
|
||||
personality_traits: List[str] = Field(default_factory=list, description="Personality traits (e.g., Witty, Authoritative, Empathetic)")
|
||||
vocal_style: str = Field(..., description="Description of the vocal style and delivery")
|
||||
vocal_characteristics: List[str] = Field(default_factory=list, description="Specific vocal traits (e.g., Deep, Raspy, Energetic, Calm)")
|
||||
look: Optional[str] = Field(None, description="Visual description of the host (for avatar generation)")
|
||||
catchphrases: List[str] = Field(default_factory=list, description="Commonly used phrases or sign-offs")
|
||||
|
||||
class VisualStyle(BaseModel):
|
||||
"""Visual aesthetic for the podcast videos and avatars."""
|
||||
style_preset: str = Field(default="Professional Studio", description="Visual style (e.g., 3D Cartoon, Cinematic, Minimalist)")
|
||||
environment: str = Field(..., description="The studio or setting where the podcast takes place")
|
||||
lighting: str = Field(default="Soft Studio Lighting", description="Lighting mood and setup")
|
||||
color_palette: List[str] = Field(default_factory=list, description="Primary brand colors for the visual elements")
|
||||
camera_style: str = Field(default="Static Mid-shot", description="Preferred camera framing and movement")
|
||||
|
||||
class AudioEnvironment(BaseModel):
|
||||
"""The soundscape and audio characteristics of the podcast."""
|
||||
soundscape: str = Field(default="Quiet Studio", description="Acoustics and ambient noise level")
|
||||
music_mood: str = Field(default="Professional & Subtle", description="Genre and mood of background music")
|
||||
sfx_style: str = Field(default="Minimalist", description="Style of sound effects used (e.g., tech-inspired, natural)")
|
||||
|
||||
class ShowRules(BaseModel):
|
||||
"""Consistency rules for the podcast narrative and structure."""
|
||||
intro_format: str = Field(..., description="Standard way to start the episode")
|
||||
outro_format: str = Field(..., description="Standard way to end the episode")
|
||||
interaction_tone: str = Field(default="Conversational", description="Tone between hosts or with audience")
|
||||
constraints: List[str] = Field(default_factory=list, description="Specific things to always do or avoid")
|
||||
|
||||
class AudienceDNA(BaseModel):
|
||||
"""Details about the target audience."""
|
||||
expertise_level: str = Field(..., description="Target audience expertise level (Beginner, Intermediate, Expert)")
|
||||
interests: List[str] = Field(default_factory=list, description="Primary interests of the audience")
|
||||
pain_points: List[str] = Field(default_factory=list, description="Common challenges or problems the audience faces")
|
||||
demographics: Optional[str] = Field(None, description="General demographic information")
|
||||
|
||||
class BrandDNA(BaseModel):
|
||||
"""Details about the brand and industry context."""
|
||||
industry: str = Field(..., description="Primary industry or niche")
|
||||
tone: str = Field(..., description="Overall brand tone (e.g., Professional, Casual, Inspirational)")
|
||||
communication_style: str = Field(..., description="Preferred communication style (e.g., Socratic, Storytelling, Analytical)")
|
||||
key_messages: List[str] = Field(default_factory=list, description="Core messages the brand wants to convey")
|
||||
competitor_context: Optional[str] = Field(None, description="Context on how to differentiate from competitors")
|
||||
|
||||
class PodcastBible(BaseModel):
|
||||
"""The complete structured Podcast Bible SSOT."""
|
||||
project_id: Optional[str] = Field(default=None, description="Associated project ID")
|
||||
host: HostPersona = Field(..., description="Host persona details")
|
||||
audience: AudienceDNA = Field(..., description="Target audience details")
|
||||
brand: BrandDNA = Field(..., description="Brand and industry context")
|
||||
visual_style: VisualStyle = Field(..., description="Visual aesthetic and environment")
|
||||
audio_environment: AudioEnvironment = Field(..., description="Soundscape and music details")
|
||||
show_rules: ShowRules = Field(..., description="Consistency and structural rules")
|
||||
@@ -33,6 +33,7 @@ class PodcastProject(Base):
|
||||
|
||||
# Project state (stored as JSON)
|
||||
# This mirrors the PodcastProjectState interface from frontend
|
||||
bible = Column(JSON, nullable=True) # PodcastBible structured data
|
||||
analysis = Column(JSON, nullable=True) # PodcastAnalysis
|
||||
queries = Column(JSON, nullable=True) # List[Query]
|
||||
selected_queries = Column(JSON, nullable=True) # Array of query IDs
|
||||
@@ -56,6 +57,11 @@ class PodcastProject(Base):
|
||||
# Final combined video URL (persisted for reloads)
|
||||
final_video_url = Column(String(1000), nullable=True) # URL to final combined podcast video
|
||||
|
||||
# Avatar details
|
||||
avatar_url = Column(String(1000), nullable=True)
|
||||
avatar_prompt = Column(Text, nullable=True)
|
||||
avatar_persona_id = Column(String(255), nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
@@ -40,11 +40,53 @@ class StoryGenerationRequest(BaseModel):
|
||||
audio_lang: str = Field(default="en", description="Language code for TTS")
|
||||
audio_slow: bool = Field(default=False, description="Whether to speak slowly (gTTS only)")
|
||||
audio_rate: int = Field(default=150, description="Speech rate (pyttsx3 only)")
|
||||
anime_bible: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional structured anime story bible for anime fiction templates",
|
||||
)
|
||||
|
||||
|
||||
class StorySetupGenerationRequest(BaseModel):
|
||||
"""Request model for AI story setup generation."""
|
||||
story_idea: str = Field(..., description="Basic story idea or information from the user")
|
||||
story_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Story mode (marketing or pure) if provided by the UI",
|
||||
)
|
||||
story_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional story template identifier (e.g. product_story, brand_manifesto)",
|
||||
)
|
||||
brand_context: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional high-signal brand context derived from onboarding",
|
||||
)
|
||||
|
||||
|
||||
class StoryIdeaEnhanceRequest(BaseModel):
|
||||
"""Request model for AI story idea enhancement."""
|
||||
story_idea: str = Field(..., description="Original story idea or concept text from the user")
|
||||
story_mode: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Story mode (marketing or pure) if provided by the UI",
|
||||
)
|
||||
story_template: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional story template identifier (e.g. product_story, brand_manifesto)",
|
||||
)
|
||||
brand_context: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional high-signal brand context derived from onboarding",
|
||||
)
|
||||
|
||||
fiction_variant: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional fiction-specific focus label (e.g. high-concept twist, shonen action)",
|
||||
)
|
||||
narrative_energy: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Optional narrative energy or pacing hint (e.g. grounded, balanced, cinematic)",
|
||||
)
|
||||
|
||||
|
||||
class StorySetupOption(BaseModel):
|
||||
@@ -78,6 +120,43 @@ class StorySetupOption(BaseModel):
|
||||
audio_lang: str = Field(default="en", description="Language code for TTS")
|
||||
audio_slow: bool = Field(default=False, description="Whether to speak slowly (gTTS only)")
|
||||
audio_rate: int = Field(default=150, description="Speech rate (pyttsx3 only)")
|
||||
anime_bible: Optional["AnimeStoryBible"] = Field(
|
||||
default=None,
|
||||
description="Optional structured anime story bible for anime fiction templates",
|
||||
)
|
||||
|
||||
|
||||
class AnimeCharacter(BaseModel):
|
||||
id: str = Field(..., description="Stable identifier for the character (snake_case)")
|
||||
name: str = Field(..., description="Character name")
|
||||
age_range: str = Field(..., description="Approximate age range (e.g., 'late teens', '30s')")
|
||||
role: str = Field(..., description="Narrative role (protagonist, antagonist, mentor, etc.)")
|
||||
look: str = Field(..., description="Key visual details (hair, build, notable traits)")
|
||||
outfit_palette: str = Field(..., description="Main outfit colors and style")
|
||||
personality_tags: List[str] = Field(default_factory=list, description="Short tags describing personality")
|
||||
|
||||
|
||||
class AnimeWorld(BaseModel):
|
||||
setting: str = Field(..., description="World description and primary locations")
|
||||
era: str = Field(..., description="Time period (near-future, far future, alt 1990s, etc.)")
|
||||
tech_or_magic_level: str = Field(..., description="Technology or magic sophistication level")
|
||||
core_rules: List[str] = Field(default_factory=list, description="Key world rules and constraints")
|
||||
|
||||
|
||||
class AnimeVisualStyle(BaseModel):
|
||||
style_preset: str = Field(..., description="High level style preset (anime_manga, cinematic_anime, cozy_slice_of_life)")
|
||||
camera_style: str = Field(..., description="Typical camera behaviour and framing")
|
||||
color_mood: str = Field(..., description="Dominant color palette and contrast")
|
||||
lighting: str = Field(..., description="Lighting style")
|
||||
line_style: str = Field(..., description="Line art style (thick, thin, rough, etc.)")
|
||||
extra_tags: List[str] = Field(default_factory=list, description="Additional style tags")
|
||||
|
||||
|
||||
class AnimeStoryBible(BaseModel):
|
||||
story_id: Optional[str] = Field(default=None, description="Optional story identifier")
|
||||
main_cast: List[AnimeCharacter] = Field(default_factory=list, description="Main cast of characters")
|
||||
world: AnimeWorld = Field(..., description="World and rules description")
|
||||
visual_style: AnimeVisualStyle = Field(..., description="Visual style anchors for images and video")
|
||||
|
||||
|
||||
class StorySetupGenerationResponse(BaseModel):
|
||||
@@ -86,8 +165,28 @@ class StorySetupGenerationResponse(BaseModel):
|
||||
success: bool = Field(default=True, description="Whether the generation was successful")
|
||||
|
||||
|
||||
class StoryIdeaEnhanceSuggestion(BaseModel):
|
||||
"""A single enhanced story idea suggestion."""
|
||||
idea: str = Field(..., description="AI-enhanced story idea text")
|
||||
whats_missing: str = Field(
|
||||
...,
|
||||
description="Concise explanation of missing or underspecified plot/context elements",
|
||||
)
|
||||
why_choose: str = Field(
|
||||
...,
|
||||
description="Why this idea is a strong direction based on the original input",
|
||||
)
|
||||
|
||||
|
||||
class StoryIdeaEnhanceResponse(BaseModel):
|
||||
"""Response model for story idea enhancement."""
|
||||
suggestions: List[StoryIdeaEnhanceSuggestion] = Field(
|
||||
..., description="List of enhanced story idea suggestions"
|
||||
)
|
||||
success: bool = Field(default=True, description="Whether the enhancement was successful")
|
||||
|
||||
|
||||
class StoryScene(BaseModel):
|
||||
"""Model for a story scene."""
|
||||
scene_number: int = Field(..., description="Scene number")
|
||||
title: str = Field(..., description="Scene title")
|
||||
description: str = Field(..., description="Scene description")
|
||||
@@ -97,6 +196,58 @@ class StoryScene(BaseModel):
|
||||
key_events: List[str] = Field(default_factory=list, description="Key events in the scene")
|
||||
|
||||
|
||||
class AnimeSceneTextRequest(BaseModel):
|
||||
scene: StoryScene = Field(..., description="Scene to refine using the anime bible")
|
||||
persona: str = Field(..., description="Persona context for the scene")
|
||||
story_setting: str = Field(..., description="Story setting")
|
||||
character_input: str = Field(..., description="Characters description from story setup")
|
||||
plot_elements: str = Field(..., description="Plot elements from story setup")
|
||||
writing_style: str = Field(..., description="Writing style")
|
||||
story_tone: str = Field(..., description="Story tone")
|
||||
narrative_pov: str = Field(..., description="Narrative point of view")
|
||||
audience_age_group: str = Field(..., description="Audience age group")
|
||||
content_rating: str = Field(..., description="Content rating")
|
||||
anime_bible: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional anime story bible used to refine the scene",
|
||||
)
|
||||
|
||||
|
||||
class AnimeSceneTextResponse(BaseModel):
|
||||
scene: StoryScene = Field(..., description="Refined scene with bible-aware text and prompts")
|
||||
success: bool = Field(default=True, description="Whether the refinement was successful")
|
||||
|
||||
|
||||
class AnimeSceneGenerateRequest(BaseModel):
|
||||
premise: str = Field(..., description="Overall story premise for context")
|
||||
persona: str = Field(..., description="Persona context for the scene")
|
||||
story_setting: str = Field(..., description="Story setting")
|
||||
character_input: str = Field(..., description="Characters description from story setup")
|
||||
plot_elements: str = Field(..., description="Plot elements from story setup")
|
||||
writing_style: str = Field(..., description="Writing style")
|
||||
story_tone: str = Field(..., description="Story tone")
|
||||
narrative_pov: str = Field(..., description="Narrative point of view")
|
||||
audience_age_group: str = Field(..., description="Audience age group")
|
||||
content_rating: str = Field(..., description="Content rating")
|
||||
anime_bible: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Anime story bible used as a hard constraint for generation",
|
||||
)
|
||||
previous_scenes: Optional[List[StoryScene]] = Field(
|
||||
default=None,
|
||||
description="Optional list of previous scenes for continuity context",
|
||||
)
|
||||
target_scene_number: Optional[int] = Field(
|
||||
default=None,
|
||||
description="Optional target scene number for the new scene",
|
||||
)
|
||||
|
||||
|
||||
class AnimeSceneGenerateResponse(BaseModel):
|
||||
scene: StoryScene = Field(..., description="Newly generated anime scene based on the bible")
|
||||
success: bool = Field(default=True, description="Whether the scene generation was successful")
|
||||
|
||||
|
||||
class StoryStartRequest(StoryGenerationRequest):
|
||||
"""Request model for story start generation."""
|
||||
premise: str = Field(..., description="The story premise")
|
||||
@@ -116,6 +267,10 @@ class StoryOutlineResponse(BaseModel):
|
||||
success: bool = Field(default=True, description="Whether the generation was successful")
|
||||
task_id: Optional[str] = Field(None, description="Task ID for async operations")
|
||||
is_structured: bool = Field(default=False, description="Whether the outline is structured (scenes) or plain text")
|
||||
anime_bible: Optional[AnimeStoryBible] = Field(
|
||||
default=None,
|
||||
description="Optional structured anime story bible generated from final story setup",
|
||||
)
|
||||
|
||||
|
||||
class StoryContentResponse(BaseModel):
|
||||
@@ -156,6 +311,10 @@ class StoryContinueRequest(BaseModel):
|
||||
content_rating: str = Field(..., description="The content rating")
|
||||
ending_preference: str = Field(..., description="The preferred ending")
|
||||
story_length: str = Field(default="Medium", description="Target story length (Short: >1000 words, Medium: >5000 words, Long: >10000 words)")
|
||||
anime_bible: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description="Optional structured anime story bible for anime fiction templates",
|
||||
)
|
||||
|
||||
|
||||
class StoryContinueResponse(BaseModel):
|
||||
|
||||
55
backend/models/story_project_models.py
Normal file
55
backend/models/story_project_models.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""
|
||||
Story Project Models
|
||||
|
||||
Database models for Story Studio project persistence and state management.
|
||||
Modeled after PodcastProject and ResearchProject for cross-device resume.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, JSON, Index
|
||||
|
||||
from models.subscription_models import Base
|
||||
|
||||
|
||||
class StoryProject(Base):
|
||||
"""
|
||||
Database model for Story Studio project state.
|
||||
Stores complete story project state to enable cross-device resume.
|
||||
"""
|
||||
|
||||
__tablename__ = "story_projects"
|
||||
|
||||
# Primary fields
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
project_id = Column(String(255), unique=True, nullable=False, index=True)
|
||||
user_id = Column(String(255), nullable=False, index=True)
|
||||
|
||||
# Project metadata
|
||||
title = Column(String(500), nullable=True)
|
||||
story_mode = Column(String(50), nullable=True)
|
||||
story_template = Column(String(100), nullable=True)
|
||||
|
||||
# Story state (stored as JSON)
|
||||
setup = Column(JSON, nullable=True)
|
||||
outline = Column(JSON, nullable=True)
|
||||
scenes = Column(JSON, nullable=True)
|
||||
story_content = Column(JSON, nullable=True)
|
||||
anime_bible = Column(JSON, nullable=True)
|
||||
media_state = Column(JSON, nullable=True)
|
||||
|
||||
# UI/progress state
|
||||
current_phase = Column(String(50), nullable=True)
|
||||
status = Column(String(50), default="draft", nullable=False, index=True)
|
||||
is_favorite = Column(Boolean, default=False, index=True)
|
||||
is_complete = Column(Boolean, default=False)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime, default=datetime.utcnow, nullable=False, index=True)
|
||||
updated_at = Column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index("idx_story_user_status_created", "user_id", "status", "created_at"),
|
||||
Index("idx_story_user_favorite_updated", "user_id", "is_favorite", "updated_at"),
|
||||
)
|
||||
|
||||
@@ -26,6 +26,8 @@ class UsageStatus(enum.Enum):
|
||||
WARNING = "warning" # 80% usage
|
||||
LIMIT_REACHED = "limit_reached" # 100% usage
|
||||
SUSPENDED = "suspended"
|
||||
CANCELLED = "cancelled"
|
||||
PAST_DUE = "past_due"
|
||||
|
||||
class APIProvider(enum.Enum):
|
||||
GEMINI = "gemini"
|
||||
@@ -389,4 +391,20 @@ class SubscriptionRenewalHistory(Base):
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
{'mysql_engine': 'InnoDB'},
|
||||
)
|
||||
)
|
||||
|
||||
class FraudWarning(Base):
|
||||
__tablename__ = "fraud_warnings"
|
||||
|
||||
id = Column(String(100), primary_key=True)
|
||||
charge_id = Column(String(100), nullable=False)
|
||||
payment_intent_id = Column(String(100), nullable=True)
|
||||
user_id = Column(String(100), nullable=True)
|
||||
amount = Column(Integer, nullable=False, default=0)
|
||||
currency = Column(String(10), nullable=False, default="")
|
||||
status = Column(String(20), nullable=False, default="open")
|
||||
action = Column(String(20), nullable=False, default="none")
|
||||
action_at = Column(DateTime, nullable=True)
|
||||
reason_notes = Column(Text, nullable=True)
|
||||
meta_info = Column(JSON, nullable=True)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
|
||||
@@ -14,6 +14,9 @@ fastapi-clerk-auth>=0.0.7
|
||||
# Database dependencies
|
||||
sqlalchemy>=2.0.25
|
||||
|
||||
# Payment processing
|
||||
stripe>=8.0.0
|
||||
|
||||
# CopilotKit and Research
|
||||
copilotkit
|
||||
exa-py==1.9.1
|
||||
|
||||
@@ -11,7 +11,7 @@ from loguru import logger
|
||||
from services.integrations.bing_oauth import BingOAuthService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/bing", tags=["Bing Analytics"])
|
||||
router = APIRouter(prefix="/api/bing", tags=["Bing Analytics"])
|
||||
|
||||
# Initialize Bing OAuth service
|
||||
bing_service = BingOAuthService()
|
||||
@@ -26,7 +26,7 @@ async def get_query_stats(
|
||||
):
|
||||
"""Get search query statistics for a Bing Webmaster site."""
|
||||
try:
|
||||
user_id = current_user.get("user_id")
|
||||
user_id = current_user.get("id") or current_user.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
|
||||
@@ -67,7 +67,7 @@ async def get_user_sites(
|
||||
):
|
||||
"""Get list of user's verified sites from Bing Webmaster."""
|
||||
try:
|
||||
user_id = current_user.get("user_id")
|
||||
user_id = current_user.get("id") or current_user.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
|
||||
@@ -98,7 +98,7 @@ async def get_query_stats_summary(
|
||||
):
|
||||
"""Get summarized query statistics for a Bing Webmaster site."""
|
||||
try:
|
||||
user_id = current_user.get("user_id")
|
||||
user_id = current_user.get("id") or current_user.get("user_id")
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
|
||||
|
||||
@@ -6,17 +6,46 @@ Provides endpoints for retrieving analytics data from connected platforms.
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Depends, Query
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
|
||||
from services.analytics import PlatformAnalyticsService
|
||||
from middleware.auth_middleware import get_current_user
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
router = APIRouter(prefix="/api/analytics", tags=["Platform Analytics"])
|
||||
|
||||
# Initialize analytics service
|
||||
analytics_service = PlatformAnalyticsService()
|
||||
|
||||
@router.post("/cache/clear")
|
||||
async def clear_analytics_cache(
|
||||
platform: Optional[str] = Query(None, description="Specific platform to clear (e.g., 'bing', 'gsc')"),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Clear analytics cache for the current user.
|
||||
If 'platform' is provided, clears only that platform's cache; otherwise clears all and connection status.
|
||||
"""
|
||||
try:
|
||||
user_id = current_user.get('id')
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="User ID not found")
|
||||
|
||||
if platform:
|
||||
analytics_service.invalidate_platform_cache(user_id, platform)
|
||||
else:
|
||||
analytics_service.invalidate_platform_cache(user_id)
|
||||
|
||||
# Always refresh connection status cache as well
|
||||
analytics_service.invalidate_connection_cache(user_id)
|
||||
|
||||
return { "success": True, "message": "Analytics cache cleared", "platform": platform or "all" }
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clear analytics cache: {e}")
|
||||
return { "success": False, "error": str(e) }
|
||||
|
||||
|
||||
class AnalyticsRequest(BaseModel):
|
||||
"""Request model for analytics data"""
|
||||
@@ -65,7 +94,9 @@ async def get_platform_connection_status(current_user: dict = Depends(get_curren
|
||||
|
||||
@router.get("/data")
|
||||
async def get_analytics_data(
|
||||
platforms: Optional[str] = Query(None, description="Comma-separated list of platforms (gsc,wix,wordpress)"),
|
||||
platforms: Optional[str] = Query(None, description="Comma-separated list of platforms (gsc,bing,wix,wordpress)"),
|
||||
start_date: Optional[str] = Query(None, description="Start date (YYYY-MM-DD)"),
|
||||
end_date: Optional[str] = Query(None, description="End date (YYYY-MM-DD)"),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> AnalyticsResponse:
|
||||
"""
|
||||
@@ -88,15 +119,31 @@ async def get_analytics_data(
|
||||
if platforms:
|
||||
platform_list = [p.strip() for p in platforms.split(',') if p.strip()]
|
||||
|
||||
logger.info(f"Getting analytics data for user: {user_id}, platforms: {platform_list}")
|
||||
logger.info(f"Getting analytics data for user: {user_id}, platforms: {platform_list}, start_date: {start_date}, end_date: {end_date}")
|
||||
|
||||
# Get analytics data
|
||||
analytics_data = await analytics_service.get_comprehensive_analytics(user_id, platform_list)
|
||||
|
||||
# Generate summary
|
||||
analytics_data = await analytics_service.get_comprehensive_analytics(user_id, platform_list, start_date=start_date, end_date=end_date)
|
||||
summary = analytics_service.get_analytics_summary(analytics_data)
|
||||
|
||||
# Convert AnalyticsData objects to dictionaries
|
||||
logger.warning(
|
||||
"Analytics summary for user {user}: total_clicks={clicks}, total_impressions={impr}, overall_ctr={ctr}, platforms={platforms}",
|
||||
user=user_id,
|
||||
clicks=summary.get("total_clicks"),
|
||||
impr=summary.get("total_impressions"),
|
||||
ctr=summary.get("overall_ctr"),
|
||||
platforms=list(analytics_data.keys()),
|
||||
)
|
||||
for platform_name, data in analytics_data.items():
|
||||
try:
|
||||
logger.warning(
|
||||
"Analytics platform snapshot {platform}: status={status}, total_clicks={clicks}, total_impressions={impr}",
|
||||
platform=platform_name,
|
||||
status=data.status,
|
||||
clicks=data.get_total_clicks(),
|
||||
impr=data.get_total_impressions(),
|
||||
)
|
||||
except Exception as log_err:
|
||||
logger.warning(f"Failed to log platform snapshot for {platform_name}: {log_err}")
|
||||
|
||||
data_dict = {}
|
||||
for platform, data in analytics_data.items():
|
||||
data_dict[platform] = {
|
||||
@@ -148,7 +195,14 @@ async def get_analytics_data_post(
|
||||
logger.info(f"Getting analytics data for user: {user_id}, platforms: {request.platforms}")
|
||||
|
||||
# Get analytics data
|
||||
analytics_data = await analytics_service.get_comprehensive_analytics(user_id, request.platforms)
|
||||
# Extract optional dates
|
||||
start_date = None
|
||||
end_date = None
|
||||
if request.date_range and isinstance(request.date_range, dict):
|
||||
start_date = request.date_range.get('start')
|
||||
end_date = request.date_range.get('end')
|
||||
|
||||
analytics_data = await analytics_service.get_comprehensive_analytics(user_id, request.platforms, start_date=start_date, end_date=end_date)
|
||||
|
||||
# Generate summary
|
||||
summary = analytics_service.get_analytics_summary(analytics_data)
|
||||
@@ -250,12 +304,196 @@ async def get_analytics_summary(current_user: dict = Depends(get_current_user))
|
||||
"platforms_connected": summary['connected_platforms'],
|
||||
"platforms_total": summary['total_platforms']
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get analytics summary: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/ai-insights")
|
||||
async def get_ai_insights(
|
||||
start_date: Optional[str] = Query(None),
|
||||
end_date: Optional[str] = Query(None),
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> Dict[str, Any]:
|
||||
try:
|
||||
user_id = current_user.get('id')
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=400, detail="User ID not found")
|
||||
sd = start_date
|
||||
ed = end_date
|
||||
if not sd or not ed:
|
||||
today = datetime.utcnow().date()
|
||||
ed = today.isoformat()
|
||||
sd = (today - timedelta(days=29)).isoformat()
|
||||
analytics = await analytics_service.get_comprehensive_analytics(user_id, ['gsc'], start_date=sd, end_date=ed)
|
||||
gsc = analytics.get('gsc')
|
||||
if not gsc or gsc.status != 'success':
|
||||
return {"success": False, "error": gsc.error_message if gsc else "GSC data unavailable"}
|
||||
metrics = gsc.metrics or {}
|
||||
tq = metrics.get('top_queries') or []
|
||||
tp = metrics.get('top_pages') or []
|
||||
cannib = metrics.get('cannibalization') or []
|
||||
sdt = datetime.strptime(sd, "%Y-%m-%d").date()
|
||||
edt = datetime.strptime(ed, "%Y-%m-%d").date()
|
||||
window_days = max((edt - sdt).days + 1, 1)
|
||||
def thr_impr():
|
||||
if window_days <= 7:
|
||||
return 100
|
||||
if window_days <= 30:
|
||||
return 500
|
||||
return 1500
|
||||
def thr_clicks():
|
||||
if window_days <= 7:
|
||||
return 10
|
||||
if window_days <= 30:
|
||||
return 30
|
||||
return 60
|
||||
low_ctr_queries = []
|
||||
for r in tq:
|
||||
imp = float(r.get('impressions', 0) or 0)
|
||||
ctr = float(r.get('ctr', 0) or 0)
|
||||
if imp >= thr_impr() and ctr <= 2.5:
|
||||
low_ctr_queries.append({
|
||||
"query": r.get('query'),
|
||||
"impressions": int(round(imp)),
|
||||
"ctr": round(ctr, 2),
|
||||
"clicks": int(round(float(r.get('clicks', 0) or 0))),
|
||||
"position": round(float(r.get('position', 0) or 0), 2) if 'position' in r else None
|
||||
})
|
||||
striking_distance = []
|
||||
for r in tq:
|
||||
pos = float(r.get('position', 0) or 0)
|
||||
imp = float(r.get('impressions', 0) or 0)
|
||||
if 8.0 <= pos <= 20.0 and imp >= (80 if window_days <= 7 else (300 if window_days <= 30 else 1000)):
|
||||
striking_distance.append({
|
||||
"query": r.get('query'),
|
||||
"impressions": int(round(imp)),
|
||||
"position": round(pos, 2),
|
||||
"clicks": int(round(float(r.get('clicks', 0) or 0)))
|
||||
})
|
||||
low_ctr_pages = []
|
||||
for p in tp:
|
||||
imp = float(p.get('impressions', 0) or 0)
|
||||
ctr = float(p.get('ctr', 0) or 0)
|
||||
if imp >= thr_impr() and ctr <= 2.0:
|
||||
low_ctr_pages.append({
|
||||
"page": p.get('page'),
|
||||
"impressions": int(round(imp)),
|
||||
"ctr": round(ctr, 2),
|
||||
"clicks": int(round(float(p.get('clicks', 0) or 0)))
|
||||
})
|
||||
serp_feature_loss = []
|
||||
for r in tq:
|
||||
pos = float(r.get('position', 0) or 0)
|
||||
imp = float(r.get('impressions', 0) or 0)
|
||||
ctr = float(r.get('ctr', 0) or 0)
|
||||
if pos > 0 and pos <= 5.0 and imp >= thr_impr() and ctr <= 2.0:
|
||||
serp_feature_loss.append({
|
||||
"query": r.get('query'),
|
||||
"impressions": int(round(imp)),
|
||||
"position": round(pos, 2),
|
||||
"ctr": round(ctr, 2),
|
||||
"clicks": int(round(float(r.get('clicks', 0) or 0)))
|
||||
})
|
||||
def build_map(rows):
|
||||
m = {}
|
||||
for r in rows:
|
||||
k = r.get('query')
|
||||
if not k:
|
||||
continue
|
||||
m[k] = {
|
||||
"clicks": float(r.get('clicks', 0) or 0),
|
||||
"impressions": float(r.get('impressions', 0) or 0)
|
||||
}
|
||||
return m
|
||||
prev_end = (sdt - timedelta(days=1)).isoformat()
|
||||
prev_start = (sdt - timedelta(days=window_days)).isoformat()
|
||||
prev_analytics = await analytics_service.get_comprehensive_analytics(user_id, ['gsc'], start_date=prev_start, end_date=prev_end)
|
||||
prev_gsc = prev_analytics.get('gsc')
|
||||
prev_tq = prev_gsc.metrics.get('top_queries') if prev_gsc and prev_gsc.metrics else []
|
||||
curr_map = build_map(tq)
|
||||
prev_map = build_map(prev_tq)
|
||||
declining_queries = []
|
||||
for q, v in curr_map.items():
|
||||
pv = prev_map.get(q) or {"clicks": 0.0, "impressions": 0.0}
|
||||
dc = int(round(v["clicks"] - pv["clicks"]))
|
||||
di = int(round(v["impressions"] - pv["impressions"]))
|
||||
if dc < 0 or di < 0:
|
||||
if abs(dc) >= 5 or abs(di) >= thr_impr() * 0.2:
|
||||
declining_queries.append({
|
||||
"query": q,
|
||||
"delta_clicks": dc,
|
||||
"delta_impressions": di,
|
||||
"prev_clicks": int(round(pv["clicks"])),
|
||||
"prev_impressions": int(round(pv["impressions"]))
|
||||
})
|
||||
low_ctr_queries = sorted(low_ctr_queries, key=lambda x: (-x["impressions"], x["ctr"]))[:10]
|
||||
striking_distance = sorted(striking_distance, key=lambda x: -x["impressions"])[:10]
|
||||
low_ctr_pages = sorted(low_ctr_pages, key=lambda x: (-x["impressions"], x["ctr"]))[:10]
|
||||
cannib_list = cannib[:10]
|
||||
serp_feature_loss = sorted(serp_feature_loss, key=lambda x: -x["impressions"])[:10]
|
||||
payload = {
|
||||
"context": {
|
||||
"site_url": None,
|
||||
"date_range": {"start": sd, "end": ed},
|
||||
"window_days": window_days
|
||||
},
|
||||
"signals": {
|
||||
"low_ctr_queries": low_ctr_queries,
|
||||
"striking_distance": striking_distance,
|
||||
"declining_queries": declining_queries[:10],
|
||||
"low_ctr_pages": low_ctr_pages,
|
||||
"cannibalization": cannib_list,
|
||||
"serp_feature_loss": serp_feature_loss
|
||||
},
|
||||
"limits": {
|
||||
"max_items_per_signal": 10,
|
||||
"language": "en",
|
||||
"tone": "simple"
|
||||
}
|
||||
}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"quick_summary": {"type": "string"},
|
||||
"prioritized_findings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"severity": {"type": "string"},
|
||||
"audience_note": {"type": "string"},
|
||||
"evidence": {"type": "string"},
|
||||
"why_it_matters": {"type": "string"},
|
||||
"actions": {"type": "array", "items": {"type": "string"}},
|
||||
"effort": {"type": "string"}
|
||||
}
|
||||
}
|
||||
},
|
||||
"playbooks": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title_meta_fixes": {"type": "array", "items": {"type": "object"}},
|
||||
"consolidation": {"type": "array", "items": {"type": "object"}},
|
||||
"refreshes": {"type": "array", "items": {"type": "object"}},
|
||||
"internal_linking": {"type": "array", "items": {"type": "object"}}
|
||||
}
|
||||
},
|
||||
"metrics": {"type": "object"}
|
||||
}
|
||||
}
|
||||
system_prompt = "You are an SEO assistant for non-technical creators. Use simple language and concrete actions. Only use provided numbers. Return a single JSON object matching the schema."
|
||||
prompt = "Analyze the following GSC-derived signals and produce prioritized findings and playbooks.\n\n" + str(payload)
|
||||
ai = llm_text_gen(prompt=prompt, json_struct=schema, system_prompt=system_prompt, user_id=user_id)
|
||||
return {"success": True, "insights": ai}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"AI insights failed: {e}")
|
||||
return {"success": False, "error": str(e)}
|
||||
|
||||
|
||||
@router.get("/cache/test")
|
||||
async def test_cache_endpoint(current_user: dict = Depends(get_current_user)) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
141
backend/scripts/create_story_project_tables.py
Normal file
141
backend/scripts/create_story_project_tables.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
Database Migration Script for Story Studio
|
||||
Creates the story_projects table for cross-device story project persistence.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
from sqlalchemy import create_engine, text
|
||||
import traceback
|
||||
|
||||
# Add the backend directory to Python path
|
||||
backend_dir = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(backend_dir))
|
||||
|
||||
from models.subscription_models import Base as SubscriptionBase
|
||||
from models.story_project_models import StoryProject # noqa: F401
|
||||
from services.database import DATABASE_URL
|
||||
|
||||
|
||||
def create_story_tables() -> None:
|
||||
"""Create story-related project tables."""
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
|
||||
logger.info("Creating Story Studio project tables...")
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
logger.info("✅ Story project tables created successfully")
|
||||
|
||||
display_setup_summary(engine)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Error creating story project tables: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
raise
|
||||
|
||||
|
||||
def display_setup_summary(engine) -> None:
|
||||
"""Display a summary of the created tables."""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("STORY STUDIO PROJECT SETUP SUMMARY")
|
||||
logger.info("=" * 60)
|
||||
|
||||
check_query = text(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='story_projects'
|
||||
"""
|
||||
)
|
||||
|
||||
result = conn.execute(check_query)
|
||||
table_exists = result.fetchone()
|
||||
|
||||
if table_exists:
|
||||
logger.info("✅ Table 'story_projects' created successfully")
|
||||
|
||||
schema_query = text(
|
||||
"""
|
||||
SELECT sql FROM sqlite_master
|
||||
WHERE type='table' AND name='story_projects'
|
||||
"""
|
||||
)
|
||||
result = conn.execute(schema_query)
|
||||
schema = result.fetchone()
|
||||
if schema:
|
||||
logger.info("\n📋 Table Schema:")
|
||||
logger.info(schema[0])
|
||||
|
||||
indexes_query = text(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='index' AND tbl_name='story_projects'
|
||||
"""
|
||||
)
|
||||
result = conn.execute(indexes_query)
|
||||
indexes = result.fetchall()
|
||||
|
||||
if indexes:
|
||||
logger.info(f"\n📊 Indexes ({len(indexes)}):")
|
||||
for idx in indexes:
|
||||
logger.info(f" • {idx[0]}")
|
||||
else:
|
||||
logger.warning("⚠️ Table 'story_projects' not found after creation")
|
||||
|
||||
logger.info("\n" + "=" * 60)
|
||||
logger.info("NEXT STEPS:")
|
||||
logger.info("=" * 60)
|
||||
logger.info("1. The story_projects table is ready for use")
|
||||
logger.info("2. Story Studio projects will sync to database via new endpoints")
|
||||
logger.info("3. Users will be able to resume Story Studio sessions across devices")
|
||||
logger.info("=" * 60)
|
||||
except Exception as e:
|
||||
logger.error(f"Error displaying Story Studio setup summary: {e}")
|
||||
|
||||
|
||||
def check_existing_table(engine) -> bool:
|
||||
"""Check if story_projects table already exists."""
|
||||
try:
|
||||
with engine.connect() as conn:
|
||||
check_query = text(
|
||||
"""
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name='story_projects'
|
||||
"""
|
||||
)
|
||||
|
||||
result = conn.execute(check_query)
|
||||
table_exists = result.fetchone()
|
||||
|
||||
if table_exists:
|
||||
logger.info("ℹ️ Table 'story_projects' already exists")
|
||||
logger.info(" Running migration will ensure schema is up to date...")
|
||||
return True
|
||||
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking existing Story Studio table: {e}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("🚀 Starting Story Studio database migration...")
|
||||
|
||||
try:
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
|
||||
check_existing_table(engine)
|
||||
|
||||
create_story_tables()
|
||||
|
||||
logger.info("✅ Story Studio migration completed successfully!")
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Migration cancelled by user")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Story Studio migration failed: {e}")
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
|
||||
@@ -121,7 +121,8 @@ class BaseALwrityAgent(ABC):
|
||||
if TXTAI_AVAILABLE:
|
||||
try:
|
||||
if not self.llm:
|
||||
self.llm = LLM(model_name)
|
||||
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
|
||||
self.llm = LLM(model_name, task="text-generation")
|
||||
|
||||
self.txtai_agent = self._create_txtai_agent()
|
||||
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
|
||||
|
||||
@@ -4,7 +4,6 @@ Bing Webmaster Tools Analytics Handler
|
||||
Handles Bing Webmaster Tools analytics data retrieval and processing.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from loguru import logger
|
||||
@@ -16,13 +15,23 @@ from ..models.platform_types import PlatformType
|
||||
from .base_handler import BaseAnalyticsHandler
|
||||
from ..insights.bing_insights_service import BingInsightsService
|
||||
from services.bing_analytics_storage_service import BingAnalyticsStorageService
|
||||
import os
|
||||
|
||||
|
||||
from services.database import get_user_db_path
|
||||
|
||||
class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
"""Handler for Bing Webmaster Tools analytics"""
|
||||
"""
|
||||
Handler for Bing Webmaster Tools analytics
|
||||
|
||||
NOTE (2026-02-14): Known issues and directions
|
||||
- Verified sites list can be empty despite valid tokens. This leads to partial/error states and prevents storage collection.
|
||||
Direction: UI now provides a manual site picker (with primary website fallback from onboarding) to trigger storage collection,
|
||||
and a future improvement should accept a target_url from /api/analytics/data to influence site selection here.
|
||||
- Token state mismatch (status shows connected, analytics reports expired) can happen across cache boundaries.
|
||||
Direction: The frontend auto-resyncs once after OAuth success and provides a backend cache clear endpoint.
|
||||
- Storage-backed summary reads rely on a selected site; when sites are missing, selected_site is None.
|
||||
Direction: Allow explicit site_url parameter in the analytics orchestrator to override selected_site resolution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PlatformType.BING)
|
||||
@@ -42,14 +51,22 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
db_url = f'sqlite:///{db_path}'
|
||||
return BingInsightsService(db_url)
|
||||
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Bing Webmaster analytics data using Bing Webmaster API
|
||||
"""
|
||||
self.log_analytics_request(user_id, "get_analytics")
|
||||
|
||||
# Check cache first
|
||||
cached_data = analytics_cache.get('bing_analytics', user_id)
|
||||
# Check cache first (include date range and target_url in key)
|
||||
cache_key_parts = [user_id]
|
||||
if target_url:
|
||||
cache_key_parts.append(str(target_url))
|
||||
if start_date:
|
||||
cache_key_parts.append(str(start_date))
|
||||
if end_date:
|
||||
cache_key_parts.append(str(end_date))
|
||||
cache_key = "_".join(cache_key_parts)
|
||||
cached_data = analytics_cache.get('bing_analytics', cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Using cached Bing analytics for user {user_id}")
|
||||
return AnalyticsData(**cached_data)
|
||||
@@ -107,9 +124,22 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
site_url_for_storage = selected_site.get('Url', '') if selected_site else ''
|
||||
logger.info(f"Using Bing site URL: {site_url_for_storage}")
|
||||
|
||||
# Determine date range (defaults to last 30 days)
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
# Compute days for storage/insights services (at least 1)
|
||||
try:
|
||||
dt_end = datetime.strptime(end_date, '%Y-%m-%d')
|
||||
dt_start = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
days_range = max(1, (dt_end - dt_start).days + 1)
|
||||
except Exception:
|
||||
days_range = 30
|
||||
|
||||
query_stats = {}
|
||||
try:
|
||||
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=30)
|
||||
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=days_range)
|
||||
if stored and isinstance(stored, dict):
|
||||
query_stats = {
|
||||
'total_clicks': stored.get('summary', {}).get('total_clicks', 0),
|
||||
@@ -138,19 +168,20 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
'insights': insights,
|
||||
'note': 'Bing Webmaster API provides SEO insights, search performance, and index status data'
|
||||
}
|
||||
|
||||
if (not sites) or (metrics.get('total_impressions', 0) == 0 and metrics.get('total_clicks', 0) == 0):
|
||||
result = self.create_partial_response(metrics=metrics, error_message='Connected to Bing; waiting for stored analytics or site verification')
|
||||
|
||||
if not sites:
|
||||
result = self.create_partial_response(metrics=metrics, error_message='Connected to Bing; no verified sites found')
|
||||
else:
|
||||
result = self.create_success_response(metrics=metrics)
|
||||
result = self.create_success_response(metrics=metrics, date_range={'start': start_date, 'end': end_date})
|
||||
|
||||
analytics_cache.set('bing_analytics', user_id, result.__dict__)
|
||||
analytics_cache.set('bing_analytics', cache_key, result.__dict__)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.log_analytics_error(user_id, "get_analytics", e)
|
||||
error_result = self.create_error_response(str(e))
|
||||
analytics_cache.set('bing_analytics', user_id, error_result.__dict__, ttl_override=300)
|
||||
# Cache error briefly to prevent hammering but recover quickly
|
||||
analytics_cache.set('bing_analytics', cache_key, error_result.__dict__, ttl_override=30)
|
||||
return error_result
|
||||
|
||||
def _get_enhanced_insights_with_service(self, insights_service: BingInsightsService, user_id: str, site_url: str) -> Dict[str, Any]:
|
||||
|
||||
@@ -22,7 +22,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
super().__init__(PlatformType.GSC)
|
||||
self.gsc_service = GSCService()
|
||||
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Google Search Console analytics data with caching
|
||||
|
||||
@@ -35,8 +35,16 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
self.log_analytics_request(user_id, "get_analytics")
|
||||
|
||||
# Check cache first - GSC API calls can be expensive
|
||||
# Include target_url in cache key if provided
|
||||
cache_key = f"{user_id}_{target_url}" if target_url else user_id
|
||||
# Include target_url and date range in cache key if provided
|
||||
cache_key_parts = [user_id]
|
||||
if target_url:
|
||||
cache_key_parts.append(str(target_url))
|
||||
if start_date:
|
||||
cache_key_parts.append(str(start_date))
|
||||
if end_date:
|
||||
cache_key_parts.append(str(end_date))
|
||||
# Bump cache version to include page insights (v2)
|
||||
cache_key = "_".join(cache_key_parts + ['v2pages'])
|
||||
cached_data = analytics_cache.get('gsc_analytics', cache_key)
|
||||
if cached_data:
|
||||
logger.info("Using cached GSC analytics for user {user_id}", user_id=user_id)
|
||||
@@ -70,9 +78,11 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
site_url = selected_site['siteUrl']
|
||||
logger.info(f"Using GSC site URL: {site_url}")
|
||||
|
||||
# Get search analytics for last 30 days
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
# Determine date range (defaults to last 30 days)
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
logger.info(f"GSC Date range: {start_date} to {end_date}")
|
||||
|
||||
search_analytics = self.gsc_service.get_search_analytics(
|
||||
@@ -86,10 +96,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
# Process GSC data into standardized format
|
||||
processed_metrics = self._process_gsc_metrics(search_analytics)
|
||||
|
||||
result = self.create_success_response(
|
||||
metrics=processed_metrics,
|
||||
date_range={'start': start_date, 'end': end_date}
|
||||
)
|
||||
result = self.create_success_response(metrics=processed_metrics, date_range={'start': start_date, 'end': end_date})
|
||||
|
||||
# Cache the result to avoid expensive API calls
|
||||
analytics_cache.set('gsc_analytics', cache_key, result.__dict__)
|
||||
@@ -101,8 +108,8 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
self.log_analytics_error(user_id, "get_analytics", e)
|
||||
error_result = self.create_error_response(str(e))
|
||||
|
||||
# Cache error result for shorter time to retry sooner
|
||||
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=300) # 5 minutes
|
||||
# Cache error result briefly to avoid repeated failures but allow quick recovery
|
||||
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=30) # 30 seconds
|
||||
return error_result
|
||||
|
||||
def get_connection_status(self, user_id: str) -> Dict[str, Any]:
|
||||
@@ -202,18 +209,159 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
sorted_queries = sorted(top_queries_source, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
|
||||
|
||||
for row in sorted_queries:
|
||||
clicks_val = row.get('clicks', 0) or 0
|
||||
impr_val = row.get('impressions', 0) or 0
|
||||
raw_ctr = row.get('ctr', None)
|
||||
# Calculate CTR% robustly even if 'ctr' field is missing in row
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = round(float(raw_ctr) * 100, 2)
|
||||
else:
|
||||
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
|
||||
top_queries.append({
|
||||
'query': self._extract_query_from_row(row),
|
||||
'clicks': row.get('clicks', 0),
|
||||
'impressions': row.get('impressions', 0),
|
||||
'ctr': round(row.get('ctr', 0) * 100, 2),
|
||||
'position': round(row.get('position', 0), 2)
|
||||
'clicks': clicks_val,
|
||||
'impressions': impr_val,
|
||||
'ctr': ctr_percent,
|
||||
'position': round(row.get('position', 0) or 0, 2)
|
||||
})
|
||||
|
||||
# Prepare Top Pages (requires page dimension, but we only requested query dimension in gsc_service step 3)
|
||||
# To get top pages, we would need another API call with dimension=['page']
|
||||
# For now, we'll return empty top_pages or infer from what we have if possible (we can't from query data)
|
||||
top_pages = []
|
||||
# Prepare Top Pages from page_data when available
|
||||
top_pages = []
|
||||
try:
|
||||
page_rows = search_analytics.get('page_data', {}).get('rows', [])
|
||||
qp_rows = search_analytics.get('query_page_data', {}).get('rows', [])
|
||||
# Build queries-by-page map
|
||||
queries_by_page: Dict[str, list] = {}
|
||||
if qp_rows:
|
||||
for r in qp_rows:
|
||||
keys = r.get('keys', [])
|
||||
if not keys or len(keys) < 2:
|
||||
continue
|
||||
query_key = keys[0]['keys'][0] if isinstance(keys[0], dict) else str(keys[0])
|
||||
page_key = keys[1]['keys'][0] if isinstance(keys[1], dict) else str(keys[1])
|
||||
clicks_val = r.get('clicks', 0) or 0
|
||||
impr_val = r.get('impressions', 0) or 0
|
||||
raw_ctr = r.get('ctr', None)
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = round(float(raw_ctr) * 100, 2)
|
||||
else:
|
||||
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
|
||||
lst = queries_by_page.setdefault(page_key, [])
|
||||
lst.append({
|
||||
'query': query_key,
|
||||
'clicks': clicks_val,
|
||||
'impressions': impr_val,
|
||||
'ctr': ctr_percent,
|
||||
})
|
||||
if page_rows:
|
||||
sorted_pages = sorted(page_rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
|
||||
for row in sorted_pages:
|
||||
clicks_val = row.get('clicks', 0) or 0
|
||||
impr_val = row.get('impressions', 0) or 0
|
||||
raw_ctr = row.get('ctr', None)
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = round(float(raw_ctr) * 100, 2)
|
||||
else:
|
||||
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
|
||||
page_url = self._extract_page_from_row(row)
|
||||
# attach top queries pointing to this page, sorted by clicks
|
||||
page_queries = sorted(queries_by_page.get(page_url, []), key=lambda x: x.get('clicks', 0), reverse=True)[:5]
|
||||
top_pages.append({
|
||||
'page': page_url,
|
||||
'clicks': clicks_val,
|
||||
'impressions': impr_val,
|
||||
'ctr': ctr_percent,
|
||||
'position': round(row.get('position', 0) or 0, 2) if 'position' in row else None,
|
||||
'queries': page_queries
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed processing top_pages: {e}")
|
||||
|
||||
# Detect Cannibalization (query mapping to multiple pages)
|
||||
cannibalization = []
|
||||
try:
|
||||
qp_rows = search_analytics.get('query_page_data', {}).get('rows', [])
|
||||
q_rows = search_analytics.get('query_data', {}).get('rows', [])
|
||||
if qp_rows:
|
||||
# Determine window days for thresholding
|
||||
from datetime import datetime
|
||||
start_s = search_analytics.get('startDate')
|
||||
end_s = search_analytics.get('endDate')
|
||||
window_days = 30
|
||||
try:
|
||||
if start_s and end_s:
|
||||
sd = datetime.strptime(start_s, "%Y-%m-%d")
|
||||
ed = datetime.strptime(end_s, "%Y-%m-%d")
|
||||
window_days = max((ed - sd).days + 1, 1)
|
||||
except Exception:
|
||||
pass
|
||||
min_clicks = 10 if window_days <= 7 else (30 if window_days <= 30 else 60)
|
||||
# Build map: query -> { page -> metrics }
|
||||
by_query: Dict[str, Dict[str, Dict[str, float]]] = {}
|
||||
for r in qp_rows:
|
||||
keys = r.get('keys', [])
|
||||
if not keys or len(keys) < 2:
|
||||
continue
|
||||
qk = keys[0]['keys'][0] if isinstance(keys[0], dict) else str(keys[0])
|
||||
pk = keys[1]['keys'][0] if isinstance(keys[1], dict) else str(keys[1])
|
||||
clicks_val = float(r.get('clicks', 0) or 0)
|
||||
impr_val = float(r.get('impressions', 0) or 0)
|
||||
raw_ctr = r.get('ctr', None)
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = float(raw_ctr) * 100.0
|
||||
else:
|
||||
ctr_percent = (clicks_val / impr_val * 100.0) if impr_val > 0 else 0.0
|
||||
pos_val = float(r.get('position', 0) or 0)
|
||||
by_query.setdefault(qk, {}).setdefault(pk, {"clicks": 0.0, "impressions": 0.0, "ctr": 0.0, "position_sum": 0.0, "position_count": 0.0})
|
||||
agg = by_query[qk][pk]
|
||||
agg["clicks"] += clicks_val
|
||||
agg["impressions"] += impr_val
|
||||
agg["ctr"] = max(agg["ctr"], ctr_percent)
|
||||
if pos_val > 0:
|
||||
agg["position_sum"] += pos_val
|
||||
agg["position_count"] += 1
|
||||
# Use query totals for context
|
||||
total_by_query: Dict[str, Dict[str, float]] = {}
|
||||
for r in q_rows or []:
|
||||
qk = self._extract_query_from_row(r)
|
||||
total_by_query[qk] = {
|
||||
"clicks": float(r.get('clicks', 0) or 0),
|
||||
"impressions": float(r.get('impressions', 0) or 0),
|
||||
"position": float(r.get('position', 0) or 0)
|
||||
}
|
||||
for qk, pages_map in by_query.items():
|
||||
if len(pages_map) < 2:
|
||||
continue
|
||||
total_clicks = sum(p["clicks"] for p in pages_map.values())
|
||||
if total_clicks < min_clicks:
|
||||
continue
|
||||
qpos = total_by_query.get(qk, {}).get("position", 0.0)
|
||||
if not (3.0 <= qpos <= 20.0) and qpos != 0.0:
|
||||
# Skip queries already ranking very well or very poorly (if pos present)
|
||||
continue
|
||||
pages_list = []
|
||||
for pk, m in pages_map.items():
|
||||
avg_pos = (m["position_sum"] / m["position_count"]) if m["position_count"] > 0 else 0.0
|
||||
pages_list.append({
|
||||
"page": pk,
|
||||
"clicks": round(m["clicks"], 0),
|
||||
"impressions": round(m["impressions"], 0),
|
||||
"ctr": round(m["ctr"], 2),
|
||||
"position": round(avg_pos, 2) if avg_pos > 0 else None
|
||||
})
|
||||
pages_list.sort(key=lambda x: x.get("clicks", 0), reverse=True)
|
||||
target_page = pages_list[0]["page"] if pages_list else None
|
||||
cannibalization.append({
|
||||
"query": qk,
|
||||
"total_clicks": int(round(total_clicks)),
|
||||
"recommended_target_page": target_page,
|
||||
"pages": pages_list[:3]
|
||||
})
|
||||
# Sort by impact
|
||||
cannibalization.sort(key=lambda item: item.get("total_clicks", 0), reverse=True)
|
||||
cannibalization = cannibalization[:10]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed computing cannibalization: {e}")
|
||||
|
||||
return {
|
||||
'connection_status': 'connected',
|
||||
@@ -224,7 +372,8 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
'avg_position': round(avg_position, 2),
|
||||
'total_queries': len(top_queries_source) if top_queries_source else 0,
|
||||
'top_queries': top_queries,
|
||||
'top_pages': top_pages
|
||||
'top_pages': top_pages,
|
||||
'cannibalization': cannibalization
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -256,3 +405,18 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting query from row: {e}")
|
||||
return 'Unknown'
|
||||
|
||||
def _extract_page_from_row(self, row: Dict[str, Any]) -> str:
|
||||
"""Extract page URL from GSC API row data"""
|
||||
try:
|
||||
keys = row.get('keys', [])
|
||||
if keys and len(keys) > 0:
|
||||
first_key = keys[0]
|
||||
if isinstance(first_key, dict):
|
||||
return first_key.get('keys', [''])[0]
|
||||
else:
|
||||
return str(first_key)
|
||||
return ''
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting page from row: {e}")
|
||||
return ''
|
||||
|
||||
@@ -21,7 +21,7 @@ class WixAnalyticsHandler(BaseAnalyticsHandler):
|
||||
super().__init__(PlatformType.WIX)
|
||||
self.wix_service = WixService()
|
||||
|
||||
async def get_analytics(self, user_id: str) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Wix analytics data using the Business Management API
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class WordPressAnalyticsHandler(BaseAnalyticsHandler):
|
||||
super().__init__(PlatformType.WORDPRESS)
|
||||
self.wordpress_service = WordPressOAuthService()
|
||||
|
||||
async def get_analytics(self, user_id: str) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get WordPress analytics data using WordPress.com REST API
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class PlatformAnalyticsService:
|
||||
self.summary_generator = AnalyticsSummaryGenerator()
|
||||
self.cache_manager = AnalyticsCacheManager()
|
||||
|
||||
async def get_comprehensive_analytics(self, user_id: str, platforms: List[str] = None) -> Dict[str, AnalyticsData]:
|
||||
async def get_comprehensive_analytics(self, user_id: str, platforms: List[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, AnalyticsData]:
|
||||
"""
|
||||
Get analytics data from all connected platforms
|
||||
|
||||
@@ -93,9 +93,18 @@ class PlatformAnalyticsService:
|
||||
|
||||
if handler:
|
||||
if platform_type == PlatformType.GSC or platform_type == PlatformType.BING:
|
||||
analytics_data[platform_name] = await handler.get_analytics(user_id, target_url=target_url)
|
||||
analytics_data[platform_name] = await handler.get_analytics(
|
||||
user_id,
|
||||
target_url=target_url,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
else:
|
||||
analytics_data[platform_name] = await handler.get_analytics(user_id)
|
||||
analytics_data[platform_name] = await handler.get_analytics(
|
||||
user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unknown platform: {platform_name}")
|
||||
analytics_data[platform_name] = self._create_error_response(platform_name, f"Unknown platform: {platform_name}")
|
||||
|
||||
@@ -237,7 +237,7 @@ class BingAnalyticsStorageService:
|
||||
Dict containing analytics summary
|
||||
"""
|
||||
try:
|
||||
db = self._get_db_session()
|
||||
db = self._get_db_session(user_id)
|
||||
|
||||
# Date range
|
||||
end_date = datetime.now()
|
||||
@@ -331,7 +331,7 @@ class BingAnalyticsStorageService:
|
||||
List of top queries with performance data
|
||||
"""
|
||||
try:
|
||||
db = self._get_db_session()
|
||||
db = self._get_db_session(user_id)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
@@ -241,6 +241,9 @@ class ExaResearchProvider(BaseProvider):
|
||||
for idx, result in enumerate(results):
|
||||
source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '')
|
||||
|
||||
# Extract image if available (some Exa results include image URL)
|
||||
image_url = result.image if hasattr(result, 'image') else None
|
||||
|
||||
sources.append({
|
||||
'title': result.title if hasattr(result, 'title') else '',
|
||||
'url': result.url if hasattr(result, 'url') else '',
|
||||
@@ -251,17 +254,21 @@ class ExaResearchProvider(BaseProvider):
|
||||
'source_type': source_type,
|
||||
'content': result.text if hasattr(result, 'text') else '',
|
||||
'highlights': result.highlights if hasattr(result, 'highlights') else [],
|
||||
'summary': result.summary if hasattr(result, 'summary') else ''
|
||||
'summary': result.summary if hasattr(result, 'summary') else '',
|
||||
'image': image_url,
|
||||
'author': result.author if hasattr(result, 'author') else None
|
||||
})
|
||||
|
||||
return sources
|
||||
|
||||
def _get_excerpt(self, result):
|
||||
"""Extract excerpt from Exa result."""
|
||||
"""Extract excerpt from Exa result. Prefer highlights if available."""
|
||||
if hasattr(result, 'highlights') and result.highlights and len(result.highlights) > 0:
|
||||
return result.highlights[0]
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
return result.summary
|
||||
if hasattr(result, 'text') and result.text:
|
||||
return result.text[:500]
|
||||
elif hasattr(result, 'summary') and result.summary:
|
||||
return result.summary
|
||||
return ''
|
||||
|
||||
def _determine_source_type(self, url):
|
||||
@@ -280,16 +287,30 @@ class ExaResearchProvider(BaseProvider):
|
||||
return 'web'
|
||||
|
||||
def _aggregate_content(self, results):
|
||||
"""Aggregate content from Exa results for LLM analysis."""
|
||||
"""Aggregate content from Exa results for LLM analysis, including highlights."""
|
||||
content_parts = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
part = [f"Source {idx + 1}: {result.title if hasattr(result, 'title') else 'Untitled'}"]
|
||||
if hasattr(result, 'url') and result.url:
|
||||
part.append(f"URL: {result.url}")
|
||||
|
||||
# Add highlights if available (most valuable for LLM)
|
||||
if hasattr(result, 'highlights') and result.highlights:
|
||||
highlights_text = "\n".join([f"- {h}" for h in result.highlights])
|
||||
part.append(f"Key Highlights:\n{highlights_text}")
|
||||
|
||||
# Add summary if available
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
content_parts.append(f"Source {idx + 1}: {result.summary}")
|
||||
part.append(f"Summary: {result.summary}")
|
||||
|
||||
# Add text snippet if highlights/summary insufficient
|
||||
elif hasattr(result, 'text') and result.text:
|
||||
content_parts.append(f"Source {idx + 1}: {result.text[:1000]}")
|
||||
part.append(f"Excerpt: {result.text[:1000]}")
|
||||
|
||||
content_parts.append("\n".join(part))
|
||||
|
||||
return "\n\n".join(content_parts)
|
||||
return "\n\n---\n\n".join(content_parts)
|
||||
|
||||
def track_exa_usage(self, user_id: str, cost: float):
|
||||
"""Track Exa API usage after successful call."""
|
||||
|
||||
@@ -159,14 +159,10 @@ class StyleDetectionLogic:
|
||||
}}
|
||||
"""
|
||||
|
||||
# Call the LLM for analysis
|
||||
logger.debug("[StyleDetectionLogic.analyze_content_style] Sending enhanced prompt to LLM")
|
||||
try:
|
||||
analysis_text = llm_text_gen(prompt, user_id=user_id)
|
||||
|
||||
# Clean and parse the response
|
||||
cleaned_json = self._clean_json_response(analysis_text)
|
||||
|
||||
analysis_results = json.loads(cleaned_json)
|
||||
logger.info("[StyleDetectionLogic.analyze_content_style] Successfully parsed enhanced analysis results")
|
||||
return {
|
||||
@@ -179,7 +175,7 @@ class StyleDetectionLogic:
|
||||
return {
|
||||
'success': True,
|
||||
'analysis': fallback_results,
|
||||
'warning': 'AI analysis failed, used fallback detection'
|
||||
'warning': f'AI analysis failed ({str(e)}), used fallback detection'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -145,6 +145,7 @@ def init_user_database(user_id: str):
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
UserBusinessInfoBase.metadata.create_all(bind=engine)
|
||||
ContentAssetBase.metadata.create_all(bind=engine)
|
||||
BingAnalyticsBase.metadata.create_all(bind=engine)
|
||||
|
||||
# Initialize default data for new databases
|
||||
try:
|
||||
|
||||
@@ -343,7 +343,11 @@ class GSCService:
|
||||
if not credentials:
|
||||
raise ValueError("No valid credentials found")
|
||||
|
||||
service = build('searchconsole', 'v1', credentials=credentials)
|
||||
# Disable discovery file cache (suppress oauth2client file_cache warnings) with safe fallback
|
||||
try:
|
||||
service = build('searchconsole', 'v1', credentials=credentials, cache_discovery=False)
|
||||
except TypeError:
|
||||
service = build('searchconsole', 'v1', credentials=credentials)
|
||||
logger.info(f"Authenticated GSC service created for user: {user_id}")
|
||||
return service
|
||||
|
||||
@@ -395,9 +399,12 @@ class GSCService:
|
||||
# Check cache first
|
||||
cache_key = f"{user_id}_{site_url}_{start_date}_{end_date}"
|
||||
cached_data = self._get_cached_data(user_id, site_url, 'analytics', cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Returning cached analytics data for user: {user_id}")
|
||||
return cached_data
|
||||
if cached_data and isinstance(cached_data, dict):
|
||||
has_pages = 'page_data' in cached_data and isinstance(cached_data.get('page_data'), dict)
|
||||
has_queries = 'query_data' in cached_data and isinstance(cached_data.get('query_data'), dict)
|
||||
if has_pages and has_queries:
|
||||
logger.info(f"Returning cached analytics data for user: {user_id} (includes page_data)")
|
||||
return cached_data
|
||||
|
||||
try:
|
||||
service = self.get_authenticated_service(user_id)
|
||||
@@ -476,8 +483,54 @@ class GSCService:
|
||||
).execute()
|
||||
|
||||
logger.info(f"GSC Query-level response for user {user_id}: {query_response}")
|
||||
|
||||
# Combine overall metrics with query-level data
|
||||
|
||||
# Step 4: Get page-level data for top pages insights
|
||||
page_request = {
|
||||
'startDate': start_date,
|
||||
'endDate': end_date,
|
||||
'dimensions': ['page'], # Get page-level data
|
||||
'rowLimit': 1000
|
||||
}
|
||||
logger.info(f"GSC Page-level request for user {user_id}: {page_request}")
|
||||
page_rows = []
|
||||
page_row_count = 0
|
||||
try:
|
||||
page_response = service.searchanalytics().query(
|
||||
siteUrl=site_url,
|
||||
body=page_request
|
||||
).execute()
|
||||
logger.info(f"GSC Page-level response for user {user_id}: {page_response}")
|
||||
page_rows = page_response.get('rows', [])
|
||||
page_row_count = page_response.get('rowCount', 0)
|
||||
except Exception as page_error:
|
||||
logger.warning(f"GSC Page-level request failed for user {user_id}: {page_error}")
|
||||
page_rows = []
|
||||
page_row_count = 0
|
||||
|
||||
# Step 5: Get query+page combined data for mapping queries to pages
|
||||
qp_rows = []
|
||||
qp_row_count = 0
|
||||
try:
|
||||
qp_request = {
|
||||
'startDate': start_date,
|
||||
'endDate': end_date,
|
||||
'dimensions': ['query', 'page'],
|
||||
'rowLimit': 1000
|
||||
}
|
||||
logger.info(f"GSC Query+Page request for user {user_id}: {qp_request}")
|
||||
qp_response = service.searchanalytics().query(
|
||||
siteUrl=site_url,
|
||||
body=qp_request
|
||||
).execute()
|
||||
logger.info(f"GSC Query+Page response for user {user_id}: {qp_response}")
|
||||
qp_rows = qp_response.get('rows', [])
|
||||
qp_row_count = qp_response.get('rowCount', 0)
|
||||
except Exception as qp_error:
|
||||
logger.warning(f"GSC Query+Page request failed for user {user_id}: {qp_error}")
|
||||
qp_rows = []
|
||||
qp_row_count = 0
|
||||
|
||||
# Combine overall, query, page and query+page data
|
||||
analytics_data = {
|
||||
'overall_metrics': {
|
||||
'rows': response.get('rows', []),
|
||||
@@ -487,6 +540,14 @@ class GSCService:
|
||||
'rows': query_response.get('rows', []),
|
||||
'rowCount': query_response.get('rowCount', 0)
|
||||
},
|
||||
'page_data': {
|
||||
'rows': page_rows,
|
||||
'rowCount': page_row_count
|
||||
},
|
||||
'query_page_data': {
|
||||
'rows': qp_rows,
|
||||
'rowCount': qp_row_count
|
||||
},
|
||||
'verification_data': {
|
||||
'rows': verification_rows,
|
||||
'rowCount': len(verification_rows)
|
||||
@@ -510,6 +571,8 @@ class GSCService:
|
||||
'rowCount': response.get('rowCount', 0)
|
||||
},
|
||||
'query_data': {'rows': [], 'rowCount': 0},
|
||||
'page_data': {'rows': [], 'rowCount': 0},
|
||||
'query_page_data': {'rows': [], 'rowCount': 0},
|
||||
'verification_data': {
|
||||
'rows': verification_rows,
|
||||
'rowCount': len(verification_rows)
|
||||
|
||||
@@ -76,7 +76,8 @@ class ALwrityAgentOrchestrator:
|
||||
try:
|
||||
# Initialize shared LLM
|
||||
if TXTAI_AVAILABLE:
|
||||
self.llm = LLM(self.config.shared_llm)
|
||||
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
|
||||
self.llm = LLM(self.config.shared_llm, task="text-generation")
|
||||
else:
|
||||
self.llm = None
|
||||
|
||||
|
||||
@@ -181,7 +181,8 @@ class BaseALwrityAgent(ABC):
|
||||
try:
|
||||
if not self.llm:
|
||||
# Create new LLM if not provided
|
||||
raw_llm = LLM(model_name)
|
||||
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
|
||||
raw_llm = LLM(model_name, task="text-generation")
|
||||
# Wrap it
|
||||
self.llm = TrackingLLMWrapper(raw_llm, self.user_id, self.model_name)
|
||||
|
||||
@@ -906,6 +907,11 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
|
||||
"name": "task_delegator",
|
||||
"description": "Delegates specific tasks to specialized agents (content, competitor, seo, social)",
|
||||
"target": self._delegate_task_tool
|
||||
},
|
||||
{
|
||||
"name": "kickoff_gsc_first_pass",
|
||||
"description": "Kicks off first-pass execution by invoking SEO/Content default GSC plans",
|
||||
"target": self._kickoff_gsc_first_pass_tool
|
||||
}
|
||||
],
|
||||
max_iterations=15,
|
||||
@@ -924,7 +930,9 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
|
||||
Do not just plan; EXECUTE by delegating.
|
||||
|
||||
Always prioritize user goals and maintain safety constraints.
|
||||
Coordinate multi-agent responses to market changes effectively."""
|
||||
Coordinate multi-agent responses to market changes effectively.
|
||||
|
||||
First, call 'kickoff_gsc_first_pass' to ground the plan on live GSC signals."""
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1033,6 +1041,37 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _kickoff_gsc_first_pass_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Invoke SEO and Content agents' default GSC plans and combine results"""
|
||||
try:
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
payload = {"start_date": start_date, "end_date": end_date}
|
||||
results = {}
|
||||
combined_actions = []
|
||||
|
||||
seo = self.sub_agents.get("seo")
|
||||
if seo and hasattr(seo, "_default_seo_gsc_plan_tool"):
|
||||
plan = await seo._default_seo_gsc_plan_tool(payload)
|
||||
results["seo"] = plan
|
||||
combined_actions.extend(plan.get("actions", []) if isinstance(plan, dict) else [])
|
||||
|
||||
content = self.sub_agents.get("content")
|
||||
if content and hasattr(content, "_default_content_gsc_plan_tool"):
|
||||
plan = await content._default_content_gsc_plan_tool(payload)
|
||||
results["content"] = plan
|
||||
combined_actions.extend(plan.get("actions", []) if isinstance(plan, dict) else [])
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"invoked": list(results.keys()),
|
||||
"results": results,
|
||||
"combined_actions": combined_actions,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
async def _strategy_synthesizer_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Tool for synthesizing strategies"""
|
||||
return {
|
||||
|
||||
@@ -13,6 +13,7 @@ from loguru import logger
|
||||
from ..txtai_service import TxtaiIntelligenceService
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, AgentAction
|
||||
from services.seo_tools.content_strategy_service import ContentStrategyService
|
||||
from services.analytics import PlatformAnalyticsService
|
||||
from services.intelligence.sif_agents import SharedLLMWrapper, LocalLLMWrapper
|
||||
try:
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
@@ -888,7 +889,37 @@ class ContentStrategyAgent(BaseALwrityAgent):
|
||||
"name": "sitemap_analyzer",
|
||||
"description": "Analyzes website structure and publishing velocity via sitemap",
|
||||
"target": self._sitemap_analyzer_tool
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_queries",
|
||||
"description": "Returns low-CTR queries with evidence from cached GSC metrics",
|
||||
"target": self._cs_gsc_low_ctr_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_striking_distance_queries",
|
||||
"description": "Returns striking-distance queries (positions ~8–20) with evidence",
|
||||
"target": self._cs_gsc_striking_distance_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_declining_queries",
|
||||
"description": "Returns period-over-period declining queries with evidence",
|
||||
"target": self._cs_gsc_declining_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_pages",
|
||||
"description": "Returns low-CTR pages with top contributing queries",
|
||||
"target": self._cs_gsc_low_ctr_pages_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_cannibalization_candidates",
|
||||
"description": "Returns query→multiple-pages cannibalization candidates with target recommendation",
|
||||
"target": self._cs_gsc_cannibalization_candidates_tool
|
||||
},
|
||||
{
|
||||
"name": "default_content_gsc_plan",
|
||||
"description": "Runs a default first-pass plan using GSC signals (titles/meta, consolidation, refreshes)",
|
||||
"target": self._default_content_gsc_plan_tool
|
||||
},
|
||||
],
|
||||
max_iterations=8,
|
||||
system=self.get_effective_system_prompt(f"""You are the Content Strategy Agent for ALwrity user {self.user_id}.
|
||||
@@ -903,12 +934,153 @@ class ContentStrategyAgent(BaseALwrityAgent):
|
||||
- Performance-based content improvements
|
||||
|
||||
Use semantic analysis (SIF) and sitemap analysis to understand content context.
|
||||
Always prioritize user goals and maintain brand consistency."""
|
||||
Always prioritize user goals and maintain brand consistency.
|
||||
|
||||
In your first pass, call 'default_content_gsc_plan' to ground your actions on live GSC signals."""
|
||||
)
|
||||
)
|
||||
|
||||
# Tool Implementations
|
||||
|
||||
async def _cs_fetch_gsc_analytics(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, Any]:
|
||||
svc = PlatformAnalyticsService()
|
||||
data = await svc.get_comprehensive_analytics(self.user_id, platforms=["gsc"], start_date=start_date, end_date=end_date)
|
||||
gsc = data.get("gsc")
|
||||
if not gsc or gsc.status != "success":
|
||||
err = getattr(gsc, "error_message", None) if gsc else "No data"
|
||||
raise RuntimeError(f"GSC analytics unavailable: {err}")
|
||||
return {"metrics": gsc.metrics, "date_range": gsc.date_range}
|
||||
|
||||
async def _cs_gsc_low_ctr_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 100)); min_clicks = int(context.get("min_clicks", 10)); ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{"query": r.get("query"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position")}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("clicks", 0) >= min_clicks and float(r.get("ctr", 0.0)) < ctr_threshold)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs low_ctr_queries failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_striking_distance_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 100)); start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{"query": r.get("query"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position")}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("position") is not None and 8.0 <= float(r.get("position")) <= 20.0)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("position") if x.get("position") is not None else 999, -x.get("impressions", 0)))
|
||||
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs striking_distance failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_declining_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_prev_clicks = int(context.get("min_prev_clicks", 10)); min_drop_pct = float(context.get("min_drop_pct", 30.0))
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
curr = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
curr_range = curr["date_range"]; s = curr_range.get("start"); e = curr_range.get("end")
|
||||
from datetime import datetime, timedelta; fmt = "%Y-%m-%d"
|
||||
sd = datetime.strptime(s, fmt) if s else datetime.utcnow() - timedelta(days=30); ed = datetime.strptime(e, fmt) if e else datetime.utcnow()
|
||||
days = max((ed - sd).days + 1, 1); prev_end = sd - timedelta(days=1); prev_start = prev_end - timedelta(days=days - 1)
|
||||
prev = await self._cs_fetch_gsc_analytics(prev_start.strftime(fmt), prev_end.strftime(fmt))
|
||||
curr_queries = {r.get("query"): r for r in (curr["metrics"].get("top_queries", []) or [])}
|
||||
prev_queries = {r.get("query"): r for r in (prev["metrics"].get("top_queries", []) or [])}
|
||||
items = []
|
||||
for q, prev_row in prev_queries.items():
|
||||
curr_row = curr_queries.get(q);
|
||||
if not curr_row: continue
|
||||
prev_clicks = int(prev_row.get("clicks", 0) or 0); curr_clicks = int(curr_row.get("clicks", 0) or 0)
|
||||
if prev_clicks >= min_prev_clicks and curr_clicks < prev_clicks:
|
||||
drop_pct = ((prev_clicks - curr_clicks) / prev_clicks) * 100.0
|
||||
if drop_pct >= min_drop_pct:
|
||||
items.append({"query": q, "prev_clicks": prev_clicks, "curr_clicks": curr_clicks, "drop_pct": round(drop_pct, 2)})
|
||||
items.sort(key=lambda x: (x.get("drop_pct", 0), x.get("prev_clicks", 0)), reverse=True)
|
||||
return {"items": items[:limit], "range": curr_range, "previous_range": prev["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs declining_queries failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_low_ctr_pages_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 200)); ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
tp = result["metrics"].get("top_pages", []) or []
|
||||
items = []
|
||||
for r in tp:
|
||||
if (r.get("impressions", 0) >= min_impr and float(r.get("ctr", 0.0)) < ctr_threshold):
|
||||
items.append({"page": r.get("page"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position"), "evidence_queries": r.get("queries", [])[:5]})
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs low_ctr_pages failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_cannibalization_candidates_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
candidates = result["metrics"].get("cannibalization", []) or []
|
||||
return {"items": candidates[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs cannibalization_candidates failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _default_content_gsc_plan_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
low_ctr_pages = await self._cs_gsc_low_ctr_pages_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
cannibals = await self._cs_gsc_cannibalization_candidates_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
striking = await self._cs_gsc_striking_distance_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
declining = await self._cs_gsc_declining_queries_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
|
||||
actions = []
|
||||
for p in low_ctr_pages.get("items", []):
|
||||
actions.append({
|
||||
"type": "improve_titles_meta",
|
||||
"target": p.get("page"),
|
||||
"reason": f"Low CTR {p.get('ctr')}% with {p.get('impressions')} impressions",
|
||||
"evidence": p.get("evidence_queries", [])
|
||||
})
|
||||
for c in cannibals.get("items", []):
|
||||
actions.append({
|
||||
"type": "consolidate/internal_link",
|
||||
"target": c.get("recommended_target_page"),
|
||||
"reason": f"Cannibalization on query '{c.get('query')}'",
|
||||
"pages": c.get("pages", [])
|
||||
})
|
||||
for q in striking.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"reason": f"Striking distance at position {q.get('position')} with {q.get('impressions')} impressions"
|
||||
})
|
||||
for q in declining.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"reason": f"Clicks decline {q.get('prev_clicks')}→{q.get('curr_clicks')} ({q.get('drop_pct')}%)"
|
||||
})
|
||||
|
||||
return {
|
||||
"plan_name": "Default Content Plan from GSC",
|
||||
"range": {"current": {"start": start_date, "end": end_date}},
|
||||
"actions": actions,
|
||||
"source": "gsc_cache",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"default_content_gsc_plan failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _sitemap_analyzer_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sitemap analysis tool using ContentStrategyService"""
|
||||
website_url = context.get('website_url')
|
||||
@@ -1324,7 +1496,37 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
"name": "query_seo_knowledge_base",
|
||||
"description": "Queries the SIF knowledge base for SEO dashboard data, GSC/Bing metrics, and semantic insights",
|
||||
"target": self._query_seo_knowledge_base_tool
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_queries",
|
||||
"description": "Returns low-CTR queries with evidence from cached GSC metrics",
|
||||
"target": self._gsc_low_ctr_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_striking_distance_queries",
|
||||
"description": "Returns striking-distance queries (positions ~8–20) with evidence",
|
||||
"target": self._gsc_striking_distance_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_declining_queries",
|
||||
"description": "Returns period-over-period declining queries with evidence",
|
||||
"target": self._gsc_declining_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_pages",
|
||||
"description": "Returns low-CTR pages with top contributing queries",
|
||||
"target": self._gsc_low_ctr_pages_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_cannibalization_candidates",
|
||||
"description": "Returns query→multiple-pages cannibalization candidates with target recommendation",
|
||||
"target": self._gsc_cannibalization_candidates_tool
|
||||
},
|
||||
{
|
||||
"name": "default_seo_gsc_plan",
|
||||
"description": "Runs a default first-pass SEO plan using GSC signals (titles/meta, consolidation, refreshes)",
|
||||
"target": self._default_seo_gsc_plan_tool
|
||||
},
|
||||
],
|
||||
max_iterations=15,
|
||||
system=self.get_effective_system_prompt(f"""You are the SEO Optimization Agent for ALwrity user {self.user_id}.
|
||||
@@ -1340,6 +1542,7 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
- Deep semantic search of SEO data (GSC, Bing, Audits)
|
||||
|
||||
Focus on high-impact, low-effort optimizations first.
|
||||
In your first pass, call 'default_seo_gsc_plan' to ground your actions on live GSC signals.
|
||||
Always maintain SEO best practices and user experience."""
|
||||
)
|
||||
)
|
||||
@@ -1666,6 +1869,223 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# GSC Insights Tools (Option B)
|
||||
async def _fetch_gsc_analytics(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, Any]:
|
||||
svc = PlatformAnalyticsService()
|
||||
data = await svc.get_comprehensive_analytics(self.user_id, platforms=["gsc"], start_date=start_date, end_date=end_date)
|
||||
gsc = data.get("gsc")
|
||||
if not gsc or gsc.status != "success":
|
||||
err = getattr(gsc, "error_message", None) if gsc else "No data"
|
||||
raise RuntimeError(f"GSC analytics unavailable: {err}")
|
||||
return {
|
||||
"metrics": gsc.metrics,
|
||||
"date_range": gsc.date_range
|
||||
}
|
||||
|
||||
async def _gsc_low_ctr_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_impr = int(context.get("min_impressions", 100))
|
||||
min_clicks = int(context.get("min_clicks", 10))
|
||||
ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{
|
||||
"query": r.get("query"),
|
||||
"clicks": r.get("clicks", 0),
|
||||
"impressions": r.get("impressions", 0),
|
||||
"ctr": r.get("ctr", 0.0),
|
||||
"position": r.get("position")
|
||||
}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("clicks", 0) >= min_clicks and float(r.get("ctr", 0.0)) < ctr_threshold)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"low_ctr_queries tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_striking_distance_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_impr = int(context.get("min_impressions", 100))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{
|
||||
"query": r.get("query"),
|
||||
"clicks": r.get("clicks", 0),
|
||||
"impressions": r.get("impressions", 0),
|
||||
"ctr": r.get("ctr", 0.0),
|
||||
"position": r.get("position")
|
||||
}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("position") is not None and 8.0 <= float(r.get("position")) <= 20.0)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("position") if x.get("position") is not None else 999, -x.get("impressions", 0)))
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"striking_distance tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_declining_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_prev_clicks = int(context.get("min_prev_clicks", 10))
|
||||
min_drop_pct = float(context.get("min_drop_pct", 30.0))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
curr = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
curr_range = curr["date_range"]
|
||||
s = curr_range.get("start")
|
||||
e = curr_range.get("end")
|
||||
from datetime import datetime, timedelta
|
||||
fmt = "%Y-%m-%d"
|
||||
sd = datetime.strptime(s, fmt) if s else datetime.utcnow() - timedelta(days=30)
|
||||
ed = datetime.strptime(e, fmt) if e else datetime.utcnow()
|
||||
days = max((ed - sd).days + 1, 1)
|
||||
prev_end = sd - timedelta(days=1)
|
||||
prev_start = prev_end - timedelta(days=days - 1)
|
||||
prev = await self._fetch_gsc_analytics(prev_start.strftime(fmt), prev_end.strftime(fmt))
|
||||
curr_queries = {r.get("query"): r for r in (curr["metrics"].get("top_queries", []) or [])}
|
||||
prev_queries = {r.get("query"): r for r in (prev["metrics"].get("top_queries", []) or [])}
|
||||
items = []
|
||||
for q, prev_row in prev_queries.items():
|
||||
curr_row = curr_queries.get(q)
|
||||
if not curr_row:
|
||||
continue
|
||||
prev_clicks = int(prev_row.get("clicks", 0) or 0)
|
||||
curr_clicks = int(curr_row.get("clicks", 0) or 0)
|
||||
if prev_clicks >= min_prev_clicks and curr_clicks < prev_clicks:
|
||||
drop_pct = ((prev_clicks - curr_clicks) / prev_clicks) * 100.0
|
||||
if drop_pct >= min_drop_pct:
|
||||
items.append({
|
||||
"query": q,
|
||||
"prev_clicks": prev_clicks,
|
||||
"curr_clicks": curr_clicks,
|
||||
"drop_pct": round(drop_pct, 2)
|
||||
})
|
||||
items.sort(key=lambda x: (x.get("drop_pct", 0), x.get("prev_clicks", 0)), reverse=True)
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": curr_range,
|
||||
"previous_range": prev["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"declining_queries tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_low_ctr_pages_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_impr = int(context.get("min_impressions", 200))
|
||||
ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
tp = result["metrics"].get("top_pages", []) or []
|
||||
items = []
|
||||
for r in tp:
|
||||
if (r.get("impressions", 0) >= min_impr and float(r.get("ctr", 0.0)) < ctr_threshold):
|
||||
items.append({
|
||||
"page": r.get("page"),
|
||||
"clicks": r.get("clicks", 0),
|
||||
"impressions": r.get("impressions", 0),
|
||||
"ctr": r.get("ctr", 0.0),
|
||||
"position": r.get("position"),
|
||||
"evidence_queries": r.get("queries", [])[:5]
|
||||
})
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"low_ctr_pages tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_cannibalization_candidates_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
candidates = result["metrics"].get("cannibalization", []) or []
|
||||
return {
|
||||
"items": candidates[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"cannibalization_candidates tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _default_seo_gsc_plan_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
low_ctr_pages = await self._gsc_low_ctr_pages_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
cannibals = await self._gsc_cannibalization_candidates_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
striking = await self._gsc_striking_distance_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
declining = await self._gsc_declining_queries_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
|
||||
actions = []
|
||||
for p in low_ctr_pages.get("items", []):
|
||||
actions.append({
|
||||
"type": "update_titles_meta",
|
||||
"target_page": p.get("page"),
|
||||
"justification": f"Low CTR {p.get('ctr')}% with {p.get('impressions')} impressions",
|
||||
"evidence": p.get("evidence_queries", [])
|
||||
})
|
||||
for c in cannibals.get("items", []):
|
||||
actions.append({
|
||||
"type": "consolidate/internal_link",
|
||||
"target_page": c.get("recommended_target_page"),
|
||||
"justification": f"Cannibalization on query '{c.get('query')}'",
|
||||
"pages": c.get("pages", [])
|
||||
})
|
||||
for q in striking.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"justification": f"Striking distance at position {q.get('position')} with {q.get('impressions')} impressions"
|
||||
})
|
||||
for q in declining.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"justification": f"Clicks decline {q.get('prev_clicks')}→{q.get('curr_clicks')} ({q.get('drop_pct')}%)"
|
||||
})
|
||||
|
||||
return {
|
||||
"plan_name": "Default SEO Plan from GSC",
|
||||
"range": {"current": {"start": start_date, "end": end_date}},
|
||||
"actions": actions,
|
||||
"source": "gsc_cache",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"default_seo_gsc_plan failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
class SocialAmplificationAgent(BaseALwrityAgent):
|
||||
"""
|
||||
|
||||
@@ -14,9 +14,9 @@ from .txtai_service import TxtaiIntelligenceService, TXTAI_AVAILABLE
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
# Optional txtai imports
|
||||
# Optional txtai imports (align with core agent framework)
|
||||
try:
|
||||
from txtai.pipeline import Agent, LLM
|
||||
from txtai import Agent, LLM
|
||||
except ImportError:
|
||||
Agent = None
|
||||
LLM = None
|
||||
@@ -28,9 +28,13 @@ class SharedLLMWrapper:
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the shared LLM provider."""
|
||||
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
|
||||
# but we could map them if needed.
|
||||
return llm_text_gen(prompt, user_id=self.user_id)
|
||||
try:
|
||||
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
|
||||
# but we could map them if needed.
|
||||
return llm_text_gen(prompt, user_id=self.user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"SharedLLMWrapper failed to generate text: {e}")
|
||||
return f"[ERROR: Shared LLM generation failed for user {self.user_id}]"
|
||||
|
||||
def __call__(self, prompt: str, **kwargs) -> str:
|
||||
return self.generate(prompt, **kwargs)
|
||||
@@ -40,8 +44,9 @@ class LocalLLMWrapper:
|
||||
Lazily loads a local LLM via txtai.
|
||||
This prevents blocking server startup with heavy model loads.
|
||||
"""
|
||||
def __init__(self, model_path: str):
|
||||
def __init__(self, model_path: str, task: str = "text-generation"):
|
||||
self.model_path = model_path
|
||||
self.task = task
|
||||
self._llm = None
|
||||
|
||||
@property
|
||||
@@ -49,8 +54,9 @@ class LocalLLMWrapper:
|
||||
if self._llm is None:
|
||||
if LLM is None:
|
||||
raise ImportError("txtai.pipeline.LLM is not available")
|
||||
logger.info(f"Loading local LLM: {self.model_path}")
|
||||
self._llm = LLM(path=self.model_path)
|
||||
logger.info(f"Loading local LLM: {self.model_path} with task: {self.task}")
|
||||
# Explicitly set task to avoid 'text2text-generation' default failures
|
||||
self._llm = LLM(path=self.model_path, task=self.task)
|
||||
return self._llm
|
||||
|
||||
def __call__(self, prompt: str, **kwargs) -> str:
|
||||
@@ -67,11 +73,12 @@ class SIFBaseAgent(BaseALwrityAgent):
|
||||
|
||||
# 2. Local LLM for internal agent work (default for SIF agents)
|
||||
if llm is None:
|
||||
if TXTAI_AVAILABLE:
|
||||
# Use Lazy Local LLM
|
||||
llm = LocalLLMWrapper(model_name)
|
||||
if TXTAI_AVAILABLE and LLM is not None:
|
||||
# Use Lazy Local LLM when txtai LLM is available
|
||||
# Hardening: Specify 'text-generation' task to avoid text2text defaults
|
||||
llm = LocalLLMWrapper(model_name, task="text-generation")
|
||||
else:
|
||||
# Fallback to Shared if txtai not available
|
||||
# Fallback to Shared if txtai or LLM is not available
|
||||
llm = self.shared_llm
|
||||
|
||||
super().__init__(user_id, agent_type, model_name, llm)
|
||||
@@ -85,14 +92,18 @@ class SIFBaseAgent(BaseALwrityAgent):
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""
|
||||
SIF agents use the intelligence service directly, but we can expose
|
||||
capabilities via a standard agent interface if needed.
|
||||
SIF agents primarily use the intelligence service directly, but we can expose
|
||||
capabilities via a standard agent interface if available.
|
||||
"""
|
||||
if not TXTAI_AVAILABLE:
|
||||
return None
|
||||
|
||||
# Return a simple agent that can use the LLM
|
||||
return Agent(llm=self.llm, tools=[])
|
||||
if not TXTAI_AVAILABLE or Agent is None:
|
||||
logger.debug(f"[{self.__class__.__name__}] txtai Agent not available, using fallback agent")
|
||||
return self._create_fallback_agent()
|
||||
|
||||
try:
|
||||
return Agent(llm=self.llm, tools=[])
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] Failed to create txtai Agent: {e}")
|
||||
return self._create_fallback_agent()
|
||||
|
||||
class StrategyArchitectAgent(SIFBaseAgent):
|
||||
"""Agent for discovering content pillars and identifying strategic gaps."""
|
||||
|
||||
@@ -25,7 +25,18 @@ except ImportError:
|
||||
TXTAI_AVAILABLE = False
|
||||
|
||||
class TxtaiIntelligenceService:
|
||||
_instances = {}
|
||||
|
||||
def __new__(cls, user_id: str, *args, **kwargs):
|
||||
if user_id not in cls._instances:
|
||||
cls._instances[user_id] = super(TxtaiIntelligenceService, cls).__new__(cls)
|
||||
return cls._instances[user_id]
|
||||
|
||||
def __init__(self, user_id: str, model_path: Optional[str] = None, enable_caching: bool = True):
|
||||
# Singleton: prevent re-initialization if already initialized
|
||||
if getattr(self, "_singleton_initialized", False):
|
||||
return
|
||||
|
||||
self.user_id = user_id
|
||||
self.model_path = model_path or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
self.index_path = f"workspace/workspace_{user_id}/indices/txtai"
|
||||
@@ -33,6 +44,11 @@ class TxtaiIntelligenceService:
|
||||
self._initialized = False
|
||||
self.enable_caching = enable_caching
|
||||
self.cache_manager = semantic_cache_manager if enable_caching else None
|
||||
self._backend = "faiss" # Default backend
|
||||
|
||||
# Mark as initialized for singleton pattern
|
||||
self._singleton_initialized = True
|
||||
|
||||
# Lazy initialization - do not initialize embeddings on startup
|
||||
# self._initialize_embeddings()
|
||||
|
||||
@@ -52,17 +68,26 @@ class TxtaiIntelligenceService:
|
||||
logger.debug(f"Model path: {self.model_path}")
|
||||
logger.debug(f"Index path: {self.index_path}")
|
||||
|
||||
# Close existing embeddings if any to release file locks
|
||||
if self.embeddings:
|
||||
try:
|
||||
if hasattr(self.embeddings, 'close'):
|
||||
self.embeddings.close()
|
||||
self.embeddings = None
|
||||
except Exception as close_err:
|
||||
logger.warning(f"Error closing existing embeddings: {close_err}")
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
|
||||
logger.debug(f"Created index directory: {os.path.dirname(self.index_path)}")
|
||||
|
||||
# Initialize embeddings with optimal configuration for ALwrity use case
|
||||
# Hardening: Disabling quantization by default as it causes 'IndexIDMap' attribute errors with small indices on Windows
|
||||
self.embeddings = Embeddings({
|
||||
"path": self.model_path,
|
||||
"content": True, # Enable content storage for retrieval
|
||||
"objects": True, # Enable object storage for metadata
|
||||
"backend": "faiss", # Use Faiss for efficient similarity search
|
||||
"quantize": True, # Enable quantization for memory efficiency
|
||||
"backend": self._backend, # Use Faiss for efficient similarity search
|
||||
"batch": 32, # Batch size for processing
|
||||
"gpu": False, # Force CPU usage for compatibility
|
||||
"limit": 1000 # Maximum number of results for queries
|
||||
@@ -76,7 +101,12 @@ class TxtaiIntelligenceService:
|
||||
try:
|
||||
self.embeddings.load(self.index_path)
|
||||
logger.info(f"Successfully loaded existing txtai index for user {self.user_id}")
|
||||
logger.debug(f"Index contains {len(self.embeddings)} items")
|
||||
# Try to log count, handle if not supported
|
||||
try:
|
||||
count = self.embeddings.count() if hasattr(self.embeddings, 'count') else "unknown"
|
||||
logger.debug(f"Index contains {count} items")
|
||||
except:
|
||||
logger.debug("Index loaded (count unavailable)")
|
||||
except Exception as load_error:
|
||||
logger.warning(f"Failed to load existing index: {load_error}. Creating new index.")
|
||||
# Reset embeddings to create new index
|
||||
@@ -84,8 +114,7 @@ class TxtaiIntelligenceService:
|
||||
"path": self.model_path,
|
||||
"content": True,
|
||||
"objects": True,
|
||||
"backend": "faiss",
|
||||
"quantize": True,
|
||||
"backend": self._backend,
|
||||
"batch": 32,
|
||||
"gpu": False,
|
||||
"limit": 1000
|
||||
@@ -146,8 +175,15 @@ class TxtaiIntelligenceService:
|
||||
logger.error(f"Error indexing content for user {self.user_id}: {e}")
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
logger.error(f"Items count: {len(items) if items else 0}")
|
||||
if items and len(items) > 0:
|
||||
logger.error(f"Sample item structure: {type(items[0])}")
|
||||
|
||||
message = str(e)
|
||||
is_windows_lock_error = isinstance(e, PermissionError) or "WinError 32" in message
|
||||
if is_windows_lock_error:
|
||||
logger.warning(
|
||||
f"Txtai index save skipped for user {self.user_id} due to file lock. "
|
||||
f"The index will be retried on a future run."
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
async def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
@@ -172,7 +208,20 @@ class TxtaiIntelligenceService:
|
||||
logger.debug(f"Cache miss for search query: '{query}'")
|
||||
|
||||
logger.debug(f"Searching for query: '{query}' with limit: {limit}")
|
||||
results = self.embeddings.search(query, limit=limit)
|
||||
try:
|
||||
results = self.embeddings.search(query, limit=limit)
|
||||
except AttributeError as ae:
|
||||
if "nprobe" in str(ae):
|
||||
logger.error(f"Detected known txtai/faiss IndexIDMap/nprobe incompatibility for user {self.user_id}. Attempting re-init with numpy backend fallback...")
|
||||
# Switch to numpy backend which doesn't have this issue
|
||||
self._backend = "numpy"
|
||||
self._initialize_embeddings()
|
||||
if self.embeddings:
|
||||
results = self.embeddings.search(query, limit=limit)
|
||||
else:
|
||||
raise ae
|
||||
else:
|
||||
raise ae
|
||||
|
||||
# Cache the results if caching is enabled
|
||||
if self.enable_caching and self.cache_manager and results:
|
||||
@@ -216,7 +265,19 @@ class TxtaiIntelligenceService:
|
||||
logger.debug(f"Cache miss for similarity calculation")
|
||||
|
||||
logger.debug(f"Calculating similarity between texts: '{text1[:50]}...' and '{text2[:50]}...'")
|
||||
similarity = self.embeddings.similarity(text1, text2)
|
||||
try:
|
||||
similarity = self.embeddings.similarity(text1, text2)
|
||||
except AttributeError as ae:
|
||||
if "nprobe" in str(ae):
|
||||
logger.error(f"Detected IndexIDMap nprobe error in similarity for user {self.user_id}. Falling back to numpy backend...")
|
||||
self._backend = "numpy"
|
||||
self._initialize_embeddings()
|
||||
if self.embeddings:
|
||||
similarity = self.embeddings.similarity(text1, text2)
|
||||
else:
|
||||
raise ae
|
||||
else:
|
||||
raise ae
|
||||
|
||||
# Cache the similarity result
|
||||
if self.enable_caching and self.cache_manager:
|
||||
@@ -272,7 +333,19 @@ class TxtaiIntelligenceService:
|
||||
# Use graph-based clustering if available
|
||||
# Perform a search to get graph structure
|
||||
sample_query = "content marketing digital strategy"
|
||||
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
|
||||
try:
|
||||
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
|
||||
except AttributeError as ae:
|
||||
if "nprobe" in str(ae):
|
||||
logger.error(f"Detected IndexIDMap nprobe error in cluster for user {self.user_id}. Falling back to numpy backend...")
|
||||
self._backend = "numpy"
|
||||
self._initialize_embeddings()
|
||||
if self.embeddings:
|
||||
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
|
||||
else:
|
||||
raise ae
|
||||
else:
|
||||
raise ae
|
||||
|
||||
if not graph_results:
|
||||
logger.warning(f"No graph results for clustering user {self.user_id}")
|
||||
@@ -306,7 +379,7 @@ class TxtaiIntelligenceService:
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
return self._fallback_clustering(min_score)
|
||||
|
||||
def _fallback_clustering(self, min_score: float) -> List[List[int]]:
|
||||
async def _fallback_clustering(self, min_score: float) -> List[List[int]]:
|
||||
"""Fallback clustering method when graph clustering is not available."""
|
||||
logger.info(f"Using fallback clustering for user {self.user_id}")
|
||||
|
||||
@@ -318,7 +391,8 @@ class TxtaiIntelligenceService:
|
||||
all_clusters = []
|
||||
|
||||
for query in sample_queries:
|
||||
results = self.embeddings.search(query, limit=5)
|
||||
# Use our search wrapper for hardening
|
||||
results = await self.search(query, limit=5)
|
||||
if results and results[0].get("score", 0) >= min_score:
|
||||
# Create a cluster from similar results
|
||||
cluster = [i for i, result in enumerate(results) if result.get("score", 0) >= min_score]
|
||||
@@ -393,9 +467,13 @@ class TxtaiIntelligenceService:
|
||||
return {"status": "not_initialized", "user_id": self.user_id}
|
||||
|
||||
try:
|
||||
# Get count of indexed items - txtai doesn't have a direct len() method
|
||||
# We'll estimate based on available data or return a placeholder
|
||||
index_size = getattr(self.embeddings, 'count', 0) or "unknown"
|
||||
# Get count of indexed items
|
||||
index_size = "unknown"
|
||||
if hasattr(self.embeddings, 'count'):
|
||||
try:
|
||||
index_size = self.embeddings.count()
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": "active",
|
||||
@@ -410,5 +488,7 @@ class TxtaiIntelligenceService:
|
||||
return {"status": "error", "user_id": self.user_id, "error": str(e)}
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the service is properly initialized."""
|
||||
"""Check if the service is properly initialized, triggering lazy init if needed."""
|
||||
if not self._initialized:
|
||||
self._ensure_initialized()
|
||||
return self._initialized and self.embeddings is not None
|
||||
|
||||
@@ -369,6 +369,12 @@ def huggingface_structured_json_response(
|
||||
response_text = re.sub(r'```\n?', '', response_text)
|
||||
response_text = response_text.strip()
|
||||
|
||||
# Fix common markdown artefacts that break JSON, e.g. lines starting with **"key":
|
||||
# **"narration": "text"
|
||||
# becomes:
|
||||
# "narration": "text"
|
||||
response_text = re.sub(r'^\s*\*\*(?=\s*")', '', response_text, flags=re.MULTILINE)
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
logger.info("✅ Hugging Face structured JSON response parsed from text")
|
||||
|
||||
@@ -648,11 +648,13 @@ async def ai_video_generate(
|
||||
|
||||
# PRE-FLIGHT VALIDATION: Validate video generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Database session unavailable for user.")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
@@ -762,9 +764,11 @@ def track_video_usage(
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
return {}
|
||||
try:
|
||||
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
|
||||
pricing_service_track = PricingService(db_track)
|
||||
|
||||
@@ -527,6 +527,11 @@ class APIKeyManager:
|
||||
def __init__(self):
|
||||
self.api_keys = {}
|
||||
self._load_from_env()
|
||||
|
||||
def load_api_keys(self):
|
||||
self.api_keys = {}
|
||||
self._load_from_env()
|
||||
return self.api_keys
|
||||
|
||||
def _load_from_env(self):
|
||||
"""Load API keys from environment variables."""
|
||||
|
||||
@@ -27,6 +27,12 @@ async def generate_facebook_persona_task(user_id: str):
|
||||
try:
|
||||
logger.info(f"Scheduled Facebook persona generation started for user {user_id}")
|
||||
|
||||
# Ensure we have a valid session factory before trying to get session
|
||||
from services.database import SessionLocal
|
||||
if not SessionLocal:
|
||||
logger.error("Database session factory not initialized")
|
||||
return
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error(f"Failed to get database session for Facebook persona generation (user: {user_id})")
|
||||
|
||||
177
backend/services/podcast_bible_service.py
Normal file
177
backend/services/podcast_bible_service.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
from services.product_marketing.personalization_service import PersonalizationService
|
||||
from models.podcast_bible_models import (
|
||||
PodcastBible,
|
||||
HostPersona,
|
||||
AudienceDNA,
|
||||
BrandDNA,
|
||||
VisualStyle,
|
||||
AudioEnvironment,
|
||||
ShowRules
|
||||
)
|
||||
|
||||
class PodcastBibleService:
|
||||
"""Service for generating and managing the Podcast Bible."""
|
||||
|
||||
def __init__(self):
|
||||
self.personalization_service = PersonalizationService()
|
||||
|
||||
def generate_bible(self, user_id: str, project_id: str) -> PodcastBible:
|
||||
"""Generate a Podcast Bible from onboarding data."""
|
||||
logger.info(f"Generating Podcast Bible for user {user_id}")
|
||||
|
||||
try:
|
||||
preferences = self.personalization_service.get_user_preferences(user_id)
|
||||
writing_style = preferences.get("writing_style", {})
|
||||
style_prefs = preferences.get("style_preferences", {})
|
||||
target_audience = preferences.get("target_audience", {})
|
||||
industry = preferences.get("industry", "General Business")
|
||||
|
||||
# 1. Map Host Persona
|
||||
host = HostPersona(
|
||||
name="Your AI Host",
|
||||
background=f"Expert in {industry}",
|
||||
expertise_level=writing_style.get("complexity", "Expert").capitalize(),
|
||||
personality_traits=[
|
||||
writing_style.get("tone", "Professional").capitalize(),
|
||||
writing_style.get("engagement_level", "Informative").capitalize()
|
||||
],
|
||||
vocal_style=writing_style.get("voice", "Authoritative").capitalize(),
|
||||
vocal_characteristics=["Clear", "Articulate", writing_style.get("voice", "Steady")],
|
||||
look=f"A professional individual dressed in business-casual attire, fitting the {industry} industry aesthetic.",
|
||||
catchphrases=[]
|
||||
)
|
||||
|
||||
# 2. Map Audience DNA
|
||||
audience = AudienceDNA(
|
||||
expertise_level=target_audience.get("expertise_level", "Intermediate").capitalize(),
|
||||
interests=target_audience.get("interests", ["Industry Trends", "Innovation"]),
|
||||
pain_points=target_audience.get("pain_points", ["Staying ahead of competition", "Efficiency"]),
|
||||
demographics=None
|
||||
)
|
||||
|
||||
# 3. Map Brand DNA
|
||||
brand = BrandDNA(
|
||||
industry=industry,
|
||||
tone=writing_style.get("tone", "Professional").capitalize(),
|
||||
communication_style=writing_style.get("engagement_level", "Informative").capitalize(),
|
||||
key_messages=preferences.get("brand_values", []),
|
||||
competitor_context=None
|
||||
)
|
||||
|
||||
# 4. Map Visual Style
|
||||
visual = VisualStyle(
|
||||
style_preset=style_prefs.get("aesthetic", "Professional Studio").capitalize(),
|
||||
environment=f"A modern {industry}-themed podcast studio with professional equipment.",
|
||||
lighting="Soft, warm studio lighting with subtle rim lights.",
|
||||
color_palette=preferences.get("brand_colors", ["#1e293b", "#3b82f6"]),
|
||||
camera_style="Dynamic mid-shots with occasional close-ups for emphasis."
|
||||
)
|
||||
|
||||
# 5. Map Audio Environment
|
||||
audio_env = AudioEnvironment(
|
||||
soundscape="Pristine studio environment with deep, warm acoustics.",
|
||||
music_mood=f"{writing_style.get('tone', 'Professional').capitalize()} & {writing_style.get('engagement_level', 'Upbeat').capitalize()}",
|
||||
sfx_style="Modern, clean interface-inspired sounds."
|
||||
)
|
||||
|
||||
# 6. Map Show Rules
|
||||
show_rules = ShowRules(
|
||||
intro_format=f"Start with a high-energy hook about the episode topic, followed by a warm welcome and an overview of the {industry} insights to be shared.",
|
||||
outro_format="Summarize the key takeaways, provide a clear call to action, and sign off with a professional closing.",
|
||||
interaction_tone=writing_style.get("engagement_level", "Conversational").capitalize(),
|
||||
constraints=[
|
||||
"Avoid overly technical jargon unless defined",
|
||||
"Keep segments concise and factual",
|
||||
f"Maintain a {writing_style.get('tone', 'Professional')} tone at all times"
|
||||
]
|
||||
)
|
||||
|
||||
bible = PodcastBible(
|
||||
project_id=project_id,
|
||||
host=host,
|
||||
audience=audience,
|
||||
brand=brand,
|
||||
visual_style=visual,
|
||||
audio_environment=audio_env,
|
||||
show_rules=show_rules
|
||||
)
|
||||
|
||||
logger.info(f"Podcast Bible generated successfully for project {project_id}")
|
||||
return bible
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Podcast Bible: {str(e)}")
|
||||
# Return a default bible if something goes wrong to ensure project creation doesn't fail
|
||||
return self._get_default_bible(project_id)
|
||||
|
||||
def _get_default_bible(self, project_id: str) -> PodcastBible:
|
||||
"""Return a sensible default Bible."""
|
||||
return PodcastBible(
|
||||
project_id=project_id,
|
||||
host=HostPersona(
|
||||
name="AI Host",
|
||||
background="Industry Professional",
|
||||
expertise_level="Expert",
|
||||
vocal_style="Authoritative",
|
||||
vocal_characteristics=["Deep", "Steady"]
|
||||
),
|
||||
audience=AudienceDNA(
|
||||
expertise_level="Intermediate",
|
||||
interests=["Industry Trends", "Technology"],
|
||||
pain_points=["Staying Competitive", "Operational Efficiency"]
|
||||
),
|
||||
brand=BrandDNA(
|
||||
industry="General Business",
|
||||
tone="Professional",
|
||||
communication_style="Analytical"
|
||||
),
|
||||
visual_style=VisualStyle(
|
||||
environment="Professional modern office studio",
|
||||
color_palette=["#000000", "#FFFFFF"]
|
||||
),
|
||||
audio_environment=AudioEnvironment(),
|
||||
show_rules=ShowRules(
|
||||
intro_format="Standard welcome and topic introduction.",
|
||||
outro_format="Summary and sign-off."
|
||||
)
|
||||
)
|
||||
|
||||
def serialize_bible(self, bible: PodcastBible) -> str:
|
||||
"""Serialize the Bible into a prompt-friendly text block."""
|
||||
return f"""
|
||||
<podcast_bible>
|
||||
HOST PERSONA:
|
||||
- Name: {bible.host.name}
|
||||
- Background: {bible.host.background}
|
||||
- Expertise Level: {bible.host.expertise_level}
|
||||
- Personality: {', '.join(bible.host.personality_traits)}
|
||||
- Vocal Style: {bible.host.vocal_style}
|
||||
- Vocal Characteristics: {', '.join(bible.host.vocal_characteristics)}
|
||||
- Visual Look: {bible.host.look}
|
||||
|
||||
TARGET AUDIENCE:
|
||||
- Expertise: {bible.audience.expertise_level}
|
||||
- Interests: {', '.join(bible.audience.interests)}
|
||||
- Pain Points: {', '.join(bible.audience.pain_points)}
|
||||
|
||||
BRAND & STYLE:
|
||||
- Industry: {bible.brand.industry}
|
||||
- Tone: {bible.brand.tone}
|
||||
- Communication Style: {bible.brand.communication_style}
|
||||
- Visual Style Preset: {bible.visual_style.style_preset}
|
||||
- Environment: {bible.visual_style.environment}
|
||||
- Lighting: {bible.visual_style.lighting}
|
||||
|
||||
AUDIO ENVIRONMENT:
|
||||
- Soundscape: {bible.audio_environment.soundscape}
|
||||
- Music Mood: {bible.audio_environment.music_mood}
|
||||
|
||||
SHOW RULES & STRUCTURE:
|
||||
- Intro Format: {bible.show_rules.intro_format}
|
||||
- Outro Format: {bible.show_rules.outro_format}
|
||||
- Interaction Tone: {bible.show_rules.interaction_tone}
|
||||
- Constraints: {', '.join(bible.show_rules.constraints)}
|
||||
</podcast_bible>
|
||||
"""
|
||||
@@ -11,6 +11,7 @@ from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from models.podcast_models import PodcastProject
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
|
||||
|
||||
class PodcastService:
|
||||
@@ -18,6 +19,7 @@ class PodcastService:
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.bible_service = PodcastBibleService()
|
||||
|
||||
def create_project(
|
||||
self,
|
||||
@@ -30,6 +32,9 @@ class PodcastService:
|
||||
**kwargs
|
||||
) -> PodcastProject:
|
||||
"""Create a new podcast project."""
|
||||
# Generate Podcast Bible automatically from onboarding data
|
||||
bible = self.bible_service.generate_bible(user_id, project_id)
|
||||
|
||||
project = PodcastProject(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
@@ -37,6 +42,7 @@ class PodcastService:
|
||||
duration=duration,
|
||||
speakers=speakers,
|
||||
budget_cap=budget_cap,
|
||||
bible=bible.model_dump() if bible else None,
|
||||
status="draft",
|
||||
current_step="create",
|
||||
**kwargs
|
||||
|
||||
@@ -5,13 +5,15 @@ Pluggable task scheduler that can work with any task model.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List, Callable
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from .executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from .task_registry import TaskRegistry
|
||||
@@ -19,8 +21,10 @@ from .exception_handler import (
|
||||
SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError,
|
||||
TaskLoaderError, SchedulerConfigError
|
||||
)
|
||||
|
||||
from services.database import get_all_user_ids, get_session_for_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from ..utils.user_job_store import get_user_job_store_name
|
||||
from models.scheduler_models import SchedulerEventLog
|
||||
from .interval_manager import determine_optimal_interval, adjust_check_interval_if_needed
|
||||
@@ -86,6 +90,9 @@ class TaskScheduler:
|
||||
}
|
||||
)
|
||||
|
||||
# Configure APScheduler to use unified logging system
|
||||
self._configure_apscheduler_logging()
|
||||
|
||||
# Task executor registry
|
||||
self.registry = TaskRegistry()
|
||||
|
||||
@@ -115,6 +122,21 @@ class TaskScheduler:
|
||||
}
|
||||
|
||||
self._running = False
|
||||
|
||||
# Local Desktop App: Always leader, no advisory locks needed
|
||||
self._leader_lock_key = int(os.getenv("SCHEDULER_LEADER_LOCK_KEY", "84321017"))
|
||||
self._leadership_check_interval_seconds = int(os.getenv("SCHEDULER_LEADERSHIP_CHECK_INTERVAL", "15"))
|
||||
self._leader_session = None
|
||||
self._is_leader = True # Always leader in local desktop app
|
||||
self._execution_enabled = True # Always enabled
|
||||
self._leader_since = datetime.utcnow().isoformat()
|
||||
self._last_leadership_check = None
|
||||
self._last_leadership_error = None
|
||||
|
||||
|
||||
# Execution lease registry (prevents duplicate redispatch across check cycles)
|
||||
self._task_leases: Dict[str, str] = {}
|
||||
self._task_lease_ttl_seconds = int(os.getenv("SCHEDULER_TASK_LEASE_TTL_SECONDS", "900"))
|
||||
|
||||
def _get_trigger_for_interval(self, interval_minutes: int):
|
||||
"""
|
||||
@@ -153,6 +175,144 @@ class TaskScheduler:
|
||||
self.registry.register(task_type, executor, task_loader)
|
||||
logger.info(f"Registered executor for task type: {task_type}")
|
||||
|
||||
def _configure_apscheduler_logging(self):
|
||||
"""Configure APScheduler to use unified logging system."""
|
||||
import logging
|
||||
|
||||
# Get APScheduler loggers and redirect them to unified logging
|
||||
apscheduler_logger = logging.getLogger("apscheduler")
|
||||
apscheduler_scheduler_logger = logging.getLogger("apscheduler.scheduler")
|
||||
apscheduler_executors_logger = logging.getLogger("apscheduler.executors")
|
||||
apscheduler_jobstores_logger = logging.getLogger("apscheduler.jobstores")
|
||||
|
||||
# Create a custom handler that redirects to unified logger
|
||||
class APSchedulerUnifiedHandler(logging.Handler):
|
||||
def __init__(self, service_logger):
|
||||
super().__init__()
|
||||
self.service_logger = service_logger
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
# Format the message
|
||||
msg = self.format(record)
|
||||
|
||||
# Map APScheduler log levels to unified logger
|
||||
if record.levelno >= logging.ERROR:
|
||||
self.service_logger.error(f"[APScheduler] {msg}")
|
||||
elif record.levelno >= logging.WARNING:
|
||||
self.service_logger.warning(f"[APScheduler] {msg}")
|
||||
elif record.levelno >= logging.INFO:
|
||||
self.service_logger.info(f"[APScheduler] {msg}")
|
||||
else:
|
||||
self.service_logger.debug(f"[APScheduler] {msg}")
|
||||
except Exception:
|
||||
# Don't let logging errors break the scheduler
|
||||
pass
|
||||
|
||||
# Create and add the handler
|
||||
unified_handler = APSchedulerUnifiedHandler(logger)
|
||||
unified_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Add handler to all APScheduler loggers
|
||||
apscheduler_logger.addHandler(unified_handler)
|
||||
apscheduler_scheduler_logger.addHandler(unified_handler)
|
||||
apscheduler_executors_logger.addHandler(unified_handler)
|
||||
apscheduler_jobstores_logger.addHandler(unified_handler)
|
||||
|
||||
# Set levels to capture all logs
|
||||
apscheduler_logger.setLevel(logging.DEBUG)
|
||||
apscheduler_scheduler_logger.setLevel(logging.DEBUG)
|
||||
apscheduler_executors_logger.setLevel(logging.DEBUG)
|
||||
apscheduler_jobstores_logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Prevent propagation to avoid duplicate logs
|
||||
apscheduler_logger.propagate = False
|
||||
apscheduler_scheduler_logger.propagate = False
|
||||
apscheduler_executors_logger.propagate = False
|
||||
apscheduler_jobstores_logger.propagate = False
|
||||
|
||||
logger.info("APScheduler logging configured to use unified logging system")
|
||||
|
||||
|
||||
def _scheduler_identity(self) -> str:
|
||||
return f"{os.getenv('HOSTNAME', 'local')}-{os.getpid()}"
|
||||
|
||||
def _acquire_leadership(self) -> bool:
|
||||
"""Always return True for local desktop app (no HA needed)."""
|
||||
self._is_leader = True
|
||||
self._execution_enabled = True
|
||||
if not self._leader_since:
|
||||
self._leader_since = datetime.utcnow().isoformat()
|
||||
self._last_leadership_check = datetime.utcnow().isoformat()
|
||||
return True
|
||||
|
||||
def _release_leadership(self):
|
||||
"""No-op for local desktop app."""
|
||||
pass
|
||||
|
||||
def _sync_check_due_tasks_job(self):
|
||||
"""Ensure check_due_tasks job exists only for leader."""
|
||||
job = self.scheduler.get_job('check_due_tasks')
|
||||
if self._is_leader and self._execution_enabled:
|
||||
if job is None:
|
||||
self.scheduler.add_job(
|
||||
self._check_and_execute_due_tasks,
|
||||
trigger=self._get_trigger_for_interval(self.current_check_interval_minutes),
|
||||
id='check_due_tasks',
|
||||
replace_existing=True
|
||||
)
|
||||
else:
|
||||
if job is not None:
|
||||
self.scheduler.remove_job('check_due_tasks')
|
||||
|
||||
async def _leadership_tick(self):
|
||||
"""Periodic leadership check/renewal (Stub for local)."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._acquire_leadership()
|
||||
self._sync_check_due_tasks_job()
|
||||
|
||||
def _acquire_task_lease(self, task_key: str) -> bool:
|
||||
"""Acquire in-memory lease for a task key if available/expired."""
|
||||
now = datetime.utcnow()
|
||||
expiry_str = self._task_leases.get(task_key)
|
||||
|
||||
if expiry_str:
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expiry_str)
|
||||
if expiry > now:
|
||||
return False
|
||||
except Exception:
|
||||
# Corrupted lease value: overwrite safely
|
||||
pass
|
||||
|
||||
expiry = now + timedelta(seconds=self._task_lease_ttl_seconds)
|
||||
self._task_leases[task_key] = expiry.isoformat()
|
||||
return True
|
||||
|
||||
def _release_task_lease(self, task_key: str):
|
||||
"""Release lease for task key."""
|
||||
if task_key in self._task_leases:
|
||||
del self._task_leases[task_key]
|
||||
|
||||
def _is_task_leased(self, task_key: str) -> bool:
|
||||
"""Check whether task key is currently leased and not expired."""
|
||||
expiry_str = self._task_leases.get(task_key)
|
||||
if not expiry_str:
|
||||
return False
|
||||
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expiry_str)
|
||||
if expiry > datetime.utcnow():
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Expired/corrupt lease gets cleaned up lazily
|
||||
self._release_task_lease(task_key)
|
||||
return False
|
||||
|
||||
async def start(self):
|
||||
"""Start the scheduler with intelligent interval adjustment."""
|
||||
if self._running:
|
||||
@@ -168,16 +328,21 @@ class TaskScheduler:
|
||||
)
|
||||
self.current_check_interval_minutes = initial_interval
|
||||
|
||||
# Add periodic job to check for due tasks
|
||||
self.scheduler.add_job(
|
||||
self._check_and_execute_due_tasks,
|
||||
trigger=self._get_trigger_for_interval(initial_interval),
|
||||
id='check_due_tasks',
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
self._running = True
|
||||
|
||||
# Leadership monitor runs on all replicas; only leader executes due-task loop.
|
||||
self.scheduler.add_job(
|
||||
self._leadership_tick,
|
||||
trigger=IntervalTrigger(seconds=self._leadership_check_interval_seconds),
|
||||
id='leadership_monitor',
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True
|
||||
)
|
||||
|
||||
# Initial leader election
|
||||
await self._leadership_tick()
|
||||
|
||||
# Check for and execute any missed jobs that are still within grace period
|
||||
await self._execute_missed_jobs()
|
||||
@@ -206,7 +371,7 @@ class TaskScheduler:
|
||||
registered_types = self.registry.get_registered_types()
|
||||
active_strategies = self.stats.get('active_strategies_count', 0)
|
||||
|
||||
# Count OAuth token monitoring tasks from database (recurring weekly tasks)
|
||||
# Count tasks per user (Multi-tenant SQLite)
|
||||
oauth_tasks_count = 0
|
||||
website_analysis_tasks_count = 0
|
||||
platform_insights_tasks_count = 0
|
||||
@@ -323,126 +488,6 @@ class TaskScheduler:
|
||||
|
||||
startup_lines.append(f"{prefix} Job: {job.id} | Trigger: {trigger_type} | Next Run: {next_run}{user_context}")
|
||||
|
||||
# Add OAuth token monitoring tasks details
|
||||
# Show ALL OAuth tasks (active and inactive) for complete visibility
|
||||
if total_oauth_tasks > 0:
|
||||
try:
|
||||
user_ids = get_all_user_ids()
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
from models.oauth_token_monitoring_models import OAuthTokenMonitoringTask
|
||||
# Get ALL tasks for this user
|
||||
oauth_tasks = db.query(OAuthTokenMonitoringTask).all()
|
||||
|
||||
for idx, task in enumerate(oauth_tasks):
|
||||
is_last = idx == len(oauth_tasks) - 1 and website_analysis_tasks_count == 0 and platform_insights_tasks_count == 0 and len(all_jobs) == 0 and user_id == user_ids[-1]
|
||||
prefix = " ├─" # Simplified prefix logic for multi-user list
|
||||
|
||||
try:
|
||||
user_job_store = get_user_job_store_name(task.user_id, db)
|
||||
if user_job_store == 'default':
|
||||
logger.debug(
|
||||
f"[Scheduler] Job store extraction returned 'default' for user {task.user_id}. "
|
||||
f"This may indicate no onboarding data or website URL not found."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Scheduler] Could not extract job store name for user {task.user_id}: {e}. "
|
||||
f"Using 'default'. Error type: {type(e).__name__}"
|
||||
)
|
||||
user_job_store = 'default'
|
||||
|
||||
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
|
||||
# Include status in the log line for visibility
|
||||
status_indicator = "✅" if task.status == 'active' else f"[{task.status}]"
|
||||
startup_lines.append(
|
||||
f"{prefix} Job: oauth_token_monitoring_{task.platform}_{task.user_id} | "
|
||||
f"Trigger: CronTrigger (Weekly) | Next Run: {next_check} | "
|
||||
f"User: {task.user_id} | Store: {user_job_store} | Platform: {task.platform} {status_indicator}"
|
||||
)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking OAuth tasks for user {user_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get OAuth token monitoring task details: {e}")
|
||||
|
||||
# Add website analysis tasks details
|
||||
if website_analysis_tasks_count > 0:
|
||||
try:
|
||||
user_ids = get_all_user_ids()
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
from models.website_analysis_monitoring_models import WebsiteAnalysisTask
|
||||
website_analysis_tasks = db.query(WebsiteAnalysisTask).all()
|
||||
|
||||
for idx, task in enumerate(website_analysis_tasks):
|
||||
is_last = idx == len(website_analysis_tasks) - 1 and platform_insights_tasks_count == 0 and len(all_jobs) == 0 and total_oauth_tasks == 0 and user_id == user_ids[-1]
|
||||
prefix = " ├─" # Simplified
|
||||
|
||||
try:
|
||||
user_job_store = get_user_job_store_name(task.user_id, db)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract job store name for user {task.user_id}: {e}")
|
||||
user_job_store = 'default'
|
||||
|
||||
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
|
||||
frequency = f"Every {task.frequency_days} days"
|
||||
task_type_label = "User Website" if task.task_type == 'user_website' else "Competitor"
|
||||
status_indicator = "✅" if task.status == 'active' else f"[{task.status}]"
|
||||
website_display = task.website_url[:50] + "..." if task.website_url and len(task.website_url) > 50 else (task.website_url or 'N/A')
|
||||
|
||||
startup_lines.append(
|
||||
f"{prefix} Job: website_analysis_{task.task_type}_{task.user_id}_{task.id} | "
|
||||
f"Trigger: CronTrigger ({frequency}) | Next Run: {next_check} | "
|
||||
f"User: {task.user_id} | Store: {user_job_store} | Type: {task_type_label} | URL: {website_display} {status_indicator}"
|
||||
)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking website analysis tasks for user {user_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get website analysis task details: {e}")
|
||||
|
||||
# Add platform insights tasks details
|
||||
if platform_insights_tasks_count > 0:
|
||||
try:
|
||||
user_ids = get_all_user_ids()
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
from models.platform_insights_monitoring_models import PlatformInsightsTask
|
||||
platform_insights_tasks = db.query(PlatformInsightsTask).all()
|
||||
|
||||
for idx, task in enumerate(platform_insights_tasks):
|
||||
is_last = idx == len(platform_insights_tasks) - 1 and len(all_jobs) == 0 and total_oauth_tasks == 0 and website_analysis_tasks_count == 0 and user_id == user_ids[-1]
|
||||
prefix = " ├─" # Simplified
|
||||
|
||||
try:
|
||||
user_job_store = get_user_job_store_name(task.user_id, db)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract job store name for user {task.user_id}: {e}")
|
||||
user_job_store = 'default'
|
||||
|
||||
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
|
||||
platform_label = task.platform.upper() if task.platform else 'Unknown'
|
||||
site_display = task.site_url[:50] + "..." if task.site_url and len(task.site_url) > 50 else (task.site_url or 'N/A')
|
||||
status_indicator = "✅" if task.status == 'active' else f"[{task.status}]"
|
||||
|
||||
startup_lines.append(
|
||||
f"{prefix} Job: platform_insights_{task.platform}_{task.user_id} | "
|
||||
f"Trigger: CronTrigger (Weekly) | Next Run: {next_check} | "
|
||||
f"User: {task.user_id} | Store: {user_job_store} | Platform: {platform_label} | Site: {site_display} {status_indicator}"
|
||||
)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking platform insights tasks for user {user_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get platform insights task details: {e}")
|
||||
|
||||
# Add Advertools tasks details
|
||||
if advertools_tasks_count > 0:
|
||||
try:
|
||||
@@ -518,7 +563,15 @@ class TaskScheduler:
|
||||
|
||||
# Get final job count before shutdown
|
||||
all_jobs_before = self.scheduler.get_jobs()
|
||||
|
||||
|
||||
# Release leadership lock and stop leadership monitor
|
||||
try:
|
||||
if self.scheduler.get_job('leadership_monitor') is not None:
|
||||
self.scheduler.remove_job('leadership_monitor')
|
||||
except Exception:
|
||||
pass
|
||||
self._release_leadership()
|
||||
|
||||
# Shutdown scheduler
|
||||
self.scheduler.shutdown(wait=True)
|
||||
self._running = False
|
||||
@@ -569,6 +622,10 @@ class TaskScheduler:
|
||||
Main scheduler loop: check for due tasks and execute them.
|
||||
This runs periodically with intelligent interval adjustment based on active strategies.
|
||||
"""
|
||||
if not self._execution_enabled or not self._is_leader:
|
||||
logger.debug("[Scheduler] Skipping due-task loop on standby replica")
|
||||
return
|
||||
|
||||
await check_and_execute_due_tasks(self)
|
||||
|
||||
async def _adjust_check_interval_if_needed(self, db: Session):
|
||||
@@ -614,309 +671,156 @@ class TaskScheduler:
|
||||
except Exception as e:
|
||||
logger.warning(f"[Scheduler] Error checking for missed jobs: {e}")
|
||||
|
||||
async def trigger_interval_adjustment(self):
|
||||
"""
|
||||
Trigger immediate interval adjustment check.
|
||||
|
||||
This should be called when a strategy is activated or deactivated
|
||||
to immediately adjust the scheduler interval based on current active strategies.
|
||||
"""
|
||||
if not self._running:
|
||||
logger.debug("Scheduler not running, skipping interval adjustment")
|
||||
return
|
||||
|
||||
try:
|
||||
# Multi-tenant aware adjustment (iterates all users internally)
|
||||
await adjust_check_interval_if_needed(self)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error triggering interval adjustment: {e}")
|
||||
|
||||
async def _validate_and_rebuild_cumulative_stats(self):
|
||||
"""
|
||||
Validate cumulative stats on scheduler startup and rebuild if needed.
|
||||
This ensures cumulative stats are accurate after restarts.
|
||||
|
||||
NOTE: Disabled in multi-tenant mode as there is no global database for cumulative stats.
|
||||
TODO: Implement per-user cumulative stats or a global admin database.
|
||||
Validate and rebuild cumulative stats if needed.
|
||||
Currently a placeholder for future implementation.
|
||||
"""
|
||||
logger.info("[Scheduler] Cumulative stats validation skipped (multi-tenant mode)")
|
||||
return
|
||||
|
||||
async def _process_task_type(self, task_type: str, db: Session, cycle_summary: Dict[str, Any] = None, user_id: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Process due tasks for a specific task type.
|
||||
|
||||
Returns:
|
||||
Summary dict with 'found', 'executed', 'failed' counts, or None if no tasks
|
||||
"""
|
||||
summary = {'found': 0, 'executed': 0, 'failed': 0}
|
||||
|
||||
pass
|
||||
|
||||
async def _process_task_type(
|
||||
self,
|
||||
task_type: str,
|
||||
db: Session,
|
||||
cycle_summary: Dict[str, Any],
|
||||
user_id: Optional[str] = None
|
||||
) -> Dict[str, int]:
|
||||
summary = {"found": 0, "executed": 0, "failed": 0}
|
||||
try:
|
||||
# Get task loader for this type
|
||||
try:
|
||||
task_loader = self.registry.get_task_loader(task_type)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to get task loader for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return None
|
||||
|
||||
# Load due tasks (with error handling)
|
||||
try:
|
||||
due_tasks = task_loader(db)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to load due tasks for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return None
|
||||
|
||||
if not due_tasks:
|
||||
return None
|
||||
|
||||
summary['found'] = len(due_tasks)
|
||||
self.stats['tasks_found'] += len(due_tasks)
|
||||
|
||||
# Execute tasks (with concurrency limit)
|
||||
execution_tasks = []
|
||||
skipped_count = 0
|
||||
for task in due_tasks:
|
||||
if len(self.active_executions) >= self.max_concurrent_executions:
|
||||
skipped_count = len(due_tasks) - len(execution_tasks)
|
||||
logger.warning(
|
||||
f"[Scheduler] ⚠️ Max concurrent executions reached ({self.max_concurrent_executions}), "
|
||||
f"skipping {skipped_count} tasks for {task_type}"
|
||||
)
|
||||
break
|
||||
|
||||
# Execute task asynchronously
|
||||
# Note: Each task gets its own database session to prevent concurrent access issues
|
||||
execution_task = asyncio.create_task(
|
||||
execute_task_async(self, task_type, task, summary, user_id=user_id)
|
||||
)
|
||||
|
||||
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
|
||||
self.active_executions[task_id] = execution_task
|
||||
|
||||
execution_tasks.append(execution_task)
|
||||
|
||||
# Wait for executions to complete (with timeout per task)
|
||||
if execution_tasks:
|
||||
await asyncio.wait(execution_tasks, timeout=300)
|
||||
|
||||
return summary
|
||||
|
||||
task_loader = self.registry.get_task_loader(task_type)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Error processing task type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
message=f"Failed to get task loader for type {task_type}: {str(e)}",
|
||||
user_id=user_id,
|
||||
context={"task_type": task_type},
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats["tasks_failed"] += 1
|
||||
return summary
|
||||
|
||||
|
||||
def _update_user_stats(self, user_id: Optional[int], success: bool):
|
||||
"""
|
||||
Update per-user statistics for user isolation tracking.
|
||||
|
||||
Args:
|
||||
user_id: User ID (None if user context not available)
|
||||
success: Whether task execution was successful
|
||||
"""
|
||||
if user_id is None:
|
||||
|
||||
try:
|
||||
tasks = task_loader(db)
|
||||
if not tasks:
|
||||
return summary
|
||||
|
||||
summary["found"] = len(tasks)
|
||||
max_concurrent = self.max_concurrent_executions
|
||||
|
||||
for task in tasks:
|
||||
task_id = getattr(task, "id", None)
|
||||
lease_key = f"{task_type}_{task_id or id(task)}"
|
||||
|
||||
if self._is_task_leased(lease_key):
|
||||
continue
|
||||
|
||||
if len(self.active_executions) >= max_concurrent:
|
||||
break
|
||||
|
||||
if not self._acquire_task_lease(lease_key):
|
||||
continue
|
||||
|
||||
execution_task = asyncio.create_task(
|
||||
execute_task_async(
|
||||
self,
|
||||
task_type,
|
||||
task,
|
||||
summary,
|
||||
execution_source="scheduler",
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
self.active_executions[lease_key] = execution_task
|
||||
|
||||
cycle_summary.setdefault("tasks_found_by_type", {})
|
||||
cycle_summary.setdefault("tasks_executed_by_type", {})
|
||||
cycle_summary.setdefault("tasks_failed_by_type", {})
|
||||
|
||||
cycle_summary["tasks_found_by_type"][task_type] = (
|
||||
cycle_summary["tasks_found_by_type"].get(task_type, 0)
|
||||
+ summary["found"]
|
||||
)
|
||||
cycle_summary["tasks_executed_by_type"][task_type] = (
|
||||
cycle_summary["tasks_executed_by_type"].get(task_type, 0)
|
||||
+ summary["executed"]
|
||||
)
|
||||
cycle_summary["tasks_failed_by_type"][task_type] = (
|
||||
cycle_summary["tasks_failed_by_type"].get(task_type, 0)
|
||||
+ summary["failed"]
|
||||
)
|
||||
|
||||
return summary
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Error processing task type {task_type}: {str(e)}",
|
||||
user_id=user_id,
|
||||
context={"task_type": task_type},
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats["tasks_failed"] += 1
|
||||
return summary
|
||||
|
||||
def _update_user_stats(self, user_id: Optional[str], success: bool):
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
if user_id not in self.stats['per_user_stats']:
|
||||
self.stats['per_user_stats'][user_id] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
|
||||
user_stats = self.stats['per_user_stats'][user_id]
|
||||
per_user = self.stats.setdefault("per_user_stats", {})
|
||||
user_stats = per_user.setdefault(
|
||||
user_id,
|
||||
{
|
||||
"tasks_executed": 0,
|
||||
"tasks_failed": 0,
|
||||
"last_update": None,
|
||||
},
|
||||
)
|
||||
if success:
|
||||
user_stats['executed'] += 1
|
||||
user_stats["tasks_executed"] += 1
|
||||
else:
|
||||
user_stats['failed'] += 1
|
||||
|
||||
# Calculate success rate
|
||||
total = user_stats['executed'] + user_stats['failed']
|
||||
if total > 0:
|
||||
user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0
|
||||
|
||||
async def _schedule_retry(self, task: Any, delay_seconds: int):
|
||||
"""Schedule a retry for a failed task."""
|
||||
# This would update the task's next_execution time
|
||||
# For now, just log - could be enhanced to update next_execution
|
||||
logger.debug(f"Scheduling retry for task in {delay_seconds}s")
|
||||
|
||||
def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get scheduler statistics with optional user filtering.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter statistics for specific user
|
||||
|
||||
Returns:
|
||||
Dictionary with scheduler statistics
|
||||
"""
|
||||
base_stats = {
|
||||
**{k: v for k, v in self.stats.items() if k not in ['per_user_stats']},
|
||||
'active_executions': len(self.active_executions),
|
||||
'registered_types': self.registry.get_registered_types(),
|
||||
'running': self._running,
|
||||
'check_interval_minutes': self.current_check_interval_minutes,
|
||||
'min_check_interval_minutes': self.min_check_interval_minutes,
|
||||
'max_check_interval_minutes': self.max_check_interval_minutes,
|
||||
'intelligent_scheduling': True
|
||||
}
|
||||
|
||||
# Include per-user stats (all users or filtered)
|
||||
if user_id is not None:
|
||||
if user_id in self.stats['per_user_stats']:
|
||||
base_stats['user_stats'] = self.stats['per_user_stats'][user_id]
|
||||
else:
|
||||
base_stats['user_stats'] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
else:
|
||||
# Include all per-user stats (for admin/debugging)
|
||||
base_stats['per_user_stats'] = self.stats['per_user_stats']
|
||||
|
||||
return base_stats
|
||||
|
||||
user_stats["tasks_failed"] += 1
|
||||
user_stats["last_update"] = datetime.utcnow().isoformat()
|
||||
|
||||
async def _schedule_retry(self, task: Any, retry_delay: int):
|
||||
try:
|
||||
task_id = getattr(task, "id", None)
|
||||
logger.warning(
|
||||
f"[Scheduler] Retry requested for task {task_id} in {retry_delay}s, "
|
||||
f"using loader-based retry semantics."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def schedule_one_time_task(
|
||||
self,
|
||||
func: Callable,
|
||||
run_date: datetime,
|
||||
job_id: str,
|
||||
args: tuple = (),
|
||||
kwargs: Dict[str, Any] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
replace_existing: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a one-time task to run at a specific datetime.
|
||||
Schedule a one-time task execution.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
run_date: Datetime when the task should run (must be timezone-aware UTC)
|
||||
job_id: Unique identifier for this job
|
||||
args: Positional arguments to pass to func
|
||||
kwargs: Keyword arguments to pass to func
|
||||
replace_existing: If True, replace existing job with same ID
|
||||
func: Function to execute
|
||||
run_date: Date/time to run the task
|
||||
job_id: Unique job ID
|
||||
kwargs: Keyword arguments for the function
|
||||
replace_existing: Whether to replace existing job with same ID
|
||||
|
||||
Returns:
|
||||
Job ID
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(
|
||||
f"Scheduler not running, but scheduling job {job_id} anyway. "
|
||||
"APScheduler will start automatically when needed."
|
||||
)
|
||||
|
||||
try:
|
||||
# Ensure run_date is timezone-aware (UTC)
|
||||
if run_date.tzinfo is None:
|
||||
from datetime import timezone
|
||||
run_date = run_date.replace(tzinfo=timezone.utc)
|
||||
logger.debug(f"Added UTC timezone to run_date: {run_date}")
|
||||
|
||||
self.scheduler.add_job(
|
||||
func,
|
||||
trigger=DateTrigger(run_date=run_date),
|
||||
args=args,
|
||||
kwargs=kwargs or {},
|
||||
id=job_id,
|
||||
kwargs=kwargs or {},
|
||||
replace_existing=replace_existing,
|
||||
misfire_grace_time=3600 # 1 hour grace period for missed jobs
|
||||
misfire_grace_time=3600 # 1 hour grace period
|
||||
)
|
||||
|
||||
# Get updated job count
|
||||
all_jobs = self.scheduler.get_jobs()
|
||||
one_time_jobs = [j for j in all_jobs if j.id != 'check_due_tasks']
|
||||
|
||||
# Extract user_id from kwargs if available for logging and job store
|
||||
user_id = kwargs.get('user_id', None) if kwargs else None
|
||||
func_name = func.__name__ if hasattr(func, '__name__') else str(func)
|
||||
|
||||
# Get job store name for user (if user_id provided)
|
||||
job_store_name = 'default'
|
||||
if user_id:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
job_store_name = get_user_job_store_name(user_id, db)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not determine job store for user {user_id}: {e}")
|
||||
|
||||
# Note: APScheduler doesn't support dynamic job store creation
|
||||
# We use 'default' for all jobs but log the user's job store name for debugging
|
||||
# The actual user isolation is handled through task filtering by user_id
|
||||
|
||||
# Log detailed one-time task scheduling information (use WARNING level for visibility)
|
||||
log_message = (
|
||||
f"[Scheduler] 📅 Scheduled One-Time Task\n"
|
||||
f" ├─ Job ID: {job_id}\n"
|
||||
f" ├─ Function: {func_name}\n"
|
||||
f" ├─ User ID: {user_id or 'system'}\n"
|
||||
f" ├─ Job Store: {job_store_name} (user context)\n"
|
||||
f" ├─ Scheduled For: {run_date}\n"
|
||||
f" ├─ Replace Existing: {replace_existing}\n"
|
||||
f" ├─ Total One-Time Jobs: {len(one_time_jobs)}\n"
|
||||
f" └─ Total Scheduled Jobs: {len(all_jobs)}"
|
||||
)
|
||||
logger.warning(log_message)
|
||||
|
||||
# Log job scheduling to event log for dashboard
|
||||
if user_id:
|
||||
try:
|
||||
event_db = get_session_for_user(user_id)
|
||||
if event_db:
|
||||
event_log = SchedulerEventLog(
|
||||
event_type='job_scheduled',
|
||||
event_date=datetime.utcnow(),
|
||||
job_id=job_id,
|
||||
job_type='one_time',
|
||||
user_id=user_id,
|
||||
event_data={
|
||||
'function_name': func_name,
|
||||
'job_store': job_store_name,
|
||||
'scheduled_for': run_date.isoformat(),
|
||||
'replace_existing': replace_existing
|
||||
}
|
||||
)
|
||||
event_db.add(event_log)
|
||||
event_db.commit()
|
||||
event_db.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to log job scheduling event: {e}")
|
||||
|
||||
logger.info(f"Scheduled one-time task {job_id} at {run_date}")
|
||||
return job_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule one-time task {job_id}: {e}")
|
||||
raise
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
async def execute_task_by_type(self, task_type: str, user_id: str, payload: Dict[str, Any]):
|
||||
"""
|
||||
Execute a task by type and payload immediately.
|
||||
Used for one-time tasks triggered by system events.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
TaskStub = namedtuple('TaskStub', ['user_id', 'payload', 'id'])
|
||||
task_stub = TaskStub(user_id=user_id, payload=payload, id=f"manual_{datetime.utcnow().timestamp()}")
|
||||
|
||||
await execute_task_async(self, task_type, task_stub, execution_source="manual")
|
||||
|
||||
|
||||
@@ -67,6 +67,77 @@ class StoryImageGenerationService:
|
||||
clean_title = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in scene_title[:30])
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"scene_{scene_number}_{clean_title}_{unique_id}.png"
|
||||
|
||||
def _refine_image_prompt_with_bible(
|
||||
self,
|
||||
image_prompt: str,
|
||||
scene: Dict[str, Any],
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Lightweight image prompt refinement using the anime story bible.
|
||||
|
||||
Takes the existing scene image_prompt and enriches it with visual_style,
|
||||
world, and cast hints from the bible. This is deterministic and avoids
|
||||
extra LLM calls.
|
||||
"""
|
||||
if not image_prompt or not isinstance(image_prompt, str):
|
||||
return image_prompt
|
||||
|
||||
if not anime_bible or not isinstance(anime_bible, dict):
|
||||
return image_prompt
|
||||
|
||||
visual_style = anime_bible.get("visual_style") or {}
|
||||
world = anime_bible.get("world") or {}
|
||||
main_cast = anime_bible.get("main_cast") or []
|
||||
|
||||
parts: List[str] = []
|
||||
|
||||
style_preset = visual_style.get("style_preset")
|
||||
if style_preset:
|
||||
parts.append(f"{style_preset} anime illustration style")
|
||||
|
||||
camera_style = visual_style.get("camera_style")
|
||||
if camera_style:
|
||||
parts.append(f"framing and camera style: {camera_style}")
|
||||
|
||||
color_mood = visual_style.get("color_mood")
|
||||
if color_mood:
|
||||
parts.append(f"color mood: {color_mood}")
|
||||
|
||||
lighting = visual_style.get("lighting")
|
||||
if lighting:
|
||||
parts.append(f"lighting: {lighting}")
|
||||
|
||||
line_style = visual_style.get("line_style")
|
||||
if line_style:
|
||||
parts.append(f"line style: {line_style}")
|
||||
|
||||
extra_tags = visual_style.get("extra_tags") or []
|
||||
if isinstance(extra_tags, (list, tuple)):
|
||||
extra_text = ", ".join(str(tag) for tag in extra_tags[:6] if tag)
|
||||
if extra_text:
|
||||
parts.append(extra_text)
|
||||
|
||||
setting = world.get("setting") if isinstance(world, dict) else None
|
||||
if setting:
|
||||
parts.append(f"world setting: {setting}")
|
||||
|
||||
if isinstance(main_cast, list):
|
||||
names = [
|
||||
c.get("name")
|
||||
for c in main_cast
|
||||
if isinstance(c, dict) and c.get("name")
|
||||
]
|
||||
if names:
|
||||
joined = ", ".join(names[:4])
|
||||
parts.append(f"keep character designs consistent for: {joined}")
|
||||
|
||||
if not parts:
|
||||
return image_prompt
|
||||
|
||||
suffix = ", " + ", ".join(parts)
|
||||
return image_prompt.strip() + suffix
|
||||
|
||||
def generate_scene_image(
|
||||
self,
|
||||
@@ -75,7 +146,8 @@ class StoryImageGenerationService:
|
||||
provider: Optional[str] = None,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
model: Optional[str] = None
|
||||
model: Optional[str] = None,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an image for a single story scene.
|
||||
@@ -94,6 +166,16 @@ class StoryImageGenerationService:
|
||||
scene_number = scene.get("scene_number", 0)
|
||||
scene_title = scene.get("title", "Untitled")
|
||||
image_prompt = scene.get("image_prompt", "")
|
||||
|
||||
if anime_bible:
|
||||
try:
|
||||
image_prompt = self._refine_image_prompt_with_bible(
|
||||
image_prompt=image_prompt,
|
||||
scene=scene,
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryImageGeneration] Failed to refine image prompt with bible: {e}")
|
||||
|
||||
if not image_prompt:
|
||||
raise ValueError(f"Scene {scene_number} ({scene_title}) has no image_prompt")
|
||||
@@ -156,7 +238,8 @@ class StoryImageGenerationService:
|
||||
height: int = 1024,
|
||||
model: Optional[str] = None,
|
||||
progress_callback: Optional[callable] = None,
|
||||
db: Optional[Session] = None
|
||||
db: Optional[Session] = None,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate images for multiple story scenes.
|
||||
@@ -192,7 +275,7 @@ class StoryImageGenerationService:
|
||||
width=width,
|
||||
height=height,
|
||||
model=model,
|
||||
db=db
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
|
||||
image_results.append(image_result)
|
||||
@@ -295,4 +378,3 @@ class StoryImageGenerationService:
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryImageGeneration] Error regenerating image for scene {scene_number}: {e}")
|
||||
raise RuntimeError(f"Failed to regenerate image for scene {scene_number}: {str(e)}") from e
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class StoryOutlineMixin(StoryServiceBase):
|
||||
ending_preference: str,
|
||||
user_id: str,
|
||||
use_structured_output: bool = True,
|
||||
include_anime_bible: bool = False,
|
||||
) -> Any:
|
||||
"""Generate a story outline with optional structured JSON output."""
|
||||
persona_prompt = self.build_persona_prompt(
|
||||
|
||||
@@ -145,20 +145,45 @@ Write ONLY the premise sentence(s). Do not write anything else.
|
||||
"reasoning",
|
||||
],
|
||||
},
|
||||
"minItems": 1,
|
||||
"maxItems": 1,
|
||||
}
|
||||
},
|
||||
"required": ["options"],
|
||||
}
|
||||
|
||||
def _build_idea_enhance_schema(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"suggestions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"idea": {"type": "string"},
|
||||
"whats_missing": {"type": "string"},
|
||||
"why_choose": {"type": "string"},
|
||||
},
|
||||
"required": ["idea", "whats_missing", "why_choose"],
|
||||
},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
}
|
||||
},
|
||||
"required": ["options"],
|
||||
"required": ["suggestions"],
|
||||
}
|
||||
|
||||
def generate_story_setup_options(
|
||||
self,
|
||||
*,
|
||||
story_idea: str,
|
||||
story_mode: str | None,
|
||||
story_template: str | None,
|
||||
brand_context: Dict[str, Any] | None,
|
||||
user_id: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate 3 story setup options from a user's story idea."""
|
||||
"""Generate a single story setup option from a user's story idea."""
|
||||
|
||||
suggested_writing_styles = ['Formal', 'Casual', 'Poetic', 'Humorous', 'Academic', 'Journalistic', 'Narrative']
|
||||
suggested_story_tones = ['Dark', 'Uplifting', 'Suspenseful', 'Whimsical', 'Melancholic', 'Mysterious', 'Romantic', 'Adventurous']
|
||||
@@ -167,12 +192,59 @@ Write ONLY the premise sentence(s). Do not write anything else.
|
||||
suggested_content_ratings = ['G', 'PG', 'PG-13', 'R']
|
||||
suggested_ending_preferences = ['Happy', 'Tragic', 'Cliffhanger', 'Twist', 'Open-ended', 'Bittersweet']
|
||||
|
||||
mode_label = None
|
||||
if story_mode == "marketing":
|
||||
mode_label = "Non-fiction marketing story (brand or product campaign)"
|
||||
elif story_mode == "pure":
|
||||
mode_label = "Fiction story"
|
||||
|
||||
template_label = None
|
||||
if story_template == "product_story":
|
||||
template_label = "Product Story"
|
||||
elif story_template == "brand_manifesto":
|
||||
template_label = "Brand Manifesto"
|
||||
elif story_template == "founder_story":
|
||||
template_label = "Founder Story"
|
||||
elif story_template == "customer_story":
|
||||
template_label = "Customer Story"
|
||||
elif story_template == "short_fiction":
|
||||
template_label = "Short Fiction"
|
||||
elif story_template == "long_fiction":
|
||||
template_label = "Long Fiction"
|
||||
elif story_template == "anime_fiction":
|
||||
template_label = "Anime Fiction"
|
||||
elif story_template == "experimental_fiction":
|
||||
template_label = "Experimental Fiction"
|
||||
|
||||
brand_name = None
|
||||
writing_tone = None
|
||||
audience_description = None
|
||||
if isinstance(brand_context, dict):
|
||||
brand_name = brand_context.get("brand_name")
|
||||
writing_tone = brand_context.get("writing_tone")
|
||||
target_audience = brand_context.get("target_audience")
|
||||
if isinstance(target_audience, dict):
|
||||
audience_description = target_audience.get("description") or target_audience.get("summary")
|
||||
elif isinstance(target_audience, str):
|
||||
audience_description = target_audience
|
||||
|
||||
setup_prompt = f"""\
|
||||
You are an expert story writer and creative writing assistant. A user has provided the following story idea or information:
|
||||
You are an expert story writer and creative writing assistant.
|
||||
|
||||
{"This is a " + mode_label + "." if mode_label else ""}
|
||||
{("The user selected the template: " + template_label + ".") if template_label else ""}
|
||||
|
||||
The story should stay consistent with the brand and audience context below when relevant:
|
||||
|
||||
- Brand name or site: {brand_name or "Not specified"}
|
||||
- Headline/overall writing tone: {writing_tone or "Not specified"}
|
||||
- Audience description: {audience_description or "Not specified"}
|
||||
|
||||
The user has provided the following story idea or information:
|
||||
|
||||
{story_idea}
|
||||
|
||||
Based on this story idea, generate exactly 3 different, well-thought-out story setup options. Each option should be CREATIVE, PERSONALIZED, and perfectly tailored to the user's specific story idea.
|
||||
Based on this story idea, generate exactly 1 well-thought-out story setup option. The setup should be CREATIVE, PERSONALIZED, and perfectly tailored to the user's specific story idea.
|
||||
|
||||
**CRITICAL - Creative Freedom:**
|
||||
- You have COMPLETE FREEDOM to craft personalized values that best fit the user's story idea
|
||||
@@ -183,7 +255,7 @@ Based on this story idea, generate exactly 3 different, well-thought-out story s
|
||||
- Narrative POV: "Second Person (You)" or "Omniscient Narrator as Guide" (not just standard options)
|
||||
- The goal is to create the PERFECT setup for THIS specific story, not to fit into generic categories
|
||||
|
||||
Each option should:
|
||||
The setup should:
|
||||
1. Have a unique and creative persona that fits the story idea perfectly
|
||||
2. Define a compelling story setting that brings the idea to life
|
||||
3. Describe interesting and engaging characters
|
||||
@@ -212,23 +284,23 @@ Each option should:
|
||||
|
||||
**Remember:** These are ONLY suggestions. If a custom value better serves the story idea, CREATE IT!
|
||||
|
||||
Return exactly 3 options as a JSON array. Each option must include a "premise" field with the story premise.
|
||||
Return exactly 1 option as a JSON array with a single object in "options". The object must include a "premise" field with the story premise.
|
||||
"""
|
||||
|
||||
setup_schema = self._build_setup_schema()
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
|
||||
logger.info(f"[StoryWriter] Generating story setup option for user {user_id}")
|
||||
response = self.load_json_response(
|
||||
llm_text_gen(prompt=setup_prompt, json_struct=setup_schema, user_id=user_id)
|
||||
)
|
||||
|
||||
options = response.get("options", [])
|
||||
if len(options) != 3:
|
||||
logger.warning(f"[StoryWriter] Expected 3 options but got {len(options)}, correcting count")
|
||||
if len(options) < 3:
|
||||
raise ValueError(f"Expected 3 options but got {len(options)}")
|
||||
options = options[:3]
|
||||
if len(options) != 1:
|
||||
logger.warning(f"[StoryWriter] Expected 1 option but got {len(options)}, correcting count")
|
||||
if len(options) < 1:
|
||||
raise ValueError(f"Expected 1 option but got {len(options)}")
|
||||
options = options[:1]
|
||||
|
||||
for idx, option in enumerate(options):
|
||||
if not option.get("premise") or not option.get("premise", "").strip():
|
||||
@@ -262,7 +334,7 @@ Return exactly 3 options as a JSON array. Each option must include a "premise" f
|
||||
premise += "."
|
||||
option["premise"] = premise
|
||||
|
||||
logger.info(f"[StoryWriter] Generated {len(options)} story setup options with premises for user {user_id}")
|
||||
logger.info(f"[StoryWriter] Generated {len(options)} story setup option(s) with premise for user {user_id}")
|
||||
return options
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -273,3 +345,119 @@ Return exactly 3 options as a JSON array. Each option must include a "premise" f
|
||||
logger.error(f"[StoryWriter] Error generating story setup options: {exc}")
|
||||
raise RuntimeError(f"Failed to generate story setup options: {exc}") from exc
|
||||
|
||||
def enhance_story_idea(
|
||||
self,
|
||||
*,
|
||||
story_idea: str,
|
||||
story_mode: str | None,
|
||||
story_template: str | None,
|
||||
brand_context: Dict[str, Any] | None,
|
||||
user_id: str,
|
||||
fiction_variant: str | None = None,
|
||||
narrative_energy: str | None = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
mode_label = None
|
||||
if story_mode == "marketing":
|
||||
mode_label = "Non-fiction marketing story (brand or product campaign)"
|
||||
elif story_mode == "pure":
|
||||
mode_label = "Fiction story"
|
||||
|
||||
template_label = None
|
||||
if story_template == "product_story":
|
||||
template_label = "Product Story"
|
||||
elif story_template == "brand_manifesto":
|
||||
template_label = "Brand Manifesto"
|
||||
elif story_template == "founder_story":
|
||||
template_label = "Founder Story"
|
||||
elif story_template == "customer_story":
|
||||
template_label = "Customer Story"
|
||||
elif story_template == "short_fiction":
|
||||
template_label = "Short Fiction"
|
||||
elif story_template == "long_fiction":
|
||||
template_label = "Long Fiction"
|
||||
elif story_template == "anime_fiction":
|
||||
template_label = "Anime Fiction"
|
||||
elif story_template == "experimental_fiction":
|
||||
template_label = "Experimental Fiction"
|
||||
|
||||
brand_name = None
|
||||
writing_tone = None
|
||||
audience_description = None
|
||||
if isinstance(brand_context, dict):
|
||||
brand_name = brand_context.get("brand_name")
|
||||
writing_tone = brand_context.get("writing_tone")
|
||||
target_audience = brand_context.get("target_audience")
|
||||
if isinstance(target_audience, dict):
|
||||
audience_description = target_audience.get("description") or target_audience.get("summary")
|
||||
elif isinstance(target_audience, str):
|
||||
audience_description = target_audience
|
||||
|
||||
fiction_focus_line = ""
|
||||
if fiction_variant:
|
||||
fiction_focus_line = f'Treat the story as "{fiction_variant}" and lean into that creative focus.'
|
||||
|
||||
energy_line = ""
|
||||
if narrative_energy:
|
||||
energy_line = f'Target narrative energy: {narrative_energy}.'
|
||||
|
||||
enhance_prompt = f"""You are a creative writing coach helping a user refine and expand a story idea.
|
||||
|
||||
{"This is a " + mode_label + "." if mode_label else ""}
|
||||
{("The user selected the template: " + template_label + ".") if template_label else ""}
|
||||
{fiction_focus_line}
|
||||
{energy_line}
|
||||
|
||||
When relevant, keep the idea aligned with this brand and audience context:
|
||||
- Brand name or site: {brand_name or "Not specified"}
|
||||
- Headline/overall writing tone: {writing_tone or "Not specified"}
|
||||
- Audience description: {audience_description or "Not specified"}
|
||||
|
||||
The user has written the following story idea or concept:
|
||||
|
||||
{story_idea}
|
||||
|
||||
Your task is to propose exactly 3 alternative enhanced story idea options.
|
||||
|
||||
Each option must:
|
||||
- Preserve the user's core premise and intent.
|
||||
- Make the premise clearer and more compelling.
|
||||
- Surface the central conflict or tension.
|
||||
- Clarify the main characters and their goals.
|
||||
- Strengthen the setting and stakes.
|
||||
- Stay at the "idea" level, not a full outline or beat-by-beat breakdown.
|
||||
|
||||
For each option, return three fields:
|
||||
- "idea": 2-4 sentences describing the improved story idea, suitable for a single textarea input.
|
||||
- "whats_missing": 2-4 sentences explaining what important details are missing or underspecified in the current brief. Focus on gaps such as: protagonist details, antagonist or opposing force, stakes, setting and time period, audience/age group, subgenre or type of fiction (for example, anime vs grounded sci-fi), language or tone preferences, and any format constraints.
|
||||
- "why_choose": 1-3 sentences explaining how this option interprets the original idea and why it might be a strong direction for the story.
|
||||
|
||||
Do not write a full story outline.
|
||||
Do not output numbered lists or markdown formatting.
|
||||
|
||||
Return a single JSON object with a "suggestions" array of 3 items, where each item has the keys "idea", "whats_missing", and "why_choose"."""
|
||||
|
||||
schema = self._build_idea_enhance_schema()
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryWriter] Enhancing story idea with structured suggestions for user {user_id}")
|
||||
response = self.load_json_response(
|
||||
llm_text_gen(prompt=enhance_prompt, json_struct=schema, user_id=user_id)
|
||||
)
|
||||
suggestions = response.get("suggestions", [])
|
||||
if len(suggestions) != 3:
|
||||
logger.warning(
|
||||
f"[StoryWriter] Expected 3 idea suggestions but got {len(suggestions)}, correcting count"
|
||||
)
|
||||
if len(suggestions) < 3:
|
||||
raise ValueError(f"Expected 3 suggestions but got {len(suggestions)}")
|
||||
suggestions = suggestions[:3]
|
||||
return suggestions
|
||||
except HTTPException:
|
||||
raise
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error(f"[StoryWriter] Failed to parse JSON response for story idea enhancement: {exc}")
|
||||
raise RuntimeError(f"Failed to parse story idea enhancement suggestions: {exc}") from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Error enhancing story idea: {exc}")
|
||||
raise RuntimeError(f"Failed to enhance story idea: {exc}") from exc
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
|
||||
from .base import StoryServiceBase
|
||||
@@ -36,6 +38,7 @@ class StoryContentMixin(StoryOutlineMixin):
|
||||
content_rating: str,
|
||||
ending_preference: str,
|
||||
story_length: str = "Medium",
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
user_id: str,
|
||||
) -> str:
|
||||
"""Generate the starting section (or full short story)."""
|
||||
@@ -52,6 +55,19 @@ class StoryContentMixin(StoryOutlineMixin):
|
||||
ending_preference,
|
||||
)
|
||||
|
||||
anime_bible_context = ""
|
||||
if anime_bible:
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
anime_bible_context = f"""
|
||||
|
||||
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
outline_text = self._format_outline_for_prompt(outline)
|
||||
story_length_lower = story_length.lower()
|
||||
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
||||
@@ -61,6 +77,8 @@ class StoryContentMixin(StoryOutlineMixin):
|
||||
short_story_prompt = f"""\
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You have a gripping premise in mind:
|
||||
|
||||
{premise}
|
||||
@@ -154,6 +172,285 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
logger.error(f"Story Start Generation Error: {exc}")
|
||||
raise RuntimeError(f"Failed to generate story start: {exc}") from exc
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Anime scene refinement
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def refine_anime_scene_text(
|
||||
self,
|
||||
*,
|
||||
scene: Dict[str, Any],
|
||||
persona: str,
|
||||
story_setting: str,
|
||||
character_input: str,
|
||||
plot_elements: str,
|
||||
writing_style: str,
|
||||
story_tone: str,
|
||||
narrative_pov: str,
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
anime_bible: Optional[Dict[str, Any]],
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
persona_prompt = self.build_persona_prompt(
|
||||
persona,
|
||||
story_setting,
|
||||
character_input,
|
||||
plot_elements,
|
||||
writing_style,
|
||||
story_tone,
|
||||
narrative_pov,
|
||||
audience_age_group,
|
||||
content_rating,
|
||||
"Neutral",
|
||||
)
|
||||
|
||||
anime_bible_context = ""
|
||||
if anime_bible:
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
anime_bible_context = f"""
|
||||
|
||||
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
current_title = scene.get("title", "")
|
||||
current_description = scene.get("description", "")
|
||||
current_image_prompt = scene.get("image_prompt", "")
|
||||
current_audio_narration = scene.get("audio_narration", "")
|
||||
current_character_descriptions = scene.get("character_descriptions") or []
|
||||
current_key_events = scene.get("key_events") or []
|
||||
|
||||
scene_schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"image_prompt": {"type": "string"},
|
||||
"audio_narration": {"type": "string"},
|
||||
"character_descriptions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"key_events": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["title", "description", "image_prompt", "audio_narration"],
|
||||
}
|
||||
|
||||
prompt = f"""
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You are refining a single anime story scene so that it fully respects the anime story bible for characters, world rules, and visual style.
|
||||
|
||||
Current scene:
|
||||
- Title: {current_title}
|
||||
- Description: {current_description}
|
||||
- Image prompt: {current_image_prompt}
|
||||
- Audio narration: {current_audio_narration}
|
||||
- Character descriptions: {current_character_descriptions}
|
||||
- Key events: {current_key_events}
|
||||
|
||||
Refine the scene so that:
|
||||
- Title is concise and evocative
|
||||
- Description clearly describes what happens in the scene
|
||||
- Image prompt is vivid, visual, and aligned with the anime bible style and cast
|
||||
- Audio narration is natural, spoken-friendly text matching the scene
|
||||
- Character descriptions highlight key visual and personality traits relevant to this moment
|
||||
- Key events list the main beats of the scene
|
||||
|
||||
Respond with JSON matching this schema:
|
||||
{scene_schema}
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
json_struct=scene_schema,
|
||||
user_id=user_id,
|
||||
)
|
||||
data = self.load_json_response(raw)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[StoryWriter] Failed to refine anime scene text via LLM: {exc}")
|
||||
return {
|
||||
"scene_number": scene.get("scene_number"),
|
||||
"title": current_title,
|
||||
"description": current_description,
|
||||
"image_prompt": current_image_prompt,
|
||||
"audio_narration": current_audio_narration,
|
||||
"character_descriptions": current_character_descriptions,
|
||||
"key_events": current_key_events,
|
||||
}
|
||||
|
||||
refined = {
|
||||
"scene_number": scene.get("scene_number"),
|
||||
"title": data.get("title", current_title),
|
||||
"description": data.get("description", current_description),
|
||||
"image_prompt": data.get("image_prompt", current_image_prompt),
|
||||
"audio_narration": data.get("audio_narration", current_audio_narration),
|
||||
"character_descriptions": data.get(
|
||||
"character_descriptions", current_character_descriptions
|
||||
),
|
||||
"key_events": data.get("key_events", current_key_events),
|
||||
}
|
||||
return refined
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Anime scene generation from bible
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def generate_anime_scene_from_bible(
|
||||
self,
|
||||
*,
|
||||
premise: str,
|
||||
persona: str,
|
||||
story_setting: str,
|
||||
character_input: str,
|
||||
plot_elements: str,
|
||||
writing_style: str,
|
||||
story_tone: str,
|
||||
narrative_pov: str,
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
anime_bible: Dict[str, Any],
|
||||
previous_scenes: Optional[List[Dict[str, Any]]],
|
||||
target_scene_number: Optional[int],
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
persona_prompt = self.build_persona_prompt(
|
||||
persona,
|
||||
story_setting,
|
||||
character_input,
|
||||
plot_elements,
|
||||
writing_style,
|
||||
story_tone,
|
||||
narrative_pov,
|
||||
audience_age_group,
|
||||
content_rating,
|
||||
"Neutral",
|
||||
)
|
||||
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
|
||||
anime_bible_context = f"""
|
||||
|
||||
You have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. You MUST treat it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
previous_summary_lines: List[str] = []
|
||||
if previous_scenes:
|
||||
for s in previous_scenes[:6]:
|
||||
num = s.get("scene_number")
|
||||
title = s.get("title") or ""
|
||||
desc = s.get("description") or ""
|
||||
summary = desc
|
||||
if len(summary) > 200:
|
||||
summary = summary[:197] + "..."
|
||||
previous_summary_lines.append(
|
||||
f"- Scene {num}: {title} — {summary}".strip()
|
||||
)
|
||||
|
||||
previous_block = ""
|
||||
if previous_summary_lines:
|
||||
previous_block = (
|
||||
"\nPrevious scenes so far (for continuity, do NOT contradict):\n"
|
||||
+ "\n".join(previous_summary_lines)
|
||||
)
|
||||
|
||||
scene_schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"image_prompt": {"type": "string"},
|
||||
"audio_narration": {"type": "string"},
|
||||
"character_descriptions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"key_events": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["title", "description", "image_prompt", "audio_narration"],
|
||||
}
|
||||
|
||||
prompt = f"""
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You are generating a brand new anime story scene that must fully respect the anime story bible for characters, world rules, and visual style.
|
||||
|
||||
Overall premise:
|
||||
{premise}
|
||||
{previous_block}
|
||||
|
||||
Your task:
|
||||
- Create the NEXT SCENE in this story.
|
||||
- It must be consistent with the anime bible (cast, world rules, visual style).
|
||||
- It must logically follow from any previous scenes given above.
|
||||
|
||||
Design the scene so that:
|
||||
- Title is concise and evocative.
|
||||
- Description clearly describes what happens in the scene.
|
||||
- Image prompt is vivid, visual, and aligned with the anime bible style and cast.
|
||||
- Audio narration is natural, spoken-friendly text matching the scene.
|
||||
- Character descriptions highlight key visual and personality traits relevant to this moment.
|
||||
- Key events list the main beats of the scene.
|
||||
|
||||
Respond with JSON matching this schema:
|
||||
{scene_schema}
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
json_struct=scene_schema,
|
||||
user_id=user_id,
|
||||
)
|
||||
data = self.load_json_response(raw)
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate anime scene from bible: {exc}")
|
||||
raise RuntimeError(f"Failed to generate anime scene from bible: {exc}") from exc
|
||||
|
||||
next_scene_number = target_scene_number
|
||||
if next_scene_number is None:
|
||||
if previous_scenes and len(previous_scenes) > 0:
|
||||
last = previous_scenes[-1]
|
||||
try:
|
||||
last_num = int(last.get("scene_number") or 0)
|
||||
except Exception:
|
||||
last_num = len(previous_scenes)
|
||||
next_scene_number = last_num + 1
|
||||
else:
|
||||
next_scene_number = 1
|
||||
|
||||
result = {
|
||||
"scene_number": next_scene_number,
|
||||
"title": data.get("title", "").strip(),
|
||||
"description": data.get("description", "").strip(),
|
||||
"image_prompt": data.get("image_prompt", "").strip(),
|
||||
"audio_narration": data.get("audio_narration", "").strip(),
|
||||
"character_descriptions": data.get("character_descriptions") or [],
|
||||
"key_events": data.get("key_events") or [],
|
||||
}
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Continuation
|
||||
# ------------------------------------------------------------------ #
|
||||
@@ -174,6 +471,7 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
ending_preference: str,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
story_length: str = "Medium",
|
||||
user_id: str,
|
||||
) -> str:
|
||||
@@ -191,6 +489,19 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
ending_preference,
|
||||
)
|
||||
|
||||
anime_bible_context = ""
|
||||
if anime_bible:
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
anime_bible_context = f"""
|
||||
|
||||
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
outline_text = self._format_outline_for_prompt(outline)
|
||||
_, continuation_word_count = self._get_story_length_guidance(story_length)
|
||||
current_word_count = len(story_text.split()) if story_text else 0
|
||||
@@ -227,6 +538,8 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
continuation_prompt = f"""\
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You have a gripping premise in mind:
|
||||
|
||||
{premise}
|
||||
@@ -298,6 +611,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
ending_preference: str,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
user_id: str,
|
||||
max_iterations: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -352,6 +666,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
audience_age_group=audience_age_group,
|
||||
content_rating=content_rating,
|
||||
ending_preference=ending_preference,
|
||||
anime_bible=anime_bible,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not draft:
|
||||
@@ -375,6 +690,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
audience_age_group=audience_age_group,
|
||||
content_rating=content_rating,
|
||||
ending_preference=ending_preference,
|
||||
anime_bible=anime_bible,
|
||||
user_id=user_id,
|
||||
)
|
||||
if continuation:
|
||||
@@ -420,6 +736,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
height: int = 1024,
|
||||
model: Optional[str] = None,
|
||||
db: Optional[Session] = None,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate images for story scenes."""
|
||||
image_service = StoryImageGenerationService()
|
||||
@@ -431,5 +748,6 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
height=height,
|
||||
model=model,
|
||||
db=db,
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
|
||||
|
||||
133
backend/services/story_writer/story_project_service.py
Normal file
133
backend/services/story_writer/story_project_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Story Project Service
|
||||
|
||||
Service layer for managing Story Studio project persistence.
|
||||
Modeled after PodcastService for a consistent project API.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.story_project_models import StoryProject
|
||||
|
||||
|
||||
class StoryProjectService:
|
||||
"""Service for managing Story Studio projects."""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def create_project(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
title: Optional[str] = None,
|
||||
story_mode: Optional[str] = None,
|
||||
story_template: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> StoryProject:
|
||||
project = StoryProject(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
story_mode=story_mode,
|
||||
story_template=story_template,
|
||||
status="draft",
|
||||
current_phase="setup",
|
||||
**kwargs,
|
||||
)
|
||||
self.db.add(project)
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def get_project(self, user_id: str, project_id: str) -> Optional[StoryProject]:
|
||||
return (
|
||||
self.db.query(StoryProject)
|
||||
.filter(
|
||||
and_(
|
||||
StoryProject.project_id == project_id,
|
||||
StoryProject.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def update_project(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
**updates: Any,
|
||||
) -> Optional[StoryProject]:
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return None
|
||||
|
||||
for key, value in updates.items():
|
||||
if hasattr(project, key):
|
||||
setattr(project, key, value)
|
||||
|
||||
project.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def list_projects(
|
||||
self,
|
||||
user_id: str,
|
||||
status: Optional[str] = None,
|
||||
favorites_only: bool = False,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
order_by: str = "updated_at",
|
||||
) -> Tuple[List[StoryProject], int]:
|
||||
query = self.db.query(StoryProject).filter(StoryProject.user_id == user_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(StoryProject.status == status)
|
||||
|
||||
if favorites_only:
|
||||
query = query.filter(StoryProject.is_favorite.is_(True))
|
||||
|
||||
total = query.count()
|
||||
|
||||
if order_by == "created_at":
|
||||
query = query.order_by(desc(StoryProject.created_at))
|
||||
else:
|
||||
query = query.order_by(desc(StoryProject.updated_at))
|
||||
|
||||
projects = query.offset(offset).limit(limit).all()
|
||||
|
||||
return projects, total
|
||||
|
||||
def delete_project(self, user_id: str, project_id: str) -> bool:
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return False
|
||||
|
||||
self.db.delete(project)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def toggle_favorite(self, user_id: str, project_id: str) -> Optional[StoryProject]:
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return None
|
||||
|
||||
project.is_favorite = not project.is_favorite
|
||||
project.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
status: str,
|
||||
) -> Optional[StoryProject]:
|
||||
return self.update_project(user_id, project_id, status=status)
|
||||
|
||||
@@ -149,7 +149,7 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
try:
|
||||
path = request.url.path
|
||||
except Exception:
|
||||
pass
|
||||
path = ""
|
||||
|
||||
db = None
|
||||
try:
|
||||
@@ -159,8 +159,16 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
|
||||
# Safe User-Agent access
|
||||
user_agent = None
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
user_agent = request.headers.get('user-agent')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Detect if this is an API call that should be rate limited
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
api_provider = api_monitor.detect_api_provider(path, user_agent)
|
||||
if not api_provider:
|
||||
return None
|
||||
|
||||
@@ -236,9 +244,28 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
user_id = None
|
||||
try:
|
||||
# PRIORITY 1: Check request.state.user_id (set by API key injection middleware)
|
||||
if hasattr(request.state, 'user_id') and request.state.user_id:
|
||||
user_id = request.state.user_id
|
||||
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
|
||||
if hasattr(request.state, 'user_id'):
|
||||
# Directly check and convert without accessing attribute if None
|
||||
raw_user_id = request.state.user_id
|
||||
|
||||
# Defensive check for Depends object or other complex types
|
||||
if raw_user_id is not None:
|
||||
# If it's a string, use it
|
||||
if isinstance(raw_user_id, str):
|
||||
user_id = raw_user_id
|
||||
# If it has a dependency attribute (likely a Depends object), ignore it
|
||||
elif hasattr(raw_user_id, 'dependency'):
|
||||
logger.warning(f"Monitoring: request.state.user_id is a Depends object, ignoring.")
|
||||
user_id = None
|
||||
# Try to convert to string if it's a simple type
|
||||
else:
|
||||
try:
|
||||
user_id = str(raw_user_id)
|
||||
except:
|
||||
user_id = None
|
||||
|
||||
if user_id:
|
||||
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
|
||||
|
||||
# PRIORITY 2: Check query parameters
|
||||
elif hasattr(request, 'query_params') and 'user_id' in request.query_params:
|
||||
@@ -247,20 +274,23 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
user_id = request.path_params['user_id']
|
||||
|
||||
# PRIORITY 3: Check headers for user identification
|
||||
elif 'x-user-id' in request.headers:
|
||||
user_id = request.headers['x-user-id']
|
||||
elif 'x-user-email' in request.headers:
|
||||
user_id = request.headers['x-user-email'] # Use email as user identifier
|
||||
elif 'x-session-id' in request.headers:
|
||||
user_id = request.headers['x-session-id'] # Use session as fallback
|
||||
|
||||
# Check for authorization header with user info
|
||||
elif 'authorization' in request.headers:
|
||||
# Auth middleware should have set request.state.user_id
|
||||
# If not, this indicates an authentication failure (likely expired token)
|
||||
# Log at debug level to reduce noise - expired tokens are expected
|
||||
# But we can try to decode token if we really needed to, but let's rely on auth middleware
|
||||
pass
|
||||
elif hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
try:
|
||||
if request.headers.get('x-user-id'):
|
||||
user_id = request.headers.get('x-user-id')
|
||||
elif request.headers.get('x-user-email'):
|
||||
user_id = request.headers.get('x-user-email')
|
||||
elif request.headers.get('x-session-id'):
|
||||
user_id = request.headers.get('x-session-id')
|
||||
|
||||
# Check for authorization header with user info
|
||||
elif request.headers.get('authorization'):
|
||||
# Auth middleware should have set request.state.user_id
|
||||
# If not, this indicates an authentication failure (likely expired token)
|
||||
# Log at debug level to reduce noise - expired tokens are expected
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Error accessing request headers: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting user ID: {e}")
|
||||
@@ -269,7 +299,11 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
# Get database session if user identified
|
||||
db = None
|
||||
if user_id:
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database session for user {user_id}: {e}")
|
||||
db = None
|
||||
|
||||
# Capture request body for usage tracking (read once, safely)
|
||||
request_body = None
|
||||
@@ -291,29 +325,52 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
request_body = None
|
||||
|
||||
# Check usage limits before processing
|
||||
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
|
||||
if limit_response:
|
||||
if db: db.close()
|
||||
return limit_response
|
||||
# Skip for OPTIONS requests
|
||||
try:
|
||||
if request.method != "OPTIONS":
|
||||
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
|
||||
if limit_response:
|
||||
if db: db.close()
|
||||
return limit_response
|
||||
except Exception as e:
|
||||
logger.error(f"Error in usage limits middleware: {e}")
|
||||
# Continue processing if usage check fails (fail open)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Capture response body for usage tracking
|
||||
# Extract response body safely for usage tracking
|
||||
response_body = None
|
||||
try:
|
||||
if hasattr(response, 'body'):
|
||||
response_body = response.body.decode('utf-8') if response.body else None
|
||||
elif hasattr(response, '_content'):
|
||||
response_body = response._content.decode('utf-8') if response._content else None
|
||||
except:
|
||||
pass
|
||||
|
||||
if hasattr(response, 'body'):
|
||||
response_body = response.body.decode('utf-8') if response.body else None
|
||||
elif hasattr(response, '_content'):
|
||||
response_body = response._content.decode('utf-8') if response._content else None
|
||||
|
||||
# Track API usage if this is an API call to external providers
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
|
||||
# Safe URL path access
|
||||
try:
|
||||
path = request.url.path
|
||||
except:
|
||||
path = ""
|
||||
|
||||
# Safe User-Agent access - handle case where headers might be a Depends object
|
||||
user_agent = None
|
||||
try:
|
||||
# Defensive check: ensure request.headers is a valid headers object
|
||||
# Some dependency injection failures replace request attributes with Depends objects
|
||||
if hasattr(request, 'headers'):
|
||||
headers_obj = request.headers
|
||||
# Check if it has a 'get' method (like a dict or Headers object)
|
||||
if hasattr(headers_obj, 'get') and callable(headers_obj.get):
|
||||
user_agent = headers_obj.get('user-agent')
|
||||
except:
|
||||
pass
|
||||
|
||||
api_provider = api_monitor.detect_api_provider(path, user_agent)
|
||||
if api_provider and user_id:
|
||||
logger.info(f"Detected API call: {request.url.path} -> {api_provider.value} for user: {user_id}")
|
||||
try:
|
||||
@@ -326,7 +383,7 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
await usage_service.track_api_usage(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=request.url.path,
|
||||
endpoint=path,
|
||||
method=request.method,
|
||||
model_used=usage_metrics.get('model_used'),
|
||||
tokens_input=usage_metrics.get('tokens_input', 0),
|
||||
@@ -335,7 +392,7 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
status_code=status_code,
|
||||
request_size=len(request_body) if request_body else None,
|
||||
response_size=len(response_body) if response_body else None,
|
||||
user_agent=request.headers.get('user-agent'),
|
||||
user_agent=user_agent,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
search_count=usage_metrics.get('search_count', 0),
|
||||
image_count=usage_metrics.get('image_count', 0),
|
||||
|
||||
487
backend/services/subscription/stripe_service.py
Normal file
487
backend/services/subscription/stripe_service.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import os
|
||||
import stripe
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from datetime import datetime
|
||||
|
||||
STRIPE_PLAN_PRICE_MAPPING = {
|
||||
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value): "price_1T2lWHR2EuR7zQJepLIVQ1EJ",
|
||||
(SubscriptionTier.PRO.value, BillingCycle.MONTHLY.value): "price_1T2ljDR2EuR7zQJeuS317KCj",
|
||||
}
|
||||
|
||||
STRIPE_PRICE_TO_PLAN = {
|
||||
price_id: {"tier": SubscriptionTier(tier), "billing_cycle": BillingCycle(billing_cycle)}
|
||||
for (tier, billing_cycle), price_id in STRIPE_PLAN_PRICE_MAPPING.items()
|
||||
}
|
||||
|
||||
class StripeService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.api_key = os.getenv("STRIPE_SECRET_KEY")
|
||||
self.webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
|
||||
if not self.api_key:
|
||||
logger.warning("STRIPE_SECRET_KEY is not set. Stripe integration will not work.")
|
||||
else:
|
||||
stripe.api_key = self.api_key
|
||||
|
||||
def _get_price_id_for_plan(self, tier: SubscriptionTier, billing_cycle: BillingCycle) -> str:
|
||||
key = (tier.value, billing_cycle.value)
|
||||
price_id = STRIPE_PLAN_PRICE_MAPPING.get(key)
|
||||
if not price_id:
|
||||
logger.error(f"No Stripe price configured for tier={tier.value}, billing_cycle={billing_cycle.value}")
|
||||
raise HTTPException(status_code=400, detail="Payment plan is not configured")
|
||||
return price_id
|
||||
|
||||
def _get_plan_for_price_id(self, price_id: str) -> tuple[SubscriptionPlan, BillingCycle]:
|
||||
mapping = STRIPE_PRICE_TO_PLAN.get(price_id)
|
||||
if not mapping:
|
||||
logger.error(f"Unknown Stripe price_id: {price_id}")
|
||||
raise HTTPException(status_code=400, detail="Unknown payment price configuration")
|
||||
tier = mapping["tier"]
|
||||
billing_cycle = mapping["billing_cycle"]
|
||||
plan = (
|
||||
self.db.query(SubscriptionPlan)
|
||||
.filter(SubscriptionPlan.tier == tier, SubscriptionPlan.is_active == True)
|
||||
.order_by(SubscriptionPlan.price_monthly)
|
||||
.first()
|
||||
)
|
||||
if not plan:
|
||||
logger.error(f"No subscription plan found for tier={tier.value}")
|
||||
raise HTTPException(status_code=400, detail="Subscription plan not found for payment price")
|
||||
return plan, billing_cycle
|
||||
|
||||
def _get_or_create_customer(self, user_id: str, email: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get existing Stripe customer ID for user, or create a new one.
|
||||
"""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
if subscription and subscription.stripe_customer_id:
|
||||
return subscription.stripe_customer_id
|
||||
|
||||
# Search Stripe for existing customer by email (if provided) or metadata
|
||||
try:
|
||||
# If we have an email, search by email first
|
||||
if email:
|
||||
existing_customers = stripe.Customer.list(email=email, limit=1)
|
||||
if existing_customers and len(existing_customers.data) > 0:
|
||||
customer = existing_customers.data[0]
|
||||
# Update DB
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer.id
|
||||
self.db.commit()
|
||||
return customer.id
|
||||
|
||||
# Search by metadata user_id
|
||||
existing_customers = stripe.Customer.search(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
limit=1
|
||||
)
|
||||
if existing_customers and len(existing_customers.data) > 0:
|
||||
customer = existing_customers.data[0]
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer.id
|
||||
self.db.commit()
|
||||
return customer.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching Stripe customer: {e}")
|
||||
|
||||
# Create new customer
|
||||
try:
|
||||
customer_data = {
|
||||
"metadata": {"user_id": user_id},
|
||||
}
|
||||
if email:
|
||||
customer_data["email"] = email
|
||||
|
||||
customer = stripe.Customer.create(**customer_data)
|
||||
|
||||
# Update DB
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer.id
|
||||
else:
|
||||
# Create a placeholder subscription record if none exists (usually created on signup/free tier)
|
||||
# But typically we expect a free tier record to exist.
|
||||
pass
|
||||
|
||||
self.db.commit()
|
||||
return customer.id
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Stripe customer: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create payment profile")
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: SubscriptionTier,
|
||||
billing_cycle: BillingCycle,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
user_email: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Stripe Checkout Session for a subscription.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise HTTPException(status_code=500, detail="Payment service not configured")
|
||||
|
||||
price_id = self._get_price_id_for_plan(tier, billing_cycle)
|
||||
customer_id = self._get_or_create_customer(user_id, user_email)
|
||||
|
||||
line_item: Dict[str, Any] = {"price": price_id}
|
||||
try:
|
||||
price = stripe.Price.retrieve(price_id)
|
||||
recurring = getattr(price, "recurring", None)
|
||||
usage_type = None
|
||||
if recurring:
|
||||
if isinstance(recurring, dict):
|
||||
usage_type = recurring.get("usage_type")
|
||||
else:
|
||||
usage_type = getattr(recurring, "usage_type", None)
|
||||
if usage_type != "metered":
|
||||
line_item["quantity"] = 1
|
||||
else:
|
||||
logger.info(f"Detected metered price {price_id}; omitting quantity in Checkout line item")
|
||||
except Exception as e:
|
||||
logger.error(f"Error inspecting Stripe price {price_id}: {e}")
|
||||
line_item["quantity"] = 1
|
||||
|
||||
try:
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
customer=customer_id,
|
||||
payment_method_types=["card"],
|
||||
line_items=[line_item],
|
||||
mode="subscription",
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
metadata={
|
||||
"user_id": user_id,
|
||||
"price_id": price_id,
|
||||
},
|
||||
subscription_data={
|
||||
"metadata": {
|
||||
"user_id": user_id,
|
||||
}
|
||||
},
|
||||
allow_promotion_codes=True,
|
||||
)
|
||||
return checkout_session.url
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating checkout session: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def create_portal_session(self, user_id: str, return_url: str) -> str:
|
||||
"""
|
||||
Create a Stripe Customer Portal session for managing billing.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise HTTPException(status_code=500, detail="Payment service not configured")
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not subscription or not subscription.stripe_customer_id:
|
||||
# Try to find customer by user_id if not in DB
|
||||
try:
|
||||
customers = stripe.Customer.search(query=f"metadata['user_id']:'{user_id}'", limit=1)
|
||||
if customers and len(customers.data) > 0:
|
||||
customer_id = customers.data[0].id
|
||||
# Update DB while we're at it
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer_id
|
||||
self.db.commit()
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="No billing profile found for this user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding customer for portal: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to access billing portal")
|
||||
else:
|
||||
customer_id = subscription.stripe_customer_id
|
||||
|
||||
try:
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
return portal_session.url
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating portal session: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def handle_webhook(self, payload: bytes, sig_header: str):
|
||||
"""
|
||||
Handle Stripe webhooks.
|
||||
"""
|
||||
if not self.webhook_secret:
|
||||
logger.warning("STRIPE_WEBHOOK_SECRET not set. Ignoring webhook.")
|
||||
return
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, self.webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid payload: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
logger.error(f"Invalid signature: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
event_type = event["type"]
|
||||
data = event["data"]["object"]
|
||||
|
||||
logger.info(f"Received Stripe webhook: {event_type}")
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
await self._handle_checkout_completed(data)
|
||||
elif event_type == "invoice.payment_succeeded":
|
||||
await self._handle_invoice_payment_succeeded(data)
|
||||
elif event_type == "invoice.payment_failed":
|
||||
await self._handle_invoice_payment_failed(data)
|
||||
elif event_type == "customer.subscription.updated":
|
||||
await self._handle_subscription_updated(data)
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
await self._handle_subscription_deleted(data)
|
||||
elif event_type.startswith("radar.early_fraud_warning."):
|
||||
await self._handle_early_fraud_warning(data)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
async def _handle_checkout_completed(self, session: Dict[str, Any]):
|
||||
"""
|
||||
Handle successful checkout.
|
||||
"""
|
||||
user_id = session.get("metadata", {}).get("user_id")
|
||||
customer_id = session.get("customer")
|
||||
subscription_id = session.get("subscription")
|
||||
|
||||
if not user_id:
|
||||
logger.error("No user_id in checkout session metadata")
|
||||
return
|
||||
|
||||
logger.info(f"Checkout completed for user {user_id}")
|
||||
|
||||
# Retrieve subscription details to get the plan/price
|
||||
if subscription_id:
|
||||
try:
|
||||
sub = stripe.Subscription.retrieve(subscription_id)
|
||||
price_id = sub['items']['data'][0]['price']['id']
|
||||
# Map price_id to internal plan_id
|
||||
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
|
||||
# For now, we'll assume the metadata or a lookup.
|
||||
# Ideally, store price_id in SubscriptionPlan table or config.
|
||||
|
||||
# Update DB
|
||||
self._update_user_subscription(
|
||||
user_id,
|
||||
stripe_customer_id=customer_id,
|
||||
stripe_subscription_id=subscription_id,
|
||||
status="active",
|
||||
price_id=price_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkout subscription: {e}")
|
||||
|
||||
async def _handle_invoice_payment_succeeded(self, invoice: Dict[str, Any]):
|
||||
"""
|
||||
Handle recurring payment success.
|
||||
"""
|
||||
subscription_id = invoice.get("subscription")
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if not subscription_id:
|
||||
return
|
||||
|
||||
# Find user by stripe_subscription_id or customer_id
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
(UserSubscription.stripe_subscription_id == subscription_id) |
|
||||
(UserSubscription.stripe_customer_id == customer_id)
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.info(f"Payment succeeded for user {subscription.user_id}")
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
# Update period end based on invoice lines period
|
||||
if invoice.get('lines'):
|
||||
period_end = invoice['lines']['data'][0]['period']['end']
|
||||
subscription.current_period_end = datetime.fromtimestamp(period_end)
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
|
||||
subscription_id = invoice.get("subscription")
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if not subscription_id:
|
||||
return
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
(UserSubscription.stripe_subscription_id == subscription_id) |
|
||||
(UserSubscription.stripe_customer_id == customer_id)
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.warning(f"Payment failed for user {subscription.user_id}")
|
||||
subscription.status = UsageStatus.PAST_DUE
|
||||
subscription.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_subscription_updated(self, subscription_obj: Dict[str, Any]):
|
||||
"""
|
||||
Handle subscription updates (cancellations, changes).
|
||||
"""
|
||||
stripe_sub_id = subscription_obj.get("id")
|
||||
status = subscription_obj.get("status")
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.stripe_subscription_id == stripe_sub_id
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.info(f"Subscription {stripe_sub_id} updated to {status}")
|
||||
if status in ["active", "trialing"]:
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
|
||||
subscription.status = UsageStatus.PAST_DUE
|
||||
subscription.is_active = False
|
||||
elif status in ["canceled"]:
|
||||
subscription.status = UsageStatus.CANCELLED
|
||||
subscription.is_active = False
|
||||
subscription.auto_renew = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
|
||||
"""
|
||||
Handle subscription cancellation (immediate).
|
||||
"""
|
||||
stripe_sub_id = subscription_obj.get("id")
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.stripe_subscription_id == stripe_sub_id
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.info(f"Subscription {stripe_sub_id} deleted")
|
||||
subscription.status = UsageStatus.CANCELLED # Need to check if this enum value exists
|
||||
subscription.is_active = False
|
||||
subscription.auto_renew = False
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_early_fraud_warning(self, warning_obj: Dict[str, Any]):
|
||||
efw_id = warning_obj.get("id")
|
||||
if not efw_id:
|
||||
return
|
||||
|
||||
charge_id = warning_obj.get("charge")
|
||||
payment_intent_id = warning_obj.get("payment_intent")
|
||||
created_ts = warning_obj.get("created")
|
||||
created_at = datetime.utcfromtimestamp(created_ts) if created_ts else datetime.utcnow()
|
||||
|
||||
amount = 0
|
||||
currency = ""
|
||||
user_id = None
|
||||
charge_data: Dict[str, Any] = {}
|
||||
|
||||
if charge_id and self.api_key:
|
||||
try:
|
||||
charge = stripe.Charge.retrieve(charge_id)
|
||||
charge_data = charge.to_dict() if hasattr(charge, "to_dict") else dict(charge)
|
||||
amount = charge_data.get("amount") or 0
|
||||
currency = charge_data.get("currency") or ""
|
||||
metadata = charge_data.get("metadata") or {}
|
||||
user_id = metadata.get("user_id")
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving charge for early fraud warning {efw_id}: {e}")
|
||||
|
||||
if not amount:
|
||||
amount = warning_obj.get("amount") or 0
|
||||
if not currency:
|
||||
currency = warning_obj.get("currency") or ""
|
||||
|
||||
existing = self.db.query(FraudWarning).filter(FraudWarning.id == efw_id).first()
|
||||
|
||||
metadata_payload: Dict[str, Any] = {
|
||||
"early_fraud_warning": warning_obj,
|
||||
}
|
||||
if charge_data:
|
||||
metadata_payload["charge"] = charge_data
|
||||
|
||||
if existing:
|
||||
existing.charge_id = charge_id or existing.charge_id
|
||||
existing.payment_intent_id = payment_intent_id or existing.payment_intent_id
|
||||
if user_id:
|
||||
existing.user_id = user_id
|
||||
if amount:
|
||||
existing.amount = amount
|
||||
if currency:
|
||||
existing.currency = currency
|
||||
existing.status = "open"
|
||||
existing.meta_info = metadata_payload
|
||||
else:
|
||||
if not charge_id:
|
||||
return
|
||||
warning = FraudWarning(
|
||||
id=efw_id,
|
||||
charge_id=charge_id,
|
||||
payment_intent_id=payment_intent_id,
|
||||
user_id=user_id,
|
||||
amount=amount or 0,
|
||||
currency=currency or "",
|
||||
status="open",
|
||||
action="none",
|
||||
meta_info=metadata_payload,
|
||||
created_at=created_at,
|
||||
)
|
||||
self.db.add(warning)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def _update_user_subscription(
|
||||
self,
|
||||
user_id: str,
|
||||
stripe_customer_id: str,
|
||||
stripe_subscription_id: str,
|
||||
status: str,
|
||||
price_id: str,
|
||||
):
|
||||
plan, billing_cycle = self._get_plan_for_price_id(price_id)
|
||||
|
||||
subscription = (
|
||||
self.db.query(UserSubscription)
|
||||
.filter(UserSubscription.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
if not subscription:
|
||||
subscription = UserSubscription(
|
||||
user_id=user_id,
|
||||
plan_id=plan.id,
|
||||
billing_cycle=billing_cycle,
|
||||
current_period_start=now,
|
||||
current_period_end=now,
|
||||
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
|
||||
is_active=status == "active",
|
||||
auto_renew=True,
|
||||
)
|
||||
self.db.add(subscription)
|
||||
else:
|
||||
subscription.plan_id = plan.id
|
||||
subscription.billing_cycle = billing_cycle
|
||||
subscription.is_active = status == "active"
|
||||
|
||||
subscription.stripe_customer_id = stripe_customer_id
|
||||
subscription.stripe_subscription_id = stripe_subscription_id
|
||||
|
||||
self.db.commit()
|
||||
@@ -39,9 +39,34 @@ def _generate_simple_infinitetalk_prompt(
|
||||
# Build a balanced prompt: scene description + simple motion hint
|
||||
parts = []
|
||||
|
||||
# Start with the main subject/scene
|
||||
# Add scene context
|
||||
if title and len(title) > 5 and title.lower() not in ("scene", "podcast", "episode"):
|
||||
parts.append(title)
|
||||
|
||||
# Add analysis context
|
||||
analysis = story_context.get("analysis", {})
|
||||
if analysis:
|
||||
content_type = analysis.get("content_type")
|
||||
if content_type:
|
||||
parts.append(f"Style: {content_type}")
|
||||
|
||||
# Audience helps define the formality/vibe
|
||||
audience = analysis.get("audience")
|
||||
if audience:
|
||||
# Just use first few words of audience to keep it short
|
||||
short_audience = " ".join(audience.split()[:3])
|
||||
parts.append(f"For: {short_audience}")
|
||||
|
||||
# Add bible context if available
|
||||
bible = story_context.get("bible", {})
|
||||
if bible:
|
||||
host_persona = bible.get("host_persona")
|
||||
tone = bible.get("tone")
|
||||
if host_persona:
|
||||
parts.append(f"Host: {host_persona}")
|
||||
if tone:
|
||||
parts.append(f"Tone: {tone}")
|
||||
|
||||
elif description:
|
||||
# Take first sentence or first 60 chars
|
||||
desc_part = description.split('.')[0][:60].strip()
|
||||
|
||||
@@ -52,6 +52,46 @@ def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str,
|
||||
image_prompt = (scene_data.get("image_prompt") or "").strip()
|
||||
tone = (story_context.get("story_tone") or "story").strip()
|
||||
setting = (story_context.get("story_setting") or "the scene").strip()
|
||||
anime_bible = story_context.get("anime_bible") or {}
|
||||
|
||||
anime_style_parts = []
|
||||
if isinstance(anime_bible, dict):
|
||||
visual_style = anime_bible.get("visual_style") or {}
|
||||
world = anime_bible.get("world") or {}
|
||||
main_cast = anime_bible.get("main_cast") or []
|
||||
|
||||
style_preset = visual_style.get("style_preset")
|
||||
camera_style = visual_style.get("camera_style")
|
||||
color_mood = visual_style.get("color_mood")
|
||||
lighting = visual_style.get("lighting")
|
||||
line_style = visual_style.get("line_style")
|
||||
extra_tags = visual_style.get("extra_tags") or []
|
||||
|
||||
if style_preset:
|
||||
anime_style_parts.append(f"Follow {style_preset} anime visual style.")
|
||||
if camera_style:
|
||||
anime_style_parts.append(f"Use camera style: {camera_style}.")
|
||||
if color_mood:
|
||||
anime_style_parts.append(f"Color mood: {color_mood}.")
|
||||
if lighting:
|
||||
anime_style_parts.append(f"Lighting: {lighting}.")
|
||||
if line_style:
|
||||
anime_style_parts.append(f"Line art: {line_style}.")
|
||||
if extra_tags:
|
||||
anime_style_parts.append("Style tags: " + ", ".join(str(tag) for tag in extra_tags[:6]))
|
||||
|
||||
if world:
|
||||
setting_desc = world.get("setting")
|
||||
if setting_desc:
|
||||
anime_style_parts.append(f"World context: {setting_desc}.")
|
||||
|
||||
if main_cast:
|
||||
names = [c.get("name") for c in main_cast if isinstance(c, dict) and c.get("name")]
|
||||
if names:
|
||||
joined = ", ".join(names[:4])
|
||||
anime_style_parts.append(f"Keep character designs consistent for: {joined}.")
|
||||
|
||||
anime_style_text = " ".join(anime_style_parts).strip()
|
||||
|
||||
parts = [
|
||||
f"{title} cinematic motion shot.",
|
||||
@@ -60,6 +100,7 @@ def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str,
|
||||
f"Maintain a {tone} mood with natural lighting accents.",
|
||||
f"Honor the original illustration details: {image_prompt[:200]}." if image_prompt else "",
|
||||
"5-second sequence, gentle push-in, flowing cloth and atmospheric particles.",
|
||||
anime_style_text,
|
||||
]
|
||||
fallback_prompt = " ".join(filter(None, parts))
|
||||
return fallback_prompt.strip()
|
||||
@@ -142,6 +183,66 @@ def generate_animation_prompt(
|
||||
title = scene_data.get("title", "")
|
||||
tone = story_context.get("story_tone") or story_context.get("story_tone", "")
|
||||
setting = story_context.get("story_setting") or story_context.get("story_setting", "")
|
||||
anime_bible = story_context.get("anime_bible") or {}
|
||||
|
||||
anime_bible_block = ""
|
||||
if isinstance(anime_bible, dict) and anime_bible:
|
||||
try:
|
||||
visual_style = anime_bible.get("visual_style") or {}
|
||||
world = anime_bible.get("world") or {}
|
||||
main_cast = anime_bible.get("main_cast") or []
|
||||
|
||||
style_lines = []
|
||||
if visual_style:
|
||||
style_preset = visual_style.get("style_preset")
|
||||
camera_style = visual_style.get("camera_style")
|
||||
color_mood = visual_style.get("color_mood")
|
||||
lighting = visual_style.get("lighting")
|
||||
line_style = visual_style.get("line_style")
|
||||
extra_tags = visual_style.get("extra_tags") or []
|
||||
|
||||
if style_preset:
|
||||
style_lines.append(f"- Visual style preset: {style_preset}")
|
||||
if camera_style:
|
||||
style_lines.append(f"- Preferred camera style: {camera_style}")
|
||||
if color_mood:
|
||||
style_lines.append(f"- Color mood: {color_mood}")
|
||||
if lighting:
|
||||
style_lines.append(f"- Lighting: {lighting}")
|
||||
if line_style:
|
||||
style_lines.append(f"- Line art style: {line_style}")
|
||||
if extra_tags:
|
||||
style_lines.append(
|
||||
"- Extra style tags: " + ", ".join(str(tag) for tag in extra_tags[:6])
|
||||
)
|
||||
|
||||
cast_line = ""
|
||||
if main_cast:
|
||||
names = [c.get("name") for c in main_cast if isinstance(c, dict) and c.get("name")]
|
||||
if names:
|
||||
cast_line = "- Main cast to keep visually consistent: " + ", ".join(names[:4])
|
||||
|
||||
world_line = ""
|
||||
if world:
|
||||
setting_desc = world.get("setting")
|
||||
if setting_desc:
|
||||
world_line = "- World/setting context: " + str(setting_desc)
|
||||
|
||||
detail_lines = []
|
||||
if cast_line:
|
||||
detail_lines.append(cast_line)
|
||||
if world_line:
|
||||
detail_lines.append(world_line)
|
||||
detail_lines.extend(style_lines)
|
||||
|
||||
if detail_lines:
|
||||
anime_bible_block = (
|
||||
"\nANIME STORY BIBLE VISUAL GUIDANCE:\n"
|
||||
+ "\n".join(detail_lines)
|
||||
+ "\nAlways respect these constraints in the motion description."
|
||||
)
|
||||
except Exception:
|
||||
anime_bible_block = ""
|
||||
|
||||
prompt = f"""
|
||||
Create a concise animation prompt (2-3 sentences) for a 5-second cinematic clip.
|
||||
@@ -151,6 +252,7 @@ Description: {description}
|
||||
Existing Image Prompt: {image_prompt}
|
||||
Story Tone: {tone}
|
||||
Setting: {setting}
|
||||
{anime_bible_block}
|
||||
|
||||
Focus on:
|
||||
- Motion of characters/objects
|
||||
|
||||
@@ -132,7 +132,19 @@ class YouTubeSceneBuilderService:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate scenes from video plan using AI."""
|
||||
|
||||
content_outline = video_plan.get("content_outline", [])
|
||||
raw_content_outline = video_plan.get("content_outline", [])
|
||||
content_outline: List[Dict[str, Any]] = []
|
||||
for item in raw_content_outline:
|
||||
if isinstance(item, dict):
|
||||
content_outline.append(item)
|
||||
else:
|
||||
content_outline.append(
|
||||
{
|
||||
"section": str(item),
|
||||
"description": "",
|
||||
"duration_estimate": 0,
|
||||
}
|
||||
)
|
||||
hook_strategy = video_plan.get("hook_strategy", "")
|
||||
call_to_action = video_plan.get("call_to_action", "")
|
||||
visual_style = video_plan.get("visual_style", "cinematic")
|
||||
@@ -263,16 +275,32 @@ Write narration that:
|
||||
# Normalize scene data
|
||||
normalized_scenes = []
|
||||
for idx, scene in enumerate(scenes, 1):
|
||||
normalized_scenes.append({
|
||||
"scene_number": scene.get("scene_number", idx),
|
||||
"title": scene.get("title", f"Scene {idx}"),
|
||||
"narration": scene.get("narration", ""),
|
||||
"visual_description": scene.get("visual_description", ""),
|
||||
"duration_estimate": scene.get("duration_estimate", scene_duration_range[0]),
|
||||
"emphasis": scene.get("emphasis", "main_content"),
|
||||
"visual_cues": scene.get("visual_cues", []),
|
||||
"visual_prompt": scene.get("visual_description", ""), # Initial prompt
|
||||
})
|
||||
if isinstance(scene, dict):
|
||||
scene_data = scene
|
||||
else:
|
||||
scene_data = {
|
||||
"scene_number": idx,
|
||||
"title": f"Scene {idx}",
|
||||
"narration": str(scene),
|
||||
"visual_description": "",
|
||||
"duration_estimate": scene_duration_range[0],
|
||||
"emphasis": "main_content",
|
||||
"visual_cues": [],
|
||||
}
|
||||
normalized_scenes.append(
|
||||
{
|
||||
"scene_number": scene_data.get("scene_number", idx),
|
||||
"title": scene_data.get("title", f"Scene {idx}"),
|
||||
"narration": scene_data.get("narration", ""),
|
||||
"visual_description": scene_data.get("visual_description", ""),
|
||||
"duration_estimate": scene_data.get(
|
||||
"duration_estimate", scene_duration_range[0]
|
||||
),
|
||||
"emphasis": scene_data.get("emphasis", "main_content"),
|
||||
"visual_cues": scene_data.get("visual_cues", []),
|
||||
"visual_prompt": scene_data.get("visual_description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return normalized_scenes
|
||||
|
||||
@@ -287,16 +315,32 @@ Write narration that:
|
||||
|
||||
normalized_scenes = []
|
||||
for idx, scene in enumerate(scenes, 1):
|
||||
normalized_scenes.append({
|
||||
"scene_number": scene.get("scene_number", idx),
|
||||
"title": scene.get("title", f"Scene {idx}"),
|
||||
"narration": scene.get("narration", ""),
|
||||
"visual_description": scene.get("visual_description", ""),
|
||||
"duration_estimate": scene.get("duration_estimate", scene_duration_range[0]),
|
||||
"emphasis": scene.get("emphasis", "main_content"),
|
||||
"visual_cues": scene.get("visual_cues", []),
|
||||
"visual_prompt": scene.get("visual_description", ""), # Initial prompt
|
||||
})
|
||||
if isinstance(scene, dict):
|
||||
scene_data = scene
|
||||
else:
|
||||
scene_data = {
|
||||
"scene_number": idx,
|
||||
"title": f"Scene {idx}",
|
||||
"narration": str(scene),
|
||||
"visual_description": "",
|
||||
"duration_estimate": scene_duration_range[0],
|
||||
"emphasis": "main_content",
|
||||
"visual_cues": [],
|
||||
}
|
||||
normalized_scenes.append(
|
||||
{
|
||||
"scene_number": scene_data.get("scene_number", idx),
|
||||
"title": scene_data.get("title", f"Scene {idx}"),
|
||||
"narration": scene_data.get("narration", ""),
|
||||
"visual_description": scene_data.get("visual_description", ""),
|
||||
"duration_estimate": scene_data.get(
|
||||
"duration_estimate", scene_duration_range[0]
|
||||
),
|
||||
"emphasis": scene_data.get("emphasis", "main_content"),
|
||||
"visual_cues": scene_data.get("visual_cues", []),
|
||||
"visual_prompt": scene_data.get("visual_description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[YouTubeSceneBuilder] ✅ Normalized {len(normalized_scenes)} scenes "
|
||||
|
||||
@@ -6,10 +6,13 @@ Promotes reuse between Podcast, YouTube, and other media-heavy modules.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from services.database import WORKSPACE_DIR
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -58,6 +61,23 @@ def resolve_media_path(media_url_or_path: str) -> Optional[Path]:
|
||||
if not filename:
|
||||
return None
|
||||
|
||||
# Handle workspace avatar assets: /api/assets/{user_id}/avatars/{filename}
|
||||
if "/api/assets/" in media_url_or_path and "/avatars/" in media_url_or_path:
|
||||
try:
|
||||
parsed_path = urlparse(media_url_or_path).path
|
||||
parts = parsed_path.split("/")
|
||||
if len(parts) >= 6:
|
||||
user_id = parts[3]
|
||||
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ("-", "_"))
|
||||
if safe_user_id == user_id:
|
||||
safe_filename = os.path.basename(filename)
|
||||
assets_path = Path(WORKSPACE_DIR) / f"workspace_{safe_user_id}" / "assets" / "avatars" / safe_filename
|
||||
if assets_path.exists() and assets_path.is_file():
|
||||
logger.debug(f"[MediaUtils] Resolved assets avatar {media_url_or_path} to {assets_path}")
|
||||
return assets_path
|
||||
except Exception as exc:
|
||||
logger.error(f"[MediaUtils] Error resolving assets avatar path: {exc}")
|
||||
|
||||
# Define search paths in order of likelihood
|
||||
# We search all avatar/image directories
|
||||
search_paths: List[Path] = [
|
||||
|
||||
Reference in New Issue
Block a user