AI Analysis and Content Strategy fixes. Enhanced Strategy Routes refactoring.

This commit is contained in:
ajaysi
2026-01-10 19:32:50 +05:30
parent 0b63ae7fc1
commit 8193cdba67
298 changed files with 45678 additions and 10952 deletions

View File

@@ -182,4 +182,4 @@ This package consolidates the following previously scattered files:
- `services.onboarding` - Onboarding and user setup
- `models.subscription_models` - Database models
- `api.subscription_api` - API endpoints
- `api.subscription` - API endpoints (modular structure with routes in `api/subscription/routes/`)

View File

@@ -1,7 +1,13 @@
"""
Log Wrapping Service
Intelligently wraps API usage logs when they exceed 5000 records.
Intelligently wraps API usage logs when they exceed limits (count or time-based).
Aggregates old logs into cumulative records while preserving historical data.
Features:
- Count-based retention: Keeps 4,000 most recent detailed logs
- Time-based retention: Aggregates logs older than 90 days
- Automatic aggregation: Triggered on log queries
- Context preservation: Maintains costs, tokens, counts, success rates
"""
from typing import Dict, Any, List, Optional
@@ -18,13 +24,18 @@ class LogWrappingService:
MAX_LOGS_PER_USER = 5000
AGGREGATION_THRESHOLD_DAYS = 30 # Aggregate logs older than 30 days
RETENTION_DAYS = 90 # Time-based retention: aggregate logs older than 90 days
def __init__(self, db: Session):
self.db = db
def check_and_wrap_logs(self, user_id: str) -> Dict[str, Any]:
"""
Check if user has exceeded log limit and wrap if necessary.
Check if user has exceeded log limit (count or time-based) and wrap if necessary.
Checks both:
1. Count-based: If user has more than MAX_LOGS_PER_USER logs
2. Time-based: If user has logs older than RETENTION_DAYS
Returns:
Dict with wrapping status and statistics
@@ -35,18 +46,42 @@ class LogWrappingService:
APIUsageLog.user_id == user_id
).scalar() or 0
if total_count <= self.MAX_LOGS_PER_USER:
# Check for logs older than retention period
retention_cutoff = datetime.utcnow() - timedelta(days=self.RETENTION_DAYS)
old_logs_count = self.db.query(func.count(APIUsageLog.id)).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.timestamp < retention_cutoff,
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate already aggregated logs
).scalar() or 0
# Determine if wrapping is needed
count_based_trigger = total_count > self.MAX_LOGS_PER_USER
time_based_trigger = old_logs_count > 0
if not count_based_trigger and not time_based_trigger:
return {
'wrapped': False,
'total_logs': total_count,
'old_logs': old_logs_count,
'max_logs': self.MAX_LOGS_PER_USER,
'message': f'Log count ({total_count}) is within limit ({self.MAX_LOGS_PER_USER})'
'retention_days': self.RETENTION_DAYS,
'message': f'Log count ({total_count}) and age are within limits'
}
# Need to wrap logs - aggregate old logs
logger.info(f"[LogWrapping] User {user_id} has {total_count} logs, exceeding limit of {self.MAX_LOGS_PER_USER}. Starting wrap...")
# Determine trigger reason
trigger_reasons = []
if count_based_trigger:
trigger_reasons.append(f'count limit ({total_count} > {self.MAX_LOGS_PER_USER})')
if time_based_trigger:
trigger_reasons.append(f'time-based retention ({old_logs_count} logs older than {self.RETENTION_DAYS} days)')
wrap_result = self._wrap_old_logs(user_id, total_count)
logger.info(
f"[LogWrapping] User {user_id} needs log wrapping. "
f"Total: {total_count}, Old logs: {old_logs_count}. "
f"Triggers: {', '.join(trigger_reasons)}"
)
wrap_result = self._wrap_old_logs(user_id, total_count, time_based=time_based_trigger)
return {
'wrapped': True,
@@ -54,6 +89,8 @@ class LogWrappingService:
'total_logs_after': wrap_result['logs_remaining'],
'aggregated_logs': wrap_result['aggregated_count'],
'aggregated_periods': wrap_result['periods'],
'trigger_reasons': trigger_reasons,
'old_logs_aggregated': wrap_result.get('old_logs_aggregated', 0),
'message': f'Wrapped {wrap_result["aggregated_count"]} logs into {len(wrap_result["periods"])} aggregated records'
}
@@ -65,30 +102,76 @@ class LogWrappingService:
'message': f'Error wrapping logs: {str(e)}'
}
def _wrap_old_logs(self, user_id: str, total_count: int) -> Dict[str, Any]:
def _wrap_old_logs(self, user_id: str, total_count: int, time_based: bool = False) -> Dict[str, Any]:
"""
Aggregate old logs into cumulative records.
Strategy:
1. Keep most recent 4000 logs (detailed)
2. Aggregate logs older than 30 days or oldest logs beyond 4000
3. Create aggregated records grouped by provider and billing period
4. Delete individual logs that were aggregated
1. Keep most recent 4000 logs (detailed) - count-based
2. Aggregate logs older than RETENTION_DAYS - time-based
3. Aggregate oldest logs beyond 4000 limit - count-based
4. Create aggregated records grouped by provider and billing period
5. Delete individual logs that were aggregated
Args:
user_id: User ID
total_count: Total number of logs for user
time_based: If True, prioritize time-based retention over count-based
"""
try:
# Calculate how many logs to keep (4000 detailed, rest aggregated)
# Calculate retention cutoff date
retention_cutoff = datetime.utcnow() - timedelta(days=self.RETENTION_DAYS)
aggregation_cutoff = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
# Determine which logs to aggregate
logs_to_keep = 4000
logs_to_aggregate = total_count - logs_to_keep
logs_to_aggregate_count = max(0, total_count - logs_to_keep)
# Get cutoff date (30 days ago)
cutoff_date = datetime.utcnow() - timedelta(days=self.AGGREGATION_THRESHOLD_DAYS)
if time_based:
# Time-based: Aggregate all logs older than retention period
# (excluding already aggregated logs)
logs_to_process = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.timestamp < retention_cutoff,
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate
).order_by(APIUsageLog.timestamp.asc()).all()
logger.info(
f"[LogWrapping] Time-based aggregation: Found {len(logs_to_process)} logs "
f"older than {self.RETENTION_DAYS} days"
)
else:
# Count-based: Aggregate oldest logs beyond the keep limit
logs_to_process = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.endpoint != '[AGGREGATED]' # Don't re-aggregate
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate_count).all()
logger.info(
f"[LogWrapping] Count-based aggregation: Processing {len(logs_to_process)} "
f"oldest logs beyond {logs_to_keep} limit"
)
# Get logs to aggregate: oldest logs beyond the keep limit
# Order by timestamp ascending to get oldest first
# We'll keep the most recent logs_to_keep logs, aggregate the rest
logs_to_process = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id
).order_by(APIUsageLog.timestamp.asc()).limit(logs_to_aggregate).all()
# Also check for time-based logs even if count-based is primary
# This ensures we don't keep very old logs just because they're within the count limit
if not time_based and logs_to_aggregate_count > 0:
# Get logs that are both old AND beyond count limit
old_logs_beyond_limit = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.timestamp < retention_cutoff,
APIUsageLog.endpoint != '[AGGREGATED]'
).order_by(APIUsageLog.timestamp.asc()).all()
# Merge with count-based logs, prioritizing old logs
existing_ids = {log.id for log in logs_to_process}
for old_log in old_logs_beyond_limit:
if old_log.id not in existing_ids:
logs_to_process.append(old_log)
logger.info(
f"[LogWrapping] Combined aggregation: {len(logs_to_process)} logs to process "
f"({logs_to_aggregate_count} count-based + {len(old_logs_beyond_limit)} time-based)"
)
if not logs_to_process:
return {
@@ -218,10 +301,18 @@ class LogWrappingService:
f"Remaining logs: {remaining_count}"
)
# Count how many old logs were aggregated (for reporting)
# Count logs that were aggregated based on time (not just count)
old_logs_aggregated = 0
for log in logs_to_process:
if log.timestamp and log.timestamp < retention_cutoff:
old_logs_aggregated += 1
return {
'aggregated_count': aggregated_count,
'logs_remaining': remaining_count,
'periods': periods_created
'periods': periods_created,
'old_logs_aggregated': old_logs_aggregated
}
except Exception as e:

View File

@@ -14,7 +14,7 @@ from collections import defaultdict, deque
import asyncio
from loguru import logger
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
from sqlalchemy import and_, func, case
import re
from models.api_monitoring import APIRequest, APIEndpointStats, SystemHealth, CachePerformance
@@ -369,19 +369,64 @@ async def get_monitoring_stats(minutes: int = 5) -> Dict[str, Any]:
db.close()
async def get_lightweight_stats() -> Dict[str, Any]:
"""Get lightweight stats for dashboard header."""
"""Get lightweight stats for dashboard header.
Optimized single-query approach using conditional aggregation for better performance.
"""
db = None
try:
db = _get_db_session()
# Minimal viable placeholder values
now = datetime.utcnow()
# Get stats from last 5 minutes
five_minutes_ago = now - timedelta(minutes=5)
# Optimized: Single query with conditional aggregation instead of two separate queries
# This is much faster as it only scans the table once
stats = db.query(
func.count(APIRequest.id).label('total_requests'),
func.sum(
case((APIRequest.status_code >= 400, 1), else_=0)
).label('total_errors')
).filter(
APIRequest.timestamp >= five_minutes_ago
).first()
recent_requests = stats.total_requests or 0 if stats else 0
recent_errors = int(stats.total_errors or 0) if stats else 0
# Calculate error rate
error_rate = (recent_errors / recent_requests * 100) if recent_requests > 0 else 0.0
# Determine status based on error rate
if error_rate > 10:
status = 'critical'
icon = '🔴'
elif error_rate > 5:
status = 'warning'
icon = '🟡'
else:
status = 'healthy'
icon = '🟢'
return {
'status': status,
'icon': icon,
'recent_requests': recent_requests,
'recent_errors': recent_errors,
'error_rate': round(error_rate, 2),
'timestamp': now.isoformat()
}
except Exception as e:
logger.error(f"Error getting lightweight stats: {e}", exc_info=True)
# Return default healthy state on error
return {
'status': 'healthy',
'icon': '🟢',
'recent_requests': 0,
'recent_errors': 0,
'error_rate': 0.0,
'timestamp': now.isoformat()
'timestamp': datetime.utcnow().isoformat()
}
finally:
if db is not None:

View File

@@ -290,6 +290,40 @@ class PricingService:
"cost_per_image": 0.04, # $0.04 per image
"description": "Stability AI Image Generation"
},
# WaveSpeed OSS Image Generation Models
{
"provider": APIProvider.STABILITY, # Using STABILITY provider for image generation
"model_name": "qwen-image",
"cost_per_image": 0.03, # $0.03 per image (OSS model via WaveSpeed)
"cost_per_request": 0.03, # Also support cost_per_request
"description": "WaveSpeed Qwen Image (OSS) - Fast generation, cost-effective"
},
{
"provider": APIProvider.STABILITY,
"model_name": "ideogram-v3-turbo",
"cost_per_image": 0.05, # $0.05 per image (OSS model via WaveSpeed)
"cost_per_request": 0.05, # Also support cost_per_request
"description": "WaveSpeed Ideogram V3 Turbo (OSS) - Photorealistic, text rendering"
},
# WaveSpeed OSS Image Editing Models
{
"provider": APIProvider.IMAGE_EDIT,
"model_name": "qwen-edit",
"cost_per_request": 0.02, # $0.02 per edit (OSS model via WaveSpeed)
"description": "WaveSpeed Qwen Image Edit (OSS) - Budget editing, bilingual"
},
{
"provider": APIProvider.IMAGE_EDIT,
"model_name": "qwen-edit-plus",
"cost_per_request": 0.02, # $0.02 per edit (OSS model via WaveSpeed)
"description": "WaveSpeed Qwen Image Edit Plus (OSS) - Multi-image editing"
},
{
"provider": APIProvider.IMAGE_EDIT,
"model_name": "flux-kontext-pro",
"cost_per_request": 0.04, # $0.04 per edit (OSS model via WaveSpeed)
"description": "WaveSpeed FLUX Kontext Pro (OSS) - Professional editing, typography"
},
{
"provider": APIProvider.EXA,
"model_name": "exa-search",
@@ -305,8 +339,8 @@ class PricingService:
{
"provider": APIProvider.VIDEO,
"model_name": "default",
"cost_per_request": 0.10, # $0.10 per video generation (estimated)
"description": "AI Video Generation default pricing"
"cost_per_request": 0.25, # UPDATED: Default to WAN 2.5 OSS model ($0.25)
"description": "AI Video Generation default pricing (OSS: WAN 2.5)"
},
{
"provider": APIProvider.VIDEO,
@@ -326,6 +360,25 @@ class PricingService:
"cost_per_request": 0.30,
"description": "WaveSpeed InfiniteTalk (image + audio to talking avatar video)"
},
# WaveSpeed OSS Video Generation Models
{
"provider": APIProvider.VIDEO,
"model_name": "wan-2.5",
"cost_per_request": 0.25, # $0.25 per video (~5 seconds, OSS model via WaveSpeed)
"description": "WaveSpeed WAN 2.5 (OSS) - Text-to-Video, Image-to-Video, cost-effective"
},
{
"provider": APIProvider.VIDEO,
"model_name": "alibaba/wan-2.5",
"cost_per_request": 0.25, # $0.25 per video (~5 seconds, OSS model via WaveSpeed)
"description": "WaveSpeed WAN 2.5 (OSS) - Alternative path, same model"
},
{
"provider": APIProvider.VIDEO,
"model_name": "seedance-1.5-pro",
"cost_per_request": 0.40, # $0.40 per video (~5 seconds, OSS model via WaveSpeed)
"description": "WaveSpeed Seedance 1.5 Pro (OSS) - Longer duration videos (10-30 sec)"
},
# Audio Generation Pricing (Minimax Speech 02 HD via WaveSpeed)
{
"provider": APIProvider.AUDIO,
@@ -404,7 +457,7 @@ class PricingService:
"tier": SubscriptionTier.BASIC,
"price_monthly": 29.0,
"price_yearly": 290.0,
"ai_text_generation_calls_limit": 10, # Unified limit for all LLM providers
"ai_text_generation_calls_limit": 50, # INCREASED: Unified limit for all LLM providers (OSS-focused strategy)
"gemini_calls_limit": 1000, # Legacy, kept for backwards compatibility (not used for enforcement)
"openai_calls_limit": 500,
"anthropic_calls_limit": 200,
@@ -413,18 +466,18 @@ class PricingService:
"serper_calls_limit": 200,
"metaphor_calls_limit": 100,
"firecrawl_calls_limit": 100,
"stability_calls_limit": 5,
"stability_calls_limit": 50, # INCREASED: Now includes WaveSpeed OSS models (Qwen Image $0.03)
"exa_calls_limit": 500,
"video_calls_limit": 20, # 20 videos/month for basic plan
"image_edit_calls_limit": 30, # 30 AI image editing calls/month
"audio_calls_limit": 50, # 50 AI audio generation calls/month
"gemini_tokens_limit": 20000, # Increased from 5000 for better stability
"openai_tokens_limit": 20000, # Increased from 5000 for better stability
"anthropic_tokens_limit": 20000, # Increased from 5000 for better stability
"mistral_tokens_limit": 20000, # Increased from 5000 for better stability
"monthly_cost_limit": 50.0,
"features": ["full_content_generation", "advanced_research", "basic_analytics"],
"description": "Great for individuals and small teams"
"video_calls_limit": 30, # INCREASED: 30 videos/month (WAN 2.5 OSS $0.25)
"image_edit_calls_limit": 50, # INCREASED: 50 AI image editing calls/month (Qwen Edit OSS $0.02)
"audio_calls_limit": 100, # INCREASED: 100 AI audio generation calls/month (Minimax Speech OSS)
"gemini_tokens_limit": 100000, # INCREASED: 100K tokens per provider (OSS-focused strategy)
"openai_tokens_limit": 100000, # INCREASED: 100K tokens per provider
"anthropic_tokens_limit": 100000, # INCREASED: 100K tokens per provider
"mistral_tokens_limit": 100000, # INCREASED: 100K tokens per provider
"monthly_cost_limit": 45.0, # ADJUSTED: $45 cap (aligns with $40-50 hard limit target)
"features": ["full_content_generation", "advanced_research", "basic_analytics", "all_tools_access", "oss_models_priority"],
"description": "Perfect for individuals and small teams. Access all ALwrity features with generous limits powered by OSS AI models."
},
{
"name": "Pro",

View File

@@ -0,0 +1,156 @@
"""
Provider Detection Utility
Detects the actual provider (WaveSpeed, Google, HuggingFace, etc.) from model names and endpoints.
"""
from typing import Optional
from models.subscription_models import APIProvider
from loguru import logger
def detect_actual_provider(provider_enum: APIProvider, model_name: Optional[str] = None, endpoint: Optional[str] = None) -> str:
"""
Detect the actual provider name from provider enum, model name, and endpoint.
Args:
provider_enum: The APIProvider enum value (e.g., APIProvider.VIDEO, APIProvider.GEMINI)
model_name: The model name (e.g., "alibaba/wan-2.5/text-to-video", "gemini-2.5-flash")
endpoint: The API endpoint (e.g., "/video-generation/wavespeed", "/image-generation/stability")
Returns:
Actual provider name: "wavespeed", "google", "huggingface", "stability", "openai", "anthropic", etc.
"""
# For LLM providers, use the enum value directly
if provider_enum in [APIProvider.GEMINI]:
return "google"
elif provider_enum == APIProvider.OPENAI:
return "openai"
elif provider_enum == APIProvider.ANTHROPIC:
return "anthropic"
elif provider_enum == APIProvider.MISTRAL:
# MISTRAL enum is used for HuggingFace models
return "huggingface"
# For search APIs, use the enum value
elif provider_enum in [APIProvider.TAVILY, APIProvider.SERPER, APIProvider.METAPHOR, APIProvider.FIRECRAWL, APIProvider.EXA]:
return provider_enum.value
# For media generation, detect from model name or endpoint
elif provider_enum == APIProvider.VIDEO:
# Check model name first
if model_name:
model_lower = model_name.lower()
# WaveSpeed models
if any(x in model_lower for x in ["wan-2.5", "seedance", "infinitetalk", "wavespeed", "alibaba"]):
return "wavespeed"
# HuggingFace models
elif any(x in model_lower for x in ["huggingface", "hf", "tencent", "hunyuan"]):
return "huggingface"
# Google models (future)
elif any(x in model_lower for x in ["veo", "gemini"]):
return "google"
# OpenAI models (future)
elif any(x in model_lower for x in ["sora", "openai"]):
return "openai"
# Check endpoint
if endpoint:
endpoint_lower = endpoint.lower()
if "wavespeed" in endpoint_lower:
return "wavespeed"
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
return "huggingface"
elif "google" in endpoint_lower or "gemini" in endpoint_lower:
return "google"
elif "openai" in endpoint_lower:
return "openai"
# Default for video: WaveSpeed (most common)
return "wavespeed"
elif provider_enum == APIProvider.AUDIO:
# Check model name first
if model_name:
model_lower = model_name.lower()
# WaveSpeed models
if any(x in model_lower for x in ["minimax", "speech-02", "wavespeed"]):
return "wavespeed"
# Google models
elif any(x in model_lower for x in ["google", "gemini", "tts"]):
return "google"
# OpenAI models
elif any(x in model_lower for x in ["openai", "tts-1"]):
return "openai"
# ElevenLabs (future)
elif "elevenlabs" in model_lower:
return "elevenlabs"
# Check endpoint
if endpoint:
endpoint_lower = endpoint.lower()
if "wavespeed" in endpoint_lower:
return "wavespeed"
elif "google" in endpoint_lower:
return "google"
elif "openai" in endpoint_lower:
return "openai"
# Default for audio: WaveSpeed (most common)
return "wavespeed"
elif provider_enum == APIProvider.STABILITY:
# Check model name first
if model_name:
model_lower = model_name.lower()
# WaveSpeed OSS models
if any(x in model_lower for x in ["qwen", "ideogram", "flux", "wavespeed"]):
return "wavespeed"
# Stability AI models
elif any(x in model_lower for x in ["stability", "stable-diffusion", "sd-"]):
return "stability"
# HuggingFace models
elif any(x in model_lower for x in ["huggingface", "hf", "runway"]):
return "huggingface"
# Check endpoint
if endpoint:
endpoint_lower = endpoint.lower()
if "wavespeed" in endpoint_lower:
return "wavespeed"
elif "stability" in endpoint_lower:
return "stability"
elif "huggingface" in endpoint_lower or "hf" in endpoint_lower:
return "huggingface"
# Default: check if it's actually WaveSpeed based on common OSS models
if model_name and any(x in model_name.lower() for x in ["qwen", "ideogram", "flux"]):
return "wavespeed"
# Default for image generation: Stability (legacy)
return "stability"
elif provider_enum == APIProvider.IMAGE_EDIT:
# Check model name first
if model_name:
model_lower = model_name.lower()
# WaveSpeed OSS models
if any(x in model_lower for x in ["qwen", "flux", "kontext", "wavespeed"]):
return "wavespeed"
# Stability AI models
elif any(x in model_lower for x in ["stability", "stable-diffusion"]):
return "stability"
# Check endpoint
if endpoint:
endpoint_lower = endpoint.lower()
if "wavespeed" in endpoint_lower:
return "wavespeed"
elif "stability" in endpoint_lower:
return "stability"
# Default for image editing: WaveSpeed (OSS-first strategy)
return "wavespeed"
# Fallback: use enum value
logger.warning(f"Could not detect actual provider for {provider_enum.value}, using enum value")
return provider_enum.value

View File

@@ -0,0 +1,264 @@
"""
Renewal History Retention Service
Manages retention policies for subscription renewal history records.
Retention Policy:
- 0-12 months: Full records with usage snapshots
- 12-24 months: Full records (compressed/removed usage snapshots)
- 24-84 months: Summary records (no usage snapshots, payment data only)
- 84+ months: Mark for archive (payment data preserved indefinitely)
"""
from typing import Dict, Any, List, Optional
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from loguru import logger
import json
from models.subscription_models import SubscriptionRenewalHistory
class RenewalHistoryRetentionService:
"""Service for managing renewal history retention policies."""
# Retention periods (in days)
COMPRESS_SNAPSHOT_DAYS = 365 # 12 months - compress/remove usage snapshots
SUMMARY_RECORDS_DAYS = 730 # 24 months - create summary records
ARCHIVE_DAYS = 2555 # 84 months (7 years) - mark for archive
def __init__(self, db: Session):
self.db = db
def check_and_apply_retention(self, user_id: str) -> Dict[str, Any]:
"""
Check and apply retention policies for renewal history.
Applies retention in stages:
1. Compress usage snapshots for records 12-24 months old
2. Create summary records for records 24-84 months old
3. Mark records older than 84 months for archive
Returns:
Dict with retention status and statistics
"""
try:
now = datetime.utcnow()
compress_cutoff = now - timedelta(days=self.COMPRESS_SNAPSHOT_DAYS)
summary_cutoff = now - timedelta(days=self.SUMMARY_RECORDS_DAYS)
archive_cutoff = now - timedelta(days=self.ARCHIVE_DAYS)
# Count records in each retention tier
total_count = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
SubscriptionRenewalHistory.user_id == user_id
).scalar() or 0
records_to_compress = self.db.query(SubscriptionRenewalHistory).filter(
SubscriptionRenewalHistory.user_id == user_id,
SubscriptionRenewalHistory.created_at < compress_cutoff,
SubscriptionRenewalHistory.created_at >= summary_cutoff,
SubscriptionRenewalHistory.usage_before_renewal.isnot(None) # Has snapshot to compress
).all()
records_to_summarize = self.db.query(SubscriptionRenewalHistory).filter(
SubscriptionRenewalHistory.user_id == user_id,
SubscriptionRenewalHistory.created_at < summary_cutoff,
SubscriptionRenewalHistory.created_at >= archive_cutoff,
SubscriptionRenewalHistory.usage_before_renewal.isnot(None) # Has snapshot to remove
).all()
records_to_archive = self.db.query(SubscriptionRenewalHistory).filter(
SubscriptionRenewalHistory.user_id == user_id,
SubscriptionRenewalHistory.created_at < archive_cutoff
).all()
# Apply retention policies
compressed_count = self._compress_usage_snapshots(records_to_compress)
summarized_count = self._create_summary_records(records_to_summarize)
archived_count = self._mark_for_archive(records_to_archive)
total_processed = compressed_count + summarized_count + archived_count
if total_processed == 0:
return {
'retention_applied': False,
'total_records': total_count,
'records_to_compress': len(records_to_compress),
'records_to_summarize': len(records_to_summarize),
'records_to_archive': len(records_to_archive),
'message': 'No records require retention processing'
}
self.db.commit()
logger.info(
f"[RenewalRetention] Applied retention for user {user_id}: "
f"{compressed_count} compressed, {summarized_count} summarized, "
f"{archived_count} archived"
)
return {
'retention_applied': True,
'total_records': total_count,
'compressed_count': compressed_count,
'summarized_count': summarized_count,
'archived_count': archived_count,
'total_processed': total_processed,
'message': f'Processed {total_processed} records: {compressed_count} compressed, {summarized_count} summarized, {archived_count} archived'
}
except Exception as e:
self.db.rollback()
logger.error(f"[RenewalRetention] Error applying retention for user {user_id}: {e}", exc_info=True)
return {
'retention_applied': False,
'error': str(e),
'message': f'Error applying retention: {str(e)}'
}
def _compress_usage_snapshots(self, records: List[SubscriptionRenewalHistory]) -> int:
"""
Compress usage snapshots for records 12-24 months old.
Strategy: Replace detailed JSON snapshot with summary statistics only.
Keeps only essential metrics: total_calls, total_tokens, total_cost.
"""
compressed = 0
for record in records:
if record.usage_before_renewal:
try:
usage_data = record.usage_before_renewal
# Handle both dict (SQLAlchemy JSON) and string formats
if isinstance(usage_data, str):
try:
usage_data = json.loads(usage_data)
except json.JSONDecodeError:
# If it's not valid JSON, remove it
record.usage_before_renewal = None
compressed += 1
continue
elif not isinstance(usage_data, dict):
# If it's not a dict or string, remove it
record.usage_before_renewal = None
compressed += 1
continue
# Check if already compressed (has 'compressed_at' key)
if isinstance(usage_data, dict) and 'compressed_at' in usage_data:
# Already compressed, skip
continue
# Create compressed summary (keep only key metrics)
compressed_summary = {
'total_calls': usage_data.get('total_calls', 0),
'total_tokens': usage_data.get('total_tokens', 0),
'total_cost': usage_data.get('total_cost', 0.0),
'compressed_at': datetime.utcnow().isoformat(),
'note': 'Usage snapshot compressed after 12 months'
}
record.usage_before_renewal = compressed_summary
compressed += 1
except (TypeError, AttributeError, KeyError) as e:
logger.warning(f"[RenewalRetention] Failed to compress snapshot for record {record.id}: {e}")
# If compression fails, remove snapshot entirely
record.usage_before_renewal = None
compressed += 1
return compressed
def _create_summary_records(self, records: List[SubscriptionRenewalHistory]) -> int:
"""
Create summary records for records 24-84 months old.
Strategy: Remove usage snapshots, keep only payment and subscription data.
"""
summarized = 0
for record in records:
if record.usage_before_renewal is not None:
# Remove usage snapshot, keep payment and subscription data
record.usage_before_renewal = None
summarized += 1
return summarized
def _mark_for_archive(self, records: List[SubscriptionRenewalHistory]) -> int:
"""
Mark records older than 84 months for archive.
Strategy: Ensure usage snapshots are removed, payment data is preserved.
Note: In future, these could be moved to an archive table.
"""
archived = 0
for record in records:
# Ensure usage snapshot is removed (should already be done)
if record.usage_before_renewal is not None:
record.usage_before_renewal = None
archived += 1
else:
# Already processed, just count
archived += 1
return archived
def get_retention_stats(self, user_id: str) -> Dict[str, Any]:
"""
Get retention statistics for a user's renewal history.
Returns breakdown by retention tier.
"""
try:
now = datetime.utcnow()
compress_cutoff = now - timedelta(days=self.COMPRESS_SNAPSHOT_DAYS)
summary_cutoff = now - timedelta(days=self.SUMMARY_RECORDS_DAYS)
archive_cutoff = now - timedelta(days=self.ARCHIVE_DAYS)
total = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
SubscriptionRenewalHistory.user_id == user_id
).scalar() or 0
recent = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
SubscriptionRenewalHistory.user_id == user_id,
SubscriptionRenewalHistory.created_at >= compress_cutoff
).scalar() or 0
to_compress = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
SubscriptionRenewalHistory.user_id == user_id,
SubscriptionRenewalHistory.created_at < compress_cutoff,
SubscriptionRenewalHistory.created_at >= summary_cutoff,
SubscriptionRenewalHistory.usage_before_renewal.isnot(None)
).scalar() or 0
to_summarize = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
SubscriptionRenewalHistory.user_id == user_id,
SubscriptionRenewalHistory.created_at < summary_cutoff,
SubscriptionRenewalHistory.created_at >= archive_cutoff,
SubscriptionRenewalHistory.usage_before_renewal.isnot(None)
).scalar() or 0
to_archive = self.db.query(func.count(SubscriptionRenewalHistory.id)).filter(
SubscriptionRenewalHistory.user_id == user_id,
SubscriptionRenewalHistory.created_at < archive_cutoff
).scalar() or 0
return {
'total_records': total,
'recent_records': recent, # 0-12 months
'records_to_compress': to_compress, # 12-24 months
'records_to_summarize': to_summarize, # 24-84 months
'records_to_archive': to_archive, # 84+ months
'retention_policy': {
'compress_after_days': self.COMPRESS_SNAPSHOT_DAYS,
'summarize_after_days': self.SUMMARY_RECORDS_DAYS,
'archive_after_days': self.ARCHIVE_DAYS
}
}
except Exception as e:
logger.error(f"[RenewalRetention] Error getting retention stats for user {user_id}: {e}", exc_info=True)
return {
'error': str(e),
'total_records': 0
}

View File

@@ -6,6 +6,7 @@ from loguru import logger
_checked_subscription_plan_columns: bool = False
_checked_usage_summaries_columns: bool = False
_checked_api_usage_logs_columns: bool = False
def ensure_subscription_plan_columns(db: Session) -> None:
@@ -114,9 +115,58 @@ def ensure_usage_summaries_columns(db: Session) -> None:
raise
def ensure_api_usage_logs_columns(db: Session) -> None:
"""Ensure required columns exist on api_usage_logs for runtime safety.
This is a defensive guard for environments where migrations have not yet
been applied. If columns are missing (e.g., actual_provider_name), we add them
with a safe default so ORM queries do not fail.
"""
global _checked_api_usage_logs_columns
if _checked_api_usage_logs_columns:
return
try:
# Discover existing columns using PRAGMA
result = db.execute(text("PRAGMA table_info(api_usage_logs)"))
cols: Set[str] = {row[1] for row in result}
logger.debug(f"Schema check: Found {len(cols)} columns in api_usage_logs table")
# Columns we may reference in models but might be missing in older DBs
required_columns = {
"actual_provider_name": "VARCHAR(50) NULL",
}
for col_name, ddl in required_columns.items():
if col_name not in cols:
logger.info(f"Adding missing column {col_name} to api_usage_logs table")
try:
db.execute(text(f"ALTER TABLE api_usage_logs ADD COLUMN {col_name} {ddl}"))
db.commit()
logger.info(f"Successfully added column {col_name}")
except Exception as alter_err:
logger.error(f"Failed to add column {col_name}: {alter_err}")
db.rollback()
# Don't set flag on error - allow retry
raise
else:
logger.debug(f"Column {col_name} already exists")
# Only set flag if we successfully completed the check
_checked_api_usage_logs_columns = True
except Exception as e:
logger.error(f"Error ensuring api_usage_logs columns: {e}", exc_info=True)
db.rollback()
# Don't set the flag if there was an error, so we retry next time
_checked_api_usage_logs_columns = False
raise
def ensure_all_schema_columns(db: Session) -> None:
"""Ensure all required columns exist in subscription-related tables."""
ensure_subscription_plan_columns(db)
ensure_usage_summaries_columns(db)
ensure_api_usage_logs_columns(db)

View File

@@ -15,6 +15,7 @@ from models.subscription_models import (
UserSubscription, UsageStatus
)
from .pricing_service import PricingService
from .provider_detection import detect_actual_provider
class UsageTrackingService:
"""Service for tracking API usage and managing subscription limits."""
@@ -67,12 +68,21 @@ class UsageTrackingService:
# Create usage log entry
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
# Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.)
actual_provider_name = detect_actual_provider(
provider_enum=provider,
model_name=model_used,
endpoint=endpoint
)
usage_log = APIUsageLog(
user_id=user_id,
provider=provider,
endpoint=endpoint,
method=method,
model_used=model_used,
actual_provider_name=actual_provider_name, # Track actual provider
tokens_input=tokens_input,
tokens_output=tokens_output,
tokens_total=(tokens_input or 0) + (tokens_output or 0),
@@ -404,18 +414,128 @@ class UsageTrackingService:
'cost': mistral_cost
}
# Add other providers (Video, Audio, Image, Image Edit) for comprehensive breakdown
# Video (WaveSpeed, HuggingFace, etc.)
video_calls = getattr(summary, "video_calls", 0) or 0
video_cost = getattr(summary, "video_cost", 0.0) or 0.0
if video_calls > 0 and video_cost == 0.0:
video_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.VIDEO,
APIUsageLog.billing_period == billing_period
).all()
if video_logs:
video_cost = sum(float(log.cost_total or 0.0) for log in video_logs)
provider_breakdown['video'] = {
'calls': video_calls,
'tokens': 0,
'cost': video_cost
}
# Audio (WaveSpeed, etc.)
audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_cost = getattr(summary, "audio_cost", 0.0) or 0.0
if audio_calls > 0 and audio_cost == 0.0:
audio_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.AUDIO,
APIUsageLog.billing_period == billing_period
).all()
if audio_logs:
audio_cost = sum(float(log.cost_total or 0.0) for log in audio_logs)
provider_breakdown['audio'] = {
'calls': audio_calls,
'tokens': 0,
'cost': audio_cost
}
# Image Generation (Stability/WaveSpeed)
stability_calls = getattr(summary, "stability_calls", 0) or 0
stability_cost = getattr(summary, "stability_cost", 0.0) or 0.0
if stability_calls > 0 and stability_cost == 0.0:
stability_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.STABILITY,
APIUsageLog.billing_period == billing_period
).all()
if stability_logs:
stability_cost = sum(float(log.cost_total or 0.0) for log in stability_logs)
provider_breakdown['image'] = {
'calls': stability_calls,
'tokens': 0,
'cost': stability_cost
}
# Image Editing (WaveSpeed)
image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_cost = getattr(summary, "image_edit_cost", 0.0) or 0.0
if image_edit_calls > 0 and image_edit_cost == 0.0:
image_edit_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.IMAGE_EDIT,
APIUsageLog.billing_period == billing_period
).all()
if image_edit_logs:
image_edit_cost = sum(float(log.cost_total or 0.0) for log in image_edit_logs)
provider_breakdown['image_edit'] = {
'calls': image_edit_calls,
'tokens': 0,
'cost': image_edit_cost
}
# Search APIs
tavily_calls = getattr(summary, "tavily_calls", 0) or 0
tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0
provider_breakdown['tavily'] = {
'calls': tavily_calls,
'tokens': 0,
'cost': tavily_cost
}
serper_calls = getattr(summary, "serper_calls", 0) or 0
serper_cost = getattr(summary, "serper_cost", 0.0) or 0.0
provider_breakdown['serper'] = {
'calls': serper_calls,
'tokens': 0,
'cost': serper_cost
}
exa_calls = getattr(summary, "exa_calls", 0) or 0
exa_cost = getattr(summary, "exa_cost", 0.0) or 0.0
provider_breakdown['exa'] = {
'calls': exa_calls,
'tokens': 0,
'cost': exa_cost
}
# Calculate total cost from provider breakdown if summary total_cost is 0
calculated_total_cost = gemini_cost + mistral_cost
calculated_total_cost = (
gemini_cost + mistral_cost + video_cost + audio_cost +
stability_cost + image_edit_cost + tavily_cost + serper_cost + exa_cost
)
summary_total_cost = summary.total_cost or 0.0
# Use calculated cost if summary cost is 0, otherwise use summary cost (it's more accurate)
final_total_cost = summary_total_cost if summary_total_cost > 0 else calculated_total_cost
# If we calculated costs from logs, update the summary for future requests
if calculated_total_cost > 0 and summary_total_cost == 0.0:
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}")
logger.info(f"[UsageStats] Updating summary costs: total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}, video_cost={video_cost:.6f}, audio_cost={audio_cost:.6f}, image_cost={stability_cost:.6f}")
summary.total_cost = final_total_cost
summary.gemini_cost = gemini_cost
summary.mistral_cost = mistral_cost
# Update other provider costs if they exist
if hasattr(summary, 'video_cost'):
summary.video_cost = video_cost
if hasattr(summary, 'audio_cost'):
summary.audio_cost = audio_cost
if hasattr(summary, 'stability_cost'):
summary.stability_cost = stability_cost
if hasattr(summary, 'image_edit_cost'):
summary.image_edit_cost = image_edit_cost
try:
self.db.commit()
except Exception as e: