Extract usage trends and reset logic into usage_tracking_helpers

This commit is contained in:
ي
2026-03-12 07:32:59 +05:30
parent 01881bb405
commit ad1756aaa2
5 changed files with 573 additions and 469 deletions

View File

@@ -0,0 +1,29 @@
"""Helper utilities for usage tracking service."""
from .usage_reset_helpers import reset_usage_summary_counters
from .usage_stats_helpers import (
build_default_usage_percentages,
build_empty_usage_response,
build_provider_breakdown,
calculate_final_total_cost,
maybe_persist_reconciled_costs,
)
from .usage_trends_helpers import (
build_billing_periods,
build_usage_trends_response,
query_usage_summaries,
self_heal_summaries_from_logs,
)
__all__ = [
"build_default_usage_percentages",
"build_empty_usage_response",
"build_provider_breakdown",
"calculate_final_total_cost",
"maybe_persist_reconciled_costs",
"build_billing_periods",
"query_usage_summaries",
"self_heal_summaries_from_logs",
"build_usage_trends_response",
"reset_usage_summary_counters",
]

View File

@@ -0,0 +1,73 @@
"""Helpers extracted from UsageTrackingService.reset_current_billing_period."""
from __future__ import annotations
from datetime import datetime
from typing import Any
from models.subscription_models import UsageStatus
_CALL_FIELDS = [
"gemini_calls",
"openai_calls",
"anthropic_calls",
"mistral_calls",
"wavespeed_calls",
"tavily_calls",
"serper_calls",
"metaphor_calls",
"firecrawl_calls",
"stability_calls",
"exa_calls",
"video_calls",
"audio_calls",
"image_edit_calls",
]
_TOKEN_FIELDS = [
"gemini_tokens",
"openai_tokens",
"anthropic_tokens",
"mistral_tokens",
"wavespeed_tokens",
]
_COST_FIELDS = [
"gemini_cost",
"openai_cost",
"anthropic_cost",
"mistral_cost",
"wavespeed_cost",
"tavily_cost",
"serper_cost",
"metaphor_cost",
"firecrawl_cost",
"stability_cost",
"exa_cost",
"video_cost",
"image_edit_cost",
"audio_cost",
]
def reset_usage_summary_counters(summary: Any) -> None:
"""Reset all known usage counters to baseline values."""
summary.usage_status = UsageStatus.ACTIVE
for field in _CALL_FIELDS:
if hasattr(summary, field):
setattr(summary, field, 0)
for field in _TOKEN_FIELDS:
if hasattr(summary, field):
setattr(summary, field, 0)
for field in _COST_FIELDS:
if hasattr(summary, field):
setattr(summary, field, 0.0)
summary.total_calls = 0
summary.total_tokens = 0
summary.total_cost = 0.0
summary.updated_at = datetime.utcnow()

View File

