- Fix import path in subscriptions.py (pricing_service location) - Add values_callable to APIUsageLog.provider enum column - Normalize provider values to lowercase in usage trends helpers - Add migration script for existing databases
144 lines
4.4 KiB
Python
144 lines
4.4 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
Migration script to normalize APIProvider enum values to lowercase.
|
|
|
|
This fixes the issue where the database has uppercase values like "VIDEO", "MISTRAL"
|
|
but the enum expects lowercase values like "video", "mistral".
|
|
|
|
Run this script once to migrate existing data.
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
# Add backend to path
|
|
backend_dir = Path(__file__).parent
|
|
sys.path.insert(0, str(backend_dir))
|
|
|
|
# Load env first
|
|
from dotenv import load_dotenv
|
|
load_dotenv(backend_dir / '.env')
|
|
|
|
from loguru import logger
|
|
import sqlite3
|
|
from glob import glob
|
|
|
|
# Provider mapping: uppercase -> lowercase
|
|
PROVIDER_MAP = {
|
|
'GEMINI': 'gemini',
|
|
'OPENAI': 'openai',
|
|
'ANTHROPIC': 'anthropic',
|
|
'MISTRAL': 'mistral',
|
|
'WAVESPEED': 'wavespeed',
|
|
'TAVILY': 'tavily',
|
|
'SERPER': 'serper',
|
|
'METAPHOR': 'metaphor',
|
|
'FIRECRAWL': 'firecrawl',
|
|
'STABILITY': 'stability',
|
|
'EXA': 'exa',
|
|
'VIDEO': 'video',
|
|
'IMAGE_EDIT': 'image_edit',
|
|
'AUDIO': 'audio',
|
|
}
|
|
|
|
def normalize_provider_value(value: str) -> str:
|
|
"""Convert provider value to lowercase if it's uppercase."""
|
|
if not value:
|
|
return value
|
|
upper_value = value.upper()
|
|
if upper_value in PROVIDER_MAP:
|
|
return PROVIDER_MAP[upper_value]
|
|
# If already lowercase, return as-is
|
|
return value
|
|
|
|
def migrate_database(db_path: str) -> tuple[int, int]:
|
|
"""Migrate a single database file. Returns (total_rows, updated_rows)."""
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Check if api_usage_logs table exists
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='api_usage_logs'")
|
|
if not cursor.fetchone():
|
|
conn.close()
|
|
return 0, 0
|
|
|
|
# Get all unique provider values
|
|
cursor.execute("SELECT DISTINCT provider FROM api_usage_logs")
|
|
unique_providers = [row[0] for row in cursor.fetchall()]
|
|
|
|
total_rows = 0
|
|
updated_rows = 0
|
|
|
|
# Count total rows
|
|
cursor.execute("SELECT COUNT(*) FROM api_usage_logs")
|
|
total_rows = cursor.fetchone()[0]
|
|
|
|
# Update each provider value that needs normalization
|
|
for provider in unique_providers:
|
|
normalized = normalize_provider_value(provider)
|
|
if provider != normalized:
|
|
cursor.execute(
|
|
"UPDATE api_usage_logs SET provider = ? WHERE provider = ?",
|
|
(normalized, provider)
|
|
)
|
|
count = cursor.rowcount
|
|
updated_rows += count
|
|
logger.info(f" {db_path}: Updated {count} rows with provider '{provider}' -> '{normalized}'")
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
return total_rows, updated_rows
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error migrating {db_path}: {e}")
|
|
return 0, 0
|
|
|
|
def main():
|
|
"""Main migration function."""
|
|
logger.info("=" * 60)
|
|
logger.info("APIProvider Enum Normalization Migration")
|
|
logger.info("=" * 60)
|
|
|
|
# Workspace directory where user databases are stored
|
|
workspace_dir = backend_dir / "workspace"
|
|
|
|
# Find all database files in workspaces
|
|
all_dbs = set()
|
|
|
|
if workspace_dir.exists():
|
|
for workspace in workspace_dir.iterdir():
|
|
if workspace.is_dir():
|
|
# Look for database files in workspace subdirectories (database/, or root)
|
|
for db_pattern in ["**/*.db", "**/alwrity*.db"]:
|
|
all_dbs.update(workspace.glob(db_pattern))
|
|
|
|
logger.info(f"Found {len(all_dbs)} database files to check")
|
|
|
|
total_dbs = 0
|
|
total_rows = 0
|
|
total_updated = 0
|
|
|
|
for db_path in sorted(all_dbs):
|
|
total_dbs += 1
|
|
rows, updated = migrate_database(str(db_path))
|
|
total_rows += rows
|
|
total_updated += updated
|
|
|
|
logger.info("=" * 60)
|
|
logger.info(f"Migration complete!")
|
|
logger.info(f" Databases checked: {total_dbs}")
|
|
logger.info(f" Total rows: {total_rows}")
|
|
logger.info(f" Rows updated: {total_updated}")
|
|
logger.info("=" * 60)
|
|
|
|
if total_updated > 0:
|
|
logger.warning("Please restart the backend server to ensure changes take effect.")
|
|
else:
|
|
logger.info("No updates needed - all provider values are already normalized.")
|
|
|
|
if __name__ == "__main__":
|
|
main()
|