fix: Resolve APIProvider enum mismatch causing dashboard errors

- Fix import path in subscriptions.py (pricing_service location)
- Add values_callable to APIUsageLog.provider enum column
- Normalize provider values to lowercase in usage trends helpers
- Add migration script for existing databases
This commit is contained in:
ajaysi
2026-03-29 12:50:50 +05:30
parent d6ad903e3d
commit 557f700f68
4 changed files with 176 additions and 9 deletions

View File

@@ -170,7 +170,7 @@ async def get_subscription_status(
if getattr(subscription, 'auto_renew', False):
# advance period
try:
from services.pricing_service import PricingService
from services.subscription.pricing_service import PricingService
pricing = PricingService(db)
# reuse helper to ensure current
pricing._ensure_subscription_current(subscription)
@@ -245,7 +245,7 @@ async def get_subscription_status(
if subscription.current_period_end < now:
if getattr(subscription, 'auto_renew', False):
try:
from services.pricing_service import PricingService
from services.subscription.pricing_service import PricingService
pricing = PricingService(db)
pricing._ensure_subscription_current(subscription)
except Exception as e2:

View File

@@ -155,7 +155,7 @@ class APIUsageLog(Base):
user_id = Column(String(100), nullable=False)
# API Details
provider = Column(Enum(APIProvider), nullable=False)
provider = Column(Enum(APIProvider, values_callable=lambda obj: [e.value for e in obj]), nullable=False)
endpoint = Column(String(200), nullable=False)
method = Column(String(10), nullable=False)
model_used = Column(String(100), nullable=True) # e.g., "gemini-2.5-flash"

View File

@@ -0,0 +1,143 @@
#!/usr/bin/env python3
"""
Migration script to normalize APIProvider enum values to lowercase.
This fixes the issue where the database has uppercase values like "VIDEO", "MISTRAL"
but the enum expects lowercase values like "video", "mistral".
Run this script once to migrate existing data.
"""
import os
import sys
from pathlib import Path
# Add backend to path
backend_dir = Path(__file__).parent
sys.path.insert(0, str(backend_dir))
# Load env first
from dotenv import load_dotenv
load_dotenv(backend_dir / '.env')
from loguru import logger
import sqlite3
from glob import glob
# Provider mapping: uppercase -> lowercase
PROVIDER_MAP = {
'GEMINI': 'gemini',
'OPENAI': 'openai',
'ANTHROPIC': 'anthropic',
'MISTRAL': 'mistral',
'WAVESPEED': 'wavespeed',
'TAVILY': 'tavily',
'SERPER': 'serper',
'METAPHOR': 'metaphor',
'FIRECRAWL': 'firecrawl',
'STABILITY': 'stability',
'EXA': 'exa',
'VIDEO': 'video',
'IMAGE_EDIT': 'image_edit',
'AUDIO': 'audio',
}
def normalize_provider_value(value: str) -> str:
"""Convert provider value to lowercase if it's uppercase."""
if not value:
return value
upper_value = value.upper()
if upper_value in PROVIDER_MAP:
return PROVIDER_MAP[upper_value]
# If already lowercase, return as-is
return value
def migrate_database(db_path: str) -> tuple[int, int]:
"""Migrate a single database file. Returns (total_rows, updated_rows)."""
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Check if api_usage_logs table exists
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='api_usage_logs'")
if not cursor.fetchone():
conn.close()
return 0, 0
# Get all unique provider values
cursor.execute("SELECT DISTINCT provider FROM api_usage_logs")
unique_providers = [row[0] for row in cursor.fetchall()]
total_rows = 0
updated_rows = 0
# Count total rows
cursor.execute("SELECT COUNT(*) FROM api_usage_logs")
total_rows = cursor.fetchone()[0]
# Update each provider value that needs normalization
for provider in unique_providers:
normalized = normalize_provider_value(provider)
if provider != normalized:
cursor.execute(
"UPDATE api_usage_logs SET provider = ? WHERE provider = ?",
(normalized, provider)
)
count = cursor.rowcount
updated_rows += count
logger.info(f" {db_path}: Updated {count} rows with provider '{provider}' -> '{normalized}'")
conn.commit()
conn.close()
return total_rows, updated_rows
except Exception as e:
logger.error(f"Error migrating {db_path}: {e}")
return 0, 0
def main():
"""Main migration function."""
logger.info("=" * 60)
logger.info("APIProvider Enum Normalization Migration")
logger.info("=" * 60)
# Workspace directory where user databases are stored
workspace_dir = backend_dir / "workspace"
# Find all database files in workspaces
all_dbs = set()
if workspace_dir.exists():
for workspace in workspace_dir.iterdir():
if workspace.is_dir():
# Look for database files in workspace subdirectories (database/, or root)
for db_pattern in ["**/*.db", "**/alwrity*.db"]:
all_dbs.update(workspace.glob(db_pattern))
logger.info(f"Found {len(all_dbs)} database files to check")
total_dbs = 0
total_rows = 0
total_updated = 0
for db_path in sorted(all_dbs):
total_dbs += 1
rows, updated = migrate_database(str(db_path))
total_rows += rows
total_updated += updated
logger.info("=" * 60)
logger.info(f"Migration complete!")
logger.info(f" Databases checked: {total_dbs}")
logger.info(f" Total rows: {total_rows}")
logger.info(f" Rows updated: {total_updated}")
logger.info("=" * 60)
if total_updated > 0:
logger.warning("Please restart the backend server to ensure changes take effect.")
else:
logger.info("No updates needed - all provider values are already normalized.")
if __name__ == "__main__":
main()

View File

@@ -11,6 +11,26 @@ from sqlalchemy import func
from models.subscription_models import APIProvider, APIUsageLog, UsageStatus, UsageSummary
def _normalize_provider_name(provider_input: Any) -> str | None:
"""Safely extract provider name from enum or string, handling both name and value formats."""
valid_providers = {'gemini', 'openai', 'anthropic', 'mistral', 'wavespeed',
'tavily', 'serper', 'metaphor', 'firecrawl', 'stability',
'exa', 'video', 'image_edit', 'audio'}
try:
if hasattr(provider_input, "value"):
return provider_input.value
elif isinstance(provider_input, str):
name = provider_input.lower()
if "." in name:
name = name.split(".")[-1].lower()
return name
else:
return str(provider_input).lower()
except Exception:
return None
def build_billing_periods(months: int) -> List[str]:
"""Build billing period keys (YYYY-MM) from oldest to newest."""
end_date = datetime.now()
@@ -35,27 +55,31 @@ def query_usage_summaries(db: Any, user_id: str, periods: List[str]) -> Dict[str
def self_heal_summaries_from_logs(db: Any, user_id: str, periods: List[str], summary_dict: Dict[str, Any]) -> None:
"""Backfill/create usage summaries from aggregated API usage logs."""
try:
from sqlalchemy import cast, String
log_stats = (
db.query(
APIUsageLog.billing_period,
APIUsageLog.provider,
cast(APIUsageLog.provider, String).label("provider"),
func.count(APIUsageLog.id).label("calls"),
func.sum(APIUsageLog.cost_total).label("cost"),
func.sum(APIUsageLog.tokens_total).label("tokens"),
)
.filter(APIUsageLog.user_id == user_id, APIUsageLog.billing_period.in_(periods))
.group_by(APIUsageLog.billing_period, APIUsageLog.provider)
.group_by(APIUsageLog.billing_period, cast(APIUsageLog.provider, String))
.all()
)
log_data_by_period: Dict[str, Dict[str, Dict[str, float | int]]] = {}
for period, provider_enum, calls, cost, tokens in log_stats:
for period, provider_str, calls, cost, tokens in log_stats:
if period not in log_data_by_period:
log_data_by_period[period] = {}
provider_name = provider_enum.value if hasattr(provider_enum, "value") else str(provider_enum).lower()
if "." in provider_name:
provider_name = provider_name.split(".")[-1].lower()
provider_name = _normalize_provider_name(provider_str)
if not provider_name:
logger.warning(f"[UsageStats] Could not normalize provider: '{provider_str}', skipping")
continue
if provider_name not in log_data_by_period[period]:
log_data_by_period[period][provider_name] = {"calls": 0, "cost": 0.0, "tokens": 0}