Release Candidate: Production Release with Multi-Tenant & Onboarding Enhancements

This commit is contained in:
ajaysi
2026-02-28 20:06:26 +05:30
parent 08a1f4a1d8
commit 4828274cbf
162 changed files with 19489 additions and 4300 deletions

View File

@@ -121,7 +121,8 @@ class BaseALwrityAgent(ABC):
if TXTAI_AVAILABLE:
try:
if not self.llm:
self.llm = LLM(model_name)
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
self.llm = LLM(model_name, task="text-generation")
self.txtai_agent = self._create_txtai_agent()
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")

View File

@@ -4,7 +4,6 @@ Bing Webmaster Tools Analytics Handler
Handles Bing Webmaster Tools analytics data retrieval and processing.
"""
import requests
from typing import Dict, Any
from datetime import datetime, timedelta
from loguru import logger
@@ -16,13 +15,23 @@ from ..models.platform_types import PlatformType
from .base_handler import BaseAnalyticsHandler
from ..insights.bing_insights_service import BingInsightsService
from services.bing_analytics_storage_service import BingAnalyticsStorageService
import os
from services.database import get_user_db_path
class BingAnalyticsHandler(BaseAnalyticsHandler):
"""Handler for Bing Webmaster Tools analytics"""
"""
Handler for Bing Webmaster Tools analytics
NOTE (2026-02-14): Known issues and directions
- Verified sites list can be empty despite valid tokens. This leads to partial/error states and prevents storage collection.
Direction: UI now provides a manual site picker (with primary website fallback from onboarding) to trigger storage collection,
and a future improvement should accept a target_url from /api/analytics/data to influence site selection here.
- Token state mismatch (status shows connected, analytics reports expired) can happen across cache boundaries.
Direction: The frontend auto-resyncs once after OAuth success and provides a backend cache clear endpoint.
- Storage-backed summary reads rely on a selected site; when sites are missing, selected_site is None.
Direction: Allow explicit site_url parameter in the analytics orchestrator to override selected_site resolution.
"""
def __init__(self):
super().__init__(PlatformType.BING)
@@ -42,14 +51,22 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
db_url = f'sqlite:///{db_path}'
return BingInsightsService(db_url)
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
async def get_analytics(self, user_id: str, target_url: str = None, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
"""
Get Bing Webmaster analytics data using Bing Webmaster API
"""
self.log_analytics_request(user_id, "get_analytics")
# Check cache first
cached_data = analytics_cache.get('bing_analytics', user_id)
# Check cache first (include date range and target_url in key)
cache_key_parts = [user_id]
if target_url:
cache_key_parts.append(str(target_url))
if start_date:
cache_key_parts.append(str(start_date))
if end_date:
cache_key_parts.append(str(end_date))
cache_key = "_".join(cache_key_parts)
cached_data = analytics_cache.get('bing_analytics', cache_key)
if cached_data:
logger.info(f"Using cached Bing analytics for user {user_id}")
return AnalyticsData(**cached_data)
@@ -107,9 +124,22 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
site_url_for_storage = selected_site.get('Url', '') if selected_site else ''
logger.info(f"Using Bing site URL: {site_url_for_storage}")
# Determine date range (defaults to last 30 days)
if not end_date:
end_date = datetime.now().strftime('%Y-%m-%d')
if not start_date:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
# Compute days for storage/insights services (at least 1)
try:
dt_end = datetime.strptime(end_date, '%Y-%m-%d')
dt_start = datetime.strptime(start_date, '%Y-%m-%d')
days_range = max(1, (dt_end - dt_start).days + 1)
except Exception:
days_range = 30
query_stats = {}
try:
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=30)
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=days_range)
if stored and isinstance(stored, dict):
query_stats = {
'total_clicks': stored.get('summary', {}).get('total_clicks', 0),
@@ -138,19 +168,20 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
'insights': insights,
'note': 'Bing Webmaster API provides SEO insights, search performance, and index status data'
}
if (not sites) or (metrics.get('total_impressions', 0) == 0 and metrics.get('total_clicks', 0) == 0):
result = self.create_partial_response(metrics=metrics, error_message='Connected to Bing; waiting for stored analytics or site verification')
if not sites:
result = self.create_partial_response(metrics=metrics, error_message='Connected to Bing; no verified sites found')
else:
result = self.create_success_response(metrics=metrics)
result = self.create_success_response(metrics=metrics, date_range={'start': start_date, 'end': end_date})
analytics_cache.set('bing_analytics', user_id, result.__dict__)
analytics_cache.set('bing_analytics', cache_key, result.__dict__)
return result
except Exception as e:
self.log_analytics_error(user_id, "get_analytics", e)
error_result = self.create_error_response(str(e))
analytics_cache.set('bing_analytics', user_id, error_result.__dict__, ttl_override=300)
# Cache error briefly to prevent hammering but recover quickly
analytics_cache.set('bing_analytics', cache_key, error_result.__dict__, ttl_override=30)
return error_result
def _get_enhanced_insights_with_service(self, insights_service: BingInsightsService, user_id: str, site_url: str) -> Dict[str, Any]:

View File

@@ -22,7 +22,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
super().__init__(PlatformType.GSC)
self.gsc_service = GSCService()
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
async def get_analytics(self, user_id: str, target_url: str = None, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
"""
Get Google Search Console analytics data with caching
@@ -35,8 +35,16 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
self.log_analytics_request(user_id, "get_analytics")
# Check cache first - GSC API calls can be expensive
# Include target_url in cache key if provided
cache_key = f"{user_id}_{target_url}" if target_url else user_id
# Include target_url and date range in cache key if provided
cache_key_parts = [user_id]
if target_url:
cache_key_parts.append(str(target_url))
if start_date:
cache_key_parts.append(str(start_date))
if end_date:
cache_key_parts.append(str(end_date))
# Bump cache version to include page insights (v2)
cache_key = "_".join(cache_key_parts + ['v2pages'])
cached_data = analytics_cache.get('gsc_analytics', cache_key)
if cached_data:
logger.info("Using cached GSC analytics for user {user_id}", user_id=user_id)
@@ -70,9 +78,11 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
site_url = selected_site['siteUrl']
logger.info(f"Using GSC site URL: {site_url}")
# Get search analytics for last 30 days
end_date = datetime.now().strftime('%Y-%m-%d')
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
# Determine date range (defaults to last 30 days)
if not end_date:
end_date = datetime.now().strftime('%Y-%m-%d')
if not start_date:
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
logger.info(f"GSC Date range: {start_date} to {end_date}")
search_analytics = self.gsc_service.get_search_analytics(
@@ -86,10 +96,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
# Process GSC data into standardized format
processed_metrics = self._process_gsc_metrics(search_analytics)
result = self.create_success_response(
metrics=processed_metrics,
date_range={'start': start_date, 'end': end_date}
)
result = self.create_success_response(metrics=processed_metrics, date_range={'start': start_date, 'end': end_date})
# Cache the result to avoid expensive API calls
analytics_cache.set('gsc_analytics', cache_key, result.__dict__)
@@ -101,8 +108,8 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
self.log_analytics_error(user_id, "get_analytics", e)
error_result = self.create_error_response(str(e))
# Cache error result for shorter time to retry sooner
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=300) # 5 minutes
# Cache error result briefly to avoid repeated failures but allow quick recovery
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=30) # 30 seconds
return error_result
def get_connection_status(self, user_id: str) -> Dict[str, Any]:
@@ -202,18 +209,159 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
sorted_queries = sorted(top_queries_source, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
for row in sorted_queries:
clicks_val = row.get('clicks', 0) or 0
impr_val = row.get('impressions', 0) or 0
raw_ctr = row.get('ctr', None)
# Calculate CTR% robustly even if 'ctr' field is missing in row
if raw_ctr is not None:
ctr_percent = round(float(raw_ctr) * 100, 2)
else:
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
top_queries.append({
'query': self._extract_query_from_row(row),
'clicks': row.get('clicks', 0),
'impressions': row.get('impressions', 0),
'ctr': round(row.get('ctr', 0) * 100, 2),
'position': round(row.get('position', 0), 2)
'clicks': clicks_val,
'impressions': impr_val,
'ctr': ctr_percent,
'position': round(row.get('position', 0) or 0, 2)
})
# Prepare Top Pages (requires page dimension, but we only requested query dimension in gsc_service step 3)
# To get top pages, we would need another API call with dimension=['page']
# For now, we'll return empty top_pages or infer from what we have if possible (we can't from query data)
top_pages = []
# Prepare Top Pages from page_data when available
top_pages = []
try:
page_rows = search_analytics.get('page_data', {}).get('rows', [])
qp_rows = search_analytics.get('query_page_data', {}).get('rows', [])
# Build queries-by-page map
queries_by_page: Dict[str, list] = {}
if qp_rows:
for r in qp_rows:
keys = r.get('keys', [])
if not keys or len(keys) < 2:
continue
query_key = keys[0]['keys'][0] if isinstance(keys[0], dict) else str(keys[0])
page_key = keys[1]['keys'][0] if isinstance(keys[1], dict) else str(keys[1])
clicks_val = r.get('clicks', 0) or 0
impr_val = r.get('impressions', 0) or 0
raw_ctr = r.get('ctr', None)
if raw_ctr is not None:
ctr_percent = round(float(raw_ctr) * 100, 2)
else:
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
lst = queries_by_page.setdefault(page_key, [])
lst.append({
'query': query_key,
'clicks': clicks_val,
'impressions': impr_val,
'ctr': ctr_percent,
})
if page_rows:
sorted_pages = sorted(page_rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
for row in sorted_pages:
clicks_val = row.get('clicks', 0) or 0
impr_val = row.get('impressions', 0) or 0
raw_ctr = row.get('ctr', None)
if raw_ctr is not None:
ctr_percent = round(float(raw_ctr) * 100, 2)
else:
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
page_url = self._extract_page_from_row(row)
# attach top queries pointing to this page, sorted by clicks
page_queries = sorted(queries_by_page.get(page_url, []), key=lambda x: x.get('clicks', 0), reverse=True)[:5]
top_pages.append({
'page': page_url,
'clicks': clicks_val,
'impressions': impr_val,
'ctr': ctr_percent,
'position': round(row.get('position', 0) or 0, 2) if 'position' in row else None,
'queries': page_queries
})
except Exception as e:
logger.warning(f"Failed processing top_pages: {e}")
# Detect Cannibalization (query mapping to multiple pages)
cannibalization = []
try:
qp_rows = search_analytics.get('query_page_data', {}).get('rows', [])
q_rows = search_analytics.get('query_data', {}).get('rows', [])
if qp_rows:
# Determine window days for thresholding
from datetime import datetime
start_s = search_analytics.get('startDate')
end_s = search_analytics.get('endDate')
window_days = 30
try:
if start_s and end_s:
sd = datetime.strptime(start_s, "%Y-%m-%d")
ed = datetime.strptime(end_s, "%Y-%m-%d")
window_days = max((ed - sd).days + 1, 1)
except Exception:
pass
min_clicks = 10 if window_days <= 7 else (30 if window_days <= 30 else 60)
# Build map: query -> { page -> metrics }
by_query: Dict[str, Dict[str, Dict[str, float]]] = {}
for r in qp_rows:
keys = r.get('keys', [])
if not keys or len(keys) < 2:
continue
qk = keys[0]['keys'][0] if isinstance(keys[0], dict) else str(keys[0])
pk = keys[1]['keys'][0] if isinstance(keys[1], dict) else str(keys[1])
clicks_val = float(r.get('clicks', 0) or 0)
impr_val = float(r.get('impressions', 0) or 0)
raw_ctr = r.get('ctr', None)
if raw_ctr is not None:
ctr_percent = float(raw_ctr) * 100.0
else:
ctr_percent = (clicks_val / impr_val * 100.0) if impr_val > 0 else 0.0
pos_val = float(r.get('position', 0) or 0)
by_query.setdefault(qk, {}).setdefault(pk, {"clicks": 0.0, "impressions": 0.0, "ctr": 0.0, "position_sum": 0.0, "position_count": 0.0})
agg = by_query[qk][pk]
agg["clicks"] += clicks_val
agg["impressions"] += impr_val
agg["ctr"] = max(agg["ctr"], ctr_percent)
if pos_val > 0:
agg["position_sum"] += pos_val
agg["position_count"] += 1
# Use query totals for context
total_by_query: Dict[str, Dict[str, float]] = {}
for r in q_rows or []:
qk = self._extract_query_from_row(r)
total_by_query[qk] = {
"clicks": float(r.get('clicks', 0) or 0),
"impressions": float(r.get('impressions', 0) or 0),
"position": float(r.get('position', 0) or 0)
}
for qk, pages_map in by_query.items():
if len(pages_map) < 2:
continue
total_clicks = sum(p["clicks"] for p in pages_map.values())
if total_clicks < min_clicks:
continue
qpos = total_by_query.get(qk, {}).get("position", 0.0)
if not (3.0 <= qpos <= 20.0) and qpos != 0.0:
# Skip queries already ranking very well or very poorly (if pos present)
continue
pages_list = []
for pk, m in pages_map.items():
avg_pos = (m["position_sum"] / m["position_count"]) if m["position_count"] > 0 else 0.0
pages_list.append({
"page": pk,
"clicks": round(m["clicks"], 0),
"impressions": round(m["impressions"], 0),
"ctr": round(m["ctr"], 2),
"position": round(avg_pos, 2) if avg_pos > 0 else None
})
pages_list.sort(key=lambda x: x.get("clicks", 0), reverse=True)
target_page = pages_list[0]["page"] if pages_list else None
cannibalization.append({
"query": qk,
"total_clicks": int(round(total_clicks)),
"recommended_target_page": target_page,
"pages": pages_list[:3]
})
# Sort by impact
cannibalization.sort(key=lambda item: item.get("total_clicks", 0), reverse=True)
cannibalization = cannibalization[:10]
except Exception as e:
logger.warning(f"Failed computing cannibalization: {e}")
return {
'connection_status': 'connected',
@@ -224,7 +372,8 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
'avg_position': round(avg_position, 2),
'total_queries': len(top_queries_source) if top_queries_source else 0,
'top_queries': top_queries,
'top_pages': top_pages
'top_pages': top_pages,
'cannibalization': cannibalization
}
except Exception as e:
@@ -256,3 +405,18 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
except Exception as e:
logger.error(f"Error extracting query from row: {e}")
return 'Unknown'
def _extract_page_from_row(self, row: Dict[str, Any]) -> str:
"""Extract page URL from GSC API row data"""
try:
keys = row.get('keys', [])
if keys and len(keys) > 0:
first_key = keys[0]
if isinstance(first_key, dict):
return first_key.get('keys', [''])[0]
else:
return str(first_key)
return ''
except Exception as e:
logger.error(f"Error extracting page from row: {e}")
return ''

View File

@@ -21,7 +21,7 @@ class WixAnalyticsHandler(BaseAnalyticsHandler):
super().__init__(PlatformType.WIX)
self.wix_service = WixService()
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
"""
Get Wix analytics data using the Business Management API

View File

@@ -22,7 +22,7 @@ class WordPressAnalyticsHandler(BaseAnalyticsHandler):
super().__init__(PlatformType.WORDPRESS)
self.wordpress_service = WordPressOAuthService()
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
"""
Get WordPress analytics data using WordPress.com REST API

View File

@@ -42,7 +42,7 @@ class PlatformAnalyticsService:
self.summary_generator = AnalyticsSummaryGenerator()
self.cache_manager = AnalyticsCacheManager()
async def get_comprehensive_analytics(self, user_id: str, platforms: List[str] = None) -> Dict[str, AnalyticsData]:
async def get_comprehensive_analytics(self, user_id: str, platforms: List[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, AnalyticsData]:
"""
Get analytics data from all connected platforms
@@ -93,9 +93,18 @@ class PlatformAnalyticsService:
if handler:
if platform_type == PlatformType.GSC or platform_type == PlatformType.BING:
analytics_data[platform_name] = await handler.get_analytics(user_id, target_url=target_url)
analytics_data[platform_name] = await handler.get_analytics(
user_id,
target_url=target_url,
start_date=start_date,
end_date=end_date
)
else:
analytics_data[platform_name] = await handler.get_analytics(user_id)
analytics_data[platform_name] = await handler.get_analytics(
user_id,
start_date=start_date,
end_date=end_date
)
else:
logger.warning(f"Unknown platform: {platform_name}")
analytics_data[platform_name] = self._create_error_response(platform_name, f"Unknown platform: {platform_name}")

View File

@@ -237,7 +237,7 @@ class BingAnalyticsStorageService:
Dict containing analytics summary
"""
try:
db = self._get_db_session()
db = self._get_db_session(user_id)
# Date range
end_date = datetime.now()
@@ -331,7 +331,7 @@ class BingAnalyticsStorageService:
List of top queries with performance data
"""
try:
db = self._get_db_session()
db = self._get_db_session(user_id)
# Calculate date range
end_date = datetime.now()

View File

@@ -241,6 +241,9 @@ class ExaResearchProvider(BaseProvider):
for idx, result in enumerate(results):
source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '')
# Extract image if available (some Exa results include image URL)
image_url = result.image if hasattr(result, 'image') else None
sources.append({
'title': result.title if hasattr(result, 'title') else '',
'url': result.url if hasattr(result, 'url') else '',
@@ -251,17 +254,21 @@ class ExaResearchProvider(BaseProvider):
'source_type': source_type,
'content': result.text if hasattr(result, 'text') else '',
'highlights': result.highlights if hasattr(result, 'highlights') else [],
'summary': result.summary if hasattr(result, 'summary') else ''
'summary': result.summary if hasattr(result, 'summary') else '',
'image': image_url,
'author': result.author if hasattr(result, 'author') else None
})
return sources
def _get_excerpt(self, result):
"""Extract excerpt from Exa result."""
"""Extract excerpt from Exa result. Prefer highlights if available."""
if hasattr(result, 'highlights') and result.highlights and len(result.highlights) > 0:
return result.highlights[0]
if hasattr(result, 'summary') and result.summary:
return result.summary
if hasattr(result, 'text') and result.text:
return result.text[:500]
elif hasattr(result, 'summary') and result.summary:
return result.summary
return ''
def _determine_source_type(self, url):
@@ -280,16 +287,30 @@ class ExaResearchProvider(BaseProvider):
return 'web'
def _aggregate_content(self, results):
"""Aggregate content from Exa results for LLM analysis."""
"""Aggregate content from Exa results for LLM analysis, including highlights."""
content_parts = []
for idx, result in enumerate(results):
part = [f"Source {idx + 1}: {result.title if hasattr(result, 'title') else 'Untitled'}"]
if hasattr(result, 'url') and result.url:
part.append(f"URL: {result.url}")
# Add highlights if available (most valuable for LLM)
if hasattr(result, 'highlights') and result.highlights:
highlights_text = "\n".join([f"- {h}" for h in result.highlights])
part.append(f"Key Highlights:\n{highlights_text}")
# Add summary if available
if hasattr(result, 'summary') and result.summary:
content_parts.append(f"Source {idx + 1}: {result.summary}")
part.append(f"Summary: {result.summary}")
# Add text snippet if highlights/summary insufficient
elif hasattr(result, 'text') and result.text:
content_parts.append(f"Source {idx + 1}: {result.text[:1000]}")
part.append(f"Excerpt: {result.text[:1000]}")
content_parts.append("\n".join(part))
return "\n\n".join(content_parts)
return "\n\n---\n\n".join(content_parts)
def track_exa_usage(self, user_id: str, cost: float):
"""Track Exa API usage after successful call."""

View File

@@ -159,14 +159,10 @@ class StyleDetectionLogic:
}}
"""
# Call the LLM for analysis
logger.debug("[StyleDetectionLogic.analyze_content_style] Sending enhanced prompt to LLM")
try:
analysis_text = llm_text_gen(prompt, user_id=user_id)
# Clean and parse the response
cleaned_json = self._clean_json_response(analysis_text)
analysis_results = json.loads(cleaned_json)
logger.info("[StyleDetectionLogic.analyze_content_style] Successfully parsed enhanced analysis results")
return {
@@ -179,7 +175,7 @@ class StyleDetectionLogic:
return {
'success': True,
'analysis': fallback_results,
'warning': 'AI analysis failed, used fallback detection'
'warning': f'AI analysis failed ({str(e)}), used fallback detection'
}
except Exception as e:

View File

@@ -145,6 +145,7 @@ def init_user_database(user_id: str):
SubscriptionBase.metadata.create_all(bind=engine)
UserBusinessInfoBase.metadata.create_all(bind=engine)
ContentAssetBase.metadata.create_all(bind=engine)
BingAnalyticsBase.metadata.create_all(bind=engine)
# Initialize default data for new databases
try:

View File

@@ -343,7 +343,11 @@ class GSCService:
if not credentials:
raise ValueError("No valid credentials found")
service = build('searchconsole', 'v1', credentials=credentials)
# Disable discovery file cache (suppress oauth2client file_cache warnings) with safe fallback
try:
service = build('searchconsole', 'v1', credentials=credentials, cache_discovery=False)
except TypeError:
service = build('searchconsole', 'v1', credentials=credentials)
logger.info(f"Authenticated GSC service created for user: {user_id}")
return service
@@ -395,9 +399,12 @@ class GSCService:
# Check cache first
cache_key = f"{user_id}_{site_url}_{start_date}_{end_date}"
cached_data = self._get_cached_data(user_id, site_url, 'analytics', cache_key)
if cached_data:
logger.info(f"Returning cached analytics data for user: {user_id}")
return cached_data
if cached_data and isinstance(cached_data, dict):
has_pages = 'page_data' in cached_data and isinstance(cached_data.get('page_data'), dict)
has_queries = 'query_data' in cached_data and isinstance(cached_data.get('query_data'), dict)
if has_pages and has_queries:
logger.info(f"Returning cached analytics data for user: {user_id} (includes page_data)")
return cached_data
try:
service = self.get_authenticated_service(user_id)
@@ -476,8 +483,54 @@ class GSCService:
).execute()
logger.info(f"GSC Query-level response for user {user_id}: {query_response}")
# Combine overall metrics with query-level data
# Step 4: Get page-level data for top pages insights
page_request = {
'startDate': start_date,
'endDate': end_date,
'dimensions': ['page'], # Get page-level data
'rowLimit': 1000
}
logger.info(f"GSC Page-level request for user {user_id}: {page_request}")
page_rows = []
page_row_count = 0
try:
page_response = service.searchanalytics().query(
siteUrl=site_url,
body=page_request
).execute()
logger.info(f"GSC Page-level response for user {user_id}: {page_response}")
page_rows = page_response.get('rows', [])
page_row_count = page_response.get('rowCount', 0)
except Exception as page_error:
logger.warning(f"GSC Page-level request failed for user {user_id}: {page_error}")
page_rows = []
page_row_count = 0
# Step 5: Get query+page combined data for mapping queries to pages
qp_rows = []
qp_row_count = 0
try:
qp_request = {
'startDate': start_date,
'endDate': end_date,
'dimensions': ['query', 'page'],
'rowLimit': 1000
}
logger.info(f"GSC Query+Page request for user {user_id}: {qp_request}")
qp_response = service.searchanalytics().query(
siteUrl=site_url,
body=qp_request
).execute()
logger.info(f"GSC Query+Page response for user {user_id}: {qp_response}")
qp_rows = qp_response.get('rows', [])
qp_row_count = qp_response.get('rowCount', 0)
except Exception as qp_error:
logger.warning(f"GSC Query+Page request failed for user {user_id}: {qp_error}")
qp_rows = []
qp_row_count = 0
# Combine overall, query, page and query+page data
analytics_data = {
'overall_metrics': {
'rows': response.get('rows', []),
@@ -487,6 +540,14 @@ class GSCService:
'rows': query_response.get('rows', []),
'rowCount': query_response.get('rowCount', 0)
},
'page_data': {
'rows': page_rows,
'rowCount': page_row_count
},
'query_page_data': {
'rows': qp_rows,
'rowCount': qp_row_count
},
'verification_data': {
'rows': verification_rows,
'rowCount': len(verification_rows)
@@ -510,6 +571,8 @@ class GSCService:
'rowCount': response.get('rowCount', 0)
},
'query_data': {'rows': [], 'rowCount': 0},
'page_data': {'rows': [], 'rowCount': 0},
'query_page_data': {'rows': [], 'rowCount': 0},
'verification_data': {
'rows': verification_rows,
'rowCount': len(verification_rows)

View File

@@ -76,7 +76,8 @@ class ALwrityAgentOrchestrator:
try:
# Initialize shared LLM
if TXTAI_AVAILABLE:
self.llm = LLM(self.config.shared_llm)
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
self.llm = LLM(self.config.shared_llm, task="text-generation")
else:
self.llm = None

View File

@@ -181,7 +181,8 @@ class BaseALwrityAgent(ABC):
try:
if not self.llm:
# Create new LLM if not provided
raw_llm = LLM(model_name)
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
raw_llm = LLM(model_name, task="text-generation")
# Wrap it
self.llm = TrackingLLMWrapper(raw_llm, self.user_id, self.model_name)
@@ -906,6 +907,11 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
"name": "task_delegator",
"description": "Delegates specific tasks to specialized agents (content, competitor, seo, social)",
"target": self._delegate_task_tool
},
{
"name": "kickoff_gsc_first_pass",
"description": "Kicks off first-pass execution by invoking SEO/Content default GSC plans",
"target": self._kickoff_gsc_first_pass_tool
}
],
max_iterations=15,
@@ -924,7 +930,9 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
Do not just plan; EXECUTE by delegating.
Always prioritize user goals and maintain safety constraints.
Coordinate multi-agent responses to market changes effectively."""
Coordinate multi-agent responses to market changes effectively.
First, call 'kickoff_gsc_first_pass' to ground the plan on live GSC signals."""
)
)
@@ -1033,6 +1041,37 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
except Exception as e:
return {"error": str(e)}
async def _kickoff_gsc_first_pass_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""Invoke SEO and Content agents' default GSC plans and combine results"""
try:
start_date = context.get("start_date")
end_date = context.get("end_date")
payload = {"start_date": start_date, "end_date": end_date}
results = {}
combined_actions = []
seo = self.sub_agents.get("seo")
if seo and hasattr(seo, "_default_seo_gsc_plan_tool"):
plan = await seo._default_seo_gsc_plan_tool(payload)
results["seo"] = plan
combined_actions.extend(plan.get("actions", []) if isinstance(plan, dict) else [])
content = self.sub_agents.get("content")
if content and hasattr(content, "_default_content_gsc_plan_tool"):
plan = await content._default_content_gsc_plan_tool(payload)
results["content"] = plan
combined_actions.extend(plan.get("actions", []) if isinstance(plan, dict) else [])
return {
"status": "ok",
"invoked": list(results.keys()),
"results": results,
"combined_actions": combined_actions,
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
return {"status": "error", "error": str(e)}
async def _strategy_synthesizer_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""Tool for synthesizing strategies"""
return {

View File

@@ -13,6 +13,7 @@ from loguru import logger
from ..txtai_service import TxtaiIntelligenceService
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, AgentAction
from services.seo_tools.content_strategy_service import ContentStrategyService
from services.analytics import PlatformAnalyticsService
from services.intelligence.sif_agents import SharedLLMWrapper, LocalLLMWrapper
try:
from services.intelligence.sif_integration import SIFIntegrationService
@@ -888,7 +889,37 @@ class ContentStrategyAgent(BaseALwrityAgent):
"name": "sitemap_analyzer",
"description": "Analyzes website structure and publishing velocity via sitemap",
"target": self._sitemap_analyzer_tool
}
},
{
"name": "gsc_low_ctr_queries",
"description": "Returns low-CTR queries with evidence from cached GSC metrics",
"target": self._cs_gsc_low_ctr_queries_tool
},
{
"name": "gsc_striking_distance_queries",
"description": "Returns striking-distance queries (positions ~820) with evidence",
"target": self._cs_gsc_striking_distance_tool
},
{
"name": "gsc_declining_queries",
"description": "Returns period-over-period declining queries with evidence",
"target": self._cs_gsc_declining_queries_tool
},
{
"name": "gsc_low_ctr_pages",
"description": "Returns low-CTR pages with top contributing queries",
"target": self._cs_gsc_low_ctr_pages_tool
},
{
"name": "gsc_cannibalization_candidates",
"description": "Returns query→multiple-pages cannibalization candidates with target recommendation",
"target": self._cs_gsc_cannibalization_candidates_tool
},
{
"name": "default_content_gsc_plan",
"description": "Runs a default first-pass plan using GSC signals (titles/meta, consolidation, refreshes)",
"target": self._default_content_gsc_plan_tool
},
],
max_iterations=8,
system=self.get_effective_system_prompt(f"""You are the Content Strategy Agent for ALwrity user {self.user_id}.
@@ -903,12 +934,153 @@ class ContentStrategyAgent(BaseALwrityAgent):
- Performance-based content improvements
Use semantic analysis (SIF) and sitemap analysis to understand content context.
Always prioritize user goals and maintain brand consistency."""
Always prioritize user goals and maintain brand consistency.
In your first pass, call 'default_content_gsc_plan' to ground your actions on live GSC signals."""
)
)
# Tool Implementations
async def _cs_fetch_gsc_analytics(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, Any]:
svc = PlatformAnalyticsService()
data = await svc.get_comprehensive_analytics(self.user_id, platforms=["gsc"], start_date=start_date, end_date=end_date)
gsc = data.get("gsc")
if not gsc or gsc.status != "success":
err = getattr(gsc, "error_message", None) if gsc else "No data"
raise RuntimeError(f"GSC analytics unavailable: {err}")
return {"metrics": gsc.metrics, "date_range": gsc.date_range}
async def _cs_gsc_low_ctr_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 100)); min_clicks = int(context.get("min_clicks", 10)); ctr_threshold = float(context.get("ctr_threshold", 1.5))
start_date = context.get("start_date"); end_date = context.get("end_date")
try:
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
tq = result["metrics"].get("top_queries", []) or []
items = [
{"query": r.get("query"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position")}
for r in tq
if (r.get("impressions", 0) >= min_impr and r.get("clicks", 0) >= min_clicks and float(r.get("ctr", 0.0)) < ctr_threshold)
]
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
except Exception as e:
logger.error(f"cs low_ctr_queries failed: {e}"); return {"error": str(e)}
async def _cs_gsc_striking_distance_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 100)); start_date = context.get("start_date"); end_date = context.get("end_date")
try:
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
tq = result["metrics"].get("top_queries", []) or []
items = [
{"query": r.get("query"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position")}
for r in tq
if (r.get("impressions", 0) >= min_impr and r.get("position") is not None and 8.0 <= float(r.get("position")) <= 20.0)
]
items.sort(key=lambda x: (x.get("position") if x.get("position") is not None else 999, -x.get("impressions", 0)))
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
except Exception as e:
logger.error(f"cs striking_distance failed: {e}"); return {"error": str(e)}
async def _cs_gsc_declining_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10)); min_prev_clicks = int(context.get("min_prev_clicks", 10)); min_drop_pct = float(context.get("min_drop_pct", 30.0))
start_date = context.get("start_date"); end_date = context.get("end_date")
try:
curr = await self._cs_fetch_gsc_analytics(start_date, end_date)
curr_range = curr["date_range"]; s = curr_range.get("start"); e = curr_range.get("end")
from datetime import datetime, timedelta; fmt = "%Y-%m-%d"
sd = datetime.strptime(s, fmt) if s else datetime.utcnow() - timedelta(days=30); ed = datetime.strptime(e, fmt) if e else datetime.utcnow()
days = max((ed - sd).days + 1, 1); prev_end = sd - timedelta(days=1); prev_start = prev_end - timedelta(days=days - 1)
prev = await self._cs_fetch_gsc_analytics(prev_start.strftime(fmt), prev_end.strftime(fmt))
curr_queries = {r.get("query"): r for r in (curr["metrics"].get("top_queries", []) or [])}
prev_queries = {r.get("query"): r for r in (prev["metrics"].get("top_queries", []) or [])}
items = []
for q, prev_row in prev_queries.items():
curr_row = curr_queries.get(q);
if not curr_row: continue
prev_clicks = int(prev_row.get("clicks", 0) or 0); curr_clicks = int(curr_row.get("clicks", 0) or 0)
if prev_clicks >= min_prev_clicks and curr_clicks < prev_clicks:
drop_pct = ((prev_clicks - curr_clicks) / prev_clicks) * 100.0
if drop_pct >= min_drop_pct:
items.append({"query": q, "prev_clicks": prev_clicks, "curr_clicks": curr_clicks, "drop_pct": round(drop_pct, 2)})
items.sort(key=lambda x: (x.get("drop_pct", 0), x.get("prev_clicks", 0)), reverse=True)
return {"items": items[:limit], "range": curr_range, "previous_range": prev["date_range"], "source": "gsc_cache"}
except Exception as e:
logger.error(f"cs declining_queries failed: {e}"); return {"error": str(e)}
async def _cs_gsc_low_ctr_pages_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 200)); ctr_threshold = float(context.get("ctr_threshold", 1.5))
start_date = context.get("start_date"); end_date = context.get("end_date")
try:
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
tp = result["metrics"].get("top_pages", []) or []
items = []
for r in tp:
if (r.get("impressions", 0) >= min_impr and float(r.get("ctr", 0.0)) < ctr_threshold):
items.append({"page": r.get("page"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position"), "evidence_queries": r.get("queries", [])[:5]})
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
except Exception as e:
logger.error(f"cs low_ctr_pages failed: {e}"); return {"error": str(e)}
async def _cs_gsc_cannibalization_candidates_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10)); start_date = context.get("start_date"); end_date = context.get("end_date")
try:
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
candidates = result["metrics"].get("cannibalization", []) or []
return {"items": candidates[:limit], "range": result["date_range"], "source": "gsc_cache"}
except Exception as e:
logger.error(f"cs cannibalization_candidates failed: {e}"); return {"error": str(e)}
async def _default_content_gsc_plan_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
start_date = context.get("start_date"); end_date = context.get("end_date")
try:
low_ctr_pages = await self._cs_gsc_low_ctr_pages_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
cannibals = await self._cs_gsc_cannibalization_candidates_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
striking = await self._cs_gsc_striking_distance_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
declining = await self._cs_gsc_declining_queries_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
actions = []
for p in low_ctr_pages.get("items", []):
actions.append({
"type": "improve_titles_meta",
"target": p.get("page"),
"reason": f"Low CTR {p.get('ctr')}% with {p.get('impressions')} impressions",
"evidence": p.get("evidence_queries", [])
})
for c in cannibals.get("items", []):
actions.append({
"type": "consolidate/internal_link",
"target": c.get("recommended_target_page"),
"reason": f"Cannibalization on query '{c.get('query')}'",
"pages": c.get("pages", [])
})
for q in striking.get("items", []):
actions.append({
"type": "refresh_content",
"target": "query",
"query": q.get("query"),
"reason": f"Striking distance at position {q.get('position')} with {q.get('impressions')} impressions"
})
for q in declining.get("items", []):
actions.append({
"type": "refresh_content",
"target": "query",
"query": q.get("query"),
"reason": f"Clicks decline {q.get('prev_clicks')}{q.get('curr_clicks')} ({q.get('drop_pct')}%)"
})
return {
"plan_name": "Default Content Plan from GSC",
"range": {"current": {"start": start_date, "end": end_date}},
"actions": actions,
"source": "gsc_cache",
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"default_content_gsc_plan failed: {e}")
return {"error": str(e)}
async def _sitemap_analyzer_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""Sitemap analysis tool using ContentStrategyService"""
website_url = context.get('website_url')
@@ -1324,7 +1496,37 @@ class SEOOptimizationAgent(BaseALwrityAgent):
"name": "query_seo_knowledge_base",
"description": "Queries the SIF knowledge base for SEO dashboard data, GSC/Bing metrics, and semantic insights",
"target": self._query_seo_knowledge_base_tool
}
},
{
"name": "gsc_low_ctr_queries",
"description": "Returns low-CTR queries with evidence from cached GSC metrics",
"target": self._gsc_low_ctr_queries_tool
},
{
"name": "gsc_striking_distance_queries",
"description": "Returns striking-distance queries (positions ~820) with evidence",
"target": self._gsc_striking_distance_tool
},
{
"name": "gsc_declining_queries",
"description": "Returns period-over-period declining queries with evidence",
"target": self._gsc_declining_queries_tool
},
{
"name": "gsc_low_ctr_pages",
"description": "Returns low-CTR pages with top contributing queries",
"target": self._gsc_low_ctr_pages_tool
},
{
"name": "gsc_cannibalization_candidates",
"description": "Returns query→multiple-pages cannibalization candidates with target recommendation",
"target": self._gsc_cannibalization_candidates_tool
},
{
"name": "default_seo_gsc_plan",
"description": "Runs a default first-pass SEO plan using GSC signals (titles/meta, consolidation, refreshes)",
"target": self._default_seo_gsc_plan_tool
},
],
max_iterations=15,
system=self.get_effective_system_prompt(f"""You are the SEO Optimization Agent for ALwrity user {self.user_id}.
@@ -1340,6 +1542,7 @@ class SEOOptimizationAgent(BaseALwrityAgent):
- Deep semantic search of SEO data (GSC, Bing, Audits)
Focus on high-impact, low-effort optimizations first.
In your first pass, call 'default_seo_gsc_plan' to ground your actions on live GSC signals.
Always maintain SEO best practices and user experience."""
)
)
@@ -1666,6 +1869,223 @@ class SEOOptimizationAgent(BaseALwrityAgent):
"timestamp": datetime.utcnow().isoformat()
}
# GSC Insights Tools (Option B)
async def _fetch_gsc_analytics(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, Any]:
svc = PlatformAnalyticsService()
data = await svc.get_comprehensive_analytics(self.user_id, platforms=["gsc"], start_date=start_date, end_date=end_date)
gsc = data.get("gsc")
if not gsc or gsc.status != "success":
err = getattr(gsc, "error_message", None) if gsc else "No data"
raise RuntimeError(f"GSC analytics unavailable: {err}")
return {
"metrics": gsc.metrics,
"date_range": gsc.date_range
}
async def _gsc_low_ctr_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10))
min_impr = int(context.get("min_impressions", 100))
min_clicks = int(context.get("min_clicks", 10))
ctr_threshold = float(context.get("ctr_threshold", 1.5))
start_date = context.get("start_date")
end_date = context.get("end_date")
try:
result = await self._fetch_gsc_analytics(start_date, end_date)
tq = result["metrics"].get("top_queries", []) or []
items = [
{
"query": r.get("query"),
"clicks": r.get("clicks", 0),
"impressions": r.get("impressions", 0),
"ctr": r.get("ctr", 0.0),
"position": r.get("position")
}
for r in tq
if (r.get("impressions", 0) >= min_impr and r.get("clicks", 0) >= min_clicks and float(r.get("ctr", 0.0)) < ctr_threshold)
]
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
return {
"items": items[:limit],
"range": result["date_range"],
"source": "gsc_cache"
}
except Exception as e:
logger.error(f"low_ctr_queries tool failed: {e}")
return {"error": str(e)}
async def _gsc_striking_distance_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10))
min_impr = int(context.get("min_impressions", 100))
start_date = context.get("start_date")
end_date = context.get("end_date")
try:
result = await self._fetch_gsc_analytics(start_date, end_date)
tq = result["metrics"].get("top_queries", []) or []
items = [
{
"query": r.get("query"),
"clicks": r.get("clicks", 0),
"impressions": r.get("impressions", 0),
"ctr": r.get("ctr", 0.0),
"position": r.get("position")
}
for r in tq
if (r.get("impressions", 0) >= min_impr and r.get("position") is not None and 8.0 <= float(r.get("position")) <= 20.0)
]
items.sort(key=lambda x: (x.get("position") if x.get("position") is not None else 999, -x.get("impressions", 0)))
return {
"items": items[:limit],
"range": result["date_range"],
"source": "gsc_cache"
}
except Exception as e:
logger.error(f"striking_distance tool failed: {e}")
return {"error": str(e)}
async def _gsc_declining_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10))
min_prev_clicks = int(context.get("min_prev_clicks", 10))
min_drop_pct = float(context.get("min_drop_pct", 30.0))
start_date = context.get("start_date")
end_date = context.get("end_date")
try:
curr = await self._fetch_gsc_analytics(start_date, end_date)
curr_range = curr["date_range"]
s = curr_range.get("start")
e = curr_range.get("end")
from datetime import datetime, timedelta
fmt = "%Y-%m-%d"
sd = datetime.strptime(s, fmt) if s else datetime.utcnow() - timedelta(days=30)
ed = datetime.strptime(e, fmt) if e else datetime.utcnow()
days = max((ed - sd).days + 1, 1)
prev_end = sd - timedelta(days=1)
prev_start = prev_end - timedelta(days=days - 1)
prev = await self._fetch_gsc_analytics(prev_start.strftime(fmt), prev_end.strftime(fmt))
curr_queries = {r.get("query"): r for r in (curr["metrics"].get("top_queries", []) or [])}
prev_queries = {r.get("query"): r for r in (prev["metrics"].get("top_queries", []) or [])}
items = []
for q, prev_row in prev_queries.items():
curr_row = curr_queries.get(q)
if not curr_row:
continue
prev_clicks = int(prev_row.get("clicks", 0) or 0)
curr_clicks = int(curr_row.get("clicks", 0) or 0)
if prev_clicks >= min_prev_clicks and curr_clicks < prev_clicks:
drop_pct = ((prev_clicks - curr_clicks) / prev_clicks) * 100.0
if drop_pct >= min_drop_pct:
items.append({
"query": q,
"prev_clicks": prev_clicks,
"curr_clicks": curr_clicks,
"drop_pct": round(drop_pct, 2)
})
items.sort(key=lambda x: (x.get("drop_pct", 0), x.get("prev_clicks", 0)), reverse=True)
return {
"items": items[:limit],
"range": curr_range,
"previous_range": prev["date_range"],
"source": "gsc_cache"
}
except Exception as e:
logger.error(f"declining_queries tool failed: {e}")
return {"error": str(e)}
async def _gsc_low_ctr_pages_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10))
min_impr = int(context.get("min_impressions", 200))
ctr_threshold = float(context.get("ctr_threshold", 1.5))
start_date = context.get("start_date")
end_date = context.get("end_date")
try:
result = await self._fetch_gsc_analytics(start_date, end_date)
tp = result["metrics"].get("top_pages", []) or []
items = []
for r in tp:
if (r.get("impressions", 0) >= min_impr and float(r.get("ctr", 0.0)) < ctr_threshold):
items.append({
"page": r.get("page"),
"clicks": r.get("clicks", 0),
"impressions": r.get("impressions", 0),
"ctr": r.get("ctr", 0.0),
"position": r.get("position"),
"evidence_queries": r.get("queries", [])[:5]
})
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
return {
"items": items[:limit],
"range": result["date_range"],
"source": "gsc_cache"
}
except Exception as e:
logger.error(f"low_ctr_pages tool failed: {e}")
return {"error": str(e)}
async def _gsc_cannibalization_candidates_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
limit = int(context.get("limit", 10))
start_date = context.get("start_date")
end_date = context.get("end_date")
try:
result = await self._fetch_gsc_analytics(start_date, end_date)
candidates = result["metrics"].get("cannibalization", []) or []
return {
"items": candidates[:limit],
"range": result["date_range"],
"source": "gsc_cache"
}
except Exception as e:
logger.error(f"cannibalization_candidates tool failed: {e}")
return {"error": str(e)}
async def _default_seo_gsc_plan_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
start_date = context.get("start_date")
end_date = context.get("end_date")
try:
low_ctr_pages = await self._gsc_low_ctr_pages_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
cannibals = await self._gsc_cannibalization_candidates_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
striking = await self._gsc_striking_distance_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
declining = await self._gsc_declining_queries_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
actions = []
for p in low_ctr_pages.get("items", []):
actions.append({
"type": "update_titles_meta",
"target_page": p.get("page"),
"justification": f"Low CTR {p.get('ctr')}% with {p.get('impressions')} impressions",
"evidence": p.get("evidence_queries", [])
})
for c in cannibals.get("items", []):
actions.append({
"type": "consolidate/internal_link",
"target_page": c.get("recommended_target_page"),
"justification": f"Cannibalization on query '{c.get('query')}'",
"pages": c.get("pages", [])
})
for q in striking.get("items", []):
actions.append({
"type": "refresh_content",
"target": "query",
"query": q.get("query"),
"justification": f"Striking distance at position {q.get('position')} with {q.get('impressions')} impressions"
})
for q in declining.get("items", []):
actions.append({
"type": "refresh_content",
"target": "query",
"query": q.get("query"),
"justification": f"Clicks decline {q.get('prev_clicks')}{q.get('curr_clicks')} ({q.get('drop_pct')}%)"
})
return {
"plan_name": "Default SEO Plan from GSC",
"range": {"current": {"start": start_date, "end": end_date}},
"actions": actions,
"source": "gsc_cache",
"timestamp": datetime.utcnow().isoformat()
}
except Exception as e:
logger.error(f"default_seo_gsc_plan failed: {e}")
return {"error": str(e)}
class SocialAmplificationAgent(BaseALwrityAgent):
"""

View File

@@ -14,9 +14,9 @@ from .txtai_service import TxtaiIntelligenceService, TXTAI_AVAILABLE
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent
from services.llm_providers.main_text_generation import llm_text_gen
# Optional txtai imports
# Optional txtai imports (align with core agent framework)
try:
from txtai.pipeline import Agent, LLM
from txtai import Agent, LLM
except ImportError:
Agent = None
LLM = None
@@ -28,9 +28,13 @@ class SharedLLMWrapper:
def generate(self, prompt: str, **kwargs) -> str:
"""Generate text using the shared LLM provider."""
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
# but we could map them if needed.
return llm_text_gen(prompt, user_id=self.user_id)
try:
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
# but we could map them if needed.
return llm_text_gen(prompt, user_id=self.user_id)
except Exception as e:
logger.error(f"SharedLLMWrapper failed to generate text: {e}")
return f"[ERROR: Shared LLM generation failed for user {self.user_id}]"
def __call__(self, prompt: str, **kwargs) -> str:
return self.generate(prompt, **kwargs)
@@ -40,8 +44,9 @@ class LocalLLMWrapper:
Lazily loads a local LLM via txtai.
This prevents blocking server startup with heavy model loads.
"""
def __init__(self, model_path: str):
def __init__(self, model_path: str, task: str = "text-generation"):
self.model_path = model_path
self.task = task
self._llm = None
@property
@@ -49,8 +54,9 @@ class LocalLLMWrapper:
if self._llm is None:
if LLM is None:
raise ImportError("txtai.pipeline.LLM is not available")
logger.info(f"Loading local LLM: {self.model_path}")
self._llm = LLM(path=self.model_path)
logger.info(f"Loading local LLM: {self.model_path} with task: {self.task}")
# Explicitly set task to avoid 'text2text-generation' default failures
self._llm = LLM(path=self.model_path, task=self.task)
return self._llm
def __call__(self, prompt: str, **kwargs) -> str:
@@ -67,11 +73,12 @@ class SIFBaseAgent(BaseALwrityAgent):
# 2. Local LLM for internal agent work (default for SIF agents)
if llm is None:
if TXTAI_AVAILABLE:
# Use Lazy Local LLM
llm = LocalLLMWrapper(model_name)
if TXTAI_AVAILABLE and LLM is not None:
# Use Lazy Local LLM when txtai LLM is available
# Hardening: Specify 'text-generation' task to avoid text2text defaults
llm = LocalLLMWrapper(model_name, task="text-generation")
else:
# Fallback to Shared if txtai not available
# Fallback to Shared if txtai or LLM is not available
llm = self.shared_llm
super().__init__(user_id, agent_type, model_name, llm)
@@ -85,14 +92,18 @@ class SIFBaseAgent(BaseALwrityAgent):
def _create_txtai_agent(self):
"""
SIF agents use the intelligence service directly, but we can expose
capabilities via a standard agent interface if needed.
SIF agents primarily use the intelligence service directly, but we can expose
capabilities via a standard agent interface if available.
"""
if not TXTAI_AVAILABLE:
return None
# Return a simple agent that can use the LLM
return Agent(llm=self.llm, tools=[])
if not TXTAI_AVAILABLE or Agent is None:
logger.debug(f"[{self.__class__.__name__}] txtai Agent not available, using fallback agent")
return self._create_fallback_agent()
try:
return Agent(llm=self.llm, tools=[])
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] Failed to create txtai Agent: {e}")
return self._create_fallback_agent()
class StrategyArchitectAgent(SIFBaseAgent):
"""Agent for discovering content pillars and identifying strategic gaps."""

View File

@@ -25,7 +25,18 @@ except ImportError:
TXTAI_AVAILABLE = False
class TxtaiIntelligenceService:
_instances = {}
def __new__(cls, user_id: str, *args, **kwargs):
if user_id not in cls._instances:
cls._instances[user_id] = super(TxtaiIntelligenceService, cls).__new__(cls)
return cls._instances[user_id]
def __init__(self, user_id: str, model_path: Optional[str] = None, enable_caching: bool = True):
# Singleton: prevent re-initialization if already initialized
if getattr(self, "_singleton_initialized", False):
return
self.user_id = user_id
self.model_path = model_path or "sentence-transformers/all-MiniLM-L6-v2"
self.index_path = f"workspace/workspace_{user_id}/indices/txtai"
@@ -33,6 +44,11 @@ class TxtaiIntelligenceService:
self._initialized = False
self.enable_caching = enable_caching
self.cache_manager = semantic_cache_manager if enable_caching else None
self._backend = "faiss" # Default backend
# Mark as initialized for singleton pattern
self._singleton_initialized = True
# Lazy initialization - do not initialize embeddings on startup
# self._initialize_embeddings()
@@ -52,17 +68,26 @@ class TxtaiIntelligenceService:
logger.debug(f"Model path: {self.model_path}")
logger.debug(f"Index path: {self.index_path}")
# Close existing embeddings if any to release file locks
if self.embeddings:
try:
if hasattr(self.embeddings, 'close'):
self.embeddings.close()
self.embeddings = None
except Exception as close_err:
logger.warning(f"Error closing existing embeddings: {close_err}")
# Ensure directory exists
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
logger.debug(f"Created index directory: {os.path.dirname(self.index_path)}")
# Initialize embeddings with optimal configuration for ALwrity use case
# Hardening: Disabling quantization by default as it causes 'IndexIDMap' attribute errors with small indices on Windows
self.embeddings = Embeddings({
"path": self.model_path,
"content": True, # Enable content storage for retrieval
"objects": True, # Enable object storage for metadata
"backend": "faiss", # Use Faiss for efficient similarity search
"quantize": True, # Enable quantization for memory efficiency
"backend": self._backend, # Use Faiss for efficient similarity search
"batch": 32, # Batch size for processing
"gpu": False, # Force CPU usage for compatibility
"limit": 1000 # Maximum number of results for queries
@@ -76,7 +101,12 @@ class TxtaiIntelligenceService:
try:
self.embeddings.load(self.index_path)
logger.info(f"Successfully loaded existing txtai index for user {self.user_id}")
logger.debug(f"Index contains {len(self.embeddings)} items")
# Try to log count, handle if not supported
try:
count = self.embeddings.count() if hasattr(self.embeddings, 'count') else "unknown"
logger.debug(f"Index contains {count} items")
except:
logger.debug("Index loaded (count unavailable)")
except Exception as load_error:
logger.warning(f"Failed to load existing index: {load_error}. Creating new index.")
# Reset embeddings to create new index
@@ -84,8 +114,7 @@ class TxtaiIntelligenceService:
"path": self.model_path,
"content": True,
"objects": True,
"backend": "faiss",
"quantize": True,
"backend": self._backend,
"batch": 32,
"gpu": False,
"limit": 1000
@@ -146,8 +175,15 @@ class TxtaiIntelligenceService:
logger.error(f"Error indexing content for user {self.user_id}: {e}")
logger.error(f"Full traceback: {traceback.format_exc()}")
logger.error(f"Items count: {len(items) if items else 0}")
if items and len(items) > 0:
logger.error(f"Sample item structure: {type(items[0])}")
message = str(e)
is_windows_lock_error = isinstance(e, PermissionError) or "WinError 32" in message
if is_windows_lock_error:
logger.warning(
f"Txtai index save skipped for user {self.user_id} due to file lock. "
f"The index will be retried on a future run."
)
return
raise
async def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
@@ -172,7 +208,20 @@ class TxtaiIntelligenceService:
logger.debug(f"Cache miss for search query: '{query}'")
logger.debug(f"Searching for query: '{query}' with limit: {limit}")
results = self.embeddings.search(query, limit=limit)
try:
results = self.embeddings.search(query, limit=limit)
except AttributeError as ae:
if "nprobe" in str(ae):
logger.error(f"Detected known txtai/faiss IndexIDMap/nprobe incompatibility for user {self.user_id}. Attempting re-init with numpy backend fallback...")
# Switch to numpy backend which doesn't have this issue
self._backend = "numpy"
self._initialize_embeddings()
if self.embeddings:
results = self.embeddings.search(query, limit=limit)
else:
raise ae
else:
raise ae
# Cache the results if caching is enabled
if self.enable_caching and self.cache_manager and results:
@@ -216,7 +265,19 @@ class TxtaiIntelligenceService:
logger.debug(f"Cache miss for similarity calculation")
logger.debug(f"Calculating similarity between texts: '{text1[:50]}...' and '{text2[:50]}...'")
similarity = self.embeddings.similarity(text1, text2)
try:
similarity = self.embeddings.similarity(text1, text2)
except AttributeError as ae:
if "nprobe" in str(ae):
logger.error(f"Detected IndexIDMap nprobe error in similarity for user {self.user_id}. Falling back to numpy backend...")
self._backend = "numpy"
self._initialize_embeddings()
if self.embeddings:
similarity = self.embeddings.similarity(text1, text2)
else:
raise ae
else:
raise ae
# Cache the similarity result
if self.enable_caching and self.cache_manager:
@@ -272,7 +333,19 @@ class TxtaiIntelligenceService:
# Use graph-based clustering if available
# Perform a search to get graph structure
sample_query = "content marketing digital strategy"
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
try:
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
except AttributeError as ae:
if "nprobe" in str(ae):
logger.error(f"Detected IndexIDMap nprobe error in cluster for user {self.user_id}. Falling back to numpy backend...")
self._backend = "numpy"
self._initialize_embeddings()
if self.embeddings:
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
else:
raise ae
else:
raise ae
if not graph_results:
logger.warning(f"No graph results for clustering user {self.user_id}")
@@ -306,7 +379,7 @@ class TxtaiIntelligenceService:
logger.error(f"Full traceback: {traceback.format_exc()}")
return self._fallback_clustering(min_score)
def _fallback_clustering(self, min_score: float) -> List[List[int]]:
async def _fallback_clustering(self, min_score: float) -> List[List[int]]:
"""Fallback clustering method when graph clustering is not available."""
logger.info(f"Using fallback clustering for user {self.user_id}")
@@ -318,7 +391,8 @@ class TxtaiIntelligenceService:
all_clusters = []
for query in sample_queries:
results = self.embeddings.search(query, limit=5)
# Use our search wrapper for hardening
results = await self.search(query, limit=5)
if results and results[0].get("score", 0) >= min_score:
# Create a cluster from similar results
cluster = [i for i, result in enumerate(results) if result.get("score", 0) >= min_score]
@@ -393,9 +467,13 @@ class TxtaiIntelligenceService:
return {"status": "not_initialized", "user_id": self.user_id}
try:
# Get count of indexed items - txtai doesn't have a direct len() method
# We'll estimate based on available data or return a placeholder
index_size = getattr(self.embeddings, 'count', 0) or "unknown"
# Get count of indexed items
index_size = "unknown"
if hasattr(self.embeddings, 'count'):
try:
index_size = self.embeddings.count()
except:
pass
return {
"status": "active",
@@ -410,5 +488,7 @@ class TxtaiIntelligenceService:
return {"status": "error", "user_id": self.user_id, "error": str(e)}
def is_initialized(self) -> bool:
"""Check if the service is properly initialized."""
"""Check if the service is properly initialized, triggering lazy init if needed."""
if not self._initialized:
self._ensure_initialized()
return self._initialized and self.embeddings is not None

View File

@@ -369,6 +369,12 @@ def huggingface_structured_json_response(
response_text = re.sub(r'```\n?', '', response_text)
response_text = response_text.strip()
# Fix common markdown artefacts that break JSON, e.g. lines starting with **"key":
# **"narration": "text"
# becomes:
# "narration": "text"
response_text = re.sub(r'^\s*\*\*(?=\s*")', '', response_text, flags=re.MULTILINE)
try:
parsed_json = json.loads(response_text)
logger.info("✅ Hugging Face structured JSON response parsed from text")

View File

@@ -648,11 +648,13 @@ async def ai_video_generate(
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.database import get_session_for_user
from services.subscription.preflight_validator import validate_video_generation_operations
from fastapi import HTTPException
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Database session unavailable for user.")
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
@@ -762,9 +764,11 @@ def track_video_usage(
from datetime import datetime
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
from services.database import get_db
from services.database import get_session_for_user
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
return {}
try:
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
pricing_service_track = PricingService(db_track)

View File

@@ -527,6 +527,11 @@ class APIKeyManager:
def __init__(self):
self.api_keys = {}
self._load_from_env()
def load_api_keys(self):
self.api_keys = {}
self._load_from_env()
return self.api_keys
def _load_from_env(self):
"""Load API keys from environment variables."""

View File

@@ -27,6 +27,12 @@ async def generate_facebook_persona_task(user_id: str):
try:
logger.info(f"Scheduled Facebook persona generation started for user {user_id}")
# Ensure we have a valid session factory before trying to get session
from services.database import SessionLocal
if not SessionLocal:
logger.error("Database session factory not initialized")
return
db = get_db_session()
if not db:
logger.error(f"Failed to get database session for Facebook persona generation (user: {user_id})")

View File

@@ -0,0 +1,177 @@
from typing import Dict, Any, Optional
from loguru import logger
from services.product_marketing.personalization_service import PersonalizationService
from models.podcast_bible_models import (
PodcastBible,
HostPersona,
AudienceDNA,
BrandDNA,
VisualStyle,
AudioEnvironment,
ShowRules
)
class PodcastBibleService:
"""Service for generating and managing the Podcast Bible."""
def __init__(self):
self.personalization_service = PersonalizationService()
def generate_bible(self, user_id: str, project_id: str) -> PodcastBible:
"""Generate a Podcast Bible from onboarding data."""
logger.info(f"Generating Podcast Bible for user {user_id}")
try:
preferences = self.personalization_service.get_user_preferences(user_id)
writing_style = preferences.get("writing_style", {})
style_prefs = preferences.get("style_preferences", {})
target_audience = preferences.get("target_audience", {})
industry = preferences.get("industry", "General Business")
# 1. Map Host Persona
host = HostPersona(
name="Your AI Host",
background=f"Expert in {industry}",
expertise_level=writing_style.get("complexity", "Expert").capitalize(),
personality_traits=[
writing_style.get("tone", "Professional").capitalize(),
writing_style.get("engagement_level", "Informative").capitalize()
],
vocal_style=writing_style.get("voice", "Authoritative").capitalize(),
vocal_characteristics=["Clear", "Articulate", writing_style.get("voice", "Steady")],
look=f"A professional individual dressed in business-casual attire, fitting the {industry} industry aesthetic.",
catchphrases=[]
)
# 2. Map Audience DNA
audience = AudienceDNA(
expertise_level=target_audience.get("expertise_level", "Intermediate").capitalize(),
interests=target_audience.get("interests", ["Industry Trends", "Innovation"]),
pain_points=target_audience.get("pain_points", ["Staying ahead of competition", "Efficiency"]),
demographics=None
)
# 3. Map Brand DNA
brand = BrandDNA(
industry=industry,
tone=writing_style.get("tone", "Professional").capitalize(),
communication_style=writing_style.get("engagement_level", "Informative").capitalize(),
key_messages=preferences.get("brand_values", []),
competitor_context=None
)
# 4. Map Visual Style
visual = VisualStyle(
style_preset=style_prefs.get("aesthetic", "Professional Studio").capitalize(),
environment=f"A modern {industry}-themed podcast studio with professional equipment.",
lighting="Soft, warm studio lighting with subtle rim lights.",
color_palette=preferences.get("brand_colors", ["#1e293b", "#3b82f6"]),
camera_style="Dynamic mid-shots with occasional close-ups for emphasis."
)
# 5. Map Audio Environment
audio_env = AudioEnvironment(
soundscape="Pristine studio environment with deep, warm acoustics.",
music_mood=f"{writing_style.get('tone', 'Professional').capitalize()} & {writing_style.get('engagement_level', 'Upbeat').capitalize()}",
sfx_style="Modern, clean interface-inspired sounds."
)
# 6. Map Show Rules
show_rules = ShowRules(
intro_format=f"Start with a high-energy hook about the episode topic, followed by a warm welcome and an overview of the {industry} insights to be shared.",
outro_format="Summarize the key takeaways, provide a clear call to action, and sign off with a professional closing.",
interaction_tone=writing_style.get("engagement_level", "Conversational").capitalize(),
constraints=[
"Avoid overly technical jargon unless defined",
"Keep segments concise and factual",
f"Maintain a {writing_style.get('tone', 'Professional')} tone at all times"
]
)
bible = PodcastBible(
project_id=project_id,
host=host,
audience=audience,
brand=brand,
visual_style=visual,
audio_environment=audio_env,
show_rules=show_rules
)
logger.info(f"Podcast Bible generated successfully for project {project_id}")
return bible
except Exception as e:
logger.error(f"Error generating Podcast Bible: {str(e)}")
# Return a default bible if something goes wrong to ensure project creation doesn't fail
return self._get_default_bible(project_id)
def _get_default_bible(self, project_id: str) -> PodcastBible:
"""Return a sensible default Bible."""
return PodcastBible(
project_id=project_id,
host=HostPersona(
name="AI Host",
background="Industry Professional",
expertise_level="Expert",
vocal_style="Authoritative",
vocal_characteristics=["Deep", "Steady"]
),
audience=AudienceDNA(
expertise_level="Intermediate",
interests=["Industry Trends", "Technology"],
pain_points=["Staying Competitive", "Operational Efficiency"]
),
brand=BrandDNA(
industry="General Business",
tone="Professional",
communication_style="Analytical"
),
visual_style=VisualStyle(
environment="Professional modern office studio",
color_palette=["#000000", "#FFFFFF"]
),
audio_environment=AudioEnvironment(),
show_rules=ShowRules(
intro_format="Standard welcome and topic introduction.",
outro_format="Summary and sign-off."
)
)
def serialize_bible(self, bible: PodcastBible) -> str:
"""Serialize the Bible into a prompt-friendly text block."""
return f"""
<podcast_bible>
HOST PERSONA:
- Name: {bible.host.name}
- Background: {bible.host.background}
- Expertise Level: {bible.host.expertise_level}
- Personality: {', '.join(bible.host.personality_traits)}
- Vocal Style: {bible.host.vocal_style}
- Vocal Characteristics: {', '.join(bible.host.vocal_characteristics)}
- Visual Look: {bible.host.look}
TARGET AUDIENCE:
- Expertise: {bible.audience.expertise_level}
- Interests: {', '.join(bible.audience.interests)}
- Pain Points: {', '.join(bible.audience.pain_points)}
BRAND & STYLE:
- Industry: {bible.brand.industry}
- Tone: {bible.brand.tone}
- Communication Style: {bible.brand.communication_style}
- Visual Style Preset: {bible.visual_style.style_preset}
- Environment: {bible.visual_style.environment}
- Lighting: {bible.visual_style.lighting}
AUDIO ENVIRONMENT:
- Soundscape: {bible.audio_environment.soundscape}
- Music Mood: {bible.audio_environment.music_mood}
SHOW RULES & STRUCTURE:
- Intro Format: {bible.show_rules.intro_format}
- Outro Format: {bible.show_rules.outro_format}
- Interaction Tone: {bible.show_rules.interaction_tone}
- Constraints: {', '.join(bible.show_rules.constraints)}
</podcast_bible>
"""

View File

@@ -11,6 +11,7 @@ from datetime import datetime
import uuid
from models.podcast_models import PodcastProject
from services.podcast_bible_service import PodcastBibleService
class PodcastService:
@@ -18,6 +19,7 @@ class PodcastService:
def __init__(self, db: Session):
self.db = db
self.bible_service = PodcastBibleService()
def create_project(
self,
@@ -30,6 +32,9 @@ class PodcastService:
**kwargs
) -> PodcastProject:
"""Create a new podcast project."""
# Generate Podcast Bible automatically from onboarding data
bible = self.bible_service.generate_bible(user_id, project_id)
project = PodcastProject(
project_id=project_id,
user_id=user_id,
@@ -37,6 +42,7 @@ class PodcastService:
duration=duration,
speakers=speakers,
budget_cap=budget_cap,
bible=bible.model_dump() if bible else None,
status="draft",
current_step="create",
**kwargs

View File

@@ -5,13 +5,15 @@ Pluggable task scheduler that can work with any task model.
import asyncio
import logging
import os
from typing import Dict, Any, Optional, List, Callable
from datetime import datetime
from datetime import datetime, timedelta
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from apscheduler.triggers.cron import CronTrigger
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.triggers.date import DateTrigger
from sqlalchemy.orm import Session
from sqlalchemy import text
from .executor_interface import TaskExecutor, TaskExecutionResult
from .task_registry import TaskRegistry
@@ -19,8 +21,10 @@ from .exception_handler import (
SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError,
TaskLoaderError, SchedulerConfigError
)
from services.database import get_all_user_ids, get_session_for_user
from utils.logger_utils import get_service_logger
from ..utils.user_job_store import get_user_job_store_name
from models.scheduler_models import SchedulerEventLog
from .interval_manager import determine_optimal_interval, adjust_check_interval_if_needed
@@ -86,6 +90,9 @@ class TaskScheduler:
}
)
# Configure APScheduler to use unified logging system
self._configure_apscheduler_logging()
# Task executor registry
self.registry = TaskRegistry()
@@ -115,6 +122,21 @@ class TaskScheduler:
}
self._running = False
# Local Desktop App: Always leader, no advisory locks needed
self._leader_lock_key = int(os.getenv("SCHEDULER_LEADER_LOCK_KEY", "84321017"))
self._leadership_check_interval_seconds = int(os.getenv("SCHEDULER_LEADERSHIP_CHECK_INTERVAL", "15"))
self._leader_session = None
self._is_leader = True # Always leader in local desktop app
self._execution_enabled = True # Always enabled
self._leader_since = datetime.utcnow().isoformat()
self._last_leadership_check = None
self._last_leadership_error = None
# Execution lease registry (prevents duplicate redispatch across check cycles)
self._task_leases: Dict[str, str] = {}
self._task_lease_ttl_seconds = int(os.getenv("SCHEDULER_TASK_LEASE_TTL_SECONDS", "900"))
def _get_trigger_for_interval(self, interval_minutes: int):
"""
@@ -153,6 +175,144 @@ class TaskScheduler:
self.registry.register(task_type, executor, task_loader)
logger.info(f"Registered executor for task type: {task_type}")
def _configure_apscheduler_logging(self):
"""Configure APScheduler to use unified logging system."""
import logging
# Get APScheduler loggers and redirect them to unified logging
apscheduler_logger = logging.getLogger("apscheduler")
apscheduler_scheduler_logger = logging.getLogger("apscheduler.scheduler")
apscheduler_executors_logger = logging.getLogger("apscheduler.executors")
apscheduler_jobstores_logger = logging.getLogger("apscheduler.jobstores")
# Create a custom handler that redirects to unified logger
class APSchedulerUnifiedHandler(logging.Handler):
def __init__(self, service_logger):
super().__init__()
self.service_logger = service_logger
def emit(self, record):
try:
# Format the message
msg = self.format(record)
# Map APScheduler log levels to unified logger
if record.levelno >= logging.ERROR:
self.service_logger.error(f"[APScheduler] {msg}")
elif record.levelno >= logging.WARNING:
self.service_logger.warning(f"[APScheduler] {msg}")
elif record.levelno >= logging.INFO:
self.service_logger.info(f"[APScheduler] {msg}")
else:
self.service_logger.debug(f"[APScheduler] {msg}")
except Exception:
# Don't let logging errors break the scheduler
pass
# Create and add the handler
unified_handler = APSchedulerUnifiedHandler(logger)
unified_handler.setLevel(logging.DEBUG)
# Add handler to all APScheduler loggers
apscheduler_logger.addHandler(unified_handler)
apscheduler_scheduler_logger.addHandler(unified_handler)
apscheduler_executors_logger.addHandler(unified_handler)
apscheduler_jobstores_logger.addHandler(unified_handler)
# Set levels to capture all logs
apscheduler_logger.setLevel(logging.DEBUG)
apscheduler_scheduler_logger.setLevel(logging.DEBUG)
apscheduler_executors_logger.setLevel(logging.DEBUG)
apscheduler_jobstores_logger.setLevel(logging.DEBUG)
# Prevent propagation to avoid duplicate logs
apscheduler_logger.propagate = False
apscheduler_scheduler_logger.propagate = False
apscheduler_executors_logger.propagate = False
apscheduler_jobstores_logger.propagate = False
logger.info("APScheduler logging configured to use unified logging system")
def _scheduler_identity(self) -> str:
return f"{os.getenv('HOSTNAME', 'local')}-{os.getpid()}"
def _acquire_leadership(self) -> bool:
"""Always return True for local desktop app (no HA needed)."""
self._is_leader = True
self._execution_enabled = True
if not self._leader_since:
self._leader_since = datetime.utcnow().isoformat()
self._last_leadership_check = datetime.utcnow().isoformat()
return True
def _release_leadership(self):
"""No-op for local desktop app."""
pass
def _sync_check_due_tasks_job(self):
"""Ensure check_due_tasks job exists only for leader."""
job = self.scheduler.get_job('check_due_tasks')
if self._is_leader and self._execution_enabled:
if job is None:
self.scheduler.add_job(
self._check_and_execute_due_tasks,
trigger=self._get_trigger_for_interval(self.current_check_interval_minutes),
id='check_due_tasks',
replace_existing=True
)
else:
if job is not None:
self.scheduler.remove_job('check_due_tasks')
async def _leadership_tick(self):
"""Periodic leadership check/renewal (Stub for local)."""
if not self._running:
return
self._acquire_leadership()
self._sync_check_due_tasks_job()
def _acquire_task_lease(self, task_key: str) -> bool:
"""Acquire in-memory lease for a task key if available/expired."""
now = datetime.utcnow()
expiry_str = self._task_leases.get(task_key)
if expiry_str:
try:
expiry = datetime.fromisoformat(expiry_str)
if expiry > now:
return False
except Exception:
# Corrupted lease value: overwrite safely
pass
expiry = now + timedelta(seconds=self._task_lease_ttl_seconds)
self._task_leases[task_key] = expiry.isoformat()
return True
def _release_task_lease(self, task_key: str):
"""Release lease for task key."""
if task_key in self._task_leases:
del self._task_leases[task_key]
def _is_task_leased(self, task_key: str) -> bool:
"""Check whether task key is currently leased and not expired."""
expiry_str = self._task_leases.get(task_key)
if not expiry_str:
return False
try:
expiry = datetime.fromisoformat(expiry_str)
if expiry > datetime.utcnow():
return True
except Exception:
pass
# Expired/corrupt lease gets cleaned up lazily
self._release_task_lease(task_key)
return False
async def start(self):
"""Start the scheduler with intelligent interval adjustment."""
if self._running:
@@ -168,16 +328,21 @@ class TaskScheduler:
)
self.current_check_interval_minutes = initial_interval
# Add periodic job to check for due tasks
self.scheduler.add_job(
self._check_and_execute_due_tasks,
trigger=self._get_trigger_for_interval(initial_interval),
id='check_due_tasks',
replace_existing=True
)
self.scheduler.start()
self._running = True
# Leadership monitor runs on all replicas; only leader executes due-task loop.
self.scheduler.add_job(
self._leadership_tick,
trigger=IntervalTrigger(seconds=self._leadership_check_interval_seconds),
id='leadership_monitor',
replace_existing=True,
max_instances=1,
coalesce=True
)
# Initial leader election
await self._leadership_tick()
# Check for and execute any missed jobs that are still within grace period
await self._execute_missed_jobs()
@@ -206,7 +371,7 @@ class TaskScheduler:
registered_types = self.registry.get_registered_types()
active_strategies = self.stats.get('active_strategies_count', 0)
# Count OAuth token monitoring tasks from database (recurring weekly tasks)
# Count tasks per user (Multi-tenant SQLite)
oauth_tasks_count = 0
website_analysis_tasks_count = 0
platform_insights_tasks_count = 0
@@ -323,126 +488,6 @@ class TaskScheduler:
startup_lines.append(f"{prefix} Job: {job.id} | Trigger: {trigger_type} | Next Run: {next_run}{user_context}")
# Add OAuth token monitoring tasks details
# Show ALL OAuth tasks (active and inactive) for complete visibility
if total_oauth_tasks > 0:
try:
user_ids = get_all_user_ids()
for user_id in user_ids:
try:
db = get_session_for_user(user_id)
if db:
from models.oauth_token_monitoring_models import OAuthTokenMonitoringTask
# Get ALL tasks for this user
oauth_tasks = db.query(OAuthTokenMonitoringTask).all()
for idx, task in enumerate(oauth_tasks):
is_last = idx == len(oauth_tasks) - 1 and website_analysis_tasks_count == 0 and platform_insights_tasks_count == 0 and len(all_jobs) == 0 and user_id == user_ids[-1]
prefix = " ├─" # Simplified prefix logic for multi-user list
try:
user_job_store = get_user_job_store_name(task.user_id, db)
if user_job_store == 'default':
logger.debug(
f"[Scheduler] Job store extraction returned 'default' for user {task.user_id}. "
f"This may indicate no onboarding data or website URL not found."
)
except Exception as e:
logger.warning(
f"[Scheduler] Could not extract job store name for user {task.user_id}: {e}. "
f"Using 'default'. Error type: {type(e).__name__}"
)
user_job_store = 'default'
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
# Include status in the log line for visibility
status_indicator = "" if task.status == 'active' else f"[{task.status}]"
startup_lines.append(
f"{prefix} Job: oauth_token_monitoring_{task.platform}_{task.user_id} | "
f"Trigger: CronTrigger (Weekly) | Next Run: {next_check} | "
f"User: {task.user_id} | Store: {user_job_store} | Platform: {task.platform} {status_indicator}"
)
db.close()
except Exception as e:
logger.warning(f"Error checking OAuth tasks for user {user_id}: {e}")
except Exception as e:
logger.debug(f"Could not get OAuth token monitoring task details: {e}")
# Add website analysis tasks details
if website_analysis_tasks_count > 0:
try:
user_ids = get_all_user_ids()
for user_id in user_ids:
try:
db = get_session_for_user(user_id)
if db:
from models.website_analysis_monitoring_models import WebsiteAnalysisTask
website_analysis_tasks = db.query(WebsiteAnalysisTask).all()
for idx, task in enumerate(website_analysis_tasks):
is_last = idx == len(website_analysis_tasks) - 1 and platform_insights_tasks_count == 0 and len(all_jobs) == 0 and total_oauth_tasks == 0 and user_id == user_ids[-1]
prefix = " ├─" # Simplified
try:
user_job_store = get_user_job_store_name(task.user_id, db)
except Exception as e:
logger.debug(f"Could not extract job store name for user {task.user_id}: {e}")
user_job_store = 'default'
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
frequency = f"Every {task.frequency_days} days"
task_type_label = "User Website" if task.task_type == 'user_website' else "Competitor"
status_indicator = "" if task.status == 'active' else f"[{task.status}]"
website_display = task.website_url[:50] + "..." if task.website_url and len(task.website_url) > 50 else (task.website_url or 'N/A')
startup_lines.append(
f"{prefix} Job: website_analysis_{task.task_type}_{task.user_id}_{task.id} | "
f"Trigger: CronTrigger ({frequency}) | Next Run: {next_check} | "
f"User: {task.user_id} | Store: {user_job_store} | Type: {task_type_label} | URL: {website_display} {status_indicator}"
)
db.close()
except Exception as e:
logger.warning(f"Error checking website analysis tasks for user {user_id}: {e}")
except Exception as e:
logger.debug(f"Could not get website analysis task details: {e}")
# Add platform insights tasks details
if platform_insights_tasks_count > 0:
try:
user_ids = get_all_user_ids()
for user_id in user_ids:
try:
db = get_session_for_user(user_id)
if db:
from models.platform_insights_monitoring_models import PlatformInsightsTask
platform_insights_tasks = db.query(PlatformInsightsTask).all()
for idx, task in enumerate(platform_insights_tasks):
is_last = idx == len(platform_insights_tasks) - 1 and len(all_jobs) == 0 and total_oauth_tasks == 0 and website_analysis_tasks_count == 0 and user_id == user_ids[-1]
prefix = " ├─" # Simplified
try:
user_job_store = get_user_job_store_name(task.user_id, db)
except Exception as e:
logger.debug(f"Could not extract job store name for user {task.user_id}: {e}")
user_job_store = 'default'
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
platform_label = task.platform.upper() if task.platform else 'Unknown'
site_display = task.site_url[:50] + "..." if task.site_url and len(task.site_url) > 50 else (task.site_url or 'N/A')
status_indicator = "" if task.status == 'active' else f"[{task.status}]"
startup_lines.append(
f"{prefix} Job: platform_insights_{task.platform}_{task.user_id} | "
f"Trigger: CronTrigger (Weekly) | Next Run: {next_check} | "
f"User: {task.user_id} | Store: {user_job_store} | Platform: {platform_label} | Site: {site_display} {status_indicator}"
)
db.close()
except Exception as e:
logger.warning(f"Error checking platform insights tasks for user {user_id}: {e}")
except Exception as e:
logger.debug(f"Could not get platform insights task details: {e}")
# Add Advertools tasks details
if advertools_tasks_count > 0:
try:
@@ -518,7 +563,15 @@ class TaskScheduler:
# Get final job count before shutdown
all_jobs_before = self.scheduler.get_jobs()
# Release leadership lock and stop leadership monitor
try:
if self.scheduler.get_job('leadership_monitor') is not None:
self.scheduler.remove_job('leadership_monitor')
except Exception:
pass
self._release_leadership()
# Shutdown scheduler
self.scheduler.shutdown(wait=True)
self._running = False
@@ -569,6 +622,10 @@ class TaskScheduler:
Main scheduler loop: check for due tasks and execute them.
This runs periodically with intelligent interval adjustment based on active strategies.
"""
if not self._execution_enabled or not self._is_leader:
logger.debug("[Scheduler] Skipping due-task loop on standby replica")
return
await check_and_execute_due_tasks(self)
async def _adjust_check_interval_if_needed(self, db: Session):
@@ -614,309 +671,156 @@ class TaskScheduler:
except Exception as e:
logger.warning(f"[Scheduler] Error checking for missed jobs: {e}")
async def trigger_interval_adjustment(self):
"""
Trigger immediate interval adjustment check.
This should be called when a strategy is activated or deactivated
to immediately adjust the scheduler interval based on current active strategies.
"""
if not self._running:
logger.debug("Scheduler not running, skipping interval adjustment")
return
try:
# Multi-tenant aware adjustment (iterates all users internally)
await adjust_check_interval_if_needed(self)
except Exception as e:
logger.warning(f"Error triggering interval adjustment: {e}")
async def _validate_and_rebuild_cumulative_stats(self):
"""
Validate cumulative stats on scheduler startup and rebuild if needed.
This ensures cumulative stats are accurate after restarts.
NOTE: Disabled in multi-tenant mode as there is no global database for cumulative stats.
TODO: Implement per-user cumulative stats or a global admin database.
Validate and rebuild cumulative stats if needed.
Currently a placeholder for future implementation.
"""
logger.info("[Scheduler] Cumulative stats validation skipped (multi-tenant mode)")
return
async def _process_task_type(self, task_type: str, db: Session, cycle_summary: Dict[str, Any] = None, user_id: str = None) -> Optional[Dict[str, Any]]:
"""
Process due tasks for a specific task type.
Returns:
Summary dict with 'found', 'executed', 'failed' counts, or None if no tasks
"""
summary = {'found': 0, 'executed': 0, 'failed': 0}
pass
async def _process_task_type(
self,
task_type: str,
db: Session,
cycle_summary: Dict[str, Any],
user_id: Optional[str] = None
) -> Dict[str, int]:
summary = {"found": 0, "executed": 0, "failed": 0}
try:
# Get task loader for this type
try:
task_loader = self.registry.get_task_loader(task_type)
except Exception as e:
error = TaskLoaderError(
message=f"Failed to get task loader for type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
return None
# Load due tasks (with error handling)
try:
due_tasks = task_loader(db)
except Exception as e:
error = TaskLoaderError(
message=f"Failed to load due tasks for type {task_type}: {str(e)}",
task_type=task_type,
original_error=e
)
self.exception_handler.handle_exception(error)
return None
if not due_tasks:
return None
summary['found'] = len(due_tasks)
self.stats['tasks_found'] += len(due_tasks)
# Execute tasks (with concurrency limit)
execution_tasks = []
skipped_count = 0
for task in due_tasks:
if len(self.active_executions) >= self.max_concurrent_executions:
skipped_count = len(due_tasks) - len(execution_tasks)
logger.warning(
f"[Scheduler] ⚠️ Max concurrent executions reached ({self.max_concurrent_executions}), "
f"skipping {skipped_count} tasks for {task_type}"
)
break
# Execute task asynchronously
# Note: Each task gets its own database session to prevent concurrent access issues
execution_task = asyncio.create_task(
execute_task_async(self, task_type, task, summary, user_id=user_id)
)
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
self.active_executions[task_id] = execution_task
execution_tasks.append(execution_task)
# Wait for executions to complete (with timeout per task)
if execution_tasks:
await asyncio.wait(execution_tasks, timeout=300)
return summary
task_loader = self.registry.get_task_loader(task_type)
except Exception as e:
error = TaskLoaderError(
message=f"Error processing task type {task_type}: {str(e)}",
task_type=task_type,
message=f"Failed to get task loader for type {task_type}: {str(e)}",
user_id=user_id,
context={"task_type": task_type},
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats["tasks_failed"] += 1
return summary
def _update_user_stats(self, user_id: Optional[int], success: bool):
"""
Update per-user statistics for user isolation tracking.
Args:
user_id: User ID (None if user context not available)
success: Whether task execution was successful
"""
if user_id is None:
try:
tasks = task_loader(db)
if not tasks:
return summary
summary["found"] = len(tasks)
max_concurrent = self.max_concurrent_executions
for task in tasks:
task_id = getattr(task, "id", None)
lease_key = f"{task_type}_{task_id or id(task)}"
if self._is_task_leased(lease_key):
continue
if len(self.active_executions) >= max_concurrent:
break
if not self._acquire_task_lease(lease_key):
continue
execution_task = asyncio.create_task(
execute_task_async(
self,
task_type,
task,
summary,
execution_source="scheduler",
user_id=user_id,
)
)
self.active_executions[lease_key] = execution_task
cycle_summary.setdefault("tasks_found_by_type", {})
cycle_summary.setdefault("tasks_executed_by_type", {})
cycle_summary.setdefault("tasks_failed_by_type", {})
cycle_summary["tasks_found_by_type"][task_type] = (
cycle_summary["tasks_found_by_type"].get(task_type, 0)
+ summary["found"]
)
cycle_summary["tasks_executed_by_type"][task_type] = (
cycle_summary["tasks_executed_by_type"].get(task_type, 0)
+ summary["executed"]
)
cycle_summary["tasks_failed_by_type"][task_type] = (
cycle_summary["tasks_failed_by_type"].get(task_type, 0)
+ summary["failed"]
)
return summary
except Exception as e:
error = TaskLoaderError(
message=f"Error processing task type {task_type}: {str(e)}",
user_id=user_id,
context={"task_type": task_type},
original_error=e
)
self.exception_handler.handle_exception(error)
self.stats["tasks_failed"] += 1
return summary
def _update_user_stats(self, user_id: Optional[str], success: bool):
if not user_id:
return
if user_id not in self.stats['per_user_stats']:
self.stats['per_user_stats'][user_id] = {
'executed': 0,
'failed': 0,
'success_rate': 0.0
}
user_stats = self.stats['per_user_stats'][user_id]
per_user = self.stats.setdefault("per_user_stats", {})
user_stats = per_user.setdefault(
user_id,
{
"tasks_executed": 0,
"tasks_failed": 0,
"last_update": None,
},
)
if success:
user_stats['executed'] += 1
user_stats["tasks_executed"] += 1
else:
user_stats['failed'] += 1
# Calculate success rate
total = user_stats['executed'] + user_stats['failed']
if total > 0:
user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0
async def _schedule_retry(self, task: Any, delay_seconds: int):
"""Schedule a retry for a failed task."""
# This would update the task's next_execution time
# For now, just log - could be enhanced to update next_execution
logger.debug(f"Scheduling retry for task in {delay_seconds}s")
def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]:
"""
Get scheduler statistics with optional user filtering.
Args:
user_id: Optional user ID to filter statistics for specific user
Returns:
Dictionary with scheduler statistics
"""
base_stats = {
**{k: v for k, v in self.stats.items() if k not in ['per_user_stats']},
'active_executions': len(self.active_executions),
'registered_types': self.registry.get_registered_types(),
'running': self._running,
'check_interval_minutes': self.current_check_interval_minutes,
'min_check_interval_minutes': self.min_check_interval_minutes,
'max_check_interval_minutes': self.max_check_interval_minutes,
'intelligent_scheduling': True
}
# Include per-user stats (all users or filtered)
if user_id is not None:
if user_id in self.stats['per_user_stats']:
base_stats['user_stats'] = self.stats['per_user_stats'][user_id]
else:
base_stats['user_stats'] = {
'executed': 0,
'failed': 0,
'success_rate': 0.0
}
else:
# Include all per-user stats (for admin/debugging)
base_stats['per_user_stats'] = self.stats['per_user_stats']
return base_stats
user_stats["tasks_failed"] += 1
user_stats["last_update"] = datetime.utcnow().isoformat()
async def _schedule_retry(self, task: Any, retry_delay: int):
try:
task_id = getattr(task, "id", None)
logger.warning(
f"[Scheduler] Retry requested for task {task_id} in {retry_delay}s, "
f"using loader-based retry semantics."
)
except Exception:
pass
def schedule_one_time_task(
self,
func: Callable,
run_date: datetime,
job_id: str,
args: tuple = (),
kwargs: Dict[str, Any] = None,
kwargs: Optional[Dict[str, Any]] = None,
replace_existing: bool = True
) -> str:
"""
Schedule a one-time task to run at a specific datetime.
Schedule a one-time task execution.
Args:
func: Async function to execute
run_date: Datetime when the task should run (must be timezone-aware UTC)
job_id: Unique identifier for this job
args: Positional arguments to pass to func
kwargs: Keyword arguments to pass to func
replace_existing: If True, replace existing job with same ID
func: Function to execute
run_date: Date/time to run the task
job_id: Unique job ID
kwargs: Keyword arguments for the function
replace_existing: Whether to replace existing job with same ID
Returns:
Job ID
"""
if not self._running:
logger.warning(
f"Scheduler not running, but scheduling job {job_id} anyway. "
"APScheduler will start automatically when needed."
)
try:
# Ensure run_date is timezone-aware (UTC)
if run_date.tzinfo is None:
from datetime import timezone
run_date = run_date.replace(tzinfo=timezone.utc)
logger.debug(f"Added UTC timezone to run_date: {run_date}")
self.scheduler.add_job(
func,
trigger=DateTrigger(run_date=run_date),
args=args,
kwargs=kwargs or {},
id=job_id,
kwargs=kwargs or {},
replace_existing=replace_existing,
misfire_grace_time=3600 # 1 hour grace period for missed jobs
misfire_grace_time=3600 # 1 hour grace period
)
# Get updated job count
all_jobs = self.scheduler.get_jobs()
one_time_jobs = [j for j in all_jobs if j.id != 'check_due_tasks']
# Extract user_id from kwargs if available for logging and job store
user_id = kwargs.get('user_id', None) if kwargs else None
func_name = func.__name__ if hasattr(func, '__name__') else str(func)
# Get job store name for user (if user_id provided)
job_store_name = 'default'
if user_id:
try:
db = get_session_for_user(user_id)
if db:
job_store_name = get_user_job_store_name(user_id, db)
db.close()
except Exception as e:
logger.warning(f"Could not determine job store for user {user_id}: {e}")
# Note: APScheduler doesn't support dynamic job store creation
# We use 'default' for all jobs but log the user's job store name for debugging
# The actual user isolation is handled through task filtering by user_id
# Log detailed one-time task scheduling information (use WARNING level for visibility)
log_message = (
f"[Scheduler] 📅 Scheduled One-Time Task\n"
f" ├─ Job ID: {job_id}\n"
f" ├─ Function: {func_name}\n"
f" ├─ User ID: {user_id or 'system'}\n"
f" ├─ Job Store: {job_store_name} (user context)\n"
f" ├─ Scheduled For: {run_date}\n"
f" ├─ Replace Existing: {replace_existing}\n"
f" ├─ Total One-Time Jobs: {len(one_time_jobs)}\n"
f" └─ Total Scheduled Jobs: {len(all_jobs)}"
)
logger.warning(log_message)
# Log job scheduling to event log for dashboard
if user_id:
try:
event_db = get_session_for_user(user_id)
if event_db:
event_log = SchedulerEventLog(
event_type='job_scheduled',
event_date=datetime.utcnow(),
job_id=job_id,
job_type='one_time',
user_id=user_id,
event_data={
'function_name': func_name,
'job_store': job_store_name,
'scheduled_for': run_date.isoformat(),
'replace_existing': replace_existing
}
)
event_db.add(event_log)
event_db.commit()
event_db.close()
except Exception as e:
logger.debug(f"Failed to log job scheduling event: {e}")
logger.info(f"Scheduled one-time task {job_id} at {run_date}")
return job_id
except Exception as e:
logger.error(f"Failed to schedule one-time task {job_id}: {e}")
raise
def is_running(self) -> bool:
"""Check if scheduler is running."""
return self._running
async def execute_task_by_type(self, task_type: str, user_id: str, payload: Dict[str, Any]):
"""
Execute a task by type and payload immediately.
Used for one-time tasks triggered by system events.
"""
from collections import namedtuple
TaskStub = namedtuple('TaskStub', ['user_id', 'payload', 'id'])
task_stub = TaskStub(user_id=user_id, payload=payload, id=f"manual_{datetime.utcnow().timestamp()}")
await execute_task_async(self, task_type, task_stub, execution_source="manual")

View File

@@ -67,6 +67,77 @@ class StoryImageGenerationService:
clean_title = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in scene_title[:30])
unique_id = str(uuid.uuid4())[:8]
return f"scene_{scene_number}_{clean_title}_{unique_id}.png"
def _refine_image_prompt_with_bible(
self,
image_prompt: str,
scene: Dict[str, Any],
anime_bible: Optional[Dict[str, Any]] = None,
) -> str:
"""
Lightweight image prompt refinement using the anime story bible.
Takes the existing scene image_prompt and enriches it with visual_style,
world, and cast hints from the bible. This is deterministic and avoids
extra LLM calls.
"""
if not image_prompt or not isinstance(image_prompt, str):
return image_prompt
if not anime_bible or not isinstance(anime_bible, dict):
return image_prompt
visual_style = anime_bible.get("visual_style") or {}
world = anime_bible.get("world") or {}
main_cast = anime_bible.get("main_cast") or []
parts: List[str] = []
style_preset = visual_style.get("style_preset")
if style_preset:
parts.append(f"{style_preset} anime illustration style")
camera_style = visual_style.get("camera_style")
if camera_style:
parts.append(f"framing and camera style: {camera_style}")
color_mood = visual_style.get("color_mood")
if color_mood:
parts.append(f"color mood: {color_mood}")
lighting = visual_style.get("lighting")
if lighting:
parts.append(f"lighting: {lighting}")
line_style = visual_style.get("line_style")
if line_style:
parts.append(f"line style: {line_style}")
extra_tags = visual_style.get("extra_tags") or []
if isinstance(extra_tags, (list, tuple)):
extra_text = ", ".join(str(tag) for tag in extra_tags[:6] if tag)
if extra_text:
parts.append(extra_text)
setting = world.get("setting") if isinstance(world, dict) else None
if setting:
parts.append(f"world setting: {setting}")
if isinstance(main_cast, list):
names = [
c.get("name")
for c in main_cast
if isinstance(c, dict) and c.get("name")
]
if names:
joined = ", ".join(names[:4])
parts.append(f"keep character designs consistent for: {joined}")
if not parts:
return image_prompt
suffix = ", " + ", ".join(parts)
return image_prompt.strip() + suffix
def generate_scene_image(
self,
@@ -75,7 +146,8 @@ class StoryImageGenerationService:
provider: Optional[str] = None,
width: int = 1024,
height: int = 1024,
model: Optional[str] = None
model: Optional[str] = None,
anime_bible: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Generate an image for a single story scene.
@@ -94,6 +166,16 @@ class StoryImageGenerationService:
scene_number = scene.get("scene_number", 0)
scene_title = scene.get("title", "Untitled")
image_prompt = scene.get("image_prompt", "")
if anime_bible:
try:
image_prompt = self._refine_image_prompt_with_bible(
image_prompt=image_prompt,
scene=scene,
anime_bible=anime_bible,
)
except Exception as e:
logger.warning(f"[StoryImageGeneration] Failed to refine image prompt with bible: {e}")
if not image_prompt:
raise ValueError(f"Scene {scene_number} ({scene_title}) has no image_prompt")
@@ -156,7 +238,8 @@ class StoryImageGenerationService:
height: int = 1024,
model: Optional[str] = None,
progress_callback: Optional[callable] = None,
db: Optional[Session] = None
db: Optional[Session] = None,
anime_bible: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""
Generate images for multiple story scenes.
@@ -192,7 +275,7 @@ class StoryImageGenerationService:
width=width,
height=height,
model=model,
db=db
anime_bible=anime_bible,
)
image_results.append(image_result)
@@ -295,4 +378,3 @@ class StoryImageGenerationService:
except Exception as e:
logger.error(f"[StoryImageGeneration] Error regenerating image for scene {scene_number}: {e}")
raise RuntimeError(f"Failed to regenerate image for scene {scene_number}: {str(e)}") from e

View File

@@ -57,6 +57,7 @@ class StoryOutlineMixin(StoryServiceBase):
ending_preference: str,
user_id: str,
use_structured_output: bool = True,
include_anime_bible: bool = False,
) -> Any:
"""Generate a story outline with optional structured JSON output."""
persona_prompt = self.build_persona_prompt(

View File

@@ -145,20 +145,45 @@ Write ONLY the premise sentence(s). Do not write anything else.
"reasoning",
],
},
"minItems": 1,
"maxItems": 1,
}
},
"required": ["options"],
}
def _build_idea_enhance_schema(self) -> Dict[str, Any]:
return {
"type": "object",
"properties": {
"suggestions": {
"type": "array",
"items": {
"type": "object",
"properties": {
"idea": {"type": "string"},
"whats_missing": {"type": "string"},
"why_choose": {"type": "string"},
},
"required": ["idea", "whats_missing", "why_choose"],
},
"minItems": 3,
"maxItems": 3,
}
},
"required": ["options"],
"required": ["suggestions"],
}
def generate_story_setup_options(
self,
*,
story_idea: str,
story_mode: str | None,
story_template: str | None,
brand_context: Dict[str, Any] | None,
user_id: str,
) -> List[Dict[str, Any]]:
"""Generate 3 story setup options from a user's story idea."""
"""Generate a single story setup option from a user's story idea."""
suggested_writing_styles = ['Formal', 'Casual', 'Poetic', 'Humorous', 'Academic', 'Journalistic', 'Narrative']
suggested_story_tones = ['Dark', 'Uplifting', 'Suspenseful', 'Whimsical', 'Melancholic', 'Mysterious', 'Romantic', 'Adventurous']
@@ -167,12 +192,59 @@ Write ONLY the premise sentence(s). Do not write anything else.
suggested_content_ratings = ['G', 'PG', 'PG-13', 'R']
suggested_ending_preferences = ['Happy', 'Tragic', 'Cliffhanger', 'Twist', 'Open-ended', 'Bittersweet']
mode_label = None
if story_mode == "marketing":
mode_label = "Non-fiction marketing story (brand or product campaign)"
elif story_mode == "pure":
mode_label = "Fiction story"
template_label = None
if story_template == "product_story":
template_label = "Product Story"
elif story_template == "brand_manifesto":
template_label = "Brand Manifesto"
elif story_template == "founder_story":
template_label = "Founder Story"
elif story_template == "customer_story":
template_label = "Customer Story"
elif story_template == "short_fiction":
template_label = "Short Fiction"
elif story_template == "long_fiction":
template_label = "Long Fiction"
elif story_template == "anime_fiction":
template_label = "Anime Fiction"
elif story_template == "experimental_fiction":
template_label = "Experimental Fiction"
brand_name = None
writing_tone = None
audience_description = None
if isinstance(brand_context, dict):
brand_name = brand_context.get("brand_name")
writing_tone = brand_context.get("writing_tone")
target_audience = brand_context.get("target_audience")
if isinstance(target_audience, dict):
audience_description = target_audience.get("description") or target_audience.get("summary")
elif isinstance(target_audience, str):
audience_description = target_audience
setup_prompt = f"""\
You are an expert story writer and creative writing assistant. A user has provided the following story idea or information:
You are an expert story writer and creative writing assistant.
{"This is a " + mode_label + "." if mode_label else ""}
{("The user selected the template: " + template_label + ".") if template_label else ""}
The story should stay consistent with the brand and audience context below when relevant:
- Brand name or site: {brand_name or "Not specified"}
- Headline/overall writing tone: {writing_tone or "Not specified"}
- Audience description: {audience_description or "Not specified"}
The user has provided the following story idea or information:
{story_idea}
Based on this story idea, generate exactly 3 different, well-thought-out story setup options. Each option should be CREATIVE, PERSONALIZED, and perfectly tailored to the user's specific story idea.
Based on this story idea, generate exactly 1 well-thought-out story setup option. The setup should be CREATIVE, PERSONALIZED, and perfectly tailored to the user's specific story idea.
**CRITICAL - Creative Freedom:**
- You have COMPLETE FREEDOM to craft personalized values that best fit the user's story idea
@@ -183,7 +255,7 @@ Based on this story idea, generate exactly 3 different, well-thought-out story s
- Narrative POV: "Second Person (You)" or "Omniscient Narrator as Guide" (not just standard options)
- The goal is to create the PERFECT setup for THIS specific story, not to fit into generic categories
Each option should:
The setup should:
1. Have a unique and creative persona that fits the story idea perfectly
2. Define a compelling story setting that brings the idea to life
3. Describe interesting and engaging characters
@@ -212,23 +284,23 @@ Each option should:
**Remember:** These are ONLY suggestions. If a custom value better serves the story idea, CREATE IT!
Return exactly 3 options as a JSON array. Each option must include a "premise" field with the story premise.
Return exactly 1 option as a JSON array with a single object in "options". The object must include a "premise" field with the story premise.
"""
setup_schema = self._build_setup_schema()
try:
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
logger.info(f"[StoryWriter] Generating story setup option for user {user_id}")
response = self.load_json_response(
llm_text_gen(prompt=setup_prompt, json_struct=setup_schema, user_id=user_id)
)
options = response.get("options", [])
if len(options) != 3:
logger.warning(f"[StoryWriter] Expected 3 options but got {len(options)}, correcting count")
if len(options) < 3:
raise ValueError(f"Expected 3 options but got {len(options)}")
options = options[:3]
if len(options) != 1:
logger.warning(f"[StoryWriter] Expected 1 option but got {len(options)}, correcting count")
if len(options) < 1:
raise ValueError(f"Expected 1 option but got {len(options)}")
options = options[:1]
for idx, option in enumerate(options):
if not option.get("premise") or not option.get("premise", "").strip():
@@ -262,7 +334,7 @@ Return exactly 3 options as a JSON array. Each option must include a "premise" f
premise += "."
option["premise"] = premise
logger.info(f"[StoryWriter] Generated {len(options)} story setup options with premises for user {user_id}")
logger.info(f"[StoryWriter] Generated {len(options)} story setup option(s) with premise for user {user_id}")
return options
except HTTPException:
raise
@@ -273,3 +345,119 @@ Return exactly 3 options as a JSON array. Each option must include a "premise" f
logger.error(f"[StoryWriter] Error generating story setup options: {exc}")
raise RuntimeError(f"Failed to generate story setup options: {exc}") from exc
def enhance_story_idea(
self,
*,
story_idea: str,
story_mode: str | None,
story_template: str | None,
brand_context: Dict[str, Any] | None,
user_id: str,
fiction_variant: str | None = None,
narrative_energy: str | None = None,
) -> List[Dict[str, Any]]:
mode_label = None
if story_mode == "marketing":
mode_label = "Non-fiction marketing story (brand or product campaign)"
elif story_mode == "pure":
mode_label = "Fiction story"
template_label = None
if story_template == "product_story":
template_label = "Product Story"
elif story_template == "brand_manifesto":
template_label = "Brand Manifesto"
elif story_template == "founder_story":
template_label = "Founder Story"
elif story_template == "customer_story":
template_label = "Customer Story"
elif story_template == "short_fiction":
template_label = "Short Fiction"
elif story_template == "long_fiction":
template_label = "Long Fiction"
elif story_template == "anime_fiction":
template_label = "Anime Fiction"
elif story_template == "experimental_fiction":
template_label = "Experimental Fiction"
brand_name = None
writing_tone = None
audience_description = None
if isinstance(brand_context, dict):
brand_name = brand_context.get("brand_name")
writing_tone = brand_context.get("writing_tone")
target_audience = brand_context.get("target_audience")
if isinstance(target_audience, dict):
audience_description = target_audience.get("description") or target_audience.get("summary")
elif isinstance(target_audience, str):
audience_description = target_audience
fiction_focus_line = ""
if fiction_variant:
fiction_focus_line = f'Treat the story as "{fiction_variant}" and lean into that creative focus.'
energy_line = ""
if narrative_energy:
energy_line = f'Target narrative energy: {narrative_energy}.'
enhance_prompt = f"""You are a creative writing coach helping a user refine and expand a story idea.
{"This is a " + mode_label + "." if mode_label else ""}
{("The user selected the template: " + template_label + ".") if template_label else ""}
{fiction_focus_line}
{energy_line}
When relevant, keep the idea aligned with this brand and audience context:
- Brand name or site: {brand_name or "Not specified"}
- Headline/overall writing tone: {writing_tone or "Not specified"}
- Audience description: {audience_description or "Not specified"}
The user has written the following story idea or concept:
{story_idea}
Your task is to propose exactly 3 alternative enhanced story idea options.
Each option must:
- Preserve the user's core premise and intent.
- Make the premise clearer and more compelling.
- Surface the central conflict or tension.
- Clarify the main characters and their goals.
- Strengthen the setting and stakes.
- Stay at the "idea" level, not a full outline or beat-by-beat breakdown.
For each option, return three fields:
- "idea": 2-4 sentences describing the improved story idea, suitable for a single textarea input.
- "whats_missing": 2-4 sentences explaining what important details are missing or underspecified in the current brief. Focus on gaps such as: protagonist details, antagonist or opposing force, stakes, setting and time period, audience/age group, subgenre or type of fiction (for example, anime vs grounded sci-fi), language or tone preferences, and any format constraints.
- "why_choose": 1-3 sentences explaining how this option interprets the original idea and why it might be a strong direction for the story.
Do not write a full story outline.
Do not output numbered lists or markdown formatting.
Return a single JSON object with a "suggestions" array of 3 items, where each item has the keys "idea", "whats_missing", and "why_choose"."""
schema = self._build_idea_enhance_schema()
try:
logger.info(f"[StoryWriter] Enhancing story idea with structured suggestions for user {user_id}")
response = self.load_json_response(
llm_text_gen(prompt=enhance_prompt, json_struct=schema, user_id=user_id)
)
suggestions = response.get("suggestions", [])
if len(suggestions) != 3:
logger.warning(
f"[StoryWriter] Expected 3 idea suggestions but got {len(suggestions)}, correcting count"
)
if len(suggestions) < 3:
raise ValueError(f"Expected 3 suggestions but got {len(suggestions)}")
suggestions = suggestions[:3]
return suggestions
except HTTPException:
raise
except json.JSONDecodeError as exc:
logger.error(f"[StoryWriter] Failed to parse JSON response for story idea enhancement: {exc}")
raise RuntimeError(f"Failed to parse story idea enhancement suggestions: {exc}") from exc
except Exception as exc:
logger.error(f"[StoryWriter] Error enhancing story idea: {exc}")
raise RuntimeError(f"Failed to enhance story idea: {exc}") from exc