@@ -0,0 +1,250 @@
"""Helper utilities extracted from UsageTrackingService.get_user_usage_stats."""
from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Iterable, Tuple
from loguru import logger
from models.subscription_models import APIProvider, APIUsageLog
def _safe_getattr(obj: Any, attr: str, default: Any) -> Any:
"""Return object attribute using a defensive default fallback."""
value = getattr(obj, attr, default)
return default if value is None else value
def build_default_usage_percentages(providers: Iterable[APIProvider]) -> Dict[str, float]:
"""Create zeroed usage percentages for all provider call limits plus cost."""
usage_percentages = {f"{provider.value}_calls": 0 for provider in providers}
usage_percentages["cost"] = 0
return usage_percentages
def build_empty_usage_response(
*,
billing_period: str,
limits: Dict[str, Any] | None,
providers: Iterable[APIProvider],
) -> Dict[str, Any]:
"""Build a no-usage response payload with complete shape."""
provider_breakdown = {
provider.value: {"calls": 0, "tokens": 0, "cost": 0.0} for provider in providers
}
return {
"billing_period": billing_period,
"usage_status": "active",
"total_calls": 0,
"total_tokens": 0,
"total_cost": 0.0,
"avg_response_time": 0.0,
"error_rate": 0.0,
"last_updated": datetime.now().isoformat(),
"limits": limits,
"provider_breakdown": provider_breakdown,
"alerts": [],
"usage_percentages": build_default_usage_percentages(providers),
}
def _resolve_cost_from_logs(
*,
db: Any,
user_id: str,
billing_period: str,
provider: APIProvider,
current_calls: int,
current_cost: float,
debug_label: str,
) -> float:
"""Backfill provider cost from usage logs only when needed."""
if current_calls <= 0 or current_cost != 0.0:
return current_cost
logs = (
db.query(APIUsageLog)
.filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == provider,
APIUsageLog.billing_period == billing_period,
)
.all()
)
if not logs:
return current_cost
calculated_cost = sum(float(log.cost_total or 0.0) for log in logs)
logger.info(
f"[UsageStats] Calculated {debug_label} cost from {len(logs)} logs: ${calculated_cost:.6f}"
)
return calculated_cost
def build_provider_breakdown(
*,
db: Any,
user_id: str,
billing_period: str,
summary: Any,
) -> Tuple[Dict[str, Dict[str, float | int]], Dict[str, float], Dict[str, int]]:
"""Build provider breakdown while preserving existing backfill behavior."""
provider_breakdown: Dict[str, Dict[str, float | int]] = {}
gemini_calls = int(_safe_getattr(summary, "gemini_calls", 0) or 0)
gemini_tokens = int(_safe_getattr(summary, "gemini_tokens", 0) or 0)
gemini_cost = float(_safe_getattr(summary, "gemini_cost", 0.0) or 0.0)
gemini_cost = _resolve_cost_from_logs(
db=db,
user_id=user_id,
billing_period=billing_period,
provider=APIProvider.GEMINI,
current_calls=gemini_calls,
current_cost=gemini_cost,
debug_label="gemini",
)
provider_breakdown["gemini"] = {"calls": gemini_calls, "tokens": gemini_tokens, "cost": gemini_cost}
mistral_calls = int(_safe_getattr(summary, "mistral_calls", 0) or 0)
mistral_tokens = int(_safe_getattr(summary, "mistral_tokens", 0) or 0)
mistral_cost = float(_safe_getattr(summary, "mistral_cost", 0.0) or 0.0)
mistral_cost = _resolve_cost_from_logs(
db=db,
user_id=user_id,
billing_period=billing_period,
provider=APIProvider.MISTRAL,
current_calls=mistral_calls,
current_cost=mistral_cost,
debug_label="mistral (HuggingFace)",
)
provider_breakdown["huggingface"] = {
"calls": mistral_calls,
"tokens": mistral_tokens,
"cost": mistral_cost,
}
mapped_providers = {
"video": ("video_calls", "video_cost", APIProvider.VIDEO),
"audio": ("audio_calls", "audio_cost", APIProvider.AUDIO),
"image": ("stability_calls", "stability_cost", APIProvider.STABILITY),
"image_edit": ("image_edit_calls", "image_edit_cost", APIProvider.IMAGE_EDIT),
}
resolved_costs = {
"gemini_cost": gemini_cost,
"mistral_cost": mistral_cost,
}
for key, (calls_attr, cost_attr, provider_enum) in mapped_providers.items():
calls = int(_safe_getattr(summary, calls_attr, 0) or 0)
cost = float(_safe_getattr(summary, cost_attr, 0.0) or 0.0)
cost = _resolve_cost_from_logs(
db=db,
user_id=user_id,
billing_period=billing_period,
provider=provider_enum,
current_calls=calls,
current_cost=cost,
debug_label=key,
)
provider_breakdown[key] = {"calls": calls, "tokens": 0, "cost": cost}
resolved_costs[cost_attr] = cost
wavespeed_logs = (
db.query(APIUsageLog)
.filter(
APIUsageLog.user_id == user_id,
APIUsageLog.billing_period == billing_period,
APIUsageLog.actual_provider_name == "wavespeed",
)
.all()
)
if wavespeed_logs:
wavespeed_calls = len(wavespeed_logs)
wavespeed_tokens = sum((log.tokens_total or 0) for log in wavespeed_logs)
wavespeed_cost = sum(float(log.cost_total or 0.0) for log in wavespeed_logs)
provider_breakdown["wavespeed"] = {
"calls": wavespeed_calls,
"tokens": wavespeed_tokens,
"cost": wavespeed_cost,
}
logger.info(
f"[UsageStats] Calculated WaveSpeed usage: {wavespeed_calls} calls, ${wavespeed_cost:.6f}"
)
else:
provider_breakdown["wavespeed"] = {"calls": 0, "tokens": 0, "cost": 0.0}
for search_provider in ("tavily", "serper", "exa"):
calls = int(_safe_getattr(summary, f"{search_provider}_calls", 0) or 0)
cost = float(_safe_getattr(summary, f"{search_provider}_cost", 0.0) or 0.0)
provider_breakdown[search_provider] = {"calls": calls, "tokens": 0, "cost": cost}
resolved_costs[f"{search_provider}_cost"] = cost
core_counts = {
"gemini_calls": gemini_calls,
"mistral_calls": mistral_calls,
}
return provider_breakdown, resolved_costs, core_counts
def calculate_final_total_cost(summary_total_cost: float, resolved_costs: Dict[str, float]) -> Tuple[float, float]:
"""Return calculated and chosen total cost values."""
calculated_total_cost = float(
resolved_costs.get("gemini_cost", 0.0)
+ resolved_costs.get("mistral_cost", 0.0)
+ resolved_costs.get("video_cost", 0.0)
+ resolved_costs.get("audio_cost", 0.0)
+ resolved_costs.get("stability_cost", 0.0)
+ resolved_costs.get("image_edit_cost", 0.0)
+ resolved_costs.get("tavily_cost", 0.0)
+ resolved_costs.get("serper_cost", 0.0)
+ resolved_costs.get("exa_cost", 0.0)
)
final_total_cost = calculated_total_cost if calculated_total_cost > (summary_total_cost or 0.0) else (summary_total_cost or 0.0)
return calculated_total_cost, final_total_cost
def maybe_persist_reconciled_costs(
*,
db: Any,
summary: Any,
summary_total_cost: float,
calculated_total_cost: float,
final_total_cost: float,
resolved_costs: Dict[str, float],
) -> None:
"""Persist summary cost reconciliation when calculated values are more complete."""
if not (
calculated_total_cost > 0
and (summary_total_cost == 0.0 or calculated_total_cost > summary_total_cost)
):
return
logger.info(
"[UsageStats] Updating summary costs (was {}): total_cost={:.6f}, gemini_cost={:.6f}, "
"mistral_cost={:.6f}, video_cost={:.6f}, audio_cost={:.6f}, image_cost={:.6f}".format(
summary_total_cost,
final_total_cost,
resolved_costs.get("gemini_cost", 0.0),
resolved_costs.get("mistral_cost", 0.0),
resolved_costs.get("video_cost", 0.0),
resolved_costs.get("audio_cost", 0.0),
resolved_costs.get("stability_cost", 0.0),
)
)
summary.total_cost = final_total_cost
summary.gemini_cost = resolved_costs.get("gemini_cost", 0.0)
summary.mistral_cost = resolved_costs.get("mistral_cost", 0.0)
for summary_attr in ("video_cost", "audio_cost", "stability_cost", "image_edit_cost"):
if hasattr(summary, summary_attr):
setattr(summary, summary_attr, resolved_costs.get(summary_attr, 0.0))
try:
db.commit()
except Exception as e:
logger.error(f"[UsageStats] Error updating summary costs: {e}")
db.rollback()

