Save local changes (GSC/Bing integrations) before merging PR #354

This commit is contained in:
ajaysi
2026-02-13 13:11:27 +05:30
parent 43e66835ac
commit 08a1f4a1d8
144 changed files with 8310 additions and 2748 deletions

View File

@@ -20,12 +20,13 @@ class BaseAnalyticsHandler(ABC):
self.platform_name = platform_type.value
@abstractmethod
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, **kwargs) -> AnalyticsData:
"""
Get analytics data for the platform
Args:
user_id: User ID to get analytics for
**kwargs: Additional arguments for specific handlers
Returns:
AnalyticsData object with platform metrics

View File

@@ -42,7 +42,7 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
db_url = f'sqlite:///{db_path}'
return BingInsightsService(db_url)
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
"""
Get Bing Webmaster analytics data using Bing Webmaster API
"""
@@ -83,9 +83,32 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
if not access_token:
return self.create_error_response('Bing Webmaster access token not available')
# Select site: Prefer target_url match, otherwise first site
selected_site = sites[0] if sites else None
if not selected_site:
return self.create_error_response('No Bing sites found')
if target_url and sites:
logger.info(f"Attempting to match target URL: {target_url}")
# Normalize target URL (remove protocol, trailing slash)
normalized_target = target_url.replace('https://', '').replace('http://', '').rstrip('/')
for site in sites:
# Bing uses 'Url' key
site_url = site.get('Url', '')
normalized_site = site_url.replace('https://', '').replace('http://', '').rstrip('/')
if normalized_target in normalized_site or normalized_site in normalized_target:
selected_site = site
logger.info(f"Found matching Bing site: {site_url}")
break
site_url_for_storage = selected_site.get('Url', '') if selected_site else ''
logger.info(f"Using Bing site URL: {site_url_for_storage}")
query_stats = {}
try:
site_url_for_storage = sites[0].get('Url', '') if (sites and isinstance(sites[0], dict)) else None
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=30)
if stored and isinstance(stored, dict):
query_stats = {
@@ -99,7 +122,7 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
logger.warning(f"Bing analytics: Failed to read stored analytics summary: {e}")
# Get enhanced insights
insights = self._get_enhanced_insights_with_service(insights_service, user_id, sites[0].get('Url', '') if sites else '')
insights = self._get_enhanced_insights_with_service(insights_service, user_id, site_url_for_storage)
metrics = {
'connection_status': 'connected',

View File

@@ -22,16 +22,22 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
super().__init__(PlatformType.GSC)
self.gsc_service = GSCService()
async def get_analytics(self, user_id: str) -> AnalyticsData:
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
"""
Get Google Search Console analytics data with caching
Args:
user_id: User ID to get analytics for
target_url: Optional URL to prefer when selecting GSC site
Returns comprehensive SEO metrics including clicks, impressions, CTR, and position data.
"""
self.log_analytics_request(user_id, "get_analytics")
# Check cache first - GSC API calls can be expensive
cached_data = analytics_cache.get('gsc_analytics', user_id)
# Include target_url in cache key if provided
cache_key = f"{user_id}_{target_url}" if target_url else user_id
cached_data = analytics_cache.get('gsc_analytics', cache_key)
if cached_data:
logger.info("Using cached GSC analytics for user {user_id}", user_id=user_id)
return AnalyticsData(**cached_data)
@@ -45,8 +51,23 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
logger.warning(f"No GSC sites found for user {user_id}")
return self.create_error_response('No GSC sites found')
# Get analytics for the first site (or combine all sites)
site_url = sites[0]['siteUrl']
# Select site: Prefer target_url match, otherwise first site
selected_site = sites[0]
if target_url:
logger.info(f"Attempting to match target URL: {target_url}")
# Normalize target URL (remove protocol, trailing slash)
normalized_target = target_url.replace('https://', '').replace('http://', '').rstrip('/')
for site in sites:
site_url = site['siteUrl']
normalized_site = site_url.replace('https://', '').replace('http://', '').rstrip('/')
if normalized_target in normalized_site or normalized_site in normalized_target:
selected_site = site
logger.info(f"Found matching GSC site: {site_url}")
break
site_url = selected_site['siteUrl']
logger.info(f"Using GSC site URL: {site_url}")
# Get search analytics for last 30 days
@@ -71,7 +92,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
)
# Cache the result to avoid expensive API calls
analytics_cache.set('gsc_analytics', user_id, result.__dict__)
analytics_cache.set('gsc_analytics', cache_key, result.__dict__)
logger.info("Cached GSC analytics data for user {user_id}", user_id=user_id)
return result
@@ -81,7 +102,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
error_result = self.create_error_response(str(e))
# Cache error result for shorter time to retry sooner
analytics_cache.set('gsc_analytics', user_id, error_result.__dict__, ttl_override=300) # 5 minutes
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=300) # 5 minutes
return error_result
def get_connection_status(self, user_id: str) -> Dict[str, Any]:
@@ -117,111 +138,93 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
# New structure from updated GSC service
overall_rows = search_analytics.get('overall_metrics', {}).get('rows', [])
query_rows = search_analytics.get('query_data', {}).get('rows', [])
verification_rows = search_analytics.get('verification_data', {}).get('rows', [])
logger.info(f"GSC Overall metrics rows: {len(overall_rows)}")
logger.info(f"GSC Query data rows: {len(query_rows)}")
logger.info(f"GSC Verification rows: {len(verification_rows)}")
# Calculate totals from overall_rows (most accurate as it includes anonymized queries)
total_clicks = 0
total_impressions = 0
total_position = 0
valid_position_rows = 0
if overall_rows:
logger.info(f"GSC Overall first row: {overall_rows[0]}")
if query_rows:
logger.info(f"GSC Query first row: {query_rows[0]}")
# Use overall_rows for totals if available, otherwise fallback to query_rows
calc_rows = overall_rows if overall_rows else query_rows
for row in calc_rows:
clicks = row.get('clicks', 0)
impressions = row.get('impressions', 0)
position = row.get('position', 0)
total_clicks += clicks
total_impressions += impressions
if position and position > 0:
total_position += position * impressions # Weighted average
# Calculate weighted average position
avg_position = total_position / total_impressions if total_impressions > 0 else 0
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
# Use query_rows for top queries list
top_queries_source = query_rows
# Use query_rows for detailed insights, overall_rows for summary
rows = query_rows if query_rows else overall_rows
else:
# Legacy structure
rows = search_analytics.get('rows', [])
logger.info(f"GSC Legacy rows count: {len(rows)}")
if rows:
logger.info(f"GSC Legacy first row structure: {rows[0]}")
logger.info(f"GSC Legacy first row keys: {list(rows[0].keys()) if rows[0] else 'No rows'}")
# Calculate summary metrics - handle different response formats
total_clicks = 0
total_impressions = 0
total_position = 0
valid_rows = 0
for row in rows:
# Handle different possible response formats
clicks = row.get('clicks', 0)
impressions = row.get('impressions', 0)
position = row.get('position', 0)
# ... existing legacy logic ...
calc_rows = rows
top_queries_source = rows
# If position is 0 or None, skip it from average calculation
if position and position > 0:
total_position += position
valid_rows += 1
total_clicks = 0
total_impressions = 0
total_position = 0
valid_position_rows = 0
total_clicks += clicks
total_impressions += impressions
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
avg_position = total_position / valid_rows if valid_rows > 0 else 0
logger.info(f"GSC Calculated metrics - clicks: {total_clicks}, impressions: {total_impressions}, ctr: {avg_ctr}, position: {avg_position}, valid_rows: {valid_rows}")
# Get top performing queries - handle different data structures
if rows and 'keys' in rows[0]:
# New GSC API format with keys array
top_queries = sorted(rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
# Get top performing pages (if we have page data)
page_data = {}
for row in rows:
# Handle different key structures
keys = row.get('keys', [])
if len(keys) > 1 and keys[1]: # Page data available
page = keys[1].get('keys', ['Unknown'])[0] if isinstance(keys[1], dict) else str(keys[1])
else:
page = 'Unknown'
for row in calc_rows:
clicks = row.get('clicks', 0)
impressions = row.get('impressions', 0)
position = row.get('position', 0)
if page not in page_data:
page_data[page] = {'clicks': 0, 'impressions': 0, 'ctr': 0, 'position': 0}
page_data[page]['clicks'] += row.get('clicks', 0)
page_data[page]['impressions'] += row.get('impressions', 0)
else:
# Legacy format or no keys structure
top_queries = sorted(rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
page_data = {}
total_clicks += clicks
total_impressions += impressions
if position and position > 0:
# Simple average for legacy/unknown structure if we can't do weighted
total_position += position
valid_position_rows += 1
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
avg_position = total_position / valid_position_rows if valid_position_rows > 0 else 0
# Calculate page metrics
for page in page_data:
if page_data[page]['impressions'] > 0:
page_data[page]['ctr'] = page_data[page]['clicks'] / page_data[page]['impressions'] * 100
top_pages = sorted(page_data.items(), key=lambda x: x[1]['clicks'], reverse=True)[:10]
return {
'connection_status': 'connected',
'connected_sites': 1, # GSC typically has one site per user
'total_clicks': total_clicks,
'total_impressions': total_impressions,
'avg_ctr': round(avg_ctr, 2),
'avg_position': round(avg_position, 2),
'total_queries': len(rows),
'top_queries': [
{
# Get top performing queries
top_queries = []
if top_queries_source:
# Sort by clicks
sorted_queries = sorted(top_queries_source, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
for row in sorted_queries:
top_queries.append({
'query': self._extract_query_from_row(row),
'clicks': row.get('clicks', 0),
'impressions': row.get('impressions', 0),
'ctr': round(row.get('ctr', 0) * 100, 2),
'position': round(row.get('position', 0), 2)
}
for row in top_queries
],
'top_pages': [
{
'page': page,
'clicks': data['clicks'],
'impressions': data['impressions'],
'ctr': round(data['ctr'], 2)
}
for page, data in top_pages
],
'note': 'Google Search Console provides search performance data, keyword rankings, and SEO insights'
})
# Prepare Top Pages (requires page dimension, but we only requested query dimension in gsc_service step 3)
# To get top pages, we would need another API call with dimension=['page']
# For now, we'll return empty top_pages or infer from what we have if possible (we can't from query data)
top_pages = []
return {
'connection_status': 'connected',
'connected_sites': 1,
'total_clicks': total_clicks,
'total_impressions': total_impressions,
'avg_ctr': round(avg_ctr, 2),
'avg_position': round(avg_position, 2),
'total_queries': len(top_queries_source) if top_queries_source else 0,
'top_queries': top_queries,
'top_pages': top_pages
}
except Exception as e:

View File

@@ -59,6 +59,32 @@ class PlatformAnalyticsService:
logger.info(f"Getting comprehensive analytics for user {user_id}, platforms: {platforms}")
analytics_data = {}
# Determine target URL from Wix/WP for GSC site selection
target_url = None
try:
status = await self.get_platform_connection_status(user_id)
# Check Wix
if status.get('wix', {}).get('connected'):
sites = status['wix'].get('sites', [])
if sites:
# Assuming site object has 'blog_url' or 'url'
site = sites[0]
target_url = site.get('blog_url') or site.get('url')
# Check WordPress if no Wix
if not target_url and status.get('wordpress', {}).get('connected'):
sites = status['wordpress'].get('sites', [])
if sites:
site = sites[0]
target_url = site.get('blog_url') or site.get('url')
if target_url:
logger.info(f"Identified primary site URL for GSC matching: {target_url}")
except Exception as e:
logger.warning(f"Failed to determine target URL for GSC: {e}")
for platform_name in platforms:
try:
# Convert string to PlatformType enum
@@ -66,7 +92,10 @@ class PlatformAnalyticsService:
handler = self.handlers.get(platform_type)
if handler:
analytics_data[platform_name] = await handler.get_analytics(user_id)
if platform_type == PlatformType.GSC or platform_type == PlatformType.BING:
analytics_data[platform_name] = await handler.get_analytics(user_id, target_url=target_url)
else:
analytics_data[platform_name] = await handler.get_analytics(user_id)
else:
logger.warning(f"Unknown platform: {platform_name}")
analytics_data[platform_name] = self._create_error_response(platform_name, f"Unknown platform: {platform_name}")

View File

@@ -30,6 +30,8 @@ from models.product_asset_models import ProductAsset, ProductStyleTemplate, Ecom
from models.podcast_models import PodcastProject
# Research models use SubscriptionBase
from models.research_models import ResearchProject
# Video Studio models
from models.video_models import VideoGenerationTask
# Bing Analytics models
from models.bing_analytics_models import Base as BingAnalyticsBase
@@ -54,7 +56,22 @@ def get_user_db_path(user_id: str) -> str:
# Sanitize user_id to be safe for filesystem
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
user_workspace = os.path.join(WORKSPACE_DIR, f"workspace_{safe_user_id}")
return os.path.join(user_workspace, 'db', f'alwrity_{safe_user_id}.db')
# Check for legacy naming convention first (to support existing data)
# Some older workspaces might have 'alwrity.db' instead of 'alwrity_{user_id}.db'
legacy_db_path = os.path.join(user_workspace, 'db', 'alwrity.db')
specific_db_path = os.path.join(user_workspace, 'db', f'alwrity_{safe_user_id}.db')
# If the specific one exists, use it (preferred)
if os.path.exists(specific_db_path):
return specific_db_path
# If legacy exists and specific doesn't, use legacy
if os.path.exists(legacy_db_path):
return legacy_db_path
# Default to specific for new databases
return specific_db_path
def get_all_user_ids() -> List[str]:
"""

View File