View File

@@ -3,10 +3,12 @@
from __future__ import annotations
from typing import Any, Dict, List, Optional
import json
from fastapi import HTTPException
from loguru import logger
from services.llm_providers.main_text_generation import llm_text_gen
from services.story_writer.image_generation_service import StoryImageGenerationService
from .base import StoryServiceBase
@@ -36,6 +38,7 @@ class StoryContentMixin(StoryOutlineMixin):
content_rating: str,
ending_preference: str,
story_length: str = "Medium",
anime_bible: Optional[Dict[str, Any]] = None,
user_id: str,
) -> str:
"""Generate the starting section (or full short story)."""
@@ -52,6 +55,19 @@ class StoryContentMixin(StoryOutlineMixin):
ending_preference,
)
anime_bible_context = ""
if anime_bible:
try:
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
except Exception:
serialized_bible = str(anime_bible)
anime_bible_context = f"""
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
{serialized_bible}
"""
outline_text = self._format_outline_for_prompt(outline)
story_length_lower = story_length.lower()
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
@@ -61,6 +77,8 @@ class StoryContentMixin(StoryOutlineMixin):
short_story_prompt = f"""\
{persona_prompt}
{anime_bible_context}
You have a gripping premise in mind:
{premise}
@@ -154,6 +172,285 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
logger.error(f"Story Start Generation Error: {exc}")
raise RuntimeError(f"Failed to generate story start: {exc}") from exc
# ------------------------------------------------------------------ #
# Anime scene refinement
# ------------------------------------------------------------------ #
def refine_anime_scene_text(
self,
*,
scene: Dict[str, Any],
persona: str,
story_setting: str,
character_input: str,
plot_elements: str,
writing_style: str,
story_tone: str,
narrative_pov: str,
audience_age_group: str,
content_rating: str,
anime_bible: Optional[Dict[str, Any]],
user_id: str,
) -> Dict[str, Any]:
persona_prompt = self.build_persona_prompt(
persona,
story_setting,
character_input,
plot_elements,
writing_style,
story_tone,
narrative_pov,
audience_age_group,
content_rating,
"Neutral",
)
anime_bible_context = ""
if anime_bible:
try:
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
except Exception:
serialized_bible = str(anime_bible)
anime_bible_context = f"""
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
{serialized_bible}
"""
current_title = scene.get("title", "")
current_description = scene.get("description", "")
current_image_prompt = scene.get("image_prompt", "")
current_audio_narration = scene.get("audio_narration", "")
current_character_descriptions = scene.get("character_descriptions") or []
current_key_events = scene.get("key_events") or []
scene_schema: Dict[str, Any] = {
"type": "object",
"properties": {
"title": {"type": "string"},
"description": {"type": "string"},
"image_prompt": {"type": "string"},
"audio_narration": {"type": "string"},
"character_descriptions": {
"type": "array",
"items": {"type": "string"},
},
"key_events": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["title", "description", "image_prompt", "audio_narration"],
}
prompt = f"""
{persona_prompt}
{anime_bible_context}
You are refining a single anime story scene so that it fully respects the anime story bible for characters, world rules, and visual style.
Current scene:
- Title: {current_title}
- Description: {current_description}
- Image prompt: {current_image_prompt}
- Audio narration: {current_audio_narration}
- Character descriptions: {current_character_descriptions}
- Key events: {current_key_events}
Refine the scene so that:
- Title is concise and evocative
- Description clearly describes what happens in the scene
- Image prompt is vivid, visual, and aligned with the anime bible style and cast
- Audio narration is natural, spoken-friendly text matching the scene
- Character descriptions highlight key visual and personality traits relevant to this moment
- Key events list the main beats of the scene
Respond with JSON matching this schema:
{scene_schema}
"""
try:
raw = llm_text_gen(
prompt=prompt.strip(),
json_struct=scene_schema,
user_id=user_id,
)
data = self.load_json_response(raw)
except Exception as exc:
logger.warning(f"[StoryWriter] Failed to refine anime scene text via LLM: {exc}")
return {
"scene_number": scene.get("scene_number"),
"title": current_title,
"description": current_description,
"image_prompt": current_image_prompt,
"audio_narration": current_audio_narration,
"character_descriptions": current_character_descriptions,
"key_events": current_key_events,
}
refined = {
"scene_number": scene.get("scene_number"),
"title": data.get("title", current_title),
"description": data.get("description", current_description),
"image_prompt": data.get("image_prompt", current_image_prompt),
"audio_narration": data.get("audio_narration", current_audio_narration),
"character_descriptions": data.get(
"character_descriptions", current_character_descriptions
),
"key_events": data.get("key_events", current_key_events),
}
return refined
# ------------------------------------------------------------------ #
# Anime scene generation from bible
# ------------------------------------------------------------------ #
def generate_anime_scene_from_bible(
self,
*,
premise: str,
persona: str,
story_setting: str,
character_input: str,
plot_elements: str,
writing_style: str,
story_tone: str,
narrative_pov: str,
audience_age_group: str,
content_rating: str,
anime_bible: Dict[str, Any],
previous_scenes: Optional[List[Dict[str, Any]]],
target_scene_number: Optional[int],
user_id: str,
) -> Dict[str, Any]:
persona_prompt = self.build_persona_prompt(
persona,
story_setting,
character_input,
plot_elements,
writing_style,
story_tone,
narrative_pov,
audience_age_group,
content_rating,
"Neutral",
)
try:
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
except Exception:
serialized_bible = str(anime_bible)
anime_bible_context = f"""
You have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. You MUST treat it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
{serialized_bible}
"""
previous_summary_lines: List[str] = []
if previous_scenes:
for s in previous_scenes[:6]:
num = s.get("scene_number")
title = s.get("title") or ""
desc = s.get("description") or ""
summary = desc
if len(summary) > 200:
summary = summary[:197] + "..."
previous_summary_lines.append(
f"- Scene {num}: {title}{summary}".strip()
)
previous_block = ""
if previous_summary_lines:
previous_block = (
"\nPrevious scenes so far (for continuity, do NOT contradict):\n"
+ "\n".join(previous_summary_lines)
)
scene_schema: Dict[str, Any] = {
"type": "object",
"properties": {
"title": {"type": "string"},
"description": {"type": "string"},
"image_prompt": {"type": "string"},
"audio_narration": {"type": "string"},
"character_descriptions": {
"type": "array",
"items": {"type": "string"},
},
"key_events": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["title", "description", "image_prompt", "audio_narration"],
}
prompt = f"""
{persona_prompt}
{anime_bible_context}
You are generating a brand new anime story scene that must fully respect the anime story bible for characters, world rules, and visual style.
Overall premise:
{premise}
{previous_block}
Your task:
- Create the NEXT SCENE in this story.
- It must be consistent with the anime bible (cast, world rules, visual style).
- It must logically follow from any previous scenes given above.
Design the scene so that:
- Title is concise and evocative.
- Description clearly describes what happens in the scene.
- Image prompt is vivid, visual, and aligned with the anime bible style and cast.
- Audio narration is natural, spoken-friendly text matching the scene.
- Character descriptions highlight key visual and personality traits relevant to this moment.
- Key events list the main beats of the scene.
Respond with JSON matching this schema:
{scene_schema}
"""
try:
raw = llm_text_gen(
prompt=prompt.strip(),
json_struct=scene_schema,
user_id=user_id,
)
data = self.load_json_response(raw)
except Exception as exc:
logger.error(f"[StoryWriter] Failed to generate anime scene from bible: {exc}")
raise RuntimeError(f"Failed to generate anime scene from bible: {exc}") from exc
next_scene_number = target_scene_number
if next_scene_number is None:
if previous_scenes and len(previous_scenes) > 0:
last = previous_scenes[-1]
try:
last_num = int(last.get("scene_number") or 0)
except Exception:
last_num = len(previous_scenes)
next_scene_number = last_num + 1
else:
next_scene_number = 1
result = {
"scene_number": next_scene_number,
"title": data.get("title", "").strip(),
"description": data.get("description", "").strip(),
"image_prompt": data.get("image_prompt", "").strip(),
"audio_narration": data.get("audio_narration", "").strip(),
"character_descriptions": data.get("character_descriptions") or [],
"key_events": data.get("key_events") or [],
}
return result
# ------------------------------------------------------------------ #
# Continuation
# ------------------------------------------------------------------ #
@@ -174,6 +471,7 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
audience_age_group: str,
content_rating: str,
ending_preference: str,
anime_bible: Optional[Dict[str, Any]] = None,
story_length: str = "Medium",
user_id: str,
) -> str:
@@ -191,6 +489,19 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
ending_preference,
)
anime_bible_context = ""
if anime_bible:
try:
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
except Exception:
serialized_bible = str(anime_bible)
anime_bible_context = f"""
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
{serialized_bible}
"""
outline_text = self._format_outline_for_prompt(outline)
_, continuation_word_count = self._get_story_length_guidance(story_length)
current_word_count = len(story_text.split()) if story_text else 0
@@ -227,6 +538,8 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
continuation_prompt = f"""\
{persona_prompt}
{anime_bible_context}
You have a gripping premise in mind:
{premise}
@@ -298,6 +611,7 @@ You have written approximately {current_word_count} words so far, leaving approx
audience_age_group: str,
content_rating: str,
ending_preference: str,
anime_bible: Optional[Dict[str, Any]] = None,
user_id: str,
max_iterations: int = 10,
) -> Dict[str, Any]:
@@ -352,6 +666,7 @@ You have written approximately {current_word_count} words so far, leaving approx
audience_age_group=audience_age_group,
content_rating=content_rating,
ending_preference=ending_preference,
anime_bible=anime_bible,
user_id=user_id,
)
if not draft:
@@ -375,6 +690,7 @@ You have written approximately {current_word_count} words so far, leaving approx
audience_age_group=audience_age_group,
content_rating=content_rating,
ending_preference=ending_preference,
anime_bible=anime_bible,
user_id=user_id,
)
if continuation:
@@ -420,6 +736,7 @@ You have written approximately {current_word_count} words so far, leaving approx
height: int = 1024,
model: Optional[str] = None,
db: Optional[Session] = None,
anime_bible: Optional[Dict[str, Any]] = None,
) -> List[Dict[str, Any]]:
"""Generate images for story scenes."""
image_service = StoryImageGenerationService()
@@ -431,5 +748,6 @@ You have written approximately {current_word_count} words so far, leaving approx
height=height,
model=model,
db=db,
anime_bible=anime_bible,
)