View File

@@ -0,0 +1,171 @@
"""Helpers extracted from UsageTrackingService.get_usage_trends."""
from __future__ import annotations
from datetime import datetime, timedelta
from typing import Any, Dict, List
from loguru import logger
from sqlalchemy import func
from models.subscription_models import APIProvider, APIUsageLog, UsageStatus, UsageSummary
def build_billing_periods(months: int) -> List[str]:
"""Build billing period keys (YYYY-MM) from oldest to newest."""
end_date = datetime.now()
periods: List[str] = []
for i in range(months):
period_date = end_date - timedelta(days=30 * i)
periods.append(period_date.strftime("%Y-%m"))
periods.reverse()
return periods
def query_usage_summaries(db: Any, user_id: str, periods: List[str]) -> Dict[str, Any]:
"""Load usage summaries for requested periods keyed by billing period."""
summaries = (
db.query(UsageSummary)
.filter(UsageSummary.user_id == user_id, UsageSummary.billing_period.in_(periods))
.all()
)
return {summary.billing_period: summary for summary in summaries}
def self_heal_summaries_from_logs(db: Any, user_id: str, periods: List[str], summary_dict: Dict[str, Any]) -> None:
"""Backfill/create usage summaries from aggregated API usage logs."""
try:
log_stats = (
db.query(
APIUsageLog.billing_period,
APIUsageLog.provider,
func.count(APIUsageLog.id).label("calls"),
func.sum(APIUsageLog.cost_total).label("cost"),
func.sum(APIUsageLog.tokens_total).label("tokens"),
)
.filter(APIUsageLog.user_id == user_id, APIUsageLog.billing_period.in_(periods))
.group_by(APIUsageLog.billing_period, APIUsageLog.provider)
.all()
)
log_data_by_period: Dict[str, Dict[str, Dict[str, float | int]]] = {}
for period, provider_enum, calls, cost, tokens in log_stats:
if period not in log_data_by_period:
log_data_by_period[period] = {}
provider_name = provider_enum.value if hasattr(provider_enum, "value") else str(provider_enum).lower()
if "." in provider_name:
provider_name = provider_name.split(".")[-1].lower()
if provider_name not in log_data_by_period[period]:
log_data_by_period[period][provider_name] = {"calls": 0, "cost": 0.0, "tokens": 0}
log_data_by_period[period][provider_name]["calls"] += calls or 0
log_data_by_period[period][provider_name]["cost"] += float(cost or 0.0)
log_data_by_period[period][provider_name]["tokens"] += tokens or 0
for period in periods:
period_logs = log_data_by_period.get(period, {})
summary = summary_dict.get(period)
if not summary and period_logs:
logger.info(f"[UsageStats] Self-healing: Creating missing summary for {period}")
summary = UsageSummary(
user_id=user_id,
billing_period=period,
usage_status=UsageStatus.ACTIVE,
total_calls=0,
total_cost=0.0,
total_tokens=0,
)
db.add(summary)
summary_dict[period] = summary
if summary and period_logs:
total_calls_calc = 0
total_cost_calc = 0.0
total_tokens_calc = 0
for provider_name, data in period_logs.items():
total_calls_calc += int(data["calls"])
total_cost_calc += float(data["cost"])
total_tokens_calc += int(data["tokens"])
calls_attr = f"{provider_name}_calls"
cost_attr = f"{provider_name}_cost"
tokens_attr = f"{provider_name}_tokens"
if hasattr(summary, calls_attr):
current_val = getattr(summary, calls_attr, 0)
if current_val < data["calls"]:
setattr(summary, calls_attr, data["calls"])
if hasattr(summary, cost_attr):
current_val = getattr(summary, cost_attr, 0.0)
if (float(data["cost"]) - current_val) > 0.000001:
setattr(summary, cost_attr, data["cost"])
if hasattr(summary, tokens_attr):
current_val = getattr(summary, tokens_attr, 0)
if current_val < data["tokens"]:
setattr(summary, tokens_attr, data["tokens"])
if (summary.total_cost or 0.0) < total_cost_calc:
logger.info(
f"[UsageStats] Self-healing cost for {period}: {summary.total_cost} -> {total_cost_calc}"
)
summary.total_cost = total_cost_calc
if (summary.total_calls or 0) < total_calls_calc:
summary.total_calls = total_calls_calc
if (summary.total_tokens or 0) < total_tokens_calc:
summary.total_tokens = total_tokens_calc
db.commit()
except Exception as e:
logger.error(f"Failed to self-heal usage trends: {e}")
db.rollback()
def build_usage_trends_response(periods: List[str], summary_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Build trends response payload from summaries."""
trends = {
"periods": periods,
"total_calls": [],
"total_cost": [],
"total_tokens": [],
"provider_trends": {},
}
for provider in APIProvider:
provider_name = provider.value
trends["provider_trends"][provider_name] = {"calls": [], "cost": [], "tokens": []}
for period in periods:
summary = summary_dict.get(period)
if summary:
trends["total_calls"].append(summary.total_calls or 0)
trends["total_cost"].append(summary.total_cost or 0.0)
trends["total_tokens"].append(summary.total_tokens or 0)
for provider in APIProvider:
provider_name = provider.value
trends["provider_trends"][provider_name]["calls"].append(
getattr(summary, f"{provider_name}_calls", 0) or 0
)
trends["provider_trends"][provider_name]["cost"].append(
getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
)
trends["provider_trends"][provider_name]["tokens"].append(
getattr(summary, f"{provider_name}_tokens", 0) or 0
)
else:
trends["total_calls"].append(0)
trends["total_cost"].append(0.0)
trends["total_tokens"].append(0)
for provider in APIProvider:
provider_name = provider.value
trends["provider_trends"][provider_name]["calls"].append(0)
trends["provider_trends"][provider_name]["cost"].append(0.0)
trends["provider_trends"][provider_name]["tokens"].append(0)
return trends

View File

@@ -10,7 +10,7 @@ import asyncio
from typing import Dict, Any, List, Tuple
from datetime import datetime, timedelta
from sqlalchemy.orm import Session
from sqlalchemy import func, desc
from sqlalchemy import desc
from loguru import logger
import json
@@ -20,6 +20,18 @@ from models.subscription_models import (
)
from .pricing_service import PricingService
from .provider_detection import detect_actual_provider
from .usage_tracking_helpers import (
build_billing_periods,
build_default_usage_percentages,
build_empty_usage_response,
build_provider_breakdown,
build_usage_trends_response,
calculate_final_total_cost,
maybe_persist_reconciled_costs,
query_usage_summaries,
reset_usage_summary_counters,
self_heal_summaries_from_logs,
)
class UsageTrackingService:
"""Service for tracking API usage and managing subscription limits."""
@@ -356,6 +368,10 @@ class UsageTrackingService:
def get_user_usage_stats(self, user_id: str, billing_period: str = None) -> Dict[str, Any]:
"""Get comprehensive usage statistics for a user."""
if not user_id:
logger.error("get_user_usage_stats called without user_id")
raise ValueError("user_id is required")
requested_billing_period = billing_period
period_keys = self._get_authoritative_billing_period_keys(user_id, requested_billing_period)
@@ -401,269 +417,54 @@ class UsageTrackingService:
pass
if not summary: # Still no summary after attempt
# No usage this period - return complete structure with zeros
provider_breakdown = {}
usage_percentages = {}
# Initialize provider breakdown with zeros
for provider in APIProvider:
provider_name = provider.value
provider_breakdown[provider_name] = {
'calls': 0,
'tokens': 0,
'cost': 0.0
}
usage_percentages[f"{provider_name}_calls"] = 0
usage_percentages['cost'] = 0
return {
'billing_period': billing_period,
'usage_status': 'active',
'total_calls': 0,
'total_tokens': 0,
'total_cost': 0.0,
'avg_response_time': 0.0,
'error_rate': 0.0,
'last_updated': datetime.now().isoformat(),
'limits': limits,
'provider_breakdown': provider_breakdown,
'alerts': [],
'usage_percentages': usage_percentages
}
return build_empty_usage_response(
billing_period=billing_period,
limits=limits,
providers=APIProvider,
)
# Provider breakdown - calculate costs first, then use for percentages
# Only include Gemini and HuggingFace (HuggingFace is stored under MISTRAL enum)
provider_breakdown = {}
# Gemini
gemini_calls = getattr(summary, "gemini_calls", 0) or 0
gemini_tokens = getattr(summary, "gemini_tokens", 0) or 0
gemini_cost = getattr(summary, "gemini_cost", 0.0) or 0.0
# If gemini cost is 0 but there are calls, calculate from usage logs
if gemini_calls > 0 and gemini_cost == 0.0:
gemini_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.GEMINI,
APIUsageLog.billing_period == billing_period
).all()
if gemini_logs:
gemini_cost = sum(float(log.cost_total or 0.0) for log in gemini_logs)
logger.info(f"[UsageStats] Calculated gemini cost from {len(gemini_logs)} logs: ${gemini_cost:.6f}")
provider_breakdown['gemini'] = {
'calls': gemini_calls,
'tokens': gemini_tokens,
'cost': gemini_cost
}
# HuggingFace (stored as MISTRAL in database)
mistral_calls = getattr(summary, "mistral_calls", 0) or 0
mistral_tokens = getattr(summary, "mistral_tokens", 0) or 0
mistral_cost = getattr(summary, "mistral_cost", 0.0) or 0.0
# If mistral (HuggingFace) cost is 0 but there are calls, calculate from usage logs
if mistral_calls > 0 and mistral_cost == 0.0:
mistral_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.MISTRAL,
APIUsageLog.billing_period == billing_period
).all()
if mistral_logs:
mistral_cost = sum(float(log.cost_total or 0.0) for log in mistral_logs)
logger.info(f"[UsageStats] Calculated mistral (HuggingFace) cost from {len(mistral_logs)} logs: ${mistral_cost:.6f}")
provider_breakdown['huggingface'] = {
'calls': mistral_calls,
'tokens': mistral_tokens,
'cost': mistral_cost
}
# Add other providers (Video, Audio, Image, Image Edit) for comprehensive breakdown
# Video (WaveSpeed, HuggingFace, etc.)
video_calls = getattr(summary, "video_calls", 0) or 0
video_cost = getattr(summary, "video_cost", 0.0) or 0.0
if video_calls > 0 and video_cost == 0.0:
video_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.VIDEO,
APIUsageLog.billing_period == billing_period
).all()
if video_logs:
video_cost = sum(float(log.cost_total or 0.0) for log in video_logs)
provider_breakdown['video'] = {
'calls': video_calls,
'tokens': 0,
'cost': video_cost
}
# Audio (WaveSpeed, etc.)
audio_calls = getattr(summary, "audio_calls", 0) or 0
audio_cost = getattr(summary, "audio_cost", 0.0) or 0.0
if audio_calls > 0 and audio_cost == 0.0:
audio_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.AUDIO,
APIUsageLog.billing_period == billing_period
).all()
if audio_logs:
audio_cost = sum(float(log.cost_total or 0.0) for log in audio_logs)
provider_breakdown['audio'] = {
'calls': audio_calls,
'tokens': 0,
'cost': audio_cost
}
# Image Generation (Stability/WaveSpeed)
stability_calls = getattr(summary, "stability_calls", 0) or 0
stability_cost = getattr(summary, "stability_cost", 0.0) or 0.0
if stability_calls > 0 and stability_cost == 0.0:
stability_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.STABILITY,
APIUsageLog.billing_period == billing_period
).all()
if stability_logs:
stability_cost = sum(float(log.cost_total or 0.0) for log in stability_logs)
provider_breakdown['image'] = {
'calls': stability_calls,
'tokens': 0,
'cost': stability_cost
}
# Image Editing (WaveSpeed)
image_edit_calls = getattr(summary, "image_edit_calls", 0) or 0
image_edit_cost = getattr(summary, "image_edit_cost", 0.0) or 0.0
if image_edit_calls > 0 and image_edit_cost == 0.0:
image_edit_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.provider == APIProvider.IMAGE_EDIT,
APIUsageLog.billing_period == billing_period
).all()
if image_edit_logs:
image_edit_cost = sum(float(log.cost_total or 0.0) for log in image_edit_logs)
provider_breakdown['image_edit'] = {
'calls': image_edit_calls,
'tokens': 0,
'cost': image_edit_cost
}
# WaveSpeed (aggregated across Video, Audio, Image, Image Edit)
# Query APIUsageLog directly to get accurate WaveSpeed-specific usage
wavespeed_logs = self.db.query(APIUsageLog).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.billing_period == billing_period,
APIUsageLog.actual_provider_name == "wavespeed"
).all()
if wavespeed_logs:
wavespeed_calls = len(wavespeed_logs)
wavespeed_tokens = sum((log.tokens_total or 0) for log in wavespeed_logs)
wavespeed_cost = sum(float(log.cost_total or 0.0) for log in wavespeed_logs)
provider_breakdown['wavespeed'] = {
'calls': wavespeed_calls,
'tokens': wavespeed_tokens,
'cost': wavespeed_cost
}
logger.info(f"[UsageStats] Calculated WaveSpeed usage: {wavespeed_calls} calls, ${wavespeed_cost:.6f}")
else:
provider_breakdown['wavespeed'] = {
'calls': 0,
'tokens': 0,
'cost': 0.0
}
# Search APIs
tavily_calls = getattr(summary, "tavily_calls", 0) or 0
tavily_cost = getattr(summary, "tavily_cost", 0.0) or 0.0
provider_breakdown['tavily'] = {
'calls': tavily_calls,
'tokens': 0,
'cost': tavily_cost
}
serper_calls = getattr(summary, "serper_calls", 0) or 0
serper_cost = getattr(summary, "serper_cost", 0.0) or 0.0
provider_breakdown['serper'] = {
'calls': serper_calls,
'tokens': 0,
'cost': serper_cost
}
exa_calls = getattr(summary, "exa_calls", 0) or 0
exa_cost = getattr(summary, "exa_cost", 0.0) or 0.0
provider_breakdown['exa'] = {
'calls': exa_calls,
'tokens': 0,
'cost': exa_cost
}
# Calculate total cost from provider breakdown if summary total_cost is 0
calculated_total_cost = (
gemini_cost + mistral_cost + video_cost + audio_cost +
stability_cost + image_edit_cost + tavily_cost + serper_cost + exa_cost
provider_breakdown, resolved_costs, core_counts = build_provider_breakdown(
db=self.db,
user_id=user_id,
billing_period=billing_period,
summary=summary,
)
summary_total_cost = summary.total_cost or 0.0
# Determine the best cost value to use
# If summary cost is 0 but we have calculated cost, use calculated cost
# If summary cost exists but is less than calculated cost (out of sync), use calculated cost
if calculated_total_cost > summary_total_cost:
final_total_cost = calculated_total_cost
else:
final_total_cost = summary_total_cost
# If we found a discrepancy (summary cost is 0 or less than calculated), update the DB
if calculated_total_cost > 0 and (summary_total_cost == 0.0 or calculated_total_cost > summary_total_cost):
logger.info(f"[UsageStats] Updating summary costs (was {summary_total_cost}): total_cost={final_total_cost:.6f}, gemini_cost={gemini_cost:.6f}, mistral_cost={mistral_cost:.6f}, video_cost={video_cost:.6f}, audio_cost={audio_cost:.6f}, image_cost={stability_cost:.6f}")
summary.total_cost = final_total_cost
summary.gemini_cost = gemini_cost
summary.mistral_cost = mistral_cost
# Update other provider costs if they exist
if hasattr(summary, 'video_cost'):
summary.video_cost = video_cost
if hasattr(summary, 'audio_cost'):
summary.audio_cost = audio_cost
if hasattr(summary, 'stability_cost'):
summary.stability_cost = stability_cost
if hasattr(summary, 'image_edit_cost'):
summary.image_edit_cost = image_edit_cost
try:
self.db.commit()
except Exception as e:
logger.error(f"[UsageStats] Error updating summary costs: {e}")
self.db.rollback()
calculated_total_cost, final_total_cost = calculate_final_total_cost(
summary_total_cost=summary_total_cost,
resolved_costs=resolved_costs,
)
maybe_persist_reconciled_costs(
db=self.db,
summary=summary,
summary_total_cost=summary_total_cost,
calculated_total_cost=calculated_total_cost,
final_total_cost=final_total_cost,
resolved_costs=resolved_costs,
)
# Calculate usage percentages - only for Gemini and HuggingFace
# Use the calculated costs for accurate percentages
usage_percentages = {}
usage_percentages = build_default_usage_percentages(APIProvider)
if limits:
# Gemini
gemini_call_limit = limits['limits'].get("gemini_calls", 0) or 0
if gemini_call_limit > 0:
usage_percentages['gemini_calls'] = (gemini_calls / gemini_call_limit) * 100
else:
usage_percentages['gemini_calls'] = 0
usage_percentages['gemini_calls'] = (core_counts['gemini_calls'] / gemini_call_limit) * 100
# HuggingFace (stored as mistral in database)
mistral_call_limit = limits['limits'].get("mistral_calls", 0) or 0
if mistral_call_limit > 0:
usage_percentages['mistral_calls'] = (mistral_calls / mistral_call_limit) * 100
else:
usage_percentages['mistral_calls'] = 0
usage_percentages['mistral_calls'] = (core_counts['mistral_calls'] / mistral_call_limit) * 100
# Cost usage percentage - use final_total_cost (calculated from logs if needed)
cost_limit = limits['limits'].get('monthly_cost', 0) or 0
if cost_limit > 0:
usage_percentages['cost'] = (final_total_cost / cost_limit) * 100
else:
usage_percentages['cost'] = 0
return {
'billing_period': billing_period,
@@ -692,171 +493,10 @@ class UsageTrackingService:
def get_usage_trends(self, user_id: str, months: int = 6) -> Dict[str, Any]:
"""Get usage trends over time with self-healing from logs."""
# Calculate billing periods
end_date = datetime.now()
periods = []
for i in range(months):
period_date = end_date - timedelta(days=30 * i)
periods.append(period_date.strftime("%Y-%m"))
periods.reverse() # Oldest first
# 1. Fetch existing summaries
summaries = self.db.query(UsageSummary).filter(
UsageSummary.user_id == user_id,
UsageSummary.billing_period.in_(periods)
).all()
summary_dict = {s.billing_period: s for s in summaries}
# 2. Fetch aggregated logs for self-healing
# Group by (billing_period, provider) to fix provider breakdowns too
try:
log_stats = self.db.query(
APIUsageLog.billing_period,
APIUsageLog.provider,
func.count(APIUsageLog.id).label('calls'),
func.sum(APIUsageLog.cost_total).label('cost'),
func.sum(APIUsageLog.tokens_total).label('tokens')
).filter(
APIUsageLog.user_id == user_id,
APIUsageLog.billing_period.in_(periods)
).group_by(APIUsageLog.billing_period, APIUsageLog.provider).all()
# Organize log stats by period -> provider
log_data_by_period = {}
for period, provider_enum, calls, cost, tokens in log_stats:
if period not in log_data_by_period:
log_data_by_period[period] = {}
# Handle provider enum or string
provider_name = provider_enum.value if hasattr(provider_enum, 'value') else str(provider_enum).lower()
# Normalize provider names (e.g. 'GEMINI' -> 'gemini')
if '.' in provider_name:
provider_name = provider_name.split('.')[-1].lower()
if provider_name not in log_data_by_period[period]:
log_data_by_period[period][provider_name] = {'calls': 0, 'cost': 0.0, 'tokens': 0}
log_data_by_period[period][provider_name]['calls'] += (calls or 0)
log_data_by_period[period][provider_name]['cost'] += float(cost or 0.0)
log_data_by_period[period][provider_name]['tokens'] += (tokens or 0)
# 3. Update/Create Summaries based on logs
for period in periods:
period_logs = log_data_by_period.get(period, {})
summary = summary_dict.get(period)
# If no summary exists but logs do, create one
if not summary and period_logs:
logger.info(f"[UsageStats] Self-healing: Creating missing summary for {period}")
summary = UsageSummary(
user_id=user_id,
billing_period=period,
usage_status=UsageStatus.ACTIVE,
total_calls=0,
total_cost=0.0,
total_tokens=0
)
self.db.add(summary)
summary_dict[period] = summary
if summary and period_logs:
total_calls_calc = 0
total_cost_calc = 0.0
total_tokens_calc = 0
for prov, data in period_logs.items():
total_calls_calc += data['calls']
total_cost_calc += data['cost']
total_tokens_calc += data['tokens']
# Update provider specific fields if logs > summary
calls_attr = f"{prov}_calls"
cost_attr = f"{prov}_cost"
tokens_attr = f"{prov}_tokens"
if hasattr(summary, calls_attr):
current_val = getattr(summary, calls_attr, 0)
if current_val < data['calls']:
setattr(summary, calls_attr, data['calls'])
if hasattr(summary, cost_attr):
current_val = getattr(summary, cost_attr, 0.0)
# Use significant difference to avoid float noise
if (data['cost'] - current_val) > 0.000001:
setattr(summary, cost_attr, data['cost'])
if hasattr(summary, tokens_attr):
current_val = getattr(summary, tokens_attr, 0)
if current_val < data['tokens']:
setattr(summary, tokens_attr, data['tokens'])
# Update totals if under-reported
if (summary.total_cost or 0.0) < total_cost_calc:
logger.info(f"[UsageStats] Self-healing cost for {period}: {summary.total_cost} -> {total_cost_calc}")
summary.total_cost = total_cost_calc
if (summary.total_calls or 0) < total_calls_calc:
summary.total_calls = total_calls_calc
if (summary.total_tokens or 0) < total_tokens_calc:
summary.total_tokens = total_tokens_calc
self.db.commit()
except Exception as e:
logger.error(f"Failed to self-heal usage trends: {e}")
self.db.rollback()
# 4. Construct Return Data
trends = {
'periods': periods,
'total_calls': [],
'total_cost': [],
'total_tokens': [],
'provider_trends': {}
}
# Initialize provider trends structure
for provider in APIProvider:
provider_name = provider.value
trends['provider_trends'][provider_name] = {
'calls': [],
'cost': [],
'tokens': []
}
for period in periods:
summary = summary_dict.get(period)
if summary:
trends['total_calls'].append(summary.total_calls or 0)
trends['total_cost'].append(summary.total_cost or 0.0)
trends['total_tokens'].append(summary.total_tokens or 0)
# Provider-specific trends
for provider in APIProvider:
provider_name = provider.value
trends['provider_trends'][provider_name]['calls'].append(
getattr(summary, f"{provider_name}_calls", 0) or 0
)
trends['provider_trends'][provider_name]['cost'].append(
getattr(summary, f"{provider_name}_cost", 0.0) or 0.0
)
trends['provider_trends'][provider_name]['tokens'].append(
getattr(summary, f"{provider_name}_tokens", 0) or 0
)
else:
# No data for this period
trends['total_calls'].append(0)
trends['total_cost'].append(0.0)
trends['total_tokens'].append(0)
for provider in APIProvider:
provider_name = provider.value
trends['provider_trends'][provider_name]['calls'].append(0)
trends['provider_trends'][provider_name]['cost'].append(0.0)
trends['provider_trends'][provider_name]['tokens'].append(0)
return trends
periods = build_billing_periods(months)
summary_dict = query_usage_summaries(self.db, user_id, periods)
self_heal_summaries_from_logs(self.db, user_id, periods, summary_dict)
return build_usage_trends_response(periods, summary_dict)
async def enforce_usage_limits(self, user_id: str, provider: APIProvider,
tokens_requested: int = 0) -> Tuple[bool, str, Dict[str, Any]]:
@@ -890,70 +530,11 @@ class UsageTrackingService:
).first()
if not summary:
# Nothing to reset
return {"reset": False, "reason": "no_summary"}
# CRITICAL: Reset ALL usage counters to 0 so user gets fresh limits with new/renewed plan
# Clear LIMIT_REACHED status
summary.usage_status = UsageStatus.ACTIVE
# Reset all LLM provider call counters
summary.gemini_calls = 0
summary.openai_calls = 0
summary.anthropic_calls = 0
summary.mistral_calls = 0
summary.wavespeed_calls = 0
# Reset all LLM provider token counters
summary.gemini_tokens = 0
summary.openai_tokens = 0
summary.anthropic_tokens = 0
summary.mistral_tokens = 0
summary.wavespeed_tokens = 0
# Reset search/research provider counters
summary.tavily_calls = 0
summary.serper_calls = 0
summary.metaphor_calls = 0
summary.firecrawl_calls = 0
# Reset image generation counters
summary.stability_calls = 0
summary.exa_calls = 0
# Reset video generation counters
summary.video_calls = 0
# Reset audio generation counters
summary.audio_calls = 0
# Reset image editing counters
summary.image_edit_calls = 0
# Reset cost counters
summary.gemini_cost = 0.0
summary.openai_cost = 0.0
summary.anthropic_cost = 0.0
summary.mistral_cost = 0.0
summary.wavespeed_cost = 0.0
summary.tavily_cost = 0.0
summary.serper_cost = 0.0
summary.metaphor_cost = 0.0
summary.firecrawl_cost = 0.0
summary.stability_cost = 0.0
summary.exa_cost = 0.0
summary.video_cost = 0.0
summary.image_edit_cost = 0.0
summary.audio_cost = 0.0
# Reset totals
summary.total_calls = 0
summary.total_tokens = 0
summary.total_cost = 0.0
summary.updated_at = datetime.utcnow()
reset_usage_summary_counters(summary)
self.db.commit()
logger.info(f"Reset usage counters for user {user_id} in billing period {billing_period} after renewal")
return {"reset": True, "counters_reset": True}
except Exception as e: