AI Analysis and Content Strategy fixes. Enhanced Strategy Routes refactoring.
This commit is contained in:
@@ -182,4 +182,4 @@ This package consolidates the following previously scattered files:
|
||||
|
||||
- `services.onboarding` - Onboarding and user setup
|
||||
- `models.subscription_models` - Database models
|
||||
- `api.subscription_api` - API endpoints
|
||||
- `api.subscription` - API endpoints (modular structure with routes in `api/subscription/routes/`)
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
"""
|
||||
Log Wrapping Service
|
||||
Intelligently wraps API usage logs when they exceed 5000 records.
|
||||
Intelligently wraps API usage logs when they exceed limits (count or time-based).
|
||||
Aggregates old logs into cumulative records while preserving historical data.
|
||||
|
||||
Features:
|
||||
- Count-based retention: Keeps 4,000 most recent detailed logs
|
||||
- Time-based retention: Aggregates logs older than 90 days
|
||||
- Automatic aggregation: Triggered on log queries
|
||||
- Context preservation: Maintains costs, tokens, counts, success rates
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
@@ -18,13 +24,18 @@ class LogWrappingService:
|
||||
|
||||
MAX_LOGS_PER_USER = 5000
|
||||
AGGREGATION_THRESHOLD_DAYS = 30 # Aggregate logs older than 30 days
|
||||
RETENTION_DAYS = 90 # Time-based retention: aggregate logs older than 90 days
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def check_and_wrap_logs(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check if user has exceeded log limit and wrap if necessary.
|
||||
Check if user has exceeded log limit (count or time-based) and wrap if necessary.
|
||||
|
||||
Checks both:
|
||||
1. Count-based: If user has more than MAX_LOGS_PER_USER logs
|
||||
2. Time-based: If user has logs older than RETENTION_DAYS
|
||||
|
||||
Returns:
|
||||
Dict with wrapping status and statistics
|
||||
@@ -35,18 +46,42 @@ class LogWrappingService:
|
||||
APIUsageLog.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
if total_count <= self.MAX_LOGS_PER_USER:
|
||||
# Check for logs older than retention period
|
||||
retention_cutoff = datetime.utcnow() - timedelta(days=self.RETENTION_DAYS)
|
||||
old_logs_count = self.db.query(func.count(APIUsageLog.id)).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.timestamp < retention_cutoff,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate already aggregated logs
|
||||
).scalar() or 0
|
||||
|
||||
# Determine if wrapping is needed
|
||||
count_based_trigger = total_count > self.MAX_LOGS_PER_USER
|
||||
time_based_trigger = old_logs_count > 0
|
||||
|
||||
if not count_based_trigger and not time_based_trigger:
|
||||
return {
|
||||
'wrapped': False,
|
||||
'total_logs': total_count,
|
||||
'old_logs': old_logs_count,
|
||||
'max_logs': self.MAX_LOGS_PER_USER,
|
||||
'message': f'Log count ({total_count}) is within limit ({self.MAX_LOGS_PER_USER})'
|
||||
'retention_days': self.RETENTION_DAYS,
|
||||
'message': f'Log count ({total_count}) and age are within limits'
|
||||
}
|
||||
|
||||
# Need to wrap logs - aggregate old logs
|
||||
logger.info(f"[LogWrapping] User {user_id} has {total_count} logs, exceeding limit of {self.MAX_LOGS_PER_USER}. Starting wrap...")
|
||||
# Determine trigger reason
|
||||
trigger_reasons = []
|
||||
if count_based_trigger:
|
||||
trigger_reasons.append(f'count limit ({total_count} > {self.MAX_LOGS_PER_USER})')
|
||||
if time_based_trigger:
|
||||
trigger_reasons.append(f'time-based retention ({old_logs_count} logs older than {self.RETENTION_DAYS} days)')
|
||||
|
||||
wrap_result = self._wrap_old_logs(user_id, total_count)
|
||||
logger.info(
|
||||
f"[LogWrapping] User {user_id} needs log wrapping. "
|
||||
f"Total: {total_count}, Old logs: {old_logs_count}. "
|
||||
f"Triggers: {', '.join(trigger_reasons)}"
|
||||
)
|
||||
|
||||
wrap_result = self._wrap_old_logs(user_id, total_count, time_based=time_based_trigger)
|
||||
|
||||
return {
|
||||
'wrapped': True,
|
||||
@@ -54,6 +89,8 @@ class LogWrappingService:
|
||||
'total_logs_after': wrap_result['logs_remaining'],
|
||||
'aggregated_logs': wrap_result['aggregated_count'],
|
||||
'aggregated_periods': wrap_result['periods'],
|
||||
'trigger_reasons': trigger_reasons,
|
||||
'old_logs_aggregated': wrap_result.get('old_logs_aggregated', 0),
|
||||
'message': f'Wrapped {wrap_result["aggregated_count"]} logs into {len(wrap_result["periods"])} aggregated records'
|
||||
}
|
||||
|
||||
@@ -65,30 +102,76 @@ class LogWrappingService:
|
||||
'message': f'Error wrapping logs: {str(e)}'
|
||||
}
|
||||
|
||||
def _wrap_old_logs(self, user_id: str, total_count: int) -> Dict[str, Any]:
|
||||
def _wrap_old_logs(self, user_id: str, total_count: int, time_based: bool = False) -> Dict[str, Any]:
|
||||
"""
|
||||
Aggregate old logs into cumulative records.
|
||||
|
||||
Strategy:
|
||||
1. Keep most recent 4000 logs (detailed)
|
||||
2. Aggregate logs older than 30 days or oldest logs beyond 4000
|
||||
3. Create aggregated records grouped by provider and billing period
|
||||
4. Delete individual logs that were aggregated
|
||||
1. Keep most recent 4000 logs (detailed) - count-based
|
||||
2. Aggregate logs older than RETENTION_DAYS - time-based
|
||||
3. Aggregate oldest logs beyond 4000 limit - count-based
|
||||
4. Create aggregated records grouped by provider and billing period
|
||||
5. Delete individual logs that were aggregated
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
total_count: Total number of logs for user
|
||||
time_based: If True, prioritize time-based retention over count-based
|
||||
"""
|
||||
try:
|
||||
# Calculate how many logs to keep (4000 detailed, rest aggregated)
|
||||
# Calculate retention cutoff date
|
||||
retention_cutoff = datetime.utcnow() - timedelta(days=self.RETENTION_DAYS)
|
||||
aggregation_cutoff = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
|
||||
|
||||
# Determine which logs to aggregate
|
||||
logs_to_keep = 4000
|
||||
logs_to_aggregate = total_count - logs_to_keep
|
||||
logs_to_aggregate_count = max(0, total_count - logs_to_keep)
|
||||
|
||||
# Get cutoff date (30 days ago)
|
||||
cutoff_date = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
|
||||
if time_based:
|
||||
# Time-based: Aggregate all logs older than retention period
|
||||
# (excluding already aggregated logs)
|
||||
logs_to_process = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.timestamp < retention_cutoff,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate
|
||||
).order_by(APIUsageLog.timestamp.asc()).all()
|
||||
|
||||
logger.info(
|
||||
f"[LogWrapping] Time-based aggregation: Found {len(logs_to_process)} logs "
|
||||
f"older than {self.RETENTION_DAYS} days"
|
||||
)
|
||||
else:
|
||||
# Count-based: Aggregate oldest logs beyond the keep limit
|
||||
logs_to_process = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate
|
||||
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate_count).all()
|
||||
|
||||
logger.info(
|
||||
f"[LogWrapping] Count-based aggregation: Processing {len(logs_to_process)} "
|
||||
f"oldest logs beyond {logs_to_keep} limit"
|
||||
)
|
||||
|
||||
# Get logs to aggregate: oldest logs beyond the keep limit
|
||||
# Order by timestamp ascending to get oldest first
|
||||
# We'll keep the most recent logs_to_keep logs, aggregate the rest
|
||||
logs_to_process = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id
|
||||
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate).all()
|
||||
# Also check for time-based logs even if count-based is primary
|
||||
# This ensures we don't keep very old logs just because they're within the count limit
|
||||
if not time_based and logs_to_aggregate_count > 0:
|
||||
# Get logs that are both old AND beyond count limit
|
||||
old_logs_beyond_limit = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.timestamp < retention_cutoff,
|
||||
APIUsageLog.endpoint != '[AGGREGATED]'
|
||||
).order_by(APIUsageLog.timestamp.asc()).all()
|
||||
|
||||
# Merge with count-based logs, prioritizing old logs
|
||||
existing_ids = {log.id for log in logs_to_process}
|
||||
for old_log in old_logs_beyond_limit:
|
||||
if old_log.id not in existing_ids:
|
||||
logs_to_process.append(old_log)
|
||||
|
||||
logger.info(
|
||||
f"[LogWrapping] Combined aggregation: {len(logs_to_process)} logs to process "
|
||||
f"({logs_to_aggregate_count} count-based + {len(old_logs_beyond_limit)} time-based)"
|
||||
)
|
||||
|
||||
if not logs_to_process:
|
||||
return {
|
||||
@@ -218,10 +301,18 @@ class LogWrappingService:
|
||||
f"Remaining logs: {remaining_count}"
|
||||
)
|
||||
|
||||
# Count how many old logs were aggregated (for reporting)
|
||||
# Count logs that were aggregated based on time (not just count)
|
||||
old_logs_aggregated = 0
|
||||
for log in logs_to_process:
|
||||
if log.timestamp and log.timestamp < retention_cutoff:
|
||||
old_logs_aggregated += 1
|
||||
|
||||
return {
|
||||
'aggregated_count': aggregated_count,
|
||||
'logs_remaining': remaining_count,
|
||||
'periods': periods_created
|
||||
'periods': periods_created,
|
||||
'old_logs_aggregated': old_logs_aggregated
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -14,7 +14,7 @@ from collections import defaultdict, deque
|
||||
import asyncio
|
||||
from loguru import logger
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func
|
||||
from sqlalchemy import and_, func, case
|
||||
import re
|
||||
|
||||
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
|
||||
@@ -369,19 +369,64 @@ async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
|
||||
db.close()
|
||||
|
||||
async def get_lightweight_stats() -> Dict[str, Any]:
|
||||
"""Get lightweight stats for dashboard header."""
|
||||
"""Get lightweight stats for dashboard header.
|
||||
|
||||
Optimized single-query approach using conditional aggregation for better performance.
|
||||
"""
|
||||
db = None
|
||||
try:
|
||||
db = _get_db_session()
|
||||
# Minimal viable placeholder values
|
||||
now = datetime.utcnow()
|
||||
|
||||
# Get stats from last 5 minutes
|
||||
five_minutes_ago = now - timedelta(minutes=5)
|
||||
|
||||
# Optimized: Single query with conditional aggregation instead of two separate queries
|
||||
# This is much faster as it only scans the table once
|
||||
stats = db.query(
|
||||
func.count(APIRequest.id).label('total_requests'),
|
||||
func.sum(
|
||||
case((APIRequest.status_code >= 400, 1), else_=0)
|
||||
).label('total_errors')
|
||||
).filter(
|
||||
APIRequest.timestamp >= five_minutes_ago
|
||||
).first()
|
||||
|
||||
recent_requests = stats.total_requests or 0 if stats else 0
|
||||
recent_errors = int(stats.total_errors or 0) if stats else 0
|
||||
|
||||
# Calculate error rate
|
||||
error_rate = (recent_errors / recent_requests * 100) if recent_requests > 0 else 0.0
|
||||
|
||||
# Determine status based on error rate
|
||||
if error_rate > 10:
|
||||
status = 'critical'
|
||||
icon = '🔴'
|
||||
elif error_rate > 5:
|
||||
status = 'warning'
|
||||
icon = '🟡'
|
||||
else:
|
||||
status = 'healthy'
|
||||
icon = '🟢'
|
||||
|
||||
return {
|
||||
'status': status,
|
||||
'icon': icon,
|
||||
'recent_requests': recent_requests,
|
||||
'recent_errors': recent_errors,
|
||||
'error_rate': round(error_rate, 2),
|
||||
'timestamp': now.isoformat()
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting lightweight stats: {e}", exc_info=True)
|
||||
# Return default healthy state on error
|
||||
return {
|
||||
'status': 'healthy',
|
||||
'icon': '🟢',
|
||||
'recent_requests': 0,
|
||||
'recent_errors': 0,
|
||||
'error_rate': 0.0,
|
||||
'timestamp': now.isoformat()
|
||||
'timestamp': datetime.utcnow().isoformat()
|
||||
}
|
||||
finally:
|
||||
if db is not None:
|
||||
|
||||
@@ -290,6 +290,40 @@ class PricingService:
|
||||
"cost_per_image": 0.04, # $0.04 per image
|
||||
"description": "Stability AI Image Generation"
|
||||
},
|
||||
# WaveSpeed OSS Image Generation Models
|
||||
{
|
||||
"provider": APIProvider.STABILITY, # Using STABILITY provider for image generation
|
||||
"model_name": "qwen-image",
|
||||
"cost_per_image": 0.03, # $0.03 per image (OSS model via WaveSpeed)
|
||||
"cost_per_request": 0.03, # Also support cost_per_request
|
||||
"description": "WaveSpeed Qwen Image (OSS) - Fast generation, cost-effective"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.STABILITY,
|
||||
"model_name": "ideogram-v3-turbo",
|
||||
"cost_per_image": 0.05, # $0.05 per image (OSS model via WaveSpeed)
|
||||
"cost_per_request": 0.05, # Also support cost_per_request
|
||||
"description": "WaveSpeed Ideogram V3 Turbo (OSS) - Photorealistic, text rendering"
|
||||
},
|
||||
# WaveSpeed OSS Image Editing Models
|
||||
{
|
||||
"provider": APIProvider.IMAGE_EDIT,
|
||||
"model_name": "qwen-edit",
|
||||
"cost_per_request": 0.02, # $0.02 per edit (OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed Qwen Image Edit (OSS) - Budget editing, bilingual"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.IMAGE_EDIT,
|
||||
"model_name": "qwen-edit-plus",
|
||||
"cost_per_request": 0.02, # $0.02 per edit (OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed Qwen Image Edit Plus (OSS) - Multi-image editing"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.IMAGE_EDIT,
|
||||
"model_name": "flux-kontext-pro",
|
||||
"cost_per_request": 0.04, # $0.04 per edit (OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed FLUX Kontext Pro (OSS) - Professional editing, typography"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.EXA,
|
||||
"model_name": "exa-search",
|
||||
@@ -305,8 +339,8 @@ class PricingService:
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "default",
|
||||
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
|
||||
"description": "AI Video Generation default pricing"
|
||||
"cost_per_request": 0.25, # UPDATED: Default to WAN 2.5 OSS model ($0.25)
|
||||
"description": "AI Video Generation default pricing (OSS: WAN 2.5)"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
@@ -326,6 +360,25 @@ class PricingService:
|
||||
"cost_per_request": 0.30,
|
||||
"description": "WaveSpeed InfiniteTalk (image + audio to talking avatar video)"
|
||||
},
|
||||
# WaveSpeed OSS Video Generation Models
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "wan-2.5",
|
||||
"cost_per_request": 0.25, # $0.25 per video (~5 seconds, OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed WAN 2.5 (OSS) - Text-to-Video, Image-to-Video, cost-effective"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "alibaba/wan-2.5",
|
||||
"cost_per_request": 0.25, # $0.25 per video (~5 seconds, OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed WAN 2.5 (OSS) - Alternative path, same model"
|
||||
},
|
||||
{
|
||||
"provider": APIProvider.VIDEO,
|
||||
"model_name": "seedance-1.5-pro",
|
||||
"cost_per_request": 0.40, # $0.40 per video (~5 seconds, OSS model via WaveSpeed)
|
||||
"description": "WaveSpeed Seedance 1.5 Pro (OSS) - Longer duration videos (10-30 sec)"
|
||||
},
|
||||
# Audio Generation Pricing (Minimax Speech 02 HD via WaveSpeed)
|
||||
{
|
||||
"provider": APIProvider.AUDIO,
|
||||
@@ -404,7 +457,7 @@ class PricingService:
|
||||
"tier": SubscriptionTier.BASIC,
|
||||
"price_monthly": 29.0,
|
||||
"price_yearly": 290.0,
|
||||
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
|
||||
"ai_text_generation_calls_limit": 50, # INCREASED: Unified limit for all LLM providers (OSS-focused strategy)
|
||||
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
|
||||
"openai_calls_limit": 500,
|
||||
"anthropic_calls_limit": 200,
|
||||
@@ -413,18 +466,18 @@ class PricingService:
|
||||
"serper_calls_limit": 200,
|
||||
"metaphor_calls_limit": 100,
|
||||
"firecrawl_calls_limit": 100,
|
||||
"stability_calls_limit": 5,
|
||||
"stability_calls_limit": 50, # INCREASED: Now includes WaveSpeed OSS models (Qwen Image $0.03)
|
||||
"exa_calls_limit": 500,
|
||||
"video_calls_limit": 20, # 20 videos/month for basic plan
|
||||
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
|
||||
"audio_calls_limit": 50, # 50 AI audio generation calls/month
|
||||
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
|
||||
"monthly_cost_limit": 50.0,
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
|
||||
"description": "Great for individuals and small teams"
|
||||
"video_calls_limit": 30, # INCREASED: 30 videos/month (WAN 2.5 OSS $0.25)
|
||||
"image_edit_calls_limit": 50, # INCREASED: 50 AI image editing calls/month (Qwen Edit OSS $0.02)
|
||||
"audio_calls_limit": 100, # INCREASED: 100 AI audio generation calls/month (Minimax Speech OSS)
|
||||
"gemini_tokens_limit": 100000, # INCREASED: 100K tokens per provider (OSS-focused strategy)
|
||||
"openai_tokens_limit": 100000, # INCREASED: 100K tokens per provider
|
||||
"anthropic_tokens_limit": 100000, # INCREASED: 100K tokens per provider
|
||||
"mistral_tokens_limit": 100000, # INCREASED: 100K tokens per provider
|
||||
"monthly_cost_limit": 45.0, # ADJUSTED: $45 cap (aligns with $40-50 hard limit target)
|
||||
"features": ["full_content_generation", "advanced_research", "basic_analytics", "all_tools_access", "oss_models_priority"],
|
||||
"description": "Perfect for individuals and small teams. Access all ALwrity features with generous limits powered by OSS AI models."
|
||||
},
|
||||
{
|
||||
"name": "Pro",
|
||||
|
||||
156
backend/services/subscription/provider_detection.py
Normal file
156
backend/services/subscription/provider_detection.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""
|
||||
Provider Detection Utility
|
||||
Detects the actual provider (WaveSpeed, Google, HuggingFace, etc.) from model names and endpoints.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from models.subscription_models import APIProvider
|
||||
from loguru import logger
|
||||
|
||||
def detect_actual_provider(provider_enum: APIProvider, model_name: Optional[str] = None, endpoint: Optional[str] = None) -> str:
|
||||
"""
|
||||
Detect the actual provider name from provider enum, model name, and endpoint.
|
||||
|
||||
Args:
|
||||
provider_enum: The APIProvider enum value (e.g., APIProvider.VIDEO, APIProvider.GEMINI)
|
||||
model_name: The model name (e.g., "alibaba/wan-2.5/text-to-video", "gemini-2.5-flash")
|
||||
endpoint: The API endpoint (e.g., "/video-generation/wavespeed", "/image-generation/stability")
|
||||
|
||||
Returns:
|
||||
Actual provider name: "wavespeed", "google", "huggingface", "stability", "openai", "anthropic", etc.
|
||||
"""
|
||||
|
||||
# For LLM providers, use the enum value directly
|
||||
if provider_enum in [APIProvider.GEMINI]:
|
||||
return "google"
|
||||
elif provider_enum == APIProvider.OPENAI:
|
||||
return "openai"
|
||||
elif provider_enum == APIProvider.ANTHROPIC:
|
||||
return "anthropic"
|
||||
elif provider_enum == APIProvider.MISTRAL:
|
||||
# MISTRAL enum is used for HuggingFace models
|
||||
return "huggingface"
|
||||
|
||||
# For search APIs, use the enum value
|
||||
elif provider_enum in [APIProvider.TAVILY, APIProvider.SERPER, APIProvider.METAPHOR, APIProvider.FIRECRAWL, APIProvider.EXA]:
|
||||
return provider_enum.value
|
||||
|
||||
# For media generation, detect from model name or endpoint
|
||||
elif provider_enum == APIProvider.VIDEO:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed models
|
||||
if any(x in model_lower for x in ["wan-2.5", "seedance", "infinitetalk", "wavespeed", "alibaba"]):
|
||||
return "wavespeed"
|
||||
# HuggingFace models
|
||||
elif any(x in model_lower for x in ["huggingface", "hf", "tencent", "hunyuan"]):
|
||||
return "huggingface"
|
||||
# Google models (future)
|
||||
elif any(x in model_lower for x in ["veo", "gemini"]):
|
||||
return "google"
|
||||
# OpenAI models (future)
|
||||
elif any(x in model_lower for x in ["sora", "openai"]):
|
||||
return "openai"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
||||
return "huggingface"
|
||||
elif "google" in endpoint_lower or "gemini" in endpoint_lower:
|
||||
return "google"
|
||||
elif "openai" in endpoint_lower:
|
||||
return "openai"
|
||||
|
||||
# Default for video: WaveSpeed (most common)
|
||||
return "wavespeed"
|
||||
|
||||
elif provider_enum == APIProvider.AUDIO:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed models
|
||||
if any(x in model_lower for x in ["minimax", "speech-02", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Google models
|
||||
elif any(x in model_lower for x in ["google", "gemini", "tts"]):
|
||||
return "google"
|
||||
# OpenAI models
|
||||
elif any(x in model_lower for x in ["openai", "tts-1"]):
|
||||
return "openai"
|
||||
# ElevenLabs (future)
|
||||
elif "elevenlabs" in model_lower:
|
||||
return "elevenlabs"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "google" in endpoint_lower:
|
||||
return "google"
|
||||
elif "openai" in endpoint_lower:
|
||||
return "openai"
|
||||
|
||||
# Default for audio: WaveSpeed (most common)
|
||||
return "wavespeed"
|
||||
|
||||
elif provider_enum == APIProvider.STABILITY:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed OSS models
|
||||
if any(x in model_lower for x in ["qwen", "ideogram", "flux", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Stability AI models
|
||||
elif any(x in model_lower for x in ["stability", "stable-diffusion", "sd-"]):
|
||||
return "stability"
|
||||
# HuggingFace models
|
||||
elif any(x in model_lower for x in ["huggingface", "hf", "runway"]):
|
||||
return "huggingface"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "stability" in endpoint_lower:
|
||||
return "stability"
|
||||
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
|
||||
return "huggingface"
|
||||
|
||||
# Default: check if it's actually WaveSpeed based on common OSS models
|
||||
if model_name and any(x in model_name.lower() for x in ["qwen", "ideogram", "flux"]):
|
||||
return "wavespeed"
|
||||
|
||||
# Default for image generation: Stability (legacy)
|
||||
return "stability"
|
||||
|
||||
elif provider_enum == APIProvider.IMAGE_EDIT:
|
||||
# Check model name first
|
||||
if model_name:
|
||||
model_lower = model_name.lower()
|
||||
# WaveSpeed OSS models
|
||||
if any(x in model_lower for x in ["qwen", "flux", "kontext", "wavespeed"]):
|
||||
return "wavespeed"
|
||||
# Stability AI models
|
||||
elif any(x in model_lower for x in ["stability", "stable-diffusion"]):
|
||||
return "stability"
|
||||
|
||||
# Check endpoint
|
||||
if endpoint:
|
||||
endpoint_lower = endpoint.lower()
|
||||
if "wavespeed" in endpoint_lower:
|
||||
return "wavespeed"
|
||||
elif "stability" in endpoint_lower:
|
||||
return "stability"
|
||||
|
||||
# Default for image editing: WaveSpeed (OSS-first strategy)
|
||||
return "wavespeed"
|
||||
|
||||
# Fallback: use enum value
|
||||
logger.warning(f"Could not detect actual provider for {provider_enum.value}, using enum value")
|
||||
return provider_enum.value
|
||||
264
backend/services/subscription/renewal_history_retention.py
Normal file
264
backend/services/subscription/renewal_history_retention.py
Normal file
@@ -0,0 +1,264 @@
|
||||
"""
|
||||
Renewal History Retention Service
|
||||
Manages retention policies for subscription renewal history records.
|
||||
|
||||
Retention Policy:
|
||||
- 0-12 months: Full records with usage snapshots
|
||||
- 12-24 months: Full records (compressed/removed usage snapshots)
|
||||
- 24-84 months: Summary records (no usage snapshots, payment data only)
|
||||
- 84+ months: Mark for archive (payment data preserved indefinitely)
|
||||
"""
|
||||
|
||||
from typing import Dict, Any, List, Optional
|
||||
from datetime import datetime, timedelta
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import func, desc
|
||||
from loguru import logger
|
||||
import json
|
||||
|
||||
from models.subscription_models import SubscriptionRenewalHistory
|
||||
|
||||
|
||||
class RenewalHistoryRetentionService:
|
||||
"""Service for managing renewal history retention policies."""
|
||||
|
||||
# Retention periods (in days)
|
||||
COMPRESS_SNAPSHOT_DAYS = 365 # 12 months - compress/remove usage snapshots
|
||||
SUMMARY_RECORDS_DAYS = 730 # 24 months - create summary records
|
||||
ARCHIVE_DAYS = 2555 # 84 months (7 years) - mark for archive
|
||||
|
||||
def __init__(self, db: Session):
|
||||
self.db = db
|
||||
|
||||
def check_and_apply_retention(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Check and apply retention policies for renewal history.
|
||||
|
||||
Applies retention in stages:
|
||||
1. Compress usage snapshots for records 12-24 months old
|
||||
2. Create summary records for records 24-84 months old
|
||||
3. Mark records older than 84 months for archive
|
||||
|
||||
Returns:
|
||||
Dict with retention status and statistics
|
||||
"""
|
||||
try:
|
||||
now = datetime.utcnow()
|
||||
compress_cutoff = now - timedelta(days=self.COMPRESS_SNAPSHOT_DAYS)
|
||||
summary_cutoff = now - timedelta(days=self.SUMMARY_RECORDS_DAYS)
|
||||
archive_cutoff = now - timedelta(days=self.ARCHIVE_DAYS)
|
||||
|
||||
# Count records in each retention tier
|
||||
total_count = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
records_to_compress = self.db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < compress_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= summary_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None) # Has snapshot to compress
|
||||
).all()
|
||||
|
||||
records_to_summarize = self.db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < summary_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= archive_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None) # Has snapshot to remove
|
||||
).all()
|
||||
|
||||
records_to_archive = self.db.query(SubscriptionRenewalHistory).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < archive_cutoff
|
||||
).all()
|
||||
|
||||
# Apply retention policies
|
||||
compressed_count = self._compress_usage_snapshots(records_to_compress)
|
||||
summarized_count = self._create_summary_records(records_to_summarize)
|
||||
archived_count = self._mark_for_archive(records_to_archive)
|
||||
|
||||
total_processed = compressed_count + summarized_count + archived_count
|
||||
|
||||
if total_processed == 0:
|
||||
return {
|
||||
'retention_applied': False,
|
||||
'total_records': total_count,
|
||||
'records_to_compress': len(records_to_compress),
|
||||
'records_to_summarize': len(records_to_summarize),
|
||||
'records_to_archive': len(records_to_archive),
|
||||
'message': 'No records require retention processing'
|
||||
}
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"[RenewalRetention] Applied retention for user {user_id}: "
|
||||
f"{compressed_count} compressed, {summarized_count} summarized, "
|
||||
f"{archived_count} archived"
|
||||
)
|
||||
|
||||
return {
|
||||
'retention_applied': True,
|
||||
'total_records': total_count,
|
||||
'compressed_count': compressed_count,
|
||||
'summarized_count': summarized_count,
|
||||
'archived_count': archived_count,
|
||||
'total_processed': total_processed,
|
||||
'message': f'Processed {total_processed} records: {compressed_count} compressed, {summarized_count} summarized, {archived_count} archived'
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
logger.error(f"[RenewalRetention] Error applying retention for user {user_id}: {e}", exc_info=True)
|
||||
return {
|
||||
'retention_applied': False,
|
||||
'error': str(e),
|
||||
'message': f'Error applying retention: {str(e)}'
|
||||
}
|
||||
|
||||
def _compress_usage_snapshots(self, records: List[SubscriptionRenewalHistory]) -> int:
|
||||
"""
|
||||
Compress usage snapshots for records 12-24 months old.
|
||||
|
||||
Strategy: Replace detailed JSON snapshot with summary statistics only.
|
||||
Keeps only essential metrics: total_calls, total_tokens, total_cost.
|
||||
"""
|
||||
compressed = 0
|
||||
for record in records:
|
||||
if record.usage_before_renewal:
|
||||
try:
|
||||
usage_data = record.usage_before_renewal
|
||||
|
||||
# Handle both dict (SQLAlchemy JSON) and string formats
|
||||
if isinstance(usage_data, str):
|
||||
try:
|
||||
usage_data = json.loads(usage_data)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON, remove it
|
||||
record.usage_before_renewal = None
|
||||
compressed += 1
|
||||
continue
|
||||
elif not isinstance(usage_data, dict):
|
||||
# If it's not a dict or string, remove it
|
||||
record.usage_before_renewal = None
|
||||
compressed += 1
|
||||
continue
|
||||
|
||||
# Check if already compressed (has 'compressed_at' key)
|
||||
if isinstance(usage_data, dict) and 'compressed_at' in usage_data:
|
||||
# Already compressed, skip
|
||||
continue
|
||||
|
||||
# Create compressed summary (keep only key metrics)
|
||||
compressed_summary = {
|
||||
'total_calls': usage_data.get('total_calls', 0),
|
||||
'total_tokens': usage_data.get('total_tokens', 0),
|
||||
'total_cost': usage_data.get('total_cost', 0.0),
|
||||
'compressed_at': datetime.utcnow().isoformat(),
|
||||
'note': 'Usage snapshot compressed after 12 months'
|
||||
}
|
||||
|
||||
record.usage_before_renewal = compressed_summary
|
||||
compressed += 1
|
||||
|
||||
except (TypeError, AttributeError, KeyError) as e:
|
||||
logger.warning(f"[RenewalRetention] Failed to compress snapshot for record {record.id}: {e}")
|
||||
# If compression fails, remove snapshot entirely
|
||||
record.usage_before_renewal = None
|
||||
compressed += 1
|
||||
|
||||
return compressed
|
||||
|
||||
def _create_summary_records(self, records: List[SubscriptionRenewalHistory]) -> int:
|
||||
"""
|
||||
Create summary records for records 24-84 months old.
|
||||
|
||||
Strategy: Remove usage snapshots, keep only payment and subscription data.
|
||||
"""
|
||||
summarized = 0
|
||||
for record in records:
|
||||
if record.usage_before_renewal is not None:
|
||||
# Remove usage snapshot, keep payment and subscription data
|
||||
record.usage_before_renewal = None
|
||||
summarized += 1
|
||||
|
||||
return summarized
|
||||
|
||||
def _mark_for_archive(self, records: List[SubscriptionRenewalHistory]) -> int:
|
||||
"""
|
||||
Mark records older than 84 months for archive.
|
||||
|
||||
Strategy: Ensure usage snapshots are removed, payment data is preserved.
|
||||
Note: In future, these could be moved to an archive table.
|
||||
"""
|
||||
archived = 0
|
||||
for record in records:
|
||||
# Ensure usage snapshot is removed (should already be done)
|
||||
if record.usage_before_renewal is not None:
|
||||
record.usage_before_renewal = None
|
||||
archived += 1
|
||||
else:
|
||||
# Already processed, just count
|
||||
archived += 1
|
||||
|
||||
return archived
|
||||
|
||||
def get_retention_stats(self, user_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Get retention statistics for a user's renewal history.
|
||||
|
||||
Returns breakdown by retention tier.
|
||||
"""
|
||||
try:
|
||||
now = datetime.utcnow()
|
||||
compress_cutoff = now - timedelta(days=self.COMPRESS_SNAPSHOT_DAYS)
|
||||
summary_cutoff = now - timedelta(days=self.SUMMARY_RECORDS_DAYS)
|
||||
archive_cutoff = now - timedelta(days=self.ARCHIVE_DAYS)
|
||||
|
||||
total = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id
|
||||
).scalar() or 0
|
||||
|
||||
recent = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at >= compress_cutoff
|
||||
).scalar() or 0
|
||||
|
||||
to_compress = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < compress_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= summary_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None)
|
||||
).scalar() or 0
|
||||
|
||||
to_summarize = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < summary_cutoff,
|
||||
SubscriptionRenewalHistory.created_at >= archive_cutoff,
|
||||
SubscriptionRenewalHistory.usage_before_renewal.isnot(None)
|
||||
).scalar() or 0
|
||||
|
||||
to_archive = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
|
||||
SubscriptionRenewalHistory.user_id == user_id,
|
||||
SubscriptionRenewalHistory.created_at < archive_cutoff
|
||||
).scalar() or 0
|
||||
|
||||
return {
|
||||
'total_records': total,
|
||||
'recent_records': recent, # 0-12 months
|
||||
'records_to_compress': to_compress, # 12-24 months
|
||||
'records_to_summarize': to_summarize, # 24-84 months
|
||||
'records_to_archive': to_archive, # 84+ months
|
||||
'retention_policy': {
|
||||
'compress_after_days': self.COMPRESS_SNAPSHOT_DAYS,
|
||||
'summarize_after_days': self.SUMMARY_RECORDS_DAYS,
|
||||
'archive_after_days': self.ARCHIVE_DAYS
|
||||
}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[RenewalRetention] Error getting retention stats for user {user_id}: {e}", exc_info=True)
|
||||
return {
|
||||
'error': str(e),
|
||||
'total_records': 0
|
||||
}
|
||||
@@ -6,6 +6,7 @@ from loguru import logger
|
||||
|
||||
_checked_subscription_plan_columns: bool = False
|
||||
_checked_usage_summaries_columns: bool = False
|
||||
_checked_api_usage_logs_columns: bool = False
|
||||
|
||||
|
||||
def ensure_subscription_plan_columns(db: Session) -> None:
|
||||
@@ -114,9 +115,58 @@ def ensure_usage_summaries_columns(db: Session) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def ensure_api_usage_logs_columns(db: Session) -> None:
|
||||
"""Ensure required columns exist on api_usage_logs for runtime safety.
|
||||
|
||||
This is a defensive guard for environments where migrations have not yet
|
||||
been applied. If columns are missing (e.g., actual_provider_name), we add them
|
||||
with a safe default so ORM queries do not fail.
|
||||
"""
|
||||
global _checked_api_usage_logs_columns
|
||||
if _checked_api_usage_logs_columns:
|
||||
return
|
||||
|
||||
try:
|
||||
# Discover existing columns using PRAGMA
|
||||
result = db.execute(text("PRAGMA table_info(api_usage_logs)"))
|
||||
cols: Set[str] = {row[1] for row in result}
|
||||
|
||||
logger.debug(f"Schema check: Found {len(cols)} columns in api_usage_logs table")
|
||||
|
||||
# Columns we may reference in models but might be missing in older DBs
|
||||
required_columns = {
|
||||
"actual_provider_name": "VARCHAR(50) NULL",
|
||||
}
|
||||
|
||||
for col_name, ddl in required_columns.items():
|
||||
if col_name not in cols:
|
||||
logger.info(f"Adding missing column {col_name} to api_usage_logs table")
|
||||
try:
|
||||
db.execute(text(f"ALTER TABLE api_usage_logs ADD COLUMN {col_name} {ddl}"))
|
||||
db.commit()
|
||||
logger.info(f"Successfully added column {col_name}")
|
||||
except Exception as alter_err:
|
||||
logger.error(f"Failed to add column {col_name}: {alter_err}")
|
||||
db.rollback()
|
||||
# Don't set flag on error - allow retry
|
||||
raise
|
||||
else:
|
||||
logger.debug(f"Column {col_name} already exists")
|
||||
|
||||
# Only set flag if we successfully completed the check
|
||||
_checked_api_usage_logs_columns = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error ensuring api_usage_logs columns: {e}", exc_info=True)
|
||||
db.rollback()
|
||||
# Don't set the flag if there was an error, so we retry next time
|
||||
_checked_api_usage_logs_columns = False
|
||||
raise
|
||||
|
||||
|
||||
def ensure_all_schema_columns(db: Session) -> None:
|
||||
"""Ensure all required columns exist in subscription-related tables."""
|
||||
ensure_subscription_plan_columns(db)
|
||||
ensure_usage_summaries_columns(db)
|
||||
ensure_api_usage_logs_columns(db)
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from models.subscription_models import (
|
||||
UserSubscription, UsageStatus
|
||||
)
|
||||
from .pricing_service import PricingService
|
||||
from .provider_detection import detect_actual_provider
|
||||
|
||||
class UsageTrackingService:
|
||||
"""Service for tracking API usage and managing subscription limits."""
|
||||
@@ -67,12 +68,21 @@ class UsageTrackingService:
|
||||
|
||||
# Create usage log entry
|
||||
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||
|
||||
# Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.)
|
||||
actual_provider_name = detect_actual_provider(
|
||||
provider_enum=provider,
|
||||
model_name=model_used,
|
||||
endpoint=endpoint
|
||||
)
|
||||
|
||||
usage_log = APIUsageLog(
|
||||
user_id=user_id,
|
||||
provider=provider,
|
||||
endpoint=endpoint,
|
||||
method=method,
|
||||
model_used=model_used,
|
||||
actual_provider_name=actual_provider_name, # Track actual provider
|
||||
tokens_input=tokens_input,
|
||||
tokens_output=tokens_output,
|
||||
tokens_total=(tokens_input or 0) + (tokens_output or 0),
|
||||
@@ -404,18 +414,128 @@ class UsageTrackingService:
|
||||
'cost': mistral_cost
|
||||
}
|
||||
|
||||
# Add other providers (Video, Audio, Image, Image Edit) for comprehensive breakdown
|
||||
# Video (WaveSpeed, HuggingFace, etc.)
|
||||
video_calls = getattr(summary, "video_calls", 0) or 0
|
||||
video_cost = getattr(summary, "video_cost", 0.0) or 0.0
|
||||
if video_calls > 0 and video_cost == 0.0:
|
||||
video_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.VIDEO,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if video_logs:
|
||||
video_cost = sum(float(log.cost_total or 0.0) for log in video_logs)
|
||||
|
||||
provider_breakdown['video'] = {
|
||||
'calls': video_calls,
|
||||
'tokens': 0,
|
||||
'cost': video_cost
|
||||
}
|
||||
|
||||
# Audio (WaveSpeed, etc.)
|
||||
audio_calls = getattr(summary, "audio_calls", 0) or 0
|
||||
audio_cost = getattr(summary, "audio_cost", 0.0) or 0.0
|
||||
if audio_calls > 0 and audio_cost == 0.0:
|
||||
audio_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.AUDIO,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if audio_logs:
|
||||
audio_cost = sum(float(log.cost_total or 0.0) for log in audio_logs)
|
||||
|
||||
provider_breakdown['audio'] = {
|
||||
'calls': audio_calls,
|
||||
'tokens': 0,
|
||||
'cost': audio_cost
|
||||
}
|
||||
|
||||
# Image Generation (Stability/WaveSpeed)
|
||||
stability_calls = getattr(summary, "stability_calls", 0) or 0
|
||||
stability_cost = getattr(summary, "stability_cost", 0.0) or 0.0
|
||||
if stability_calls > 0 and stability_cost == 0.0:
|
||||
stability_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.STABILITY,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if stability_logs:
|
||||
stability_cost = sum(float(log.cost_total or 0.0) for log in stability_logs)
|
||||
|
||||
provider_breakdown['image'] = {
|
||||
'calls': stability_calls,
|
||||
'tokens': 0,
|
||||
'cost': stability_cost
|
||||
}
|
||||
|
||||
# Image Editing (WaveSpeed)
|
||||
image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
|
||||
image_edit_cost = getattr(summary, "image_edit_cost", 0.0) or 0.0
|
||||
if image_edit_calls > 0 and image_edit_cost == 0.0:
|
||||
image_edit_logs = self.db.query(APIUsageLog).filter(
|
||||
APIUsageLog.user_id == user_id,
|
||||
APIUsageLog.provider == APIProvider.IMAGE_EDIT,
|
||||
APIUsageLog.billing_period == billing_period
|
||||
).all()
|
||||
if image_edit_logs:
|
||||
image_edit_cost = sum(float(log.cost_total or 0.0) for log in image_edit_logs)
|
||||
|
||||
provider_breakdown['image_edit'] = {
|
||||
'calls': image_edit_calls,
|
||||
'tokens': 0,
|
||||
'cost': image_edit_cost
|
||||
}
|
||||
|
||||
# Search APIs
|
||||
tavily_calls = getattr(summary, "tavily_calls", 0) or 0
|
||||
tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0
|
||||
provider_breakdown['tavily'] = {
|
||||
'calls': tavily_calls,
|
||||
'tokens': 0,
|
||||
'cost': tavily_cost
|
||||
}
|
||||
|
||||
serper_calls = getattr(summary, "serper_calls", 0) or 0
|
||||
serper_cost = getattr(summary, "serper_cost", 0.0) or 0.0
|
||||
provider_breakdown['serper'] = {
|
||||
'calls': serper_calls,
|
||||
'tokens': 0,
|
||||
'cost': serper_cost
|
||||
}
|
||||
|
||||
exa_calls = getattr(summary, "exa_calls", 0) or 0
|
||||
exa_cost = getattr(summary, "exa_cost", 0.0) or 0.0
|
||||
provider_breakdown['exa'] = {
|
||||
'calls': exa_calls,
|
||||
'tokens': 0,
|
||||
'cost': exa_cost
|
||||
}
|
||||
|
||||
# Calculate total cost from provider breakdown if summary total_cost is 0
|
||||
calculated_total_cost = gemini_cost + mistral_cost
|
||||
calculated_total_cost = (
|
||||
gemini_cost + mistral_cost + video_cost + audio_cost +
|
||||
stability_cost + image_edit_cost + tavily_cost + serper_cost + exa_cost
|
||||
)
|
||||
summary_total_cost = summary.total_cost or 0.0
|
||||
# Use calculated cost if summary cost is 0, otherwise use summary cost (it's more accurate)
|
||||
final_total_cost = summary_total_cost if summary_total_cost > 0 else calculated_total_cost
|
||||
|
||||
# If we calculated costs from logs, update the summary for future requests
|
||||
if calculated_total_cost > 0 and summary_total_cost == 0.0:
|
||||
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}")
|
||||
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}, video_cost={video_cost:.6f}, audio_cost={audio_cost:.6f}, image_cost={stability_cost:.6f}")
|
||||
summary.total_cost = final_total_cost
|
||||
summary.gemini_cost = gemini_cost
|
||||
summary.mistral_cost = mistral_cost
|
||||
# Update other provider costs if they exist
|
||||
if hasattr(summary, 'video_cost'):
|
||||
summary.video_cost = video_cost
|
||||
if hasattr(summary, 'audio_cost'):
|
||||
summary.audio_cost = audio_cost
|
||||
if hasattr(summary, 'stability_cost'):
|
||||
summary.stability_cost = stability_cost
|
||||
if hasattr(summary, 'image_edit_cost'):
|
||||
summary.image_edit_cost = image_edit_cost
|
||||
try:
|
||||
self.db.commit()
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user