View File

@@ -0,0 +1,133 @@
"""
Story Project Service
Service layer for managing Story Studio project persistence.
Modeled after PodcastService for a consistent project API.
"""
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple
from sqlalchemy import and_, desc
from sqlalchemy.orm import Session
from models.story_project_models import StoryProject
class StoryProjectService:
"""Service for managing Story Studio projects."""
def __init__(self, db: Session) -> None:
self.db = db
def create_project(
self,
user_id: str,
project_id: str,
title: Optional[str] = None,
story_mode: Optional[str] = None,
story_template: Optional[str] = None,
**kwargs: Any,
) -> StoryProject:
project = StoryProject(
project_id=project_id,
user_id=user_id,
title=title,
story_mode=story_mode,
story_template=story_template,
status="draft",
current_phase="setup",
**kwargs,
)
self.db.add(project)
self.db.commit()
self.db.refresh(project)
return project
def get_project(self, user_id: str, project_id: str) -> Optional[StoryProject]:
return (
self.db.query(StoryProject)
.filter(
and_(
StoryProject.project_id == project_id,
StoryProject.user_id == user_id,
)
)
.first()
)
def update_project(
self,
user_id: str,
project_id: str,
**updates: Any,
) -> Optional[StoryProject]:
project = self.get_project(user_id, project_id)
if not project:
return None
for key, value in updates.items():
if hasattr(project, key):
setattr(project, key, value)
project.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(project)
return project
def list_projects(
self,
user_id: str,
status: Optional[str] = None,
favorites_only: bool = False,
limit: int = 50,
offset: int = 0,
order_by: str = "updated_at",
) -> Tuple[List[StoryProject], int]:
query = self.db.query(StoryProject).filter(StoryProject.user_id == user_id)
if status:
query = query.filter(StoryProject.status == status)
if favorites_only:
query = query.filter(StoryProject.is_favorite.is_(True))
total = query.count()
if order_by == "created_at":
query = query.order_by(desc(StoryProject.created_at))
else:
query = query.order_by(desc(StoryProject.updated_at))
projects = query.offset(offset).limit(limit).all()
return projects, total
def delete_project(self, user_id: str, project_id: str) -> bool:
project = self.get_project(user_id, project_id)
if not project:
return False
self.db.delete(project)
self.db.commit()
return True
def toggle_favorite(self, user_id: str, project_id: str) -> Optional[StoryProject]:
project = self.get_project(user_id, project_id)
if not project:
return None
project.is_favorite = not project.is_favorite
project.updated_at = datetime.utcnow()
self.db.commit()
self.db.refresh(project)
return project
def update_status(
self,
user_id: str,
project_id: str,
status: str,
) -> Optional[StoryProject]:
return self.update_project(user_id, project_id, status=status)

