Release Candidate: Production Release with Multi-Tenant & Onboarding Enhancements
This commit is contained in:
@@ -121,7 +121,8 @@ class BaseALwrityAgent(ABC):
|
||||
if TXTAI_AVAILABLE:
|
||||
try:
|
||||
if not self.llm:
|
||||
self.llm = LLM(model_name)
|
||||
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
|
||||
self.llm = LLM(model_name, task="text-generation")
|
||||
|
||||
self.txtai_agent = self._create_txtai_agent()
|
||||
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
|
||||
|
||||
@@ -4,7 +4,6 @@ Bing Webmaster Tools Analytics Handler
|
||||
Handles Bing Webmaster Tools analytics data retrieval and processing.
|
||||
"""
|
||||
|
||||
import requests
|
||||
from typing import Dict, Any
|
||||
from datetime import datetime, timedelta
|
||||
from loguru import logger
|
||||
@@ -16,13 +15,23 @@ from ..models.platform_types import PlatformType
|
||||
from .base_handler import BaseAnalyticsHandler
|
||||
from ..insights.bing_insights_service import BingInsightsService
|
||||
from services.bing_analytics_storage_service import BingAnalyticsStorageService
|
||||
import os
|
||||
|
||||
|
||||
from services.database import get_user_db_path
|
||||
|
||||
class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
"""Handler for Bing Webmaster Tools analytics"""
|
||||
"""
|
||||
Handler for Bing Webmaster Tools analytics
|
||||
|
||||
NOTE (2026-02-14): Known issues and directions
|
||||
- Verified sites list can be empty despite valid tokens. This leads to partial/error states and prevents storage collection.
|
||||
Direction: UI now provides a manual site picker (with primary website fallback from onboarding) to trigger storage collection,
|
||||
and a future improvement should accept a target_url from /api/analytics/data to influence site selection here.
|
||||
- Token state mismatch (status shows connected, analytics reports expired) can happen across cache boundaries.
|
||||
Direction: The frontend auto-resyncs once after OAuth success and provides a backend cache clear endpoint.
|
||||
- Storage-backed summary reads rely on a selected site; when sites are missing, selected_site is None.
|
||||
Direction: Allow explicit site_url parameter in the analytics orchestrator to override selected_site resolution.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(PlatformType.BING)
|
||||
@@ -42,14 +51,22 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
db_url = f'sqlite:///{db_path}'
|
||||
return BingInsightsService(db_url)
|
||||
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Bing Webmaster analytics data using Bing Webmaster API
|
||||
"""
|
||||
self.log_analytics_request(user_id, "get_analytics")
|
||||
|
||||
# Check cache first
|
||||
cached_data = analytics_cache.get('bing_analytics', user_id)
|
||||
# Check cache first (include date range and target_url in key)
|
||||
cache_key_parts = [user_id]
|
||||
if target_url:
|
||||
cache_key_parts.append(str(target_url))
|
||||
if start_date:
|
||||
cache_key_parts.append(str(start_date))
|
||||
if end_date:
|
||||
cache_key_parts.append(str(end_date))
|
||||
cache_key = "_".join(cache_key_parts)
|
||||
cached_data = analytics_cache.get('bing_analytics', cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Using cached Bing analytics for user {user_id}")
|
||||
return AnalyticsData(**cached_data)
|
||||
@@ -107,9 +124,22 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
site_url_for_storage = selected_site.get('Url', '') if selected_site else ''
|
||||
logger.info(f"Using Bing site URL: {site_url_for_storage}")
|
||||
|
||||
# Determine date range (defaults to last 30 days)
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
# Compute days for storage/insights services (at least 1)
|
||||
try:
|
||||
dt_end = datetime.strptime(end_date, '%Y-%m-%d')
|
||||
dt_start = datetime.strptime(start_date, '%Y-%m-%d')
|
||||
days_range = max(1, (dt_end - dt_start).days + 1)
|
||||
except Exception:
|
||||
days_range = 30
|
||||
|
||||
query_stats = {}
|
||||
try:
|
||||
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=30)
|
||||
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=days_range)
|
||||
if stored and isinstance(stored, dict):
|
||||
query_stats = {
|
||||
'total_clicks': stored.get('summary', {}).get('total_clicks', 0),
|
||||
@@ -138,19 +168,20 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
'insights': insights,
|
||||
'note': 'Bing Webmaster API provides SEO insights, search performance, and index status data'
|
||||
}
|
||||
|
||||
if (not sites) or (metrics.get('total_impressions', 0) == 0 and metrics.get('total_clicks', 0) == 0):
|
||||
result = self.create_partial_response(metrics=metrics, error_message='Connected to Bing; waiting for stored analytics or site verification')
|
||||
|
||||
if not sites:
|
||||
result = self.create_partial_response(metrics=metrics, error_message='Connected to Bing; no verified sites found')
|
||||
else:
|
||||
result = self.create_success_response(metrics=metrics)
|
||||
result = self.create_success_response(metrics=metrics, date_range={'start': start_date, 'end': end_date})
|
||||
|
||||
analytics_cache.set('bing_analytics', user_id, result.__dict__)
|
||||
analytics_cache.set('bing_analytics', cache_key, result.__dict__)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
self.log_analytics_error(user_id, "get_analytics", e)
|
||||
error_result = self.create_error_response(str(e))
|
||||
analytics_cache.set('bing_analytics', user_id, error_result.__dict__, ttl_override=300)
|
||||
# Cache error briefly to prevent hammering but recover quickly
|
||||
analytics_cache.set('bing_analytics', cache_key, error_result.__dict__, ttl_override=30)
|
||||
return error_result
|
||||
|
||||
def _get_enhanced_insights_with_service(self, insights_service: BingInsightsService, user_id: str, site_url: str) -> Dict[str, Any]:
|
||||
|
||||
@@ -22,7 +22,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
super().__init__(PlatformType.GSC)
|
||||
self.gsc_service = GSCService()
|
||||
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Google Search Console analytics data with caching
|
||||
|
||||
@@ -35,8 +35,16 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
self.log_analytics_request(user_id, "get_analytics")
|
||||
|
||||
# Check cache first - GSC API calls can be expensive
|
||||
# Include target_url in cache key if provided
|
||||
cache_key = f"{user_id}_{target_url}" if target_url else user_id
|
||||
# Include target_url and date range in cache key if provided
|
||||
cache_key_parts = [user_id]
|
||||
if target_url:
|
||||
cache_key_parts.append(str(target_url))
|
||||
if start_date:
|
||||
cache_key_parts.append(str(start_date))
|
||||
if end_date:
|
||||
cache_key_parts.append(str(end_date))
|
||||
# Bump cache version to include page insights (v2)
|
||||
cache_key = "_".join(cache_key_parts + ['v2pages'])
|
||||
cached_data = analytics_cache.get('gsc_analytics', cache_key)
|
||||
if cached_data:
|
||||
logger.info("Using cached GSC analytics for user {user_id}", user_id=user_id)
|
||||
@@ -70,9 +78,11 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
site_url = selected_site['siteUrl']
|
||||
logger.info(f"Using GSC site URL: {site_url}")
|
||||
|
||||
# Get search analytics for last 30 days
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
# Determine date range (defaults to last 30 days)
|
||||
if not end_date:
|
||||
end_date = datetime.now().strftime('%Y-%m-%d')
|
||||
if not start_date:
|
||||
start_date = (datetime.now() - timedelta(days=30)).strftime('%Y-%m-%d')
|
||||
logger.info(f"GSC Date range: {start_date} to {end_date}")
|
||||
|
||||
search_analytics = self.gsc_service.get_search_analytics(
|
||||
@@ -86,10 +96,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
# Process GSC data into standardized format
|
||||
processed_metrics = self._process_gsc_metrics(search_analytics)
|
||||
|
||||
result = self.create_success_response(
|
||||
metrics=processed_metrics,
|
||||
date_range={'start': start_date, 'end': end_date}
|
||||
)
|
||||
result = self.create_success_response(metrics=processed_metrics, date_range={'start': start_date, 'end': end_date})
|
||||
|
||||
# Cache the result to avoid expensive API calls
|
||||
analytics_cache.set('gsc_analytics', cache_key, result.__dict__)
|
||||
@@ -101,8 +108,8 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
self.log_analytics_error(user_id, "get_analytics", e)
|
||||
error_result = self.create_error_response(str(e))
|
||||
|
||||
# Cache error result for shorter time to retry sooner
|
||||
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=300) # 5 minutes
|
||||
# Cache error result briefly to avoid repeated failures but allow quick recovery
|
||||
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=30) # 30 seconds
|
||||
return error_result
|
||||
|
||||
def get_connection_status(self, user_id: str) -> Dict[str, Any]:
|
||||
@@ -202,18 +209,159 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
sorted_queries = sorted(top_queries_source, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
|
||||
|
||||
for row in sorted_queries:
|
||||
clicks_val = row.get('clicks', 0) or 0
|
||||
impr_val = row.get('impressions', 0) or 0
|
||||
raw_ctr = row.get('ctr', None)
|
||||
# Calculate CTR% robustly even if 'ctr' field is missing in row
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = round(float(raw_ctr) * 100, 2)
|
||||
else:
|
||||
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
|
||||
top_queries.append({
|
||||
'query': self._extract_query_from_row(row),
|
||||
'clicks': row.get('clicks', 0),
|
||||
'impressions': row.get('impressions', 0),
|
||||
'ctr': round(row.get('ctr', 0) * 100, 2),
|
||||
'position': round(row.get('position', 0), 2)
|
||||
'clicks': clicks_val,
|
||||
'impressions': impr_val,
|
||||
'ctr': ctr_percent,
|
||||
'position': round(row.get('position', 0) or 0, 2)
|
||||
})
|
||||
|
||||
# Prepare Top Pages (requires page dimension, but we only requested query dimension in gsc_service step 3)
|
||||
# To get top pages, we would need another API call with dimension=['page']
|
||||
# For now, we'll return empty top_pages or infer from what we have if possible (we can't from query data)
|
||||
top_pages = []
|
||||
# Prepare Top Pages from page_data when available
|
||||
top_pages = []
|
||||
try:
|
||||
page_rows = search_analytics.get('page_data', {}).get('rows', [])
|
||||
qp_rows = search_analytics.get('query_page_data', {}).get('rows', [])
|
||||
# Build queries-by-page map
|
||||
queries_by_page: Dict[str, list] = {}
|
||||
if qp_rows:
|
||||
for r in qp_rows:
|
||||
keys = r.get('keys', [])
|
||||
if not keys or len(keys) < 2:
|
||||
continue
|
||||
query_key = keys[0]['keys'][0] if isinstance(keys[0], dict) else str(keys[0])
|
||||
page_key = keys[1]['keys'][0] if isinstance(keys[1], dict) else str(keys[1])
|
||||
clicks_val = r.get('clicks', 0) or 0
|
||||
impr_val = r.get('impressions', 0) or 0
|
||||
raw_ctr = r.get('ctr', None)
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = round(float(raw_ctr) * 100, 2)
|
||||
else:
|
||||
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
|
||||
lst = queries_by_page.setdefault(page_key, [])
|
||||
lst.append({
|
||||
'query': query_key,
|
||||
'clicks': clicks_val,
|
||||
'impressions': impr_val,
|
||||
'ctr': ctr_percent,
|
||||
})
|
||||
if page_rows:
|
||||
sorted_pages = sorted(page_rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
|
||||
for row in sorted_pages:
|
||||
clicks_val = row.get('clicks', 0) or 0
|
||||
impr_val = row.get('impressions', 0) or 0
|
||||
raw_ctr = row.get('ctr', None)
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = round(float(raw_ctr) * 100, 2)
|
||||
else:
|
||||
ctr_percent = round(((clicks_val / impr_val) * 100), 2) if impr_val > 0 else 0.0
|
||||
page_url = self._extract_page_from_row(row)
|
||||
# attach top queries pointing to this page, sorted by clicks
|
||||
page_queries = sorted(queries_by_page.get(page_url, []), key=lambda x: x.get('clicks', 0), reverse=True)[:5]
|
||||
top_pages.append({
|
||||
'page': page_url,
|
||||
'clicks': clicks_val,
|
||||
'impressions': impr_val,
|
||||
'ctr': ctr_percent,
|
||||
'position': round(row.get('position', 0) or 0, 2) if 'position' in row else None,
|
||||
'queries': page_queries
|
||||
})
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed processing top_pages: {e}")
|
||||
|
||||
# Detect Cannibalization (query mapping to multiple pages)
|
||||
cannibalization = []
|
||||
try:
|
||||
qp_rows = search_analytics.get('query_page_data', {}).get('rows', [])
|
||||
q_rows = search_analytics.get('query_data', {}).get('rows', [])
|
||||
if qp_rows:
|
||||
# Determine window days for thresholding
|
||||
from datetime import datetime
|
||||
start_s = search_analytics.get('startDate')
|
||||
end_s = search_analytics.get('endDate')
|
||||
window_days = 30
|
||||
try:
|
||||
if start_s and end_s:
|
||||
sd = datetime.strptime(start_s, "%Y-%m-%d")
|
||||
ed = datetime.strptime(end_s, "%Y-%m-%d")
|
||||
window_days = max((ed - sd).days + 1, 1)
|
||||
except Exception:
|
||||
pass
|
||||
min_clicks = 10 if window_days <= 7 else (30 if window_days <= 30 else 60)
|
||||
# Build map: query -> { page -> metrics }
|
||||
by_query: Dict[str, Dict[str, Dict[str, float]]] = {}
|
||||
for r in qp_rows:
|
||||
keys = r.get('keys', [])
|
||||
if not keys or len(keys) < 2:
|
||||
continue
|
||||
qk = keys[0]['keys'][0] if isinstance(keys[0], dict) else str(keys[0])
|
||||
pk = keys[1]['keys'][0] if isinstance(keys[1], dict) else str(keys[1])
|
||||
clicks_val = float(r.get('clicks', 0) or 0)
|
||||
impr_val = float(r.get('impressions', 0) or 0)
|
||||
raw_ctr = r.get('ctr', None)
|
||||
if raw_ctr is not None:
|
||||
ctr_percent = float(raw_ctr) * 100.0
|
||||
else:
|
||||
ctr_percent = (clicks_val / impr_val * 100.0) if impr_val > 0 else 0.0
|
||||
pos_val = float(r.get('position', 0) or 0)
|
||||
by_query.setdefault(qk, {}).setdefault(pk, {"clicks": 0.0, "impressions": 0.0, "ctr": 0.0, "position_sum": 0.0, "position_count": 0.0})
|
||||
agg = by_query[qk][pk]
|
||||
agg["clicks"] += clicks_val
|
||||
agg["impressions"] += impr_val
|
||||
agg["ctr"] = max(agg["ctr"], ctr_percent)
|
||||
if pos_val > 0:
|
||||
agg["position_sum"] += pos_val
|
||||
agg["position_count"] += 1
|
||||
# Use query totals for context
|
||||
total_by_query: Dict[str, Dict[str, float]] = {}
|
||||
for r in q_rows or []:
|
||||
qk = self._extract_query_from_row(r)
|
||||
total_by_query[qk] = {
|
||||
"clicks": float(r.get('clicks', 0) or 0),
|
||||
"impressions": float(r.get('impressions', 0) or 0),
|
||||
"position": float(r.get('position', 0) or 0)
|
||||
}
|
||||
for qk, pages_map in by_query.items():
|
||||
if len(pages_map) < 2:
|
||||
continue
|
||||
total_clicks = sum(p["clicks"] for p in pages_map.values())
|
||||
if total_clicks < min_clicks:
|
||||
continue
|
||||
qpos = total_by_query.get(qk, {}).get("position", 0.0)
|
||||
if not (3.0 <= qpos <= 20.0) and qpos != 0.0:
|
||||
# Skip queries already ranking very well or very poorly (if pos present)
|
||||
continue
|
||||
pages_list = []
|
||||
for pk, m in pages_map.items():
|
||||
avg_pos = (m["position_sum"] / m["position_count"]) if m["position_count"] > 0 else 0.0
|
||||
pages_list.append({
|
||||
"page": pk,
|
||||
"clicks": round(m["clicks"], 0),
|
||||
"impressions": round(m["impressions"], 0),
|
||||
"ctr": round(m["ctr"], 2),
|
||||
"position": round(avg_pos, 2) if avg_pos > 0 else None
|
||||
})
|
||||
pages_list.sort(key=lambda x: x.get("clicks", 0), reverse=True)
|
||||
target_page = pages_list[0]["page"] if pages_list else None
|
||||
cannibalization.append({
|
||||
"query": qk,
|
||||
"total_clicks": int(round(total_clicks)),
|
||||
"recommended_target_page": target_page,
|
||||
"pages": pages_list[:3]
|
||||
})
|
||||
# Sort by impact
|
||||
cannibalization.sort(key=lambda item: item.get("total_clicks", 0), reverse=True)
|
||||
cannibalization = cannibalization[:10]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed computing cannibalization: {e}")
|
||||
|
||||
return {
|
||||
'connection_status': 'connected',
|
||||
@@ -224,7 +372,8 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
'avg_position': round(avg_position, 2),
|
||||
'total_queries': len(top_queries_source) if top_queries_source else 0,
|
||||
'top_queries': top_queries,
|
||||
'top_pages': top_pages
|
||||
'top_pages': top_pages,
|
||||
'cannibalization': cannibalization
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
@@ -256,3 +405,18 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting query from row: {e}")
|
||||
return 'Unknown'
|
||||
|
||||
def _extract_page_from_row(self, row: Dict[str, Any]) -> str:
|
||||
"""Extract page URL from GSC API row data"""
|
||||
try:
|
||||
keys = row.get('keys', [])
|
||||
if keys and len(keys) > 0:
|
||||
first_key = keys[0]
|
||||
if isinstance(first_key, dict):
|
||||
return first_key.get('keys', [''])[0]
|
||||
else:
|
||||
return str(first_key)
|
||||
return ''
|
||||
except Exception as e:
|
||||
logger.error(f"Error extracting page from row: {e}")
|
||||
return ''
|
||||
|
||||
@@ -21,7 +21,7 @@ class WixAnalyticsHandler(BaseAnalyticsHandler):
|
||||
super().__init__(PlatformType.WIX)
|
||||
self.wix_service = WixService()
|
||||
|
||||
async def get_analytics(self, user_id: str) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Wix analytics data using the Business Management API
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class WordPressAnalyticsHandler(BaseAnalyticsHandler):
|
||||
super().__init__(PlatformType.WORDPRESS)
|
||||
self.wordpress_service = WordPressOAuthService()
|
||||
|
||||
async def get_analytics(self, user_id: str) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, start_date: str = None, end_date: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get WordPress analytics data using WordPress.com REST API
|
||||
|
||||
|
||||
@@ -42,7 +42,7 @@ class PlatformAnalyticsService:
|
||||
self.summary_generator = AnalyticsSummaryGenerator()
|
||||
self.cache_manager = AnalyticsCacheManager()
|
||||
|
||||
async def get_comprehensive_analytics(self, user_id: str, platforms: List[str] = None) -> Dict[str, AnalyticsData]:
|
||||
async def get_comprehensive_analytics(self, user_id: str, platforms: List[str] = None, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, AnalyticsData]:
|
||||
"""
|
||||
Get analytics data from all connected platforms
|
||||
|
||||
@@ -93,9 +93,18 @@ class PlatformAnalyticsService:
|
||||
|
||||
if handler:
|
||||
if platform_type == PlatformType.GSC or platform_type == PlatformType.BING:
|
||||
analytics_data[platform_name] = await handler.get_analytics(user_id, target_url=target_url)
|
||||
analytics_data[platform_name] = await handler.get_analytics(
|
||||
user_id,
|
||||
target_url=target_url,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
else:
|
||||
analytics_data[platform_name] = await handler.get_analytics(user_id)
|
||||
analytics_data[platform_name] = await handler.get_analytics(
|
||||
user_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unknown platform: {platform_name}")
|
||||
analytics_data[platform_name] = self._create_error_response(platform_name, f"Unknown platform: {platform_name}")
|
||||
|
||||
@@ -237,7 +237,7 @@ class BingAnalyticsStorageService:
|
||||
Dict containing analytics summary
|
||||
"""
|
||||
try:
|
||||
db = self._get_db_session()
|
||||
db = self._get_db_session(user_id)
|
||||
|
||||
# Date range
|
||||
end_date = datetime.now()
|
||||
@@ -331,7 +331,7 @@ class BingAnalyticsStorageService:
|
||||
List of top queries with performance data
|
||||
"""
|
||||
try:
|
||||
db = self._get_db_session()
|
||||
db = self._get_db_session(user_id)
|
||||
|
||||
# Calculate date range
|
||||
end_date = datetime.now()
|
||||
|
||||
@@ -241,6 +241,9 @@ class ExaResearchProvider(BaseProvider):
|
||||
for idx, result in enumerate(results):
|
||||
source_type = self._determine_source_type(result.url if hasattr(result, 'url') else '')
|
||||
|
||||
# Extract image if available (some Exa results include image URL)
|
||||
image_url = result.image if hasattr(result, 'image') else None
|
||||
|
||||
sources.append({
|
||||
'title': result.title if hasattr(result, 'title') else '',
|
||||
'url': result.url if hasattr(result, 'url') else '',
|
||||
@@ -251,17 +254,21 @@ class ExaResearchProvider(BaseProvider):
|
||||
'source_type': source_type,
|
||||
'content': result.text if hasattr(result, 'text') else '',
|
||||
'highlights': result.highlights if hasattr(result, 'highlights') else [],
|
||||
'summary': result.summary if hasattr(result, 'summary') else ''
|
||||
'summary': result.summary if hasattr(result, 'summary') else '',
|
||||
'image': image_url,
|
||||
'author': result.author if hasattr(result, 'author') else None
|
||||
})
|
||||
|
||||
return sources
|
||||
|
||||
def _get_excerpt(self, result):
|
||||
"""Extract excerpt from Exa result."""
|
||||
"""Extract excerpt from Exa result. Prefer highlights if available."""
|
||||
if hasattr(result, 'highlights') and result.highlights and len(result.highlights) > 0:
|
||||
return result.highlights[0]
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
return result.summary
|
||||
if hasattr(result, 'text') and result.text:
|
||||
return result.text[:500]
|
||||
elif hasattr(result, 'summary') and result.summary:
|
||||
return result.summary
|
||||
return ''
|
||||
|
||||
def _determine_source_type(self, url):
|
||||
@@ -280,16 +287,30 @@ class ExaResearchProvider(BaseProvider):
|
||||
return 'web'
|
||||
|
||||
def _aggregate_content(self, results):
|
||||
"""Aggregate content from Exa results for LLM analysis."""
|
||||
"""Aggregate content from Exa results for LLM analysis, including highlights."""
|
||||
content_parts = []
|
||||
|
||||
for idx, result in enumerate(results):
|
||||
part = [f"Source {idx + 1}: {result.title if hasattr(result, 'title') else 'Untitled'}"]
|
||||
if hasattr(result, 'url') and result.url:
|
||||
part.append(f"URL: {result.url}")
|
||||
|
||||
# Add highlights if available (most valuable for LLM)
|
||||
if hasattr(result, 'highlights') and result.highlights:
|
||||
highlights_text = "\n".join([f"- {h}" for h in result.highlights])
|
||||
part.append(f"Key Highlights:\n{highlights_text}")
|
||||
|
||||
# Add summary if available
|
||||
if hasattr(result, 'summary') and result.summary:
|
||||
content_parts.append(f"Source {idx + 1}: {result.summary}")
|
||||
part.append(f"Summary: {result.summary}")
|
||||
|
||||
# Add text snippet if highlights/summary insufficient
|
||||
elif hasattr(result, 'text') and result.text:
|
||||
content_parts.append(f"Source {idx + 1}: {result.text[:1000]}")
|
||||
part.append(f"Excerpt: {result.text[:1000]}")
|
||||
|
||||
content_parts.append("\n".join(part))
|
||||
|
||||
return "\n\n".join(content_parts)
|
||||
return "\n\n---\n\n".join(content_parts)
|
||||
|
||||
def track_exa_usage(self, user_id: str, cost: float):
|
||||
"""Track Exa API usage after successful call."""
|
||||
|
||||
@@ -159,14 +159,10 @@ class StyleDetectionLogic:
|
||||
}}
|
||||
"""
|
||||
|
||||
# Call the LLM for analysis
|
||||
logger.debug("[StyleDetectionLogic.analyze_content_style] Sending enhanced prompt to LLM")
|
||||
try:
|
||||
analysis_text = llm_text_gen(prompt, user_id=user_id)
|
||||
|
||||
# Clean and parse the response
|
||||
cleaned_json = self._clean_json_response(analysis_text)
|
||||
|
||||
analysis_results = json.loads(cleaned_json)
|
||||
logger.info("[StyleDetectionLogic.analyze_content_style] Successfully parsed enhanced analysis results")
|
||||
return {
|
||||
@@ -179,7 +175,7 @@ class StyleDetectionLogic:
|
||||
return {
|
||||
'success': True,
|
||||
'analysis': fallback_results,
|
||||
'warning': 'AI analysis failed, used fallback detection'
|
||||
'warning': f'AI analysis failed ({str(e)}), used fallback detection'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -145,6 +145,7 @@ def init_user_database(user_id: str):
|
||||
SubscriptionBase.metadata.create_all(bind=engine)
|
||||
UserBusinessInfoBase.metadata.create_all(bind=engine)
|
||||
ContentAssetBase.metadata.create_all(bind=engine)
|
||||
BingAnalyticsBase.metadata.create_all(bind=engine)
|
||||
|
||||
# Initialize default data for new databases
|
||||
try:
|
||||
|
||||
@@ -343,7 +343,11 @@ class GSCService:
|
||||
if not credentials:
|
||||
raise ValueError("No valid credentials found")
|
||||
|
||||
service = build('searchconsole', 'v1', credentials=credentials)
|
||||
# Disable discovery file cache (suppress oauth2client file_cache warnings) with safe fallback
|
||||
try:
|
||||
service = build('searchconsole', 'v1', credentials=credentials, cache_discovery=False)
|
||||
except TypeError:
|
||||
service = build('searchconsole', 'v1', credentials=credentials)
|
||||
logger.info(f"Authenticated GSC service created for user: {user_id}")
|
||||
return service
|
||||
|
||||
@@ -395,9 +399,12 @@ class GSCService:
|
||||
# Check cache first
|
||||
cache_key = f"{user_id}_{site_url}_{start_date}_{end_date}"
|
||||
cached_data = self._get_cached_data(user_id, site_url, 'analytics', cache_key)
|
||||
if cached_data:
|
||||
logger.info(f"Returning cached analytics data for user: {user_id}")
|
||||
return cached_data
|
||||
if cached_data and isinstance(cached_data, dict):
|
||||
has_pages = 'page_data' in cached_data and isinstance(cached_data.get('page_data'), dict)
|
||||
has_queries = 'query_data' in cached_data and isinstance(cached_data.get('query_data'), dict)
|
||||
if has_pages and has_queries:
|
||||
logger.info(f"Returning cached analytics data for user: {user_id} (includes page_data)")
|
||||
return cached_data
|
||||
|
||||
try:
|
||||
service = self.get_authenticated_service(user_id)
|
||||
@@ -476,8 +483,54 @@ class GSCService:
|
||||
).execute()
|
||||
|
||||
logger.info(f"GSC Query-level response for user {user_id}: {query_response}")
|
||||
|
||||
# Combine overall metrics with query-level data
|
||||
|
||||
# Step 4: Get page-level data for top pages insights
|
||||
page_request = {
|
||||
'startDate': start_date,
|
||||
'endDate': end_date,
|
||||
'dimensions': ['page'], # Get page-level data
|
||||
'rowLimit': 1000
|
||||
}
|
||||
logger.info(f"GSC Page-level request for user {user_id}: {page_request}")
|
||||
page_rows = []
|
||||
page_row_count = 0
|
||||
try:
|
||||
page_response = service.searchanalytics().query(
|
||||
siteUrl=site_url,
|
||||
body=page_request
|
||||
).execute()
|
||||
logger.info(f"GSC Page-level response for user {user_id}: {page_response}")
|
||||
page_rows = page_response.get('rows', [])
|
||||
page_row_count = page_response.get('rowCount', 0)
|
||||
except Exception as page_error:
|
||||
logger.warning(f"GSC Page-level request failed for user {user_id}: {page_error}")
|
||||
page_rows = []
|
||||
page_row_count = 0
|
||||
|
||||
# Step 5: Get query+page combined data for mapping queries to pages
|
||||
qp_rows = []
|
||||
qp_row_count = 0
|
||||
try:
|
||||
qp_request = {
|
||||
'startDate': start_date,
|
||||
'endDate': end_date,
|
||||
'dimensions': ['query', 'page'],
|
||||
'rowLimit': 1000
|
||||
}
|
||||
logger.info(f"GSC Query+Page request for user {user_id}: {qp_request}")
|
||||
qp_response = service.searchanalytics().query(
|
||||
siteUrl=site_url,
|
||||
body=qp_request
|
||||
).execute()
|
||||
logger.info(f"GSC Query+Page response for user {user_id}: {qp_response}")
|
||||
qp_rows = qp_response.get('rows', [])
|
||||
qp_row_count = qp_response.get('rowCount', 0)
|
||||
except Exception as qp_error:
|
||||
logger.warning(f"GSC Query+Page request failed for user {user_id}: {qp_error}")
|
||||
qp_rows = []
|
||||
qp_row_count = 0
|
||||
|
||||
# Combine overall, query, page and query+page data
|
||||
analytics_data = {
|
||||
'overall_metrics': {
|
||||
'rows': response.get('rows', []),
|
||||
@@ -487,6 +540,14 @@ class GSCService:
|
||||
'rows': query_response.get('rows', []),
|
||||
'rowCount': query_response.get('rowCount', 0)
|
||||
},
|
||||
'page_data': {
|
||||
'rows': page_rows,
|
||||
'rowCount': page_row_count
|
||||
},
|
||||
'query_page_data': {
|
||||
'rows': qp_rows,
|
||||
'rowCount': qp_row_count
|
||||
},
|
||||
'verification_data': {
|
||||
'rows': verification_rows,
|
||||
'rowCount': len(verification_rows)
|
||||
@@ -510,6 +571,8 @@ class GSCService:
|
||||
'rowCount': response.get('rowCount', 0)
|
||||
},
|
||||
'query_data': {'rows': [], 'rowCount': 0},
|
||||
'page_data': {'rows': [], 'rowCount': 0},
|
||||
'query_page_data': {'rows': [], 'rowCount': 0},
|
||||
'verification_data': {
|
||||
'rows': verification_rows,
|
||||
'rowCount': len(verification_rows)
|
||||
|
||||
@@ -76,7 +76,8 @@ class ALwrityAgentOrchestrator:
|
||||
try:
|
||||
# Initialize shared LLM
|
||||
if TXTAI_AVAILABLE:
|
||||
self.llm = LLM(self.config.shared_llm)
|
||||
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
|
||||
self.llm = LLM(self.config.shared_llm, task="text-generation")
|
||||
else:
|
||||
self.llm = None
|
||||
|
||||
|
||||
@@ -181,7 +181,8 @@ class BaseALwrityAgent(ABC):
|
||||
try:
|
||||
if not self.llm:
|
||||
# Create new LLM if not provided
|
||||
raw_llm = LLM(model_name)
|
||||
# Hardening: Explicitly set task to avoid 'text2text-generation' default failures
|
||||
raw_llm = LLM(model_name, task="text-generation")
|
||||
# Wrap it
|
||||
self.llm = TrackingLLMWrapper(raw_llm, self.user_id, self.model_name)
|
||||
|
||||
@@ -906,6 +907,11 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
|
||||
"name": "task_delegator",
|
||||
"description": "Delegates specific tasks to specialized agents (content, competitor, seo, social)",
|
||||
"target": self._delegate_task_tool
|
||||
},
|
||||
{
|
||||
"name": "kickoff_gsc_first_pass",
|
||||
"description": "Kicks off first-pass execution by invoking SEO/Content default GSC plans",
|
||||
"target": self._kickoff_gsc_first_pass_tool
|
||||
}
|
||||
],
|
||||
max_iterations=15,
|
||||
@@ -924,7 +930,9 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
|
||||
Do not just plan; EXECUTE by delegating.
|
||||
|
||||
Always prioritize user goals and maintain safety constraints.
|
||||
Coordinate multi-agent responses to market changes effectively."""
|
||||
Coordinate multi-agent responses to market changes effectively.
|
||||
|
||||
First, call 'kickoff_gsc_first_pass' to ground the plan on live GSC signals."""
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1033,6 +1041,37 @@ class StrategyOrchestratorAgent(BaseALwrityAgent):
|
||||
except Exception as e:
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _kickoff_gsc_first_pass_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Invoke SEO and Content agents' default GSC plans and combine results"""
|
||||
try:
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
payload = {"start_date": start_date, "end_date": end_date}
|
||||
results = {}
|
||||
combined_actions = []
|
||||
|
||||
seo = self.sub_agents.get("seo")
|
||||
if seo and hasattr(seo, "_default_seo_gsc_plan_tool"):
|
||||
plan = await seo._default_seo_gsc_plan_tool(payload)
|
||||
results["seo"] = plan
|
||||
combined_actions.extend(plan.get("actions", []) if isinstance(plan, dict) else [])
|
||||
|
||||
content = self.sub_agents.get("content")
|
||||
if content and hasattr(content, "_default_content_gsc_plan_tool"):
|
||||
plan = await content._default_content_gsc_plan_tool(payload)
|
||||
results["content"] = plan
|
||||
combined_actions.extend(plan.get("actions", []) if isinstance(plan, dict) else [])
|
||||
|
||||
return {
|
||||
"status": "ok",
|
||||
"invoked": list(results.keys()),
|
||||
"results": results,
|
||||
"combined_actions": combined_actions,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
return {"status": "error", "error": str(e)}
|
||||
|
||||
async def _strategy_synthesizer_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Tool for synthesizing strategies"""
|
||||
return {
|
||||
|
||||
@@ -13,6 +13,7 @@ from loguru import logger
|
||||
from ..txtai_service import TxtaiIntelligenceService
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, AgentAction
|
||||
from services.seo_tools.content_strategy_service import ContentStrategyService
|
||||
from services.analytics import PlatformAnalyticsService
|
||||
from services.intelligence.sif_agents import SharedLLMWrapper, LocalLLMWrapper
|
||||
try:
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
@@ -888,7 +889,37 @@ class ContentStrategyAgent(BaseALwrityAgent):
|
||||
"name": "sitemap_analyzer",
|
||||
"description": "Analyzes website structure and publishing velocity via sitemap",
|
||||
"target": self._sitemap_analyzer_tool
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_queries",
|
||||
"description": "Returns low-CTR queries with evidence from cached GSC metrics",
|
||||
"target": self._cs_gsc_low_ctr_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_striking_distance_queries",
|
||||
"description": "Returns striking-distance queries (positions ~8–20) with evidence",
|
||||
"target": self._cs_gsc_striking_distance_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_declining_queries",
|
||||
"description": "Returns period-over-period declining queries with evidence",
|
||||
"target": self._cs_gsc_declining_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_pages",
|
||||
"description": "Returns low-CTR pages with top contributing queries",
|
||||
"target": self._cs_gsc_low_ctr_pages_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_cannibalization_candidates",
|
||||
"description": "Returns query→multiple-pages cannibalization candidates with target recommendation",
|
||||
"target": self._cs_gsc_cannibalization_candidates_tool
|
||||
},
|
||||
{
|
||||
"name": "default_content_gsc_plan",
|
||||
"description": "Runs a default first-pass plan using GSC signals (titles/meta, consolidation, refreshes)",
|
||||
"target": self._default_content_gsc_plan_tool
|
||||
},
|
||||
],
|
||||
max_iterations=8,
|
||||
system=self.get_effective_system_prompt(f"""You are the Content Strategy Agent for ALwrity user {self.user_id}.
|
||||
@@ -903,12 +934,153 @@ class ContentStrategyAgent(BaseALwrityAgent):
|
||||
- Performance-based content improvements
|
||||
|
||||
Use semantic analysis (SIF) and sitemap analysis to understand content context.
|
||||
Always prioritize user goals and maintain brand consistency."""
|
||||
Always prioritize user goals and maintain brand consistency.
|
||||
|
||||
In your first pass, call 'default_content_gsc_plan' to ground your actions on live GSC signals."""
|
||||
)
|
||||
)
|
||||
|
||||
# Tool Implementations
|
||||
|
||||
async def _cs_fetch_gsc_analytics(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, Any]:
|
||||
svc = PlatformAnalyticsService()
|
||||
data = await svc.get_comprehensive_analytics(self.user_id, platforms=["gsc"], start_date=start_date, end_date=end_date)
|
||||
gsc = data.get("gsc")
|
||||
if not gsc or gsc.status != "success":
|
||||
err = getattr(gsc, "error_message", None) if gsc else "No data"
|
||||
raise RuntimeError(f"GSC analytics unavailable: {err}")
|
||||
return {"metrics": gsc.metrics, "date_range": gsc.date_range}
|
||||
|
||||
async def _cs_gsc_low_ctr_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 100)); min_clicks = int(context.get("min_clicks", 10)); ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{"query": r.get("query"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position")}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("clicks", 0) >= min_clicks and float(r.get("ctr", 0.0)) < ctr_threshold)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs low_ctr_queries failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_striking_distance_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 100)); start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{"query": r.get("query"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position")}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("position") is not None and 8.0 <= float(r.get("position")) <= 20.0)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("position") if x.get("position") is not None else 999, -x.get("impressions", 0)))
|
||||
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs striking_distance failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_declining_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_prev_clicks = int(context.get("min_prev_clicks", 10)); min_drop_pct = float(context.get("min_drop_pct", 30.0))
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
curr = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
curr_range = curr["date_range"]; s = curr_range.get("start"); e = curr_range.get("end")
|
||||
from datetime import datetime, timedelta; fmt = "%Y-%m-%d"
|
||||
sd = datetime.strptime(s, fmt) if s else datetime.utcnow() - timedelta(days=30); ed = datetime.strptime(e, fmt) if e else datetime.utcnow()
|
||||
days = max((ed - sd).days + 1, 1); prev_end = sd - timedelta(days=1); prev_start = prev_end - timedelta(days=days - 1)
|
||||
prev = await self._cs_fetch_gsc_analytics(prev_start.strftime(fmt), prev_end.strftime(fmt))
|
||||
curr_queries = {r.get("query"): r for r in (curr["metrics"].get("top_queries", []) or [])}
|
||||
prev_queries = {r.get("query"): r for r in (prev["metrics"].get("top_queries", []) or [])}
|
||||
items = []
|
||||
for q, prev_row in prev_queries.items():
|
||||
curr_row = curr_queries.get(q);
|
||||
if not curr_row: continue
|
||||
prev_clicks = int(prev_row.get("clicks", 0) or 0); curr_clicks = int(curr_row.get("clicks", 0) or 0)
|
||||
if prev_clicks >= min_prev_clicks and curr_clicks < prev_clicks:
|
||||
drop_pct = ((prev_clicks - curr_clicks) / prev_clicks) * 100.0
|
||||
if drop_pct >= min_drop_pct:
|
||||
items.append({"query": q, "prev_clicks": prev_clicks, "curr_clicks": curr_clicks, "drop_pct": round(drop_pct, 2)})
|
||||
items.sort(key=lambda x: (x.get("drop_pct", 0), x.get("prev_clicks", 0)), reverse=True)
|
||||
return {"items": items[:limit], "range": curr_range, "previous_range": prev["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs declining_queries failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_low_ctr_pages_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); min_impr = int(context.get("min_impressions", 200)); ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
tp = result["metrics"].get("top_pages", []) or []
|
||||
items = []
|
||||
for r in tp:
|
||||
if (r.get("impressions", 0) >= min_impr and float(r.get("ctr", 0.0)) < ctr_threshold):
|
||||
items.append({"page": r.get("page"), "clicks": r.get("clicks", 0), "impressions": r.get("impressions", 0), "ctr": r.get("ctr", 0.0), "position": r.get("position"), "evidence_queries": r.get("queries", [])[:5]})
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {"items": items[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs low_ctr_pages failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _cs_gsc_cannibalization_candidates_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10)); start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._cs_fetch_gsc_analytics(start_date, end_date)
|
||||
candidates = result["metrics"].get("cannibalization", []) or []
|
||||
return {"items": candidates[:limit], "range": result["date_range"], "source": "gsc_cache"}
|
||||
except Exception as e:
|
||||
logger.error(f"cs cannibalization_candidates failed: {e}"); return {"error": str(e)}
|
||||
|
||||
async def _default_content_gsc_plan_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start_date = context.get("start_date"); end_date = context.get("end_date")
|
||||
try:
|
||||
low_ctr_pages = await self._cs_gsc_low_ctr_pages_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
cannibals = await self._cs_gsc_cannibalization_candidates_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
striking = await self._cs_gsc_striking_distance_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
declining = await self._cs_gsc_declining_queries_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
|
||||
actions = []
|
||||
for p in low_ctr_pages.get("items", []):
|
||||
actions.append({
|
||||
"type": "improve_titles_meta",
|
||||
"target": p.get("page"),
|
||||
"reason": f"Low CTR {p.get('ctr')}% with {p.get('impressions')} impressions",
|
||||
"evidence": p.get("evidence_queries", [])
|
||||
})
|
||||
for c in cannibals.get("items", []):
|
||||
actions.append({
|
||||
"type": "consolidate/internal_link",
|
||||
"target": c.get("recommended_target_page"),
|
||||
"reason": f"Cannibalization on query '{c.get('query')}'",
|
||||
"pages": c.get("pages", [])
|
||||
})
|
||||
for q in striking.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"reason": f"Striking distance at position {q.get('position')} with {q.get('impressions')} impressions"
|
||||
})
|
||||
for q in declining.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"reason": f"Clicks decline {q.get('prev_clicks')}→{q.get('curr_clicks')} ({q.get('drop_pct')}%)"
|
||||
})
|
||||
|
||||
return {
|
||||
"plan_name": "Default Content Plan from GSC",
|
||||
"range": {"current": {"start": start_date, "end": end_date}},
|
||||
"actions": actions,
|
||||
"source": "gsc_cache",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"default_content_gsc_plan failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _sitemap_analyzer_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Sitemap analysis tool using ContentStrategyService"""
|
||||
website_url = context.get('website_url')
|
||||
@@ -1324,7 +1496,37 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
"name": "query_seo_knowledge_base",
|
||||
"description": "Queries the SIF knowledge base for SEO dashboard data, GSC/Bing metrics, and semantic insights",
|
||||
"target": self._query_seo_knowledge_base_tool
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_queries",
|
||||
"description": "Returns low-CTR queries with evidence from cached GSC metrics",
|
||||
"target": self._gsc_low_ctr_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_striking_distance_queries",
|
||||
"description": "Returns striking-distance queries (positions ~8–20) with evidence",
|
||||
"target": self._gsc_striking_distance_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_declining_queries",
|
||||
"description": "Returns period-over-period declining queries with evidence",
|
||||
"target": self._gsc_declining_queries_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_low_ctr_pages",
|
||||
"description": "Returns low-CTR pages with top contributing queries",
|
||||
"target": self._gsc_low_ctr_pages_tool
|
||||
},
|
||||
{
|
||||
"name": "gsc_cannibalization_candidates",
|
||||
"description": "Returns query→multiple-pages cannibalization candidates with target recommendation",
|
||||
"target": self._gsc_cannibalization_candidates_tool
|
||||
},
|
||||
{
|
||||
"name": "default_seo_gsc_plan",
|
||||
"description": "Runs a default first-pass SEO plan using GSC signals (titles/meta, consolidation, refreshes)",
|
||||
"target": self._default_seo_gsc_plan_tool
|
||||
},
|
||||
],
|
||||
max_iterations=15,
|
||||
system=self.get_effective_system_prompt(f"""You are the SEO Optimization Agent for ALwrity user {self.user_id}.
|
||||
@@ -1340,6 +1542,7 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
- Deep semantic search of SEO data (GSC, Bing, Audits)
|
||||
|
||||
Focus on high-impact, low-effort optimizations first.
|
||||
In your first pass, call 'default_seo_gsc_plan' to ground your actions on live GSC signals.
|
||||
Always maintain SEO best practices and user experience."""
|
||||
)
|
||||
)
|
||||
@@ -1666,6 +1869,223 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
# GSC Insights Tools (Option B)
|
||||
async def _fetch_gsc_analytics(self, start_date: Optional[str] = None, end_date: Optional[str] = None) -> Dict[str, Any]:
|
||||
svc = PlatformAnalyticsService()
|
||||
data = await svc.get_comprehensive_analytics(self.user_id, platforms=["gsc"], start_date=start_date, end_date=end_date)
|
||||
gsc = data.get("gsc")
|
||||
if not gsc or gsc.status != "success":
|
||||
err = getattr(gsc, "error_message", None) if gsc else "No data"
|
||||
raise RuntimeError(f"GSC analytics unavailable: {err}")
|
||||
return {
|
||||
"metrics": gsc.metrics,
|
||||
"date_range": gsc.date_range
|
||||
}
|
||||
|
||||
async def _gsc_low_ctr_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_impr = int(context.get("min_impressions", 100))
|
||||
min_clicks = int(context.get("min_clicks", 10))
|
||||
ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{
|
||||
"query": r.get("query"),
|
||||
"clicks": r.get("clicks", 0),
|
||||
"impressions": r.get("impressions", 0),
|
||||
"ctr": r.get("ctr", 0.0),
|
||||
"position": r.get("position")
|
||||
}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("clicks", 0) >= min_clicks and float(r.get("ctr", 0.0)) < ctr_threshold)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"low_ctr_queries tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_striking_distance_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_impr = int(context.get("min_impressions", 100))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
tq = result["metrics"].get("top_queries", []) or []
|
||||
items = [
|
||||
{
|
||||
"query": r.get("query"),
|
||||
"clicks": r.get("clicks", 0),
|
||||
"impressions": r.get("impressions", 0),
|
||||
"ctr": r.get("ctr", 0.0),
|
||||
"position": r.get("position")
|
||||
}
|
||||
for r in tq
|
||||
if (r.get("impressions", 0) >= min_impr and r.get("position") is not None and 8.0 <= float(r.get("position")) <= 20.0)
|
||||
]
|
||||
items.sort(key=lambda x: (x.get("position") if x.get("position") is not None else 999, -x.get("impressions", 0)))
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"striking_distance tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_declining_queries_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_prev_clicks = int(context.get("min_prev_clicks", 10))
|
||||
min_drop_pct = float(context.get("min_drop_pct", 30.0))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
curr = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
curr_range = curr["date_range"]
|
||||
s = curr_range.get("start")
|
||||
e = curr_range.get("end")
|
||||
from datetime import datetime, timedelta
|
||||
fmt = "%Y-%m-%d"
|
||||
sd = datetime.strptime(s, fmt) if s else datetime.utcnow() - timedelta(days=30)
|
||||
ed = datetime.strptime(e, fmt) if e else datetime.utcnow()
|
||||
days = max((ed - sd).days + 1, 1)
|
||||
prev_end = sd - timedelta(days=1)
|
||||
prev_start = prev_end - timedelta(days=days - 1)
|
||||
prev = await self._fetch_gsc_analytics(prev_start.strftime(fmt), prev_end.strftime(fmt))
|
||||
curr_queries = {r.get("query"): r for r in (curr["metrics"].get("top_queries", []) or [])}
|
||||
prev_queries = {r.get("query"): r for r in (prev["metrics"].get("top_queries", []) or [])}
|
||||
items = []
|
||||
for q, prev_row in prev_queries.items():
|
||||
curr_row = curr_queries.get(q)
|
||||
if not curr_row:
|
||||
continue
|
||||
prev_clicks = int(prev_row.get("clicks", 0) or 0)
|
||||
curr_clicks = int(curr_row.get("clicks", 0) or 0)
|
||||
if prev_clicks >= min_prev_clicks and curr_clicks < prev_clicks:
|
||||
drop_pct = ((prev_clicks - curr_clicks) / prev_clicks) * 100.0
|
||||
if drop_pct >= min_drop_pct:
|
||||
items.append({
|
||||
"query": q,
|
||||
"prev_clicks": prev_clicks,
|
||||
"curr_clicks": curr_clicks,
|
||||
"drop_pct": round(drop_pct, 2)
|
||||
})
|
||||
items.sort(key=lambda x: (x.get("drop_pct", 0), x.get("prev_clicks", 0)), reverse=True)
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": curr_range,
|
||||
"previous_range": prev["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"declining_queries tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_low_ctr_pages_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
min_impr = int(context.get("min_impressions", 200))
|
||||
ctr_threshold = float(context.get("ctr_threshold", 1.5))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
tp = result["metrics"].get("top_pages", []) or []
|
||||
items = []
|
||||
for r in tp:
|
||||
if (r.get("impressions", 0) >= min_impr and float(r.get("ctr", 0.0)) < ctr_threshold):
|
||||
items.append({
|
||||
"page": r.get("page"),
|
||||
"clicks": r.get("clicks", 0),
|
||||
"impressions": r.get("impressions", 0),
|
||||
"ctr": r.get("ctr", 0.0),
|
||||
"position": r.get("position"),
|
||||
"evidence_queries": r.get("queries", [])[:5]
|
||||
})
|
||||
items.sort(key=lambda x: (x.get("impressions", 0), -x.get("ctr", 100.0)), reverse=True)
|
||||
return {
|
||||
"items": items[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"low_ctr_pages tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _gsc_cannibalization_candidates_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
limit = int(context.get("limit", 10))
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
result = await self._fetch_gsc_analytics(start_date, end_date)
|
||||
candidates = result["metrics"].get("cannibalization", []) or []
|
||||
return {
|
||||
"items": candidates[:limit],
|
||||
"range": result["date_range"],
|
||||
"source": "gsc_cache"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"cannibalization_candidates tool failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def _default_seo_gsc_plan_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
start_date = context.get("start_date")
|
||||
end_date = context.get("end_date")
|
||||
try:
|
||||
low_ctr_pages = await self._gsc_low_ctr_pages_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
cannibals = await self._gsc_cannibalization_candidates_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
striking = await self._gsc_striking_distance_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
declining = await self._gsc_declining_queries_tool({"start_date": start_date, "end_date": end_date, "limit": 10})
|
||||
|
||||
actions = []
|
||||
for p in low_ctr_pages.get("items", []):
|
||||
actions.append({
|
||||
"type": "update_titles_meta",
|
||||
"target_page": p.get("page"),
|
||||
"justification": f"Low CTR {p.get('ctr')}% with {p.get('impressions')} impressions",
|
||||
"evidence": p.get("evidence_queries", [])
|
||||
})
|
||||
for c in cannibals.get("items", []):
|
||||
actions.append({
|
||||
"type": "consolidate/internal_link",
|
||||
"target_page": c.get("recommended_target_page"),
|
||||
"justification": f"Cannibalization on query '{c.get('query')}'",
|
||||
"pages": c.get("pages", [])
|
||||
})
|
||||
for q in striking.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"justification": f"Striking distance at position {q.get('position')} with {q.get('impressions')} impressions"
|
||||
})
|
||||
for q in declining.get("items", []):
|
||||
actions.append({
|
||||
"type": "refresh_content",
|
||||
"target": "query",
|
||||
"query": q.get("query"),
|
||||
"justification": f"Clicks decline {q.get('prev_clicks')}→{q.get('curr_clicks')} ({q.get('drop_pct')}%)"
|
||||
})
|
||||
|
||||
return {
|
||||
"plan_name": "Default SEO Plan from GSC",
|
||||
"range": {"current": {"start": start_date, "end": end_date}},
|
||||
"actions": actions,
|
||||
"source": "gsc_cache",
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"default_seo_gsc_plan failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
class SocialAmplificationAgent(BaseALwrityAgent):
|
||||
"""
|
||||
|
||||
@@ -14,9 +14,9 @@ from .txtai_service import TxtaiIntelligenceService, TXTAI_AVAILABLE
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
# Optional txtai imports
|
||||
# Optional txtai imports (align with core agent framework)
|
||||
try:
|
||||
from txtai.pipeline import Agent, LLM
|
||||
from txtai import Agent, LLM
|
||||
except ImportError:
|
||||
Agent = None
|
||||
LLM = None
|
||||
@@ -28,9 +28,13 @@ class SharedLLMWrapper:
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the shared LLM provider."""
|
||||
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
|
||||
# but we could map them if needed.
|
||||
return llm_text_gen(prompt, user_id=self.user_id)
|
||||
try:
|
||||
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
|
||||
# but we could map them if needed.
|
||||
return llm_text_gen(prompt, user_id=self.user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"SharedLLMWrapper failed to generate text: {e}")
|
||||
return f"[ERROR: Shared LLM generation failed for user {self.user_id}]"
|
||||
|
||||
def __call__(self, prompt: str, **kwargs) -> str:
|
||||
return self.generate(prompt, **kwargs)
|
||||
@@ -40,8 +44,9 @@ class LocalLLMWrapper:
|
||||
Lazily loads a local LLM via txtai.
|
||||
This prevents blocking server startup with heavy model loads.
|
||||
"""
|
||||
def __init__(self, model_path: str):
|
||||
def __init__(self, model_path: str, task: str = "text-generation"):
|
||||
self.model_path = model_path
|
||||
self.task = task
|
||||
self._llm = None
|
||||
|
||||
@property
|
||||
@@ -49,8 +54,9 @@ class LocalLLMWrapper:
|
||||
if self._llm is None:
|
||||
if LLM is None:
|
||||
raise ImportError("txtai.pipeline.LLM is not available")
|
||||
logger.info(f"Loading local LLM: {self.model_path}")
|
||||
self._llm = LLM(path=self.model_path)
|
||||
logger.info(f"Loading local LLM: {self.model_path} with task: {self.task}")
|
||||
# Explicitly set task to avoid 'text2text-generation' default failures
|
||||
self._llm = LLM(path=self.model_path, task=self.task)
|
||||
return self._llm
|
||||
|
||||
def __call__(self, prompt: str, **kwargs) -> str:
|
||||
@@ -67,11 +73,12 @@ class SIFBaseAgent(BaseALwrityAgent):
|
||||
|
||||
# 2. Local LLM for internal agent work (default for SIF agents)
|
||||
if llm is None:
|
||||
if TXTAI_AVAILABLE:
|
||||
# Use Lazy Local LLM
|
||||
llm = LocalLLMWrapper(model_name)
|
||||
if TXTAI_AVAILABLE and LLM is not None:
|
||||
# Use Lazy Local LLM when txtai LLM is available
|
||||
# Hardening: Specify 'text-generation' task to avoid text2text defaults
|
||||
llm = LocalLLMWrapper(model_name, task="text-generation")
|
||||
else:
|
||||
# Fallback to Shared if txtai not available
|
||||
# Fallback to Shared if txtai or LLM is not available
|
||||
llm = self.shared_llm
|
||||
|
||||
super().__init__(user_id, agent_type, model_name, llm)
|
||||
@@ -85,14 +92,18 @@ class SIFBaseAgent(BaseALwrityAgent):
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""
|
||||
SIF agents use the intelligence service directly, but we can expose
|
||||
capabilities via a standard agent interface if needed.
|
||||
SIF agents primarily use the intelligence service directly, but we can expose
|
||||
capabilities via a standard agent interface if available.
|
||||
"""
|
||||
if not TXTAI_AVAILABLE:
|
||||
return None
|
||||
|
||||
# Return a simple agent that can use the LLM
|
||||
return Agent(llm=self.llm, tools=[])
|
||||
if not TXTAI_AVAILABLE or Agent is None:
|
||||
logger.debug(f"[{self.__class__.__name__}] txtai Agent not available, using fallback agent")
|
||||
return self._create_fallback_agent()
|
||||
|
||||
try:
|
||||
return Agent(llm=self.llm, tools=[])
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] Failed to create txtai Agent: {e}")
|
||||
return self._create_fallback_agent()
|
||||
|
||||
class StrategyArchitectAgent(SIFBaseAgent):
|
||||
"""Agent for discovering content pillars and identifying strategic gaps."""
|
||||
|
||||
@@ -25,7 +25,18 @@ except ImportError:
|
||||
TXTAI_AVAILABLE = False
|
||||
|
||||
class TxtaiIntelligenceService:
|
||||
_instances = {}
|
||||
|
||||
def __new__(cls, user_id: str, *args, **kwargs):
|
||||
if user_id not in cls._instances:
|
||||
cls._instances[user_id] = super(TxtaiIntelligenceService, cls).__new__(cls)
|
||||
return cls._instances[user_id]
|
||||
|
||||
def __init__(self, user_id: str, model_path: Optional[str] = None, enable_caching: bool = True):
|
||||
# Singleton: prevent re-initialization if already initialized
|
||||
if getattr(self, "_singleton_initialized", False):
|
||||
return
|
||||
|
||||
self.user_id = user_id
|
||||
self.model_path = model_path or "sentence-transformers/all-MiniLM-L6-v2"
|
||||
self.index_path = f"workspace/workspace_{user_id}/indices/txtai"
|
||||
@@ -33,6 +44,11 @@ class TxtaiIntelligenceService:
|
||||
self._initialized = False
|
||||
self.enable_caching = enable_caching
|
||||
self.cache_manager = semantic_cache_manager if enable_caching else None
|
||||
self._backend = "faiss" # Default backend
|
||||
|
||||
# Mark as initialized for singleton pattern
|
||||
self._singleton_initialized = True
|
||||
|
||||
# Lazy initialization - do not initialize embeddings on startup
|
||||
# self._initialize_embeddings()
|
||||
|
||||
@@ -52,17 +68,26 @@ class TxtaiIntelligenceService:
|
||||
logger.debug(f"Model path: {self.model_path}")
|
||||
logger.debug(f"Index path: {self.index_path}")
|
||||
|
||||
# Close existing embeddings if any to release file locks
|
||||
if self.embeddings:
|
||||
try:
|
||||
if hasattr(self.embeddings, 'close'):
|
||||
self.embeddings.close()
|
||||
self.embeddings = None
|
||||
except Exception as close_err:
|
||||
logger.warning(f"Error closing existing embeddings: {close_err}")
|
||||
|
||||
# Ensure directory exists
|
||||
os.makedirs(os.path.dirname(self.index_path), exist_ok=True)
|
||||
logger.debug(f"Created index directory: {os.path.dirname(self.index_path)}")
|
||||
|
||||
# Initialize embeddings with optimal configuration for ALwrity use case
|
||||
# Hardening: Disabling quantization by default as it causes 'IndexIDMap' attribute errors with small indices on Windows
|
||||
self.embeddings = Embeddings({
|
||||
"path": self.model_path,
|
||||
"content": True, # Enable content storage for retrieval
|
||||
"objects": True, # Enable object storage for metadata
|
||||
"backend": "faiss", # Use Faiss for efficient similarity search
|
||||
"quantize": True, # Enable quantization for memory efficiency
|
||||
"backend": self._backend, # Use Faiss for efficient similarity search
|
||||
"batch": 32, # Batch size for processing
|
||||
"gpu": False, # Force CPU usage for compatibility
|
||||
"limit": 1000 # Maximum number of results for queries
|
||||
@@ -76,7 +101,12 @@ class TxtaiIntelligenceService:
|
||||
try:
|
||||
self.embeddings.load(self.index_path)
|
||||
logger.info(f"Successfully loaded existing txtai index for user {self.user_id}")
|
||||
logger.debug(f"Index contains {len(self.embeddings)} items")
|
||||
# Try to log count, handle if not supported
|
||||
try:
|
||||
count = self.embeddings.count() if hasattr(self.embeddings, 'count') else "unknown"
|
||||
logger.debug(f"Index contains {count} items")
|
||||
except:
|
||||
logger.debug("Index loaded (count unavailable)")
|
||||
except Exception as load_error:
|
||||
logger.warning(f"Failed to load existing index: {load_error}. Creating new index.")
|
||||
# Reset embeddings to create new index
|
||||
@@ -84,8 +114,7 @@ class TxtaiIntelligenceService:
|
||||
"path": self.model_path,
|
||||
"content": True,
|
||||
"objects": True,
|
||||
"backend": "faiss",
|
||||
"quantize": True,
|
||||
"backend": self._backend,
|
||||
"batch": 32,
|
||||
"gpu": False,
|
||||
"limit": 1000
|
||||
@@ -146,8 +175,15 @@ class TxtaiIntelligenceService:
|
||||
logger.error(f"Error indexing content for user {self.user_id}: {e}")
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
logger.error(f"Items count: {len(items) if items else 0}")
|
||||
if items and len(items) > 0:
|
||||
logger.error(f"Sample item structure: {type(items[0])}")
|
||||
|
||||
message = str(e)
|
||||
is_windows_lock_error = isinstance(e, PermissionError) or "WinError 32" in message
|
||||
if is_windows_lock_error:
|
||||
logger.warning(
|
||||
f"Txtai index save skipped for user {self.user_id} due to file lock. "
|
||||
f"The index will be retried on a future run."
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
async def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
@@ -172,7 +208,20 @@ class TxtaiIntelligenceService:
|
||||
logger.debug(f"Cache miss for search query: '{query}'")
|
||||
|
||||
logger.debug(f"Searching for query: '{query}' with limit: {limit}")
|
||||
results = self.embeddings.search(query, limit=limit)
|
||||
try:
|
||||
results = self.embeddings.search(query, limit=limit)
|
||||
except AttributeError as ae:
|
||||
if "nprobe" in str(ae):
|
||||
logger.error(f"Detected known txtai/faiss IndexIDMap/nprobe incompatibility for user {self.user_id}. Attempting re-init with numpy backend fallback...")
|
||||
# Switch to numpy backend which doesn't have this issue
|
||||
self._backend = "numpy"
|
||||
self._initialize_embeddings()
|
||||
if self.embeddings:
|
||||
results = self.embeddings.search(query, limit=limit)
|
||||
else:
|
||||
raise ae
|
||||
else:
|
||||
raise ae
|
||||
|
||||
# Cache the results if caching is enabled
|
||||
if self.enable_caching and self.cache_manager and results:
|
||||
@@ -216,7 +265,19 @@ class TxtaiIntelligenceService:
|
||||
logger.debug(f"Cache miss for similarity calculation")
|
||||
|
||||
logger.debug(f"Calculating similarity between texts: '{text1[:50]}...' and '{text2[:50]}...'")
|
||||
similarity = self.embeddings.similarity(text1, text2)
|
||||
try:
|
||||
similarity = self.embeddings.similarity(text1, text2)
|
||||
except AttributeError as ae:
|
||||
if "nprobe" in str(ae):
|
||||
logger.error(f"Detected IndexIDMap nprobe error in similarity for user {self.user_id}. Falling back to numpy backend...")
|
||||
self._backend = "numpy"
|
||||
self._initialize_embeddings()
|
||||
if self.embeddings:
|
||||
similarity = self.embeddings.similarity(text1, text2)
|
||||
else:
|
||||
raise ae
|
||||
else:
|
||||
raise ae
|
||||
|
||||
# Cache the similarity result
|
||||
if self.enable_caching and self.cache_manager:
|
||||
@@ -272,7 +333,19 @@ class TxtaiIntelligenceService:
|
||||
# Use graph-based clustering if available
|
||||
# Perform a search to get graph structure
|
||||
sample_query = "content marketing digital strategy"
|
||||
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
|
||||
try:
|
||||
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
|
||||
except AttributeError as ae:
|
||||
if "nprobe" in str(ae):
|
||||
logger.error(f"Detected IndexIDMap nprobe error in cluster for user {self.user_id}. Falling back to numpy backend...")
|
||||
self._backend = "numpy"
|
||||
self._initialize_embeddings()
|
||||
if self.embeddings:
|
||||
graph_results = self.embeddings.search(sample_query, limit=10, graph=True)
|
||||
else:
|
||||
raise ae
|
||||
else:
|
||||
raise ae
|
||||
|
||||
if not graph_results:
|
||||
logger.warning(f"No graph results for clustering user {self.user_id}")
|
||||
@@ -306,7 +379,7 @@ class TxtaiIntelligenceService:
|
||||
logger.error(f"Full traceback: {traceback.format_exc()}")
|
||||
return self._fallback_clustering(min_score)
|
||||
|
||||
def _fallback_clustering(self, min_score: float) -> List[List[int]]:
|
||||
async def _fallback_clustering(self, min_score: float) -> List[List[int]]:
|
||||
"""Fallback clustering method when graph clustering is not available."""
|
||||
logger.info(f"Using fallback clustering for user {self.user_id}")
|
||||
|
||||
@@ -318,7 +391,8 @@ class TxtaiIntelligenceService:
|
||||
all_clusters = []
|
||||
|
||||
for query in sample_queries:
|
||||
results = self.embeddings.search(query, limit=5)
|
||||
# Use our search wrapper for hardening
|
||||
results = await self.search(query, limit=5)
|
||||
if results and results[0].get("score", 0) >= min_score:
|
||||
# Create a cluster from similar results
|
||||
cluster = [i for i, result in enumerate(results) if result.get("score", 0) >= min_score]
|
||||
@@ -393,9 +467,13 @@ class TxtaiIntelligenceService:
|
||||
return {"status": "not_initialized", "user_id": self.user_id}
|
||||
|
||||
try:
|
||||
# Get count of indexed items - txtai doesn't have a direct len() method
|
||||
# We'll estimate based on available data or return a placeholder
|
||||
index_size = getattr(self.embeddings, 'count', 0) or "unknown"
|
||||
# Get count of indexed items
|
||||
index_size = "unknown"
|
||||
if hasattr(self.embeddings, 'count'):
|
||||
try:
|
||||
index_size = self.embeddings.count()
|
||||
except:
|
||||
pass
|
||||
|
||||
return {
|
||||
"status": "active",
|
||||
@@ -410,5 +488,7 @@ class TxtaiIntelligenceService:
|
||||
return {"status": "error", "user_id": self.user_id, "error": str(e)}
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
"""Check if the service is properly initialized."""
|
||||
"""Check if the service is properly initialized, triggering lazy init if needed."""
|
||||
if not self._initialized:
|
||||
self._ensure_initialized()
|
||||
return self._initialized and self.embeddings is not None
|
||||
|
||||
@@ -369,6 +369,12 @@ def huggingface_structured_json_response(
|
||||
response_text = re.sub(r'```\n?', '', response_text)
|
||||
response_text = response_text.strip()
|
||||
|
||||
# Fix common markdown artefacts that break JSON, e.g. lines starting with **"key":
|
||||
# **"narration": "text"
|
||||
# becomes:
|
||||
# "narration": "text"
|
||||
response_text = re.sub(r'^\s*\*\*(?=\s*")', '', response_text, flags=re.MULTILINE)
|
||||
|
||||
try:
|
||||
parsed_json = json.loads(response_text)
|
||||
logger.info("✅ Hugging Face structured JSON response parsed from text")
|
||||
|
||||
@@ -648,11 +648,13 @@ async def ai_video_generate(
|
||||
|
||||
# PRE-FLIGHT VALIDATION: Validate video generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
from fastapi import HTTPException
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Database session unavailable for user.")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
@@ -762,9 +764,11 @@ def track_video_usage(
|
||||
from datetime import datetime
|
||||
|
||||
from models.subscription_models import APIProvider, APIUsageLog, UsageSummary
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
return {}
|
||||
try:
|
||||
logger.info(f"[video_gen] Starting usage tracking for user={user_id}, provider={provider}, model={model_name}")
|
||||
pricing_service_track = PricingService(db_track)
|
||||
|
||||
@@ -527,6 +527,11 @@ class APIKeyManager:
|
||||
def __init__(self):
|
||||
self.api_keys = {}
|
||||
self._load_from_env()
|
||||
|
||||
def load_api_keys(self):
|
||||
self.api_keys = {}
|
||||
self._load_from_env()
|
||||
return self.api_keys
|
||||
|
||||
def _load_from_env(self):
|
||||
"""Load API keys from environment variables."""
|
||||
|
||||
@@ -27,6 +27,12 @@ async def generate_facebook_persona_task(user_id: str):
|
||||
try:
|
||||
logger.info(f"Scheduled Facebook persona generation started for user {user_id}")
|
||||
|
||||
# Ensure we have a valid session factory before trying to get session
|
||||
from services.database import SessionLocal
|
||||
if not SessionLocal:
|
||||
logger.error("Database session factory not initialized")
|
||||
return
|
||||
|
||||
db = get_db_session()
|
||||
if not db:
|
||||
logger.error(f"Failed to get database session for Facebook persona generation (user: {user_id})")
|
||||
|
||||
177
backend/services/podcast_bible_service.py
Normal file
177
backend/services/podcast_bible_service.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from typing import Dict, Any, Optional
|
||||
from loguru import logger
|
||||
from services.product_marketing.personalization_service import PersonalizationService
|
||||
from models.podcast_bible_models import (
|
||||
PodcastBible,
|
||||
HostPersona,
|
||||
AudienceDNA,
|
||||
BrandDNA,
|
||||
VisualStyle,
|
||||
AudioEnvironment,
|
||||
ShowRules
|
||||
)
|
||||
|
||||
class PodcastBibleService:
|
||||
"""Service for generating and managing the Podcast Bible."""
|
||||
|
||||
def __init__(self):
|
||||
self.personalization_service = PersonalizationService()
|
||||
|
||||
def generate_bible(self, user_id: str, project_id: str) -> PodcastBible:
|
||||
"""Generate a Podcast Bible from onboarding data."""
|
||||
logger.info(f"Generating Podcast Bible for user {user_id}")
|
||||
|
||||
try:
|
||||
preferences = self.personalization_service.get_user_preferences(user_id)
|
||||
writing_style = preferences.get("writing_style", {})
|
||||
style_prefs = preferences.get("style_preferences", {})
|
||||
target_audience = preferences.get("target_audience", {})
|
||||
industry = preferences.get("industry", "General Business")
|
||||
|
||||
# 1. Map Host Persona
|
||||
host = HostPersona(
|
||||
name="Your AI Host",
|
||||
background=f"Expert in {industry}",
|
||||
expertise_level=writing_style.get("complexity", "Expert").capitalize(),
|
||||
personality_traits=[
|
||||
writing_style.get("tone", "Professional").capitalize(),
|
||||
writing_style.get("engagement_level", "Informative").capitalize()
|
||||
],
|
||||
vocal_style=writing_style.get("voice", "Authoritative").capitalize(),
|
||||
vocal_characteristics=["Clear", "Articulate", writing_style.get("voice", "Steady")],
|
||||
look=f"A professional individual dressed in business-casual attire, fitting the {industry} industry aesthetic.",
|
||||
catchphrases=[]
|
||||
)
|
||||
|
||||
# 2. Map Audience DNA
|
||||
audience = AudienceDNA(
|
||||
expertise_level=target_audience.get("expertise_level", "Intermediate").capitalize(),
|
||||
interests=target_audience.get("interests", ["Industry Trends", "Innovation"]),
|
||||
pain_points=target_audience.get("pain_points", ["Staying ahead of competition", "Efficiency"]),
|
||||
demographics=None
|
||||
)
|
||||
|
||||
# 3. Map Brand DNA
|
||||
brand = BrandDNA(
|
||||
industry=industry,
|
||||
tone=writing_style.get("tone", "Professional").capitalize(),
|
||||
communication_style=writing_style.get("engagement_level", "Informative").capitalize(),
|
||||
key_messages=preferences.get("brand_values", []),
|
||||
competitor_context=None
|
||||
)
|
||||
|
||||
# 4. Map Visual Style
|
||||
visual = VisualStyle(
|
||||
style_preset=style_prefs.get("aesthetic", "Professional Studio").capitalize(),
|
||||
environment=f"A modern {industry}-themed podcast studio with professional equipment.",
|
||||
lighting="Soft, warm studio lighting with subtle rim lights.",
|
||||
color_palette=preferences.get("brand_colors", ["#1e293b", "#3b82f6"]),
|
||||
camera_style="Dynamic mid-shots with occasional close-ups for emphasis."
|
||||
)
|
||||
|
||||
# 5. Map Audio Environment
|
||||
audio_env = AudioEnvironment(
|
||||
soundscape="Pristine studio environment with deep, warm acoustics.",
|
||||
music_mood=f"{writing_style.get('tone', 'Professional').capitalize()} & {writing_style.get('engagement_level', 'Upbeat').capitalize()}",
|
||||
sfx_style="Modern, clean interface-inspired sounds."
|
||||
)
|
||||
|
||||
# 6. Map Show Rules
|
||||
show_rules = ShowRules(
|
||||
intro_format=f"Start with a high-energy hook about the episode topic, followed by a warm welcome and an overview of the {industry} insights to be shared.",
|
||||
outro_format="Summarize the key takeaways, provide a clear call to action, and sign off with a professional closing.",
|
||||
interaction_tone=writing_style.get("engagement_level", "Conversational").capitalize(),
|
||||
constraints=[
|
||||
"Avoid overly technical jargon unless defined",
|
||||
"Keep segments concise and factual",
|
||||
f"Maintain a {writing_style.get('tone', 'Professional')} tone at all times"
|
||||
]
|
||||
)
|
||||
|
||||
bible = PodcastBible(
|
||||
project_id=project_id,
|
||||
host=host,
|
||||
audience=audience,
|
||||
brand=brand,
|
||||
visual_style=visual,
|
||||
audio_environment=audio_env,
|
||||
show_rules=show_rules
|
||||
)
|
||||
|
||||
logger.info(f"Podcast Bible generated successfully for project {project_id}")
|
||||
return bible
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating Podcast Bible: {str(e)}")
|
||||
# Return a default bible if something goes wrong to ensure project creation doesn't fail
|
||||
return self._get_default_bible(project_id)
|
||||
|
||||
def _get_default_bible(self, project_id: str) -> PodcastBible:
|
||||
"""Return a sensible default Bible."""
|
||||
return PodcastBible(
|
||||
project_id=project_id,
|
||||
host=HostPersona(
|
||||
name="AI Host",
|
||||
background="Industry Professional",
|
||||
expertise_level="Expert",
|
||||
vocal_style="Authoritative",
|
||||
vocal_characteristics=["Deep", "Steady"]
|
||||
),
|
||||
audience=AudienceDNA(
|
||||
expertise_level="Intermediate",
|
||||
interests=["Industry Trends", "Technology"],
|
||||
pain_points=["Staying Competitive", "Operational Efficiency"]
|
||||
),
|
||||
brand=BrandDNA(
|
||||
industry="General Business",
|
||||
tone="Professional",
|
||||
communication_style="Analytical"
|
||||
),
|
||||
visual_style=VisualStyle(
|
||||
environment="Professional modern office studio",
|
||||
color_palette=["#000000", "#FFFFFF"]
|
||||
),
|
||||
audio_environment=AudioEnvironment(),
|
||||
show_rules=ShowRules(
|
||||
intro_format="Standard welcome and topic introduction.",
|
||||
outro_format="Summary and sign-off."
|
||||
)
|
||||
)
|
||||
|
||||
def serialize_bible(self, bible: PodcastBible) -> str:
|
||||
"""Serialize the Bible into a prompt-friendly text block."""
|
||||
return f"""
|
||||
<podcast_bible>
|
||||
HOST PERSONA:
|
||||
- Name: {bible.host.name}
|
||||
- Background: {bible.host.background}
|
||||
- Expertise Level: {bible.host.expertise_level}
|
||||
- Personality: {', '.join(bible.host.personality_traits)}
|
||||
- Vocal Style: {bible.host.vocal_style}
|
||||
- Vocal Characteristics: {', '.join(bible.host.vocal_characteristics)}
|
||||
- Visual Look: {bible.host.look}
|
||||
|
||||
TARGET AUDIENCE:
|
||||
- Expertise: {bible.audience.expertise_level}
|
||||
- Interests: {', '.join(bible.audience.interests)}
|
||||
- Pain Points: {', '.join(bible.audience.pain_points)}
|
||||
|
||||
BRAND & STYLE:
|
||||
- Industry: {bible.brand.industry}
|
||||
- Tone: {bible.brand.tone}
|
||||
- Communication Style: {bible.brand.communication_style}
|
||||
- Visual Style Preset: {bible.visual_style.style_preset}
|
||||
- Environment: {bible.visual_style.environment}
|
||||
- Lighting: {bible.visual_style.lighting}
|
||||
|
||||
AUDIO ENVIRONMENT:
|
||||
- Soundscape: {bible.audio_environment.soundscape}
|
||||
- Music Mood: {bible.audio_environment.music_mood}
|
||||
|
||||
SHOW RULES & STRUCTURE:
|
||||
- Intro Format: {bible.show_rules.intro_format}
|
||||
- Outro Format: {bible.show_rules.outro_format}
|
||||
- Interaction Tone: {bible.show_rules.interaction_tone}
|
||||
- Constraints: {', '.join(bible.show_rules.constraints)}
|
||||
</podcast_bible>
|
||||
"""
|
||||
@@ -11,6 +11,7 @@ from datetime import datetime
|
||||
import uuid
|
||||
|
||||
from models.podcast_models import PodcastProject
|
||||
from services.podcast_bible_service import PodcastBibleService
|
||||
|
||||
|
||||
class PodcastService:
|
||||
@@ -18,6 +19,7 @@ class PodcastService:
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.bible_service = PodcastBibleService()
|
||||
|
||||
def create_project(
|
||||
self,
|
||||
@@ -30,6 +32,9 @@ class PodcastService:
|
||||
**kwargs
|
||||
) -> PodcastProject:
|
||||
"""Create a new podcast project."""
|
||||
# Generate Podcast Bible automatically from onboarding data
|
||||
bible = self.bible_service.generate_bible(user_id, project_id)
|
||||
|
||||
project = PodcastProject(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
@@ -37,6 +42,7 @@ class PodcastService:
|
||||
duration=duration,
|
||||
speakers=speakers,
|
||||
budget_cap=budget_cap,
|
||||
bible=bible.model_dump() if bible else None,
|
||||
status="draft",
|
||||
current_step="create",
|
||||
**kwargs
|
||||
|
||||
@@ -5,13 +5,15 @@ Pluggable task scheduler that can work with any task model.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Dict, Any, Optional, List, Callable
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timedelta
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from apscheduler.triggers.cron import CronTrigger
|
||||
from apscheduler.triggers.interval import IntervalTrigger
|
||||
from apscheduler.triggers.date import DateTrigger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import text
|
||||
|
||||
from .executor_interface import TaskExecutor, TaskExecutionResult
|
||||
from .task_registry import TaskRegistry
|
||||
@@ -19,8 +21,10 @@ from .exception_handler import (
|
||||
SchedulerExceptionHandler, SchedulerException, TaskExecutionError, DatabaseError,
|
||||
TaskLoaderError, SchedulerConfigError
|
||||
)
|
||||
|
||||
from services.database import get_all_user_ids, get_session_for_user
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
from ..utils.user_job_store import get_user_job_store_name
|
||||
from models.scheduler_models import SchedulerEventLog
|
||||
from .interval_manager import determine_optimal_interval, adjust_check_interval_if_needed
|
||||
@@ -86,6 +90,9 @@ class TaskScheduler:
|
||||
}
|
||||
)
|
||||
|
||||
# Configure APScheduler to use unified logging system
|
||||
self._configure_apscheduler_logging()
|
||||
|
||||
# Task executor registry
|
||||
self.registry = TaskRegistry()
|
||||
|
||||
@@ -115,6 +122,21 @@ class TaskScheduler:
|
||||
}
|
||||
|
||||
self._running = False
|
||||
|
||||
# Local Desktop App: Always leader, no advisory locks needed
|
||||
self._leader_lock_key = int(os.getenv("SCHEDULER_LEADER_LOCK_KEY", "84321017"))
|
||||
self._leadership_check_interval_seconds = int(os.getenv("SCHEDULER_LEADERSHIP_CHECK_INTERVAL", "15"))
|
||||
self._leader_session = None
|
||||
self._is_leader = True # Always leader in local desktop app
|
||||
self._execution_enabled = True # Always enabled
|
||||
self._leader_since = datetime.utcnow().isoformat()
|
||||
self._last_leadership_check = None
|
||||
self._last_leadership_error = None
|
||||
|
||||
|
||||
# Execution lease registry (prevents duplicate redispatch across check cycles)
|
||||
self._task_leases: Dict[str, str] = {}
|
||||
self._task_lease_ttl_seconds = int(os.getenv("SCHEDULER_TASK_LEASE_TTL_SECONDS", "900"))
|
||||
|
||||
def _get_trigger_for_interval(self, interval_minutes: int):
|
||||
"""
|
||||
@@ -153,6 +175,144 @@ class TaskScheduler:
|
||||
self.registry.register(task_type, executor, task_loader)
|
||||
logger.info(f"Registered executor for task type: {task_type}")
|
||||
|
||||
def _configure_apscheduler_logging(self):
|
||||
"""Configure APScheduler to use unified logging system."""
|
||||
import logging
|
||||
|
||||
# Get APScheduler loggers and redirect them to unified logging
|
||||
apscheduler_logger = logging.getLogger("apscheduler")
|
||||
apscheduler_scheduler_logger = logging.getLogger("apscheduler.scheduler")
|
||||
apscheduler_executors_logger = logging.getLogger("apscheduler.executors")
|
||||
apscheduler_jobstores_logger = logging.getLogger("apscheduler.jobstores")
|
||||
|
||||
# Create a custom handler that redirects to unified logger
|
||||
class APSchedulerUnifiedHandler(logging.Handler):
|
||||
def __init__(self, service_logger):
|
||||
super().__init__()
|
||||
self.service_logger = service_logger
|
||||
|
||||
def emit(self, record):
|
||||
try:
|
||||
# Format the message
|
||||
msg = self.format(record)
|
||||
|
||||
# Map APScheduler log levels to unified logger
|
||||
if record.levelno >= logging.ERROR:
|
||||
self.service_logger.error(f"[APScheduler] {msg}")
|
||||
elif record.levelno >= logging.WARNING:
|
||||
self.service_logger.warning(f"[APScheduler] {msg}")
|
||||
elif record.levelno >= logging.INFO:
|
||||
self.service_logger.info(f"[APScheduler] {msg}")
|
||||
else:
|
||||
self.service_logger.debug(f"[APScheduler] {msg}")
|
||||
except Exception:
|
||||
# Don't let logging errors break the scheduler
|
||||
pass
|
||||
|
||||
# Create and add the handler
|
||||
unified_handler = APSchedulerUnifiedHandler(logger)
|
||||
unified_handler.setLevel(logging.DEBUG)
|
||||
|
||||
# Add handler to all APScheduler loggers
|
||||
apscheduler_logger.addHandler(unified_handler)
|
||||
apscheduler_scheduler_logger.addHandler(unified_handler)
|
||||
apscheduler_executors_logger.addHandler(unified_handler)
|
||||
apscheduler_jobstores_logger.addHandler(unified_handler)
|
||||
|
||||
# Set levels to capture all logs
|
||||
apscheduler_logger.setLevel(logging.DEBUG)
|
||||
apscheduler_scheduler_logger.setLevel(logging.DEBUG)
|
||||
apscheduler_executors_logger.setLevel(logging.DEBUG)
|
||||
apscheduler_jobstores_logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Prevent propagation to avoid duplicate logs
|
||||
apscheduler_logger.propagate = False
|
||||
apscheduler_scheduler_logger.propagate = False
|
||||
apscheduler_executors_logger.propagate = False
|
||||
apscheduler_jobstores_logger.propagate = False
|
||||
|
||||
logger.info("APScheduler logging configured to use unified logging system")
|
||||
|
||||
|
||||
def _scheduler_identity(self) -> str:
|
||||
return f"{os.getenv('HOSTNAME', 'local')}-{os.getpid()}"
|
||||
|
||||
def _acquire_leadership(self) -> bool:
|
||||
"""Always return True for local desktop app (no HA needed)."""
|
||||
self._is_leader = True
|
||||
self._execution_enabled = True
|
||||
if not self._leader_since:
|
||||
self._leader_since = datetime.utcnow().isoformat()
|
||||
self._last_leadership_check = datetime.utcnow().isoformat()
|
||||
return True
|
||||
|
||||
def _release_leadership(self):
|
||||
"""No-op for local desktop app."""
|
||||
pass
|
||||
|
||||
def _sync_check_due_tasks_job(self):
|
||||
"""Ensure check_due_tasks job exists only for leader."""
|
||||
job = self.scheduler.get_job('check_due_tasks')
|
||||
if self._is_leader and self._execution_enabled:
|
||||
if job is None:
|
||||
self.scheduler.add_job(
|
||||
self._check_and_execute_due_tasks,
|
||||
trigger=self._get_trigger_for_interval(self.current_check_interval_minutes),
|
||||
id='check_due_tasks',
|
||||
replace_existing=True
|
||||
)
|
||||
else:
|
||||
if job is not None:
|
||||
self.scheduler.remove_job('check_due_tasks')
|
||||
|
||||
async def _leadership_tick(self):
|
||||
"""Periodic leadership check/renewal (Stub for local)."""
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
self._acquire_leadership()
|
||||
self._sync_check_due_tasks_job()
|
||||
|
||||
def _acquire_task_lease(self, task_key: str) -> bool:
|
||||
"""Acquire in-memory lease for a task key if available/expired."""
|
||||
now = datetime.utcnow()
|
||||
expiry_str = self._task_leases.get(task_key)
|
||||
|
||||
if expiry_str:
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expiry_str)
|
||||
if expiry > now:
|
||||
return False
|
||||
except Exception:
|
||||
# Corrupted lease value: overwrite safely
|
||||
pass
|
||||
|
||||
expiry = now + timedelta(seconds=self._task_lease_ttl_seconds)
|
||||
self._task_leases[task_key] = expiry.isoformat()
|
||||
return True
|
||||
|
||||
def _release_task_lease(self, task_key: str):
|
||||
"""Release lease for task key."""
|
||||
if task_key in self._task_leases:
|
||||
del self._task_leases[task_key]
|
||||
|
||||
def _is_task_leased(self, task_key: str) -> bool:
|
||||
"""Check whether task key is currently leased and not expired."""
|
||||
expiry_str = self._task_leases.get(task_key)
|
||||
if not expiry_str:
|
||||
return False
|
||||
|
||||
try:
|
||||
expiry = datetime.fromisoformat(expiry_str)
|
||||
if expiry > datetime.utcnow():
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Expired/corrupt lease gets cleaned up lazily
|
||||
self._release_task_lease(task_key)
|
||||
return False
|
||||
|
||||
async def start(self):
|
||||
"""Start the scheduler with intelligent interval adjustment."""
|
||||
if self._running:
|
||||
@@ -168,16 +328,21 @@ class TaskScheduler:
|
||||
)
|
||||
self.current_check_interval_minutes = initial_interval
|
||||
|
||||
# Add periodic job to check for due tasks
|
||||
self.scheduler.add_job(
|
||||
self._check_and_execute_due_tasks,
|
||||
trigger=self._get_trigger_for_interval(initial_interval),
|
||||
id='check_due_tasks',
|
||||
replace_existing=True
|
||||
)
|
||||
|
||||
self.scheduler.start()
|
||||
self._running = True
|
||||
|
||||
# Leadership monitor runs on all replicas; only leader executes due-task loop.
|
||||
self.scheduler.add_job(
|
||||
self._leadership_tick,
|
||||
trigger=IntervalTrigger(seconds=self._leadership_check_interval_seconds),
|
||||
id='leadership_monitor',
|
||||
replace_existing=True,
|
||||
max_instances=1,
|
||||
coalesce=True
|
||||
)
|
||||
|
||||
# Initial leader election
|
||||
await self._leadership_tick()
|
||||
|
||||
# Check for and execute any missed jobs that are still within grace period
|
||||
await self._execute_missed_jobs()
|
||||
@@ -206,7 +371,7 @@ class TaskScheduler:
|
||||
registered_types = self.registry.get_registered_types()
|
||||
active_strategies = self.stats.get('active_strategies_count', 0)
|
||||
|
||||
# Count OAuth token monitoring tasks from database (recurring weekly tasks)
|
||||
# Count tasks per user (Multi-tenant SQLite)
|
||||
oauth_tasks_count = 0
|
||||
website_analysis_tasks_count = 0
|
||||
platform_insights_tasks_count = 0
|
||||
@@ -323,126 +488,6 @@ class TaskScheduler:
|
||||
|
||||
startup_lines.append(f"{prefix} Job: {job.id} | Trigger: {trigger_type} | Next Run: {next_run}{user_context}")
|
||||
|
||||
# Add OAuth token monitoring tasks details
|
||||
# Show ALL OAuth tasks (active and inactive) for complete visibility
|
||||
if total_oauth_tasks > 0:
|
||||
try:
|
||||
user_ids = get_all_user_ids()
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
from models.oauth_token_monitoring_models import OAuthTokenMonitoringTask
|
||||
# Get ALL tasks for this user
|
||||
oauth_tasks = db.query(OAuthTokenMonitoringTask).all()
|
||||
|
||||
for idx, task in enumerate(oauth_tasks):
|
||||
is_last = idx == len(oauth_tasks) - 1 and website_analysis_tasks_count == 0 and platform_insights_tasks_count == 0 and len(all_jobs) == 0 and user_id == user_ids[-1]
|
||||
prefix = " ├─" # Simplified prefix logic for multi-user list
|
||||
|
||||
try:
|
||||
user_job_store = get_user_job_store_name(task.user_id, db)
|
||||
if user_job_store == 'default':
|
||||
logger.debug(
|
||||
f"[Scheduler] Job store extraction returned 'default' for user {task.user_id}. "
|
||||
f"This may indicate no onboarding data or website URL not found."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[Scheduler] Could not extract job store name for user {task.user_id}: {e}. "
|
||||
f"Using 'default'. Error type: {type(e).__name__}"
|
||||
)
|
||||
user_job_store = 'default'
|
||||
|
||||
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
|
||||
# Include status in the log line for visibility
|
||||
status_indicator = "✅" if task.status == 'active' else f"[{task.status}]"
|
||||
startup_lines.append(
|
||||
f"{prefix} Job: oauth_token_monitoring_{task.platform}_{task.user_id} | "
|
||||
f"Trigger: CronTrigger (Weekly) | Next Run: {next_check} | "
|
||||
f"User: {task.user_id} | Store: {user_job_store} | Platform: {task.platform} {status_indicator}"
|
||||
)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking OAuth tasks for user {user_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get OAuth token monitoring task details: {e}")
|
||||
|
||||
# Add website analysis tasks details
|
||||
if website_analysis_tasks_count > 0:
|
||||
try:
|
||||
user_ids = get_all_user_ids()
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
from models.website_analysis_monitoring_models import WebsiteAnalysisTask
|
||||
website_analysis_tasks = db.query(WebsiteAnalysisTask).all()
|
||||
|
||||
for idx, task in enumerate(website_analysis_tasks):
|
||||
is_last = idx == len(website_analysis_tasks) - 1 and platform_insights_tasks_count == 0 and len(all_jobs) == 0 and total_oauth_tasks == 0 and user_id == user_ids[-1]
|
||||
prefix = " ├─" # Simplified
|
||||
|
||||
try:
|
||||
user_job_store = get_user_job_store_name(task.user_id, db)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract job store name for user {task.user_id}: {e}")
|
||||
user_job_store = 'default'
|
||||
|
||||
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
|
||||
frequency = f"Every {task.frequency_days} days"
|
||||
task_type_label = "User Website" if task.task_type == 'user_website' else "Competitor"
|
||||
status_indicator = "✅" if task.status == 'active' else f"[{task.status}]"
|
||||
website_display = task.website_url[:50] + "..." if task.website_url and len(task.website_url) > 50 else (task.website_url or 'N/A')
|
||||
|
||||
startup_lines.append(
|
||||
f"{prefix} Job: website_analysis_{task.task_type}_{task.user_id}_{task.id} | "
|
||||
f"Trigger: CronTrigger ({frequency}) | Next Run: {next_check} | "
|
||||
f"User: {task.user_id} | Store: {user_job_store} | Type: {task_type_label} | URL: {website_display} {status_indicator}"
|
||||
)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking website analysis tasks for user {user_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get website analysis task details: {e}")
|
||||
|
||||
# Add platform insights tasks details
|
||||
if platform_insights_tasks_count > 0:
|
||||
try:
|
||||
user_ids = get_all_user_ids()
|
||||
for user_id in user_ids:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
from models.platform_insights_monitoring_models import PlatformInsightsTask
|
||||
platform_insights_tasks = db.query(PlatformInsightsTask).all()
|
||||
|
||||
for idx, task in enumerate(platform_insights_tasks):
|
||||
is_last = idx == len(platform_insights_tasks) - 1 and len(all_jobs) == 0 and total_oauth_tasks == 0 and website_analysis_tasks_count == 0 and user_id == user_ids[-1]
|
||||
prefix = " ├─" # Simplified
|
||||
|
||||
try:
|
||||
user_job_store = get_user_job_store_name(task.user_id, db)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not extract job store name for user {task.user_id}: {e}")
|
||||
user_job_store = 'default'
|
||||
|
||||
next_check = task.next_check.isoformat() if task.next_check else 'Not scheduled'
|
||||
platform_label = task.platform.upper() if task.platform else 'Unknown'
|
||||
site_display = task.site_url[:50] + "..." if task.site_url and len(task.site_url) > 50 else (task.site_url or 'N/A')
|
||||
status_indicator = "✅" if task.status == 'active' else f"[{task.status}]"
|
||||
|
||||
startup_lines.append(
|
||||
f"{prefix} Job: platform_insights_{task.platform}_{task.user_id} | "
|
||||
f"Trigger: CronTrigger (Weekly) | Next Run: {next_check} | "
|
||||
f"User: {task.user_id} | Store: {user_job_store} | Platform: {platform_label} | Site: {site_display} {status_indicator}"
|
||||
)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error checking platform insights tasks for user {user_id}: {e}")
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not get platform insights task details: {e}")
|
||||
|
||||
# Add Advertools tasks details
|
||||
if advertools_tasks_count > 0:
|
||||
try:
|
||||
@@ -518,7 +563,15 @@ class TaskScheduler:
|
||||
|
||||
# Get final job count before shutdown
|
||||
all_jobs_before = self.scheduler.get_jobs()
|
||||
|
||||
|
||||
# Release leadership lock and stop leadership monitor
|
||||
try:
|
||||
if self.scheduler.get_job('leadership_monitor') is not None:
|
||||
self.scheduler.remove_job('leadership_monitor')
|
||||
except Exception:
|
||||
pass
|
||||
self._release_leadership()
|
||||
|
||||
# Shutdown scheduler
|
||||
self.scheduler.shutdown(wait=True)
|
||||
self._running = False
|
||||
@@ -569,6 +622,10 @@ class TaskScheduler:
|
||||
Main scheduler loop: check for due tasks and execute them.
|
||||
This runs periodically with intelligent interval adjustment based on active strategies.
|
||||
"""
|
||||
if not self._execution_enabled or not self._is_leader:
|
||||
logger.debug("[Scheduler] Skipping due-task loop on standby replica")
|
||||
return
|
||||
|
||||
await check_and_execute_due_tasks(self)
|
||||
|
||||
async def _adjust_check_interval_if_needed(self, db: Session):
|
||||
@@ -614,309 +671,156 @@ class TaskScheduler:
|
||||
except Exception as e:
|
||||
logger.warning(f"[Scheduler] Error checking for missed jobs: {e}")
|
||||
|
||||
async def trigger_interval_adjustment(self):
|
||||
"""
|
||||
Trigger immediate interval adjustment check.
|
||||
|
||||
This should be called when a strategy is activated or deactivated
|
||||
to immediately adjust the scheduler interval based on current active strategies.
|
||||
"""
|
||||
if not self._running:
|
||||
logger.debug("Scheduler not running, skipping interval adjustment")
|
||||
return
|
||||
|
||||
try:
|
||||
# Multi-tenant aware adjustment (iterates all users internally)
|
||||
await adjust_check_interval_if_needed(self)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error triggering interval adjustment: {e}")
|
||||
|
||||
async def _validate_and_rebuild_cumulative_stats(self):
|
||||
"""
|
||||
Validate cumulative stats on scheduler startup and rebuild if needed.
|
||||
This ensures cumulative stats are accurate after restarts.
|
||||
|
||||
NOTE: Disabled in multi-tenant mode as there is no global database for cumulative stats.
|
||||
TODO: Implement per-user cumulative stats or a global admin database.
|
||||
Validate and rebuild cumulative stats if needed.
|
||||
Currently a placeholder for future implementation.
|
||||
"""
|
||||
logger.info("[Scheduler] Cumulative stats validation skipped (multi-tenant mode)")
|
||||
return
|
||||
|
||||
async def _process_task_type(self, task_type: str, db: Session, cycle_summary: Dict[str, Any] = None, user_id: str = None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Process due tasks for a specific task type.
|
||||
|
||||
Returns:
|
||||
Summary dict with 'found', 'executed', 'failed' counts, or None if no tasks
|
||||
"""
|
||||
summary = {'found': 0, 'executed': 0, 'failed': 0}
|
||||
|
||||
pass
|
||||
|
||||
async def _process_task_type(
|
||||
self,
|
||||
task_type: str,
|
||||
db: Session,
|
||||
cycle_summary: Dict[str, Any],
|
||||
user_id: Optional[str] = None
|
||||
) -> Dict[str, int]:
|
||||
summary = {"found": 0, "executed": 0, "failed": 0}
|
||||
try:
|
||||
# Get task loader for this type
|
||||
try:
|
||||
task_loader = self.registry.get_task_loader(task_type)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to get task loader for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return None
|
||||
|
||||
# Load due tasks (with error handling)
|
||||
try:
|
||||
due_tasks = task_loader(db)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Failed to load due tasks for type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
return None
|
||||
|
||||
if not due_tasks:
|
||||
return None
|
||||
|
||||
summary['found'] = len(due_tasks)
|
||||
self.stats['tasks_found'] += len(due_tasks)
|
||||
|
||||
# Execute tasks (with concurrency limit)
|
||||
execution_tasks = []
|
||||
skipped_count = 0
|
||||
for task in due_tasks:
|
||||
if len(self.active_executions) >= self.max_concurrent_executions:
|
||||
skipped_count = len(due_tasks) - len(execution_tasks)
|
||||
logger.warning(
|
||||
f"[Scheduler] ⚠️ Max concurrent executions reached ({self.max_concurrent_executions}), "
|
||||
f"skipping {skipped_count} tasks for {task_type}"
|
||||
)
|
||||
break
|
||||
|
||||
# Execute task asynchronously
|
||||
# Note: Each task gets its own database session to prevent concurrent access issues
|
||||
execution_task = asyncio.create_task(
|
||||
execute_task_async(self, task_type, task, summary, user_id=user_id)
|
||||
)
|
||||
|
||||
task_id = f"{task_type}_{getattr(task, 'id', id(task))}"
|
||||
self.active_executions[task_id] = execution_task
|
||||
|
||||
execution_tasks.append(execution_task)
|
||||
|
||||
# Wait for executions to complete (with timeout per task)
|
||||
if execution_tasks:
|
||||
await asyncio.wait(execution_tasks, timeout=300)
|
||||
|
||||
return summary
|
||||
|
||||
task_loader = self.registry.get_task_loader(task_type)
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Error processing task type {task_type}: {str(e)}",
|
||||
task_type=task_type,
|
||||
message=f"Failed to get task loader for type {task_type}: {str(e)}",
|
||||
user_id=user_id,
|
||||
context={"task_type": task_type},
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats["tasks_failed"] += 1
|
||||
return summary
|
||||
|
||||
|
||||
def _update_user_stats(self, user_id: Optional[int], success: bool):
|
||||
"""
|
||||
Update per-user statistics for user isolation tracking.
|
||||
|
||||
Args:
|
||||
user_id: User ID (None if user context not available)
|
||||
success: Whether task execution was successful
|
||||
"""
|
||||
if user_id is None:
|
||||
|
||||
try:
|
||||
tasks = task_loader(db)
|
||||
if not tasks:
|
||||
return summary
|
||||
|
||||
summary["found"] = len(tasks)
|
||||
max_concurrent = self.max_concurrent_executions
|
||||
|
||||
for task in tasks:
|
||||
task_id = getattr(task, "id", None)
|
||||
lease_key = f"{task_type}_{task_id or id(task)}"
|
||||
|
||||
if self._is_task_leased(lease_key):
|
||||
continue
|
||||
|
||||
if len(self.active_executions) >= max_concurrent:
|
||||
break
|
||||
|
||||
if not self._acquire_task_lease(lease_key):
|
||||
continue
|
||||
|
||||
execution_task = asyncio.create_task(
|
||||
execute_task_async(
|
||||
self,
|
||||
task_type,
|
||||
task,
|
||||
summary,
|
||||
execution_source="scheduler",
|
||||
user_id=user_id,
|
||||
)
|
||||
)
|
||||
self.active_executions[lease_key] = execution_task
|
||||
|
||||
cycle_summary.setdefault("tasks_found_by_type", {})
|
||||
cycle_summary.setdefault("tasks_executed_by_type", {})
|
||||
cycle_summary.setdefault("tasks_failed_by_type", {})
|
||||
|
||||
cycle_summary["tasks_found_by_type"][task_type] = (
|
||||
cycle_summary["tasks_found_by_type"].get(task_type, 0)
|
||||
+ summary["found"]
|
||||
)
|
||||
cycle_summary["tasks_executed_by_type"][task_type] = (
|
||||
cycle_summary["tasks_executed_by_type"].get(task_type, 0)
|
||||
+ summary["executed"]
|
||||
)
|
||||
cycle_summary["tasks_failed_by_type"][task_type] = (
|
||||
cycle_summary["tasks_failed_by_type"].get(task_type, 0)
|
||||
+ summary["failed"]
|
||||
)
|
||||
|
||||
return summary
|
||||
except Exception as e:
|
||||
error = TaskLoaderError(
|
||||
message=f"Error processing task type {task_type}: {str(e)}",
|
||||
user_id=user_id,
|
||||
context={"task_type": task_type},
|
||||
original_error=e
|
||||
)
|
||||
self.exception_handler.handle_exception(error)
|
||||
self.stats["tasks_failed"] += 1
|
||||
return summary
|
||||
|
||||
def _update_user_stats(self, user_id: Optional[str], success: bool):
|
||||
if not user_id:
|
||||
return
|
||||
|
||||
if user_id not in self.stats['per_user_stats']:
|
||||
self.stats['per_user_stats'][user_id] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
|
||||
user_stats = self.stats['per_user_stats'][user_id]
|
||||
per_user = self.stats.setdefault("per_user_stats", {})
|
||||
user_stats = per_user.setdefault(
|
||||
user_id,
|
||||
{
|
||||
"tasks_executed": 0,
|
||||
"tasks_failed": 0,
|
||||
"last_update": None,
|
||||
},
|
||||
)
|
||||
if success:
|
||||
user_stats['executed'] += 1
|
||||
user_stats["tasks_executed"] += 1
|
||||
else:
|
||||
user_stats['failed'] += 1
|
||||
|
||||
# Calculate success rate
|
||||
total = user_stats['executed'] + user_stats['failed']
|
||||
if total > 0:
|
||||
user_stats['success_rate'] = (user_stats['executed'] / total) * 100.0
|
||||
|
||||
async def _schedule_retry(self, task: Any, delay_seconds: int):
|
||||
"""Schedule a retry for a failed task."""
|
||||
# This would update the task's next_execution time
|
||||
# For now, just log - could be enhanced to update next_execution
|
||||
logger.debug(f"Scheduling retry for task in {delay_seconds}s")
|
||||
|
||||
def get_stats(self, user_id: Optional[int] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Get scheduler statistics with optional user filtering.
|
||||
|
||||
Args:
|
||||
user_id: Optional user ID to filter statistics for specific user
|
||||
|
||||
Returns:
|
||||
Dictionary with scheduler statistics
|
||||
"""
|
||||
base_stats = {
|
||||
**{k: v for k, v in self.stats.items() if k not in ['per_user_stats']},
|
||||
'active_executions': len(self.active_executions),
|
||||
'registered_types': self.registry.get_registered_types(),
|
||||
'running': self._running,
|
||||
'check_interval_minutes': self.current_check_interval_minutes,
|
||||
'min_check_interval_minutes': self.min_check_interval_minutes,
|
||||
'max_check_interval_minutes': self.max_check_interval_minutes,
|
||||
'intelligent_scheduling': True
|
||||
}
|
||||
|
||||
# Include per-user stats (all users or filtered)
|
||||
if user_id is not None:
|
||||
if user_id in self.stats['per_user_stats']:
|
||||
base_stats['user_stats'] = self.stats['per_user_stats'][user_id]
|
||||
else:
|
||||
base_stats['user_stats'] = {
|
||||
'executed': 0,
|
||||
'failed': 0,
|
||||
'success_rate': 0.0
|
||||
}
|
||||
else:
|
||||
# Include all per-user stats (for admin/debugging)
|
||||
base_stats['per_user_stats'] = self.stats['per_user_stats']
|
||||
|
||||
return base_stats
|
||||
|
||||
user_stats["tasks_failed"] += 1
|
||||
user_stats["last_update"] = datetime.utcnow().isoformat()
|
||||
|
||||
async def _schedule_retry(self, task: Any, retry_delay: int):
|
||||
try:
|
||||
task_id = getattr(task, "id", None)
|
||||
logger.warning(
|
||||
f"[Scheduler] Retry requested for task {task_id} in {retry_delay}s, "
|
||||
f"using loader-based retry semantics."
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def schedule_one_time_task(
|
||||
self,
|
||||
func: Callable,
|
||||
run_date: datetime,
|
||||
job_id: str,
|
||||
args: tuple = (),
|
||||
kwargs: Dict[str, Any] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
replace_existing: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Schedule a one-time task to run at a specific datetime.
|
||||
Schedule a one-time task execution.
|
||||
|
||||
Args:
|
||||
func: Async function to execute
|
||||
run_date: Datetime when the task should run (must be timezone-aware UTC)
|
||||
job_id: Unique identifier for this job
|
||||
args: Positional arguments to pass to func
|
||||
kwargs: Keyword arguments to pass to func
|
||||
replace_existing: If True, replace existing job with same ID
|
||||
func: Function to execute
|
||||
run_date: Date/time to run the task
|
||||
job_id: Unique job ID
|
||||
kwargs: Keyword arguments for the function
|
||||
replace_existing: Whether to replace existing job with same ID
|
||||
|
||||
Returns:
|
||||
Job ID
|
||||
"""
|
||||
if not self._running:
|
||||
logger.warning(
|
||||
f"Scheduler not running, but scheduling job {job_id} anyway. "
|
||||
"APScheduler will start automatically when needed."
|
||||
)
|
||||
|
||||
try:
|
||||
# Ensure run_date is timezone-aware (UTC)
|
||||
if run_date.tzinfo is None:
|
||||
from datetime import timezone
|
||||
run_date = run_date.replace(tzinfo=timezone.utc)
|
||||
logger.debug(f"Added UTC timezone to run_date: {run_date}")
|
||||
|
||||
self.scheduler.add_job(
|
||||
func,
|
||||
trigger=DateTrigger(run_date=run_date),
|
||||
args=args,
|
||||
kwargs=kwargs or {},
|
||||
id=job_id,
|
||||
kwargs=kwargs or {},
|
||||
replace_existing=replace_existing,
|
||||
misfire_grace_time=3600 # 1 hour grace period for missed jobs
|
||||
misfire_grace_time=3600 # 1 hour grace period
|
||||
)
|
||||
|
||||
# Get updated job count
|
||||
all_jobs = self.scheduler.get_jobs()
|
||||
one_time_jobs = [j for j in all_jobs if j.id != 'check_due_tasks']
|
||||
|
||||
# Extract user_id from kwargs if available for logging and job store
|
||||
user_id = kwargs.get('user_id', None) if kwargs else None
|
||||
func_name = func.__name__ if hasattr(func, '__name__') else str(func)
|
||||
|
||||
# Get job store name for user (if user_id provided)
|
||||
job_store_name = 'default'
|
||||
if user_id:
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if db:
|
||||
job_store_name = get_user_job_store_name(user_id, db)
|
||||
db.close()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not determine job store for user {user_id}: {e}")
|
||||
|
||||
# Note: APScheduler doesn't support dynamic job store creation
|
||||
# We use 'default' for all jobs but log the user's job store name for debugging
|
||||
# The actual user isolation is handled through task filtering by user_id
|
||||
|
||||
# Log detailed one-time task scheduling information (use WARNING level for visibility)
|
||||
log_message = (
|
||||
f"[Scheduler] 📅 Scheduled One-Time Task\n"
|
||||
f" ├─ Job ID: {job_id}\n"
|
||||
f" ├─ Function: {func_name}\n"
|
||||
f" ├─ User ID: {user_id or 'system'}\n"
|
||||
f" ├─ Job Store: {job_store_name} (user context)\n"
|
||||
f" ├─ Scheduled For: {run_date}\n"
|
||||
f" ├─ Replace Existing: {replace_existing}\n"
|
||||
f" ├─ Total One-Time Jobs: {len(one_time_jobs)}\n"
|
||||
f" └─ Total Scheduled Jobs: {len(all_jobs)}"
|
||||
)
|
||||
logger.warning(log_message)
|
||||
|
||||
# Log job scheduling to event log for dashboard
|
||||
if user_id:
|
||||
try:
|
||||
event_db = get_session_for_user(user_id)
|
||||
if event_db:
|
||||
event_log = SchedulerEventLog(
|
||||
event_type='job_scheduled',
|
||||
event_date=datetime.utcnow(),
|
||||
job_id=job_id,
|
||||
job_type='one_time',
|
||||
user_id=user_id,
|
||||
event_data={
|
||||
'function_name': func_name,
|
||||
'job_store': job_store_name,
|
||||
'scheduled_for': run_date.isoformat(),
|
||||
'replace_existing': replace_existing
|
||||
}
|
||||
)
|
||||
event_db.add(event_log)
|
||||
event_db.commit()
|
||||
event_db.close()
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to log job scheduling event: {e}")
|
||||
|
||||
logger.info(f"Scheduled one-time task {job_id} at {run_date}")
|
||||
return job_id
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to schedule one-time task {job_id}: {e}")
|
||||
raise
|
||||
|
||||
def is_running(self) -> bool:
|
||||
"""Check if scheduler is running."""
|
||||
return self._running
|
||||
|
||||
async def execute_task_by_type(self, task_type: str, user_id: str, payload: Dict[str, Any]):
|
||||
"""
|
||||
Execute a task by type and payload immediately.
|
||||
Used for one-time tasks triggered by system events.
|
||||
"""
|
||||
from collections import namedtuple
|
||||
TaskStub = namedtuple('TaskStub', ['user_id', 'payload', 'id'])
|
||||
task_stub = TaskStub(user_id=user_id, payload=payload, id=f"manual_{datetime.utcnow().timestamp()}")
|
||||
|
||||
await execute_task_async(self, task_type, task_stub, execution_source="manual")
|
||||
|
||||
|
||||
@@ -67,6 +67,77 @@ class StoryImageGenerationService:
|
||||
clean_title = "".join(c if c.isalnum() or c in ('-', '_') else '_' for c in scene_title[:30])
|
||||
unique_id = str(uuid.uuid4())[:8]
|
||||
return f"scene_{scene_number}_{clean_title}_{unique_id}.png"
|
||||
|
||||
def _refine_image_prompt_with_bible(
|
||||
self,
|
||||
image_prompt: str,
|
||||
scene: Dict[str, Any],
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Lightweight image prompt refinement using the anime story bible.
|
||||
|
||||
Takes the existing scene image_prompt and enriches it with visual_style,
|
||||
world, and cast hints from the bible. This is deterministic and avoids
|
||||
extra LLM calls.
|
||||
"""
|
||||
if not image_prompt or not isinstance(image_prompt, str):
|
||||
return image_prompt
|
||||
|
||||
if not anime_bible or not isinstance(anime_bible, dict):
|
||||
return image_prompt
|
||||
|
||||
visual_style = anime_bible.get("visual_style") or {}
|
||||
world = anime_bible.get("world") or {}
|
||||
main_cast = anime_bible.get("main_cast") or []
|
||||
|
||||
parts: List[str] = []
|
||||
|
||||
style_preset = visual_style.get("style_preset")
|
||||
if style_preset:
|
||||
parts.append(f"{style_preset} anime illustration style")
|
||||
|
||||
camera_style = visual_style.get("camera_style")
|
||||
if camera_style:
|
||||
parts.append(f"framing and camera style: {camera_style}")
|
||||
|
||||
color_mood = visual_style.get("color_mood")
|
||||
if color_mood:
|
||||
parts.append(f"color mood: {color_mood}")
|
||||
|
||||
lighting = visual_style.get("lighting")
|
||||
if lighting:
|
||||
parts.append(f"lighting: {lighting}")
|
||||
|
||||
line_style = visual_style.get("line_style")
|
||||
if line_style:
|
||||
parts.append(f"line style: {line_style}")
|
||||
|
||||
extra_tags = visual_style.get("extra_tags") or []
|
||||
if isinstance(extra_tags, (list, tuple)):
|
||||
extra_text = ", ".join(str(tag) for tag in extra_tags[:6] if tag)
|
||||
if extra_text:
|
||||
parts.append(extra_text)
|
||||
|
||||
setting = world.get("setting") if isinstance(world, dict) else None
|
||||
if setting:
|
||||
parts.append(f"world setting: {setting}")
|
||||
|
||||
if isinstance(main_cast, list):
|
||||
names = [
|
||||
c.get("name")
|
||||
for c in main_cast
|
||||
if isinstance(c, dict) and c.get("name")
|
||||
]
|
||||
if names:
|
||||
joined = ", ".join(names[:4])
|
||||
parts.append(f"keep character designs consistent for: {joined}")
|
||||
|
||||
if not parts:
|
||||
return image_prompt
|
||||
|
||||
suffix = ", " + ", ".join(parts)
|
||||
return image_prompt.strip() + suffix
|
||||
|
||||
def generate_scene_image(
|
||||
self,
|
||||
@@ -75,7 +146,8 @@ class StoryImageGenerationService:
|
||||
provider: Optional[str] = None,
|
||||
width: int = 1024,
|
||||
height: int = 1024,
|
||||
model: Optional[str] = None
|
||||
model: Optional[str] = None,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate an image for a single story scene.
|
||||
@@ -94,6 +166,16 @@ class StoryImageGenerationService:
|
||||
scene_number = scene.get("scene_number", 0)
|
||||
scene_title = scene.get("title", "Untitled")
|
||||
image_prompt = scene.get("image_prompt", "")
|
||||
|
||||
if anime_bible:
|
||||
try:
|
||||
image_prompt = self._refine_image_prompt_with_bible(
|
||||
image_prompt=image_prompt,
|
||||
scene=scene,
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[StoryImageGeneration] Failed to refine image prompt with bible: {e}")
|
||||
|
||||
if not image_prompt:
|
||||
raise ValueError(f"Scene {scene_number} ({scene_title}) has no image_prompt")
|
||||
@@ -156,7 +238,8 @@ class StoryImageGenerationService:
|
||||
height: int = 1024,
|
||||
model: Optional[str] = None,
|
||||
progress_callback: Optional[callable] = None,
|
||||
db: Optional[Session] = None
|
||||
db: Optional[Session] = None,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Generate images for multiple story scenes.
|
||||
@@ -192,7 +275,7 @@ class StoryImageGenerationService:
|
||||
width=width,
|
||||
height=height,
|
||||
model=model,
|
||||
db=db
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
|
||||
image_results.append(image_result)
|
||||
@@ -295,4 +378,3 @@ class StoryImageGenerationService:
|
||||
except Exception as e:
|
||||
logger.error(f"[StoryImageGeneration] Error regenerating image for scene {scene_number}: {e}")
|
||||
raise RuntimeError(f"Failed to regenerate image for scene {scene_number}: {str(e)}") from e
|
||||
|
||||
|
||||
@@ -57,6 +57,7 @@ class StoryOutlineMixin(StoryServiceBase):
|
||||
ending_preference: str,
|
||||
user_id: str,
|
||||
use_structured_output: bool = True,
|
||||
include_anime_bible: bool = False,
|
||||
) -> Any:
|
||||
"""Generate a story outline with optional structured JSON output."""
|
||||
persona_prompt = self.build_persona_prompt(
|
||||
|
||||
@@ -145,20 +145,45 @@ Write ONLY the premise sentence(s). Do not write anything else.
|
||||
"reasoning",
|
||||
],
|
||||
},
|
||||
"minItems": 1,
|
||||
"maxItems": 1,
|
||||
}
|
||||
},
|
||||
"required": ["options"],
|
||||
}
|
||||
|
||||
def _build_idea_enhance_schema(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"suggestions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"idea": {"type": "string"},
|
||||
"whats_missing": {"type": "string"},
|
||||
"why_choose": {"type": "string"},
|
||||
},
|
||||
"required": ["idea", "whats_missing", "why_choose"],
|
||||
},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
}
|
||||
},
|
||||
"required": ["options"],
|
||||
"required": ["suggestions"],
|
||||
}
|
||||
|
||||
def generate_story_setup_options(
|
||||
self,
|
||||
*,
|
||||
story_idea: str,
|
||||
story_mode: str | None,
|
||||
story_template: str | None,
|
||||
brand_context: Dict[str, Any] | None,
|
||||
user_id: str,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate 3 story setup options from a user's story idea."""
|
||||
"""Generate a single story setup option from a user's story idea."""
|
||||
|
||||
suggested_writing_styles = ['Formal', 'Casual', 'Poetic', 'Humorous', 'Academic', 'Journalistic', 'Narrative']
|
||||
suggested_story_tones = ['Dark', 'Uplifting', 'Suspenseful', 'Whimsical', 'Melancholic', 'Mysterious', 'Romantic', 'Adventurous']
|
||||
@@ -167,12 +192,59 @@ Write ONLY the premise sentence(s). Do not write anything else.
|
||||
suggested_content_ratings = ['G', 'PG', 'PG-13', 'R']
|
||||
suggested_ending_preferences = ['Happy', 'Tragic', 'Cliffhanger', 'Twist', 'Open-ended', 'Bittersweet']
|
||||
|
||||
mode_label = None
|
||||
if story_mode == "marketing":
|
||||
mode_label = "Non-fiction marketing story (brand or product campaign)"
|
||||
elif story_mode == "pure":
|
||||
mode_label = "Fiction story"
|
||||
|
||||
template_label = None
|
||||
if story_template == "product_story":
|
||||
template_label = "Product Story"
|
||||
elif story_template == "brand_manifesto":
|
||||
template_label = "Brand Manifesto"
|
||||
elif story_template == "founder_story":
|
||||
template_label = "Founder Story"
|
||||
elif story_template == "customer_story":
|
||||
template_label = "Customer Story"
|
||||
elif story_template == "short_fiction":
|
||||
template_label = "Short Fiction"
|
||||
elif story_template == "long_fiction":
|
||||
template_label = "Long Fiction"
|
||||
elif story_template == "anime_fiction":
|
||||
template_label = "Anime Fiction"
|
||||
elif story_template == "experimental_fiction":
|
||||
template_label = "Experimental Fiction"
|
||||
|
||||
brand_name = None
|
||||
writing_tone = None
|
||||
audience_description = None
|
||||
if isinstance(brand_context, dict):
|
||||
brand_name = brand_context.get("brand_name")
|
||||
writing_tone = brand_context.get("writing_tone")
|
||||
target_audience = brand_context.get("target_audience")
|
||||
if isinstance(target_audience, dict):
|
||||
audience_description = target_audience.get("description") or target_audience.get("summary")
|
||||
elif isinstance(target_audience, str):
|
||||
audience_description = target_audience
|
||||
|
||||
setup_prompt = f"""\
|
||||
You are an expert story writer and creative writing assistant. A user has provided the following story idea or information:
|
||||
You are an expert story writer and creative writing assistant.
|
||||
|
||||
{"This is a " + mode_label + "." if mode_label else ""}
|
||||
{("The user selected the template: " + template_label + ".") if template_label else ""}
|
||||
|
||||
The story should stay consistent with the brand and audience context below when relevant:
|
||||
|
||||
- Brand name or site: {brand_name or "Not specified"}
|
||||
- Headline/overall writing tone: {writing_tone or "Not specified"}
|
||||
- Audience description: {audience_description or "Not specified"}
|
||||
|
||||
The user has provided the following story idea or information:
|
||||
|
||||
{story_idea}
|
||||
|
||||
Based on this story idea, generate exactly 3 different, well-thought-out story setup options. Each option should be CREATIVE, PERSONALIZED, and perfectly tailored to the user's specific story idea.
|
||||
Based on this story idea, generate exactly 1 well-thought-out story setup option. The setup should be CREATIVE, PERSONALIZED, and perfectly tailored to the user's specific story idea.
|
||||
|
||||
**CRITICAL - Creative Freedom:**
|
||||
- You have COMPLETE FREEDOM to craft personalized values that best fit the user's story idea
|
||||
@@ -183,7 +255,7 @@ Based on this story idea, generate exactly 3 different, well-thought-out story s
|
||||
- Narrative POV: "Second Person (You)" or "Omniscient Narrator as Guide" (not just standard options)
|
||||
- The goal is to create the PERFECT setup for THIS specific story, not to fit into generic categories
|
||||
|
||||
Each option should:
|
||||
The setup should:
|
||||
1. Have a unique and creative persona that fits the story idea perfectly
|
||||
2. Define a compelling story setting that brings the idea to life
|
||||
3. Describe interesting and engaging characters
|
||||
@@ -212,23 +284,23 @@ Each option should:
|
||||
|
||||
**Remember:** These are ONLY suggestions. If a custom value better serves the story idea, CREATE IT!
|
||||
|
||||
Return exactly 3 options as a JSON array. Each option must include a "premise" field with the story premise.
|
||||
Return exactly 1 option as a JSON array with a single object in "options". The object must include a "premise" field with the story premise.
|
||||
"""
|
||||
|
||||
setup_schema = self._build_setup_schema()
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryWriter] Generating story setup options for user {user_id}")
|
||||
logger.info(f"[StoryWriter] Generating story setup option for user {user_id}")
|
||||
response = self.load_json_response(
|
||||
llm_text_gen(prompt=setup_prompt, json_struct=setup_schema, user_id=user_id)
|
||||
)
|
||||
|
||||
options = response.get("options", [])
|
||||
if len(options) != 3:
|
||||
logger.warning(f"[StoryWriter] Expected 3 options but got {len(options)}, correcting count")
|
||||
if len(options) < 3:
|
||||
raise ValueError(f"Expected 3 options but got {len(options)}")
|
||||
options = options[:3]
|
||||
if len(options) != 1:
|
||||
logger.warning(f"[StoryWriter] Expected 1 option but got {len(options)}, correcting count")
|
||||
if len(options) < 1:
|
||||
raise ValueError(f"Expected 1 option but got {len(options)}")
|
||||
options = options[:1]
|
||||
|
||||
for idx, option in enumerate(options):
|
||||
if not option.get("premise") or not option.get("premise", "").strip():
|
||||
@@ -262,7 +334,7 @@ Return exactly 3 options as a JSON array. Each option must include a "premise" f
|
||||
premise += "."
|
||||
option["premise"] = premise
|
||||
|
||||
logger.info(f"[StoryWriter] Generated {len(options)} story setup options with premises for user {user_id}")
|
||||
logger.info(f"[StoryWriter] Generated {len(options)} story setup option(s) with premise for user {user_id}")
|
||||
return options
|
||||
except HTTPException:
|
||||
raise
|
||||
@@ -273,3 +345,119 @@ Return exactly 3 options as a JSON array. Each option must include a "premise" f
|
||||
logger.error(f"[StoryWriter] Error generating story setup options: {exc}")
|
||||
raise RuntimeError(f"Failed to generate story setup options: {exc}") from exc
|
||||
|
||||
def enhance_story_idea(
|
||||
self,
|
||||
*,
|
||||
story_idea: str,
|
||||
story_mode: str | None,
|
||||
story_template: str | None,
|
||||
brand_context: Dict[str, Any] | None,
|
||||
user_id: str,
|
||||
fiction_variant: str | None = None,
|
||||
narrative_energy: str | None = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
mode_label = None
|
||||
if story_mode == "marketing":
|
||||
mode_label = "Non-fiction marketing story (brand or product campaign)"
|
||||
elif story_mode == "pure":
|
||||
mode_label = "Fiction story"
|
||||
|
||||
template_label = None
|
||||
if story_template == "product_story":
|
||||
template_label = "Product Story"
|
||||
elif story_template == "brand_manifesto":
|
||||
template_label = "Brand Manifesto"
|
||||
elif story_template == "founder_story":
|
||||
template_label = "Founder Story"
|
||||
elif story_template == "customer_story":
|
||||
template_label = "Customer Story"
|
||||
elif story_template == "short_fiction":
|
||||
template_label = "Short Fiction"
|
||||
elif story_template == "long_fiction":
|
||||
template_label = "Long Fiction"
|
||||
elif story_template == "anime_fiction":
|
||||
template_label = "Anime Fiction"
|
||||
elif story_template == "experimental_fiction":
|
||||
template_label = "Experimental Fiction"
|
||||
|
||||
brand_name = None
|
||||
writing_tone = None
|
||||
audience_description = None
|
||||
if isinstance(brand_context, dict):
|
||||
brand_name = brand_context.get("brand_name")
|
||||
writing_tone = brand_context.get("writing_tone")
|
||||
target_audience = brand_context.get("target_audience")
|
||||
if isinstance(target_audience, dict):
|
||||
audience_description = target_audience.get("description") or target_audience.get("summary")
|
||||
elif isinstance(target_audience, str):
|
||||
audience_description = target_audience
|
||||
|
||||
fiction_focus_line = ""
|
||||
if fiction_variant:
|
||||
fiction_focus_line = f'Treat the story as "{fiction_variant}" and lean into that creative focus.'
|
||||
|
||||
energy_line = ""
|
||||
if narrative_energy:
|
||||
energy_line = f'Target narrative energy: {narrative_energy}.'
|
||||
|
||||
enhance_prompt = f"""You are a creative writing coach helping a user refine and expand a story idea.
|
||||
|
||||
{"This is a " + mode_label + "." if mode_label else ""}
|
||||
{("The user selected the template: " + template_label + ".") if template_label else ""}
|
||||
{fiction_focus_line}
|
||||
{energy_line}
|
||||
|
||||
When relevant, keep the idea aligned with this brand and audience context:
|
||||
- Brand name or site: {brand_name or "Not specified"}
|
||||
- Headline/overall writing tone: {writing_tone or "Not specified"}
|
||||
- Audience description: {audience_description or "Not specified"}
|
||||
|
||||
The user has written the following story idea or concept:
|
||||
|
||||
{story_idea}
|
||||
|
||||
Your task is to propose exactly 3 alternative enhanced story idea options.
|
||||
|
||||
Each option must:
|
||||
- Preserve the user's core premise and intent.
|
||||
- Make the premise clearer and more compelling.
|
||||
- Surface the central conflict or tension.
|
||||
- Clarify the main characters and their goals.
|
||||
- Strengthen the setting and stakes.
|
||||
- Stay at the "idea" level, not a full outline or beat-by-beat breakdown.
|
||||
|
||||
For each option, return three fields:
|
||||
- "idea": 2-4 sentences describing the improved story idea, suitable for a single textarea input.
|
||||
- "whats_missing": 2-4 sentences explaining what important details are missing or underspecified in the current brief. Focus on gaps such as: protagonist details, antagonist or opposing force, stakes, setting and time period, audience/age group, subgenre or type of fiction (for example, anime vs grounded sci-fi), language or tone preferences, and any format constraints.
|
||||
- "why_choose": 1-3 sentences explaining how this option interprets the original idea and why it might be a strong direction for the story.
|
||||
|
||||
Do not write a full story outline.
|
||||
Do not output numbered lists or markdown formatting.
|
||||
|
||||
Return a single JSON object with a "suggestions" array of 3 items, where each item has the keys "idea", "whats_missing", and "why_choose"."""
|
||||
|
||||
schema = self._build_idea_enhance_schema()
|
||||
|
||||
try:
|
||||
logger.info(f"[StoryWriter] Enhancing story idea with structured suggestions for user {user_id}")
|
||||
response = self.load_json_response(
|
||||
llm_text_gen(prompt=enhance_prompt, json_struct=schema, user_id=user_id)
|
||||
)
|
||||
suggestions = response.get("suggestions", [])
|
||||
if len(suggestions) != 3:
|
||||
logger.warning(
|
||||
f"[StoryWriter] Expected 3 idea suggestions but got {len(suggestions)}, correcting count"
|
||||
)
|
||||
if len(suggestions) < 3:
|
||||
raise ValueError(f"Expected 3 suggestions but got {len(suggestions)}")
|
||||
suggestions = suggestions[:3]
|
||||
return suggestions
|
||||
except HTTPException:
|
||||
raise
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.error(f"[StoryWriter] Failed to parse JSON response for story idea enhancement: {exc}")
|
||||
raise RuntimeError(f"Failed to parse story idea enhancement suggestions: {exc}") from exc
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Error enhancing story idea: {exc}")
|
||||
raise RuntimeError(f"Failed to enhance story idea: {exc}") from exc
|
||||
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
from services.story_writer.image_generation_service import StoryImageGenerationService
|
||||
|
||||
from .base import StoryServiceBase
|
||||
@@ -36,6 +38,7 @@ class StoryContentMixin(StoryOutlineMixin):
|
||||
content_rating: str,
|
||||
ending_preference: str,
|
||||
story_length: str = "Medium",
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
user_id: str,
|
||||
) -> str:
|
||||
"""Generate the starting section (or full short story)."""
|
||||
@@ -52,6 +55,19 @@ class StoryContentMixin(StoryOutlineMixin):
|
||||
ending_preference,
|
||||
)
|
||||
|
||||
anime_bible_context = ""
|
||||
if anime_bible:
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
anime_bible_context = f"""
|
||||
|
||||
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
outline_text = self._format_outline_for_prompt(outline)
|
||||
story_length_lower = story_length.lower()
|
||||
is_short_story = "short" in story_length_lower or "1000" in story_length_lower
|
||||
@@ -61,6 +77,8 @@ class StoryContentMixin(StoryOutlineMixin):
|
||||
short_story_prompt = f"""\
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You have a gripping premise in mind:
|
||||
|
||||
{premise}
|
||||
@@ -154,6 +172,285 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
logger.error(f"Story Start Generation Error: {exc}")
|
||||
raise RuntimeError(f"Failed to generate story start: {exc}") from exc
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Anime scene refinement
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def refine_anime_scene_text(
|
||||
self,
|
||||
*,
|
||||
scene: Dict[str, Any],
|
||||
persona: str,
|
||||
story_setting: str,
|
||||
character_input: str,
|
||||
plot_elements: str,
|
||||
writing_style: str,
|
||||
story_tone: str,
|
||||
narrative_pov: str,
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
anime_bible: Optional[Dict[str, Any]],
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
persona_prompt = self.build_persona_prompt(
|
||||
persona,
|
||||
story_setting,
|
||||
character_input,
|
||||
plot_elements,
|
||||
writing_style,
|
||||
story_tone,
|
||||
narrative_pov,
|
||||
audience_age_group,
|
||||
content_rating,
|
||||
"Neutral",
|
||||
)
|
||||
|
||||
anime_bible_context = ""
|
||||
if anime_bible:
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
anime_bible_context = f"""
|
||||
|
||||
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
current_title = scene.get("title", "")
|
||||
current_description = scene.get("description", "")
|
||||
current_image_prompt = scene.get("image_prompt", "")
|
||||
current_audio_narration = scene.get("audio_narration", "")
|
||||
current_character_descriptions = scene.get("character_descriptions") or []
|
||||
current_key_events = scene.get("key_events") or []
|
||||
|
||||
scene_schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"image_prompt": {"type": "string"},
|
||||
"audio_narration": {"type": "string"},
|
||||
"character_descriptions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"key_events": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["title", "description", "image_prompt", "audio_narration"],
|
||||
}
|
||||
|
||||
prompt = f"""
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You are refining a single anime story scene so that it fully respects the anime story bible for characters, world rules, and visual style.
|
||||
|
||||
Current scene:
|
||||
- Title: {current_title}
|
||||
- Description: {current_description}
|
||||
- Image prompt: {current_image_prompt}
|
||||
- Audio narration: {current_audio_narration}
|
||||
- Character descriptions: {current_character_descriptions}
|
||||
- Key events: {current_key_events}
|
||||
|
||||
Refine the scene so that:
|
||||
- Title is concise and evocative
|
||||
- Description clearly describes what happens in the scene
|
||||
- Image prompt is vivid, visual, and aligned with the anime bible style and cast
|
||||
- Audio narration is natural, spoken-friendly text matching the scene
|
||||
- Character descriptions highlight key visual and personality traits relevant to this moment
|
||||
- Key events list the main beats of the scene
|
||||
|
||||
Respond with JSON matching this schema:
|
||||
{scene_schema}
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
json_struct=scene_schema,
|
||||
user_id=user_id,
|
||||
)
|
||||
data = self.load_json_response(raw)
|
||||
except Exception as exc:
|
||||
logger.warning(f"[StoryWriter] Failed to refine anime scene text via LLM: {exc}")
|
||||
return {
|
||||
"scene_number": scene.get("scene_number"),
|
||||
"title": current_title,
|
||||
"description": current_description,
|
||||
"image_prompt": current_image_prompt,
|
||||
"audio_narration": current_audio_narration,
|
||||
"character_descriptions": current_character_descriptions,
|
||||
"key_events": current_key_events,
|
||||
}
|
||||
|
||||
refined = {
|
||||
"scene_number": scene.get("scene_number"),
|
||||
"title": data.get("title", current_title),
|
||||
"description": data.get("description", current_description),
|
||||
"image_prompt": data.get("image_prompt", current_image_prompt),
|
||||
"audio_narration": data.get("audio_narration", current_audio_narration),
|
||||
"character_descriptions": data.get(
|
||||
"character_descriptions", current_character_descriptions
|
||||
),
|
||||
"key_events": data.get("key_events", current_key_events),
|
||||
}
|
||||
return refined
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Anime scene generation from bible
|
||||
# ------------------------------------------------------------------ #
|
||||
|
||||
def generate_anime_scene_from_bible(
|
||||
self,
|
||||
*,
|
||||
premise: str,
|
||||
persona: str,
|
||||
story_setting: str,
|
||||
character_input: str,
|
||||
plot_elements: str,
|
||||
writing_style: str,
|
||||
story_tone: str,
|
||||
narrative_pov: str,
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
anime_bible: Dict[str, Any],
|
||||
previous_scenes: Optional[List[Dict[str, Any]]],
|
||||
target_scene_number: Optional[int],
|
||||
user_id: str,
|
||||
) -> Dict[str, Any]:
|
||||
persona_prompt = self.build_persona_prompt(
|
||||
persona,
|
||||
story_setting,
|
||||
character_input,
|
||||
plot_elements,
|
||||
writing_style,
|
||||
story_tone,
|
||||
narrative_pov,
|
||||
audience_age_group,
|
||||
content_rating,
|
||||
"Neutral",
|
||||
)
|
||||
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
|
||||
anime_bible_context = f"""
|
||||
|
||||
You have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. You MUST treat it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
previous_summary_lines: List[str] = []
|
||||
if previous_scenes:
|
||||
for s in previous_scenes[:6]:
|
||||
num = s.get("scene_number")
|
||||
title = s.get("title") or ""
|
||||
desc = s.get("description") or ""
|
||||
summary = desc
|
||||
if len(summary) > 200:
|
||||
summary = summary[:197] + "..."
|
||||
previous_summary_lines.append(
|
||||
f"- Scene {num}: {title} — {summary}".strip()
|
||||
)
|
||||
|
||||
previous_block = ""
|
||||
if previous_summary_lines:
|
||||
previous_block = (
|
||||
"\nPrevious scenes so far (for continuity, do NOT contradict):\n"
|
||||
+ "\n".join(previous_summary_lines)
|
||||
)
|
||||
|
||||
scene_schema: Dict[str, Any] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"description": {"type": "string"},
|
||||
"image_prompt": {"type": "string"},
|
||||
"audio_narration": {"type": "string"},
|
||||
"character_descriptions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
"key_events": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
},
|
||||
},
|
||||
"required": ["title", "description", "image_prompt", "audio_narration"],
|
||||
}
|
||||
|
||||
prompt = f"""
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You are generating a brand new anime story scene that must fully respect the anime story bible for characters, world rules, and visual style.
|
||||
|
||||
Overall premise:
|
||||
{premise}
|
||||
{previous_block}
|
||||
|
||||
Your task:
|
||||
- Create the NEXT SCENE in this story.
|
||||
- It must be consistent with the anime bible (cast, world rules, visual style).
|
||||
- It must logically follow from any previous scenes given above.
|
||||
|
||||
Design the scene so that:
|
||||
- Title is concise and evocative.
|
||||
- Description clearly describes what happens in the scene.
|
||||
- Image prompt is vivid, visual, and aligned with the anime bible style and cast.
|
||||
- Audio narration is natural, spoken-friendly text matching the scene.
|
||||
- Character descriptions highlight key visual and personality traits relevant to this moment.
|
||||
- Key events list the main beats of the scene.
|
||||
|
||||
Respond with JSON matching this schema:
|
||||
{scene_schema}
|
||||
"""
|
||||
|
||||
try:
|
||||
raw = llm_text_gen(
|
||||
prompt=prompt.strip(),
|
||||
json_struct=scene_schema,
|
||||
user_id=user_id,
|
||||
)
|
||||
data = self.load_json_response(raw)
|
||||
except Exception as exc:
|
||||
logger.error(f"[StoryWriter] Failed to generate anime scene from bible: {exc}")
|
||||
raise RuntimeError(f"Failed to generate anime scene from bible: {exc}") from exc
|
||||
|
||||
next_scene_number = target_scene_number
|
||||
if next_scene_number is None:
|
||||
if previous_scenes and len(previous_scenes) > 0:
|
||||
last = previous_scenes[-1]
|
||||
try:
|
||||
last_num = int(last.get("scene_number") or 0)
|
||||
except Exception:
|
||||
last_num = len(previous_scenes)
|
||||
next_scene_number = last_num + 1
|
||||
else:
|
||||
next_scene_number = 1
|
||||
|
||||
result = {
|
||||
"scene_number": next_scene_number,
|
||||
"title": data.get("title", "").strip(),
|
||||
"description": data.get("description", "").strip(),
|
||||
"image_prompt": data.get("image_prompt", "").strip(),
|
||||
"audio_narration": data.get("audio_narration", "").strip(),
|
||||
"character_descriptions": data.get("character_descriptions") or [],
|
||||
"key_events": data.get("key_events") or [],
|
||||
}
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# Continuation
|
||||
# ------------------------------------------------------------------ #
|
||||
@@ -174,6 +471,7 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
ending_preference: str,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
story_length: str = "Medium",
|
||||
user_id: str,
|
||||
) -> str:
|
||||
@@ -191,6 +489,19 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
ending_preference,
|
||||
)
|
||||
|
||||
anime_bible_context = ""
|
||||
if anime_bible:
|
||||
try:
|
||||
serialized_bible = json.dumps(anime_bible, ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
serialized_bible = str(anime_bible)
|
||||
anime_bible_context = f"""
|
||||
|
||||
You also have a structured ANIME STORY BIBLE that defines the main cast, world rules, and visual style. Use it as a hard constraint for character consistency, worldbuilding, and visual storytelling:
|
||||
|
||||
{serialized_bible}
|
||||
"""
|
||||
|
||||
outline_text = self._format_outline_for_prompt(outline)
|
||||
_, continuation_word_count = self._get_story_length_guidance(story_length)
|
||||
current_word_count = len(story_text.split()) if story_text else 0
|
||||
@@ -227,6 +538,8 @@ on establishing the setting, characters, and beginning of the plot in {initial_w
|
||||
continuation_prompt = f"""\
|
||||
{persona_prompt}
|
||||
|
||||
{anime_bible_context}
|
||||
|
||||
You have a gripping premise in mind:
|
||||
|
||||
{premise}
|
||||
@@ -298,6 +611,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
audience_age_group: str,
|
||||
content_rating: str,
|
||||
ending_preference: str,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
user_id: str,
|
||||
max_iterations: int = 10,
|
||||
) -> Dict[str, Any]:
|
||||
@@ -352,6 +666,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
audience_age_group=audience_age_group,
|
||||
content_rating=content_rating,
|
||||
ending_preference=ending_preference,
|
||||
anime_bible=anime_bible,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not draft:
|
||||
@@ -375,6 +690,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
audience_age_group=audience_age_group,
|
||||
content_rating=content_rating,
|
||||
ending_preference=ending_preference,
|
||||
anime_bible=anime_bible,
|
||||
user_id=user_id,
|
||||
)
|
||||
if continuation:
|
||||
@@ -420,6 +736,7 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
height: int = 1024,
|
||||
model: Optional[str] = None,
|
||||
db: Optional[Session] = None,
|
||||
anime_bible: Optional[Dict[str, Any]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate images for story scenes."""
|
||||
image_service = StoryImageGenerationService()
|
||||
@@ -431,5 +748,6 @@ You have written approximately {current_word_count} words so far, leaving approx
|
||||
height=height,
|
||||
model=model,
|
||||
db=db,
|
||||
anime_bible=anime_bible,
|
||||
)
|
||||
|
||||
|
||||
133
backend/services/story_writer/story_project_service.py
Normal file
133
backend/services/story_writer/story_project_service.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
Story Project Service
|
||||
|
||||
Service layer for managing Story Studio project persistence.
|
||||
Modeled after PodcastService for a consistent project API.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from sqlalchemy import and_, desc
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.story_project_models import StoryProject
|
||||
|
||||
|
||||
class StoryProjectService:
|
||||
"""Service for managing Story Studio projects."""
|
||||
|
||||
def __init__(self, db: Session) -> None:
|
||||
self.db = db
|
||||
|
||||
def create_project(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
title: Optional[str] = None,
|
||||
story_mode: Optional[str] = None,
|
||||
story_template: Optional[str] = None,
|
||||
**kwargs: Any,
|
||||
) -> StoryProject:
|
||||
project = StoryProject(
|
||||
project_id=project_id,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
story_mode=story_mode,
|
||||
story_template=story_template,
|
||||
status="draft",
|
||||
current_phase="setup",
|
||||
**kwargs,
|
||||
)
|
||||
self.db.add(project)
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def get_project(self, user_id: str, project_id: str) -> Optional[StoryProject]:
|
||||
return (
|
||||
self.db.query(StoryProject)
|
||||
.filter(
|
||||
and_(
|
||||
StoryProject.project_id == project_id,
|
||||
StoryProject.user_id == user_id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
def update_project(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
**updates: Any,
|
||||
) -> Optional[StoryProject]:
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return None
|
||||
|
||||
for key, value in updates.items():
|
||||
if hasattr(project, key):
|
||||
setattr(project, key, value)
|
||||
|
||||
project.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def list_projects(
|
||||
self,
|
||||
user_id: str,
|
||||
status: Optional[str] = None,
|
||||
favorites_only: bool = False,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
order_by: str = "updated_at",
|
||||
) -> Tuple[List[StoryProject], int]:
|
||||
query = self.db.query(StoryProject).filter(StoryProject.user_id == user_id)
|
||||
|
||||
if status:
|
||||
query = query.filter(StoryProject.status == status)
|
||||
|
||||
if favorites_only:
|
||||
query = query.filter(StoryProject.is_favorite.is_(True))
|
||||
|
||||
total = query.count()
|
||||
|
||||
if order_by == "created_at":
|
||||
query = query.order_by(desc(StoryProject.created_at))
|
||||
else:
|
||||
query = query.order_by(desc(StoryProject.updated_at))
|
||||
|
||||
projects = query.offset(offset).limit(limit).all()
|
||||
|
||||
return projects, total
|
||||
|
||||
def delete_project(self, user_id: str, project_id: str) -> bool:
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return False
|
||||
|
||||
self.db.delete(project)
|
||||
self.db.commit()
|
||||
return True
|
||||
|
||||
def toggle_favorite(self, user_id: str, project_id: str) -> Optional[StoryProject]:
|
||||
project = self.get_project(user_id, project_id)
|
||||
if not project:
|
||||
return None
|
||||
|
||||
project.is_favorite = not project.is_favorite
|
||||
project.updated_at = datetime.utcnow()
|
||||
self.db.commit()
|
||||
self.db.refresh(project)
|
||||
return project
|
||||
|
||||
def update_status(
|
||||
self,
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
status: str,
|
||||
) -> Optional[StoryProject]:
|
||||
return self.update_project(user_id, project_id, status=status)
|
||||
|
||||
@@ -149,7 +149,7 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
try:
|
||||
path = request.url.path
|
||||
except Exception:
|
||||
pass
|
||||
path = ""
|
||||
|
||||
db = None
|
||||
try:
|
||||
@@ -159,8 +159,16 @@ async def check_usage_limits_middleware(request: Request, user_id: str, request_
|
||||
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
|
||||
# Safe User-Agent access
|
||||
user_agent = None
|
||||
try:
|
||||
if hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
user_agent = request.headers.get('user-agent')
|
||||
except:
|
||||
pass
|
||||
|
||||
# Detect if this is an API call that should be rate limited
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
api_provider = api_monitor.detect_api_provider(path, user_agent)
|
||||
if not api_provider:
|
||||
return None
|
||||
|
||||
@@ -236,9 +244,28 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
user_id = None
|
||||
try:
|
||||
# PRIORITY 1: Check request.state.user_id (set by API key injection middleware)
|
||||
if hasattr(request.state, 'user_id') and request.state.user_id:
|
||||
user_id = request.state.user_id
|
||||
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
|
||||
if hasattr(request.state, 'user_id'):
|
||||
# Directly check and convert without accessing attribute if None
|
||||
raw_user_id = request.state.user_id
|
||||
|
||||
# Defensive check for Depends object or other complex types
|
||||
if raw_user_id is not None:
|
||||
# If it's a string, use it
|
||||
if isinstance(raw_user_id, str):
|
||||
user_id = raw_user_id
|
||||
# If it has a dependency attribute (likely a Depends object), ignore it
|
||||
elif hasattr(raw_user_id, 'dependency'):
|
||||
logger.warning(f"Monitoring: request.state.user_id is a Depends object, ignoring.")
|
||||
user_id = None
|
||||
# Try to convert to string if it's a simple type
|
||||
else:
|
||||
try:
|
||||
user_id = str(raw_user_id)
|
||||
except:
|
||||
user_id = None
|
||||
|
||||
if user_id:
|
||||
logger.debug(f"Monitoring: Using user_id from request.state: {user_id}")
|
||||
|
||||
# PRIORITY 2: Check query parameters
|
||||
elif hasattr(request, 'query_params') and 'user_id' in request.query_params:
|
||||
@@ -247,20 +274,23 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
user_id = request.path_params['user_id']
|
||||
|
||||
# PRIORITY 3: Check headers for user identification
|
||||
elif 'x-user-id' in request.headers:
|
||||
user_id = request.headers['x-user-id']
|
||||
elif 'x-user-email' in request.headers:
|
||||
user_id = request.headers['x-user-email'] # Use email as user identifier
|
||||
elif 'x-session-id' in request.headers:
|
||||
user_id = request.headers['x-session-id'] # Use session as fallback
|
||||
|
||||
# Check for authorization header with user info
|
||||
elif 'authorization' in request.headers:
|
||||
# Auth middleware should have set request.state.user_id
|
||||
# If not, this indicates an authentication failure (likely expired token)
|
||||
# Log at debug level to reduce noise - expired tokens are expected
|
||||
# But we can try to decode token if we really needed to, but let's rely on auth middleware
|
||||
pass
|
||||
elif hasattr(request, 'headers') and hasattr(request.headers, 'get'):
|
||||
try:
|
||||
if request.headers.get('x-user-id'):
|
||||
user_id = request.headers.get('x-user-id')
|
||||
elif request.headers.get('x-user-email'):
|
||||
user_id = request.headers.get('x-user-email')
|
||||
elif request.headers.get('x-session-id'):
|
||||
user_id = request.headers.get('x-session-id')
|
||||
|
||||
# Check for authorization header with user info
|
||||
elif request.headers.get('authorization'):
|
||||
# Auth middleware should have set request.state.user_id
|
||||
# If not, this indicates an authentication failure (likely expired token)
|
||||
# Log at debug level to reduce noise - expired tokens are expected
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.debug(f"Error accessing request headers: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error extracting user ID: {e}")
|
||||
@@ -269,7 +299,11 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
# Get database session if user identified
|
||||
db = None
|
||||
if user_id:
|
||||
db = get_session_for_user(user_id)
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get database session for user {user_id}: {e}")
|
||||
db = None
|
||||
|
||||
# Capture request body for usage tracking (read once, safely)
|
||||
request_body = None
|
||||
@@ -291,29 +325,52 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
request_body = None
|
||||
|
||||
# Check usage limits before processing
|
||||
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
|
||||
if limit_response:
|
||||
if db: db.close()
|
||||
return limit_response
|
||||
# Skip for OPTIONS requests
|
||||
try:
|
||||
if request.method != "OPTIONS":
|
||||
limit_response = await check_usage_limits_middleware(request, user_id, request_body)
|
||||
if limit_response:
|
||||
if db: db.close()
|
||||
return limit_response
|
||||
except Exception as e:
|
||||
logger.error(f"Error in usage limits middleware: {e}")
|
||||
# Continue processing if usage check fails (fail open)
|
||||
|
||||
try:
|
||||
response = await call_next(request)
|
||||
status_code = response.status_code
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Capture response body for usage tracking
|
||||
# Extract response body safely for usage tracking
|
||||
response_body = None
|
||||
try:
|
||||
if hasattr(response, 'body'):
|
||||
response_body = response.body.decode('utf-8') if response.body else None
|
||||
elif hasattr(response, '_content'):
|
||||
response_body = response._content.decode('utf-8') if response._content else None
|
||||
except:
|
||||
pass
|
||||
|
||||
if hasattr(response, 'body'):
|
||||
response_body = response.body.decode('utf-8') if response.body else None
|
||||
elif hasattr(response, '_content'):
|
||||
response_body = response._content.decode('utf-8') if response._content else None
|
||||
|
||||
# Track API usage if this is an API call to external providers
|
||||
api_monitor = DatabaseAPIMonitor()
|
||||
api_provider = api_monitor.detect_api_provider(request.url.path, request.headers.get('user-agent'))
|
||||
|
||||
# Safe URL path access
|
||||
try:
|
||||
path = request.url.path
|
||||
except:
|
||||
path = ""
|
||||
|
||||
# Safe User-Agent access - handle case where headers might be a Depends object
|
||||
user_agent = None
|
||||
try:
|
||||
# Defensive check: ensure request.headers is a valid headers object
|
||||
# Some dependency injection failures replace request attributes with Depends objects
|
||||
if hasattr(request, 'headers'):
|
||||
headers_obj = request.headers
|
||||
# Check if it has a 'get' method (like a dict or Headers object)
|
||||
if hasattr(headers_obj, 'get') and callable(headers_obj.get):
|
||||
user_agent = headers_obj.get('user-agent')
|
||||
except:
|
||||
pass
|
||||
|
||||
api_provider = api_monitor.detect_api_provider(path, user_agent)
|
||||
if api_provider and user_id:
|
||||
logger.info(f"Detected API call: {request.url.path} -> {api_provider.value} for user: {user_id}")
|
||||
try:
|
||||
@@ -326,7 +383,7 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
await usage_service.track_api_usage(
|
||||
user_id=user_id,
|
||||
provider=api_provider,
|
||||
endpoint=request.url.path,
|
||||
endpoint=path,
|
||||
method=request.method,
|
||||
model_used=usage_metrics.get('model_used'),
|
||||
tokens_input=usage_metrics.get('tokens_input', 0),
|
||||
@@ -335,7 +392,7 @@ async def monitoring_middleware(request: Request, call_next):
|
||||
status_code=status_code,
|
||||
request_size=len(request_body) if request_body else None,
|
||||
response_size=len(response_body) if response_body else None,
|
||||
user_agent=request.headers.get('user-agent'),
|
||||
user_agent=user_agent,
|
||||
ip_address=request.client.host if request.client else None,
|
||||
search_count=usage_metrics.get('search_count', 0),
|
||||
image_count=usage_metrics.get('image_count', 0),
|
||||
|
||||
487
backend/services/subscription/stripe_service.py
Normal file
487
backend/services/subscription/stripe_service.py
Normal file
@@ -0,0 +1,487 @@
|
||||
import os
|
||||
import stripe
|
||||
from typing import Optional, Dict, Any
|
||||
from loguru import logger
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
from models.subscription_models import UserSubscription, SubscriptionPlan, SubscriptionTier, BillingCycle, UsageStatus, FraudWarning
|
||||
from services.subscription.pricing_service import PricingService
|
||||
from datetime import datetime
|
||||
|
||||
STRIPE_PLAN_PRICE_MAPPING = {
|
||||
(SubscriptionTier.BASIC.value, BillingCycle.MONTHLY.value): "price_1T2lWHR2EuR7zQJepLIVQ1EJ",
|
||||
(SubscriptionTier.PRO.value, BillingCycle.MONTHLY.value): "price_1T2ljDR2EuR7zQJeuS317KCj",
|
||||
}
|
||||
|
||||
STRIPE_PRICE_TO_PLAN = {
|
||||
price_id: {"tier": SubscriptionTier(tier), "billing_cycle": BillingCycle(billing_cycle)}
|
||||
for (tier, billing_cycle), price_id in STRIPE_PLAN_PRICE_MAPPING.items()
|
||||
}
|
||||
|
||||
class StripeService:
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
self.api_key = os.getenv("STRIPE_SECRET_KEY")
|
||||
self.webhook_secret = os.getenv("STRIPE_WEBHOOK_SECRET")
|
||||
if not self.api_key:
|
||||
logger.warning("STRIPE_SECRET_KEY is not set. Stripe integration will not work.")
|
||||
else:
|
||||
stripe.api_key = self.api_key
|
||||
|
||||
def _get_price_id_for_plan(self, tier: SubscriptionTier, billing_cycle: BillingCycle) -> str:
|
||||
key = (tier.value, billing_cycle.value)
|
||||
price_id = STRIPE_PLAN_PRICE_MAPPING.get(key)
|
||||
if not price_id:
|
||||
logger.error(f"No Stripe price configured for tier={tier.value}, billing_cycle={billing_cycle.value}")
|
||||
raise HTTPException(status_code=400, detail="Payment plan is not configured")
|
||||
return price_id
|
||||
|
||||
def _get_plan_for_price_id(self, price_id: str) -> tuple[SubscriptionPlan, BillingCycle]:
|
||||
mapping = STRIPE_PRICE_TO_PLAN.get(price_id)
|
||||
if not mapping:
|
||||
logger.error(f"Unknown Stripe price_id: {price_id}")
|
||||
raise HTTPException(status_code=400, detail="Unknown payment price configuration")
|
||||
tier = mapping["tier"]
|
||||
billing_cycle = mapping["billing_cycle"]
|
||||
plan = (
|
||||
self.db.query(SubscriptionPlan)
|
||||
.filter(SubscriptionPlan.tier == tier, SubscriptionPlan.is_active == True)
|
||||
.order_by(SubscriptionPlan.price_monthly)
|
||||
.first()
|
||||
)
|
||||
if not plan:
|
||||
logger.error(f"No subscription plan found for tier={tier.value}")
|
||||
raise HTTPException(status_code=400, detail="Subscription plan not found for payment price")
|
||||
return plan, billing_cycle
|
||||
|
||||
def _get_or_create_customer(self, user_id: str, email: Optional[str] = None) -> str:
|
||||
"""
|
||||
Get existing Stripe customer ID for user, or create a new one.
|
||||
"""
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
if subscription and subscription.stripe_customer_id:
|
||||
return subscription.stripe_customer_id
|
||||
|
||||
# Search Stripe for existing customer by email (if provided) or metadata
|
||||
try:
|
||||
# If we have an email, search by email first
|
||||
if email:
|
||||
existing_customers = stripe.Customer.list(email=email, limit=1)
|
||||
if existing_customers and len(existing_customers.data) > 0:
|
||||
customer = existing_customers.data[0]
|
||||
# Update DB
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer.id
|
||||
self.db.commit()
|
||||
return customer.id
|
||||
|
||||
# Search by metadata user_id
|
||||
existing_customers = stripe.Customer.search(
|
||||
query=f"metadata['user_id']:'{user_id}'",
|
||||
limit=1
|
||||
)
|
||||
if existing_customers and len(existing_customers.data) > 0:
|
||||
customer = existing_customers.data[0]
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer.id
|
||||
self.db.commit()
|
||||
return customer.id
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error searching Stripe customer: {e}")
|
||||
|
||||
# Create new customer
|
||||
try:
|
||||
customer_data = {
|
||||
"metadata": {"user_id": user_id},
|
||||
}
|
||||
if email:
|
||||
customer_data["email"] = email
|
||||
|
||||
customer = stripe.Customer.create(**customer_data)
|
||||
|
||||
# Update DB
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer.id
|
||||
else:
|
||||
# Create a placeholder subscription record if none exists (usually created on signup/free tier)
|
||||
# But typically we expect a free tier record to exist.
|
||||
pass
|
||||
|
||||
self.db.commit()
|
||||
return customer.id
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating Stripe customer: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to create payment profile")
|
||||
|
||||
def create_checkout_session(
|
||||
self,
|
||||
user_id: str,
|
||||
tier: SubscriptionTier,
|
||||
billing_cycle: BillingCycle,
|
||||
success_url: str,
|
||||
cancel_url: str,
|
||||
user_email: Optional[str] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Create a Stripe Checkout Session for a subscription.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise HTTPException(status_code=500, detail="Payment service not configured")
|
||||
|
||||
price_id = self._get_price_id_for_plan(tier, billing_cycle)
|
||||
customer_id = self._get_or_create_customer(user_id, user_email)
|
||||
|
||||
line_item: Dict[str, Any] = {"price": price_id}
|
||||
try:
|
||||
price = stripe.Price.retrieve(price_id)
|
||||
recurring = getattr(price, "recurring", None)
|
||||
usage_type = None
|
||||
if recurring:
|
||||
if isinstance(recurring, dict):
|
||||
usage_type = recurring.get("usage_type")
|
||||
else:
|
||||
usage_type = getattr(recurring, "usage_type", None)
|
||||
if usage_type != "metered":
|
||||
line_item["quantity"] = 1
|
||||
else:
|
||||
logger.info(f"Detected metered price {price_id}; omitting quantity in Checkout line item")
|
||||
except Exception as e:
|
||||
logger.error(f"Error inspecting Stripe price {price_id}: {e}")
|
||||
line_item["quantity"] = 1
|
||||
|
||||
try:
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
customer=customer_id,
|
||||
payment_method_types=["card"],
|
||||
line_items=[line_item],
|
||||
mode="subscription",
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
metadata={
|
||||
"user_id": user_id,
|
||||
"price_id": price_id,
|
||||
},
|
||||
subscription_data={
|
||||
"metadata": {
|
||||
"user_id": user_id,
|
||||
}
|
||||
},
|
||||
allow_promotion_codes=True,
|
||||
)
|
||||
return checkout_session.url
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating checkout session: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
def create_portal_session(self, user_id: str, return_url: str) -> str:
|
||||
"""
|
||||
Create a Stripe Customer Portal session for managing billing.
|
||||
"""
|
||||
if not self.api_key:
|
||||
raise HTTPException(status_code=500, detail="Payment service not configured")
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.user_id == user_id
|
||||
).first()
|
||||
|
||||
if not subscription or not subscription.stripe_customer_id:
|
||||
# Try to find customer by user_id if not in DB
|
||||
try:
|
||||
customers = stripe.Customer.search(query=f"metadata['user_id']:'{user_id}'", limit=1)
|
||||
if customers and len(customers.data) > 0:
|
||||
customer_id = customers.data[0].id
|
||||
# Update DB while we're at it
|
||||
if subscription:
|
||||
subscription.stripe_customer_id = customer_id
|
||||
self.db.commit()
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="No billing profile found for this user")
|
||||
except Exception as e:
|
||||
logger.error(f"Error finding customer for portal: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to access billing portal")
|
||||
else:
|
||||
customer_id = subscription.stripe_customer_id
|
||||
|
||||
try:
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=customer_id,
|
||||
return_url=return_url,
|
||||
)
|
||||
return portal_session.url
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating portal session: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
async def handle_webhook(self, payload: bytes, sig_header: str):
|
||||
"""
|
||||
Handle Stripe webhooks.
|
||||
"""
|
||||
if not self.webhook_secret:
|
||||
logger.warning("STRIPE_WEBHOOK_SECRET not set. Ignoring webhook.")
|
||||
return
|
||||
|
||||
try:
|
||||
event = stripe.Webhook.construct_event(
|
||||
payload, sig_header, self.webhook_secret
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid payload: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid payload")
|
||||
except stripe.error.SignatureVerificationError as e:
|
||||
logger.error(f"Invalid signature: {e}")
|
||||
raise HTTPException(status_code=400, detail="Invalid signature")
|
||||
|
||||
event_type = event["type"]
|
||||
data = event["data"]["object"]
|
||||
|
||||
logger.info(f"Received Stripe webhook: {event_type}")
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
await self._handle_checkout_completed(data)
|
||||
elif event_type == "invoice.payment_succeeded":
|
||||
await self._handle_invoice_payment_succeeded(data)
|
||||
elif event_type == "invoice.payment_failed":
|
||||
await self._handle_invoice_payment_failed(data)
|
||||
elif event_type == "customer.subscription.updated":
|
||||
await self._handle_subscription_updated(data)
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
await self._handle_subscription_deleted(data)
|
||||
elif event_type.startswith("radar.early_fraud_warning."):
|
||||
await self._handle_early_fraud_warning(data)
|
||||
|
||||
return {"status": "success"}
|
||||
|
||||
async def _handle_checkout_completed(self, session: Dict[str, Any]):
|
||||
"""
|
||||
Handle successful checkout.
|
||||
"""
|
||||
user_id = session.get("metadata", {}).get("user_id")
|
||||
customer_id = session.get("customer")
|
||||
subscription_id = session.get("subscription")
|
||||
|
||||
if not user_id:
|
||||
logger.error("No user_id in checkout session metadata")
|
||||
return
|
||||
|
||||
logger.info(f"Checkout completed for user {user_id}")
|
||||
|
||||
# Retrieve subscription details to get the plan/price
|
||||
if subscription_id:
|
||||
try:
|
||||
sub = stripe.Subscription.retrieve(subscription_id)
|
||||
price_id = sub['items']['data'][0]['price']['id']
|
||||
# Map price_id to internal plan_id
|
||||
# Note: You need a way to map Stripe Price IDs to your Plan IDs.
|
||||
# For now, we'll assume the metadata or a lookup.
|
||||
# Ideally, store price_id in SubscriptionPlan table or config.
|
||||
|
||||
# Update DB
|
||||
self._update_user_subscription(
|
||||
user_id,
|
||||
stripe_customer_id=customer_id,
|
||||
stripe_subscription_id=subscription_id,
|
||||
status="active",
|
||||
price_id=price_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing checkout subscription: {e}")
|
||||
|
||||
async def _handle_invoice_payment_succeeded(self, invoice: Dict[str, Any]):
|
||||
"""
|
||||
Handle recurring payment success.
|
||||
"""
|
||||
subscription_id = invoice.get("subscription")
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if not subscription_id:
|
||||
return
|
||||
|
||||
# Find user by stripe_subscription_id or customer_id
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
(UserSubscription.stripe_subscription_id == subscription_id) |
|
||||
(UserSubscription.stripe_customer_id == customer_id)
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.info(f"Payment succeeded for user {subscription.user_id}")
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
# Update period end based on invoice lines period
|
||||
if invoice.get('lines'):
|
||||
period_end = invoice['lines']['data'][0]['period']['end']
|
||||
subscription.current_period_end = datetime.fromtimestamp(period_end)
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_invoice_payment_failed(self, invoice: Dict[str, Any]):
|
||||
subscription_id = invoice.get("subscription")
|
||||
customer_id = invoice.get("customer")
|
||||
|
||||
if not subscription_id:
|
||||
return
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
(UserSubscription.stripe_subscription_id == subscription_id) |
|
||||
(UserSubscription.stripe_customer_id == customer_id)
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.warning(f"Payment failed for user {subscription.user_id}")
|
||||
subscription.status = UsageStatus.PAST_DUE
|
||||
subscription.is_active = False
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_subscription_updated(self, subscription_obj: Dict[str, Any]):
|
||||
"""
|
||||
Handle subscription updates (cancellations, changes).
|
||||
"""
|
||||
stripe_sub_id = subscription_obj.get("id")
|
||||
status = subscription_obj.get("status")
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.stripe_subscription_id == stripe_sub_id
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.info(f"Subscription {stripe_sub_id} updated to {status}")
|
||||
if status in ["active", "trialing"]:
|
||||
subscription.status = UsageStatus.ACTIVE
|
||||
subscription.is_active = True
|
||||
elif status in ["past_due", "unpaid", "incomplete", "incomplete_expired"]:
|
||||
subscription.status = UsageStatus.PAST_DUE
|
||||
subscription.is_active = False
|
||||
elif status in ["canceled"]:
|
||||
subscription.status = UsageStatus.CANCELLED
|
||||
subscription.is_active = False
|
||||
subscription.auto_renew = False
|
||||
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_subscription_deleted(self, subscription_obj: Dict[str, Any]):
|
||||
"""
|
||||
Handle subscription cancellation (immediate).
|
||||
"""
|
||||
stripe_sub_id = subscription_obj.get("id")
|
||||
|
||||
subscription = self.db.query(UserSubscription).filter(
|
||||
UserSubscription.stripe_subscription_id == stripe_sub_id
|
||||
).first()
|
||||
|
||||
if subscription:
|
||||
logger.info(f"Subscription {stripe_sub_id} deleted")
|
||||
subscription.status = UsageStatus.CANCELLED # Need to check if this enum value exists
|
||||
subscription.is_active = False
|
||||
subscription.auto_renew = False
|
||||
self.db.commit()
|
||||
|
||||
async def _handle_early_fraud_warning(self, warning_obj: Dict[str, Any]):
|
||||
efw_id = warning_obj.get("id")
|
||||
if not efw_id:
|
||||
return
|
||||
|
||||
charge_id = warning_obj.get("charge")
|
||||
payment_intent_id = warning_obj.get("payment_intent")
|
||||
created_ts = warning_obj.get("created")
|
||||
created_at = datetime.utcfromtimestamp(created_ts) if created_ts else datetime.utcnow()
|
||||
|
||||
amount = 0
|
||||
currency = ""
|
||||
user_id = None
|
||||
charge_data: Dict[str, Any] = {}
|
||||
|
||||
if charge_id and self.api_key:
|
||||
try:
|
||||
charge = stripe.Charge.retrieve(charge_id)
|
||||
charge_data = charge.to_dict() if hasattr(charge, "to_dict") else dict(charge)
|
||||
amount = charge_data.get("amount") or 0
|
||||
currency = charge_data.get("currency") or ""
|
||||
metadata = charge_data.get("metadata") or {}
|
||||
user_id = metadata.get("user_id")
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving charge for early fraud warning {efw_id}: {e}")
|
||||
|
||||
if not amount:
|
||||
amount = warning_obj.get("amount") or 0
|
||||
if not currency:
|
||||
currency = warning_obj.get("currency") or ""
|
||||
|
||||
existing = self.db.query(FraudWarning).filter(FraudWarning.id == efw_id).first()
|
||||
|
||||
metadata_payload: Dict[str, Any] = {
|
||||
"early_fraud_warning": warning_obj,
|
||||
}
|
||||
if charge_data:
|
||||
metadata_payload["charge"] = charge_data
|
||||
|
||||
if existing:
|
||||
existing.charge_id = charge_id or existing.charge_id
|
||||
existing.payment_intent_id = payment_intent_id or existing.payment_intent_id
|
||||
if user_id:
|
||||
existing.user_id = user_id
|
||||
if amount:
|
||||
existing.amount = amount
|
||||
if currency:
|
||||
existing.currency = currency
|
||||
existing.status = "open"
|
||||
existing.meta_info = metadata_payload
|
||||
else:
|
||||
if not charge_id:
|
||||
return
|
||||
warning = FraudWarning(
|
||||
id=efw_id,
|
||||
charge_id=charge_id,
|
||||
payment_intent_id=payment_intent_id,
|
||||
user_id=user_id,
|
||||
amount=amount or 0,
|
||||
currency=currency or "",
|
||||
status="open",
|
||||
action="none",
|
||||
meta_info=metadata_payload,
|
||||
created_at=created_at,
|
||||
)
|
||||
self.db.add(warning)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def _update_user_subscription(
|
||||
self,
|
||||
user_id: str,
|
||||
stripe_customer_id: str,
|
||||
stripe_subscription_id: str,
|
||||
status: str,
|
||||
price_id: str,
|
||||
):
|
||||
plan, billing_cycle = self._get_plan_for_price_id(price_id)
|
||||
|
||||
subscription = (
|
||||
self.db.query(UserSubscription)
|
||||
.filter(UserSubscription.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
now = datetime.utcnow()
|
||||
|
||||
if not subscription:
|
||||
subscription = UserSubscription(
|
||||
user_id=user_id,
|
||||
plan_id=plan.id,
|
||||
billing_cycle=billing_cycle,
|
||||
current_period_start=now,
|
||||
current_period_end=now,
|
||||
status=UsageStatus.ACTIVE if status == "active" else UsageStatus.SUSPENDED,
|
||||
is_active=status == "active",
|
||||
auto_renew=True,
|
||||
)
|
||||
self.db.add(subscription)
|
||||
else:
|
||||
subscription.plan_id = plan.id
|
||||
subscription.billing_cycle = billing_cycle
|
||||
subscription.is_active = status == "active"
|
||||
|
||||
subscription.stripe_customer_id = stripe_customer_id
|
||||
subscription.stripe_subscription_id = stripe_subscription_id
|
||||
|
||||
self.db.commit()
|
||||
@@ -39,9 +39,34 @@ def _generate_simple_infinitetalk_prompt(
|
||||
# Build a balanced prompt: scene description + simple motion hint
|
||||
parts = []
|
||||
|
||||
# Start with the main subject/scene
|
||||
# Add scene context
|
||||
if title and len(title) > 5 and title.lower() not in ("scene", "podcast", "episode"):
|
||||
parts.append(title)
|
||||
|
||||
# Add analysis context
|
||||
analysis = story_context.get("analysis", {})
|
||||
if analysis:
|
||||
content_type = analysis.get("content_type")
|
||||
if content_type:
|
||||
parts.append(f"Style: {content_type}")
|
||||
|
||||
# Audience helps define the formality/vibe
|
||||
audience = analysis.get("audience")
|
||||
if audience:
|
||||
# Just use first few words of audience to keep it short
|
||||
short_audience = " ".join(audience.split()[:3])
|
||||
parts.append(f"For: {short_audience}")
|
||||
|
||||
# Add bible context if available
|
||||
bible = story_context.get("bible", {})
|
||||
if bible:
|
||||
host_persona = bible.get("host_persona")
|
||||
tone = bible.get("tone")
|
||||
if host_persona:
|
||||
parts.append(f"Host: {host_persona}")
|
||||
if tone:
|
||||
parts.append(f"Tone: {tone}")
|
||||
|
||||
elif description:
|
||||
# Take first sentence or first 60 chars
|
||||
desc_part = description.split('.')[0][:60].strip()
|
||||
|
||||
@@ -52,6 +52,46 @@ def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str,
|
||||
image_prompt = (scene_data.get("image_prompt") or "").strip()
|
||||
tone = (story_context.get("story_tone") or "story").strip()
|
||||
setting = (story_context.get("story_setting") or "the scene").strip()
|
||||
anime_bible = story_context.get("anime_bible") or {}
|
||||
|
||||
anime_style_parts = []
|
||||
if isinstance(anime_bible, dict):
|
||||
visual_style = anime_bible.get("visual_style") or {}
|
||||
world = anime_bible.get("world") or {}
|
||||
main_cast = anime_bible.get("main_cast") or []
|
||||
|
||||
style_preset = visual_style.get("style_preset")
|
||||
camera_style = visual_style.get("camera_style")
|
||||
color_mood = visual_style.get("color_mood")
|
||||
lighting = visual_style.get("lighting")
|
||||
line_style = visual_style.get("line_style")
|
||||
extra_tags = visual_style.get("extra_tags") or []
|
||||
|
||||
if style_preset:
|
||||
anime_style_parts.append(f"Follow {style_preset} anime visual style.")
|
||||
if camera_style:
|
||||
anime_style_parts.append(f"Use camera style: {camera_style}.")
|
||||
if color_mood:
|
||||
anime_style_parts.append(f"Color mood: {color_mood}.")
|
||||
if lighting:
|
||||
anime_style_parts.append(f"Lighting: {lighting}.")
|
||||
if line_style:
|
||||
anime_style_parts.append(f"Line art: {line_style}.")
|
||||
if extra_tags:
|
||||
anime_style_parts.append("Style tags: " + ", ".join(str(tag) for tag in extra_tags[:6]))
|
||||
|
||||
if world:
|
||||
setting_desc = world.get("setting")
|
||||
if setting_desc:
|
||||
anime_style_parts.append(f"World context: {setting_desc}.")
|
||||
|
||||
if main_cast:
|
||||
names = [c.get("name") for c in main_cast if isinstance(c, dict) and c.get("name")]
|
||||
if names:
|
||||
joined = ", ".join(names[:4])
|
||||
anime_style_parts.append(f"Keep character designs consistent for: {joined}.")
|
||||
|
||||
anime_style_text = " ".join(anime_style_parts).strip()
|
||||
|
||||
parts = [
|
||||
f"{title} cinematic motion shot.",
|
||||
@@ -60,6 +100,7 @@ def _build_fallback_prompt(scene_data: Dict[str, Any], story_context: Dict[str,
|
||||
f"Maintain a {tone} mood with natural lighting accents.",
|
||||
f"Honor the original illustration details: {image_prompt[:200]}." if image_prompt else "",
|
||||
"5-second sequence, gentle push-in, flowing cloth and atmospheric particles.",
|
||||
anime_style_text,
|
||||
]
|
||||
fallback_prompt = " ".join(filter(None, parts))
|
||||
return fallback_prompt.strip()
|
||||
@@ -142,6 +183,66 @@ def generate_animation_prompt(
|
||||
title = scene_data.get("title", "")
|
||||
tone = story_context.get("story_tone") or story_context.get("story_tone", "")
|
||||
setting = story_context.get("story_setting") or story_context.get("story_setting", "")
|
||||
anime_bible = story_context.get("anime_bible") or {}
|
||||
|
||||
anime_bible_block = ""
|
||||
if isinstance(anime_bible, dict) and anime_bible:
|
||||
try:
|
||||
visual_style = anime_bible.get("visual_style") or {}
|
||||
world = anime_bible.get("world") or {}
|
||||
main_cast = anime_bible.get("main_cast") or []
|
||||
|
||||
style_lines = []
|
||||
if visual_style:
|
||||
style_preset = visual_style.get("style_preset")
|
||||
camera_style = visual_style.get("camera_style")
|
||||
color_mood = visual_style.get("color_mood")
|
||||
lighting = visual_style.get("lighting")
|
||||
line_style = visual_style.get("line_style")
|
||||
extra_tags = visual_style.get("extra_tags") or []
|
||||
|
||||
if style_preset:
|
||||
style_lines.append(f"- Visual style preset: {style_preset}")
|
||||
if camera_style:
|
||||
style_lines.append(f"- Preferred camera style: {camera_style}")
|
||||
if color_mood:
|
||||
style_lines.append(f"- Color mood: {color_mood}")
|
||||
if lighting:
|
||||
style_lines.append(f"- Lighting: {lighting}")
|
||||
if line_style:
|
||||
style_lines.append(f"- Line art style: {line_style}")
|
||||
if extra_tags:
|
||||
style_lines.append(
|
||||
"- Extra style tags: " + ", ".join(str(tag) for tag in extra_tags[:6])
|
||||
)
|
||||
|
||||
cast_line = ""
|
||||
if main_cast:
|
||||
names = [c.get("name") for c in main_cast if isinstance(c, dict) and c.get("name")]
|
||||
if names:
|
||||
cast_line = "- Main cast to keep visually consistent: " + ", ".join(names[:4])
|
||||
|
||||
world_line = ""
|
||||
if world:
|
||||
setting_desc = world.get("setting")
|
||||
if setting_desc:
|
||||
world_line = "- World/setting context: " + str(setting_desc)
|
||||
|
||||
detail_lines = []
|
||||
if cast_line:
|
||||
detail_lines.append(cast_line)
|
||||
if world_line:
|
||||
detail_lines.append(world_line)
|
||||
detail_lines.extend(style_lines)
|
||||
|
||||
if detail_lines:
|
||||
anime_bible_block = (
|
||||
"\nANIME STORY BIBLE VISUAL GUIDANCE:\n"
|
||||
+ "\n".join(detail_lines)
|
||||
+ "\nAlways respect these constraints in the motion description."
|
||||
)
|
||||
except Exception:
|
||||
anime_bible_block = ""
|
||||
|
||||
prompt = f"""
|
||||
Create a concise animation prompt (2-3 sentences) for a 5-second cinematic clip.
|
||||
@@ -151,6 +252,7 @@ Description: {description}
|
||||
Existing Image Prompt: {image_prompt}
|
||||
Story Tone: {tone}
|
||||
Setting: {setting}
|
||||
{anime_bible_block}
|
||||
|
||||
Focus on:
|
||||
- Motion of characters/objects
|
||||
|
||||
@@ -132,7 +132,19 @@ class YouTubeSceneBuilderService:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate scenes from video plan using AI."""
|
||||
|
||||
content_outline = video_plan.get("content_outline", [])
|
||||
raw_content_outline = video_plan.get("content_outline", [])
|
||||
content_outline: List[Dict[str, Any]] = []
|
||||
for item in raw_content_outline:
|
||||
if isinstance(item, dict):
|
||||
content_outline.append(item)
|
||||
else:
|
||||
content_outline.append(
|
||||
{
|
||||
"section": str(item),
|
||||
"description": "",
|
||||
"duration_estimate": 0,
|
||||
}
|
||||
)
|
||||
hook_strategy = video_plan.get("hook_strategy", "")
|
||||
call_to_action = video_plan.get("call_to_action", "")
|
||||
visual_style = video_plan.get("visual_style", "cinematic")
|
||||
@@ -263,16 +275,32 @@ Write narration that:
|
||||
# Normalize scene data
|
||||
normalized_scenes = []
|
||||
for idx, scene in enumerate(scenes, 1):
|
||||
normalized_scenes.append({
|
||||
"scene_number": scene.get("scene_number", idx),
|
||||
"title": scene.get("title", f"Scene {idx}"),
|
||||
"narration": scene.get("narration", ""),
|
||||
"visual_description": scene.get("visual_description", ""),
|
||||
"duration_estimate": scene.get("duration_estimate", scene_duration_range[0]),
|
||||
"emphasis": scene.get("emphasis", "main_content"),
|
||||
"visual_cues": scene.get("visual_cues", []),
|
||||
"visual_prompt": scene.get("visual_description", ""), # Initial prompt
|
||||
})
|
||||
if isinstance(scene, dict):
|
||||
scene_data = scene
|
||||
else:
|
||||
scene_data = {
|
||||
"scene_number": idx,
|
||||
"title": f"Scene {idx}",
|
||||
"narration": str(scene),
|
||||
"visual_description": "",
|
||||
"duration_estimate": scene_duration_range[0],
|
||||
"emphasis": "main_content",
|
||||
"visual_cues": [],
|
||||
}
|
||||
normalized_scenes.append(
|
||||
{
|
||||
"scene_number": scene_data.get("scene_number", idx),
|
||||
"title": scene_data.get("title", f"Scene {idx}"),
|
||||
"narration": scene_data.get("narration", ""),
|
||||
"visual_description": scene_data.get("visual_description", ""),
|
||||
"duration_estimate": scene_data.get(
|
||||
"duration_estimate", scene_duration_range[0]
|
||||
),
|
||||
"emphasis": scene_data.get("emphasis", "main_content"),
|
||||
"visual_cues": scene_data.get("visual_cues", []),
|
||||
"visual_prompt": scene_data.get("visual_description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return normalized_scenes
|
||||
|
||||
@@ -287,16 +315,32 @@ Write narration that:
|
||||
|
||||
normalized_scenes = []
|
||||
for idx, scene in enumerate(scenes, 1):
|
||||
normalized_scenes.append({
|
||||
"scene_number": scene.get("scene_number", idx),
|
||||
"title": scene.get("title", f"Scene {idx}"),
|
||||
"narration": scene.get("narration", ""),
|
||||
"visual_description": scene.get("visual_description", ""),
|
||||
"duration_estimate": scene.get("duration_estimate", scene_duration_range[0]),
|
||||
"emphasis": scene.get("emphasis", "main_content"),
|
||||
"visual_cues": scene.get("visual_cues", []),
|
||||
"visual_prompt": scene.get("visual_description", ""), # Initial prompt
|
||||
})
|
||||
if isinstance(scene, dict):
|
||||
scene_data = scene
|
||||
else:
|
||||
scene_data = {
|
||||
"scene_number": idx,
|
||||
"title": f"Scene {idx}",
|
||||
"narration": str(scene),
|
||||
"visual_description": "",
|
||||
"duration_estimate": scene_duration_range[0],
|
||||
"emphasis": "main_content",
|
||||
"visual_cues": [],
|
||||
}
|
||||
normalized_scenes.append(
|
||||
{
|
||||
"scene_number": scene_data.get("scene_number", idx),
|
||||
"title": scene_data.get("title", f"Scene {idx}"),
|
||||
"narration": scene_data.get("narration", ""),
|
||||
"visual_description": scene_data.get("visual_description", ""),
|
||||
"duration_estimate": scene_data.get(
|
||||
"duration_estimate", scene_duration_range[0]
|
||||
),
|
||||
"emphasis": scene_data.get("emphasis", "main_content"),
|
||||
"visual_cues": scene_data.get("visual_cues", []),
|
||||
"visual_prompt": scene_data.get("visual_description", ""),
|
||||
}
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[YouTubeSceneBuilder] ✅ Normalized {len(normalized_scenes)} scenes "
|
||||
|
||||
Reference in New Issue
Block a user