Save local changes (GSC/Bing integrations) before merging PR #354
This commit is contained in:
@@ -20,12 +20,13 @@ class BaseAnalyticsHandler(ABC):
|
||||
self.platform_name = platform_type.value
|
||||
|
||||
@abstractmethod
|
||||
async def get_analytics(self, user_id: str) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get analytics data for the platform
|
||||
|
||||
Args:
|
||||
user_id: User ID to get analytics for
|
||||
**kwargs: Additional arguments for specific handlers
|
||||
|
||||
Returns:
|
||||
AnalyticsData object with platform metrics
|
||||
|
||||
@@ -42,7 +42,7 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
db_url = f'sqlite:///{db_path}'
|
||||
return BingInsightsService(db_url)
|
||||
|
||||
async def get_analytics(self, user_id: str) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Bing Webmaster analytics data using Bing Webmaster API
|
||||
"""
|
||||
@@ -83,9 +83,32 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
if not access_token:
|
||||
return self.create_error_response('Bing Webmaster access token not available')
|
||||
|
||||
# Select site: Prefer target_url match, otherwise first site
|
||||
selected_site = sites[0] if sites else None
|
||||
|
||||
if not selected_site:
|
||||
return self.create_error_response('No Bing sites found')
|
||||
|
||||
if target_url and sites:
|
||||
logger.info(f"Attempting to match target URL: {target_url}")
|
||||
# Normalize target URL (remove protocol, trailing slash)
|
||||
normalized_target = target_url.replace('https://', '').replace('http://', '').rstrip('/')
|
||||
|
||||
for site in sites:
|
||||
# Bing uses 'Url' key
|
||||
site_url = site.get('Url', '')
|
||||
normalized_site = site_url.replace('https://', '').replace('http://', '').rstrip('/')
|
||||
|
||||
if normalized_target in normalized_site or normalized_site in normalized_target:
|
||||
selected_site = site
|
||||
logger.info(f"Found matching Bing site: {site_url}")
|
||||
break
|
||||
|
||||
site_url_for_storage = selected_site.get('Url', '') if selected_site else ''
|
||||
logger.info(f"Using Bing site URL: {site_url_for_storage}")
|
||||
|
||||
query_stats = {}
|
||||
try:
|
||||
site_url_for_storage = sites[0].get('Url', '') if (sites and isinstance(sites[0], dict)) else None
|
||||
stored = storage_service.get_analytics_summary(user_id, site_url_for_storage, days=30)
|
||||
if stored and isinstance(stored, dict):
|
||||
query_stats = {
|
||||
@@ -99,7 +122,7 @@ class BingAnalyticsHandler(BaseAnalyticsHandler):
|
||||
logger.warning(f"Bing analytics: Failed to read stored analytics summary: {e}")
|
||||
|
||||
# Get enhanced insights
|
||||
insights = self._get_enhanced_insights_with_service(insights_service, user_id, sites[0].get('Url', '') if sites else '')
|
||||
insights = self._get_enhanced_insights_with_service(insights_service, user_id, site_url_for_storage)
|
||||
|
||||
metrics = {
|
||||
'connection_status': 'connected',
|
||||
|
||||
@@ -22,16 +22,22 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
super().__init__(PlatformType.GSC)
|
||||
self.gsc_service = GSCService()
|
||||
|
||||
async def get_analytics(self, user_id: str) -> AnalyticsData:
|
||||
async def get_analytics(self, user_id: str, target_url: str = None, **kwargs) -> AnalyticsData:
|
||||
"""
|
||||
Get Google Search Console analytics data with caching
|
||||
|
||||
Args:
|
||||
user_id: User ID to get analytics for
|
||||
target_url: Optional URL to prefer when selecting GSC site
|
||||
|
||||
Returns comprehensive SEO metrics including clicks, impressions, CTR, and position data.
|
||||
"""
|
||||
self.log_analytics_request(user_id, "get_analytics")
|
||||
|
||||
# Check cache first - GSC API calls can be expensive
|
||||
cached_data = analytics_cache.get('gsc_analytics', user_id)
|
||||
# Include target_url in cache key if provided
|
||||
cache_key = f"{user_id}_{target_url}" if target_url else user_id
|
||||
cached_data = analytics_cache.get('gsc_analytics', cache_key)
|
||||
if cached_data:
|
||||
logger.info("Using cached GSC analytics for user {user_id}", user_id=user_id)
|
||||
return AnalyticsData(**cached_data)
|
||||
@@ -45,8 +51,23 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
logger.warning(f"No GSC sites found for user {user_id}")
|
||||
return self.create_error_response('No GSC sites found')
|
||||
|
||||
# Get analytics for the first site (or combine all sites)
|
||||
site_url = sites[0]['siteUrl']
|
||||
# Select site: Prefer target_url match, otherwise first site
|
||||
selected_site = sites[0]
|
||||
if target_url:
|
||||
logger.info(f"Attempting to match target URL: {target_url}")
|
||||
# Normalize target URL (remove protocol, trailing slash)
|
||||
normalized_target = target_url.replace('https://', '').replace('http://', '').rstrip('/')
|
||||
|
||||
for site in sites:
|
||||
site_url = site['siteUrl']
|
||||
normalized_site = site_url.replace('https://', '').replace('http://', '').rstrip('/')
|
||||
|
||||
if normalized_target in normalized_site or normalized_site in normalized_target:
|
||||
selected_site = site
|
||||
logger.info(f"Found matching GSC site: {site_url}")
|
||||
break
|
||||
|
||||
site_url = selected_site['siteUrl']
|
||||
logger.info(f"Using GSC site URL: {site_url}")
|
||||
|
||||
# Get search analytics for last 30 days
|
||||
@@ -71,7 +92,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
)
|
||||
|
||||
# Cache the result to avoid expensive API calls
|
||||
analytics_cache.set('gsc_analytics', user_id, result.__dict__)
|
||||
analytics_cache.set('gsc_analytics', cache_key, result.__dict__)
|
||||
logger.info("Cached GSC analytics data for user {user_id}", user_id=user_id)
|
||||
|
||||
return result
|
||||
@@ -81,7 +102,7 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
error_result = self.create_error_response(str(e))
|
||||
|
||||
# Cache error result for shorter time to retry sooner
|
||||
analytics_cache.set('gsc_analytics', user_id, error_result.__dict__, ttl_override=300) # 5 minutes
|
||||
analytics_cache.set('gsc_analytics', cache_key, error_result.__dict__, ttl_override=300) # 5 minutes
|
||||
return error_result
|
||||
|
||||
def get_connection_status(self, user_id: str) -> Dict[str, Any]:
|
||||
@@ -117,111 +138,93 @@ class GSCAnalyticsHandler(BaseAnalyticsHandler):
|
||||
# New structure from updated GSC service
|
||||
overall_rows = search_analytics.get('overall_metrics', {}).get('rows', [])
|
||||
query_rows = search_analytics.get('query_data', {}).get('rows', [])
|
||||
verification_rows = search_analytics.get('verification_data', {}).get('rows', [])
|
||||
|
||||
logger.info(f"GSC Overall metrics rows: {len(overall_rows)}")
|
||||
logger.info(f"GSC Query data rows: {len(query_rows)}")
|
||||
logger.info(f"GSC Verification rows: {len(verification_rows)}")
|
||||
# Calculate totals from overall_rows (most accurate as it includes anonymized queries)
|
||||
total_clicks = 0
|
||||
total_impressions = 0
|
||||
total_position = 0
|
||||
valid_position_rows = 0
|
||||
|
||||
if overall_rows:
|
||||
logger.info(f"GSC Overall first row: {overall_rows[0]}")
|
||||
if query_rows:
|
||||
logger.info(f"GSC Query first row: {query_rows[0]}")
|
||||
# Use overall_rows for totals if available, otherwise fallback to query_rows
|
||||
calc_rows = overall_rows if overall_rows else query_rows
|
||||
|
||||
for row in calc_rows:
|
||||
clicks = row.get('clicks', 0)
|
||||
impressions = row.get('impressions', 0)
|
||||
position = row.get('position', 0)
|
||||
|
||||
total_clicks += clicks
|
||||
total_impressions += impressions
|
||||
|
||||
if position and position > 0:
|
||||
total_position += position * impressions # Weighted average
|
||||
|
||||
# Calculate weighted average position
|
||||
avg_position = total_position / total_impressions if total_impressions > 0 else 0
|
||||
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
|
||||
|
||||
# Use query_rows for top queries list
|
||||
top_queries_source = query_rows
|
||||
|
||||
# Use query_rows for detailed insights, overall_rows for summary
|
||||
rows = query_rows if query_rows else overall_rows
|
||||
else:
|
||||
# Legacy structure
|
||||
rows = search_analytics.get('rows', [])
|
||||
logger.info(f"GSC Legacy rows count: {len(rows)}")
|
||||
if rows:
|
||||
logger.info(f"GSC Legacy first row structure: {rows[0]}")
|
||||
logger.info(f"GSC Legacy first row keys: {list(rows[0].keys()) if rows[0] else 'No rows'}")
|
||||
|
||||
# Calculate summary metrics - handle different response formats
|
||||
total_clicks = 0
|
||||
total_impressions = 0
|
||||
total_position = 0
|
||||
valid_rows = 0
|
||||
|
||||
for row in rows:
|
||||
# Handle different possible response formats
|
||||
clicks = row.get('clicks', 0)
|
||||
impressions = row.get('impressions', 0)
|
||||
position = row.get('position', 0)
|
||||
# ... existing legacy logic ...
|
||||
calc_rows = rows
|
||||
top_queries_source = rows
|
||||
|
||||
# If position is 0 or None, skip it from average calculation
|
||||
if position and position > 0:
|
||||
total_position += position
|
||||
valid_rows += 1
|
||||
total_clicks = 0
|
||||
total_impressions = 0
|
||||
total_position = 0
|
||||
valid_position_rows = 0
|
||||
|
||||
total_clicks += clicks
|
||||
total_impressions += impressions
|
||||
|
||||
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
|
||||
avg_position = total_position / valid_rows if valid_rows > 0 else 0
|
||||
|
||||
logger.info(f"GSC Calculated metrics - clicks: {total_clicks}, impressions: {total_impressions}, ctr: {avg_ctr}, position: {avg_position}, valid_rows: {valid_rows}")
|
||||
|
||||
# Get top performing queries - handle different data structures
|
||||
if rows and 'keys' in rows[0]:
|
||||
# New GSC API format with keys array
|
||||
top_queries = sorted(rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
|
||||
|
||||
# Get top performing pages (if we have page data)
|
||||
page_data = {}
|
||||
for row in rows:
|
||||
# Handle different key structures
|
||||
keys = row.get('keys', [])
|
||||
if len(keys) > 1 and keys[1]: # Page data available
|
||||
page = keys[1].get('keys', ['Unknown'])[0] if isinstance(keys[1], dict) else str(keys[1])
|
||||
else:
|
||||
page = 'Unknown'
|
||||
for row in calc_rows:
|
||||
clicks = row.get('clicks', 0)
|
||||
impressions = row.get('impressions', 0)
|
||||
position = row.get('position', 0)
|
||||
|
||||
if page not in page_data:
|
||||
page_data[page] = {'clicks': 0, 'impressions': 0, 'ctr': 0, 'position': 0}
|
||||
page_data[page]['clicks'] += row.get('clicks', 0)
|
||||
page_data[page]['impressions'] += row.get('impressions', 0)
|
||||
else:
|
||||
# Legacy format or no keys structure
|
||||
top_queries = sorted(rows, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
|
||||
page_data = {}
|
||||
total_clicks += clicks
|
||||
total_impressions += impressions
|
||||
|
||||
if position and position > 0:
|
||||
# Simple average for legacy/unknown structure if we can't do weighted
|
||||
total_position += position
|
||||
valid_position_rows += 1
|
||||
|
||||
avg_ctr = (total_clicks / total_impressions * 100) if total_impressions > 0 else 0
|
||||
avg_position = total_position / valid_position_rows if valid_position_rows > 0 else 0
|
||||
|
||||
|
||||
# Calculate page metrics
|
||||
for page in page_data:
|
||||
if page_data[page]['impressions'] > 0:
|
||||
page_data[page]['ctr'] = page_data[page]['clicks'] / page_data[page]['impressions'] * 100
|
||||
|
||||
top_pages = sorted(page_data.items(), key=lambda x: x[1]['clicks'], reverse=True)[:10]
|
||||
|
||||
return {
|
||||
'connection_status': 'connected',
|
||||
'connected_sites': 1, # GSC typically has one site per user
|
||||
'total_clicks': total_clicks,
|
||||
'total_impressions': total_impressions,
|
||||
'avg_ctr': round(avg_ctr, 2),
|
||||
'avg_position': round(avg_position, 2),
|
||||
'total_queries': len(rows),
|
||||
'top_queries': [
|
||||
{
|
||||
# Get top performing queries
|
||||
top_queries = []
|
||||
if top_queries_source:
|
||||
# Sort by clicks
|
||||
sorted_queries = sorted(top_queries_source, key=lambda x: x.get('clicks', 0), reverse=True)[:10]
|
||||
|
||||
for row in sorted_queries:
|
||||
top_queries.append({
|
||||
'query': self._extract_query_from_row(row),
|
||||
'clicks': row.get('clicks', 0),
|
||||
'impressions': row.get('impressions', 0),
|
||||
'ctr': round(row.get('ctr', 0) * 100, 2),
|
||||
'position': round(row.get('position', 0), 2)
|
||||
}
|
||||
for row in top_queries
|
||||
],
|
||||
'top_pages': [
|
||||
{
|
||||
'page': page,
|
||||
'clicks': data['clicks'],
|
||||
'impressions': data['impressions'],
|
||||
'ctr': round(data['ctr'], 2)
|
||||
}
|
||||
for page, data in top_pages
|
||||
],
|
||||
'note': 'Google Search Console provides search performance data, keyword rankings, and SEO insights'
|
||||
})
|
||||
|
||||
# Prepare Top Pages (requires page dimension, but we only requested query dimension in gsc_service step 3)
|
||||
# To get top pages, we would need another API call with dimension=['page']
|
||||
# For now, we'll return empty top_pages or infer from what we have if possible (we can't from query data)
|
||||
top_pages = []
|
||||
|
||||
return {
|
||||
'connection_status': 'connected',
|
||||
'connected_sites': 1,
|
||||
'total_clicks': total_clicks,
|
||||
'total_impressions': total_impressions,
|
||||
'avg_ctr': round(avg_ctr, 2),
|
||||
'avg_position': round(avg_position, 2),
|
||||
'total_queries': len(top_queries_source) if top_queries_source else 0,
|
||||
'top_queries': top_queries,
|
||||
'top_pages': top_pages
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -59,6 +59,32 @@ class PlatformAnalyticsService:
|
||||
logger.info(f"Getting comprehensive analytics for user {user_id}, platforms: {platforms}")
|
||||
analytics_data = {}
|
||||
|
||||
# Determine target URL from Wix/WP for GSC site selection
|
||||
target_url = None
|
||||
try:
|
||||
status = await self.get_platform_connection_status(user_id)
|
||||
|
||||
# Check Wix
|
||||
if status.get('wix', {}).get('connected'):
|
||||
sites = status['wix'].get('sites', [])
|
||||
if sites:
|
||||
# Assuming site object has 'blog_url' or 'url'
|
||||
site = sites[0]
|
||||
target_url = site.get('blog_url') or site.get('url')
|
||||
|
||||
# Check WordPress if no Wix
|
||||
if not target_url and status.get('wordpress', {}).get('connected'):
|
||||
sites = status['wordpress'].get('sites', [])
|
||||
if sites:
|
||||
site = sites[0]
|
||||
target_url = site.get('blog_url') or site.get('url')
|
||||
|
||||
if target_url:
|
||||
logger.info(f"Identified primary site URL for GSC matching: {target_url}")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to determine target URL for GSC: {e}")
|
||||
|
||||
for platform_name in platforms:
|
||||
try:
|
||||
# Convert string to PlatformType enum
|
||||
@@ -66,7 +92,10 @@ class PlatformAnalyticsService:
|
||||
handler = self.handlers.get(platform_type)
|
||||
|
||||
if handler:
|
||||
analytics_data[platform_name] = await handler.get_analytics(user_id)
|
||||
if platform_type == PlatformType.GSC or platform_type == PlatformType.BING:
|
||||
analytics_data[platform_name] = await handler.get_analytics(user_id, target_url=target_url)
|
||||
else:
|
||||
analytics_data[platform_name] = await handler.get_analytics(user_id)
|
||||
else:
|
||||
logger.warning(f"Unknown platform: {platform_name}")
|
||||
analytics_data[platform_name] = self._create_error_response(platform_name, f"Unknown platform: {platform_name}")
|
||||
|
||||
@@ -30,6 +30,8 @@ from models.product_asset_models import ProductAsset, ProductStyleTemplate, Ecom
|
||||
from models.podcast_models import PodcastProject
|
||||
# Research models use SubscriptionBase
|
||||
from models.research_models import ResearchProject
|
||||
# Video Studio models
|
||||
from models.video_models import VideoGenerationTask
|
||||
# Bing Analytics models
|
||||
from models.bing_analytics_models import Base as BingAnalyticsBase
|
||||
|
||||
@@ -54,7 +56,22 @@ def get_user_db_path(user_id: str) -> str:
|
||||
# Sanitize user_id to be safe for filesystem
|
||||
safe_user_id = "".join(c for c in user_id if c.isalnum() or c in ('-', '_'))
|
||||
user_workspace = os.path.join(WORKSPACE_DIR, f"workspace_{safe_user_id}")
|
||||
return os.path.join(user_workspace, 'db', f'alwrity_{safe_user_id}.db')
|
||||
|
||||
# Check for legacy naming convention first (to support existing data)
|
||||
# Some older workspaces might have 'alwrity.db' instead of 'alwrity_{user_id}.db'
|
||||
legacy_db_path = os.path.join(user_workspace, 'db', 'alwrity.db')
|
||||
specific_db_path = os.path.join(user_workspace, 'db', f'alwrity_{safe_user_id}.db')
|
||||
|
||||
# If the specific one exists, use it (preferred)
|
||||
if os.path.exists(specific_db_path):
|
||||
return specific_db_path
|
||||
|
||||
# If legacy exists and specific doesn't, use legacy
|
||||
if os.path.exists(legacy_db_path):
|
||||
return legacy_db_path
|
||||
|
||||
# Default to specific for new databases
|
||||
return specific_db_path
|
||||
|
||||
def get_all_user_ids() -> List[str]:
|
||||
"""
|
||||
|
||||
@@ -14,6 +14,8 @@ from loguru import logger
|
||||
|
||||
from services.database import get_user_db_path
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
class GSCService:
|
||||
"""Service for Google Search Console integration."""
|
||||
|
||||
@@ -31,10 +33,62 @@ class GSCService:
|
||||
services_dir = os.path.dirname(__file__)
|
||||
backend_dir = os.path.abspath(os.path.join(services_dir, os.pardir))
|
||||
self.credentials_file = os.path.join(backend_dir, "gsc_credentials.json")
|
||||
logger.info(f"GSC credentials file path set to: {self.credentials_file}")
|
||||
|
||||
# Load client config from file or environment variables
|
||||
self.client_config = self._load_client_config()
|
||||
|
||||
if self.client_config:
|
||||
logger.info("GSC client configuration loaded successfully")
|
||||
else:
|
||||
logger.warning(f"GSC credentials not found in {self.credentials_file} or environment variables")
|
||||
|
||||
self.scopes = ['https://www.googleapis.com/auth/webmasters.readonly']
|
||||
# Note: Tables are initialized lazily per user
|
||||
logger.info("GSC Service initialized successfully")
|
||||
|
||||
def _load_client_config(self) -> Optional[Dict[str, Any]]:
|
||||
"""Load Google client configuration from environment variables or file."""
|
||||
# Reload environment variables to catch any runtime changes (e.g. .env updates)
|
||||
load_dotenv(override=True)
|
||||
|
||||
# 1. Check Environment Variables (Priority)
|
||||
client_id = os.getenv("GOOGLE_CLIENT_ID")
|
||||
client_secret = os.getenv("GOOGLE_CLIENT_SECRET")
|
||||
|
||||
if client_id and client_secret:
|
||||
redirect_uri = os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
|
||||
logger.info("Loading GSC credentials from environment variables")
|
||||
# Construct the config dictionary expected by google_auth_oauthlib
|
||||
return {
|
||||
"web": {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"project_id": os.getenv("GOOGLE_PROJECT_ID", "alwrity"),
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"redirect_uris": [
|
||||
"http://localhost:5173/onboarding",
|
||||
redirect_uri
|
||||
],
|
||||
"javascript_origins": [
|
||||
"http://localhost:5173",
|
||||
"http://localhost:8000"
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
# 2. Fallback to File
|
||||
if os.path.exists(self.credentials_file):
|
||||
try:
|
||||
with open(self.credentials_file, 'r') as f:
|
||||
config = json.load(f)
|
||||
logger.info(f"Loading GSC credentials from file: {self.credentials_file}")
|
||||
return config
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load GSC credentials from file: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _get_db_path(self, user_id: str) -> str:
|
||||
return get_user_db_path(user_id)
|
||||
@@ -94,11 +148,11 @@ class GSCService:
|
||||
self._init_gsc_tables(user_id)
|
||||
db_path = self._get_db_path(user_id)
|
||||
|
||||
# Read client credentials from file to ensure we have all required fields
|
||||
with open(self.credentials_file, 'r') as f:
|
||||
client_config = json.load(f)
|
||||
if not self.client_config:
|
||||
logger.error("Cannot save credentials: Client configuration not loaded")
|
||||
return False
|
||||
|
||||
web_config = client_config.get('web', {})
|
||||
web_config = self.client_config.get('web', {})
|
||||
|
||||
credentials_json = json.dumps({
|
||||
'token': credentials.token,
|
||||
@@ -184,12 +238,17 @@ class GSCService:
|
||||
try:
|
||||
logger.info(f"Generating OAuth URL for user: {user_id}")
|
||||
|
||||
if not os.path.exists(self.credentials_file):
|
||||
raise FileNotFoundError(f"GSC credentials file not found: {self.credentials_file}")
|
||||
# Retry loading config if missing (in case .env was added later)
|
||||
if not self.client_config:
|
||||
self.client_config = self._load_client_config()
|
||||
|
||||
if not self.client_config:
|
||||
raise FileNotFoundError("GSC credentials not found in file or environment variables.")
|
||||
|
||||
redirect_uri = os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
|
||||
flow = Flow.from_client_secrets_file(
|
||||
self.credentials_file,
|
||||
|
||||
flow = Flow.from_client_config(
|
||||
self.client_config,
|
||||
scopes=self.scopes,
|
||||
redirect_uri=redirect_uri
|
||||
)
|
||||
@@ -256,8 +315,12 @@ class GSCService:
|
||||
conn.commit()
|
||||
|
||||
# Exchange code for credentials
|
||||
flow = Flow.from_client_secrets_file(
|
||||
self.credentials_file,
|
||||
if not self.client_config:
|
||||
logger.error("Cannot handle callback: Client configuration not loaded")
|
||||
return False
|
||||
|
||||
flow = Flow.from_client_config(
|
||||
self.client_config,
|
||||
scopes=self.scopes,
|
||||
redirect_uri=os.getenv('GSC_REDIRECT_URI', 'http://localhost:8000/gsc/callback')
|
||||
)
|
||||
@@ -283,7 +346,11 @@ class GSCService:
|
||||
service = build('searchconsole', 'v1', credentials=credentials)
|
||||
logger.info(f"Authenticated GSC service created for user: {user_id}")
|
||||
return service
|
||||
|
||||
|
||||
except ValueError as e:
|
||||
# Log as warning only, as this is expected for unconnected users
|
||||
# logger.warning(f"Cannot create GSC service for user {user_id}: {e}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating authenticated GSC service for user {user_id}: {e}")
|
||||
raise
|
||||
@@ -291,7 +358,13 @@ class GSCService:
|
||||
def get_site_list(self, user_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get list of sites from GSC."""
|
||||
try:
|
||||
service = self.get_authenticated_service(user_id)
|
||||
try:
|
||||
service = self.get_authenticated_service(user_id)
|
||||
except ValueError:
|
||||
# User not connected or credentials invalid
|
||||
logger.warning(f"User {user_id} not connected to GSC. Returning empty site list.")
|
||||
return []
|
||||
|
||||
sites = service.sites().list().execute()
|
||||
|
||||
site_list = []
|
||||
@@ -306,7 +379,8 @@ class GSCService:
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting site list for user {user_id}: {e}")
|
||||
raise
|
||||
# Return empty list instead of raising to prevent frontend 500s
|
||||
return []
|
||||
|
||||
def get_search_analytics(self, user_id: str, site_url: str,
|
||||
start_date: str = None, end_date: str = None) -> Dict[str, Any]:
|
||||
@@ -325,7 +399,12 @@ class GSCService:
|
||||
logger.info(f"Returning cached analytics data for user: {user_id}")
|
||||
return cached_data
|
||||
|
||||
service = self.get_authenticated_service(user_id)
|
||||
try:
|
||||
service = self.get_authenticated_service(user_id)
|
||||
except ValueError:
|
||||
logger.warning(f"User {user_id} not connected to GSC. Returning empty analytics.")
|
||||
return {'error': 'User not connected to GSC', 'rows': [], 'rowCount': 0}
|
||||
|
||||
if not service:
|
||||
logger.error(f"Failed to get authenticated GSC service for user: {user_id}")
|
||||
return {'error': 'Authentication failed', 'rows': [], 'rowCount': 0}
|
||||
@@ -359,11 +438,11 @@ class GSCService:
|
||||
logger.error(f"GSC Data verification failed for user {user_id}: {verification_error}")
|
||||
return {'error': f'Data verification failed: {str(verification_error)}', 'rows': [], 'rowCount': 0}
|
||||
|
||||
# Step 2: Get overall metrics (no dimensions)
|
||||
# Step 2: Get daily metrics for charting (ensure we have rows)
|
||||
request = {
|
||||
'startDate': start_date,
|
||||
'endDate': end_date,
|
||||
'dimensions': [], # No dimensions for overall metrics
|
||||
'dimensions': ['date'], # Use date dimension to get time-series data
|
||||
'rowLimit': 1000
|
||||
}
|
||||
|
||||
@@ -472,7 +551,11 @@ class GSCService:
|
||||
def revoke_user_access(self, user_id: str) -> bool:
|
||||
"""Revoke user's GSC access."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
db_path = self._get_db_path(user_id)
|
||||
if not os.path.exists(db_path):
|
||||
return True
|
||||
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Delete credentials
|
||||
@@ -496,7 +579,11 @@ class GSCService:
|
||||
def clear_incomplete_credentials(self, user_id: str) -> bool:
|
||||
"""Clear incomplete GSC credentials that are missing required fields."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
db_path = self._get_db_path(user_id)
|
||||
if not os.path.exists(db_path):
|
||||
return True
|
||||
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('DELETE FROM gsc_credentials WHERE user_id = ?', (user_id,))
|
||||
conn.commit()
|
||||
@@ -511,7 +598,11 @@ class GSCService:
|
||||
def _get_cached_data(self, user_id: str, site_url: str, data_type: str, cache_key: str) -> Optional[Dict]:
|
||||
"""Get cached data if not expired."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
db_path = self._get_db_path(user_id)
|
||||
if not os.path.exists(db_path):
|
||||
return None
|
||||
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT data_json FROM gsc_data_cache
|
||||
@@ -531,9 +622,12 @@ class GSCService:
|
||||
def _cache_data(self, user_id: str, site_url: str, data_type: str, data: Dict, cache_key: str):
|
||||
"""Cache data with expiration."""
|
||||
try:
|
||||
self._init_gsc_tables(user_id)
|
||||
db_path = self._get_db_path(user_id)
|
||||
|
||||
expires_at = datetime.now() + timedelta(hours=1) # Cache for 1 hour
|
||||
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
INSERT OR REPLACE INTO gsc_data_cache
|
||||
|
||||
@@ -24,7 +24,16 @@ class WordPressOAuthService:
|
||||
# WordPress.com OAuth2 credentials
|
||||
self.client_id = os.getenv('WORDPRESS_CLIENT_ID', '')
|
||||
self.client_secret = os.getenv('WORDPRESS_CLIENT_SECRET', '')
|
||||
self.redirect_uri = os.getenv('WORDPRESS_REDIRECT_URI', 'https://alwrity-ai.vercel.app/wp/callback')
|
||||
|
||||
# Determine redirect URI dynamically
|
||||
default_redirect = 'https://alwrity-ai.vercel.app/wp/callback'
|
||||
frontend_url = os.getenv('FRONTEND_URL')
|
||||
|
||||
if frontend_url:
|
||||
self.redirect_uri = f"{frontend_url.rstrip('/')}/wp/callback"
|
||||
else:
|
||||
self.redirect_uri = os.getenv('WORDPRESS_REDIRECT_URI', default_redirect)
|
||||
|
||||
self.base_url = "https://public-api.wordpress.com"
|
||||
|
||||
# Validate configuration
|
||||
|
||||
@@ -17,8 +17,7 @@ from .core_agent_framework import (
|
||||
# Market signal detection
|
||||
from .market_signal_detector import (
|
||||
MarketSignal,
|
||||
MarketSignalDetector,
|
||||
MarketTrendAnalyzer
|
||||
MarketSignalDetector
|
||||
)
|
||||
|
||||
# Performance monitoring
|
||||
|
||||
@@ -105,6 +105,18 @@ class ALwrityAgentOrchestrator:
|
||||
def _create_specialized_agents(self):
|
||||
"""Create specialized marketing agents"""
|
||||
try:
|
||||
# Check if onboarding is complete before initializing heavy agents
|
||||
try:
|
||||
from services.onboarding.progress_service import OnboardingProgressService
|
||||
onboarding_service = OnboardingProgressService()
|
||||
status = onboarding_service.get_onboarding_status(self.user_id)
|
||||
if not status.get("is_completed", False):
|
||||
logger.info(f"Skipping agent initialization for user {self.user_id} - Onboarding incomplete")
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not check onboarding status for {self.user_id}: {e}")
|
||||
# Fallthrough to attempt initialization if check fails
|
||||
|
||||
enabled_by_key = {}
|
||||
db = None
|
||||
try:
|
||||
@@ -159,6 +171,26 @@ class ALwrityAgentOrchestrator:
|
||||
self.trend_surfer_agent = TrendSurferAgent(intel_service, self.user_id)
|
||||
self.agents['trend'] = self.trend_surfer_agent
|
||||
|
||||
# Content Guardian Agent
|
||||
if enabled_by_key.get("content_guardian", True):
|
||||
try:
|
||||
from services.intelligence.sif_agents import ContentGuardianAgent
|
||||
from services.intelligence.txtai_service import TxtaiIntelligenceService
|
||||
|
||||
# Initialize intelligence service if not already available
|
||||
intel_service = TxtaiIntelligenceService(self.user_id)
|
||||
|
||||
# Initialize Content Guardian Agent
|
||||
self.content_guardian_agent = ContentGuardianAgent(
|
||||
intelligence_service=intel_service,
|
||||
user_id=self.user_id,
|
||||
sif_service=None # SIF service is optional/circular dependency handling
|
||||
)
|
||||
self.agents['guardian'] = self.content_guardian_agent
|
||||
logger.info(f"Initialized ContentGuardianAgent for user {self.user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize ContentGuardianAgent: {e}")
|
||||
|
||||
logger.info(f"Created {len(self.agents)} specialized agents for user {self.user_id}")
|
||||
|
||||
except Exception as e:
|
||||
|
||||
213
backend/services/intelligence/agents/agent_usage_tracking.py
Normal file
213
backend/services/intelligence/agents/agent_usage_tracking.py
Normal file
@@ -0,0 +1,213 @@
|
||||
import logging
|
||||
import time
|
||||
from datetime import datetime
|
||||
from sqlalchemy import text
|
||||
from services.database import get_session_for_user
|
||||
from models.subscription_models import APIProvider, UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def track_agent_usage_sync(user_id: str, model_name: str, prompt: str, response_text: str, duration: float):
|
||||
"""
|
||||
Synchronously track agent LLM usage.
|
||||
This mimics the logic in llm_text_gen to ensure consistency and robustness.
|
||||
"""
|
||||
try:
|
||||
# Detect provider
|
||||
provider_enum = APIProvider.GEMINI # Default
|
||||
actual_provider_name = "gemini"
|
||||
|
||||
model_lower = model_name.lower()
|
||||
if "gemini" in model_lower:
|
||||
provider_enum = APIProvider.GEMINI
|
||||
actual_provider_name = "gemini"
|
||||
elif "gpt" in model_lower or "openai" in model_lower or "mistral" in model_lower:
|
||||
# HuggingFace/Mistral often mapped to gpt-oss or mistral
|
||||
provider_enum = APIProvider.MISTRAL
|
||||
actual_provider_name = "huggingface"
|
||||
elif "claude" in model_lower or "anthropic" in model_lower:
|
||||
provider_enum = APIProvider.ANTHROPIC
|
||||
actual_provider_name = "anthropic"
|
||||
|
||||
logger.info(f"[AgentTracking] Tracking usage for user {user_id}, provider {actual_provider_name}, model {model_name}")
|
||||
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
logger.error(f"[AgentTracking] Could not get database session for user {user_id}")
|
||||
return
|
||||
|
||||
try:
|
||||
# Estimate tokens
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3)
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
pricing = PricingService(db)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get limits
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
provider_key = provider_enum.value
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_key}_tokens", 0) or 0
|
||||
|
||||
# Check for existing record
|
||||
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
|
||||
record_count = db.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
|
||||
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
|
||||
if record_count and record_count > 0:
|
||||
# Read current values
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_key}_calls, {provider_key}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
result = db.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
current_calls_before = result[0] if result[0] is not None else 0
|
||||
current_tokens_before = result[1] if result[1] is not None else 0
|
||||
else:
|
||||
# Create new summary
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db.add(summary)
|
||||
db.flush()
|
||||
|
||||
# Update calls
|
||||
new_calls = current_calls_before + 1
|
||||
update_calls_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_key}_calls = :new_calls
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_calls_query, {
|
||||
'new_calls': new_calls,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update tokens with limit check
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
new_tokens = token_limit
|
||||
tokens_total = max(0, token_limit - current_tokens_before)
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
update_tokens_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_key}_tokens = :new_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_tokens_query, {
|
||||
'new_tokens': new_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
else:
|
||||
tokens_total = 0
|
||||
|
||||
# Calculate cost
|
||||
try:
|
||||
tracked_tokens_input = min(tokens_input, tokens_total)
|
||||
tracked_tokens_output = max(0, tokens_total - tracked_tokens_input)
|
||||
|
||||
cost_info = pricing.calculate_api_cost(
|
||||
provider=provider_enum,
|
||||
model_name=model_name,
|
||||
tokens_input=tracked_tokens_input,
|
||||
tokens_output=tracked_tokens_output,
|
||||
request_count=1
|
||||
)
|
||||
cost_total = cost_info.get('cost_total', 0.0) or 0.0
|
||||
cost_input = cost_info.get('cost_input', 0.0) or 0.0
|
||||
cost_output = cost_info.get('cost_output', 0.0) or 0.0
|
||||
except Exception as e:
|
||||
logger.error(f"[AgentTracking] Cost calculation failed: {e}")
|
||||
cost_total = 0.0
|
||||
cost_input = 0.0
|
||||
cost_output = 0.0
|
||||
|
||||
# Insert into APIUsageLog
|
||||
try:
|
||||
log_query = text("""
|
||||
INSERT INTO api_usage_logs (
|
||||
user_id, provider, endpoint, method, model_used,
|
||||
tokens_input, tokens_output, tokens_total,
|
||||
cost_input, cost_output, cost_total,
|
||||
response_time, status_code, billing_period,
|
||||
timestamp, actual_provider_name
|
||||
) VALUES (
|
||||
:user_id, :provider, :endpoint, :method, :model_used,
|
||||
:tokens_input, :tokens_output, :tokens_total,
|
||||
:cost_input, :cost_output, :cost_total,
|
||||
:response_time, :status_code, :billing_period,
|
||||
:created_at, :actual_provider_name
|
||||
)
|
||||
""")
|
||||
|
||||
db.execute(log_query, {
|
||||
'user_id': user_id,
|
||||
'provider': provider_enum.name, # Use name (GEMINI) not value (gemini) for SQLAlchemy Enum
|
||||
'endpoint': 'agent_action',
|
||||
'method': 'GENERATE',
|
||||
'model_used': model_name,
|
||||
'tokens_input': tracked_tokens_input,
|
||||
'tokens_output': tracked_tokens_output,
|
||||
'tokens_total': tracked_tokens_input + tracked_tokens_output,
|
||||
'cost_input': cost_input,
|
||||
'cost_output': cost_output,
|
||||
'cost_total': cost_total,
|
||||
'response_time': duration,
|
||||
'status_code': 200,
|
||||
'billing_period': current_period,
|
||||
'created_at': datetime.utcnow(),
|
||||
'actual_provider_name': actual_provider_name
|
||||
})
|
||||
except Exception as log_e:
|
||||
logger.error(f"[AgentTracking] Failed to insert usage log: {log_e}")
|
||||
|
||||
if cost_total > 0:
|
||||
update_costs_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_key}_cost = COALESCE({provider_key}_cost, 0) + :cost,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_costs_query, {
|
||||
'cost': cost_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update totals
|
||||
update_totals_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET total_calls = COALESCE(total_calls, 0) + 1,
|
||||
total_tokens = COALESCE(total_tokens, 0) + :tokens_total
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db.execute(update_totals_query, {
|
||||
'tokens_total': tokens_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
db.commit()
|
||||
logger.info(f"[AgentTracking] ✅ Usage tracked: {new_calls} calls, {cost_total} cost")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AgentTracking] Error tracking usage: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[AgentTracking] Top level error: {e}", exc_info=True)
|
||||
@@ -32,9 +32,64 @@ from services.database import get_session_for_user
|
||||
from services.intelligence.monitoring.semantic_dashboard import RealTimeSemanticMonitor
|
||||
from services.intelligence.agents.safety_framework import get_safety_framework
|
||||
from services.agent_activity_service import AgentActivityService
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
import time
|
||||
|
||||
logger = get_service_logger(__name__)
|
||||
|
||||
class TrackingLLMWrapper:
|
||||
"""
|
||||
Wrapper for LLM instances to transparently track usage.
|
||||
Intercepts calls to __call__ and generate() to log metrics.
|
||||
"""
|
||||
def __init__(self, llm: Any, user_id: str, model_name: str):
|
||||
self.llm = llm
|
||||
self.user_id = user_id
|
||||
self.model_name = model_name
|
||||
|
||||
def __call__(self, prompt: str, *args, **kwargs) -> Any:
|
||||
return self.generate(prompt, *args, **kwargs)
|
||||
|
||||
def generate(self, prompt: str, *args, **kwargs) -> str:
|
||||
start_time = time.time()
|
||||
try:
|
||||
# Delegate to the underlying LLM
|
||||
if hasattr(self.llm, "generate"):
|
||||
response = self.llm.generate(prompt, *args, **kwargs)
|
||||
else:
|
||||
response = self.llm(prompt, *args, **kwargs)
|
||||
|
||||
# Handle response format (some might return list of dicts)
|
||||
response_text = str(response)
|
||||
if isinstance(response, list):
|
||||
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
|
||||
response_text = response[0]['generated_text']
|
||||
else:
|
||||
response_text = str(response[0])
|
||||
|
||||
# Track usage
|
||||
duration = time.time() - start_time
|
||||
try:
|
||||
track_agent_usage_sync(
|
||||
user_id=self.user_id,
|
||||
model_name=self.model_name,
|
||||
prompt=prompt,
|
||||
response_text=response_text,
|
||||
duration=duration
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to track agent usage in wrapper: {e}")
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM generation failed in tracking wrapper: {e}")
|
||||
raise e
|
||||
|
||||
def __getattr__(self, name):
|
||||
# Delegate other attribute access to the underlying LLM
|
||||
return getattr(self.llm, name)
|
||||
|
||||
@dataclass
|
||||
class AgentAction:
|
||||
"""Represents an action taken by an agent"""
|
||||
@@ -114,6 +169,10 @@ class BaseALwrityAgent(ABC):
|
||||
self.txtai_agent = None
|
||||
self.llm = llm # Ensure llm is set if provided, regardless of txtai availability
|
||||
|
||||
# Wrap LLM with tracking if it exists
|
||||
if self.llm:
|
||||
self.llm = TrackingLLMWrapper(self.llm, self.user_id, self.model_name)
|
||||
|
||||
self.agent_key = self._resolve_agent_key(agent_type)
|
||||
self._agent_profile = self._load_agent_profile_overrides()
|
||||
self._prompt_context = self._load_prompt_context()
|
||||
@@ -121,10 +180,17 @@ class BaseALwrityAgent(ABC):
|
||||
if TXTAI_AVAILABLE:
|
||||
try:
|
||||
if not self.llm:
|
||||
self.llm = LLM(model_name)
|
||||
|
||||
self.txtai_agent = self._create_txtai_agent()
|
||||
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
|
||||
# Create new LLM if not provided
|
||||
raw_llm = LLM(model_name)
|
||||
# Wrap it
|
||||
self.llm = TrackingLLMWrapper(raw_llm, self.user_id, self.model_name)
|
||||
|
||||
try:
|
||||
self.txtai_agent = self._create_txtai_agent()
|
||||
logger.info(f"Initialized txtai agent for {agent_type} - {self.agent_id}")
|
||||
except Exception as inner_e:
|
||||
logger.warning(f"Could not initialize specific txtai agent for {agent_type}: {inner_e}")
|
||||
self.txtai_agent = self._create_fallback_agent()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize txtai agent for {agent_type}: {e}")
|
||||
self.txtai_agent = self._create_fallback_agent()
|
||||
@@ -134,6 +200,38 @@ class BaseALwrityAgent(ABC):
|
||||
# Initialize safety framework
|
||||
self.safety_framework = get_safety_framework(user_id)
|
||||
|
||||
async def _generate_llm_response(self, prompt: str) -> str:
|
||||
"""
|
||||
Helper to generate text using the agent's LLM with usage tracking.
|
||||
Centralized method for all agents inheriting from BaseALwrityAgent.
|
||||
"""
|
||||
if not self.llm:
|
||||
return "[LLM Unavailable]"
|
||||
|
||||
try:
|
||||
# Run in executor to avoid blocking if LLM is synchronous
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Use the wrapped LLM's generate method (which handles tracking)
|
||||
if hasattr(self.llm, "generate"):
|
||||
response = await loop.run_in_executor(None, lambda: self.llm.generate(prompt))
|
||||
else:
|
||||
response = await loop.run_in_executor(None, lambda: self.llm(prompt))
|
||||
|
||||
# Handle list output (some models return list of dicts)
|
||||
response_text = str(response)
|
||||
if isinstance(response, list):
|
||||
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
|
||||
response_text = response[0]['generated_text']
|
||||
else:
|
||||
response_text = str(response[0])
|
||||
|
||||
return response_text
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM generation failed in agent {self.agent_type}: {e}")
|
||||
return "[Generation Failed]"
|
||||
|
||||
def _resolve_agent_key(self, agent_type: str) -> str:
|
||||
value = str(agent_type or "").strip()
|
||||
if value.lower() == "strategyorchestrator".lower():
|
||||
|
||||
@@ -758,6 +758,11 @@ async def get_agent_performance_summary(user_id: str, agent_id: str) -> Dict[str
|
||||
"""Get comprehensive performance summary for an agent"""
|
||||
return await performance_service.get_agent_performance_summary(user_id, agent_id)
|
||||
|
||||
async def get_all_agents_performance_summary(user_id: str) -> List[Dict[str, Any]]:
|
||||
async def get_all_agents_performance_summary(user_id: str) -> List[Dict[str, Any]]:
|
||||
"""Get performance summary for all agents for a user"""
|
||||
return await performance_service.get_all_agents_performance_summary(user_id)
|
||||
return await performance_service.get_all_agents_performance_summary(user_id)
|
||||
|
||||
# Alias for backward compatibility
|
||||
PerformanceMonitor = AgentPerformanceMonitor
|
||||
performance_monitor = performance_service
|
||||
AgentPerformanceMetrics = AgentPerformanceSnapshot
|
||||
@@ -13,6 +13,7 @@ from loguru import logger
|
||||
from ..txtai_service import TxtaiIntelligenceService
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent, AgentAction
|
||||
from services.seo_tools.content_strategy_service import ContentStrategyService
|
||||
from services.intelligence.sif_agents import SharedLLMWrapper, LocalLLMWrapper
|
||||
try:
|
||||
from services.intelligence.sif_integration import SIFIntegrationService
|
||||
SIF_AVAILABLE = True
|
||||
@@ -20,14 +21,36 @@ except ImportError:
|
||||
SIF_AVAILABLE = False
|
||||
|
||||
try:
|
||||
from txtai import Agent, LLM
|
||||
# Try importing from pipeline first (standard location)
|
||||
from txtai.pipeline import Agent, LLM
|
||||
TXTAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
TXTAI_AVAILABLE = False
|
||||
logger.warning("txtai not available, using fallback implementation")
|
||||
try:
|
||||
# Fallback to top-level import
|
||||
from txtai import Agent, LLM
|
||||
TXTAI_AVAILABLE = True
|
||||
except ImportError:
|
||||
TXTAI_AVAILABLE = False
|
||||
Agent = None
|
||||
LLM = None
|
||||
logger.warning("txtai not available, using fallback implementation")
|
||||
|
||||
class SIFBaseAgent:
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService):
|
||||
class SIFBaseAgent(BaseALwrityAgent):
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, agent_type: str = "sif_agent", model_name: str = "Qwen/Qwen2.5-3B-Instruct", llm: Any = None):
|
||||
# Hybrid LLM Strategy:
|
||||
# 1. Shared LLM for external/high-quality generation
|
||||
self.shared_llm = SharedLLMWrapper(user_id)
|
||||
|
||||
# 2. Local LLM for internal agent work (default for SIF agents)
|
||||
if llm is None:
|
||||
if TXTAI_AVAILABLE:
|
||||
# Use Lazy Local LLM
|
||||
llm = LocalLLMWrapper(model_name)
|
||||
else:
|
||||
# Fallback to Shared if txtai not available
|
||||
llm = self.shared_llm
|
||||
|
||||
super().__init__(user_id, agent_type, model_name, llm)
|
||||
self.intelligence = intelligence_service
|
||||
|
||||
def _log_agent_operation(self, operation: str, **kwargs):
|
||||
@@ -36,9 +59,27 @@ class SIFBaseAgent:
|
||||
if kwargs:
|
||||
logger.debug(f"[{self.__class__.__name__}] Parameters: {kwargs}")
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""
|
||||
SIF agents use the intelligence service directly, but we can expose
|
||||
capabilities via a standard agent interface if needed.
|
||||
"""
|
||||
if not TXTAI_AVAILABLE or Agent is None:
|
||||
return None
|
||||
|
||||
# Return a simple agent that can use the LLM
|
||||
try:
|
||||
return Agent(llm=self.llm, tools=[])
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to create txtai Agent: {e}")
|
||||
return None
|
||||
|
||||
class StrategyArchitectAgent(SIFBaseAgent):
|
||||
"""Agent for discovering content pillars and identifying strategic gaps."""
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
|
||||
super().__init__(intelligence_service, user_id, agent_type="strategy_architect")
|
||||
|
||||
async def discover_pillars(self) -> List[Dict[str, Any]]:
|
||||
"""Identify content pillars through semantic clustering."""
|
||||
self._log_agent_operation("Discovering content pillars")
|
||||
@@ -108,9 +149,61 @@ class ContentGuardianAgent(SIFBaseAgent):
|
||||
CANNIBALIZATION_THRESHOLD = 0.85 # Similarity threshold for cannibalization warning
|
||||
ORIGINALITY_THRESHOLD = 0.75 # Minimum originality score
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
|
||||
super().__init__(intelligence_service)
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
|
||||
super().__init__(intelligence_service, user_id, agent_type="content_guardian")
|
||||
self.sif_service = sif_service
|
||||
|
||||
# Lazy initialization of SIF service if not provided
|
||||
if self.sif_service is None and SIF_AVAILABLE:
|
||||
try:
|
||||
self.sif_service = SIFIntegrationService(user_id)
|
||||
logger.info(f"[{self.__class__.__name__}] Lazily initialized SIFIntegrationService")
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] Failed to lazily initialize SIF service: {e}")
|
||||
|
||||
async def assess_content_quality(self, content: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Assess content quality based on originality, readability, and cannibalization risks.
|
||||
"""
|
||||
self._log_agent_operation("Assessing content quality", content_length=len(content))
|
||||
|
||||
try:
|
||||
# 1. Check for cannibalization
|
||||
cannibalization_result = await self.check_cannibalization(content)
|
||||
|
||||
# 2. Check originality (if not cannibalized)
|
||||
originality_score = 1.0
|
||||
if not cannibalization_result.get("warning"):
|
||||
originality_result = await self.verify_originality(content, None)
|
||||
originality_score = originality_result.get("originality_score", 1.0)
|
||||
|
||||
# 3. Check Style Compliance
|
||||
style_result = await self.style_enforcer(content)
|
||||
style_score = style_result.get("compliance_score", 1.0)
|
||||
|
||||
# 4. Basic Readability (Flesch-Kincaid proxy via sentence length/word complexity)
|
||||
# Simple heuristic for now
|
||||
words = content.split()
|
||||
sentences = content.split('.')
|
||||
avg_sentence_length = len(words) / max(1, len(sentences))
|
||||
readability_score = 1.0 if avg_sentence_length < 20 else max(0.5, 1.0 - (avg_sentence_length - 20) * 0.05)
|
||||
|
||||
# Weighted Score: Originality (40%) + Style (30%) + Readability (30%)
|
||||
quality_score = (originality_score * 0.4) + (style_score * 0.3) + (readability_score * 0.3)
|
||||
|
||||
return {
|
||||
"quality_score": quality_score,
|
||||
"originality_score": originality_score,
|
||||
"readability_score": readability_score,
|
||||
"style_score": style_score,
|
||||
"cannibalization_risk": cannibalization_result,
|
||||
"style_compliance": style_result,
|
||||
"is_acceptable": quality_score > 0.7 and not cannibalization_result.get("warning", False)
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] Failed to assess content quality: {e}")
|
||||
return {"error": str(e), "quality_score": 0.0}
|
||||
|
||||
async def check_cannibalization(self, new_draft: str) -> Dict[str, Any]:
|
||||
"""Check if a new draft competes semantically with existing pages."""
|
||||
@@ -193,25 +286,74 @@ class ContentGuardianAgent(SIFBaseAgent):
|
||||
# 1. Fetch Style Guidelines from SIF if not provided
|
||||
if not style_guidelines and self.sif_service:
|
||||
try:
|
||||
# Search for website analysis to get brand voice/style
|
||||
# We assume the most relevant 'website_analysis' doc contains the guidelines
|
||||
results = await self.intelligence.search("website analysis brand voice style", limit=1)
|
||||
if results:
|
||||
import json
|
||||
res = results[0]
|
||||
metadata_str = res.get('object')
|
||||
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else (metadata_str or res)
|
||||
# Use central SIF service to get robust context
|
||||
seo_context = await self.sif_service.get_seo_context()
|
||||
|
||||
if seo_context and "error" not in seo_context:
|
||||
# Extract brand voice/style from the context
|
||||
# The context structure is normalized in get_seo_context
|
||||
|
||||
if metadata.get('type') == 'website_analysis':
|
||||
report = metadata.get('full_report', {})
|
||||
style_guidelines = {
|
||||
"tone": report.get('brand_analysis', {}).get('brand_voice', 'neutral'),
|
||||
"style_patterns": report.get('style_patterns', {}),
|
||||
"writing_style": report.get('writing_style', {})
|
||||
}
|
||||
logger.info(f"[{self.__class__.__name__}] Retrieved style guidelines from SIF: {style_guidelines.get('tone')}")
|
||||
# Note: get_seo_context returns a flattened dict.
|
||||
# We need to dig into the original structure if available, or rely on what's mapped.
|
||||
# However, get_seo_context maps 'seo_audit', 'sitemap_analysis', etc.
|
||||
# Brand info is usually in 'brand_analysis' col of WebsiteAnalysis, which might not be fully exposed
|
||||
# in the simplified get_seo_context return.
|
||||
# Let's check if we can get the full object or if we need to expand get_seo_context.
|
||||
# For now, we'll try to use what's there or fall back to a specific search if needed.
|
||||
|
||||
# Actually, looking at get_seo_context implementation:
|
||||
# It returns 'seo_audit', 'crawl_result'.
|
||||
# Brand analysis is often stored in WebsiteAnalysis.brand_analysis.
|
||||
# We might need to extend get_seo_context or do a specific retrieval here.
|
||||
# But wait! I saw get_seo_context implementation earlier:
|
||||
# It retrieves the "full_report" from the SIF metadata.
|
||||
# If the SIF index contains the full WebsiteAnalysis object, we are good.
|
||||
|
||||
# Let's try to get it from the full report if we can access it,
|
||||
# but get_seo_context returns a filtered dict.
|
||||
|
||||
# Alternative: Use the robust retrieval logic but specifically for brand info if get_seo_context is too narrow.
|
||||
# But get_seo_context logic includes "website analysis seo audit" query.
|
||||
|
||||
# Let's assume for now we use the same retrieval logic but locally adapted,
|
||||
# OR better, trust get_seo_context to be the single point of truth.
|
||||
# If get_seo_context doesn't return brand info, we should update IT, not hack here.
|
||||
# But I can't update SIFIntegrationService right now without context switch.
|
||||
|
||||
# Let's stick to the previous manual search pattern BUT use the SIF service helper if possible.
|
||||
# Actually, the previous code was:
|
||||
# results = await self.intelligence.search("website analysis brand voice style", limit=1)
|
||||
|
||||
# Let's keep it simple and robust:
|
||||
# Try to get it from SIF service if possible.
|
||||
# Since get_seo_context might not return brand_voice directly, let's try to see if we can use it.
|
||||
|
||||
# Actually, let's use the manual search but with better error handling,
|
||||
# mirroring get_seo_context's robustness (e.g. parsing).
|
||||
|
||||
results = await self.intelligence.search("website analysis brand voice style", limit=1)
|
||||
if results:
|
||||
res = results[0]
|
||||
metadata_str = res.get('object')
|
||||
metadata = json.loads(metadata_str) if isinstance(metadata_str, str) else (metadata_str or res)
|
||||
|
||||
if metadata.get('type') == 'website_analysis':
|
||||
report = metadata.get('full_report', {})
|
||||
# Support both flat and nested structures
|
||||
brand_analysis = report.get('brand_analysis') or report.get('brand_voice', {})
|
||||
if isinstance(brand_analysis, str):
|
||||
# Handle case where it might be a JSON string
|
||||
try: brand_analysis = json.loads(brand_analysis)
|
||||
except: brand_analysis = {"brand_voice": brand_analysis}
|
||||
|
||||
style_guidelines = {
|
||||
"tone": brand_analysis.get('brand_voice', 'neutral') if isinstance(brand_analysis, dict) else 'neutral',
|
||||
"style_patterns": report.get('style_patterns', {}),
|
||||
"writing_style": report.get('writing_style', {})
|
||||
}
|
||||
logger.info(f"[{self.__class__.__name__}] Retrieved style guidelines from SIF index")
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] Failed to retrieve style guidelines from SIF: {e}")
|
||||
logger.warning(f"[{self.__class__.__name__}] Failed to retrieve style guidelines: {e}")
|
||||
|
||||
issues = []
|
||||
score = 1.0
|
||||
@@ -246,6 +388,55 @@ class ContentGuardianAgent(SIFBaseAgent):
|
||||
logger.error(f"[{self.__class__.__name__}] Style enforcement failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def perform_site_audit(self, website_url: str, limit: int = 10) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform a quality audit on the user's website content.
|
||||
"""
|
||||
self._log_agent_operation("Performing site audit", website_url=website_url)
|
||||
|
||||
try:
|
||||
# 1. Retrieve recent content for the site from SIF
|
||||
# We search for everything with the website_url in metadata
|
||||
# Note: This depends on how data is indexed.
|
||||
results = await self.intelligence.search(f"site:{website_url}", limit=limit)
|
||||
|
||||
if not results:
|
||||
logger.info(f"[{self.__class__.__name__}] No content found for site audit")
|
||||
return {"error": "No content found"}
|
||||
|
||||
audit_results = []
|
||||
total_quality = 0.0
|
||||
|
||||
for res in results:
|
||||
text = res.get('text', '')
|
||||
if not text or len(text) < 100:
|
||||
continue
|
||||
|
||||
quality = await self.assess_content_quality(text)
|
||||
audit_results.append({
|
||||
"id": res.get('id'),
|
||||
"title": res.get('title', 'Unknown'),
|
||||
"quality": quality
|
||||
})
|
||||
total_quality += quality.get('quality_score', 0.0)
|
||||
|
||||
avg_quality = total_quality / len(audit_results) if audit_results else 0.0
|
||||
|
||||
report = {
|
||||
"website_url": website_url,
|
||||
"pages_audited": len(audit_results),
|
||||
"average_quality_score": avg_quality,
|
||||
"details": audit_results,
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
logger.info(f"[{self.__class__.__name__}] Site audit completed. Avg Quality: {avg_quality:.2f}")
|
||||
return report
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] Site audit failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
async def safety_filter(self, text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Tool: Flags potentially harmful, offensive, or sensitive content.
|
||||
@@ -290,8 +481,8 @@ class LinkGraphAgent(SIFBaseAgent):
|
||||
RELEVANCE_THRESHOLD = 0.6 # Minimum relevance score for link suggestions
|
||||
MAX_SUGGESTIONS = 10 # Maximum number of link suggestions
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
|
||||
super().__init__(intelligence_service)
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
|
||||
super().__init__(intelligence_service, user_id, agent_type="link_graph")
|
||||
self.sif_service = sif_service
|
||||
|
||||
async def suggest_internal_links(self, draft: str) -> List[Dict[str, Any]]:
|
||||
@@ -823,9 +1014,10 @@ class ContentStrategyAgent(BaseALwrityAgent):
|
||||
Maintain the original meaning and tone.
|
||||
"""
|
||||
|
||||
if hasattr(self.llm, "generate"):
|
||||
if self.llm:
|
||||
# We assume the LLM returns JSON-like text or we parse it
|
||||
response = self.llm.generate(f"{system_prompt}\n\nText to rewrite:\n{content}")
|
||||
response = await self._generate_llm_response(f"{system_prompt}\n\nText to rewrite:\n{content}")
|
||||
|
||||
# Simple parsing fallback if LLM returns raw text
|
||||
if isinstance(response, str) and not response.strip().startswith("{"):
|
||||
optimized_content = response
|
||||
@@ -1456,34 +1648,7 @@ class SEOOptimizationAgent(BaseALwrityAgent):
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
async def _generate_llm_response(self, prompt: str) -> str:
|
||||
"""Helper to generate text using the agent's LLM"""
|
||||
if not self.llm:
|
||||
return "[LLM Unavailable]"
|
||||
|
||||
try:
|
||||
# Run in executor to avoid blocking if LLM is synchronous
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Check if LLM is a txtai pipeline (callable) or has generate method
|
||||
if hasattr(self.llm, "generate"):
|
||||
# Some txtai pipelines use generate, some are just called
|
||||
response = await loop.run_in_executor(None, lambda: self.llm.generate(prompt))
|
||||
else:
|
||||
# Assume callable (standard txtai pipeline)
|
||||
response = await loop.run_in_executor(None, lambda: self.llm(prompt))
|
||||
|
||||
# Handle list output (some models return list of dicts)
|
||||
if isinstance(response, list):
|
||||
if response and isinstance(response[0], dict) and 'generated_text' in response[0]:
|
||||
return response[0]['generated_text']
|
||||
return str(response[0])
|
||||
|
||||
return str(response)
|
||||
except Exception as e:
|
||||
logger.error(f"LLM generation failed: {e}")
|
||||
return "[Generation Failed]"
|
||||
|
||||
|
||||
async def _strategy_generator_tool(self, context: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""SEO strategy generation tool"""
|
||||
audit_results = context.get("audit_results", {})
|
||||
@@ -1629,8 +1794,8 @@ class SocialAmplificationAgent(BaseALwrityAgent):
|
||||
Return ONLY the adapted content.
|
||||
"""
|
||||
|
||||
if hasattr(self.llm, "generate"):
|
||||
adapted_content = self.llm.generate(prompt)
|
||||
if self.llm:
|
||||
adapted_content = await self._generate_llm_response(prompt)
|
||||
else:
|
||||
adapted_content = f"[Mock {platform}]: {content[:50]}... #adapted"
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ class TrendSurferAgent(SIFBaseAgent):
|
||||
"""
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
|
||||
super().__init__(intelligence_service)
|
||||
super().__init__(intelligence_service, user_id, agent_type="trend_surfer")
|
||||
self.user_id = user_id
|
||||
self.signal_detector = MarketSignalDetector(user_id)
|
||||
self.trends_service = GoogleTrendsService()
|
||||
@@ -148,15 +148,41 @@ class TrendSurferAgent(SIFBaseAgent):
|
||||
else:
|
||||
recommendation = "Create new content"
|
||||
|
||||
# Use LLM to generate creative angle
|
||||
headline = f"Trend: {trend.description}"
|
||||
angle = f"Leverage {trend.source} trend on {trend.related_topics[0] if trend.related_topics else 'topic'}"
|
||||
|
||||
try:
|
||||
prompt = f"""
|
||||
Analyze this market trend signal and propose a content angle:
|
||||
Trend: {trend.description}
|
||||
Related Topics: {', '.join(trend.related_topics)}
|
||||
Impact Score: {trend.impact_score}
|
||||
Recommendation: {recommendation}
|
||||
|
||||
Provide a catchy headline and a 1-sentence strategic angle.
|
||||
Format: Headline | Angle
|
||||
"""
|
||||
response = await self._generate_llm_response(prompt)
|
||||
if response and "|" in response:
|
||||
parts = response.split('|')
|
||||
headline = parts[0].strip()
|
||||
angle = parts[1].strip()
|
||||
elif response:
|
||||
angle = response.strip()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.__class__.__name__}] LLM generation failed for opportunity: {e}")
|
||||
|
||||
return {
|
||||
"trend_id": trend.signal_id,
|
||||
"topic": trend.description,
|
||||
"headline": headline,
|
||||
"source": trend.source,
|
||||
"urgency": trend.urgency_level.value,
|
||||
"impact_score": trend.impact_score,
|
||||
"current_coverage": coverage_score,
|
||||
"recommendation": recommendation,
|
||||
"suggested_angle": f"Leverage {trend.source} trend on {trend.related_topics[0] if trend.related_topics else 'topic'}",
|
||||
"suggested_angle": angle,
|
||||
"detected_at": trend.detected_at
|
||||
}
|
||||
|
||||
|
||||
@@ -5,13 +5,76 @@ Each agent leverages TxtaiIntelligenceService for semantic operations.
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import json
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from .txtai_service import TxtaiIntelligenceService
|
||||
from .txtai_service import TxtaiIntelligenceService, TXTAI_AVAILABLE
|
||||
from services.intelligence.agents.core_agent_framework import BaseALwrityAgent
|
||||
from services.llm_providers.main_text_generation import llm_text_gen
|
||||
|
||||
class SIFBaseAgent:
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService):
|
||||
# Optional txtai imports
|
||||
try:
|
||||
from txtai.pipeline import Agent, LLM
|
||||
except ImportError:
|
||||
Agent = None
|
||||
LLM = None
|
||||
|
||||
class SharedLLMWrapper:
|
||||
"""Wraps the shared ALwrity LLM service to look like a txtai LLM."""
|
||||
def __init__(self, user_id: str):
|
||||
self.user_id = user_id
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> str:
|
||||
"""Generate text using the shared LLM provider."""
|
||||
# We ignore kwargs like 'max_tokens' as llm_text_gen handles defaults,
|
||||
# but we could map them if needed.
|
||||
return llm_text_gen(prompt, user_id=self.user_id)
|
||||
|
||||
def __call__(self, prompt: str, **kwargs) -> str:
|
||||
return self.generate(prompt, **kwargs)
|
||||
|
||||
class LocalLLMWrapper:
|
||||
"""
|
||||
Lazily loads a local LLM via txtai.
|
||||
This prevents blocking server startup with heavy model loads.
|
||||
"""
|
||||
def __init__(self, model_path: str):
|
||||
self.model_path = model_path
|
||||
self._llm = None
|
||||
|
||||
@property
|
||||
def llm(self):
|
||||
if self._llm is None:
|
||||
if LLM is None:
|
||||
raise ImportError("txtai.pipeline.LLM is not available")
|
||||
logger.info(f"Loading local LLM: {self.model_path}")
|
||||
self._llm = LLM(path=self.model_path)
|
||||
return self._llm
|
||||
|
||||
def __call__(self, prompt: str, **kwargs) -> str:
|
||||
return self.llm(prompt, **kwargs)
|
||||
|
||||
def generate(self, prompt: str, **kwargs) -> str:
|
||||
return self.llm(prompt, **kwargs)
|
||||
|
||||
class SIFBaseAgent(BaseALwrityAgent):
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, agent_type: str = "sif_agent", model_name: str = "Qwen/Qwen2.5-3B-Instruct", llm: Any = None):
|
||||
# Hybrid LLM Strategy:
|
||||
# 1. Shared LLM for external/high-quality generation (available to all agents)
|
||||
self.shared_llm = SharedLLMWrapper(user_id)
|
||||
|
||||
# 2. Local LLM for internal agent work (default for SIF agents)
|
||||
if llm is None:
|
||||
if TXTAI_AVAILABLE:
|
||||
# Use Lazy Local LLM
|
||||
llm = LocalLLMWrapper(model_name)
|
||||
else:
|
||||
# Fallback to Shared if txtai not available
|
||||
llm = self.shared_llm
|
||||
|
||||
super().__init__(user_id, agent_type, model_name, llm)
|
||||
self.intelligence = intelligence_service
|
||||
|
||||
def _log_agent_operation(self, operation: str, **kwargs):
|
||||
@@ -20,9 +83,23 @@ class SIFBaseAgent:
|
||||
if kwargs:
|
||||
logger.debug(f"[{self.__class__.__name__}] Parameters: {kwargs}")
|
||||
|
||||
def _create_txtai_agent(self):
|
||||
"""
|
||||
SIF agents use the intelligence service directly, but we can expose
|
||||
capabilities via a standard agent interface if needed.
|
||||
"""
|
||||
if not TXTAI_AVAILABLE:
|
||||
return None
|
||||
|
||||
# Return a simple agent that can use the LLM
|
||||
return Agent(llm=self.llm, tools=[])
|
||||
|
||||
class StrategyArchitectAgent(SIFBaseAgent):
|
||||
"""Agent for discovering content pillars and identifying strategic gaps."""
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
|
||||
super().__init__(intelligence_service, user_id, agent_type="strategy_architect")
|
||||
|
||||
async def discover_pillars(self) -> List[Dict[str, Any]]:
|
||||
"""Identify content pillars through semantic clustering."""
|
||||
self._log_agent_operation("Discovering content pillars")
|
||||
@@ -58,6 +135,61 @@ class StrategyArchitectAgent(SIFBaseAgent):
|
||||
logger.error(f"[{self.__class__.__name__}] Failed to discover pillars: {e}")
|
||||
logger.error(f"[{self.__class__.__name__}] Full traceback: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
async def analyze_content_strategy(self, website_data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Analyze content strategy based on website data and semantic insights.
|
||||
|
||||
Args:
|
||||
website_data: Dictionary containing website analysis data
|
||||
|
||||
Returns:
|
||||
List of strategic recommendations
|
||||
"""
|
||||
self._log_agent_operation("Analyzing content strategy")
|
||||
|
||||
try:
|
||||
recommendations = []
|
||||
|
||||
# 1. Discover existing pillars
|
||||
pillars = await self.discover_pillars()
|
||||
|
||||
# 2. Analyze gaps based on pillars (simplified logic for now)
|
||||
if not pillars:
|
||||
recommendations.append({
|
||||
"type": "strategy_gap",
|
||||
"priority": "high",
|
||||
"title": "Establish Core Content Pillars",
|
||||
"description": "No clear content clusters found. Focus on defining 3-5 core topics to build authority."
|
||||
})
|
||||
else:
|
||||
# Suggest strengthening weak pillars
|
||||
for pillar in pillars:
|
||||
if pillar['size'] < 3:
|
||||
recommendations.append({
|
||||
"type": "content_depth",
|
||||
"priority": "medium",
|
||||
"title": f"Strengthen Pillar {pillar['pillar_id']}",
|
||||
"description": "This topic cluster has few articles. Create more content to establish authority.",
|
||||
"pillar_id": pillar['pillar_id']
|
||||
})
|
||||
|
||||
# 3. Add generic recommendations based on website data if available
|
||||
if website_data:
|
||||
if not website_data.get('description'):
|
||||
recommendations.append({
|
||||
"type": "metadata",
|
||||
"priority": "high",
|
||||
"title": "Missing Meta Description",
|
||||
"description": "Website is missing a meta description. Add one to improve SEO CTR."
|
||||
})
|
||||
|
||||
logger.info(f"[{self.__class__.__name__}] Generated {len(recommendations)} strategic recommendations")
|
||||
return recommendations
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] Failed to analyze content strategy: {e}")
|
||||
return []
|
||||
|
||||
def _calculate_cluster_confidence(self, cluster_indices: List[int]) -> float:
|
||||
"""Calculate confidence score for a cluster based on its size and coherence."""
|
||||
@@ -92,10 +224,40 @@ class ContentGuardianAgent(SIFBaseAgent):
|
||||
CANNIBALIZATION_THRESHOLD = 0.85 # Similarity threshold for cannibalization warning
|
||||
ORIGINALITY_THRESHOLD = 0.75 # Minimum originality score
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
|
||||
super().__init__(intelligence_service)
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
|
||||
super().__init__(intelligence_service, user_id, agent_type="content_guardian")
|
||||
self.sif_service = sif_service
|
||||
|
||||
async def assess_content_quality(self, website_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Assess overall content quality based on website data."""
|
||||
self._log_agent_operation("Assessing content quality")
|
||||
try:
|
||||
# Extract sample text or description from website_data
|
||||
text_to_analyze = website_data.get('description', '') or website_data.get('title', '')
|
||||
if not text_to_analyze:
|
||||
return {"score": 0.5, "reason": "No content to analyze"}
|
||||
|
||||
# Run style check
|
||||
style_result = await self.style_enforcer(text_to_analyze)
|
||||
|
||||
# Run safety check
|
||||
safety_result = await self.safety_filter(text_to_analyze)
|
||||
|
||||
# Calculate aggregate score
|
||||
base_score = style_result.get('compliance_score', 0.8)
|
||||
if safety_result.get('action') == 'flag_for_review':
|
||||
base_score *= 0.5
|
||||
|
||||
return {
|
||||
"score": base_score,
|
||||
"style_analysis": style_result,
|
||||
"safety_analysis": safety_result,
|
||||
"analyzed_text_length": len(text_to_analyze)
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] Quality assessment failed: {e}")
|
||||
return {"score": 0.0, "error": str(e)}
|
||||
|
||||
async def check_cannibalization(self, new_draft: str) -> Dict[str, Any]:
|
||||
"""Check if a new draft competes semantically with existing pages."""
|
||||
self._log_agent_operation("Checking for semantic cannibalization", draft_length=len(new_draft))
|
||||
@@ -274,8 +436,8 @@ class LinkGraphAgent(SIFBaseAgent):
|
||||
RELEVANCE_THRESHOLD = 0.6 # Minimum relevance score for link suggestions
|
||||
MAX_SUGGESTIONS = 10 # Maximum number of link suggestions
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, sif_service: Any = None):
|
||||
super().__init__(intelligence_service)
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str, sif_service: Any = None):
|
||||
super().__init__(intelligence_service, user_id, agent_type="link_graph")
|
||||
self.sif_service = sif_service
|
||||
|
||||
async def suggest_internal_links(self, draft: str) -> List[Dict[str, Any]]:
|
||||
@@ -479,6 +641,9 @@ class CitationExpert(SIFBaseAgent):
|
||||
EVIDENCE_THRESHOLD = 0.7 # Minimum relevance score for evidence
|
||||
MAX_EVIDENCE = 5 # Maximum number of evidence pieces to return
|
||||
|
||||
def __init__(self, intelligence_service: TxtaiIntelligenceService, user_id: str):
|
||||
super().__init__(intelligence_service, user_id, agent_type="citation_expert")
|
||||
|
||||
async def fact_checker(self, claim: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Tool: Verifies facts against trusted research data.
|
||||
@@ -542,60 +707,25 @@ class CitationExpert(SIFBaseAgent):
|
||||
"claim": claim,
|
||||
"status": status,
|
||||
"evidence_count": len(evidence),
|
||||
"top_evidence": evidence[0]['source'] if evidence else None
|
||||
"top_evidence": evidence[0] if evidence else None
|
||||
})
|
||||
|
||||
return {
|
||||
"status": "verification_complete",
|
||||
"total_claims": len(claims),
|
||||
"status": "completed",
|
||||
"verified_claims": verified_results,
|
||||
"unsupported_count": len([c for c in verified_results if c['status'] == 'unsupported']),
|
||||
"timestamp": datetime.utcnow().isoformat()
|
||||
"verification_score": len([c for c in verified_results if c['status'] == 'supported']) / len(verified_results)
|
||||
}
|
||||
|
||||
async def verify_facts(self, claim: str) -> List[Dict[str, Any]]:
|
||||
"""Find supporting or contradicting evidence in the indexed research."""
|
||||
self._log_agent_operation("Verifying facts", claim_length=len(claim))
|
||||
"""Verify a single claim against intelligence data."""
|
||||
results = await self.intelligence.search(claim, limit=3)
|
||||
|
||||
try:
|
||||
if not self.intelligence.is_initialized():
|
||||
logger.error(f"[{self.__class__.__name__}] Intelligence service not initialized")
|
||||
return []
|
||||
|
||||
if not claim or len(claim.strip()) < 20:
|
||||
logger.warning(f"[{self.__class__.__name__}] Claim too short for meaningful verification")
|
||||
return []
|
||||
|
||||
results = await self.intelligence.search(claim, limit=self.MAX_EVIDENCE)
|
||||
|
||||
if not results:
|
||||
logger.info(f"[{self.__class__.__name__}] No evidence found for claim")
|
||||
return []
|
||||
|
||||
evidence = []
|
||||
for result in results:
|
||||
relevance_score = result.get('score', 0.0)
|
||||
|
||||
if relevance_score >= self.EVIDENCE_THRESHOLD:
|
||||
evidence_piece = {
|
||||
"source": result.get('id', 'unknown'),
|
||||
"relevance": relevance_score,
|
||||
"confidence": self._calculate_evidence_confidence(relevance_score),
|
||||
"type": "supporting" if relevance_score > 0.8 else "related",
|
||||
"excerpt": result.get('text', '')[:200] + "..." if len(result.get('text', '')) > 200 else result.get('text', '')
|
||||
}
|
||||
evidence.append(evidence_piece)
|
||||
logger.debug(f"[{self.__class__.__name__}] Found evidence: {evidence_piece['source']} (score: {relevance_score:.3f})")
|
||||
|
||||
logger.info(f"[{self.__class__.__name__}] Found {len(evidence)} pieces of evidence for claim")
|
||||
return evidence
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.__class__.__name__}] Failed to verify facts: {e}")
|
||||
logger.error(f"[{self.__class__.__name__}] Full traceback: {traceback.format_exc()}")
|
||||
return []
|
||||
|
||||
def _calculate_evidence_confidence(self, relevance_score: float) -> float:
|
||||
"""Calculate confidence score for evidence."""
|
||||
# Simple confidence based on relevance score
|
||||
return min(1.0, relevance_score * 1.2)
|
||||
evidence = []
|
||||
for result in results:
|
||||
if result.get('score', 0) > self.EVIDENCE_THRESHOLD:
|
||||
evidence.append({
|
||||
"text": result.get('text'),
|
||||
"source": result.get('id'),
|
||||
"confidence": result.get('score')
|
||||
})
|
||||
return evidence
|
||||
|
||||
@@ -938,14 +938,14 @@ class SIFIntegrationService:
|
||||
# Strategic recommendations (lazy initialization to avoid circular imports)
|
||||
if not self.strategy_agent:
|
||||
from .sif_agents import StrategyArchitectAgent
|
||||
self.strategy_agent = StrategyArchitectAgent(self.intelligence_service)
|
||||
self.strategy_agent = StrategyArchitectAgent(self.intelligence_service, user_id=self.user_id)
|
||||
recommendations = await self.strategy_agent.analyze_content_strategy(website_data)
|
||||
insights["strategic_recommendations"] = recommendations
|
||||
|
||||
# Content quality assessment (lazy initialization to avoid circular imports)
|
||||
if not self.guardian_agent:
|
||||
from .sif_agents import ContentGuardianAgent
|
||||
self.guardian_agent = ContentGuardianAgent(self.intelligence_service, sif_service=self)
|
||||
self.guardian_agent = ContentGuardianAgent(self.intelligence_service, user_id=self.user_id, sif_service=self)
|
||||
quality_score = await self.guardian_agent.assess_content_quality(website_data)
|
||||
insights["content_quality"] = quality_score
|
||||
|
||||
|
||||
@@ -33,7 +33,13 @@ class TxtaiIntelligenceService:
|
||||
self._initialized = False
|
||||
self.enable_caching = enable_caching
|
||||
self.cache_manager = semantic_cache_manager if enable_caching else None
|
||||
self._initialize_embeddings()
|
||||
# Lazy initialization - do not initialize embeddings on startup
|
||||
# self._initialize_embeddings()
|
||||
|
||||
def _ensure_initialized(self):
|
||||
"""Lazy initialization helper."""
|
||||
if not self._initialized:
|
||||
self._initialize_embeddings()
|
||||
|
||||
def _initialize_embeddings(self):
|
||||
"""Initialize txtai embeddings with local storage support and comprehensive error handling."""
|
||||
@@ -106,6 +112,7 @@ class TxtaiIntelligenceService:
|
||||
Args:
|
||||
items: List of (id, text, metadata) tuples.
|
||||
"""
|
||||
self._ensure_initialized()
|
||||
if not self._initialized or not self.embeddings:
|
||||
logger.error(f"Cannot index content - service not initialized for user {self.user_id}")
|
||||
return
|
||||
@@ -145,6 +152,7 @@ class TxtaiIntelligenceService:
|
||||
|
||||
async def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]:
|
||||
"""Perform semantic search with intelligent caching."""
|
||||
self._ensure_initialized()
|
||||
if not self._initialized or not self.embeddings:
|
||||
logger.error(f"Cannot perform search - service not initialized for user {self.user_id}")
|
||||
return []
|
||||
@@ -186,6 +194,7 @@ class TxtaiIntelligenceService:
|
||||
|
||||
async def get_similarity(self, text1: str, text2: str) -> float:
|
||||
"""Get semantic similarity between two texts with caching."""
|
||||
self._ensure_initialized()
|
||||
if not self._initialized or not self.embeddings:
|
||||
logger.error(f"Cannot calculate similarity - service not initialized for user {self.user_id}")
|
||||
return 0.0
|
||||
@@ -234,6 +243,7 @@ class TxtaiIntelligenceService:
|
||||
|
||||
async def cluster(self, min_score: float = 0.5) -> List[List[int]]:
|
||||
"""Cluster indexed content to find semantic pillars using graph-based clustering with caching."""
|
||||
self._ensure_initialized()
|
||||
if not self._initialized or not self.embeddings:
|
||||
logger.error(f"Cannot cluster content - service not initialized for user {self.user_id}")
|
||||
return []
|
||||
@@ -358,6 +368,7 @@ class TxtaiIntelligenceService:
|
||||
|
||||
async def classify(self, text: str, labels: List[str]) -> List[Tuple[str, float]]:
|
||||
"""Classify text using zero-shot classification."""
|
||||
self._ensure_initialized()
|
||||
if not self._initialized or not Labels:
|
||||
logger.error(f"Cannot classify text - service not initialized or Labels not available for user {self.user_id}")
|
||||
return []
|
||||
|
||||
@@ -297,7 +297,7 @@ def _dict_to_types_schema(schema: Dict[str, Any]) -> types.Schema:
|
||||
return _convert(schema)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None):
|
||||
def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9, top_k=40, max_tokens=8192, system_prompt=None, user_id: str = None):
|
||||
"""
|
||||
Generate structured JSON response using Google's Gemini Pro model.
|
||||
|
||||
@@ -312,6 +312,7 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
top_k (int): Top-k sampling parameter
|
||||
max_tokens (int): Maximum tokens in response. Use 8192 for complex outputs
|
||||
system_prompt (str, optional): System instruction for the model
|
||||
user_id (str, optional): User ID for usage tracking.
|
||||
|
||||
Returns:
|
||||
dict: Parsed JSON response matching the provided schema
|
||||
@@ -468,6 +469,25 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
logger.info(f"Response has parsed attribute: {response.parsed is not None}")
|
||||
if response.parsed is not None:
|
||||
logger.info("Using response.parsed for structured output")
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
import json
|
||||
|
||||
response_str = json.dumps(response.parsed)
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=response_str,
|
||||
duration=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
return response.parsed
|
||||
else:
|
||||
logger.warning("Response.parsed is None, falling back to text parsing")
|
||||
@@ -500,6 +520,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
|
||||
parsed_text = json.loads(cleaned_text)
|
||||
logger.info("Successfully parsed text as JSON")
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=cleaned_text,
|
||||
duration=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
return parsed_text
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse text as JSON: {e}")
|
||||
@@ -521,6 +557,26 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
fixed_json = re.sub(r',\s*]', ']', fixed_json)
|
||||
|
||||
parsed_text = json.loads(fixed_json)
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
import json
|
||||
|
||||
response_str = json.dumps(parsed_text) if parsed_text else ""
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=response_str,
|
||||
duration=0.5 # Approximation
|
||||
)
|
||||
logger.info(f"✅ Tracked structured JSON usage for user {user_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
logger.info("Successfully parsed cleaned JSON")
|
||||
return parsed_text
|
||||
except Exception as fix_error:
|
||||
@@ -537,6 +593,22 @@ def gemini_structured_json_response(prompt, schema, temperature=0.7, top_p=0.9,
|
||||
import json
|
||||
parsed_text = json.loads(part.text)
|
||||
logger.info("Successfully parsed candidate text as JSON")
|
||||
|
||||
# Track usage if user_id is provided
|
||||
if user_id:
|
||||
try:
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name="gemini-2.5-flash",
|
||||
prompt=prompt,
|
||||
response_text=part.text,
|
||||
duration=0.5
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to track usage: {e}")
|
||||
|
||||
return parsed_text
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse candidate text as JSON: {e}")
|
||||
|
||||
@@ -4,6 +4,7 @@ import io
|
||||
import os
|
||||
from typing import Optional
|
||||
from PIL import Image
|
||||
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
||||
|
||||
from .base import ImageGenerationProvider, ImageGenerationOptions, ImageGenerationResult
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
@@ -14,7 +15,10 @@ logger = get_service_logger("wavespeed.image_provider")
|
||||
|
||||
|
||||
class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen."""
|
||||
"""WaveSpeed AI image generation provider supporting Ideogram V3 and Qwen.
|
||||
|
||||
Implements robust error handling and retries for production stability.
|
||||
"""
|
||||
|
||||
SUPPORTED_MODELS = {
|
||||
"ideogram-v3-turbo": {
|
||||
@@ -54,6 +58,28 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
logger.info("[WaveSpeed Image Provider] Initialized with available models: %s",
|
||||
list(self.SUPPORTED_MODELS.keys()))
|
||||
|
||||
@retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=2, max=10),
|
||||
retry=retry_if_exception_type((RuntimeError, IOError)),
|
||||
reraise=True
|
||||
)
|
||||
def _call_api_with_retry(self, method, **kwargs):
|
||||
"""Execute API call with retry logic.
|
||||
|
||||
Args:
|
||||
method: Callable API method
|
||||
**kwargs: Arguments for the method
|
||||
|
||||
Returns:
|
||||
API response
|
||||
"""
|
||||
try:
|
||||
return method(**kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"WaveSpeed API call failed (retrying): {str(e)}")
|
||||
raise
|
||||
|
||||
def _validate_options(self, options: ImageGenerationOptions) -> None:
|
||||
"""Validate generation options.
|
||||
|
||||
@@ -117,7 +143,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
|
||||
# Call WaveSpeed API (using generic image generation method)
|
||||
# This will need to be adjusted based on actual WaveSpeed client implementation
|
||||
result = self.client.generate_image(**params)
|
||||
result = self._call_api_with_retry(self.client.generate_image, **params)
|
||||
|
||||
# Extract image bytes from result
|
||||
# Adjust based on actual WaveSpeed API response format
|
||||
@@ -167,7 +193,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
result = self._call_api_with_retry(self.client.generate_image, **params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
@@ -216,7 +242,7 @@ class WaveSpeedImageProvider(ImageGenerationProvider):
|
||||
params["seed"] = options.seed
|
||||
|
||||
# Call WaveSpeed API
|
||||
result = self.client.generate_image(**params)
|
||||
result = self._call_api_with_retry(self.client.generate_image, **params)
|
||||
|
||||
# Extract image bytes from result
|
||||
if isinstance(result, bytes):
|
||||
|
||||
@@ -107,11 +107,13 @@ def generate_audio(
|
||||
estimated_cost = (character_count / 1000.0) * cost_per_1000_chars
|
||||
|
||||
try:
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import UsageSummary, APIProvider
|
||||
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
|
||||
@@ -194,7 +196,11 @@ def generate_audio(
|
||||
if audio_bytes:
|
||||
logger.info(f"[audio_gen] ✅ API call successful, tracking usage for user {user_id}")
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[audio_gen] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
@@ -383,12 +389,14 @@ def clone_voice(
|
||||
|
||||
voice_clone_cost = 0.5
|
||||
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
@@ -432,7 +440,11 @@ def clone_voice(
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[clone_voice] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
@@ -570,12 +582,14 @@ def qwen3_voice_clone(
|
||||
char_count = len(text)
|
||||
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
|
||||
|
||||
from services.database import get_db
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = next(get_db())
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
@@ -615,7 +629,11 @@ def qwen3_voice_clone(
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = next(get_db())
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[qwen3_voice_clone] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
@@ -691,6 +709,7 @@ def qwen3_voice_clone(
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: wavespeed-ai/qwen3-tts/voice-clone
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Text chars: {char_count}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
@@ -724,3 +743,373 @@ def qwen3_voice_clone(
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def qwen3_voice_design(
|
||||
text: str,
|
||||
voice_description: str,
|
||||
*,
|
||||
language: str = "auto",
|
||||
user_id: Optional[str] = None,
|
||||
) -> VoiceCloneResult:
|
||||
try:
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
if not text or not isinstance(text, str) or len(text.strip()) == 0:
|
||||
raise ValueError("Text is required and cannot be empty")
|
||||
text = text.strip()
|
||||
|
||||
if not voice_description or not isinstance(voice_description, str) or len(voice_description.strip()) == 0:
|
||||
raise ValueError("Voice description is required")
|
||||
voice_description = voice_description.strip()
|
||||
|
||||
char_count = len(text)
|
||||
# Pricing logic similar to TTS/Clone
|
||||
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=char_count,
|
||||
actual_provider_name="wavespeed",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": message,
|
||||
"message": message,
|
||||
"provider": "wavespeed",
|
||||
"usage_info": usage_info if usage_info else {},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
preview_audio_bytes = client.voice_design(
|
||||
text=text,
|
||||
voice_description=voice_description,
|
||||
language=language
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
# Track usage
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[qwen3_voice_design] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text as sql_text
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + float(estimated_cost)
|
||||
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
"new_calls": new_calls,
|
||||
"new_cost": new_cost,
|
||||
"user_id": user_id,
|
||||
"period": current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="wavespeed-ai/qwen3-tts/voice-design",
|
||||
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed/qwen3-tts/voice-design",
|
||||
method="POST",
|
||||
model_used="wavespeed-ai/qwen3-tts/voice-design",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=char_count,
|
||||
tokens_output=0,
|
||||
tokens_total=char_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=float(estimated_cost),
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=len(text) + len(voice_description),
|
||||
response_size=len(preview_audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
db_track.commit()
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] Qwen3 Voice Design
|
||||
├─ User: {user_id}
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: wavespeed-ai/qwen3-tts/voice-design
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Text chars: {char_count}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
except Exception as track_error:
|
||||
logger.error(f"[qwen3_voice_design] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[qwen3_voice_design] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return VoiceCloneResult(
|
||||
preview_audio_bytes=preview_audio_bytes,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-ai/qwen3-tts/voice-design",
|
||||
custom_voice_id="", # No persistent ID for design usually, unless we save it
|
||||
file_size=len(preview_audio_bytes),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[qwen3_voice_design] Error designing voice: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Qwen3 voice design failed",
|
||||
"message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def cosyvoice_voice_clone(
|
||||
audio_bytes: bytes,
|
||||
text: str,
|
||||
*,
|
||||
reference_text: Optional[str] = None,
|
||||
audio_mime_type: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> VoiceCloneResult:
|
||||
try:
|
||||
if not user_id:
|
||||
raise RuntimeError("user_id is required for subscription checking. Please provide Clerk user ID.")
|
||||
|
||||
if not audio_bytes or not isinstance(audio_bytes, (bytes, bytearray)) or len(audio_bytes) == 0:
|
||||
raise ValueError("Audio is required and cannot be empty")
|
||||
|
||||
if len(audio_bytes) > 15 * 1024 * 1024:
|
||||
raise ValueError("Audio file too large. Maximum is 15MB.")
|
||||
|
||||
if not text or not isinstance(text, str) or len(text.strip()) == 0:
|
||||
raise ValueError("Text is required and cannot be empty")
|
||||
text = text.strip()
|
||||
if len(text) > 4000:
|
||||
raise ValueError("Text too long. Please keep it under 4000 characters.")
|
||||
|
||||
char_count = len(text)
|
||||
estimated_cost = max(0.005, 0.005 * (char_count / 100.0))
|
||||
|
||||
from services.database import get_session_for_user
|
||||
from services.subscription import PricingService
|
||||
from models.subscription_models import APIProvider
|
||||
|
||||
try:
|
||||
db = get_session_for_user(user_id)
|
||||
if not db:
|
||||
raise RuntimeError("Failed to get database session")
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
can_proceed, message, usage_info = pricing_service.check_usage_limits(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
tokens_requested=char_count,
|
||||
actual_provider_name="wavespeed",
|
||||
)
|
||||
if not can_proceed:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail={
|
||||
"error": message,
|
||||
"message": message,
|
||||
"provider": "wavespeed",
|
||||
"usage_info": usage_info if usage_info else {},
|
||||
},
|
||||
)
|
||||
finally:
|
||||
db.close()
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as sub_error:
|
||||
raise RuntimeError(f"Subscription check failed: {str(sub_error)}")
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
client = WaveSpeedClient()
|
||||
preview_audio_bytes = client.cosyvoice_voice_clone(
|
||||
audio_bytes=bytes(audio_bytes),
|
||||
text=text,
|
||||
audio_mime_type=audio_mime_type or "audio/wav",
|
||||
reference_text=reference_text,
|
||||
)
|
||||
response_time = time.time() - start_time
|
||||
|
||||
if preview_audio_bytes:
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
if not db_track:
|
||||
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to get database session for tracking")
|
||||
raise RuntimeError("Failed to get database session")
|
||||
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
from sqlalchemy import text as sql_text
|
||||
from services.subscription.provider_detection import detect_actual_provider
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(user_id=user_id, billing_period=current_period)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
current_calls_before = getattr(summary, "audio_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + float(estimated_cost)
|
||||
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET audio_calls = :new_calls,
|
||||
audio_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
"new_calls": new_calls,
|
||||
"new_cost": new_cost,
|
||||
"user_id": user_id,
|
||||
"period": current_period
|
||||
})
|
||||
|
||||
summary.total_cost = (summary.total_cost or 0.0) + float(estimated_cost)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
actual_provider = detect_actual_provider(
|
||||
provider_enum=APIProvider.AUDIO,
|
||||
model_name="wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.AUDIO,
|
||||
endpoint="/audio-generation/wavespeed/cosyvoice-tts/voice-clone",
|
||||
method="POST",
|
||||
model_used="wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
actual_provider_name=actual_provider,
|
||||
tokens_input=char_count,
|
||||
tokens_output=0,
|
||||
tokens_total=char_count,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=float(estimated_cost),
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=len(audio_bytes) + len(text.encode("utf-8")),
|
||||
response_size=len(preview_audio_bytes),
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
db_track.commit()
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] CosyVoice Voice Clone
|
||||
├─ User: {user_id}
|
||||
├─ Provider: wavespeed
|
||||
├─ Model: wavespeed-ai/cosyvoice-tts/voice-clone
|
||||
├─ Calls: {current_calls_before} → {new_calls}
|
||||
├─ Text chars: {char_count}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
except Exception as track_error:
|
||||
logger.error(f"[cosyvoice_voice_clone] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[cosyvoice_voice_clone] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
|
||||
return VoiceCloneResult(
|
||||
preview_audio_bytes=preview_audio_bytes,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
custom_voice_id="",
|
||||
file_size=len(preview_audio_bytes),
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except RuntimeError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[cosyvoice_voice_clone] Error cloning voice: {e}", exc_info=True)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "CosyVoice voice cloning failed",
|
||||
"message": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import os
|
||||
import io
|
||||
import base64
|
||||
import logging
|
||||
from typing import Optional, Dict, Any
|
||||
from PIL import Image
|
||||
|
||||
@@ -9,6 +11,9 @@ from .image_generation import (
|
||||
ImageGenerationOptions,
|
||||
ImageGenerationResult,
|
||||
)
|
||||
from .image_generation.base import ImageEditOptions
|
||||
from .image_generation.wavespeed_edit_provider import WaveSpeedEditProvider
|
||||
|
||||
from utils.logger_utils import get_service_logger
|
||||
|
||||
try:
|
||||
@@ -22,21 +27,36 @@ logger = get_service_logger("image_editing.facade")
|
||||
|
||||
|
||||
DEFAULT_IMAGE_EDIT_MODEL = os.getenv(
|
||||
"HF_IMAGE_EDIT_MODEL",
|
||||
"Qwen/Qwen-Image-Edit",
|
||||
"WAVESPEED_IMAGE_EDIT_MODEL",
|
||||
"qwen-edit-plus",
|
||||
)
|
||||
|
||||
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
"""Select provider for image editing. Defaults to huggingface with fal-ai."""
|
||||
"""
|
||||
Select the appropriate image editing provider.
|
||||
|
||||
Priority:
|
||||
1. Explicitly requested provider
|
||||
2. WaveSpeed (if API key available) - Preferred for quality/speed
|
||||
3. Hugging Face (fallback)
|
||||
"""
|
||||
if explicit:
|
||||
return explicit
|
||||
# Default to huggingface for image editing (best support for image-to-image)
|
||||
return explicit.lower()
|
||||
|
||||
# Check for WaveSpeed API key first (Preferred provider)
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
# Default to huggingface if WaveSpeed not available
|
||||
return "huggingface"
|
||||
|
||||
|
||||
def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
"""Get InferenceClient for the specified provider."""
|
||||
"""Get the client for the specified provider."""
|
||||
if provider_name == "wavespeed":
|
||||
return WaveSpeedEditProvider(api_key=api_key)
|
||||
|
||||
if not HF_HUB_AVAILABLE:
|
||||
raise RuntimeError("huggingface_hub is not installed. Install with: pip install huggingface_hub")
|
||||
|
||||
@@ -44,7 +64,7 @@ def _get_provider_client(provider_name: str, api_key: Optional[str] = None):
|
||||
api_key = api_key or os.getenv("HF_TOKEN")
|
||||
if not api_key:
|
||||
raise RuntimeError("HF_TOKEN is required for Hugging Face image editing")
|
||||
# Use fal-ai provider for fast inference
|
||||
# Use fal-ai provider for fast inference via HF Inference API
|
||||
return InferenceClient(provider="fal-ai", api_key=api_key)
|
||||
|
||||
raise ValueError(f"Unknown image editing provider: {provider_name}")
|
||||
@@ -86,6 +106,8 @@ def edit_image(
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger.info(f"[Image Editing] 🔍 Starting pre-flight validation for user_id={user_id}")
|
||||
# Note: get_db() is a generator, so we need to use next() to get the session
|
||||
# and ensure we close it in the finally block
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
@@ -99,6 +121,9 @@ def edit_image(
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[Image Editing] ❌ Pre-flight validation failed for user_id={user_id} - blocking API call: {http_ex.detail}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ Unexpected error during pre-flight validation: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Image editing validation failed: {str(e)}")
|
||||
finally:
|
||||
db.close()
|
||||
else:
|
||||
@@ -119,6 +144,69 @@ def edit_image(
|
||||
# Get provider client
|
||||
client = _get_provider_client(provider_name, opts.get("api_key"))
|
||||
|
||||
if provider_name == "wavespeed":
|
||||
# Handle WaveSpeed provider
|
||||
try:
|
||||
# Convert inputs to base64 for WaveSpeed
|
||||
image_b64 = base64.b64encode(input_image_bytes).decode('utf-8')
|
||||
mask_b64 = None
|
||||
if mask_bytes:
|
||||
mask_b64 = base64.b64encode(mask_bytes).decode('utf-8')
|
||||
|
||||
# Determine operation type based on prompt/mask
|
||||
operation = "general_edit" # Default
|
||||
if not prompt and mask_b64:
|
||||
operation = "remove_bg" # Heuristic: mask but no prompt implies removal/in-painting
|
||||
elif prompt and not mask_b64:
|
||||
operation = "style_transfer" # Heuristic: prompt but no mask implies style transfer
|
||||
elif opts.get("operation"):
|
||||
operation = opts.get("operation")
|
||||
|
||||
edit_options = ImageEditOptions(
|
||||
image_base64=image_b64,
|
||||
prompt=prompt.strip(),
|
||||
operation=operation,
|
||||
mask_base64=mask_b64,
|
||||
model=model,
|
||||
guidance_scale=opts.get("guidance_scale"),
|
||||
steps=opts.get("steps"),
|
||||
seed=opts.get("seed"),
|
||||
extra=opts
|
||||
)
|
||||
|
||||
logger.info(f"[Image Editing] Calling WaveSpeed edit with model={model}")
|
||||
result = client.edit(edit_options)
|
||||
|
||||
# TRACK USAGE after successful WaveSpeed call
|
||||
if user_id:
|
||||
try:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
|
||||
# Estimate cost (WaveSpeed default: $0.02)
|
||||
estimated_cost = result.metadata.get("estimated_cost", 0.02) if result.metadata else 0.02
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model=result.model or model,
|
||||
operation_type="image-editing",
|
||||
result_bytes=result.image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-editing",
|
||||
metadata=result.metadata,
|
||||
log_prefix="[Image Editing]"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Image Editing] ❌ WaveSpeed editing failed: {e}", exc_info=True)
|
||||
raise RuntimeError(f"WaveSpeed editing failed: {str(e)}")
|
||||
|
||||
# Hugging Face (Fallback)
|
||||
# Prepare parameters for image-to-image
|
||||
params: Dict[str, Any] = {}
|
||||
if opts.get("guidance_scale") is not None:
|
||||
@@ -170,6 +258,29 @@ def edit_image(
|
||||
|
||||
logger.info(f"[Image Editing] ✅ Successfully edited image: {len(edited_image_bytes)} bytes")
|
||||
|
||||
# TRACK USAGE after successful HF call
|
||||
if user_id:
|
||||
try:
|
||||
from services.llm_providers.main_image_generation import _track_image_operation_usage
|
||||
|
||||
# Estimate cost (HF/Fal-ai default: $0.05)
|
||||
estimated_cost = 0.05
|
||||
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="huggingface",
|
||||
model=model,
|
||||
operation_type="image-editing",
|
||||
result_bytes=edited_image_bytes,
|
||||
cost=estimated_cost,
|
||||
prompt=prompt,
|
||||
endpoint="/image-editing",
|
||||
metadata={"provider": "fal-ai"},
|
||||
log_prefix="[Image Editing]"
|
||||
)
|
||||
except Exception as track_error:
|
||||
logger.warning(f"[Image Editing] ⚠️ Failed to track usage: {track_error}")
|
||||
|
||||
return ImageGenerationResult(
|
||||
image_bytes=edited_image_bytes,
|
||||
width=edited_image.width,
|
||||
|
||||
@@ -5,6 +5,7 @@ import sys
|
||||
import base64
|
||||
from datetime import datetime
|
||||
from typing import Optional, Dict, Any
|
||||
from fastapi import HTTPException
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
|
||||
from .image_generation import (
|
||||
@@ -29,6 +30,11 @@ logger = get_service_logger("image_generation.facade")
|
||||
def _select_provider(explicit: Optional[str]) -> str:
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
# User requested WaveSpeed as default provider
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
gpt_provider = (os.getenv("GPT_PROVIDER") or "").lower()
|
||||
if gpt_provider.startswith("gemini"):
|
||||
return "gemini"
|
||||
@@ -36,8 +42,7 @@ def _select_provider(explicit: Optional[str]) -> str:
|
||||
return "huggingface"
|
||||
if os.getenv("STABILITY_API_KEY"):
|
||||
return "stability"
|
||||
if os.getenv("WAVESPEED_API_KEY"):
|
||||
return "wavespeed"
|
||||
|
||||
# Fallback to huggingface to enable a path if configured
|
||||
return "huggingface"
|
||||
|
||||
@@ -739,18 +744,139 @@ async def generate_image_with_provider(
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_with_provider: {e}")
|
||||
# Propagate specific error message if available
|
||||
error_detail = str(e)
|
||||
if "402" in error_detail or "Payment Required" in error_detail:
|
||||
raise HTTPException(status_code=402, detail=f"Payment Required: {error_detail}")
|
||||
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
"error": error_detail
|
||||
}
|
||||
|
||||
|
||||
import time
|
||||
from services.database import get_session_for_user
|
||||
from models.onboarding import WebsiteAnalysis, OnboardingSession, CompetitorAnalysis
|
||||
|
||||
async def enhance_image_prompt(prompt: str, user_id: Optional[str] = None) -> str:
|
||||
"""
|
||||
Enhance image prompt using LLM.
|
||||
Placeholder implementation.
|
||||
Enhance image prompt using WaveSpeed's specialized prompt optimizer.
|
||||
Restructures and enriches prompts for visual clarity and cinematic detail.
|
||||
Uses Step 2 (Website Analysis) and Step 3 (Competitor Analysis) context if available.
|
||||
"""
|
||||
return prompt
|
||||
start_time = time.time()
|
||||
try:
|
||||
from services.wavespeed.client import WaveSpeedClient
|
||||
|
||||
# 1. Pre-flight Validation
|
||||
if user_id:
|
||||
_validate_image_operation(
|
||||
user_id=user_id,
|
||||
operation_type="prompt-enhancement",
|
||||
num_operations=1,
|
||||
log_prefix="[Prompt Enhancement]"
|
||||
)
|
||||
|
||||
# 2. Fetch Context from Step 2 & 3
|
||||
context_instruction = ""
|
||||
if user_id:
|
||||
try:
|
||||
db_session = get_session_for_user(user_id)
|
||||
try:
|
||||
# Get Onboarding Session
|
||||
session = db_session.query(OnboardingSession).filter(
|
||||
OnboardingSession.user_id == user_id
|
||||
).first()
|
||||
|
||||
if session:
|
||||
# Step 2: Website Analysis
|
||||
website_analysis = db_session.query(WebsiteAnalysis).filter(
|
||||
WebsiteAnalysis.session_id == session.id
|
||||
).first()
|
||||
|
||||
if website_analysis:
|
||||
# Handle potential JSON or dict types
|
||||
brand_voice = website_analysis.brand_analysis
|
||||
style = website_analysis.style_guidelines
|
||||
target_audience = website_analysis.target_audience
|
||||
|
||||
context_instruction += "\n\nCONTEXT FROM WEBSITE ANALYSIS:\n"
|
||||
if target_audience:
|
||||
context_instruction += f"Target Audience: {target_audience}\n"
|
||||
|
||||
if brand_voice and isinstance(brand_voice, dict):
|
||||
context_instruction += f"Brand Voice: {brand_voice.get('voice_characteristics', '')} - {brand_voice.get('tone', '')}\n"
|
||||
|
||||
if style and isinstance(style, dict):
|
||||
context_instruction += f"Visual Style: {style.get('visual_style', '')} - {style.get('color_palette', '')}\n"
|
||||
|
||||
# Step 3: Competitor Analysis (Limit to top 3)
|
||||
competitors = db_session.query(CompetitorAnalysis).filter(
|
||||
CompetitorAnalysis.session_id == session.id
|
||||
).limit(3).all()
|
||||
|
||||
if competitors:
|
||||
context_instruction += "\nCOMPETITOR VISUAL INSIGHTS:\n"
|
||||
for comp in competitors:
|
||||
if comp.analysis_data and isinstance(comp.analysis_data, dict):
|
||||
comp_title = comp.analysis_data.get('title', 'Competitor')
|
||||
# Try to extract visual/content insights if available
|
||||
highlights = comp.analysis_data.get('highlights', [])
|
||||
if highlights:
|
||||
context_instruction += f"- {comp_title}: {', '.join(highlights[:2])}\n"
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
except Exception as db_ex:
|
||||
logger.warning(f"Failed to fetch context for prompt enhancement: {db_ex}")
|
||||
|
||||
# Combine prompt with context
|
||||
full_input_text = prompt
|
||||
if context_instruction:
|
||||
logger.info(f"Enhancing prompt for user {user_id} with Step 2/3 context")
|
||||
# We append context as instruction for the optimizer
|
||||
full_input_text = f"Original Request: {prompt}\n\n{context_instruction}\n\nTask: Generate a hyper-personalized, detailed image generation prompt based on the Original Request and the provided Context. Ensure the visual style aligns with the Brand Voice and Visual Style."
|
||||
else:
|
||||
logger.info(f"Enhancing prompt for user {user_id} (no context found)")
|
||||
|
||||
# 3. Call WaveSpeed
|
||||
client = WaveSpeedClient()
|
||||
# Use 'image' mode for avatar/image generation workflows
|
||||
# Use 'photographic' style as requested for avatars
|
||||
optimized_prompt = client.optimize_prompt(
|
||||
text=full_input_text,
|
||||
mode="image",
|
||||
style="photographic",
|
||||
enable_sync_mode=True,
|
||||
timeout=30
|
||||
)
|
||||
|
||||
# 4. Track Usage
|
||||
if user_id:
|
||||
duration = time.time() - start_time
|
||||
# Track as 0 cost for now unless we have specific pricing for prompt opt
|
||||
# But we track it as an operation
|
||||
_track_image_operation_usage(
|
||||
user_id=user_id,
|
||||
provider="wavespeed",
|
||||
model="wavespeed-prompt-opt",
|
||||
operation_type="prompt-enhancement",
|
||||
result_bytes=b"", # No image
|
||||
cost=0.0,
|
||||
prompt=prompt,
|
||||
endpoint="/enhance-prompt",
|
||||
metadata={"duration": duration, "context_added": bool(context_instruction)},
|
||||
log_prefix="[Prompt Enhancement]",
|
||||
response_time=duration
|
||||
)
|
||||
|
||||
return optimized_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to enhance prompt via WaveSpeed: {e}")
|
||||
# Fallback to original prompt on failure
|
||||
return prompt
|
||||
|
||||
|
||||
async def generate_image_variation(
|
||||
@@ -760,13 +886,123 @@ async def generate_image_variation(
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate variation of an existing image.
|
||||
Placeholder implementation.
|
||||
Generate variation of an existing image using image-to-image editing.
|
||||
Wrapper for step4_asset_routes.
|
||||
"""
|
||||
return {
|
||||
"success": False,
|
||||
"error": "Not implemented yet"
|
||||
}
|
||||
try:
|
||||
# Handle image input (bytes, file, or base64)
|
||||
image_bytes = None
|
||||
if isinstance(image, bytes):
|
||||
image_bytes = image
|
||||
elif hasattr(image, "read"):
|
||||
image_bytes = await image.read()
|
||||
elif isinstance(image, str):
|
||||
# Assume base64 or path
|
||||
if os.path.exists(image):
|
||||
with open(image, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
else:
|
||||
# Try base64 decode
|
||||
try:
|
||||
if "base64," in image:
|
||||
image = image.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not image_bytes:
|
||||
return {"success": False, "error": "Invalid image input"}
|
||||
|
||||
# Convert to base64 for internal function
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Use generate_image_edit with "variation" intent
|
||||
# For variation, we typically use general_edit with specific prompt
|
||||
result = await run_in_threadpool(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt,
|
||||
operation="general_edit",
|
||||
model=kwargs.get("model", "qwen-edit-plus"), # Default to capable model
|
||||
options=kwargs,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": result_base64,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_variation: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
async def generate_image_enhance(
|
||||
image: Any,
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Enhance/Upscale an existing image.
|
||||
Wrapper for step4_asset_routes.
|
||||
"""
|
||||
try:
|
||||
# Handle image input
|
||||
image_bytes = None
|
||||
if isinstance(image, bytes):
|
||||
image_bytes = image
|
||||
elif hasattr(image, "read"):
|
||||
image_bytes = await image.read()
|
||||
elif isinstance(image, str):
|
||||
if os.path.exists(image):
|
||||
with open(image, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
else:
|
||||
try:
|
||||
if "base64," in image:
|
||||
image = image.split("base64,")[1]
|
||||
image_bytes = base64.b64decode(image)
|
||||
except:
|
||||
pass
|
||||
|
||||
if not image_bytes:
|
||||
return {"success": False, "error": "Invalid image input"}
|
||||
|
||||
image_base64 = base64.b64encode(image_bytes).decode('utf-8')
|
||||
|
||||
# Use generate_image_edit with "enhance" intent
|
||||
# Use high-res model like nano-banana-pro-edit-ultra
|
||||
result = await run_in_threadpool(
|
||||
generate_image_edit,
|
||||
image_base64=image_base64,
|
||||
prompt="enhance details, high resolution, professional quality, 4k, sharp focus",
|
||||
operation="general_edit",
|
||||
model="nano-banana-pro-edit-ultra",
|
||||
options={**kwargs, "resolution": "4k"},
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
result_base64 = base64.b64encode(result.image_bytes).decode('utf-8')
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"image_base64": result_base64,
|
||||
"metadata": result.metadata
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in generate_image_enhance: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -260,335 +260,23 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3) # Estimate output tokens
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
logger.debug(f"[llm_text_gen] Token estimates: input={tokens_input}, output={tokens_output}, total={tokens_total}")
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Looking for usage summary: user_id={user_id}, period={current_period}")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
# This ensures we always get the absolute latest committed values, even across different sessions
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
record_count = 0 # Initialize to ensure it's always defined
|
||||
|
||||
# CRITICAL: First check if record exists using COUNT query
|
||||
try:
|
||||
check_query = text("SELECT COUNT(*) FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period")
|
||||
record_count = db_track.execute(check_query, {'user_id': user_id, 'period': current_period}).scalar()
|
||||
logger.debug(f"[llm_text_gen] 🔍 DEBUG: Record count check - found {record_count} record(s) for user={user_id}, period={current_period}")
|
||||
except Exception as count_error:
|
||||
logger.error(f"[llm_text_gen] ❌ COUNT query failed: {count_error}", exc_info=True)
|
||||
record_count = 0
|
||||
|
||||
if record_count and record_count > 0:
|
||||
# Record exists - read current values with raw SQL
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection (whitelist approach)
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
# CRITICAL: This bypasses SQLAlchemy's session cache and gets absolute latest values
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Executing raw SQL for EXISTING record: SELECT {provider_name}_calls, {provider_name}_tokens WHERE user_id={user_id}, period={current_period}")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
raw_calls = result[0] if result[0] is not None else 0
|
||||
raw_tokens = result[1] if result[1] is not None else 0
|
||||
current_calls_before = raw_calls
|
||||
current_tokens_before = raw_tokens
|
||||
logger.debug(f"[llm_text_gen] ✅ Raw SQL SUCCESS: Found EXISTING record - calls={current_calls_before}, tokens={current_tokens_before} (provider={provider_name}, column={provider_name}_calls/{provider_name}_tokens)")
|
||||
logger.debug(f"[llm_text_gen] 🔍 Raw SQL returned row: {result}, extracted calls={raw_calls}, tokens={raw_tokens}")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL BUG: Record EXISTS (count={record_count}) but SELECT query returned None! Query: {sql_query}")
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
logger.warning(f"[llm_text_gen] ⚠️ Using ORM fallback: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Raw SQL query failed: {sql_error}", exc_info=True)
|
||||
# Fallback: Use ORM to get values
|
||||
summary_fallback = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary_fallback:
|
||||
db_track.refresh(summary_fallback)
|
||||
current_calls_before = getattr(summary_fallback, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary_fallback, f"{provider_name}_tokens", 0) or 0
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ℹ️ No record exists yet (will create new) - user={user_id}, period={current_period}")
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
logger.debug(f"[llm_text_gen] Creating NEW usage summary for user {user_id}, period {current_period}")
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
# New record - values are already 0, no need to set
|
||||
logger.debug(f"[llm_text_gen] ✅ New summary created - starting from 0")
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
logger.debug(f"[llm_text_gen] 🔄 Existing summary found - syncing with raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] ✅ Synchronized ORM object: {provider_name}_calls={current_calls_before}, {provider_name}_tokens={current_tokens_before}")
|
||||
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
# SQLAlchemy doesn't detect changes when using setattr() on dynamic attributes
|
||||
# Using raw SQL UPDATE ensures the change is persisted
|
||||
from sqlalchemy import text
|
||||
update_calls_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_calls = :new_calls
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_calls_query, {
|
||||
'new_calls': new_calls,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_calls via SQL: {current_calls_before} -> {new_calls}")
|
||||
|
||||
# Update token usage for LLM providers with safety check
|
||||
# CRITICAL: Use current_tokens_before from raw SQL query (NOT from ORM object)
|
||||
# The ORM object may have stale values, but raw SQL always has the latest committed values
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
# CRITICAL: Use direct SQL UPDATE instead of ORM setattr for dynamic attributes
|
||||
update_tokens_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_tokens = :new_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_tokens_query, {
|
||||
'new_tokens': new_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated {provider_name}_tokens via SQL: {current_tokens_before} -> {new_tokens}")
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Determine tracked tokens (after any safety capping)
|
||||
tracked_tokens_input = min(tokens_input, tokens_total)
|
||||
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
|
||||
|
||||
# Calculate and persist cost for this call
|
||||
try:
|
||||
cost_info = pricing.calculate_api_cost(
|
||||
provider=provider_enum,
|
||||
model_name=model,
|
||||
tokens_input=tracked_tokens_input,
|
||||
tokens_output=tracked_tokens_output,
|
||||
request_count=1
|
||||
)
|
||||
cost_total = cost_info.get('cost_total', 0.0) or 0.0
|
||||
except Exception as cost_error:
|
||||
cost_total = 0.0
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to calculate API cost: {cost_error}", exc_info=True)
|
||||
|
||||
if cost_total > 0:
|
||||
logger.debug(f"[llm_text_gen] 💰 Calculated cost for {provider_name}: ${cost_total:.6f}")
|
||||
update_costs_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_costs_query, {
|
||||
'cost': cost_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Keep ORM object in sync for logging/debugging
|
||||
current_provider_cost = getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
|
||||
setattr(summary, f"{provider_name}_cost", current_provider_cost + cost_total)
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost_total
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] 💰 Cost calculation returned $0 for {provider_name} (tokens_input={tracked_tokens_input}, tokens_output={tracked_tokens_output})")
|
||||
|
||||
# Update totals using SQL UPDATE
|
||||
old_total_calls = summary.total_calls or 0
|
||||
old_total_tokens = summary.total_tokens or 0
|
||||
new_total_calls = old_total_calls + 1
|
||||
new_total_tokens = old_total_tokens + tokens_total
|
||||
|
||||
update_totals_query = text("""
|
||||
UPDATE usage_summaries
|
||||
SET total_calls = :total_calls, total_tokens = :total_tokens
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_totals_query, {
|
||||
'total_calls': new_total_calls,
|
||||
'total_tokens': new_total_tokens,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
logger.debug(f"[llm_text_gen] Updated totals via SQL: calls {old_total_calls} -> {new_total_calls}, tokens {old_total_tokens} -> {new_total_tokens}")
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info BEFORE commit (always visible, flushed immediately)
|
||||
import sys
|
||||
debug_msg = f"[DEBUG] BEFORE COMMIT - Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}, Period: {current_period}, New calls will be: {new_calls}, New tokens will be: {new_tokens}"
|
||||
print(debug_msg, flush=True)
|
||||
sys.stdout.flush()
|
||||
logger.debug(f"[llm_text_gen] {debug_msg}")
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.debug(f"[llm_text_gen] ✅ Successfully tracked usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (COMMITTED to DB)")
|
||||
logger.debug(f"[llm_text_gen] Database state after commit: {provider_name}_calls={new_calls}, {provider_name}_tokens={new_tokens} (should be visible to next session)")
|
||||
|
||||
# CRITICAL: Verify commit worked by reading back from DB immediately after commit
|
||||
try:
|
||||
verify_query = text(f"SELECT {provider_name}_calls, {provider_name}_tokens FROM usage_summaries WHERE user_id = :user_id AND billing_period = :period LIMIT 1")
|
||||
verify_result = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result:
|
||||
verified_calls = verify_result[0] if verify_result[0] is not None else 0
|
||||
verified_tokens = verify_result[1] if verify_result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] ✅ VERIFICATION AFTER COMMIT: Read back calls={verified_calls}, tokens={verified_tokens} (expected: calls={new_calls}, tokens={new_tokens})")
|
||||
if verified_calls != new_calls or verified_tokens != new_tokens:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Expected calls={new_calls}, tokens={new_tokens}, but DB has calls={verified_calls}, tokens={verified_tokens}")
|
||||
# Force another commit attempt
|
||||
db_track.commit()
|
||||
verify_result2 = db_track.execute(verify_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if verify_result2:
|
||||
verified_calls2 = verify_result2[0] if verify_result2[0] is not None else 0
|
||||
verified_tokens2 = verify_result2[1] if verify_result2[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] 🔄 After second commit attempt: calls={verified_calls2}, tokens={verified_tokens2}")
|
||||
else:
|
||||
logger.debug(f"[llm_text_gen] ✅ COMMIT VERIFICATION PASSED: Values match expected values")
|
||||
else:
|
||||
logger.error(f"[llm_text_gen] ❌ CRITICAL: COMMIT VERIFICATION FAILED! Record not found after commit!")
|
||||
except Exception as verify_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error verifying commit: {verify_error}", exc_info=True)
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG - Shows before/after state in one message
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
# DEBUG: Log the actual values being used
|
||||
logger.debug(f"[llm_text_gen] 📊 FINAL VALUES FOR LOG: calls_before={current_calls_before}, calls_after={new_calls}, tokens_before={current_tokens_before}, tokens_after={new_tokens}, provider={provider_name}, enum={provider_enum}")
|
||||
|
||||
# CRITICAL DEBUG: Print diagnostic info to stdout (always visible)
|
||||
print(f"[DEBUG] Record count: {record_count}, Raw SQL values: calls={current_calls_before}, tokens={current_tokens_before}, Provider: {provider_name}")
|
||||
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
# Estimate tokens
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
|
||||
# Calculate duration (mocking it since we didn't track start time explicitly in this function)
|
||||
# Ideally we should track start_time at beginning of function
|
||||
duration = 0.5
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name=model,
|
||||
prompt=prompt,
|
||||
response_text=response_text,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
except Exception as usage_error:
|
||||
# Non-blocking: log error but don't fail the request
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
@@ -661,208 +349,18 @@ def llm_text_gen(prompt: str, system_prompt: Optional[str] = None, json_struct:
|
||||
if response_text:
|
||||
logger.info(f"[llm_text_gen] ✅ Fallback API call successful, tracking usage for user {user_id}, provider {provider_enum.value}")
|
||||
try:
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
# Estimate tokens from prompt and response
|
||||
# Recalculate input tokens from prompt (consistent with pre-flight estimation)
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
tokens_output = int(len(str(response_text).split()) * 1.3)
|
||||
tokens_total = tokens_input + tokens_output
|
||||
|
||||
# Get or create usage summary
|
||||
from models.subscription_models import UsageSummary
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get limits once for safety check (to prevent exceeding limits even if actual usage > estimate)
|
||||
provider_name = provider_enum.value
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
token_limit = 0
|
||||
if limits and limits.get('limits'):
|
||||
token_limit = limits['limits'].get(f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# CRITICAL: Use raw SQL to read current values directly from DB, bypassing SQLAlchemy cache
|
||||
from sqlalchemy import text
|
||||
current_calls_before = 0
|
||||
current_tokens_before = 0
|
||||
|
||||
try:
|
||||
# Validate provider_name to prevent SQL injection
|
||||
valid_providers = ['gemini', 'openai', 'anthropic', 'mistral']
|
||||
if provider_name not in valid_providers:
|
||||
raise ValueError(f"Invalid provider_name for SQL query: {provider_name}")
|
||||
|
||||
# Read current values directly from database using raw SQL
|
||||
sql_query = text(f"""
|
||||
SELECT {provider_name}_calls, {provider_name}_tokens
|
||||
FROM usage_summaries
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
LIMIT 1
|
||||
""")
|
||||
result = db_track.execute(sql_query, {'user_id': user_id, 'period': current_period}).first()
|
||||
if result:
|
||||
current_calls_before = result[0] if result[0] is not None else 0
|
||||
current_tokens_before = result[1] if result[1] is not None else 0
|
||||
logger.debug(f"[llm_text_gen] Raw SQL read current values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
except Exception as sql_error:
|
||||
logger.warning(f"[llm_text_gen] Raw SQL query failed (fallback), falling back to ORM: {sql_error}")
|
||||
# Fallback to ORM query if raw SQL fails
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
if summary:
|
||||
db_track.refresh(summary)
|
||||
current_calls_before = getattr(summary, f"{provider_name}_calls", 0) or 0
|
||||
current_tokens_before = getattr(summary, f"{provider_name}_tokens", 0) or 0
|
||||
|
||||
# Get or create usage summary object (needed for ORM update)
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush() # Ensure summary is persisted before updating
|
||||
else:
|
||||
# CRITICAL: Update the ORM object with values from raw SQL query
|
||||
# This ensures the ORM object reflects the actual database state before we increment
|
||||
setattr(summary, f"{provider_name}_calls", current_calls_before)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
setattr(summary, f"{provider_name}_tokens", current_tokens_before)
|
||||
logger.debug(f"[llm_text_gen] Synchronized summary object with raw SQL values (fallback): calls={current_calls_before}, tokens={current_tokens_before}")
|
||||
|
||||
# Get "before" state for unified log (from raw SQL query)
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_calls from DB (fallback, raw SQL): {current_calls_before}")
|
||||
|
||||
# Update provider-specific counters (sync operation)
|
||||
new_calls = current_calls_before + 1
|
||||
setattr(summary, f"{provider_name}_calls", new_calls)
|
||||
|
||||
# Update token usage for LLM providers with safety check
|
||||
# Use current_tokens_before from raw SQL query (most reliable)
|
||||
if provider_enum in [APIProvider.GEMINI, APIProvider.OPENAI, APIProvider.ANTHROPIC, APIProvider.MISTRAL]:
|
||||
logger.debug(f"[llm_text_gen] Current {provider_name}_tokens from DB (fallback, raw SQL): {current_tokens_before}")
|
||||
|
||||
# SAFETY CHECK: Prevent exceeding token limit even if actual usage exceeds estimate
|
||||
# This prevents abuse where actual response tokens exceed pre-flight validation estimate
|
||||
projected_new_tokens = current_tokens_before + tokens_total
|
||||
|
||||
# If limit is set (> 0) and would be exceeded, cap at limit
|
||||
if token_limit > 0 and projected_new_tokens > token_limit:
|
||||
logger.warning(
|
||||
f"[llm_text_gen] ⚠️ ACTUAL token usage ({tokens_total}) exceeded estimate in fallback provider. "
|
||||
f"Would exceed limit: {projected_new_tokens} > {token_limit}. "
|
||||
f"Capping tracked tokens at limit to prevent abuse."
|
||||
)
|
||||
# Cap at limit to prevent abuse
|
||||
new_tokens = token_limit
|
||||
# Adjust tokens_total for accurate total tracking
|
||||
tokens_total = token_limit - current_tokens_before
|
||||
if tokens_total < 0:
|
||||
tokens_total = 0
|
||||
else:
|
||||
new_tokens = projected_new_tokens
|
||||
|
||||
setattr(summary, f"{provider_name}_tokens", new_tokens)
|
||||
else:
|
||||
current_tokens_before = 0
|
||||
new_tokens = 0
|
||||
|
||||
# Determine tracked tokens after any safety capping
|
||||
tracked_tokens_input = min(tokens_input, tokens_total)
|
||||
tracked_tokens_output = max(tokens_total - tracked_tokens_input, 0)
|
||||
|
||||
# Calculate and persist cost for this fallback call
|
||||
cost_total = 0.0
|
||||
try:
|
||||
cost_info = pricing.calculate_api_cost(
|
||||
provider=provider_enum,
|
||||
model_name=fallback_model,
|
||||
tokens_input=tracked_tokens_input,
|
||||
tokens_output=tracked_tokens_output,
|
||||
request_count=1
|
||||
)
|
||||
cost_total = cost_info.get('cost_total', 0.0) or 0.0
|
||||
except Exception as cost_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to calculate fallback cost: {cost_error}", exc_info=True)
|
||||
|
||||
if cost_total > 0:
|
||||
update_costs_query = text(f"""
|
||||
UPDATE usage_summaries
|
||||
SET {provider_name}_cost = COALESCE({provider_name}_cost, 0) + :cost,
|
||||
total_cost = COALESCE(total_cost, 0) + :cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_costs_query, {
|
||||
'cost': cost_total,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
setattr(summary, f"{provider_name}_cost", (getattr(summary, f"{provider_name}_cost", 0.0) or 0.0) + cost_total)
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost_total
|
||||
|
||||
# Update totals (using potentially capped tokens_total from safety check)
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.total_tokens = (summary.total_tokens or 0) + tokens_total
|
||||
|
||||
# Get plan details for unified log (limits already retrieved above)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
call_limit = limits['limits'].get(f"{provider_name}_calls", 0) if limits else 0
|
||||
|
||||
# Get image stats for unified log
|
||||
current_images_before = getattr(summary, "stability_calls", 0) or 0
|
||||
image_limit = limits['limits'].get("stability_calls", 0) if limits else 0
|
||||
|
||||
# Get image editing stats for unified log
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
|
||||
# Get video stats for unified log
|
||||
current_video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
|
||||
# Get audio stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
# Only show ∞ for Enterprise tier when limit is 0 (unlimited)
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# CRITICAL: Flush before commit to ensure changes are immediately visible to other sessions
|
||||
db_track.flush() # Flush to ensure changes are in DB (not just in transaction)
|
||||
db_track.commit() # Commit transaction to make changes visible to other sessions
|
||||
logger.info(f"[llm_text_gen] ✅ Successfully tracked fallback usage: user {user_id} -> provider {provider_name} -> {new_calls} calls, {new_tokens} tokens (committed)")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG for fallback
|
||||
# Use actual_provider_name (e.g., "huggingface") instead of enum value (e.g., "mistral")
|
||||
# Include image stats in the log
|
||||
print(f"""
|
||||
[SUBSCRIPTION] LLM Text Generation (Fallback)
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {actual_provider_name}
|
||||
├─ Model: {fallback_model}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {call_limit if call_limit > 0 else '∞'}
|
||||
├─ Tokens: {current_tokens_before} → {new_tokens} / {token_limit if token_limit > 0 else '∞'}
|
||||
├─ Images: {current_images_before} / {image_limit if image_limit > 0 else '∞'}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit if image_edit_limit > 0 else '∞'}
|
||||
├─ Videos: {current_video_calls} / {video_limit if video_limit > 0 else '∞'}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""")
|
||||
except Exception as track_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Error tracking fallback usage (non-blocking): {track_error}", exc_info=True)
|
||||
db_track.rollback()
|
||||
finally:
|
||||
db_track.close()
|
||||
from services.intelligence.agents.agent_usage_tracking import track_agent_usage_sync
|
||||
|
||||
# Estimate tokens
|
||||
tokens_input = int(len(prompt.split()) * 1.3)
|
||||
|
||||
track_agent_usage_sync(
|
||||
user_id=user_id,
|
||||
model_name=fallback_model,
|
||||
prompt=prompt,
|
||||
response_text=response_text,
|
||||
duration=0.5 # Approximate duration
|
||||
)
|
||||
except Exception as usage_error:
|
||||
logger.error(f"[llm_text_gen] ❌ Failed to track fallback usage: {usage_error}", exc_info=True)
|
||||
|
||||
|
||||
@@ -36,6 +36,172 @@ class VideoProviderNotImplemented(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _track_video_operation_usage(
|
||||
user_id: str,
|
||||
provider: str,
|
||||
model: str,
|
||||
operation_type: str,
|
||||
result_bytes: bytes,
|
||||
cost: float,
|
||||
prompt: Optional[str] = None,
|
||||
endpoint: str = "/video-generation",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
log_prefix: str = "[Video Generation]",
|
||||
response_time: float = 0.0
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Reusable usage tracking helper for all video operations.
|
||||
|
||||
Args:
|
||||
user_id: User ID for tracking
|
||||
provider: Provider name
|
||||
model: Model name used
|
||||
operation_type: Type of operation (for logging)
|
||||
result_bytes: Generated video bytes
|
||||
cost: Cost of the operation
|
||||
prompt: Optional prompt text
|
||||
endpoint: API endpoint path
|
||||
metadata: Optional additional metadata
|
||||
log_prefix: Logging prefix
|
||||
response_time: API response time
|
||||
|
||||
Returns:
|
||||
Dictionary with tracking information
|
||||
"""
|
||||
try:
|
||||
from services.database import get_session_for_user
|
||||
db_track = get_session_for_user(user_id)
|
||||
try:
|
||||
from models.subscription_models import UsageSummary, APIUsageLog, APIProvider
|
||||
from services.subscription import PricingService
|
||||
|
||||
pricing = PricingService(db_track)
|
||||
current_period = pricing.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Get or create usage summary
|
||||
summary = db_track.query(UsageSummary).filter(
|
||||
UsageSummary.user_id == user_id,
|
||||
UsageSummary.billing_period == current_period
|
||||
).first()
|
||||
|
||||
if not summary:
|
||||
summary = UsageSummary(
|
||||
user_id=user_id,
|
||||
billing_period=current_period
|
||||
)
|
||||
db_track.add(summary)
|
||||
db_track.flush()
|
||||
|
||||
# Get current values before update
|
||||
current_calls_before = getattr(summary, "video_calls", 0) or 0
|
||||
current_cost_before = getattr(summary, "video_cost", 0.0) or 0.0
|
||||
|
||||
# Update video calls and cost
|
||||
new_calls = current_calls_before + 1
|
||||
new_cost = current_cost_before + cost
|
||||
|
||||
# Use direct SQL UPDATE for dynamic attributes
|
||||
from sqlalchemy import text as sql_text
|
||||
update_query = sql_text("""
|
||||
UPDATE usage_summaries
|
||||
SET video_calls = :new_calls,
|
||||
video_cost = :new_cost
|
||||
WHERE user_id = :user_id AND billing_period = :period
|
||||
""")
|
||||
db_track.execute(update_query, {
|
||||
'new_calls': new_calls,
|
||||
'new_cost': new_cost,
|
||||
'user_id': user_id,
|
||||
'period': current_period
|
||||
})
|
||||
|
||||
# Update total cost
|
||||
summary.total_cost = (summary.total_cost or 0.0) + cost
|
||||
summary.total_calls = (summary.total_calls or 0) + 1
|
||||
summary.updated_at = datetime.utcnow()
|
||||
|
||||
# Create usage log
|
||||
request_size = len(prompt.encode("utf-8")) if prompt else 0
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=APIProvider.WAVESPEED, # Default for video
|
||||
endpoint=endpoint,
|
||||
method="POST",
|
||||
model_used=model or "unknown",
|
||||
actual_provider_name=provider,
|
||||
tokens_input=0,
|
||||
tokens_output=0,
|
||||
tokens_total=0,
|
||||
cost_input=0.0,
|
||||
cost_output=0.0,
|
||||
cost_total=cost,
|
||||
response_time=response_time,
|
||||
status_code=200,
|
||||
request_size=request_size,
|
||||
response_size=len(result_bytes) if result_bytes else 0,
|
||||
billing_period=current_period,
|
||||
)
|
||||
db_track.add(usage_log)
|
||||
|
||||
# Get plan details for unified log
|
||||
limits = pricing.get_user_limits(user_id)
|
||||
plan_name = limits.get('plan_name', 'unknown') if limits else 'unknown'
|
||||
tier = limits.get('tier', 'unknown') if limits else 'unknown'
|
||||
|
||||
# Get limits for display
|
||||
video_limit = limits['limits'].get("video_calls", 0) if limits else 0
|
||||
video_limit_display = video_limit if (video_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
# Get related stats for unified log
|
||||
current_audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_limit = limits['limits'].get("audio_calls", 0) if limits else 0
|
||||
audio_limit_display = audio_limit if (audio_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
current_image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_limit = limits['limits'].get("image_edit_calls", 0) if limits else 0
|
||||
image_edit_limit_display = image_edit_limit if (image_edit_limit > 0 or tier != 'enterprise') else '∞'
|
||||
|
||||
db_track.commit()
|
||||
logger.info(f"{log_prefix} ✅ Successfully tracked usage: user {user_id} -> {operation_type} -> {new_calls} calls, ${cost:.4f}")
|
||||
|
||||
# UNIFIED SUBSCRIPTION LOG
|
||||
operation_name = operation_type.replace("-", " ").title()
|
||||
print(f"""
|
||||
[SUBSCRIPTION] {operation_name}
|
||||
├─ User: {user_id}
|
||||
├─ Plan: {plan_name} ({tier})
|
||||
├─ Provider: {provider}
|
||||
├─ Actual Provider: {provider}
|
||||
├─ Model: {model or 'unknown'}
|
||||
├─ Calls: {current_calls_before} → {new_calls} / {video_limit_display}
|
||||
├─ Cost: ${current_cost_before:.4f} → ${new_cost:.4f}
|
||||
├─ Audio: {current_audio_calls} / {audio_limit_display}
|
||||
├─ Image Editing: {current_image_edit_calls} / {image_edit_limit_display}
|
||||
└─ Status: ✅ Allowed & Tracked
|
||||
""", flush=True)
|
||||
sys.stdout.flush()
|
||||
|
||||
return {
|
||||
"current_calls": new_calls,
|
||||
"cost": cost,
|
||||
"total_cost": new_cost,
|
||||
}
|
||||
|
||||
except Exception as track_error:
|
||||
logger.error(f"{log_prefix} ❌ Error tracking usage (non-blocking): {track_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
db_track.rollback()
|
||||
return {}
|
||||
finally:
|
||||
db_track.close()
|
||||
except Exception as usage_error:
|
||||
logger.error(f"{log_prefix} ❌ Failed to track usage: {usage_error}", exc_info=True)
|
||||
import traceback
|
||||
logger.error(f"{log_prefix} Full traceback: {traceback.format_exc()}")
|
||||
return {}
|
||||
|
||||
|
||||
def _get_api_key(provider: str) -> Optional[str]:
|
||||
try:
|
||||
manager = APIKeyManager()
|
||||
@@ -500,156 +666,74 @@ async def ai_video_generate(
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
logger.info(f"[Video Generation] ✅ Pre-flight validation passed - proceeding with {operation_type}")
|
||||
|
||||
# Progress callback: Initial submission
|
||||
if progress_callback:
|
||||
progress_callback(10.0, f"Submitting {operation_type} request to {provider}...")
|
||||
|
||||
# Generate video based on operation type
|
||||
model_name = kwargs.get("model", _get_default_model(operation_type, provider))
|
||||
|
||||
# Track response time for video generation
|
||||
|
||||
# Track response time
|
||||
import time
|
||||
from datetime import datetime
|
||||
start_time = time.time()
|
||||
|
||||
# Execute operation based on type
|
||||
result = {}
|
||||
try:
|
||||
if operation_type == "text-to-video":
|
||||
if provider == "huggingface":
|
||||
video_bytes = _generate_with_huggingface(
|
||||
prompt=prompt,
|
||||
**kwargs,
|
||||
)
|
||||
# For text-to-video, create metadata dict (HuggingFace doesn't return metadata)
|
||||
result_dict = {
|
||||
video_bytes = _generate_with_huggingface(prompt=prompt, **kwargs)
|
||||
result = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10, # Default cost, will be calculated in track_video_usage
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280, # Default, actual may vary
|
||||
"height": 720, # Default, actual may vary
|
||||
"metadata": {},
|
||||
"model_name": kwargs.get("model", "tencent/HunyuanVideo"),
|
||||
"provider": "huggingface",
|
||||
"cost": 0.0, # HuggingFace inference is free/low cost
|
||||
}
|
||||
elif provider == "wavespeed":
|
||||
# WaveSpeed text-to-video - use unified service
|
||||
result_dict = await _generate_text_to_video_wavespeed(
|
||||
result = await _generate_text_to_video_wavespeed(
|
||||
prompt=prompt,
|
||||
progress_callback=progress_callback,
|
||||
**kwargs,
|
||||
**kwargs
|
||||
)
|
||||
elif provider == "gemini":
|
||||
video_bytes = _generate_with_gemini(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
result = {"video_bytes": _generate_with_gemini(prompt, **kwargs)}
|
||||
elif provider == "openai":
|
||||
video_bytes = _generate_with_openai(prompt=prompt, **kwargs)
|
||||
result_dict = {
|
||||
"video_bytes": video_bytes,
|
||||
"prompt": prompt,
|
||||
"duration": kwargs.get("duration", 5.0),
|
||||
"model_name": model_name,
|
||||
"cost": 0.10,
|
||||
"provider": provider,
|
||||
"resolution": kwargs.get("resolution", "720p"),
|
||||
"width": 1280,
|
||||
"height": 720,
|
||||
"metadata": {},
|
||||
}
|
||||
result = {"video_bytes": _generate_with_openai(prompt, **kwargs)}
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for text-to-video: {provider}")
|
||||
|
||||
raise ValueError(f"Unknown provider for text-to-video: {provider}")
|
||||
|
||||
elif operation_type == "image-to-video":
|
||||
if provider == "wavespeed":
|
||||
# Progress callback: Starting generation
|
||||
if progress_callback:
|
||||
progress_callback(20.0, "Video generation in progress...")
|
||||
|
||||
# Handle async call from sync context
|
||||
# Since ai_video_generate is sync, we need to run async function
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
# We're in an async context - use ThreadPoolExecutor to run in new event loop
|
||||
import concurrent.futures
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
asyncio.run,
|
||||
_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
result_dict = future.result()
|
||||
else:
|
||||
# Event loop exists but not running - use it
|
||||
result_dict = loop.run_until_complete(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
except RuntimeError:
|
||||
# No event loop exists, create a new one
|
||||
result_dict = asyncio.run(_generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or kwargs.get("prompt", ""),
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
))
|
||||
video_bytes = result_dict["video_bytes"]
|
||||
model_name = result_dict.get("model_name", model_name)
|
||||
|
||||
# Progress callback: Processing result
|
||||
if progress_callback:
|
||||
progress_callback(90.0, "Processing video result...")
|
||||
result = await _generate_image_to_video_wavespeed(
|
||||
image_data=image_data,
|
||||
image_base64=image_base64,
|
||||
prompt=prompt or "",
|
||||
progress_callback=progress_callback,
|
||||
**kwargs
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Unknown provider for image-to-video: {provider}. Only 'wavespeed' is supported.")
|
||||
raise ValueError(f"Unknown provider for image-to-video: {provider}")
|
||||
|
||||
# Track usage (same pattern as text generation)
|
||||
# Use cost from result_dict if available, otherwise calculate
|
||||
response_time = time.time() - start_time
|
||||
cost_override = result_dict.get("cost") if operation_type == "image-to-video" else kwargs.get("cost_override")
|
||||
track_video_usage(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
model_name=model_name,
|
||||
prompt=result_dict.get("prompt", prompt or ""),
|
||||
video_bytes=video_bytes,
|
||||
cost_override=cost_override,
|
||||
response_time=response_time,
|
||||
)
|
||||
|
||||
# Progress callback: Complete
|
||||
if progress_callback:
|
||||
progress_callback(100.0, "Video generation complete!")
|
||||
|
||||
return result_dict
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTPExceptions (e.g., from validation or API errors)
|
||||
raise
|
||||
# TRACK USAGE after successful API call
|
||||
video_bytes = result.get("video_bytes")
|
||||
if user_id and video_bytes:
|
||||
_track_video_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=result.get("provider", provider),
|
||||
model=result.get("model_name", kwargs.get("model", "unknown")),
|
||||
operation_type=operation_type,
|
||||
result_bytes=video_bytes,
|
||||
cost=result.get("cost", 0.0),
|
||||
prompt=prompt,
|
||||
endpoint="/video-generation",
|
||||
metadata=result.get("metadata"),
|
||||
log_prefix=f"[{operation_type.replace('-', ' ').title()}]",
|
||||
response_time=response_time
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[video_gen] Error during video generation: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail={"error": str(e)})
|
||||
# Log failure but don't track usage (no cost incurred)
|
||||
logger.error(f"[video_gen] Generation failed: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def _get_default_model(operation_type: str, provider: str) -> str:
|
||||
|
||||
@@ -46,6 +46,9 @@ class CorePersonaService:
|
||||
# Get schema for structured response
|
||||
persona_schema = self.prompt_builder.get_persona_schema()
|
||||
|
||||
# Extract user_id for tracking
|
||||
user_id = onboarding_data.get("session_info", {}).get("user_id")
|
||||
|
||||
try:
|
||||
# Generate structured response using Gemini
|
||||
response = gemini_structured_json_response(
|
||||
@@ -53,7 +56,8 @@ class CorePersonaService:
|
||||
schema=persona_schema,
|
||||
temperature=0.2, # Low temperature for consistent analysis
|
||||
max_tokens=8192,
|
||||
system_prompt="You are an expert writing style analyst and persona developer. Analyze the provided data to create a precise, actionable writing persona."
|
||||
system_prompt="You are an expert writing style analyst and persona developer. Analyze the provided data to create a precise, actionable writing persona.",
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if "error" in response:
|
||||
@@ -103,13 +107,17 @@ class CorePersonaService:
|
||||
# Get platform-specific schema
|
||||
platform_schema = self.prompt_builder.get_platform_schema()
|
||||
|
||||
# Extract user_id for tracking
|
||||
user_id = onboarding_data.get("session_info", {}).get("user_id")
|
||||
|
||||
try:
|
||||
response = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=platform_schema,
|
||||
temperature=0.2,
|
||||
max_tokens=4096,
|
||||
system_prompt=f"You are an expert in {platform} content strategy and platform-specific writing optimization."
|
||||
system_prompt=f"You are an expert in {platform} content strategy and platform-specific writing optimization.",
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -62,6 +62,9 @@ class FacebookPersonaService:
|
||||
|
||||
# Get Facebook-specific schema
|
||||
schema = self._get_enhanced_facebook_schema()
|
||||
|
||||
# Extract user_id for tracking
|
||||
user_id = onboarding_data.get("session_info", {}).get("user_id")
|
||||
|
||||
# Generate structured response using Gemini with optimized prompts
|
||||
response = gemini_structured_json_response(
|
||||
@@ -69,7 +72,8 @@ class FacebookPersonaService:
|
||||
schema=schema,
|
||||
temperature=0.2,
|
||||
max_tokens=4096,
|
||||
system_prompt=system_prompt
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if not response or "error" in response:
|
||||
|
||||
@@ -54,13 +54,17 @@ class LinkedInPersonaService:
|
||||
# Get LinkedIn-specific schema
|
||||
schema = self.schemas.get_enhanced_linkedin_schema()
|
||||
|
||||
# Extract user_id for tracking
|
||||
user_id = onboarding_data.get("session_info", {}).get("user_id")
|
||||
|
||||
# Generate structured response using Gemini with optimized prompts
|
||||
response = gemini_structured_json_response(
|
||||
prompt=prompt,
|
||||
schema=schema,
|
||||
temperature=0.2,
|
||||
max_tokens=4096,
|
||||
system_prompt=system_prompt
|
||||
system_prompt=system_prompt,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
if "error" in response:
|
||||
|
||||
@@ -56,6 +56,17 @@ async def check_and_execute_due_tasks(scheduler: 'TaskScheduler'):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Check onboarding status first
|
||||
# Skip users who haven't completed onboarding to prevent premature agent initialization
|
||||
from services.onboarding.progress_service import OnboardingProgressService
|
||||
onboarding_service = OnboardingProgressService()
|
||||
status = onboarding_service.get_onboarding_status(user_id)
|
||||
|
||||
if not status.get("is_completed", False):
|
||||
# Skip logging for inactive users to reduce noise, unless debugging
|
||||
# logger.debug(f"[Scheduler Check] Skipping user {user_id} - Onboarding incomplete")
|
||||
continue
|
||||
|
||||
# Check active strategies for this user (for interval adjustment)
|
||||
try:
|
||||
from services.active_strategy_service import ActiveStrategyService
|
||||
|
||||
@@ -67,6 +67,27 @@ class SIFIndexingExecutor(TaskExecutor):
|
||||
# 2. Sync User Website Content (Deep Crawl / Snapshot)
|
||||
content_synced = await sif_service.sync_user_website_content(website_url)
|
||||
|
||||
# 3. Trigger Content Guardian Audit (Background Analysis)
|
||||
# This ensures the agent runs immediately after new data is indexed
|
||||
guardian_report = None
|
||||
if content_synced:
|
||||
try:
|
||||
from services.intelligence.agents.specialized_agents import ContentGuardianAgent
|
||||
# Re-use the intelligence service from sif_service
|
||||
guardian_agent = ContentGuardianAgent(
|
||||
intelligence_service=sif_service.intelligence_service,
|
||||
user_id=user_id,
|
||||
sif_service=sif_service
|
||||
)
|
||||
|
||||
logger.info("Triggering Content Guardian Site Audit...")
|
||||
guardian_report = await guardian_agent.perform_site_audit(website_url)
|
||||
|
||||
# Persist the audit report (optional, or rely on logs/alerts)
|
||||
# For now, we just include it in the task result
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to run Content Guardian audit: {e}")
|
||||
|
||||
# Determine overall success
|
||||
# We consider it a success if at least one operation worked, or if both were attempted without error
|
||||
# But ideally, content sync is the heavy lifter.
|
||||
@@ -91,6 +112,7 @@ class SIFIndexingExecutor(TaskExecutor):
|
||||
task_log.result_data = {
|
||||
"metadata_synced": metadata_synced,
|
||||
"content_synced": content_synced,
|
||||
"guardian_report": guardian_report,
|
||||
"website_url": website_url
|
||||
}
|
||||
task_log.execution_time_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
@@ -29,9 +29,10 @@ def load_due_sif_indexing_tasks(db: Session, user_id: str = None) -> List[SIFInd
|
||||
query = db.query(SIFIndexingTask).filter(
|
||||
or_(
|
||||
SIFIndexingTask.status == "pending",
|
||||
SIFIndexingTask.status == "active",
|
||||
SIFIndexingTask.status == "failed" # Retry failed tasks
|
||||
),
|
||||
SIFIndexingTask.next_run_at <= datetime.utcnow()
|
||||
SIFIndexingTask.next_execution <= datetime.utcnow()
|
||||
)
|
||||
|
||||
if user_id:
|
||||
|
||||
@@ -199,6 +199,24 @@ class PricingService:
|
||||
"cost_per_input_token": 0.0, # No additional token cost for grounding
|
||||
"cost_per_output_token": 0.0, # No additional token cost for grounding
|
||||
"description": "Grounding with Google Search - 1,500 RPD free, then $35/1K requests"
|
||||
},
|
||||
# Alwrity Voice Cloning - Qwen3
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "alwrity-ai/qwen3-tts/voice-clone",
|
||||
"cost_per_request": 0.10,
|
||||
"cost_per_input_token": 0.00001,
|
||||
"cost_per_output_token": 0.0,
|
||||
"description": "Alwrity Qwen3 Voice Clone (Efficient)"
|
||||
},
|
||||
# Alwrity Voice Cloning - CosyVoice
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "alwrity-ai/cosyvoice/voice-clone",
|
||||
"cost_per_request": 0.15,
|
||||
"cost_per_input_token": 0.00001,
|
||||
"cost_per_output_token": 0.0,
|
||||
"description": "Alwrity CosyVoice Clone (High Fidelity)"
|
||||
}
|
||||
]
|
||||
|
||||
@@ -402,11 +420,19 @@ class PricingService:
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "wavespeed-ai/qwen3-tts/voice-clone",
|
||||
"cost_per_request": 0.0,
|
||||
"cost_per_input_token": 0.0,
|
||||
"cost_per_request": 0.005,
|
||||
"cost_per_input_token": 0.00005,
|
||||
"cost_per_output_token": 0.0,
|
||||
"description": "Qwen3-TTS Voice Clone via WaveSpeed (cost depends on text length)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
"cost_per_request": 0.005,
|
||||
"cost_per_input_token": 0.00005,
|
||||
"cost_per_output_token": 0.0,
|
||||
"description": "CosyVoice-TTS Voice Clone via WaveSpeed (cost depends on text length)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
"model_name": "default",
|
||||
@@ -429,8 +455,9 @@ class PricingService:
|
||||
|
||||
if existing:
|
||||
# Update existing pricing (especially for HuggingFace if env vars changed)
|
||||
if pricing_data["provider"] == APIProvider.MISTRAL:
|
||||
# Update HuggingFace pricing from env vars
|
||||
if pricing_data["provider"] in [APIProvider.MISTRAL, APIProvider.AUDIO]:
|
||||
# Update pricing
|
||||
existing.cost_per_request = pricing_data.get("cost_per_request", 0.0)
|
||||
existing.cost_per_input_token = pricing_data["cost_per_input_token"]
|
||||
existing.cost_per_output_token = pricing_data["cost_per_output_token"]
|
||||
existing.description = pricing_data["description"]
|
||||
|
||||
@@ -490,6 +490,32 @@ class UsageTrackingService:
|
||||
'cost': image_edit_cost
|
||||
}
|
||||
|
||||
# WaveSpeed (aggregated across Video, Audio, Image, Image Edit)
|
||||
# Query APIUsageLog directly to get accurate WaveSpeed-specific usage
|
||||
wavespeed_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.billing_period == billing_period,
|
||||
APIUsageLog.actual_provider_name == "wavespeed"
|
||||
).all()
|
||||
|
||||
if wavespeed_logs:
|
||||
wavespeed_calls = len(wavespeed_logs)
|
||||
wavespeed_tokens = sum((log.tokens_total or 0) for log in wavespeed_logs)
|
||||
wavespeed_cost = sum(float(log.cost_total or 0.0) for log in wavespeed_logs)
|
||||
|
||||
provider_breakdown['wavespeed'] = {
|
||||
'calls': wavespeed_calls,
|
||||
'tokens': wavespeed_tokens,
|
||||
'cost': wavespeed_cost
|
||||
}
|
||||
logger.info(f"[UsageStats] Calculated WaveSpeed usage: {wavespeed_calls} calls, ${wavespeed_cost:.6f}")
|
||||
else:
|
||||
provider_breakdown['wavespeed'] = {
|
||||
'calls': 0,
|
||||
'tokens': 0,
|
||||
'cost': 0.0
|
||||
}
|
||||
|
||||
# Search APIs
|
||||
tavily_calls = getattr(summary, "tavily_calls", 0) or 0
|
||||
tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0
|
||||
|
||||
@@ -12,6 +12,7 @@ from loguru import logger
|
||||
from services.image_studio.infinitetalk_adapter import InfiniteTalkService
|
||||
from services.video_studio.hunyuan_avatar_adapter import HunyuanAvatarService
|
||||
from utils.logger_utils import get_service_logger
|
||||
from services.llm_providers.main_video_generation import _track_video_operation_usage
|
||||
|
||||
logger = get_service_logger("video_studio.avatar")
|
||||
|
||||
@@ -58,6 +59,30 @@ class AvatarStudioService:
|
||||
f"[AvatarStudio] Creating talking avatar: user={user_id}, resolution={resolution}, model={model}"
|
||||
)
|
||||
|
||||
# PRE-FLIGHT VALIDATION: Validate video generation before API call
|
||||
# MUST happen BEFORE any API calls - return immediately if validation fails
|
||||
from services.database import get_db
|
||||
from services.subscription import PricingService
|
||||
from services.subscription.preflight_validator import validate_video_generation_operations
|
||||
|
||||
db = next(get_db())
|
||||
try:
|
||||
pricing_service = PricingService(db)
|
||||
# Raises HTTPException immediately if validation fails - frontend gets immediate response
|
||||
validate_video_generation_operations(
|
||||
pricing_service=pricing_service,
|
||||
user_id=user_id
|
||||
)
|
||||
except HTTPException:
|
||||
# Re-raise immediately - don't proceed with API call
|
||||
logger.error(f"[AvatarStudio] ❌ Pre-flight validation failed - blocking API call")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if model == "hunyuan-avatar":
|
||||
# Use Hunyuan Avatar (doesn't support mask_image)
|
||||
@@ -82,12 +107,32 @@ class AvatarStudioService:
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
response_time = time.time() - start_time
|
||||
|
||||
logger.info(
|
||||
f"[AvatarStudio] ✅ Talking avatar created: "
|
||||
f"model={model}, resolution={resolution}, duration={result.get('duration', 0)}s, "
|
||||
f"cost=${result.get('cost', 0):.2f}"
|
||||
)
|
||||
|
||||
# TRACK USAGE after successful API call
|
||||
# Use video_bytes if available, otherwise check if result itself is bytes (unlikely, dict expected)
|
||||
video_bytes = result.get("video_bytes")
|
||||
if user_id and video_bytes:
|
||||
_track_video_operation_usage(
|
||||
user_id=user_id,
|
||||
provider=model, # Use model name as provider/actual_provider for now
|
||||
model=model,
|
||||
operation_type="talking-avatar",
|
||||
result_bytes=video_bytes,
|
||||
cost=result.get("cost", 0.0),
|
||||
prompt=prompt,
|
||||
endpoint="/avatar-generation",
|
||||
metadata=result,
|
||||
log_prefix="[Avatar Generation]",
|
||||
response_time=response_time
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
|
||||
@@ -324,6 +324,39 @@ class WaveSpeedClient:
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def voice_design(
|
||||
self,
|
||||
text: str,
|
||||
voice_description: str,
|
||||
language: str = "auto",
|
||||
timeout: int = 180,
|
||||
) -> bytes:
|
||||
return self.speech.voice_design(
|
||||
text=text,
|
||||
voice_description=voice_description,
|
||||
language=language,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def cosyvoice_voice_clone(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
text: str,
|
||||
*,
|
||||
model: str = "wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
audio_mime_type: str = "audio/wav",
|
||||
reference_text: Optional[str] = None,
|
||||
timeout: int = 180,
|
||||
) -> bytes:
|
||||
return self.speech.cosyvoice_voice_clone(
|
||||
audio_bytes=audio_bytes,
|
||||
text=text,
|
||||
model=model,
|
||||
audio_mime_type=audio_mime_type,
|
||||
reference_text=reference_text,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
def generate_text_video(
|
||||
self,
|
||||
prompt: str,
|
||||
|
||||
@@ -146,14 +146,44 @@ class PromptGenerator:
|
||||
if isinstance(first_output, str):
|
||||
if first_output.startswith("http://") or first_output.startswith("https://"):
|
||||
logger.info(f"[WaveSpeed] Fetching optimized prompt from URL: {first_output}")
|
||||
url_response = requests.get(first_output, timeout=timeout)
|
||||
if url_response.status_code == 200:
|
||||
return url_response.text.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
|
||||
# Use stream=True to avoid downloading large files into memory
|
||||
try:
|
||||
with requests.get(first_output, timeout=timeout, stream=True) as url_response:
|
||||
if url_response.status_code == 200:
|
||||
# Check Content-Length if available
|
||||
content_length = url_response.headers.get("Content-Length")
|
||||
if content_length and int(content_length) > 1024 * 1024: # 1MB limit for prompts
|
||||
logger.error(f"[WaveSpeed] Optimized prompt URL content too large: {content_length} bytes")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer returned a file that is too large",
|
||||
)
|
||||
|
||||
# Read content with limit
|
||||
content = ""
|
||||
for chunk in url_response.iter_content(chunk_size=8192, decode_unicode=True):
|
||||
if chunk:
|
||||
content += chunk
|
||||
if len(content) > 1024 * 1024: # Hard limit 1MB
|
||||
logger.error("[WaveSpeed] Optimized prompt URL content exceeded 1MB limit during download")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="WaveSpeed prompt optimizer returned a file that is too large",
|
||||
)
|
||||
|
||||
return content.strip()
|
||||
else:
|
||||
logger.error(f"[WaveSpeed] Failed to fetch prompt from URL: {url_response.status_code}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"[WaveSpeed] Error fetching prompt from URL: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail="Failed to fetch optimized prompt from WaveSpeed URL",
|
||||
detail=f"Error fetching optimized prompt: {str(e)}",
|
||||
)
|
||||
else:
|
||||
# It's already the text
|
||||
|
||||
@@ -181,6 +181,102 @@ class SpeechGenerator:
|
||||
audio_url = self._extract_audio_url(outputs)
|
||||
return self._download_audio(audio_url, timeout)
|
||||
|
||||
def voice_design(
|
||||
self,
|
||||
text: str,
|
||||
voice_description: str,
|
||||
language: str = "auto",
|
||||
timeout: int = 180,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate speech using Qwen3 Voice Design (text + voice description).
|
||||
"""
|
||||
url = f"{self.base_url}/wavespeed-ai/qwen3-tts/voice-design"
|
||||
|
||||
payload = {
|
||||
"text": text,
|
||||
"voice_description": voice_description,
|
||||
"language": language
|
||||
}
|
||||
|
||||
logger.info(f"[WaveSpeed] Voice design via {url}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
timeout=(30, 90),
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
raise HTTPException(status_code=504, detail={"error": "WaveSpeed Voice Design timed out", "message": str(e)})
|
||||
except (requests_exceptions.ConnectionError, requests_exceptions.ConnectTimeout) as e:
|
||||
raise HTTPException(status_code=504, detail={"error": "WaveSpeed Voice Design connection failed", "message": str(e)})
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=response.status_code,
|
||||
detail={"error": "WaveSpeed Voice Design failed", "message": response.text}
|
||||
)
|
||||
|
||||
try:
|
||||
data = response.json()
|
||||
# The API is async and returns a task ID or direct output depending on implementation.
|
||||
# Based on user input, it returns a "data" object with "id" and we poll.
|
||||
# BUT wait, the Python example provided by user shows:
|
||||
# response = requests.post(url, ...)
|
||||
# if response.status_code == 200: result = response.json()["data"] ...
|
||||
# Then it polls /api/v3/predictions/{request_id}/result
|
||||
|
||||
# Let's handle the async polling logic here or in the caller.
|
||||
# The user's Python example is very clear. It's an async task.
|
||||
|
||||
if "data" in data and "id" in data["data"]:
|
||||
request_id = data["data"]["id"]
|
||||
return self._poll_prediction_result(request_id, timeout=timeout)
|
||||
|
||||
# Fallback if it returns direct output (unlikely based on docs)
|
||||
if "data" in data and "outputs" in data["data"] and data["data"]["outputs"]:
|
||||
return self._download_audio(data["data"]["outputs"][0]["url"], timeout) # Assuming structure
|
||||
|
||||
raise ValueError(f"Unexpected response format: {data}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[WaveSpeed] Error parsing Voice Design response: {e}")
|
||||
raise HTTPException(status_code=500, detail={"error": "Failed to parse Voice Design response", "message": str(e)})
|
||||
|
||||
def _poll_prediction_result(self, request_id: str, timeout: int = 180) -> bytes:
|
||||
import time
|
||||
url = f"https://api.wavespeed.ai/api/v3/predictions/{request_id}/result"
|
||||
start_time = time.time()
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
try:
|
||||
response = requests.get(url, headers=self._get_headers(), timeout=10)
|
||||
if response.status_code == 200:
|
||||
result = response.json().get("data", {})
|
||||
status = result.get("status")
|
||||
|
||||
if status == "completed":
|
||||
if result.get("outputs") and len(result["outputs"]) > 0:
|
||||
audio_url = result["outputs"][0] # It's a URL string in the array
|
||||
return self._download_audio(audio_url, timeout)
|
||||
else:
|
||||
raise ValueError("Completed task has no output URLs")
|
||||
elif status == "failed":
|
||||
raise ValueError(f"Task failed: {result.get('error')}")
|
||||
|
||||
# If processing/created, continue polling
|
||||
time.sleep(1)
|
||||
else:
|
||||
logger.warning(f"Polling error {response.status_code}: {response.text}")
|
||||
time.sleep(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Polling exception: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
raise HTTPException(status_code=504, detail="Voice Design generation timed out")
|
||||
|
||||
def voice_clone(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
@@ -320,6 +416,70 @@ class SpeechGenerator:
|
||||
audio_url = self._extract_audio_url(outputs)
|
||||
return self._download_audio(audio_url, timeout)
|
||||
|
||||
def cosyvoice_voice_clone(
|
||||
self,
|
||||
audio_bytes: bytes,
|
||||
text: str,
|
||||
*,
|
||||
model: str = "wavespeed-ai/cosyvoice-tts/voice-clone",
|
||||
audio_mime_type: str = "audio/wav",
|
||||
reference_text: Optional[str] = None,
|
||||
timeout: int = 180,
|
||||
) -> bytes:
|
||||
url = f"{self.base_url}/{model}"
|
||||
|
||||
audio_b64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
mime = audio_mime_type or "audio/wav"
|
||||
audio_data_url = f"data:{mime};base64,{audio_b64}"
|
||||
|
||||
payload = {
|
||||
"audio": audio_data_url,
|
||||
"text": text,
|
||||
}
|
||||
if reference_text:
|
||||
payload["reference_text"] = reference_text
|
||||
|
||||
logger.info(f"[WaveSpeed] CosyVoice voice clone via {url}")
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self._get_headers(),
|
||||
json=payload,
|
||||
timeout=(30, 90),
|
||||
)
|
||||
except requests_exceptions.Timeout as e:
|
||||
raise HTTPException(status_code=504, detail={"error": "WaveSpeed CosyVoice voice clone timed out", "message": str(e)})
|
||||
except (requests_exceptions.ConnectionError, requests_exceptions.ConnectTimeout) as e:
|
||||
raise HTTPException(status_code=504, detail={"error": "WaveSpeed CosyVoice voice clone connection failed", "message": str(e)})
|
||||
|
||||
if response.status_code != 200:
|
||||
raise HTTPException(
|
||||
status_code=502,
|
||||
detail={
|
||||
"error": "WaveSpeed CosyVoice voice clone failed",
|
||||
"status_code": response.status_code,
|
||||
"response": response.text,
|
||||
},
|
||||
)
|
||||
|
||||
response_json = response.json()
|
||||
data = response_json.get("data") or response_json
|
||||
|
||||
outputs = data.get("outputs") or []
|
||||
status = data.get("status")
|
||||
prediction_id = data.get("id")
|
||||
|
||||
if not outputs and prediction_id and status in {"created", "processing"}:
|
||||
result = self.polling.poll_until_complete(prediction_id, timeout_seconds=timeout, interval_seconds=0.8)
|
||||
outputs = result.get("outputs") or []
|
||||
|
||||
if not outputs:
|
||||
raise HTTPException(status_code=502, detail="WaveSpeed CosyVoice voice clone returned no outputs")
|
||||
|
||||
audio_url = self._extract_audio_url(outputs)
|
||||
return self._download_audio(audio_url, timeout)
|
||||
|
||||
def _extract_audio_url(self, outputs: list) -> str:
|
||||
"""Extract audio URL from outputs."""
|
||||
if not isinstance(outputs, list) or len(outputs) == 0:
|
||||
|
||||
Reference in New Issue
Block a user