View File

@@ -149,7 +149,7 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
try:
path = request.url.path
except Exception:
pass
path = ""
db = None
try:
@@ -159,8 +159,16 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
api_monitor = DatabaseAPIMonitor()
# Safe User-Agent access
user_agent = None
try:
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
user_agent = request.headers.get('user-agent')
except:
pass
# Detect if this is an API call that should be rate limited
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
api_provider = api_monitor.detect_api_provider(path, user_agent)
if not api_provider:
return None
@@ -236,9 +244,28 @@ async def monitoring_middleware(request: Request, call_next):
user_id = None
try:
# PRIORITY 1: Check request.state.user_id (set by API key injection middleware)
if hasattr(request.state, 'user_id') and request.state.user_id:
user_id = request.state.user_id
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
if hasattr(request.state, 'user_id'):
# Directly check and convert without accessing attribute if None
raw_user_id = request.state.user_id
# Defensive check for Depends object or other complex types
if raw_user_id is not None:
# If it's a string, use it
if isinstance(raw_user_id, str):
user_id = raw_user_id
# If it has a dependency attribute (likely a Depends object), ignore it
elif hasattr(raw_user_id, 'dependency'):
logger.warning(f"Monitoring: request.state.user_id is a Depends object, ignoring.")
user_id = None
# Try to convert to string if it's a simple type
else:
try:
user_id = str(raw_user_id)
except:
user_id = None
if user_id:
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
# PRIORITY 2: Check query parameters
elif hasattr(request, 'query_params') and 'user_id' in request.query_params:
@@ -247,20 +274,23 @@ async def monitoring_middleware(request: Request, call_next):
user_id = request.path_params['user_id']
# PRIORITY 3: Check headers for user identification
elif 'x-user-id' in request.headers:
user_id = request.headers['x-user-id']
elif 'x-user-email' in request.headers:
user_id = request.headers['x-user-email'] # Use email as user identifier
elif 'x-session-id' in request.headers:
user_id = request.headers['x-session-id'] # Use session as fallback
# Check for authorization header with user info
elif 'authorization' in request.headers:
# Auth middleware should have set request.state.user_id
# If not, this indicates an authentication failure (likely expired token)
# Log at debug level to reduce noise - expired tokens are expected
# But we can try to decode token if we really needed to, but let's rely on auth middleware
pass
elif hasattr(request, 'headers') and hasattr(request.headers, 'get'):
try:
if request.headers.get('x-user-id'):
user_id = request.headers.get('x-user-id')
elif request.headers.get('x-user-email'):
user_id = request.headers.get('x-user-email')
elif request.headers.get('x-session-id'):
user_id = request.headers.get('x-session-id')
# Check for authorization header with user info
elif request.headers.get('authorization'):
# Auth middleware should have set request.state.user_id
# If not, this indicates an authentication failure (likely expired token)
# Log at debug level to reduce noise - expired tokens are expected
pass
except Exception as e:
logger.debug(f"Error accessing request headers: {e}")
except Exception as e:
logger.debug(f"Error extracting user ID: {e}")
@@ -269,7 +299,11 @@ async def monitoring_middleware(request: Request, call_next):
# Get database session if user identified
db = None
if user_id:
db = get_session_for_user(user_id)
try:
db = get_session_for_user(user_id)
except Exception as e:
logger.error(f"Failed to get database session for user {user_id}: {e}")
db = None
# Capture request body for usage tracking (read once, safely)
request_body = None
@@ -291,29 +325,52 @@ async def monitoring_middleware(request: Request, call_next):
request_body = None
# Check usage limits before processing
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
if limit_response:
if db: db.close()
return limit_response
# Skip for OPTIONS requests
try:
if request.method != "OPTIONS":
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
if limit_response:
if db: db.close()
return limit_response
except Exception as e:
logger.error(f"Error in usage limits middleware: {e}")
# Continue processing if usage check fails (fail open)
try:
response = await call_next(request)
status_code = response.status_code
duration = time.time() - start_time
# Capture response body for usage tracking
# Extract response body safely for usage tracking
response_body = None
try:
if hasattr(response, 'body'):
response_body = response.body.decode('utf-8') if response.body else None
elif hasattr(response, '_content'):
response_body = response._content.decode('utf-8') if response._content else None
except:
pass
if hasattr(response, 'body'):
response_body = response.body.decode('utf-8') if response.body else None
elif hasattr(response, '_content'):
response_body = response._content.decode('utf-8') if response._content else None
# Track API usage if this is an API call to external providers
api_monitor = DatabaseAPIMonitor()
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
# Safe URL path access
try:
path = request.url.path
except:
path = ""
# Safe User-Agent access - handle case where headers might be a Depends object
user_agent = None
try:
# Defensive check: ensure request.headers is a valid headers object
# Some dependency injection failures replace request attributes with Depends objects
if hasattr(request, 'headers'):
headers_obj = request.headers
# Check if it has a 'get' method (like a dict or Headers object)
if hasattr(headers_obj, 'get') and callable(headers_obj.get):
user_agent = headers_obj.get('user-agent')
except:
pass
api_provider = api_monitor.detect_api_provider(path, user_agent)
if api_provider and user_id:
logger.info(f"Detected API call: {request.url.path} -> {api_provider.value} for user: {user_id}")
try:
@@ -326,7 +383,7 @@ async def monitoring_middleware(request: Request, call_next):
await usage_service.track_api_usage(
user_id=user_id,
provider=api_provider,
endpoint=request.url.path,
endpoint=path,
method=request.method,
model_used=usage_metrics.get('model_used'),
tokens_input=usage_metrics.get('tokens_input', 0),
@@ -335,7 +392,7 @@ async def monitoring_middleware(request: Request, call_next):
status_code=status_code,
request_size=len(request_body) if request_body else None,
response_size=len(response_body) if response_body else None,
user_agent=request.headers.get('user-agent'),
user_agent=user_agent,
ip_address=request.client.host if request.client else None,
search_count=usage_metrics.get('search_count', 0),
image_count=usage_metrics.get('image_count', 0),

View File

@@ -0,0 +1,487 @@
import os
import stripe
from typing import Optional, Dict, Any
from loguru import logger
from fastapi import HTTPException
from sqlalchemy.orm import Session
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning
from services.subscription.pricing_service import PricingService
from datetime import datetime
STRIPE_PLAN_PRICE_MAPPING = {
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value): "price_1T2lWHR2EuR7zQJepLIVQ1EJ",
(SubscriptionTier.PRO.value, BillingCycle.MONTHLY.value): "price_1T2ljDR2EuR7zQJeuS317KCj",
}
STRIPE_PRICE_TO_PLAN = {
price_id: {"tier": SubscriptionTier(tier), "billing_cycle": BillingCycle(billing_cycle)}
for (tier, billing_cycle), price_id in STRIPE_PLAN_PRICE_MAPPING.items()
}
class StripeService:
def __init__(self, db: Session):
self.db = db
self.api_key = os.getenv("STRIPE_SECRET_KEY")
self.webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
if not self.api_key:
logger.warning("STRIPE_SECRET_KEY is not set. Stripe integration will not work.")
else:
stripe.api_key = self.api_key
def _get_price_id_for_plan(self, tier: SubscriptionTier, billing_cycle: BillingCycle) -> str:
key = (tier.value, billing_cycle.value)
price_id = STRIPE_PLAN_PRICE_MAPPING.get(key)
if not price_id:
logger.error(f"No Stripe price configured for tier={tier.value}, billing_cycle={billing_cycle.value}")
raise HTTPException(status_code=400, detail="Payment plan is not configured")
return price_id
def _get_plan_for_price_id(self, price_id: str) -> tuple[SubscriptionPlan, BillingCycle]:
mapping = STRIPE_PRICE_TO_PLAN.get(price_id)
if not mapping:
logger.error(f"Unknown Stripe price_id: {price_id}")
raise HTTPException(status_code=400, detail="Unknown payment price configuration")
tier = mapping["tier"]
billing_cycle = mapping["billing_cycle"]
plan = (
self.db.query(SubscriptionPlan)
.filter(SubscriptionPlan.tier == tier, SubscriptionPlan.is_active == True)
.order_by(SubscriptionPlan.price_monthly)
.first()
)
if not plan:
logger.error(f"No subscription plan found for tier={tier.value}")
raise HTTPException(status_code=400, detail="Subscription plan not found for payment price")
return plan, billing_cycle
def _get_or_create_customer(self, user_id: str, email: Optional[str] = None) -> str:
"""
Get existing Stripe customer ID for user, or create a new one.
"""
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()
if subscription and subscription.stripe_customer_id:
return subscription.stripe_customer_id
# Search Stripe for existing customer by email (if provided) or metadata
try:
# If we have an email, search by email first
if email:
existing_customers = stripe.Customer.list(email=email, limit=1)
if existing_customers and len(existing_customers.data) > 0:
customer = existing_customers.data[0]
# Update DB
if subscription:
subscription.stripe_customer_id = customer.id
self.db.commit()
return customer.id
# Search by metadata user_id
existing_customers = stripe.Customer.search(
query=f"metadata['user_id']:'{user_id}'",
limit=1
)
if existing_customers and len(existing_customers.data) > 0:
customer = existing_customers.data[0]
if subscription:
subscription.stripe_customer_id = customer.id
self.db.commit()
return customer.id
except Exception as e:
logger.error(f"Error searching Stripe customer: {e}")
# Create new customer
try:
customer_data = {
"metadata": {"user_id": user_id},
}
if email:
customer_data["email"] = email
customer = stripe.Customer.create(**customer_data)
# Update DB
if subscription:
subscription.stripe_customer_id = customer.id
else:
# Create a placeholder subscription record if none exists (usually created on signup/free tier)
# But typically we expect a free tier record to exist.
pass
self.db.commit()
return customer.id
except Exception as e:
logger.error(f"Error creating Stripe customer: {e}")
raise HTTPException(status_code=500, detail="Failed to create payment profile")
def create_checkout_session(
self,
user_id: str,
tier: SubscriptionTier,
billing_cycle: BillingCycle,
success_url: str,
cancel_url: str,
user_email: Optional[str] = None,
) -> str:
"""
Create a Stripe Checkout Session for a subscription.
"""
if not self.api_key:
raise HTTPException(status_code=500, detail="Payment service not configured")
price_id = self._get_price_id_for_plan(tier, billing_cycle)
customer_id = self._get_or_create_customer(user_id, user_email)
line_item: Dict[str, Any] = {"price": price_id}
try:
price = stripe.Price.retrieve(price_id)
recurring = getattr(price, "recurring", None)
usage_type = None
if recurring:
if isinstance(recurring, dict):
usage_type = recurring.get("usage_type")
else:
usage_type = getattr(recurring, "usage_type", None)
if usage_type != "metered":
line_item["quantity"] = 1
else:
logger.info(f"Detected metered price {price_id}; omitting quantity in Checkout line item")
except Exception as e:
logger.error(f"Error inspecting Stripe price {price_id}: {e}")
line_item["quantity"] = 1
try:
checkout_session = stripe.checkout.Session.create(
customer=customer_id,
payment_method_types=["card"],
line_items=[line_item],
mode="subscription",
success_url=success_url,
cancel_url=cancel_url,
metadata={
"user_id": user_id,
"price_id": price_id,
},
subscription_data={
"metadata": {
"user_id": user_id,
}
},
allow_promotion_codes=True,
)
return checkout_session.url
except Exception as e:
logger.error(f"Error creating checkout session: {e}")
raise HTTPException(status_code=500, detail=str(e))
def create_portal_session(self, user_id: str, return_url: str) -> str:
"""
Create a Stripe Customer Portal session for managing billing.
"""
if not self.api_key:
raise HTTPException(status_code=500, detail="Payment service not configured")
subscription = self.db.query(UserSubscription).filter(
UserSubscription.user_id == user_id
).first()
if not subscription or not subscription.stripe_customer_id:
# Try to find customer by user_id if not in DB
try:
customers = stripe.Customer.search(query=f"metadata['user_id']:'{user_id}'", limit=1)
if customers and len(customers.data) > 0:
customer_id = customers.data[0].id
# Update DB while we're at it
if subscription:
subscription.stripe_customer_id = customer_id
self.db.commit()
else:
raise HTTPException(status_code=400, detail="No billing profile found for this user")
except Exception as e:
logger.error(f"Error finding customer for portal: {e}")
raise HTTPException(status_code=500, detail="Failed to access billing portal")
else:
customer_id = subscription.stripe_customer_id
try:
portal_session = stripe.billing_portal.Session.create(
customer=customer_id,
return_url=return_url,
)
return portal_session.url
except Exception as e:
logger.error(f"Error creating portal session: {e}")
raise HTTPException(status_code=500, detail=str(e))
async def handle_webhook(self, payload: bytes, sig_header: str):
"""
Handle Stripe webhooks.
"""
if not self.webhook_secret:
logger.warning("STRIPE_WEBHOOK_SECRET not set. Ignoring webhook.")
return
try:
event = stripe.Webhook.construct_event(
payload, sig_header, self.webhook_secret
)
except ValueError as e:
logger.error(f"Invalid payload: {e}")
raise HTTPException(status_code=400, detail="Invalid payload")
except stripe.error.SignatureVerificationError as e:
logger.error(f"Invalid signature: {e}")
raise HTTPException(status_code=400, detail="Invalid signature")
event_type = event["type"]
data = event["data"]["object"]
logger.info(f"Received Stripe webhook: {event_type}")
if event_type == "checkout.session.completed":
await self._handle_checkout_completed(data)
elif event_type == "invoice.payment_succeeded":
await self._handle_invoice_payment_succeeded(data)
elif event_type == "invoice.payment_failed":
await self._handle_invoice_payment_failed(data)
elif event_type == "customer.subscription.updated":
await self._handle_subscription_updated(data)
elif event_type == "customer.subscription.deleted":
await self._handle_subscription_deleted(data)
elif event_type.startswith("radar.early_fraud_warning."):
await self._handle_early_fraud_warning(data)
return {"status": "success"}
async def _handle_checkout_completed(self, session: Dict[str, Any]):
"""
Handle successful checkout.
"""
user_id = session.get("metadata", {}).get("user_id")
customer_id = session.get("customer")
subscription_id = session.get("subscription")
if not user_id:
logger.error("No user_id in checkout session metadata")
return
logger.info(f"Checkout completed for user {user_id}")
# Retrieve subscription details to get the plan/price
if subscription_id:
try:
sub = stripe.Subscription.retrieve(subscription_id)
price_id = sub['items']['data'][0]['price']['id']
# Map price_id to internal plan_id
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
# For now, we'll assume the metadata or a lookup.
# Ideally, store price_id in SubscriptionPlan table or config.
# Update DB
self._update_user_subscription(
user_id,
stripe_customer_id=customer_id,
stripe_subscription_id=subscription_id,
status="active",
price_id=price_id
)
except Exception as e:
logger.error(f"Error processing checkout subscription: {e}")
async def _handle_invoice_payment_succeeded(self, invoice: Dict[str, Any]):
"""
Handle recurring payment success.
"""
subscription_id = invoice.get("subscription")
customer_id = invoice.get("customer")
if not subscription_id:
return
# Find user by stripe_subscription_id or customer_id
subscription = self.db.query(UserSubscription).filter(
(UserSubscription.stripe_subscription_id == subscription_id) |
(UserSubscription.stripe_customer_id == customer_id)
).first()
if subscription:
logger.info(f"Payment succeeded for user {subscription.user_id}")
subscription.status = UsageStatus.ACTIVE
subscription.is_active = True
# Update period end based on invoice lines period
if invoice.get('lines'):
period_end = invoice['lines']['data'][0]['period']['end']
subscription.current_period_end = datetime.fromtimestamp(period_end)
self.db.commit()
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
subscription_id = invoice.get("subscription")
customer_id = invoice.get("customer")
if not subscription_id:
return
subscription = self.db.query(UserSubscription).filter(
(UserSubscription.stripe_subscription_id == subscription_id) |
(UserSubscription.stripe_customer_id == customer_id)
).first()
if subscription:
logger.warning(f"Payment failed for user {subscription.user_id}")
subscription.status = UsageStatus.PAST_DUE
subscription.is_active = False
self.db.commit()
async def _handle_subscription_updated(self, subscription_obj: Dict[str, Any]):
"""
Handle subscription updates (cancellations, changes).
"""
stripe_sub_id = subscription_obj.get("id")
status = subscription_obj.get("status")
subscription = self.db.query(UserSubscription).filter(
UserSubscription.stripe_subscription_id == stripe_sub_id
).first()
if subscription:
logger.info(f"Subscription {stripe_sub_id} updated to {status}")
if status in ["active", "trialing"]:
subscription.status = UsageStatus.ACTIVE
subscription.is_active = True
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
subscription.status = UsageStatus.PAST_DUE
subscription.is_active = False
elif status in ["canceled"]:
subscription.status = UsageStatus.CANCELLED
subscription.is_active = False
subscription.auto_renew = False
self.db.commit()
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
"""
Handle subscription cancellation (immediate).
"""
stripe_sub_id = subscription_obj.get("id")
subscription = self.db.query(UserSubscription).filter(
UserSubscription.stripe_subscription_id == stripe_sub_id
).first()
if subscription:
logger.info(f"Subscription {stripe_sub_id} deleted")
subscription.status = UsageStatus.CANCELLED # Need to check if this enum value exists
subscription.is_active = False
subscription.auto_renew = False
self.db.commit()
async def _handle_early_fraud_warning(self, warning_obj: Dict[str, Any]):
efw_id = warning_obj.get("id")
if not efw_id:
return
charge_id = warning_obj.get("charge")
payment_intent_id = warning_obj.get("payment_intent")
created_ts = warning_obj.get("created")
created_at = datetime.utcfromtimestamp(created_ts) if created_ts else datetime.utcnow()
amount = 0
currency = ""
user_id = None
charge_data: Dict[str, Any] = {}
if charge_id and self.api_key:
try:
charge = stripe.Charge.retrieve(charge_id)
charge_data = charge.to_dict() if hasattr(charge, "to_dict") else dict(charge)
amount = charge_data.get("amount") or 0
currency = charge_data.get("currency") or ""
metadata = charge_data.get("metadata") or {}
user_id = metadata.get("user_id")
except Exception as e:
logger.error(f"Error retrieving charge for early fraud warning {efw_id}: {e}")
if not amount:
amount = warning_obj.get("amount") or 0
if not currency:
currency = warning_obj.get("currency") or ""
existing = self.db.query(FraudWarning).filter(FraudWarning.id == efw_id).first()
metadata_payload: Dict[str, Any] = {
"early_fraud_warning": warning_obj,
}
if charge_data:
metadata_payload["charge"] = charge_data
if existing:
existing.charge_id = charge_id or existing.charge_id
existing.payment_intent_id = payment_intent_id or existing.payment_intent_id
if user_id:
existing.user_id = user_id
if amount:
existing.amount = amount
if currency:
existing.currency = currency
existing.status = "open"
existing.meta_info = metadata_payload
else:
if not charge_id:
return
warning = FraudWarning(
id=efw_id,
charge_id=charge_id,
payment_intent_id=payment_intent_id,
user_id=user_id,
amount=amount or 0,
currency=currency or "",
status="open",
action="none",
meta_info=metadata_payload,
created_at=created_at,
)
self.db.add(warning)
self.db.commit()
def _update_user_subscription(
self,
user_id: str,
stripe_customer_id: str,
stripe_subscription_id: str,
status: str,
price_id: str,
):
plan, billing_cycle = self._get_plan_for_price_id(price_id)
subscription = (
self.db.query(UserSubscription)
.filter(UserSubscription.user_id == user_id)
.first()
)
now = datetime.utcnow()
if not subscription:
subscription = UserSubscription(
user_id=user_id,
plan_id=plan.id,
billing_cycle=billing_cycle,
current_period_start=now,
current_period_end=now,
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
is_active=status == "active",
auto_renew=True,
)
self.db.add(subscription)
else:
subscription.plan_id = plan.id
subscription.billing_cycle = billing_cycle
subscription.is_active = status == "active"
subscription.stripe_customer_id = stripe_customer_id
subscription.stripe_subscription_id = stripe_subscription_id
self.db.commit()

