Fix DB schemas (#138)

This commit is contained in:
Will Chen
2025-05-12 21:51:08 -07:00
committed by GitHub
parent f5a6a1abca
commit 993c5417e3
10 changed files with 273 additions and 128 deletions

View File

@@ -9,6 +9,7 @@ import log from "electron-log";
import {
getLanguageModelProviders,
getLanguageModels,
getLanguageModelsByProviders,
} from "../shared/language_model_helpers";
import { db } from "@/db";
import {
@@ -64,7 +65,8 @@ export function registerLanguageModelHandlers() {
// Insert the new provider
await db.insert(languageModelProvidersSchema).values({
id,
// Make sure we will never have accidental collisions with builtin providers
id: "custom::" + id,
name,
api_base_url: apiBaseUrl,
env_var_name: envVarName || null,
@@ -108,12 +110,8 @@ export function registerLanguageModelHandlers() {
}
// Check if provider exists
const provider = db
.select()
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, providerId))
.get();
const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === providerId);
if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`);
}
@@ -122,7 +120,8 @@ export function registerLanguageModelHandlers() {
await db.insert(languageModelsSchema).values({
displayName,
apiName,
provider_id: providerId,
builtinProviderId: provider.type === "cloud" ? providerId : undefined,
customProviderId: provider.type === "custom" ? providerId : undefined,
description: description || null,
max_output_tokens: maxOutputTokens || null,
context_window: contextWindow || null,
@@ -182,11 +181,22 @@ export function registerLanguageModelHandlers() {
`Attempting to delete custom model ${modelApiName} for provider ${providerId}`,
);
const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === providerId);
if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`);
}
if (provider.type === "local") {
throw new Error("Local models cannot be deleted");
}
const result = db
.delete(language_models)
.where(
and(
eq(language_models.provider_id, providerId),
provider.type === "cloud"
? eq(language_models.builtinProviderId, providerId)
: eq(language_models.customProviderId, providerId),
eq(language_models.apiName, modelApiName),
),
)
@@ -243,7 +253,7 @@ export function registerLanguageModelHandlers() {
// 1. Delete associated models
const deleteModelsResult = await tx
.delete(languageModelsSchema)
.where(eq(languageModelsSchema.provider_id, providerId))
.where(eq(languageModelsSchema.customProviderId, providerId))
.run();
logger.info(
`Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`,
@@ -279,7 +289,26 @@ export function registerLanguageModelHandlers() {
if (!params || typeof params.providerId !== "string") {
throw new Error("Invalid parameters: providerId (string) is required.");
}
return getLanguageModels({ providerId: params.providerId });
const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === params.providerId);
if (!provider) {
throw new Error(`Provider with ID "${params.providerId}" not found`);
}
if (provider.type === "local") {
throw new Error("Local models cannot be fetched");
}
return getLanguageModels(
provider.type === "cloud"
? { builtinProviderId: params.providerId }
: { customProviderId: params.providerId },
);
},
);
handle(
"get-language-models-by-providers",
async (): Promise<Record<string, LanguageModel[]>> => {
return getLanguageModelsByProviders();
},
);
}