From cf4c08ff7cee1f17e00da49e05e0606b85bb7b37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=D9=8A?= Date: Wed, 4 Mar 2026 20:42:44 +0530 Subject: [PATCH] Align usage period keys with subscription window and reset audio counters --- .../subscription/usage_tracking_service.py | 64 ++++++++++++++++--- 1 file changed, 54 insertions(+), 10 deletions(-) diff --git a/backend/services/subscription/usage_tracking_service.py b/backend/services/subscription/usage_tracking_service.py index 3590e8b1..a228921d 100644 --- a/backend/services/subscription/usage_tracking_service.py +++ b/backend/services/subscription/usage_tracking_service.py @@ -30,6 +30,40 @@ class UsageTrackingService: # TTL cache (30s) for enforcement results to cut DB chatter # key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime } self._enforce_cache: Dict[str, Dict[str, Any]] = {} + + def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]: + """Return authoritative billing period lookup keys anchored to subscription period boundaries.""" + subscription = self.db.query(UserSubscription).filter( + UserSubscription.user_id == user_id + ).first() + + # If caller explicitly requested a billing period, keep it authoritative for that read. + if billing_period: + return { + "billing_period": billing_period, + "lookup_periods": [billing_period], + "period_start": subscription.current_period_start if subscription else None, + "period_end": subscription.current_period_end if subscription else None, + } + + if subscription and subscription.current_period_start and subscription.current_period_end: + start_key = subscription.current_period_start.strftime("%Y-%m") + end_key = subscription.current_period_end.strftime("%Y-%m") + lookup_periods = [start_key] if start_key == end_key else [start_key, end_key] + return { + "billing_period": start_key, + "lookup_periods": lookup_periods, + "period_start": subscription.current_period_start, + "period_end": subscription.current_period_end, + } + + resolved_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + return { + "billing_period": resolved_period, + "lookup_periods": [resolved_period], + "period_start": None, + "period_end": None, + } async def track_api_usage(self, user_id: str, provider: APIProvider, endpoint: str, method: str, model_used: str = None, @@ -71,7 +105,8 @@ class UsageTrackingService: ) # Create usage log entry - billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + period_keys = self._get_authoritative_billing_period_keys(user_id) + billing_period = period_keys["billing_period"] # Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.) actual_provider_name = detect_actual_provider( @@ -145,15 +180,16 @@ class UsageTrackingService: """Update the usage summary for a user.""" # Get or create usage summary + period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period) summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, - UsageSummary.billing_period == billing_period + UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() if not summary: summary = UsageSummary( user_id=user_id, - billing_period=billing_period + billing_period=period_keys["billing_period"] ) self.db.add(summary) @@ -233,9 +269,10 @@ class UsageTrackingService: """Check if usage alerts should be sent.""" # Get current usage + period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period) summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, - UsageSummary.billing_period == billing_period + UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() if not summary: @@ -319,13 +356,14 @@ class UsageTrackingService: def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]: """Get comprehensive usage statistics for a user.""" - if not billing_period: - billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m") + requested_billing_period = billing_period + period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period) + billing_period = period_keys["billing_period"] # Get usage summary summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, - UsageSummary.billing_period == billing_period + UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() # Get user limits @@ -341,7 +379,7 @@ class UsageTrackingService: if not summary: # If no summary exists for current period, we should initialize it # This handles the "start of month" case where a user logs in but hasn't made calls yet - if billing_period == datetime.now().strftime("%Y-%m"): + if not requested_billing_period: logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}") summary = UsageSummary( user_id=user_id, @@ -843,10 +881,11 @@ class UsageTrackingService: async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]: """Reset usage status and counters for the current billing period (after plan renewal/change).""" try: - billing_period = datetime.now().strftime("%Y-%m") + period_keys = self._get_authoritative_billing_period_keys(user_id) + billing_period = period_keys["billing_period"] summary = self.db.query(UsageSummary).filter( UsageSummary.user_id == user_id, - UsageSummary.billing_period == billing_period + UsageSummary.billing_period.in_(period_keys["lookup_periods"]) ).first() if not summary: @@ -877,9 +916,13 @@ class UsageTrackingService: # Reset image generation counters summary.stability_calls = 0 + summary.exa_calls = 0 # Reset video generation counters summary.video_calls = 0 + + # Reset audio generation counters + summary.audio_calls = 0 # Reset image editing counters summary.image_edit_calls = 0 @@ -897,6 +940,7 @@ class UsageTrackingService: summary.exa_cost = 0.0 summary.video_cost = 0.0 summary.image_edit_cost = 0.0 + summary.audio_cost = 0.0 # Reset totals summary.total_calls = 0