Fix DB schemas (#138)

This commit is contained in:
Will Chen
2025-05-12 21:51:08 -07:00
committed by GitHub
parent f5a6a1abca
commit 993c5417e3
10 changed files with 273 additions and 128 deletions

View File

@@ -9,6 +9,7 @@ import log from "electron-log";
import {
getLanguageModelProviders,
getLanguageModels,
getLanguageModelsByProviders,
} from "../shared/language_model_helpers";
import { db } from "@/db";
import {
@@ -64,7 +65,8 @@ export function registerLanguageModelHandlers() {
// Insert the new provider
await db.insert(languageModelProvidersSchema).values({
id,
// Make sure we will never have accidental collisions with builtin providers
id: "custom::" + id,
name,
api_base_url: apiBaseUrl,
env_var_name: envVarName || null,
@@ -108,12 +110,8 @@ export function registerLanguageModelHandlers() {
}
// Check if provider exists
const provider = db
.select()
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, providerId))
.get();
const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === providerId);
if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`);
}
@@ -122,7 +120,8 @@ export function registerLanguageModelHandlers() {
await db.insert(languageModelsSchema).values({
displayName,
apiName,
provider_id: providerId,
builtinProviderId: provider.type === "cloud" ? providerId : undefined,
customProviderId: provider.type === "custom" ? providerId : undefined,
description: description || null,
max_output_tokens: maxOutputTokens || null,
context_window: contextWindow || null,
@@ -182,11 +181,22 @@ export function registerLanguageModelHandlers() {
`Attempting to delete custom model ${modelApiName} for provider ${providerId}`,
);
const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === providerId);
if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`);
}
if (provider.type === "local") {
throw new Error("Local models cannot be deleted");
}
const result = db
.delete(language_models)
.where(
and(
eq(language_models.provider_id, providerId),
provider.type === "cloud"
? eq(language_models.builtinProviderId, providerId)
: eq(language_models.customProviderId, providerId),
eq(language_models.apiName, modelApiName),
),
)
@@ -243,7 +253,7 @@ export function registerLanguageModelHandlers() {
// 1. Delete associated models
const deleteModelsResult = await tx
.delete(languageModelsSchema)
.where(eq(languageModelsSchema.provider_id, providerId))
.where(eq(languageModelsSchema.customProviderId, providerId))
.run();
logger.info(
`Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`,
@@ -279,7 +289,26 @@ export function registerLanguageModelHandlers() {
if (!params || typeof params.providerId !== "string") {
throw new Error("Invalid parameters: providerId (string) is required.");
}
return getLanguageModels({ providerId: params.providerId });
const providers = await getLanguageModelProviders();
const provider = providers.find((p) => p.id === params.providerId);
if (!provider) {
throw new Error(`Provider with ID "${params.providerId}" not found`);
}
if (provider.type === "local") {
throw new Error("Local models cannot be fetched");
}
return getLanguageModels(
provider.type === "cloud"
? { builtinProviderId: params.providerId }
: { customProviderId: params.providerId },
);
},
);
handle(
"get-language-models-by-providers",
async (): Promise<Record<string, LanguageModel[]>> => {
return getLanguageModelsByProviders();
},
);
}

View File

@@ -746,6 +746,12 @@ export class IpcClient {
return this.ipcRenderer.invoke("get-language-models", params);
}
public async getLanguageModelsByProviders(): Promise<
Record<string, LanguageModel[]>
> {
return this.ipcRenderer.invoke("get-language-models-by-providers");
}
public async createCustomLanguageModelProvider({
id,
name,

View File

@@ -139,23 +139,72 @@ export async function getLanguageModelProviders(): Promise<
* @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects.
*/
export async function getLanguageModels(obj: {
providerId: string;
}): Promise<LanguageModel[]> {
const { providerId } = obj;
export async function getLanguageModels(
obj:
| {
customProviderId: string;
// builtinProviderId?: undefined;
}
| {
builtinProviderId: string;
// customProviderId?: undefined;
},
): Promise<LanguageModel[]> {
const allProviders = await getLanguageModelProviders();
const provider = allProviders.find((p) => p.id === providerId);
const provider = allProviders.find(
(p) =>
p.id === (obj as { customProviderId: string }).customProviderId ||
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
);
if (!provider) {
console.warn(`Provider with ID "${providerId}" not found.`);
console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`);
return [];
}
// Get custom models from DB for all provider types
let customModels: LanguageModel[] = [];
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
displayName: languageModelsSchema.displayName,
apiName: languageModelsSchema.apiName,
description: languageModelsSchema.description,
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(
"customProviderId" in obj
? eq(languageModelsSchema.customProviderId, obj.customProviderId)
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId),
);
customModels = customModelsDb.map((model) => ({
...model,
description: model.description ?? "",
tag: undefined,
maxOutputTokens: model.maxOutputTokens ?? undefined,
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`,
error,
);
// Continue with empty custom models array
}
// If it's a cloud provider, also get the hardcoded models
let hardcodedModels: LanguageModel[] = [];
const providerId = provider.id;
if (provider.type === "cloud") {
// Check if providerId is a valid key for MODEL_OPTIONS
if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
return models.map((model) => ({
hardcodedModels = models.map((model) => ({
...model,
apiName: model.name,
type: "cloud",
@@ -164,49 +213,54 @@ export async function getLanguageModels(obj: {
console.warn(
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
);
return [];
}
} else if (provider.type === "custom") {
// Fetch models from the database for this custom provider
// Assuming a language_models table with necessary columns and provider_id foreign key
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
// Map DB columns to LanguageModel fields
displayName: languageModelsSchema.displayName,
apiName: languageModelsSchema.apiName,
// No display_name in DB, use name instead
description: languageModelsSchema.description,
// No tag in DB
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available
return customModelsDb.map((model) => ({
...model,
// Ensure possibly null fields are handled, provide defaults or undefined if needed
description: model.description ?? "",
tag: undefined, // No tag for custom models from DB
maxOutputTokens: model.maxOutputTokens ?? undefined,
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${providerId}" from DB:`,
error,
);
// Depending on desired behavior, could throw, return empty, or return a specific error state
return [];
}
} else {
// Handle other types like "local" if necessary, currently ignored
console.warn(
`Provider type "${provider.type}" not handled for model fetching.`,
);
return [];
}
// Merge the models, with custom models taking precedence over hardcoded ones
const mergedModelsMap = new Map<string, LanguageModel>();
// Add hardcoded models first
for (const model of hardcodedModels) {
mergedModelsMap.set(model.apiName, model);
}
// Then override with custom models
for (const model of customModels) {
mergedModelsMap.set(model.apiName, model);
}
return Array.from(mergedModelsMap.values());
}
/**
* Fetches all language models grouped by their provider IDs.
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
*/
export async function getLanguageModelsByProviders(): Promise<
Record<string, LanguageModel[]>
> {
const providers = await getLanguageModelProviders();
// Fetch all models concurrently
const modelPromises = providers
.filter((p) => p.type !== "local")
.map(async (provider) => {
const models = await getLanguageModels(
provider.type === "cloud"
? { builtinProviderId: provider.id }
: { customProviderId: provider.id },
);
return { providerId: provider.id, models };
});
// Wait for all requests to complete
const results = await Promise.all(modelPromises);
// Convert the array of results to a record
const record: Record<string, LanguageModel[]> = {};
for (const result of results) {
record[result.providerId] = result.models;
}
return record;
}