allow creating and listing custom language model (#134)

This commit is contained in:
Will Chen
2025-05-12 16:00:16 -07:00
committed by GitHub
parent c63781d7cc
commit 477015b43d
13 changed files with 925 additions and 291 deletions

View File

@@ -1,22 +1,26 @@
import type { LanguageModelProvider } from "@/ipc/ipc_types";
import type {
LanguageModelProvider,
LanguageModel,
CreateCustomLanguageModelProviderParams,
CreateCustomLanguageModelParams,
} from "@/ipc/ipc_types";
import { createLoggedHandler } from "./safe_handle";
import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
import {
getLanguageModelProviders,
getLanguageModels,
} from "../shared/language_model_helpers";
import { db } from "@/db";
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
import {
language_model_providers as languageModelProvidersSchema,
language_models as languageModelsSchema,
} from "@/db/schema";
import { eq } from "drizzle-orm";
import { IpcMainInvokeEvent } from "electron";
const logger = log.scope("language_model_handlers");
const handle = createLoggedHandler(logger);
export interface CreateCustomLanguageModelProviderParams {
id: string;
name: string;
apiBaseUrl: string;
envVarName?: string;
}
export function registerLanguageModelHandlers() {
handle(
"get-language-model-providers",
@@ -47,7 +51,7 @@ export function registerLanguageModelHandlers() {
}
// Check if a provider with this ID already exists
const existingProvider = await db
const existingProvider = db
.select()
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, id))
@@ -75,4 +79,77 @@ export function registerLanguageModelHandlers() {
};
},
);
handle(
"create-custom-language-model",
async (
event: IpcMainInvokeEvent,
params: CreateCustomLanguageModelParams,
): Promise<void> => {
const {
id,
name,
providerId,
description,
maxOutputTokens,
contextWindow,
} = params;
// Validation
if (!id) {
throw new Error("Model ID is required");
}
if (!name) {
throw new Error("Model name is required");
}
if (!providerId) {
throw new Error("Provider ID is required");
}
// Check if provider exists
const provider = db
.select()
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, providerId))
.get();
if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`);
}
// Check if model ID already exists
const existingModel = db
.select()
.from(languageModelsSchema)
.where(eq(languageModelsSchema.id, id))
.get();
if (existingModel) {
throw new Error(`A model with ID "${id}" already exists`);
}
// Insert the new model
await db.insert(languageModelsSchema).values({
id,
name,
provider_id: providerId,
description: description || null,
max_output_tokens: maxOutputTokens || null,
context_window: contextWindow || null,
});
},
);
handle(
"get-language-models",
async (
event: IpcMainInvokeEvent,
params: { providerId: string },
): Promise<LanguageModel[]> => {
if (!params || typeof params.providerId !== "string") {
throw new Error("Invalid parameters: providerId (string) is required.");
}
return getLanguageModels({ providerId: params.providerId });
},
);
}