@@ -14,6 +14,8 @@ from loguru import logger
from services.database import get_user_db_path
from dotenv import load_dotenv
class GSCService:
"""Service for Google Search Console integration."""
@@ -31,10 +33,62 @@ class GSCService:
services_dir = os.path.dirname(__file__)
backend_dir = os.path.abspath(os.path.join(services_dir, os.pardir))
self.credentials_file = os.path.join(backend_dir, "gsc_credentials.json")
logger.info(f"GSC credentials file path set to: {self.credentials_file}")
# Load client config from file or environment variables
self.client_config = self._load_client_config()
if self.client_config:
logger.info("GSC client configuration loaded successfully")
else:
logger.warning(f"GSC credentials not found in {self.credentials_file} or environment variables")
self.scopes = ['https://www.googleapis.com/auth/webmasters.readonly']
# Note: Tables are initialized lazily per user
logger.info("GSC Service initialized successfully")
def _load_client_config(self) -> Optional[Dict[str, Any]]:
"""Load Google client configuration from environment variables or file."""
# Reload environment variables to catch any runtime changes (e.g. .env updates)
load_dotenv(override=True)
# 1. Check Environment Variables (Priority)
client_id = os.getenv("GOOGLE_CLIENT_ID")
client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
if client_id and client_secret:
redirect_uri = os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
logger.info("Loading GSC credentials from environment variables")
# Construct the config dictionary expected by google_auth_oauthlib
return {
"web": {
"client_id": client_id,
"client_secret": client_secret,
"project_id": os.getenv("GOOGLE_PROJECT_ID", "alwrity"),
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"redirect_uris": [
"http://localhost:5173/onboarding",
redirect_uri
],
"javascript_origins": [
"http://localhost:5173",
"http://localhost:8000"
]
}
}
# 2. Fallback to File
if os.path.exists(self.credentials_file):
try:
with open(self.credentials_file, 'r') as f:
config = json.load(f)
logger.info(f"Loading GSC credentials from file: {self.credentials_file}")
return config
except Exception as e:
logger.warning(f"Failed to load GSC credentials from file: {e}")
return None
def _get_db_path(self, user_id: str) -> str:
return get_user_db_path(user_id)
@@ -94,11 +148,11 @@ class GSCService:
self._init_gsc_tables(user_id)
db_path = self._get_db_path(user_id)
# Read client credentials from file to ensure we have all required fields
with open(self.credentials_file, 'r') as f:
client_config = json.load(f)
if not self.client_config:
logger.error("Cannot save credentials: Client configuration not loaded")
return False
web_config = client_config.get('web', {})
web_config = self.client_config.get('web', {})
credentials_json = json.dumps({
'token': credentials.token,
@@ -184,12 +238,17 @@ class GSCService:
try:
logger.info(f"Generating OAuth URL for user: {user_id}")
if not os.path.exists(self.credentials_file):
raise FileNotFoundError(f"GSC credentials file not found: {self.credentials_file}")
# Retry loading config if missing (in case .env was added later)
if not self.client_config:
self.client_config = self._load_client_config()
if not self.client_config:
raise FileNotFoundError("GSC credentials not found in file or environment variables.")
redirect_uri = os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
flow = Flow.from_client_secrets_file(
self.credentials_file,
flow = Flow.from_client_config(
self.client_config,
scopes=self.scopes,
redirect_uri=redirect_uri
)
@@ -256,8 +315,12 @@ class GSCService:
conn.commit()
# Exchange code for credentials
flow = Flow.from_client_secrets_file(
self.credentials_file,
if not self.client_config:
logger.error("Cannot handle callback: Client configuration not loaded")
return False
flow = Flow.from_client_config(
self.client_config,
scopes=self.scopes,
redirect_uri=os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
)
@@ -283,7 +346,11 @@ class GSCService:
service = build('searchconsole', 'v1', credentials=credentials)
logger.info(f"Authenticated GSC service created for user: {user_id}")
return service
except ValueError as e:
# Log as warning only, as this is expected for unconnected users
# logger.warning(f"Cannot create GSC service for user {user_id}: {e}")
raise e
except Exception as e:
logger.error(f"Error creating authenticated GSC service for user {user_id}: {e}")
raise
@@ -291,7 +358,13 @@ class GSCService:
def get_site_list(self, user_id: str) -> List[Dict[str, Any]]:
"""Get list of sites from GSC."""
try:
service = self.get_authenticated_service(user_id)
try:
service = self.get_authenticated_service(user_id)
except ValueError:
# User not connected or credentials invalid
logger.warning(f"User {user_id} not connected to GSC. Returning empty site list.")
return []
sites = service.sites().list().execute()
site_list = []
@@ -306,7 +379,8 @@ class GSCService:
except Exception as e:
logger.error(f"Error getting site list for user {user_id}: {e}")
raise
# Return empty list instead of raising to prevent frontend 500s
return []
def get_search_analytics(self, user_id: str, site_url: str,
start_date: str = None, end_date: str = None) -> Dict[str, Any]:
@@ -325,7 +399,12 @@ class GSCService:
logger.info(f"Returning cached analytics data for user: {user_id}")
return cached_data
service = self.get_authenticated_service(user_id)
try:
service = self.get_authenticated_service(user_id)
except ValueError:
logger.warning(f"User {user_id} not connected to GSC. Returning empty analytics.")
return {'error': 'User not connected to GSC', 'rows': [], 'rowCount': 0}
if not service:
logger.error(f"Failed to get authenticated GSC service for user: {user_id}")
return {'error': 'Authentication failed', 'rows': [], 'rowCount': 0}
@@ -359,11 +438,11 @@ class GSCService:
logger.error(f"GSC Data verification failed for user {user_id}: {verification_error}")
return {'error': f'Data verification failed: {str(verification_error)}', 'rows': [], 'rowCount': 0}
# Step 2: Get overall metrics (no dimensions)
# Step 2: Get daily metrics for charting (ensure we have rows)
request = {
'startDate': start_date,
'endDate': end_date,
'dimensions': [], # No dimensions for overall metrics
'dimensions': ['date'], # Use date dimension to get time-series data
'rowLimit': 1000
}
@@ -472,7 +551,11 @@ class GSCService:
def revoke_user_access(self, user_id: str) -> bool:
"""Revoke user's GSC access."""
try:
with sqlite3.connect(self.db_path) as conn:
db_path = self._get_db_path(user_id)
if not os.path.exists(db_path):
return True
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
# Delete credentials
@@ -496,7 +579,11 @@ class GSCService:
def clear_incomplete_credentials(self, user_id: str) -> bool:
"""Clear incomplete GSC credentials that are missing required fields."""
try:
with sqlite3.connect(self.db_path) as conn:
db_path = self._get_db_path(user_id)
if not os.path.exists(db_path):
return True
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute('DELETE FROM gsc_credentials WHERE user_id = ?', (user_id,))
conn.commit()
@@ -511,7 +598,11 @@ class GSCService:
def _get_cached_data(self, user_id: str, site_url: str, data_type: str, cache_key: str) -> Optional[Dict]:
"""Get cached data if not expired."""
try:
with sqlite3.connect(self.db_path) as conn:
db_path = self._get_db_path(user_id)
if not os.path.exists(db_path):
return None
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
SELECT data_json FROM gsc_data_cache
@@ -531,9 +622,12 @@ class GSCService:
def _cache_data(self, user_id: str, site_url: str, data_type: str, data: Dict, cache_key: str):
"""Cache data with expiration."""
try:
self._init_gsc_tables(user_id)
db_path = self._get_db_path(user_id)
expires_at = datetime.now() + timedelta(hours=1) # Cache for 1 hour
with sqlite3.connect(self.db_path) as conn:
with sqlite3.connect(db_path) as conn:
cursor = conn.cursor()
cursor.execute('''
INSERT OR REPLACE INTO gsc_data_cache

View File

@@ -24,7 +24,16 @@ class WordPressOAuthService:
# WordPress.com OAuth2 credentials
self.client_id = os.getenv('WORDPRESS_CLIENT_ID', '')
self.client_secret = os.getenv('WORDPRESS_CLIENT_SECRET', '')
self.redirect_uri = os.getenv('WORDPRESS_REDIRECT_URI', 'https://alwrity-ai.vercel.app/wp/callback')
# Determine redirect URI dynamically
default_redirect = 'https://alwrity-ai.vercel.app/wp/callback'
frontend_url = os.getenv('FRONTEND_URL')
if frontend_url:
self.redirect_uri = f"{frontend_url.rstrip('/')}/wp/callback"
else:
self.redirect_uri = os.getenv('WORDPRESS_REDIRECT_URI', default_redirect)
self.base_url = "https://public-api.wordpress.com"
# Validate configuration

View File

@@ -17,8 +17,7 @@ from .core_agent_framework import (
# Market signal detection
from .market_signal_detector import (
MarketSignal,
MarketSignalDetector,
MarketTrendAnalyzer
MarketSignalDetector
)
# Performance monitoring

View File

@@ -105,6 +105,18 @@ class ALwrityAgentOrchestrator:
def _create_specialized_agents(self):
"""Create specialized marketing agents"""
try:
# Check if onboarding is complete before initializing heavy agents
try:
from services.onboarding.progress_service import OnboardingProgressService
onboarding_service = OnboardingProgressService()
status = onboarding_service.get_onboarding_status(self.user_id)
if not status.get("is_completed", False):
logger.info(f"Skipping agent initialization for user {self.user_id} - Onboarding incomplete")
return
except Exception as e:
logger.warning(f"Could not check onboarding status for {self.user_id}: {e}")
# Fallthrough to attempt initialization if check fails
enabled_by_key = {}
db = None
try:
@@ -159,6 +171,26 @@ class ALwrityAgentOrchestrator:
self.trend_surfer_agent = TrendSurferAgent(intel_service, self.user_id)
self.agents['trend'] = self.trend_surfer_agent
# Content Guardian Agent
if enabled_by_key.get("content_guardian", True):
try:
from services.intelligence.sif_agents import ContentGuardianAgent
from services.intelligence.txtai_service import TxtaiIntelligenceService
# Initialize intelligence service if not already available
intel_service = TxtaiIntelligenceService(self.user_id)
# Initialize Content Guardian Agent
self.content_guardian_agent = ContentGuardianAgent(
intelligence_service=intel_service,
user_id=self.user_id,
sif_service=None # SIF service is optional/circular dependency handling
)
self.agents['guardian'] = self.content_guardian_agent
logger.info(f"Initialized ContentGuardianAgent for user {self.user_id}")
except Exception as e:
logger.error(f"Failed to initialize ContentGuardianAgent: {e}")
logger.info(f"Created {len(self.agents)} specialized agents for user {self.user_id}")
except Exception as e:

View File

@@ -0,0 +1,213 @@
import logging
import time
from datetime import datetime
from sqlalchemy import text
from services.database import get_session_for_user
from models.subscription_models import APIProvider, UsageSummary
from services.subscription import PricingService
logger = logging.getLogger(__name__)
def track_agent_usage_sync(user_id: str, model_name: str, prompt: str, response_text: str, duration: float):
"""
Synchronously track agent LLM usage.
This mimics the logic in llm_text_gen to ensure consistency and robustness.
"""
try:
# Detect provider
provider_enum = APIProvider.GEMINI # Default
actual_provider_name = "gemini"
model_lower = model_name.lower()
if "gemini" in model_lower:
provider_enum = APIProvider.GEMINI
actual_provider_name = "gemini"
elif "gpt" in model_lower or "openai" in model_lower or "mistral" in model_lower:
# HuggingFace/Mistral often mapped to gpt-oss or mistral
provider_enum = APIProvider.MISTRAL
actual_provider_name = "huggingface"
elif "claude" in model_lower or "anthropic" in model_lower:
provider_enum = APIProvider.ANTHROPIC
actual_provider_name = "anthropic"
logger.info(f"[AgentTracking] Tracking usage for user {user_id}, provider {actual_provider_name}, model {model_name}")
db = get_session_for_user(user_id)
if not db:
logger.error(f"[AgentTracking] Could not get database session for user {user_id}")
return
try:
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
pricing = PricingService(db)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits
limits = pricing.get_user_limits(user_id)
token_limit = 0
provider_key = provider_enum.value
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_key}_tokens", 0) or 0
# Check for existing record
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
current_calls_before = 0
current_tokens_before = 0
if record_count and record_count > 0:
# Read current values
sql_query = text(f"""
SELECT {provider_key}_calls, {provider_key}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
else:
# Create new summary
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db.add(summary)
db.flush()
# Update calls
new_calls = current_calls_before + 1
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_key}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
# Update tokens with limit check
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
projected_new_tokens = current_tokens_before + tokens_total
if token_limit > 0 and projected_new_tokens > token_limit:
new_tokens = token_limit
tokens_total = max(0, token_limit - current_tokens_before)
else:
new_tokens = projected_new_tokens
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_key}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
else:
tokens_total = 0
# Calculate cost
try:
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(0, tokens_total - tracked_tokens_input)
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=model_name,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
cost_input = cost_info.get('cost_input', 0.0) or 0.0
cost_output = cost_info.get('cost_output', 0.0) or 0.0
except Exception as e:
logger.error(f"[AgentTracking] Cost calculation failed: {e}")
cost_total = 0.0
cost_input = 0.0
cost_output = 0.0
# Insert into APIUsageLog
try:
log_query = text("""
INSERT INTO api_usage_logs (
user_id, provider, endpoint, method, model_used,
tokens_input, tokens_output, tokens_total,
cost_input, cost_output, cost_total,
response_time, status_code, billing_period,
timestamp, actual_provider_name
) VALUES (
:user_id, :provider, :endpoint, :method, :model_used,
:tokens_input, :tokens_output, :tokens_total,
:cost_input, :cost_output, :cost_total,
:response_time, :status_code, :billing_period,
:created_at, :actual_provider_name
)
""")
db.execute(log_query, {
'user_id': user_id,
'provider': provider_enum.name, # Use name (GEMINI) not value (gemini) for SQLAlchemy Enum
'endpoint': 'agent_action',
'method': 'GENERATE',
'model_used': model_name,
'tokens_input': tracked_tokens_input,
'tokens_output': tracked_tokens_output,
'tokens_total': tracked_tokens_input + tracked_tokens_output,
'cost_input': cost_input,
'cost_output': cost_output,
'cost_total': cost_total,
'response_time': duration,
'status_code': 200,
'billing_period': current_period,
'created_at': datetime.utcnow(),
'actual_provider_name': actual_provider_name
})
except Exception as log_e:
logger.error(f"[AgentTracking] Failed to insert usage log: {log_e}")
if cost_total > 0:
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_key}_cost = COALESCE({provider_key}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
# Update totals
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = COALESCE(total_calls, 0) + 1,
total_tokens = COALESCE(total_tokens, 0) + :tokens_total
WHERE user_id = :user_id AND billing_period = :period
""")
db.execute(update_totals_query, {
'tokens_total': tokens_total,
'user_id': user_id,
'period': current_period
})
db.commit()
logger.info(f"[AgentTracking] ✅ Usage tracked: {new_calls} calls, {cost_total} cost")
except Exception as e:
logger.error(f"[AgentTracking] Error tracking usage: {e}", exc_info=True)
db.rollback()
finally:
db.close()
except Exception as e:
logger.error(f"[AgentTracking] Top level error: {e}", exc_info=True)

View File

@@ -32,9 +32,64 @@ from services.database import get_session_for_user
from services.intelligence.monitoring.semantic_dashboard import RealTimeSemanticMonitor
from services.intelligence.agents.safety_framework import get_safety_framework
from services.agent_activity_service import AgentActivityService
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import time
logger = get_service_logger(__name__)
class TrackingLLMWrapper:
"""
Wrapper for LLM instances to transparently track usage.
Intercepts calls to __call__ and generate() to log metrics.
"""
def __init__(self, llm: Any, user_id: str, model_name: str):
self.llm = llm
self.user_id = user_id
self.model_name = model_name
def __call__(self, prompt: str, *args, **kwargs) -> Any:
return self.generate(prompt, *args, **kwargs)
def generate(self, prompt: str, *args, **kwargs) -> str:
start_time = time.time()
try:
# Delegate to the underlying LLM
if hasattr(self.llm, "generate"):
response = self.llm.generate(prompt, *args, **kwargs)
else:
response = self.llm(prompt, *args, **kwargs)
# Handle response format (some might return list of dicts)
response_text = str(response)
if isinstance(response, list):
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
response_text = response[0]['generated_text']
else:
response_text = str(response[0])
# Track usage
duration = time.time() - start_time
try:
track_agent_usage_sync(
user_id=self.user_id,
model_name=self.model_name,
prompt=prompt,
response_text=response_text,
duration=duration
)
except Exception as e:
logger.warning(f"Failed to track agent usage in wrapper: {e}")
return response
except Exception as e:
logger.error(f"LLM generation failed in tracking wrapper: {e}")
raise e
def __getattr__(self, name):
# Delegate other attribute access to the underlying LLM
return getattr(self.llm, name)
@dataclass
class AgentAction:
"""Represents an action taken by an agent"""
@@ -114,6 +169,10 @@ class BaseALwrityAgent(ABC):
self.txtai_agent = None
self.llm = llm # Ensure llm is set if provided, regardless of txtai availability
# Wrap LLM with tracking if it exists
if self.llm:
self.llm = TrackingLLMWrapper(self.llm, self.user_id, self.model_name)
self.agent_key = self._resolve_agent_key(agent_type)
self._agent_profile = self._load_agent_profile_overrides()
self._prompt_context = self._load_prompt_context()
@@ -121,10 +180,17 @@ class BaseALwrityAgent(ABC):
if TXTAI_AVAILABLE:
try:
if not self.llm:
self.llm = LLM(model_name)
self.txtai_agent = self._create_txtai_agent()
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
# Create new LLM if not provided
raw_llm = LLM(model_name)
# Wrap it
self.llm = TrackingLLMWrapper(raw_llm, self.user_id, self.model_name)
try:
self.txtai_agent = self._create_txtai_agent()
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
except Exception as inner_e:
logger.warning(f"Could not initialize specific txtai agent for {agent_type}: {inner_e}")
self.txtai_agent = self._create_fallback_agent()
except Exception as e:
logger.error(f"Failed to initialize txtai agent for {agent_type}: {e}")
self.txtai_agent = self._create_fallback_agent()
@@ -134,6 +200,38 @@ class BaseALwrityAgent(ABC):
# Initialize safety framework
self.safety_framework = get_safety_framework(user_id)
async def _generate_llm_response(self, prompt: str) -> str:
"""
Helper to generate text using the agent's LLM with usage tracking.
Centralized method for all agents inheriting from BaseALwrityAgent.
"""
if not self.llm:
return "[LLM Unavailable]"
try:
# Run in executor to avoid blocking if LLM is synchronous
loop = asyncio.get_event_loop()
# Use the wrapped LLM's generate method (which handles tracking)
if hasattr(self.llm, "generate"):
response = await loop.run_in_executor(None, lambda: self.llm.generate(prompt))
else:
response = await loop.run_in_executor(None, lambda: self.llm(prompt))
# Handle list output (some models return list of dicts)
response_text = str(response)
if isinstance(response, list):
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
response_text = response[0]['generated_text']
else:
response_text = str(response[0])
return response_text
except Exception as e:
logger.error(f"LLM generation failed in agent {self.agent_type}: {e}")
return "[Generation Failed]"
def _resolve_agent_key(self, agent_type: str) -> str:
value = str(agent_type or "").strip()
if value.lower() == "strategyorchestrator".lower():

View File

@@ -758,6 +758,11 @@ async def get_agent_performance_summary(user_id: str, agent_id: str) -> Dict[str
"""Get comprehensive performance summary for an agent"""
return await performance_service.get_agent_performance_summary(user_id, agent_id)
async def get_all_agents_performance_summary(user_id: str) -> List[Dict[str, Any]]:
async def get_all_agents_performance_summary(user_id: str) -> List[Dict[str, Any]]:
"""Get performance summary for all agents for a user"""
return await performance_service.get_all_agents_performance_summary(user_id)
return await performance_service.get_all_agents_performance_summary(user_id)
# Alias for backward compatibility
PerformanceMonitor = AgentPerformanceMonitor
performance_monitor = performance_service
AgentPerformanceMetrics = AgentPerformanceSnapshot

View File

@@ -13,6 +13,7 @@ from loguru import logger
from ..txtai_service import TxtaiIntelligenceService
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, AgentAction
from services.seo_tools.content_strategy_service import ContentStrategyService
from services.intelligence.sif_agents import SharedLLMWrapper, LocalLLMWrapper
try:
from services.intelligence.sif_integration import SIFIntegrationService
SIF_AVAILABLE = True
@@ -20,14 +21,36 @@ except ImportError:
SIF_AVAILABLE = False
try:
from txtai import Agent, LLM
# Try importing from pipeline first (standard location)
from txtai.pipeline import Agent, LLM
TXTAI_AVAILABLE = True
except ImportError:
TXTAI_AVAILABLE = False
logger.warning("txtai not available, using fallback implementation")
try:
# Fallback to top-level import
from txtai import Agent, LLM
TXTAI_AVAILABLE = True
except ImportError:
TXTAI_AVAILABLE = False
Agent = None
LLM = None
logger.warning("txtai not available, using fallback implementation")
class SIFBaseAgent:
def __init__(self, intelligence_service: TxtaiIntelligenceService):
class SIFBaseAgent(BaseALwrityAgent):
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, agent_type: str = "sif_agent", model_name: str = "Qwen/Qwen2.5-3B-Instruct", llm: Any = None):
# Hybrid LLM Strategy:
# 1. Shared LLM for external/high-quality generation
self.shared_llm = SharedLLMWrapper(user_id)
# 2. Local LLM for internal agent work (default for SIF agents)
if llm is None:
if TXTAI_AVAILABLE:
# Use Lazy Local LLM
llm = LocalLLMWrapper(model_name)
else:
# Fallback to Shared if txtai not available
llm = self.shared_llm
super().__init__(user_id, agent_type, model_name, llm)
self.intelligence = intelligence_service
def _log_agent_operation(self, operation: str, **kwargs):
@@ -36,9 +59,27 @@ class SIFBaseAgent:
if kwargs:
logger.debug(f"[{self.__class__.__name__}] Parameters: {kwargs}")
def _create_txtai_agent(self):
"""
SIF agents use the intelligence service directly, but we can expose
capabilities via a standard agent interface if needed.
"""
if not TXTAI_AVAILABLE or Agent is None:
return None
# Return a simple agent that can use the LLM
try:
return Agent(llm=self.llm, tools=[])
except Exception as e:
logger.warning(f"Failed to create txtai Agent: {e}")
return None
class StrategyArchitectAgent(SIFBaseAgent):
"""Agent for discovering content pillars and identifying strategic gaps."""
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service, user_id, agent_type="strategy_architect")
async def discover_pillars(self) -> List[Dict[str, Any]]:
"""Identify content pillars through semantic clustering."""
self._log_agent_operation("Discovering content pillars")
@@ -108,9 +149,61 @@ class ContentGuardianAgent(SIFBaseAgent):
CANNIBALIZATION_THRESHOLD = 0.85 # Similarity threshold for cannibalization warning
ORIGINALITY_THRESHOLD = 0.75 # Minimum originality score
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="content_guardian")
self.sif_service = sif_service
# Lazy initialization of SIF service if not provided
if self.sif_service is None and SIF_AVAILABLE:
try:
self.sif_service = SIFIntegrationService(user_id)
logger.info(f"[{self.__class__.__name__}] Lazily initialized SIFIntegrationService")
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] Failed to lazily initialize SIF service: {e}")
async def assess_content_quality(self, content: str) -> Dict[str, Any]:
"""
Assess content quality based on originality, readability, and cannibalization risks.
"""
self._log_agent_operation("Assessing content quality", content_length=len(content))
try:
# 1. Check for cannibalization
cannibalization_result = await self.check_cannibalization(content)
# 2. Check originality (if not cannibalized)
originality_score = 1.0
if not cannibalization_result.get("warning"):
originality_result = await self.verify_originality(content, None)
originality_score = originality_result.get("originality_score", 1.0)
# 3. Check Style Compliance
style_result = await self.style_enforcer(content)
style_score = style_result.get("compliance_score", 1.0)
# 4. Basic Readability (Flesch-Kincaid proxy via sentence length/word complexity)
# Simple heuristic for now
words = content.split()
sentences = content.split('.')
avg_sentence_length = len(words) / max(1, len(sentences))
readability_score = 1.0 if avg_sentence_length < 20 else max(0.5, 1.0 - (avg_sentence_length - 20) * 0.05)
# Weighted Score: Originality (40%) + Style (30%) + Readability (30%)
quality_score = (originality_score * 0.4) + (style_score * 0.3) + (readability_score * 0.3)
return {
"quality_score": quality_score,
"originality_score": originality_score,
"readability_score": readability_score,
"style_score": style_score,
"cannibalization_risk": cannibalization_result,
"style_compliance": style_result,
"is_acceptable": quality_score > 0.7 and not cannibalization_result.get("warning", False)
}
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Failed to assess content quality: {e}")
return {"error": str(e), "quality_score": 0.0}
async def check_cannibalization(self, new_draft: str) -> Dict[str, Any]:
"""Check if a new draft competes semantically with existing pages."""
@@ -193,25 +286,74 @@ class ContentGuardianAgent(SIFBaseAgent):
# 1. Fetch Style Guidelines from SIF if not provided
if not style_guidelines and self.sif_service:
try:
# Search for website analysis to get brand voice/style
# We assume the most relevant 'website_analysis' doc contains the guidelines
results = await self.intelligence.search("website analysis brand voice style", limit=1)
if results:
import json
res = results[0]
metadata_str = res.get('object')
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else (metadata_str or res)
# Use central SIF service to get robust context
seo_context = await self.sif_service.get_seo_context()
if seo_context and "error" not in seo_context:
# Extract brand voice/style from the context
# The context structure is normalized in get_seo_context
if metadata.get('type') == 'website_analysis':
report = metadata.get('full_report', {})
style_guidelines = {
"tone": report.get('brand_analysis', {}).get('brand_voice', 'neutral'),
"style_patterns": report.get('style_patterns', {}),
"writing_style": report.get('writing_style', {})
}
logger.info(f"[{self.__class__.__name__}] Retrieved style guidelines from SIF: {style_guidelines.get('tone')}")
# Note: get_seo_context returns a flattened dict.
# We need to dig into the original structure if available, or rely on what's mapped.
# However, get_seo_context maps 'seo_audit', 'sitemap_analysis', etc.
# Brand info is usually in 'brand_analysis' col of WebsiteAnalysis, which might not be fully exposed
# in the simplified get_seo_context return.
# Let's check if we can get the full object or if we need to expand get_seo_context.
# For now, we'll try to use what's there or fall back to a specific search if needed.
# Actually, looking at get_seo_context implementation:
# It returns 'seo_audit', 'crawl_result'.
# Brand analysis is often stored in WebsiteAnalysis.brand_analysis.
# We might need to extend get_seo_context or do a specific retrieval here.
# But wait! I saw get_seo_context implementation earlier:
# It retrieves the "full_report" from the SIF metadata.
# If the SIF index contains the full WebsiteAnalysis object, we are good.
# Let's try to get it from the full report if we can access it,
# but get_seo_context returns a filtered dict.
# Alternative: Use the robust retrieval logic but specifically for brand info if get_seo_context is too narrow.
# But get_seo_context logic includes "website analysis seo audit" query.
# Let's assume for now we use the same retrieval logic but locally adapted,
# OR better, trust get_seo_context to be the single point of truth.
# If get_seo_context doesn't return brand info, we should update IT, not hack here.
# But I can't update SIFIntegrationService right now without context switch.
# Let's stick to the previous manual search pattern BUT use the SIF service helper if possible.
# Actually, the previous code was:
# results = await self.intelligence.search("website analysis brand voice style", limit=1)
# Let's keep it simple and robust:
# Try to get it from SIF service if possible.
# Since get_seo_context might not return brand_voice directly, let's try to see if we can use it.
# Actually, let's use the manual search but with better error handling,
# mirroring get_seo_context's robustness (e.g. parsing).
results = await self.intelligence.search("website analysis brand voice style", limit=1)
if results:
res = results[0]
metadata_str = res.get('object')
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else (metadata_str or res)
if metadata.get('type') == 'website_analysis':
report = metadata.get('full_report', {})
# Support both flat and nested structures
brand_analysis = report.get('brand_analysis') or report.get('brand_voice', {})
if isinstance(brand_analysis, str):
# Handle case where it might be a JSON string
try: brand_analysis = json.loads(brand_analysis)
except: brand_analysis = {"brand_voice": brand_analysis}
style_guidelines = {
"tone": brand_analysis.get('brand_voice', 'neutral') if isinstance(brand_analysis, dict) else 'neutral',
"style_patterns": report.get('style_patterns', {}),
"writing_style": report.get('writing_style', {})
}
logger.info(f"[{self.__class__.__name__}] Retrieved style guidelines from SIF index")
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] Failed to retrieve style guidelines from SIF: {e}")
logger.warning(f"[{self.__class__.__name__}] Failed to retrieve style guidelines: {e}")
issues = []
score = 1.0
@@ -246,6 +388,55 @@ class ContentGuardianAgent(SIFBaseAgent):
logger.error(f"[{self.__class__.__name__}] Style enforcement failed: {e}")
return {"error": str(e)}
async def perform_site_audit(self, website_url: str, limit: int = 10) -> Dict[str, Any]:
"""
Perform a quality audit on the user's website content.
"""
self._log_agent_operation("Performing site audit", website_url=website_url)
try:
# 1. Retrieve recent content for the site from SIF
# We search for everything with the website_url in metadata
# Note: This depends on how data is indexed.
results = await self.intelligence.search(f"site:{website_url}", limit=limit)
if not results:
logger.info(f"[{self.__class__.__name__}] No content found for site audit")
return {"error": "No content found"}
audit_results = []
total_quality = 0.0
for res in results:
text = res.get('text', '')
if not text or len(text) < 100:
continue
quality = await self.assess_content_quality(text)
audit_results.append({
"id": res.get('id'),
"title": res.get('title', 'Unknown'),
"quality": quality
})
total_quality += quality.get('quality_score', 0.0)
avg_quality = total_quality / len(audit_results) if audit_results else 0.0
report = {
"website_url": website_url,
"pages_audited": len(audit_results),
"average_quality_score": avg_quality,
"details": audit_results,
"timestamp": datetime.utcnow().isoformat()
}
logger.info(f"[{self.__class__.__name__}] Site audit completed. Avg Quality: {avg_quality:.2f}")
return report
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Site audit failed: {e}")
return {"error": str(e)}
async def safety_filter(self, text: str) -> Dict[str, Any]:
"""
Tool: Flags potentially harmful, offensive, or sensitive content.
@@ -290,8 +481,8 @@ class LinkGraphAgent(SIFBaseAgent):
RELEVANCE_THRESHOLD = 0.6 # Minimum relevance score for link suggestions
MAX_SUGGESTIONS = 10 # Maximum number of link suggestions
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="link_graph")
self.sif_service = sif_service
async def suggest_internal_links(self, draft: str) -> List[Dict[str, Any]]:
@@ -823,9 +1014,10 @@ class ContentStrategyAgent(BaseALwrityAgent):
Maintain the original meaning and tone.
"""
if hasattr(self.llm, "generate"):
if self.llm:
# We assume the LLM returns JSON-like text or we parse it
response = self.llm.generate(f"{system_prompt}\n\nText to rewrite:\n{content}")
response = await self._generate_llm_response(f"{system_prompt}\n\nText to rewrite:\n{content}")
# Simple parsing fallback if LLM returns raw text
if isinstance(response, str) and not response.strip().startswith("{"):
optimized_content = response
@@ -1456,34 +1648,7 @@ class SEOOptimizationAgent(BaseALwrityAgent):
"timestamp": datetime.utcnow().isoformat()
}
async def _generate_llm_response(self, prompt: str) -> str:
"""Helper to generate text using the agent's LLM"""
if not self.llm:
return "[LLM Unavailable]"
try:
# Run in executor to avoid blocking if LLM is synchronous
loop = asyncio.get_event_loop()
# Check if LLM is a txtai pipeline (callable) or has generate method
if hasattr(self.llm, "generate"):
# Some txtai pipelines use generate, some are just called
response = await loop.run_in_executor(None, lambda: self.llm.generate(prompt))
else:
# Assume callable (standard txtai pipeline)
response = await loop.run_in_executor(None, lambda: self.llm(prompt))
# Handle list output (some models return list of dicts)
if isinstance(response, list):
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
return response[0]['generated_text']
return str(response[0])
return str(response)
except Exception as e:
logger.error(f"LLM generation failed: {e}")
return "[Generation Failed]"
async def _strategy_generator_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
"""SEO strategy generation tool"""
audit_results = context.get("audit_results", {})
@@ -1629,8 +1794,8 @@ class SocialAmplificationAgent(BaseALwrityAgent):
Return ONLY the adapted content.
"""
if hasattr(self.llm, "generate"):
adapted_content = self.llm.generate(prompt)
if self.llm:
adapted_content = await self._generate_llm_response(prompt)
else:
adapted_content = f"[Mock {platform}]: {content[:50]}... #adapted"

View File

@@ -19,7 +19,7 @@ class TrendSurferAgent(SIFBaseAgent):
"""
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service)
super().__init__(intelligence_service, user_id, agent_type="trend_surfer")
self.user_id = user_id
self.signal_detector = MarketSignalDetector(user_id)
self.trends_service = GoogleTrendsService()
@@ -148,15 +148,41 @@ class TrendSurferAgent(SIFBaseAgent):
else:
recommendation = "Create new content"
# Use LLM to generate creative angle
headline = f"Trend: {trend.description}"
angle = f"Leverage {trend.source} trend on {trend.related_topics[0] if trend.related_topics else 'topic'}"
try:
prompt = f"""
Analyze this market trend signal and propose a content angle:
Trend: {trend.description}
Related Topics: {', '.join(trend.related_topics)}
Impact Score: {trend.impact_score}
Recommendation: {recommendation}
Provide a catchy headline and a 1-sentence strategic angle.
Format: Headline | Angle
"""
response = await self._generate_llm_response(prompt)
if response and "|" in response:
parts = response.split('|')
headline = parts[0].strip()
angle = parts[1].strip()
elif response:
angle = response.strip()
except Exception as e:
logger.warning(f"[{self.__class__.__name__}] LLM generation failed for opportunity: {e}")
return {
"trend_id": trend.signal_id,
"topic": trend.description,
"headline": headline,
"source": trend.source,
"urgency": trend.urgency_level.value,
"impact_score": trend.impact_score,
"current_coverage": coverage_score,
"recommendation": recommendation,
"suggested_angle": f"Leverage {trend.source} trend on {trend.related_topics[0] if trend.related_topics else 'topic'}",
"suggested_angle": angle,
"detected_at": trend.detected_at
}

View File

@@ -5,13 +5,76 @@ Each agent leverages TxtaiIntelligenceService for semantic operations.
"""
import traceback
import json
import asyncio
from typing import List, Dict, Any, Optional
from datetime import datetime
from loguru import logger
from .txtai_service import TxtaiIntelligenceService
from .txtai_service import TxtaiIntelligenceService, TXTAI_AVAILABLE
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent
from services.llm_providers.main_text_generation import llm_text_gen
class SIFBaseAgent:
def __init__(self, intelligence_service: TxtaiIntelligenceService):
# Optional txtai imports
try:
from txtai.pipeline import Agent, LLM
except ImportError:
Agent = None
LLM = None
class SharedLLMWrapper:
"""Wraps the shared ALwrity LLM service to look like a txtai LLM."""
def __init__(self, user_id: str):
self.user_id = user_id
def generate(self, prompt: str, **kwargs) -> str:
"""Generate text using the shared LLM provider."""
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
# but we could map them if needed.
return llm_text_gen(prompt, user_id=self.user_id)
def __call__(self, prompt: str, **kwargs) -> str:
return self.generate(prompt, **kwargs)
class LocalLLMWrapper:
"""
Lazily loads a local LLM via txtai.
This prevents blocking server startup with heavy model loads.
"""
def __init__(self, model_path: str):
self.model_path = model_path
self._llm = None
@property
def llm(self):
if self._llm is None:
if LLM is None:
raise ImportError("txtai.pipeline.LLM is not available")
logger.info(f"Loading local LLM: {self.model_path}")
self._llm = LLM(path=self.model_path)
return self._llm
def __call__(self, prompt: str, **kwargs) -> str:
return self.llm(prompt, **kwargs)
def generate(self, prompt: str, **kwargs) -> str:
return self.llm(prompt, **kwargs)
class SIFBaseAgent(BaseALwrityAgent):
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, agent_type: str = "sif_agent", model_name: str = "Qwen/Qwen2.5-3B-Instruct", llm: Any = None):
# Hybrid LLM Strategy:
# 1. Shared LLM for external/high-quality generation (available to all agents)
self.shared_llm = SharedLLMWrapper(user_id)
# 2. Local LLM for internal agent work (default for SIF agents)
if llm is None:
if TXTAI_AVAILABLE:
# Use Lazy Local LLM
llm = LocalLLMWrapper(model_name)
else:
# Fallback to Shared if txtai not available
llm = self.shared_llm
super().__init__(user_id, agent_type, model_name, llm)
self.intelligence = intelligence_service
def _log_agent_operation(self, operation: str, **kwargs):
@@ -20,9 +83,23 @@ class SIFBaseAgent:
if kwargs:
logger.debug(f"[{self.__class__.__name__}] Parameters: {kwargs}")
def _create_txtai_agent(self):
"""
SIF agents use the intelligence service directly, but we can expose
capabilities via a standard agent interface if needed.
"""
if not TXTAI_AVAILABLE:
return None
# Return a simple agent that can use the LLM
return Agent(llm=self.llm, tools=[])
class StrategyArchitectAgent(SIFBaseAgent):
"""Agent for discovering content pillars and identifying strategic gaps."""
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service, user_id, agent_type="strategy_architect")
async def discover_pillars(self) -> List[Dict[str, Any]]:
"""Identify content pillars through semantic clustering."""
self._log_agent_operation("Discovering content pillars")
@@ -58,6 +135,61 @@ class StrategyArchitectAgent(SIFBaseAgent):
logger.error(f"[{self.__class__.__name__}] Failed to discover pillars: {e}")
logger.error(f"[{self.__class__.__name__}] Full traceback: {traceback.format_exc()}")
return []
async def analyze_content_strategy(self, website_data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
Analyze content strategy based on website data and semantic insights.
Args:
website_data: Dictionary containing website analysis data
Returns:
List of strategic recommendations
"""
self._log_agent_operation("Analyzing content strategy")
try:
recommendations = []
# 1. Discover existing pillars
pillars = await self.discover_pillars()
# 2. Analyze gaps based on pillars (simplified logic for now)
if not pillars:
recommendations.append({
"type": "strategy_gap",
"priority": "high",
"title": "Establish Core Content Pillars",
"description": "No clear content clusters found. Focus on defining 3-5 core topics to build authority."
})
else:
# Suggest strengthening weak pillars
for pillar in pillars:
if pillar['size'] < 3:
recommendations.append({
"type": "content_depth",
"priority": "medium",
"title": f"Strengthen Pillar {pillar['pillar_id']}",
"description": "This topic cluster has few articles. Create more content to establish authority.",
"pillar_id": pillar['pillar_id']
})
# 3. Add generic recommendations based on website data if available
if website_data:
if not website_data.get('description'):
recommendations.append({
"type": "metadata",
"priority": "high",
"title": "Missing Meta Description",
"description": "Website is missing a meta description. Add one to improve SEO CTR."
})
logger.info(f"[{self.__class__.__name__}] Generated {len(recommendations)} strategic recommendations")
return recommendations
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Failed to analyze content strategy: {e}")
return []
def _calculate_cluster_confidence(self, cluster_indices: List[int]) -> float:
"""Calculate confidence score for a cluster based on its size and coherence."""
@@ -92,10 +224,40 @@ class ContentGuardianAgent(SIFBaseAgent):
CANNIBALIZATION_THRESHOLD = 0.85 # Similarity threshold for cannibalization warning
ORIGINALITY_THRESHOLD = 0.75 # Minimum originality score
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="content_guardian")
self.sif_service = sif_service
async def assess_content_quality(self, website_data: Dict[str, Any]) -> Dict[str, Any]:
"""Assess overall content quality based on website data."""
self._log_agent_operation("Assessing content quality")
try:
# Extract sample text or description from website_data
text_to_analyze = website_data.get('description', '') or website_data.get('title', '')
if not text_to_analyze:
return {"score": 0.5, "reason": "No content to analyze"}
# Run style check
style_result = await self.style_enforcer(text_to_analyze)
# Run safety check
safety_result = await self.safety_filter(text_to_analyze)
# Calculate aggregate score
base_score = style_result.get('compliance_score', 0.8)
if safety_result.get('action') == 'flag_for_review':
base_score *= 0.5
return {
"score": base_score,
"style_analysis": style_result,
"safety_analysis": safety_result,
"analyzed_text_length": len(text_to_analyze)
}
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Quality assessment failed: {e}")
return {"score": 0.0, "error": str(e)}
async def check_cannibalization(self, new_draft: str) -> Dict[str, Any]:
"""Check if a new draft competes semantically with existing pages."""
self._log_agent_operation("Checking for semantic cannibalization", draft_length=len(new_draft))
@@ -274,8 +436,8 @@ class LinkGraphAgent(SIFBaseAgent):
RELEVANCE_THRESHOLD = 0.6 # Minimum relevance score for link suggestions
MAX_SUGGESTIONS = 10 # Maximum number of link suggestions
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
super().__init__(intelligence_service)
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
super().__init__(intelligence_service, user_id, agent_type="link_graph")
self.sif_service = sif_service
async def suggest_internal_links(self, draft: str) -> List[Dict[str, Any]]:
@@ -479,6 +641,9 @@ class CitationExpert(SIFBaseAgent):
EVIDENCE_THRESHOLD = 0.7 # Minimum relevance score for evidence
MAX_EVIDENCE = 5 # Maximum number of evidence pieces to return
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
super().__init__(intelligence_service, user_id, agent_type="citation_expert")
async def fact_checker(self, claim: str) -> List[Dict[str, Any]]:
"""
Tool: Verifies facts against trusted research data.
@@ -542,60 +707,25 @@ class CitationExpert(SIFBaseAgent):
"claim": claim,
"status": status,
"evidence_count": len(evidence),
"top_evidence": evidence[0]['source'] if evidence else None
"top_evidence": evidence[0] if evidence else None
})
return {
"status": "verification_complete",
"total_claims": len(claims),
"status": "completed",
"verified_claims": verified_results,
"unsupported_count": len([c for c in verified_results if c['status'] == 'unsupported']),
"timestamp": datetime.utcnow().isoformat()
"verification_score": len([c for c in verified_results if c['status'] == 'supported']) / len(verified_results)
}
async def verify_facts(self, claim: str) -> List[Dict[str, Any]]:
"""Find supporting or contradicting evidence in the indexed research."""
self._log_agent_operation("Verifying facts", claim_length=len(claim))
"""Verify a single claim against intelligence data."""
results = await self.intelligence.search(claim, limit=3)
try:
if not self.intelligence.is_initialized():
logger.error(f"[{self.__class__.__name__}] Intelligence service not initialized")
return []
if not claim or len(claim.strip()) < 20:
logger.warning(f"[{self.__class__.__name__}] Claim too short for meaningful verification")
return []
results = await self.intelligence.search(claim, limit=self.MAX_EVIDENCE)
if not results:
logger.info(f"[{self.__class__.__name__}] No evidence found for claim")
return []
evidence = []
for result in results:
relevance_score = result.get('score', 0.0)
if relevance_score >= self.EVIDENCE_THRESHOLD:
evidence_piece = {
"source": result.get('id', 'unknown'),
"relevance": relevance_score,
"confidence": self._calculate_evidence_confidence(relevance_score),
"type": "supporting" if relevance_score > 0.8 else "related",
"excerpt": result.get('text', '')[:200] + "..." if len(result.get('text', '')) > 200 else result.get('text', '')
}
evidence.append(evidence_piece)
logger.debug(f"[{self.__class__.__name__}] Found evidence: {evidence_piece['source']} (score: {relevance_score:.3f})")
logger.info(f"[{self.__class__.__name__}] Found {len(evidence)} pieces of evidence for claim")
return evidence
except Exception as e:
logger.error(f"[{self.__class__.__name__}] Failed to verify facts: {e}")
logger.error(f"[{self.__class__.__name__}] Full traceback: {traceback.format_exc()}")
return []
def _calculate_evidence_confidence(self, relevance_score: float) -> float:
"""Calculate confidence score for evidence."""
# Simple confidence based on relevance score
return min(1.0, relevance_score * 1.2)
evidence = []
for result in results:
if result.get('score', 0) > self.EVIDENCE_THRESHOLD:
evidence.append({
"text": result.get('text'),
"source": result.get('id'),
"confidence": result.get('score')
})
return evidence

View File

@@ -938,14 +938,14 @@ class SIFIntegrationService:
# Strategic recommendations (lazy initialization to avoid circular imports)
if not self.strategy_agent:
from .sif_agents import StrategyArchitectAgent
self.strategy_agent = StrategyArchitectAgent(self.intelligence_service)
self.strategy_agent = StrategyArchitectAgent(self.intelligence_service, user_id=self.user_id)
recommendations = await self.strategy_agent.analyze_content_strategy(website_data)
insights["strategic_recommendations"] = recommendations
# Content quality assessment (lazy initialization to avoid circular imports)
if not self.guardian_agent:
from .sif_agents import ContentGuardianAgent
self.guardian_agent = ContentGuardianAgent(self.intelligence_service, sif_service=self)
self.guardian_agent = ContentGuardianAgent(self.intelligence_service, user_id=self.user_id, sif_service=self)
quality_score = await self.guardian_agent.assess_content_quality(website_data)
insights["content_quality"] = quality_score

View File

@@ -33,7 +33,13 @@ class TxtaiIntelligenceService:
self._initialized = False
self.enable_caching = enable_caching
self.cache_manager = semantic_cache_manager if enable_caching else None
self._initialize_embeddings()
# Lazy initialization - do not initialize embeddings on startup
# self._initialize_embeddings()
def _ensure_initialized(self):
"""Lazy initialization helper."""
if not self._initialized:
self._initialize_embeddings()
def _initialize_embeddings(self):
"""Initialize txtai embeddings with local storage support and comprehensive error handling."""
@@ -106,6 +112,7 @@ class TxtaiIntelligenceService:
Args:
items: List of (id, text, metadata) tuples.
"""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot index content - service not initialized for user {self.user_id}")
return
@@ -145,6 +152,7 @@ class TxtaiIntelligenceService:
async def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
"""Perform semantic search with intelligent caching."""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot perform search - service not initialized for user {self.user_id}")
return []
@@ -186,6 +194,7 @@ class TxtaiIntelligenceService:
async def get_similarity(self, text1: str, text2: str) -> float:
"""Get semantic similarity between two texts with caching."""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot calculate similarity - service not initialized for user {self.user_id}")
return 0.0
@@ -234,6 +243,7 @@ class TxtaiIntelligenceService:
async def cluster(self, min_score: float = 0.5) -> List[List[int]]:
"""Cluster indexed content to find semantic pillars using graph-based clustering with caching."""
self._ensure_initialized()
if not self._initialized or not self.embeddings:
logger.error(f"Cannot cluster content - service not initialized for user {self.user_id}")
return []
@@ -358,6 +368,7 @@ class TxtaiIntelligenceService:
async def classify(self, text: str, labels: List[str]) -> List[Tuple[str, float]]:
"""Classify text using zero-shot classification."""
self._ensure_initialized()
if not self._initialized or not Labels:
logger.error(f"Cannot classify text - service not initialized or Labels not available for user {self.user_id}")
return []

View File

@@ -297,7 +297,7 @@ def _dict_to_types_schema(schema: Dict[str, Any]) -> types.Schema:
return _convert(schema)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None):
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None, user_id: str = None):
"""
Generate structured JSON response using Google's Gemini Pro model.
@@ -312,6 +312,7 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
top_k (int): Top-k sampling parameter
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
system_prompt (str, optional): System instruction for the model
user_id (str, optional): User ID for usage tracking.
Returns:
dict: Parsed JSON response matching the provided schema
@@ -468,6 +469,25 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
if response.parsed is not None:
logger.info("Using response.parsed for structured output")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import json
response_str = json.dumps(response.parsed)
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=response_str,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return response.parsed
else:
logger.warning("Response.parsed is None, falling back to text parsing")
@@ -500,6 +520,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
parsed_text = json.loads(cleaned_text)
logger.info("Successfully parsed text as JSON")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=cleaned_text,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse text as JSON: {e}")
@@ -521,6 +557,26 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
fixed_json = re.sub(r',\s*]', ']', fixed_json)
parsed_text = json.loads(fixed_json)
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
import json
response_str = json.dumps(parsed_text) if parsed_text else ""
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=response_str,
duration=0.5 # Approximation
)
logger.info(f"✅ Tracked structured JSON usage for user {user_id}")
except Exception as e:
logger.error(f"Failed to track usage: {e}")
logger.info("Successfully parsed cleaned JSON")
return parsed_text
except Exception as fix_error:
@@ -537,6 +593,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
import json
parsed_text = json.loads(part.text)
logger.info("Successfully parsed candidate text as JSON")
# Track usage if user_id is provided
if user_id:
try:
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
track_agent_usage_sync(
user_id=user_id,
model_name="gemini-2.5-flash",
prompt=prompt,
response_text=part.text,
duration=0.5
)
except Exception as e:
logger.error(f"Failed to track usage: {e}")
return parsed_text
except json.JSONDecodeError as e:
logger.error(f"Failed to parse candidate text as JSON: {e}")

View File

@@ -4,6 +4,7 @@ import io
import os
from typing import Optional
from PIL import Image
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
from services.wavespeed.client import WaveSpeedClient
@@ -14,7 +15,10 @@ logger = get_service_logger("wavespeed.image_provider")
class WaveSpeedImageProvider(ImageGenerationProvider):
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen.
Implements robust error handling and retries for production stability.
"""
SUPPORTED_MODELS = {
"ideogram-v3-turbo": {
@@ -54,6 +58,28 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
list(self.SUPPORTED_MODELS.keys()))
@retry(
stop=stop_after_attempt(3),
wait=wait_exponential(multiplier=1, min=2, max=10),
retry=retry_if_exception_type((RuntimeError, IOError)),
reraise=True
)
def _call_api_with_retry(self, method, **kwargs):
"""Execute API call with retry logic.
Args:
method: Callable API method
**kwargs: Arguments for the method
Returns:
API response
"""
try:
return method(**kwargs)
except Exception as e:
logger.warning(f"WaveSpeed API call failed (retrying): {str(e)}")
raise
def _validate_options(self, options: ImageGenerationOptions) -> None:
"""Validate generation options.
@@ -117,7 +143,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
# Call WaveSpeed API (using generic image generation method)
# This will need to be adjusted based on actual WaveSpeed client implementation
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
# Adjust based on actual WaveSpeed API response format
@@ -167,7 +193,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
if isinstance(result, bytes):
@@ -216,7 +242,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
params["seed"] = options.seed
# Call WaveSpeed API
result = self.client.generate_image(**params)
result = self._call_api_with_retry(self.client.generate_image, **params)
# Extract image bytes from result
if isinstance(result, bytes):

View File

@@ -107,11 +107,13 @@ def generate_audio(
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
try:
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import UsageSummary, APIProvider
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
@@ -194,7 +196,11 @@ def generate_audio(
if audio_bytes:
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[audio_gen] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -383,12 +389,14 @@ def clone_voice(
voice_clone_cost = 0.5
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
@@ -432,7 +440,11 @@ def clone_voice(
if preview_audio_bytes:
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[clone_voice] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -570,12 +582,14 @@ def qwen3_voice_clone(
char_count = len(text)
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_db
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = next(get_db())
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
@@ -615,7 +629,11 @@ def qwen3_voice_clone(
if preview_audio_bytes:
try:
db_track = next(get_db())
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[qwen3_voice_clone] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
@@ -691,6 +709,7 @@ def qwen3_voice_clone(
├─ Provider: wavespeed
├─ Model: wavespeed-ai/qwen3-tts/voice-clone
├─ Calls: {current_calls_before}{new_calls}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
@@ -724,3 +743,373 @@ def qwen3_voice_clone(
},
)
def qwen3_voice_design(
text: str,
voice_description: str,
*,
language: str = "auto",
user_id: Optional[str] = None,
) -> VoiceCloneResult:
try:
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
if not text or not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text is required and cannot be empty")
text = text.strip()
if not voice_description or not isinstance(voice_description, str) or len(voice_description.strip()) == 0:
raise ValueError("Voice description is required")
voice_description = voice_description.strip()
char_count = len(text)
# Pricing logic similar to TTS/Clone
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=char_count,
actual_provider_name="wavespeed",
)
if not can_proceed:
raise HTTPException(
status_code=429,
detail={
"error": message,
"message": message,
"provider": "wavespeed",
"usage_info": usage_info if usage_info else {},
},
)
finally:
db.close()
except HTTPException:
raise
except Exception as sub_error:
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
import time
start_time = time.time()
client = WaveSpeedClient()
preview_audio_bytes = client.voice_design(
text=text,
voice_description=voice_description,
language=language
)
response_time = time.time() - start_time
# Track usage
try:
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[qwen3_voice_design] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
from sqlalchemy import text as sql_text
from services.subscription.provider_detection import detect_actual_provider
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db_track.add(summary)
db_track.flush()
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + float(estimated_cost)
update_query = sql_text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
"new_calls": new_calls,
"new_cost": new_cost,
"user_id": user_id,
"period": current_period
})
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
actual_provider = detect_actual_provider(
provider_enum=APIProvider.AUDIO,
model_name="wavespeed-ai/qwen3-tts/voice-design",
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
)
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
method="POST",
model_used="wavespeed-ai/qwen3-tts/voice-design",
actual_provider_name=actual_provider,
tokens_input=char_count,
tokens_output=0,
tokens_total=char_count,
cost_input=0.0,
cost_output=0.0,
cost_total=float(estimated_cost),
response_time=response_time,
status_code=200,
request_size=len(text) + len(voice_description),
response_size=len(preview_audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
db_track.commit()
print(f"""
[SUBSCRIPTION] Qwen3 Voice Design
├─ User: {user_id}
├─ Provider: wavespeed
├─ Model: wavespeed-ai/qwen3-tts/voice-design
├─ Calls: {current_calls_before}{new_calls}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[qwen3_voice_design] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[qwen3_voice_design] ❌ Failed to track usage: {usage_error}", exc_info=True)
return VoiceCloneResult(
preview_audio_bytes=preview_audio_bytes,
provider="wavespeed",
model="wavespeed-ai/qwen3-tts/voice-design",
custom_voice_id="", # No persistent ID for design usually, unless we save it
file_size=len(preview_audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[qwen3_voice_design] Error designing voice: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "Qwen3 voice design failed",
"message": str(e),
},
)
def cosyvoice_voice_clone(
audio_bytes: bytes,
text: str,
*,
reference_text: Optional[str] = None,
audio_mime_type: Optional[str] = None,
user_id: Optional[str] = None,
) -> VoiceCloneResult:
try:
if not user_id:
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
if not audio_bytes or not isinstance(audio_bytes, (bytes, bytearray)) or len(audio_bytes) == 0:
raise ValueError("Audio is required and cannot be empty")
if len(audio_bytes) > 15 * 1024 * 1024:
raise ValueError("Audio file too large. Maximum is 15MB.")
if not text or not isinstance(text, str) or len(text.strip()) == 0:
raise ValueError("Text is required and cannot be empty")
text = text.strip()
if len(text) > 4000:
raise ValueError("Text too long. Please keep it under 4000 characters.")
char_count = len(text)
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
from services.database import get_session_for_user
from services.subscription import PricingService
from models.subscription_models import APIProvider
try:
db = get_session_for_user(user_id)
if not db:
raise RuntimeError("Failed to get database session")
try:
pricing_service = PricingService(db)
can_proceed, message, usage_info = pricing_service.check_usage_limits(
user_id=user_id,
provider=APIProvider.AUDIO,
tokens_requested=char_count,
actual_provider_name="wavespeed",
)
if not can_proceed:
raise HTTPException(
status_code=429,
detail={
"error": message,
"message": message,
"provider": "wavespeed",
"usage_info": usage_info if usage_info else {},
},
)
finally:
db.close()
except HTTPException:
raise
except Exception as sub_error:
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
import time
start_time = time.time()
client = WaveSpeedClient()
preview_audio_bytes = client.cosyvoice_voice_clone(
audio_bytes=bytes(audio_bytes),
text=text,
audio_mime_type=audio_mime_type or "audio/wav",
reference_text=reference_text,
)
response_time = time.time() - start_time
if preview_audio_bytes:
try:
db_track = get_session_for_user(user_id)
if not db_track:
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to get database session for tracking")
raise RuntimeError("Failed to get database session")
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
from sqlalchemy import text as sql_text
from services.subscription.provider_detection import detect_actual_provider
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(user_id=user_id, billing_period=current_period)
db_track.add(summary)
db_track.flush()
current_calls_before = getattr(summary, "audio_calls", 0) or 0
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
new_calls = current_calls_before + 1
new_cost = current_cost_before + float(estimated_cost)
update_query = sql_text("""
UPDATE usage_summaries
SET audio_calls = :new_calls,
audio_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
"new_calls": new_calls,
"new_cost": new_cost,
"user_id": user_id,
"period": current_period
})
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
actual_provider = detect_actual_provider(
provider_enum=APIProvider.AUDIO,
model_name="wavespeed-ai/cosyvoice-tts/voice-clone",
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
)
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.AUDIO,
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
method="POST",
model_used="wavespeed-ai/cosyvoice-tts/voice-clone",
actual_provider_name=actual_provider,
tokens_input=char_count,
tokens_output=0,
tokens_total=char_count,
cost_input=0.0,
cost_output=0.0,
cost_total=float(estimated_cost),
response_time=response_time,
status_code=200,
request_size=len(audio_bytes) + len(text.encode("utf-8")),
response_size=len(preview_audio_bytes),
billing_period=current_period,
)
db_track.add(usage_log)
db_track.commit()
print(f"""
[SUBSCRIPTION] CosyVoice Voice Clone
├─ User: {user_id}
├─ Provider: wavespeed
├─ Model: wavespeed-ai/cosyvoice-tts/voice-clone
├─ Calls: {current_calls_before}{new_calls}
├─ Text chars: {char_count}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
except Exception as track_error:
logger.error(f"[cosyvoice_voice_clone] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to track usage: {usage_error}", exc_info=True)
return VoiceCloneResult(
preview_audio_bytes=preview_audio_bytes,
provider="wavespeed",
model="wavespeed-ai/cosyvoice-tts/voice-clone",
custom_voice_id="",
file_size=len(preview_audio_bytes),
)
except HTTPException:
raise
except RuntimeError:
raise
except Exception as e:
logger.error(f"[cosyvoice_voice_clone] Error cloning voice: {e}", exc_info=True)
raise HTTPException(
status_code=500,
detail={
"error": "CosyVoice voice cloning failed",
"message": str(e),
},
)

View File

@@ -2,6 +2,8 @@ from __future__ import annotations
import os
import io
import base64
import logging
from typing import Optional, Dict, Any
from PIL import Image
@@ -9,6 +11,9 @@ from .image_generation import (
ImageGenerationOptions,
ImageGenerationResult,
)
from .image_generation.base import ImageEditOptions
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
from utils.logger_utils import get_service_logger
try:
@@ -22,21 +27,36 @@ logger = get_service_logger("image_editing.facade")
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
"HF_IMAGE_EDIT_MODEL",
"Qwen/Qwen-Image-Edit",
"WAVESPEED_IMAGE_EDIT_MODEL",
"qwen-edit-plus",
)
def _select_provider(explicit: Optional[str]) -> str:
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
"""
Select the appropriate image editing provider.
Priority:
1. Explicitly requested provider
2. WaveSpeed (if API key available) - Preferred for quality/speed
3. Hugging Face (fallback)
"""
if explicit:
return explicit
# Default to huggingface for image editing (best support for image-to-image)
return explicit.lower()
# Check for WaveSpeed API key first (Preferred provider)
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Default to huggingface if WaveSpeed not available
return "huggingface"
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
"""Get InferenceClient for the specified provider."""
"""Get the client for the specified provider."""
if provider_name == "wavespeed":
return WaveSpeedEditProvider(api_key=api_key)
if not HF_HUB_AVAILABLE:
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
@@ -44,7 +64,7 @@ def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
api_key = api_key or os.getenv("HF_TOKEN")
if not api_key:
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
# Use fal-ai provider for fast inference
# Use fal-ai provider for fast inference via HF Inference API
return InferenceClient(provider="fal-ai", api_key=api_key)
raise ValueError(f"Unknown image editing provider: {provider_name}")
@@ -86,6 +106,8 @@ def edit_image(
from fastapi import HTTPException
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
# Note: get_db() is a generator, so we need to use next() to get the session
# and ensure we close it in the finally block
db = next(get_db())
try:
pricing_service = PricingService(db)
@@ -99,6 +121,9 @@ def edit_image(
# Re-raise immediately - don't proceed with API call
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
raise
except Exception as e:
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
finally:
db.close()
else:
@@ -119,6 +144,69 @@ def edit_image(
# Get provider client
client = _get_provider_client(provider_name, opts.get("api_key"))
if provider_name == "wavespeed":
# Handle WaveSpeed provider
try:
# Convert inputs to base64 for WaveSpeed
image_b64 = base64.b64encode(input_image_bytes).decode('utf-8')
mask_b64 = None
if mask_bytes:
mask_b64 = base64.b64encode(mask_bytes).decode('utf-8')
# Determine operation type based on prompt/mask
operation = "general_edit" # Default
if not prompt and mask_b64:
operation = "remove_bg" # Heuristic: mask but no prompt implies removal/in-painting
elif prompt and not mask_b64:
operation = "style_transfer" # Heuristic: prompt but no mask implies style transfer
elif opts.get("operation"):
operation = opts.get("operation")
edit_options = ImageEditOptions(
image_base64=image_b64,
prompt=prompt.strip(),
operation=operation,
mask_base64=mask_b64,
model=model,
guidance_scale=opts.get("guidance_scale"),
steps=opts.get("steps"),
seed=opts.get("seed"),
extra=opts
)
logger.info(f"[Image Editing] Calling WaveSpeed edit with model={model}")
result = client.edit(edit_options)
# TRACK USAGE after successful WaveSpeed call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (WaveSpeed default: $0.02)
estimated_cost = result.metadata.get("estimated_cost", 0.02) if result.metadata else 0.02
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model=result.model or model,
operation_type="image-editing",
result_bytes=result.image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata=result.metadata,
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return result
except Exception as e:
logger.error(f"[Image Editing] ❌ WaveSpeed editing failed: {e}", exc_info=True)
raise RuntimeError(f"WaveSpeed editing failed: {str(e)}")
# Hugging Face (Fallback)
# Prepare parameters for image-to-image
params: Dict[str, Any] = {}
if opts.get("guidance_scale") is not None:
@@ -170,6 +258,29 @@ def edit_image(
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
# TRACK USAGE after successful HF call
if user_id:
try:
from services.llm_providers.main_image_generation import _track_image_operation_usage
# Estimate cost (HF/Fal-ai default: $0.05)
estimated_cost = 0.05
_track_image_operation_usage(
user_id=user_id,
provider="huggingface",
model=model,
operation_type="image-editing",
result_bytes=edited_image_bytes,
cost=estimated_cost,
prompt=prompt,
endpoint="/image-editing",
metadata={"provider": "fal-ai"},
log_prefix="[Image Editing]"
)
except Exception as track_error:
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
return ImageGenerationResult(
image_bytes=edited_image_bytes,
width=edited_image.width,

View File

@@ -5,6 +5,7 @@ import sys
import base64
from datetime import datetime
from typing import Optional, Dict, Any
from fastapi import HTTPException
from fastapi.concurrency import run_in_threadpool
from .image_generation import (
@@ -29,6 +30,11 @@ logger = get_service_logger("image_generation.facade")
def _select_provider(explicit: Optional[str]) -> str:
if explicit:
return explicit
# User requested WaveSpeed as default provider
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
if gpt_provider.startswith("gemini"):
return "gemini"
@@ -36,8 +42,7 @@ def _select_provider(explicit: Optional[str]) -> str:
return "huggingface"
if os.getenv("STABILITY_API_KEY"):
return "stability"
if os.getenv("WAVESPEED_API_KEY"):
return "wavespeed"
# Fallback to huggingface to enable a path if configured
return "huggingface"
@@ -739,18 +744,139 @@ async def generate_image_with_provider(
}
except Exception as e:
logger.error(f"Error in generate_image_with_provider: {e}")
# Propagate specific error message if available
error_detail = str(e)
if "402" in error_detail or "Payment Required" in error_detail:
raise HTTPException(status_code=402, detail=f"Payment Required: {error_detail}")
return {
"success": False,
"error": str(e)
"error": error_detail
}
import time
from services.database import get_session_for_user
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
async def enhance_image_prompt(prompt: str, user_id: Optional[str] = None) -> str:
"""
Enhance image prompt using LLM.
Placeholder implementation.
Enhance image prompt using WaveSpeed's specialized prompt optimizer.
Restructures and enriches prompts for visual clarity and cinematic detail.
Uses Step 2 (Website Analysis) and Step 3 (Competitor Analysis) context if available.
"""
return prompt
start_time = time.time()
try:
from services.wavespeed.client import WaveSpeedClient
# 1. Pre-flight Validation
if user_id:
_validate_image_operation(
user_id=user_id,
operation_type="prompt-enhancement",
num_operations=1,
log_prefix="[Prompt Enhancement]"
)
# 2. Fetch Context from Step 2 & 3
context_instruction = ""
if user_id:
try:
db_session = get_session_for_user(user_id)
try:
# Get Onboarding Session
session = db_session.query(OnboardingSession).filter(
OnboardingSession.user_id == user_id
).first()
if session:
# Step 2: Website Analysis
website_analysis = db_session.query(WebsiteAnalysis).filter(
WebsiteAnalysis.session_id == session.id
).first()
if website_analysis:
# Handle potential JSON or dict types
brand_voice = website_analysis.brand_analysis
style = website_analysis.style_guidelines
target_audience = website_analysis.target_audience
context_instruction += "\n\nCONTEXT FROM WEBSITE ANALYSIS:\n"
if target_audience:
context_instruction += f"Target Audience: {target_audience}\n"
if brand_voice and isinstance(brand_voice, dict):
context_instruction += f"Brand Voice: {brand_voice.get('voice_characteristics', '')} - {brand_voice.get('tone', '')}\n"
if style and isinstance(style, dict):
context_instruction += f"Visual Style: {style.get('visual_style', '')} - {style.get('color_palette', '')}\n"
# Step 3: Competitor Analysis (Limit to top 3)
competitors = db_session.query(CompetitorAnalysis).filter(
CompetitorAnalysis.session_id == session.id
).limit(3).all()
if competitors:
context_instruction += "\nCOMPETITOR VISUAL INSIGHTS:\n"
for comp in competitors:
if comp.analysis_data and isinstance(comp.analysis_data, dict):
comp_title = comp.analysis_data.get('title', 'Competitor')
# Try to extract visual/content insights if available
highlights = comp.analysis_data.get('highlights', [])
if highlights:
context_instruction += f"- {comp_title}: {', '.join(highlights[:2])}\n"
finally:
db_session.close()
except Exception as db_ex:
logger.warning(f"Failed to fetch context for prompt enhancement: {db_ex}")
# Combine prompt with context
full_input_text = prompt
if context_instruction:
logger.info(f"Enhancing prompt for user {user_id} with Step 2/3 context")
# We append context as instruction for the optimizer
full_input_text = f"Original Request: {prompt}\n\n{context_instruction}\n\nTask: Generate a hyper-personalized, detailed image generation prompt based on the Original Request and the provided Context. Ensure the visual style aligns with the Brand Voice and Visual Style."
else:
logger.info(f"Enhancing prompt for user {user_id} (no context found)")
# 3. Call WaveSpeed
client = WaveSpeedClient()
# Use 'image' mode for avatar/image generation workflows
# Use 'photographic' style as requested for avatars
optimized_prompt = client.optimize_prompt(
text=full_input_text,
mode="image",
style="photographic",
enable_sync_mode=True,
timeout=30
)
# 4. Track Usage
if user_id:
duration = time.time() - start_time
# Track as 0 cost for now unless we have specific pricing for prompt opt
# But we track it as an operation
_track_image_operation_usage(
user_id=user_id,
provider="wavespeed",
model="wavespeed-prompt-opt",
operation_type="prompt-enhancement",
result_bytes=b"", # No image
cost=0.0,
prompt=prompt,
endpoint="/enhance-prompt",
metadata={"duration": duration, "context_added": bool(context_instruction)},
log_prefix="[Prompt Enhancement]",
response_time=duration
)
return optimized_prompt
except Exception as e:
logger.error(f"Failed to enhance prompt via WaveSpeed: {e}")
# Fallback to original prompt on failure
return prompt
async def generate_image_variation(
@@ -760,13 +886,123 @@ async def generate_image_variation(
**kwargs
) -> Dict[str, Any]:
"""
Generate variation of an existing image.
Placeholder implementation.
Generate variation of an existing image using image-to-image editing.
Wrapper for step4_asset_routes.
"""
return {
"success": False,
"error": "Not implemented yet"
}
try:
# Handle image input (bytes, file, or base64)
image_bytes = None
if isinstance(image, bytes):
image_bytes = image
elif hasattr(image, "read"):
image_bytes = await image.read()
elif isinstance(image, str):
# Assume base64 or path
if os.path.exists(image):
with open(image, "rb") as f:
image_bytes = f.read()
else:
# Try base64 decode
try:
if "base64," in image:
image = image.split("base64,")[1]
image_bytes = base64.b64decode(image)
except:
pass
if not image_bytes:
return {"success": False, "error": "Invalid image input"}
# Convert to base64 for internal function
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Use generate_image_edit with "variation" intent
# For variation, we typically use general_edit with specific prompt
result = await run_in_threadpool(
generate_image_edit,
image_base64=image_base64,
prompt=prompt,
operation="general_edit",
model=kwargs.get("model", "qwen-edit-plus"), # Default to capable model
options=kwargs,
user_id=user_id
)
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
return {
"success": True,
"image_base64": result_base64,
"metadata": result.metadata
}
except Exception as e:
logger.error(f"Error in generate_image_variation: {e}")
return {
"success": False,
"error": str(e)
}
async def generate_image_enhance(
image: Any,
user_id: Optional[str] = None,
**kwargs
) -> Dict[str, Any]:
"""
Enhance/Upscale an existing image.
Wrapper for step4_asset_routes.
"""
try:
# Handle image input
image_bytes = None
if isinstance(image, bytes):
image_bytes = image
elif hasattr(image, "read"):
image_bytes = await image.read()
elif isinstance(image, str):
if os.path.exists(image):
with open(image, "rb") as f:
image_bytes = f.read()
else:
try:
if "base64," in image:
image = image.split("base64,")[1]
image_bytes = base64.b64decode(image)
except:
pass
if not image_bytes:
return {"success": False, "error": "Invalid image input"}
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
# Use generate_image_edit with "enhance" intent
# Use high-res model like nano-banana-pro-edit-ultra
result = await run_in_threadpool(
generate_image_edit,
image_base64=image_base64,
prompt="enhance details, high resolution, professional quality, 4k, sharp focus",
operation="general_edit",
model="nano-banana-pro-edit-ultra",
options={**kwargs, "resolution": "4k"},
user_id=user_id
)
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
return {
"success": True,
"image_base64": result_base64,
"metadata": result.metadata
}
except Exception as e:
logger.error(f"Error in generate_image_enhance: {e}")
return {
"success": False,
"error": str(e)
}

View File

@@ -260,335 +260,23 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if response_text:
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = get_session_for_user(user_id)
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
tokens_total = tokens_input + tokens_output
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
# This ensures we always get the absolute latest committed values, even across different sessions
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
record_count = 0 # Initialize to ensure it's always defined
# CRITICAL: First check if record exists using COUNT query
try:
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
except Exception as count_error:
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
record_count = 0
if record_count and record_count > 0:
# Record exists - read current values with raw SQL
try:
# Validate provider_name to prevent SQL injection (whitelist approach)
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
raw_calls = result[0] if result[0] is not None else 0
raw_tokens = result[1] if result[1] is not None else 0
current_calls_before = raw_calls
current_tokens_before = raw_tokens
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
# Fallback: Use ORM to get values
summary_fallback = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary_fallback:
db_track.refresh(summary_fallback)
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
else:
logger.debug(f"[llm_text_gen] No record exists yet (will create new) - user={user_id}, period={current_period}")
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
# New record - values are already 0, no need to set
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
# Using raw SQL UPDATE ensures the change is persisted
from sqlalchemy import text
update_calls_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_calls = :new_calls
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_calls_query, {
'new_calls': new_calls,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
# Update token usage for LLM providers with safety check
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
# The ORM object may have stale values, but raw SQL always has the latest committed values
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
update_tokens_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_tokens = :new_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_tokens_query, {
'new_tokens': new_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens (after any safety capping)
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this call
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
cost_total = 0.0
logger.error(f"[llm_text_gen] ❌ Failed to calculate API cost: {cost_error}", exc_info=True)
if cost_total > 0:
logger.debug(f"[llm_text_gen] 💰 Calculated cost for {provider_name}: ${cost_total:.6f}")
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
# Keep ORM object in sync for logging/debugging
current_provider_cost = getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
setattr(summary, f"{provider_name}_cost", current_provider_cost + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
else:
logger.debug(f"[llm_text_gen] 💰 Cost calculation returned $0 for {provider_name} (tokens_input={tracked_tokens_input}, tokens_output={tracked_tokens_output})")
# Update totals using SQL UPDATE
old_total_calls = summary.total_calls or 0
old_total_tokens = summary.total_tokens or 0
new_total_calls = old_total_calls + 1
new_total_tokens = old_total_tokens + tokens_total
update_totals_query = text("""
UPDATE usage_summaries
SET total_calls = :total_calls, total_tokens = :total_tokens
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_totals_query, {
'total_calls': new_total_calls,
'total_tokens': new_total_tokens,
'user_id': user_id,
'period': current_period
})
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
import sys
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
print(debug_msg, flush=True)
sys.stdout.flush()
logger.debug(f"[llm_text_gen] {debug_msg}")
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
try:
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result:
verified_calls = verify_result[0] if verify_result[0] is not None else 0
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
if verified_calls != new_calls or verified_tokens != new_tokens:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
# Force another commit attempt
db_track.commit()
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
if verify_result2:
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
else:
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
else:
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
except Exception as verify_error:
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
# DEBUG: Log the actual values being used
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
print(f"""
[SUBSCRIPTION] LLM Text Generation
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
# Ideally we should track start_time at beginning of function
duration = 0.5
track_agent_usage_sync(
user_id=user_id,
model_name=model,
prompt=prompt,
response_text=response_text,
duration=duration
)
except Exception as usage_error:
# Non-blocking: log error but don't fail the request
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
@@ -661,208 +349,18 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
if response_text:
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
try:
db_track = get_session_for_user(user_id)
try:
# Estimate tokens from prompt and response
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
tokens_input = int(len(prompt.split()) * 1.3)
tokens_output = int(len(str(response_text).split()) * 1.3)
tokens_total = tokens_input + tokens_output
# Get or create usage summary
from models.subscription_models import UsageSummary
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
provider_name = provider_enum.value
limits = pricing.get_user_limits(user_id)
token_limit = 0
if limits and limits.get('limits'):
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
from sqlalchemy import text
current_calls_before = 0
current_tokens_before = 0
try:
# Validate provider_name to prevent SQL injection
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
if provider_name not in valid_providers:
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
# Read current values directly from database using raw SQL
sql_query = text(f"""
SELECT {provider_name}_calls, {provider_name}_tokens
FROM usage_summaries
WHERE user_id = :user_id AND billing_period = :period
LIMIT 1
""")
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
if result:
current_calls_before = result[0] if result[0] is not None else 0
current_tokens_before = result[1] if result[1] is not None else 0
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
except Exception as sql_error:
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
# Fallback to ORM query if raw SQL fails
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if summary:
db_track.refresh(summary)
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
# Get or create usage summary object (needed for ORM update)
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush() # Ensure summary is persisted before updating
else:
# CRITICAL: Update the ORM object with values from raw SQL query
# This ensures the ORM object reflects the actual database state before we increment
setattr(summary, f"{provider_name}_calls", current_calls_before)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
# Get "before" state for unified log (from raw SQL query)
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
# Update provider-specific counters (sync operation)
new_calls = current_calls_before + 1
setattr(summary, f"{provider_name}_calls", new_calls)
# Update token usage for LLM providers with safety check
# Use current_tokens_before from raw SQL query (most reliable)
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
projected_new_tokens = current_tokens_before + tokens_total
# If limit is set (> 0) and would be exceeded, cap at limit
if token_limit > 0 and projected_new_tokens > token_limit:
logger.warning(
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
f"Capping tracked tokens at limit to prevent abuse."
)
# Cap at limit to prevent abuse
new_tokens = token_limit
# Adjust tokens_total for accurate total tracking
tokens_total = token_limit - current_tokens_before
if tokens_total < 0:
tokens_total = 0
else:
new_tokens = projected_new_tokens
setattr(summary, f"{provider_name}_tokens", new_tokens)
else:
current_tokens_before = 0
new_tokens = 0
# Determine tracked tokens after any safety capping
tracked_tokens_input = min(tokens_input, tokens_total)
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
# Calculate and persist cost for this fallback call
cost_total = 0.0
try:
cost_info = pricing.calculate_api_cost(
provider=provider_enum,
model_name=fallback_model,
tokens_input=tracked_tokens_input,
tokens_output=tracked_tokens_output,
request_count=1
)
cost_total = cost_info.get('cost_total', 0.0) or 0.0
except Exception as cost_error:
logger.error(f"[llm_text_gen] ❌ Failed to calculate fallback cost: {cost_error}", exc_info=True)
if cost_total > 0:
update_costs_query = text(f"""
UPDATE usage_summaries
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
total_cost = COALESCE(total_cost, 0) + :cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_costs_query, {
'cost': cost_total,
'user_id': user_id,
'period': current_period
})
setattr(summary, f"{provider_name}_cost", (getattr(summary, f"{provider_name}_cost", 0.0) or 0.0) + cost_total)
summary.total_cost = (summary.total_cost or 0.0) + cost_total
# Update totals (using potentially capped tokens_total from safety check)
summary.total_calls = (summary.total_calls or 0) + 1
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
# Get plan details for unified log (limits already retrieved above)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
# Get image stats for unified log
current_images_before = getattr(summary, "stability_calls", 0) or 0
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
# Get image editing stats for unified log
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
# Get video stats for unified log
current_video_calls = getattr(summary, "video_calls", 0) or 0
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
# Get audio stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
db_track.commit() # Commit transaction to make changes visible to other sessions
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
# UNIFIED SUBSCRIPTION LOG for fallback
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
# Include image stats in the log
print(f"""
[SUBSCRIPTION] LLM Text Generation (Fallback)
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {actual_provider_name}
├─ Model: {fallback_model}
├─ Calls: {current_calls_before}{new_calls} / {call_limit if call_limit > 0 else ''}
├─ Tokens: {current_tokens_before}{new_tokens} / {token_limit if token_limit > 0 else ''}
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else ''}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else ''}
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else ''}
├─ Audio: {current_audio_calls} / {audio_limit_display}
└─ Status: ✅ Allowed & Tracked
""")
except Exception as track_error:
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
db_track.rollback()
finally:
db_track.close()
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
# Estimate tokens
tokens_input = int(len(prompt.split()) * 1.3)
track_agent_usage_sync(
user_id=user_id,
model_name=fallback_model,
prompt=prompt,
response_text=response_text,
duration=0.5 # Approximate duration
)
except Exception as usage_error:
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)

View File

@@ -36,6 +36,172 @@ class VideoProviderNotImplemented(Exception):
pass
def _track_video_operation_usage(
user_id: str,
provider: str,
model: str,
operation_type: str,
result_bytes: bytes,
cost: float,
prompt: Optional[str] = None,
endpoint: str = "/video-generation",
metadata: Optional[Dict[str, Any]] = None,
log_prefix: str = "[Video Generation]",
response_time: float = 0.0
) -> Dict[str, Any]:
"""
Reusable usage tracking helper for all video operations.
Args:
user_id: User ID for tracking
provider: Provider name
model: Model name used
operation_type: Type of operation (for logging)
result_bytes: Generated video bytes
cost: Cost of the operation
prompt: Optional prompt text
endpoint: API endpoint path
metadata: Optional additional metadata
log_prefix: Logging prefix
response_time: API response time
Returns:
Dictionary with tracking information
"""
try:
from services.database import get_session_for_user
db_track = get_session_for_user(user_id)
try:
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
from services.subscription import PricingService
pricing = PricingService(db_track)
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Get or create usage summary
summary = db_track.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period == current_period
).first()
if not summary:
summary = UsageSummary(
user_id=user_id,
billing_period=current_period
)
db_track.add(summary)
db_track.flush()
# Get current values before update
current_calls_before = getattr(summary, "video_calls", 0) or 0
current_cost_before = getattr(summary, "video_cost", 0.0) or 0.0
# Update video calls and cost
new_calls = current_calls_before + 1
new_cost = current_cost_before + cost
# Use direct SQL UPDATE for dynamic attributes
from sqlalchemy import text as sql_text
update_query = sql_text("""
UPDATE usage_summaries
SET video_calls = :new_calls,
video_cost = :new_cost
WHERE user_id = :user_id AND billing_period = :period
""")
db_track.execute(update_query, {
'new_calls': new_calls,
'new_cost': new_cost,
'user_id': user_id,
'period': current_period
})
# Update total cost
summary.total_cost = (summary.total_cost or 0.0) + cost
summary.total_calls = (summary.total_calls or 0) + 1
summary.updated_at = datetime.utcnow()
# Create usage log
request_size = len(prompt.encode("utf-8")) if prompt else 0
usage_log = APIUsageLog(
user_id=user_id,
provider=APIProvider.WAVESPEED, # Default for video
endpoint=endpoint,
method="POST",
model_used=model or "unknown",
actual_provider_name=provider,
tokens_input=0,
tokens_output=0,
tokens_total=0,
cost_input=0.0,
cost_output=0.0,
cost_total=cost,
response_time=response_time,
status_code=200,
request_size=request_size,
response_size=len(result_bytes) if result_bytes else 0,
billing_period=current_period,
)
db_track.add(usage_log)
# Get plan details for unified log
limits = pricing.get_user_limits(user_id)
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
tier = limits.get('tier', 'unknown') if limits else 'unknown'
# Get limits for display
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
video_limit_display = video_limit if (video_limit > 0 or tier != 'enterprise') else ''
# Get related stats for unified log
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else ''
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
image_edit_limit_display = image_edit_limit if (image_edit_limit > 0 or tier != 'enterprise') else ''
db_track.commit()
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
# UNIFIED SUBSCRIPTION LOG
operation_name = operation_type.replace("-", " ").title()
print(f"""
[SUBSCRIPTION] {operation_name}
├─ User: {user_id}
├─ Plan: {plan_name} ({tier})
├─ Provider: {provider}
├─ Actual Provider: {provider}
├─ Model: {model or 'unknown'}
├─ Calls: {current_calls_before}{new_calls} / {video_limit_display}
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
├─ Audio: {current_audio_calls} / {audio_limit_display}
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit_display}
└─ Status: ✅ Allowed & Tracked
""", flush=True)
sys.stdout.flush()
return {
"current_calls": new_calls,
"cost": cost,
"total_cost": new_cost,
}
except Exception as track_error:
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
db_track.rollback()
return {}
finally:
db_track.close()
except Exception as usage_error:
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
import traceback
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
return {}
def _get_api_key(provider: str) -> Optional[str]:
try:
manager = APIKeyManager()
@@ -500,156 +666,74 @@ async def ai_video_generate(
raise
finally:
db.close()
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
# Progress callback: Initial submission
if progress_callback:
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
# Generate video based on operation type
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
# Track response time for video generation
# Track response time
import time
from datetime import datetime
start_time = time.time()
# Execute operation based on type
result = {}
try:
if operation_type == "text-to-video":
if provider == "huggingface":
video_bytes = _generate_with_huggingface(
prompt=prompt,
**kwargs,
)
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
result_dict = {
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
result = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10, # Default cost, will be calculated in track_video_usage
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280, # Default, actual may vary
"height": 720, # Default, actual may vary
"metadata": {},
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
"provider": "huggingface",
"cost": 0.0, # HuggingFace inference is free/low cost
}
elif provider == "wavespeed":
# WaveSpeed text-to-video - use unified service
result_dict = await _generate_text_to_video_wavespeed(
result = await _generate_text_to_video_wavespeed(
prompt=prompt,
progress_callback=progress_callback,
**kwargs,
**kwargs
)
elif provider == "gemini":
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
result = {"video_bytes": _generate_with_gemini(prompt, **kwargs)}
elif provider == "openai":
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
result_dict = {
"video_bytes": video_bytes,
"prompt": prompt,
"duration": kwargs.get("duration", 5.0),
"model_name": model_name,
"cost": 0.10,
"provider": provider,
"resolution": kwargs.get("resolution", "720p"),
"width": 1280,
"height": 720,
"metadata": {},
}
result = {"video_bytes": _generate_with_openai(prompt, **kwargs)}
else:
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
raise ValueError(f"Unknown provider for text-to-video: {provider}")
elif operation_type == "image-to-video":
if provider == "wavespeed":
# Progress callback: Starting generation
if progress_callback:
progress_callback(20.0, "Video generation in progress...")
# Handle async call from sync context
# Since ai_video_generate is sync, we need to run async function
try:
loop = asyncio.get_event_loop()
if loop.is_running():
# We're in an async context - use ThreadPoolExecutor to run in new event loop
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(
asyncio.run,
_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
)
)
result_dict = future.result()
else:
# Event loop exists but not running - use it
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
except RuntimeError:
# No event loop exists, create a new one
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or kwargs.get("prompt", ""),
progress_callback=progress_callback,
**kwargs
))
video_bytes = result_dict["video_bytes"]
model_name = result_dict.get("model_name", model_name)
# Progress callback: Processing result
if progress_callback:
progress_callback(90.0, "Processing video result...")
result = await _generate_image_to_video_wavespeed(
image_data=image_data,
image_base64=image_base64,
prompt=prompt or "",
progress_callback=progress_callback,
**kwargs
)
else:
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
raise ValueError(f"Unknown provider for image-to-video: {provider}")
# Track usage (same pattern as text generation)
# Use cost from result_dict if available, otherwise calculate
response_time = time.time() - start_time
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
track_video_usage(
user_id=user_id,
provider=provider,
model_name=model_name,
prompt=result_dict.get("prompt", prompt or ""),
video_bytes=video_bytes,
cost_override=cost_override,
response_time=response_time,
)
# Progress callback: Complete
if progress_callback:
progress_callback(100.0, "Video generation complete!")
return result_dict
except HTTPException:
# Re-raise HTTPExceptions (e.g., from validation or API errors)
raise
# TRACK USAGE after successful API call
video_bytes = result.get("video_bytes")
if user_id and video_bytes:
_track_video_operation_usage(
user_id=user_id,
provider=result.get("provider", provider),
model=result.get("model_name", kwargs.get("model", "unknown")),
operation_type=operation_type,
result_bytes=video_bytes,
cost=result.get("cost", 0.0),
prompt=prompt,
endpoint="/video-generation",
metadata=result.get("metadata"),
log_prefix=f"[{operation_type.replace('-', ' ').title()}]",
response_time=response_time
)
return result
except Exception as e:
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
raise HTTPException(status_code=500, detail={"error": str(e)})
# Log failure but don't track usage (no cost incurred)
logger.error(f"[video_gen] Generation failed: {str(e)}")
raise
def _get_default_model(operation_type: str, provider: str) -> str:

View File

@@ -46,6 +46,9 @@ class CorePersonaService:
# Get schema for structured response
persona_schema = self.prompt_builder.get_persona_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
try:
# Generate structured response using Gemini
response = gemini_structured_json_response(
@@ -53,7 +56,8 @@ class CorePersonaService:
schema=persona_schema,
temperature=0.2, # Low temperature for consistent analysis
max_tokens=8192,
system_prompt="You are an expert writing style analyst and persona developer. Analyze the provided data to create a precise, actionable writing persona."
system_prompt="You are an expert writing style analyst and persona developer. Analyze the provided data to create a precise, actionable writing persona.",
user_id=user_id
)
if "error" in response:
@@ -103,13 +107,17 @@ class CorePersonaService:
# Get platform-specific schema
platform_schema = self.prompt_builder.get_platform_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
try:
response = gemini_structured_json_response(
prompt=prompt,
schema=platform_schema,
temperature=0.2,
max_tokens=4096,
system_prompt=f"You are an expert in {platform} content strategy and platform-specific writing optimization."
system_prompt=f"You are an expert in {platform} content strategy and platform-specific writing optimization.",
user_id=user_id
)
return response

View File

@@ -62,6 +62,9 @@ class FacebookPersonaService:
# Get Facebook-specific schema
schema = self._get_enhanced_facebook_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
# Generate structured response using Gemini with optimized prompts
response = gemini_structured_json_response(
@@ -69,7 +72,8 @@ class FacebookPersonaService:
schema=schema,
temperature=0.2,
max_tokens=4096,
system_prompt=system_prompt
system_prompt=system_prompt,
user_id=user_id
)
if not response or "error" in response:

View File

@@ -54,13 +54,17 @@ class LinkedInPersonaService:
# Get LinkedIn-specific schema
schema = self.schemas.get_enhanced_linkedin_schema()
# Extract user_id for tracking
user_id = onboarding_data.get("session_info", {}).get("user_id")
# Generate structured response using Gemini with optimized prompts
response = gemini_structured_json_response(
prompt=prompt,
schema=schema,
temperature=0.2,
max_tokens=4096,
system_prompt=system_prompt
system_prompt=system_prompt,
user_id=user_id
)
if "error" in response:

View File

@@ -56,6 +56,17 @@ async def check_and_execute_due_tasks(scheduler: 'TaskScheduler'):
continue
try:
# Check onboarding status first
# Skip users who haven't completed onboarding to prevent premature agent initialization
from services.onboarding.progress_service import OnboardingProgressService
onboarding_service = OnboardingProgressService()
status = onboarding_service.get_onboarding_status(user_id)
if not status.get("is_completed", False):
# Skip logging for inactive users to reduce noise, unless debugging
# logger.debug(f"[Scheduler Check] Skipping user {user_id} - Onboarding incomplete")
continue
# Check active strategies for this user (for interval adjustment)
try:
from services.active_strategy_service import ActiveStrategyService

View File

@@ -67,6 +67,27 @@ class SIFIndexingExecutor(TaskExecutor):
# 2. Sync User Website Content (Deep Crawl / Snapshot)
content_synced = await sif_service.sync_user_website_content(website_url)
# 3. Trigger Content Guardian Audit (Background Analysis)
# This ensures the agent runs immediately after new data is indexed
guardian_report = None
if content_synced:
try:
from services.intelligence.agents.specialized_agents import ContentGuardianAgent
# Re-use the intelligence service from sif_service
guardian_agent = ContentGuardianAgent(
intelligence_service=sif_service.intelligence_service,
user_id=user_id,
sif_service=sif_service
)
logger.info("Triggering Content Guardian Site Audit...")
guardian_report = await guardian_agent.perform_site_audit(website_url)
# Persist the audit report (optional, or rely on logs/alerts)
# For now, we just include it in the task result
except Exception as e:
logger.error(f"Failed to run Content Guardian audit: {e}")
# Determine overall success
# We consider it a success if at least one operation worked, or if both were attempted without error
# But ideally, content sync is the heavy lifter.
@@ -91,6 +112,7 @@ class SIFIndexingExecutor(TaskExecutor):
task_log.result_data = {
"metadata_synced": metadata_synced,
"content_synced": content_synced,
"guardian_report": guardian_report,
"website_url": website_url
}
task_log.execution_time_ms = int((time.time() - start_time) * 1000)

View File

@@ -29,9 +29,10 @@ def load_due_sif_indexing_tasks(db: Session, user_id: str = None) -> List[SIFInd
query = db.query(SIFIndexingTask).filter(
or_(
SIFIndexingTask.status == "pending",
SIFIndexingTask.status == "active",
SIFIndexingTask.status == "failed" # Retry failed tasks
),
SIFIndexingTask.next_run_at <= datetime.utcnow()
SIFIndexingTask.next_execution <= datetime.utcnow()
)
if user_id:

View File

@@ -199,6 +199,24 @@ class PricingService:
"cost_per_input_token": 0.0, # No additional token cost for grounding
"cost_per_output_token": 0.0, # No additional token cost for grounding
"description": "Grounding with Google Search - 1,500 RPD free, then $35/1K requests"
},
# Alwrity Voice Cloning - Qwen3
{
"provider": APIProvider.AUDIO,
"model_name": "alwrity-ai/qwen3-tts/voice-clone",
"cost_per_request": 0.10,
"cost_per_input_token": 0.00001,
"cost_per_output_token": 0.0,
"description": "Alwrity Qwen3 Voice Clone (Efficient)"
},
# Alwrity Voice Cloning - CosyVoice
{
"provider": APIProvider.AUDIO,
"model_name": "alwrity-ai/cosyvoice/voice-clone",
"cost_per_request": 0.15,
"cost_per_input_token": 0.00001,
"cost_per_output_token": 0.0,
"description": "Alwrity CosyVoice Clone (High Fidelity)"
}
]
@@ -402,11 +420,19 @@ class PricingService:
{
"provider": APIProvider.AUDIO,
"model_name": "wavespeed-ai/qwen3-tts/voice-clone",
"cost_per_request": 0.0,
"cost_per_input_token": 0.0,
"cost_per_request": 0.005,
"cost_per_input_token": 0.00005,
"cost_per_output_token": 0.0,
"description": "Qwen3-TTS Voice Clone via WaveSpeed (cost depends on text length)"
},
{
"provider": APIProvider.AUDIO,
"model_name": "wavespeed-ai/cosyvoice-tts/voice-clone",
"cost_per_request": 0.005,
"cost_per_input_token": 0.00005,
"cost_per_output_token": 0.0,
"description": "CosyVoice-TTS Voice Clone via WaveSpeed (cost depends on text length)"
},
{
"provider": APIProvider.AUDIO,
"model_name": "default",
@@ -429,8 +455,9 @@ class PricingService:
if existing:
# Update existing pricing (especially for HuggingFace if env vars changed)
if pricing_data["provider"] == APIProvider.MISTRAL:
# Update HuggingFace pricing from env vars
if pricing_data["provider"] in [APIProvider.MISTRAL, APIProvider.AUDIO]:
# Update pricing
existing.cost_per_request = pricing_data.get("cost_per_request", 0.0)
existing.cost_per_input_token = pricing_data["cost_per_input_token"]
existing.cost_per_output_token = pricing_data["cost_per_output_token"]
existing.description = pricing_data["description"]

View File

@@ -490,6 +490,32 @@ class UsageTrackingService:
'cost': image_edit_cost
}
# WaveSpeed (aggregated across Video, Audio, Image, Image Edit)
# Query APIUsageLog directly to get accurate WaveSpeed-specific usage
wavespeed_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.billing_period == billing_period,
APIUsageLog.actual_provider_name == "wavespeed"
).all()
if wavespeed_logs:
wavespeed_calls = len(wavespeed_logs)
wavespeed_tokens = sum((log.tokens_total or 0) for log in wavespeed_logs)
wavespeed_cost = sum(float(log.cost_total or 0.0) for log in wavespeed_logs)
provider_breakdown['wavespeed'] = {
'calls': wavespeed_calls,
'tokens': wavespeed_tokens,
'cost': wavespeed_cost
}
logger.info(f"[UsageStats] Calculated WaveSpeed usage: {wavespeed_calls} calls, ${wavespeed_cost:.6f}")
else:
provider_breakdown['wavespeed'] = {
'calls': 0,
'tokens': 0,
'cost': 0.0
}
# Search APIs
tavily_calls = getattr(summary, "tavily_calls", 0) or 0
tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0

View File

@@ -12,6 +12,7 @@ from loguru import logger
from services.image_studio.infinitetalk_adapter import InfiniteTalkService
from services.video_studio.hunyuan_avatar_adapter import HunyuanAvatarService
from utils.logger_utils import get_service_logger
from services.llm_providers.main_video_generation import _track_video_operation_usage
logger = get_service_logger("video_studio.avatar")
@@ -58,6 +59,30 @@ class AvatarStudioService:
f"[AvatarStudio] Creating talking avatar: user={user_id}, resolution={resolution}, model={model}"
)
# PRE-FLIGHT VALIDATION: Validate video generation before API call
# MUST happen BEFORE any API calls - return immediately if validation fails
from services.database import get_db
from services.subscription import PricingService
from services.subscription.preflight_validator import validate_video_generation_operations
db = next(get_db())
try:
pricing_service = PricingService(db)
# Raises HTTPException immediately if validation fails - frontend gets immediate response
validate_video_generation_operations(
pricing_service=pricing_service,
user_id=user_id
)
except HTTPException:
# Re-raise immediately - don't proceed with API call
logger.error(f"[AvatarStudio] ❌ Pre-flight validation failed - blocking API call")
raise
finally:
db.close()
import time
start_time = time.time()
try:
if model == "hunyuan-avatar":
# Use Hunyuan Avatar (doesn't support mask_image)
@@ -82,12 +107,32 @@ class AvatarStudioService:
user_id=user_id,
)
response_time = time.time() - start_time
logger.info(
f"[AvatarStudio] ✅ Talking avatar created: "
f"model={model}, resolution={resolution}, duration={result.get('duration', 0)}s, "
f"cost=${result.get('cost', 0):.2f}"
)
# TRACK USAGE after successful API call
# Use video_bytes if available, otherwise check if result itself is bytes (unlikely, dict expected)
video_bytes = result.get("video_bytes")
if user_id and video_bytes:
_track_video_operation_usage(
user_id=user_id,
provider=model, # Use model name as provider/actual_provider for now
model=model,
operation_type="talking-avatar",
result_bytes=video_bytes,
cost=result.get("cost", 0.0),
prompt=prompt,
endpoint="/avatar-generation",
metadata=result,
log_prefix="[Avatar Generation]",
response_time=response_time
)
return result
except HTTPException:

View File

@@ -324,6 +324,39 @@ class WaveSpeedClient:
timeout=timeout,
)
def voice_design(
self,
text: str,
voice_description: str,
language: str = "auto",
timeout: int = 180,
) -> bytes:
return self.speech.voice_design(
text=text,
voice_description=voice_description,
language=language,
timeout=timeout,
)
def cosyvoice_voice_clone(
self,
audio_bytes: bytes,
text: str,
*,
model: str = "wavespeed-ai/cosyvoice-tts/voice-clone",
audio_mime_type: str = "audio/wav",
reference_text: Optional[str] = None,
timeout: int = 180,
) -> bytes:
return self.speech.cosyvoice_voice_clone(
audio_bytes=audio_bytes,
text=text,
model=model,
audio_mime_type=audio_mime_type,
reference_text=reference_text,
timeout=timeout,
)
def generate_text_video(
self,
prompt: str,

View File

@@ -146,14 +146,44 @@ class PromptGenerator:
if isinstance(first_output, str):
if first_output.startswith("http://") or first_output.startswith("https://"):
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
url_response = requests.get(first_output, timeout=timeout)
if url_response.status_code == 200:
return url_response.text.strip()
else:
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
# Use stream=True to avoid downloading large files into memory
try:
with requests.get(first_output, timeout=timeout, stream=True) as url_response:
if url_response.status_code == 200:
# Check Content-Length if available
content_length = url_response.headers.get("Content-Length")
if content_length and int(content_length) > 1024 * 1024: # 1MB limit for prompts
logger.error(f"[WaveSpeed] Optimized prompt URL content too large: {content_length} bytes")
raise HTTPException(
status_code=502,
detail="WaveSpeed prompt optimizer returned a file that is too large",
)
# Read content with limit
content = ""
for chunk in url_response.iter_content(chunk_size=8192, decode_unicode=True):
if chunk:
content += chunk
if len(content) > 1024 * 1024: # Hard limit 1MB
logger.error("[WaveSpeed] Optimized prompt URL content exceeded 1MB limit during download")
raise HTTPException(
status_code=502,
detail="WaveSpeed prompt optimizer returned a file that is too large",
)
return content.strip()
else:
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
raise HTTPException(
status_code=502,
detail="Failed to fetch optimized prompt from WaveSpeed URL",
)
except requests.RequestException as e:
logger.error(f"[WaveSpeed] Error fetching prompt from URL: {str(e)}")
raise HTTPException(
status_code=502,
detail="Failed to fetch optimized prompt from WaveSpeed URL",
detail=f"Error fetching optimized prompt: {str(e)}",
)
else:
# It's already the text

View File

@@ -181,6 +181,102 @@ class SpeechGenerator:
audio_url = self._extract_audio_url(outputs)
return self._download_audio(audio_url, timeout)
def voice_design(
self,
text: str,
voice_description: str,
language: str = "auto",
timeout: int = 180,
) -> bytes:
"""
Generate speech using Qwen3 Voice Design (text + voice description).
"""
url = f"{self.base_url}/wavespeed-ai/qwen3-tts/voice-design"
payload = {
"text": text,
"voice_description": voice_description,
"language": language
}
logger.info(f"[WaveSpeed] Voice design via {url}")
try:
response = requests.post(
url,
headers=self._get_headers(),
json=payload,
timeout=(30, 90),
)
except requests_exceptions.Timeout as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed Voice Design timed out", "message": str(e)})
except (requests_exceptions.ConnectionError, requests_exceptions.ConnectTimeout) as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed Voice Design connection failed", "message": str(e)})
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code,
detail={"error": "WaveSpeed Voice Design failed", "message": response.text}
)
try:
data = response.json()
# The API is async and returns a task ID or direct output depending on implementation.
# Based on user input, it returns a "data" object with "id" and we poll.
# BUT wait, the Python example provided by user shows:
# response = requests.post(url, ...)
# if response.status_code == 200: result = response.json()["data"] ...
# Then it polls /api/v3/predictions/{request_id}/result
# Let's handle the async polling logic here or in the caller.
# The user's Python example is very clear. It's an async task.
if "data" in data and "id" in data["data"]:
request_id = data["data"]["id"]
return self._poll_prediction_result(request_id, timeout=timeout)
# Fallback if it returns direct output (unlikely based on docs)
if "data" in data and "outputs" in data["data"] and data["data"]["outputs"]:
return self._download_audio(data["data"]["outputs"][0]["url"], timeout) # Assuming structure
raise ValueError(f"Unexpected response format: {data}")
except Exception as e:
logger.error(f"[WaveSpeed] Error parsing Voice Design response: {e}")
raise HTTPException(status_code=500, detail={"error": "Failed to parse Voice Design response", "message": str(e)})
def _poll_prediction_result(self, request_id: str, timeout: int = 180) -> bytes:
import time
url = f"https://api.wavespeed.ai/api/v3/predictions/{request_id}/result"
start_time = time.time()
while time.time() - start_time < timeout:
try:
response = requests.get(url, headers=self._get_headers(), timeout=10)
if response.status_code == 200:
result = response.json().get("data", {})
status = result.get("status")
if status == "completed":
if result.get("outputs") and len(result["outputs"]) > 0:
audio_url = result["outputs"][0] # It's a URL string in the array
return self._download_audio(audio_url, timeout)
else:
raise ValueError("Completed task has no output URLs")
elif status == "failed":
raise ValueError(f"Task failed: {result.get('error')}")
# If processing/created, continue polling
time.sleep(1)
else:
logger.warning(f"Polling error {response.status_code}: {response.text}")
time.sleep(1)
except Exception as e:
logger.error(f"Polling exception: {e}")
time.sleep(1)
raise HTTPException(status_code=504, detail="Voice Design generation timed out")
def voice_clone(
self,
audio_bytes: bytes,
@@ -320,6 +416,70 @@ class SpeechGenerator:
audio_url = self._extract_audio_url(outputs)
return self._download_audio(audio_url, timeout)
def cosyvoice_voice_clone(
self,
audio_bytes: bytes,
text: str,
*,
model: str = "wavespeed-ai/cosyvoice-tts/voice-clone",
audio_mime_type: str = "audio/wav",
reference_text: Optional[str] = None,
timeout: int = 180,
) -> bytes:
url = f"{self.base_url}/{model}"
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
mime = audio_mime_type or "audio/wav"
audio_data_url = f"data:{mime};base64,{audio_b64}"
payload = {
"audio": audio_data_url,
"text": text,
}
if reference_text:
payload["reference_text"] = reference_text
logger.info(f"[WaveSpeed] CosyVoice voice clone via {url}")
try:
response = requests.post(
url,
headers=self._get_headers(),
json=payload,
timeout=(30, 90),
)
except requests_exceptions.Timeout as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed CosyVoice voice clone timed out", "message": str(e)})
except (requests_exceptions.ConnectionError, requests_exceptions.ConnectTimeout) as e:
raise HTTPException(status_code=504, detail={"error": "WaveSpeed CosyVoice voice clone connection failed", "message": str(e)})
if response.status_code != 200:
raise HTTPException(
status_code=502,
detail={
"error": "WaveSpeed CosyVoice voice clone failed",
"status_code": response.status_code,
"response": response.text,
},
)
response_json = response.json()
data = response_json.get("data") or response_json
outputs = data.get("outputs") or []
status = data.get("status")
prediction_id = data.get("id")
if not outputs and prediction_id and status in {"created", "processing"}:
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=timeout, interval_seconds=0.8)
outputs = result.get("outputs") or []
if not outputs:
raise HTTPException(status_code=502, detail="WaveSpeed CosyVoice voice clone returned no outputs")
audio_url = self._extract_audio_url(outputs)
return self._download_audio(audio_url, timeout)
def _extract_audio_url(self, outputs: list) -> str:
"""Extract audio URL from outputs."""
if not isinstance(outputs, list) or len(outputs) == 0: