Align usage period keys with subscription window and reset audio counters
This commit is contained in:
@@ -31,6 +31,40 @@ class UsageTrackingService:
|
|||||||
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
|
# key: f"{user_id}:{provider}", value: { 'result': (bool,str,dict), 'expires_at': datetime }
|
||||||
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
|
self._enforce_cache: Dict[str, Dict[str, Any]] = {}
|
||||||
|
|
||||||
|
def _get_authoritative_billing_period_keys(self, user_id: str, billing_period: Optional[str] = None) -> Dict[str, Any]:
|
||||||
|
"""Return authoritative billing period lookup keys anchored to subscription period boundaries."""
|
||||||
|
subscription = self.db.query(UserSubscription).filter(
|
||||||
|
UserSubscription.user_id == user_id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
# If caller explicitly requested a billing period, keep it authoritative for that read.
|
||||||
|
if billing_period:
|
||||||
|
return {
|
||||||
|
"billing_period": billing_period,
|
||||||
|
"lookup_periods": [billing_period],
|
||||||
|
"period_start": subscription.current_period_start if subscription else None,
|
||||||
|
"period_end": subscription.current_period_end if subscription else None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if subscription and subscription.current_period_start and subscription.current_period_end:
|
||||||
|
start_key = subscription.current_period_start.strftime("%Y-%m")
|
||||||
|
end_key = subscription.current_period_end.strftime("%Y-%m")
|
||||||
|
lookup_periods = [start_key] if start_key == end_key else [start_key, end_key]
|
||||||
|
return {
|
||||||
|
"billing_period": start_key,
|
||||||
|
"lookup_periods": lookup_periods,
|
||||||
|
"period_start": subscription.current_period_start,
|
||||||
|
"period_end": subscription.current_period_end,
|
||||||
|
}
|
||||||
|
|
||||||
|
resolved_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
||||||
|
return {
|
||||||
|
"billing_period": resolved_period,
|
||||||
|
"lookup_periods": [resolved_period],
|
||||||
|
"period_start": None,
|
||||||
|
"period_end": None,
|
||||||
|
}
|
||||||
|
|
||||||
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
async def track_api_usage(self, user_id: str, provider: APIProvider,
|
||||||
endpoint: str, method: str, model_used: str = None,
|
endpoint: str, method: str, model_used: str = None,
|
||||||
tokens_input: int = 0, tokens_output: int = 0,
|
tokens_input: int = 0, tokens_output: int = 0,
|
||||||
@@ -71,7 +105,8 @@ class UsageTrackingService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create usage log entry
|
# Create usage log entry
|
||||||
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
period_keys = self._get_authoritative_billing_period_keys(user_id)
|
||||||
|
billing_period = period_keys["billing_period"]
|
||||||
|
|
||||||
# Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.)
|
# Detect actual provider name (WaveSpeed, Google, HuggingFace, etc.)
|
||||||
actual_provider_name = detect_actual_provider(
|
actual_provider_name = detect_actual_provider(
|
||||||
@@ -145,15 +180,16 @@ class UsageTrackingService:
|
|||||||
"""Update the usage summary for a user."""
|
"""Update the usage summary for a user."""
|
||||||
|
|
||||||
# Get or create usage summary
|
# Get or create usage summary
|
||||||
|
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||||
summary = self.db.query(UsageSummary).filter(
|
summary = self.db.query(UsageSummary).filter(
|
||||||
UsageSummary.user_id == user_id,
|
UsageSummary.user_id == user_id,
|
||||||
UsageSummary.billing_period == billing_period
|
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not summary:
|
if not summary:
|
||||||
summary = UsageSummary(
|
summary = UsageSummary(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
billing_period=billing_period
|
billing_period=period_keys["billing_period"]
|
||||||
)
|
)
|
||||||
self.db.add(summary)
|
self.db.add(summary)
|
||||||
|
|
||||||
@@ -233,9 +269,10 @@ class UsageTrackingService:
|
|||||||
"""Check if usage alerts should be sent."""
|
"""Check if usage alerts should be sent."""
|
||||||
|
|
||||||
# Get current usage
|
# Get current usage
|
||||||
|
period_keys = self._get_authoritative_billing_period_keys(user_id, billing_period)
|
||||||
summary = self.db.query(UsageSummary).filter(
|
summary = self.db.query(UsageSummary).filter(
|
||||||
UsageSummary.user_id == user_id,
|
UsageSummary.user_id == user_id,
|
||||||
UsageSummary.billing_period == billing_period
|
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not summary:
|
if not summary:
|
||||||
@@ -319,13 +356,14 @@ class UsageTrackingService:
|
|||||||
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
|
||||||
"""Get comprehensive usage statistics for a user."""
|
"""Get comprehensive usage statistics for a user."""
|
||||||
|
|
||||||
if not billing_period:
|
requested_billing_period = billing_period
|
||||||
billing_period = self.pricing_service.get_current_billing_period(user_id) or datetime.now().strftime("%Y-%m")
|
period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period)
|
||||||
|
billing_period = period_keys["billing_period"]
|
||||||
|
|
||||||
# Get usage summary
|
# Get usage summary
|
||||||
summary = self.db.query(UsageSummary).filter(
|
summary = self.db.query(UsageSummary).filter(
|
||||||
UsageSummary.user_id == user_id,
|
UsageSummary.user_id == user_id,
|
||||||
UsageSummary.billing_period == billing_period
|
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
# Get user limits
|
# Get user limits
|
||||||
@@ -341,7 +379,7 @@ class UsageTrackingService:
|
|||||||
if not summary:
|
if not summary:
|
||||||
# If no summary exists for current period, we should initialize it
|
# If no summary exists for current period, we should initialize it
|
||||||
# This handles the "start of month" case where a user logs in but hasn't made calls yet
|
# This handles the "start of month" case where a user logs in but hasn't made calls yet
|
||||||
if billing_period == datetime.now().strftime("%Y-%m"):
|
if not requested_billing_period:
|
||||||
logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}")
|
logger.info(f"Initializing empty UsageSummary for user {user_id} in period {billing_period}")
|
||||||
summary = UsageSummary(
|
summary = UsageSummary(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -843,10 +881,11 @@ class UsageTrackingService:
|
|||||||
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
async def reset_current_billing_period(self, user_id: str) -> Dict[str, Any]:
|
||||||
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
"""Reset usage status and counters for the current billing period (after plan renewal/change)."""
|
||||||
try:
|
try:
|
||||||
billing_period = datetime.now().strftime("%Y-%m")
|
period_keys = self._get_authoritative_billing_period_keys(user_id)
|
||||||
|
billing_period = period_keys["billing_period"]
|
||||||
summary = self.db.query(UsageSummary).filter(
|
summary = self.db.query(UsageSummary).filter(
|
||||||
UsageSummary.user_id == user_id,
|
UsageSummary.user_id == user_id,
|
||||||
UsageSummary.billing_period == billing_period
|
UsageSummary.billing_period.in_(period_keys["lookup_periods"])
|
||||||
).first()
|
).first()
|
||||||
|
|
||||||
if not summary:
|
if not summary:
|
||||||
@@ -877,10 +916,14 @@ class UsageTrackingService:
|
|||||||
|
|
||||||
# Reset image generation counters
|
# Reset image generation counters
|
||||||
summary.stability_calls = 0
|
summary.stability_calls = 0
|
||||||
|
summary.exa_calls = 0
|
||||||
|
|
||||||
# Reset video generation counters
|
# Reset video generation counters
|
||||||
summary.video_calls = 0
|
summary.video_calls = 0
|
||||||
|
|
||||||
|
# Reset audio generation counters
|
||||||
|
summary.audio_calls = 0
|
||||||
|
|
||||||
# Reset image editing counters
|
# Reset image editing counters
|
||||||
summary.image_edit_calls = 0
|
summary.image_edit_calls = 0
|
||||||
|
|
||||||
@@ -897,6 +940,7 @@ class UsageTrackingService:
|
|||||||
summary.exa_cost = 0.0
|
summary.exa_cost = 0.0
|
||||||
summary.video_cost = 0.0
|
summary.video_cost = 0.0
|
||||||
summary.image_edit_cost = 0.0
|
summary.image_edit_cost = 0.0
|
||||||
|
summary.audio_cost = 0.0
|
||||||
|
|
||||||
# Reset totals
|
# Reset totals
|
||||||
summary.total_calls = 0
|
summary.total_calls = 0
|
||||||
|
|||||||
Reference in New Issue
Block a user