View File

@@ -39,9 +39,34 @@ def _generate_simple_infinitetalk_prompt(
# Build a balanced prompt: scene description + simple motion hint
parts = []
# Start with the main subject/scene
# Add scene context
if title and len(title) > 5 and title.lower() not in ("scene", "podcast", "episode"):
parts.append(title)
# Add analysis context
analysis = story_context.get("analysis", {})
if analysis:
content_type = analysis.get("content_type")
if content_type:
parts.append(f"Style: {content_type}")
# Audience helps define the formality/vibe
audience = analysis.get("audience")
if audience:
# Just use first few words of audience to keep it short
short_audience = " ".join(audience.split()[:3])
parts.append(f"For: {short_audience}")
# Add bible context if available
bible = story_context.get("bible", {})
if bible:
host_persona = bible.get("host_persona")
tone = bible.get("tone")
if host_persona:
parts.append(f"Host: {host_persona}")
if tone:
parts.append(f"Tone: {tone}")
elif description:
# Take first sentence or first 60 chars
desc_part = description.split('.')[0][:60].strip()

View File

@@ -52,6 +52,46 @@ def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str,
image_prompt = (scene_data.get("image_prompt") or "").strip()
tone = (story_context.get("story_tone") or "story").strip()
setting = (story_context.get("story_setting") or "the scene").strip()
anime_bible = story_context.get("anime_bible") or {}
anime_style_parts = []
if isinstance(anime_bible, dict):
visual_style = anime_bible.get("visual_style") or {}
world = anime_bible.get("world") or {}
main_cast = anime_bible.get("main_cast") or []
style_preset = visual_style.get("style_preset")
camera_style = visual_style.get("camera_style")
color_mood = visual_style.get("color_mood")
lighting = visual_style.get("lighting")
line_style = visual_style.get("line_style")
extra_tags = visual_style.get("extra_tags") or []
if style_preset:
anime_style_parts.append(f"Follow {style_preset} anime visual style.")
if camera_style:
anime_style_parts.append(f"Use camera style: {camera_style}.")
if color_mood:
anime_style_parts.append(f"Color mood: {color_mood}.")
if lighting:
anime_style_parts.append(f"Lighting: {lighting}.")
if line_style:
anime_style_parts.append(f"Line art: {line_style}.")
if extra_tags:
anime_style_parts.append("Style tags: " + ", ".join(str(tag) for tag in extra_tags[:6]))
if world:
setting_desc = world.get("setting")
if setting_desc:
anime_style_parts.append(f"World context: {setting_desc}.")
if main_cast:
names = [c.get("name") for c in main_cast if isinstance(c, dict) and c.get("name")]
if names:
joined = ", ".join(names[:4])
anime_style_parts.append(f"Keep character designs consistent for: {joined}.")
anime_style_text = " ".join(anime_style_parts).strip()
parts = [
f"{title} cinematic motion shot.",
@@ -60,6 +100,7 @@ def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str,
f"Maintain a {tone} mood with natural lighting accents.",
f"Honor the original illustration details: {image_prompt[:200]}." if image_prompt else "",
"5-second sequence, gentle push-in, flowing cloth and atmospheric particles.",
anime_style_text,
]
fallback_prompt = " ".join(filter(None, parts))
return fallback_prompt.strip()
@@ -142,6 +183,66 @@ def generate_animation_prompt(
title = scene_data.get("title", "")
tone = story_context.get("story_tone") or story_context.get("story_tone", "")
setting = story_context.get("story_setting") or story_context.get("story_setting", "")
anime_bible = story_context.get("anime_bible") or {}
anime_bible_block = ""
if isinstance(anime_bible, dict) and anime_bible:
try:
visual_style = anime_bible.get("visual_style") or {}
world = anime_bible.get("world") or {}
main_cast = anime_bible.get("main_cast") or []
style_lines = []
if visual_style:
style_preset = visual_style.get("style_preset")
camera_style = visual_style.get("camera_style")
color_mood = visual_style.get("color_mood")
lighting = visual_style.get("lighting")
line_style = visual_style.get("line_style")
extra_tags = visual_style.get("extra_tags") or []
if style_preset:
style_lines.append(f"- Visual style preset: {style_preset}")
if camera_style:
style_lines.append(f"- Preferred camera style: {camera_style}")
if color_mood:
style_lines.append(f"- Color mood: {color_mood}")
if lighting:
style_lines.append(f"- Lighting: {lighting}")
if line_style:
style_lines.append(f"- Line art style: {line_style}")
if extra_tags:
style_lines.append(
"- Extra style tags: " + ", ".join(str(tag) for tag in extra_tags[:6])
)
cast_line = ""
if main_cast:
names = [c.get("name") for c in main_cast if isinstance(c, dict) and c.get("name")]
if names:
cast_line = "- Main cast to keep visually consistent: " + ", ".join(names[:4])
world_line = ""
if world:
setting_desc = world.get("setting")
if setting_desc:
world_line = "- World/setting context: " + str(setting_desc)
detail_lines = []
if cast_line:
detail_lines.append(cast_line)
if world_line:
detail_lines.append(world_line)
detail_lines.extend(style_lines)
if detail_lines:
anime_bible_block = (
"\nANIME STORY BIBLE VISUAL GUIDANCE:\n"
+ "\n".join(detail_lines)
+ "\nAlways respect these constraints in the motion description."
)
except Exception:
anime_bible_block = ""
prompt = f"""
Create a concise animation prompt (2-3 sentences) for a 5-second cinematic clip.
@@ -151,6 +252,7 @@ Description: {description}
Existing Image Prompt: {image_prompt}
Story Tone: {tone}
Setting: {setting}
{anime_bible_block}
Focus on:
- Motion of characters/objects

View File

@@ -132,7 +132,19 @@ class YouTubeSceneBuilderService:
) -> List[Dict[str, Any]]:
"""Generate scenes from video plan using AI."""
content_outline = video_plan.get("content_outline", [])
raw_content_outline = video_plan.get("content_outline", [])
content_outline: List[Dict[str, Any]] = []
for item in raw_content_outline:
if isinstance(item, dict):
content_outline.append(item)
else:
content_outline.append(
{
"section": str(item),
"description": "",
"duration_estimate": 0,
}
)
hook_strategy = video_plan.get("hook_strategy", "")
call_to_action = video_plan.get("call_to_action", "")
visual_style = video_plan.get("visual_style", "cinematic")
@@ -263,16 +275,32 @@ Write narration that:
# Normalize scene data
normalized_scenes = []
for idx, scene in enumerate(scenes, 1):
normalized_scenes.append({
"scene_number": scene.get("scene_number", idx),
"title": scene.get("title", f"Scene {idx}"),
"narration": scene.get("narration", ""),
"visual_description": scene.get("visual_description", ""),
"duration_estimate": scene.get("duration_estimate", scene_duration_range[0]),
"emphasis": scene.get("emphasis", "main_content"),
"visual_cues": scene.get("visual_cues", []),
"visual_prompt": scene.get("visual_description", ""), # Initial prompt
})
if isinstance(scene, dict):
scene_data = scene
else:
scene_data = {
"scene_number": idx,
"title": f"Scene {idx}",
"narration": str(scene),
"visual_description": "",
"duration_estimate": scene_duration_range[0],
"emphasis": "main_content",
"visual_cues": [],
}
normalized_scenes.append(
{
"scene_number": scene_data.get("scene_number", idx),
"title": scene_data.get("title", f"Scene {idx}"),
"narration": scene_data.get("narration", ""),
"visual_description": scene_data.get("visual_description", ""),
"duration_estimate": scene_data.get(
"duration_estimate", scene_duration_range[0]
),
"emphasis": scene_data.get("emphasis", "main_content"),
"visual_cues": scene_data.get("visual_cues", []),
"visual_prompt": scene_data.get("visual_description", ""),
}
)
return normalized_scenes
@@ -287,16 +315,32 @@ Write narration that:
normalized_scenes = []
for idx, scene in enumerate(scenes, 1):
normalized_scenes.append({
"scene_number": scene.get("scene_number", idx),
"title": scene.get("title", f"Scene {idx}"),
"narration": scene.get("narration", ""),
"visual_description": scene.get("visual_description", ""),
"duration_estimate": scene.get("duration_estimate", scene_duration_range[0]),
"emphasis": scene.get("emphasis", "main_content"),
"visual_cues": scene.get("visual_cues", []),
"visual_prompt": scene.get("visual_description", ""), # Initial prompt
})
if isinstance(scene, dict):
scene_data = scene
else:
scene_data = {
"scene_number": idx,
"title": f"Scene {idx}",
"narration": str(scene),
"visual_description": "",
"duration_estimate": scene_duration_range[0],
"emphasis": "main_content",
"visual_cues": [],
}
normalized_scenes.append(
{
"scene_number": scene_data.get("scene_number", idx),
"title": scene_data.get("title", f"Scene {idx}"),
"narration": scene_data.get("narration", ""),
"visual_description": scene_data.get("visual_description", ""),
"duration_estimate": scene_data.get(
"duration_estimate", scene_duration_range[0]
),
"emphasis": scene_data.get("emphasis", "main_content"),
"visual_cues": scene_data.get("visual_cues", []),
"visual_prompt": scene_data.get("visual_description", ""),
}
)
logger.info(
f"[YouTubeSceneBuilder] ✅ Normalized {len(normalized_scenes)} scenes "