allow creating and listing custom language model (#134)

This commit is contained in:
Will Chen
2025-05-12 16:00:16 -07:00
committed by GitHub
parent c63781d7cc
commit 477015b43d
13 changed files with 925 additions and 291 deletions

View File

@@ -1,22 +1,26 @@
import type { LanguageModelProvider } from "@/ipc/ipc_types";
import type {
LanguageModelProvider,
LanguageModel,
CreateCustomLanguageModelProviderParams,
CreateCustomLanguageModelParams,
} from "@/ipc/ipc_types";
import { createLoggedHandler } from "./safe_handle";
import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
import {
getLanguageModelProviders,
getLanguageModels,
} from "../shared/language_model_helpers";
import { db } from "@/db";
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
import {
language_model_providers as languageModelProvidersSchema,
language_models as languageModelsSchema,
} from "@/db/schema";
import { eq } from "drizzle-orm";
import { IpcMainInvokeEvent } from "electron";
const logger = log.scope("language_model_handlers");
const handle = createLoggedHandler(logger);
export interface CreateCustomLanguageModelProviderParams {
id: string;
name: string;
apiBaseUrl: string;
envVarName?: string;
}
export function registerLanguageModelHandlers() {
handle(
"get-language-model-providers",
@@ -47,7 +51,7 @@ export function registerLanguageModelHandlers() {
}
// Check if a provider with this ID already exists
const existingProvider = await db
const existingProvider = db
.select()
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, id))
@@ -75,4 +79,77 @@ export function registerLanguageModelHandlers() {
};
},
);
handle(
"create-custom-language-model",
async (
event: IpcMainInvokeEvent,
params: CreateCustomLanguageModelParams,
): Promise<void> => {
const {
id,
name,
providerId,
description,
maxOutputTokens,
contextWindow,
} = params;
// Validation
if (!id) {
throw new Error("Model ID is required");
}
if (!name) {
throw new Error("Model name is required");
}
if (!providerId) {
throw new Error("Provider ID is required");
}
// Check if provider exists
const provider = db
.select()
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, providerId))
.get();
if (!provider) {
throw new Error(`Provider with ID "${providerId}" not found`);
}
// Check if model ID already exists
const existingModel = db
.select()
.from(languageModelsSchema)
.where(eq(languageModelsSchema.id, id))
.get();
if (existingModel) {
throw new Error(`A model with ID "${id}" already exists`);
}
// Insert the new model
await db.insert(languageModelsSchema).values({
id,
name,
provider_id: providerId,
description: description || null,
max_output_tokens: maxOutputTokens || null,
context_window: contextWindow || null,
});
},
);
handle(
"get-language-models",
async (
event: IpcMainInvokeEvent,
params: { providerId: string },
): Promise<LanguageModel[]> => {
if (!params || typeof params.providerId !== "string") {
throw new Error("Invalid parameters: providerId (string) is required.");
}
return getLanguageModels({ providerId: params.providerId });
},
);
}

View File

@@ -22,6 +22,9 @@ import type {
ChatLogsData,
BranchResult,
LanguageModelProvider,
LanguageModel,
CreateCustomLanguageModelProviderParams,
CreateCustomLanguageModelParams,
} from "./ipc_types";
import type { ProposalResult } from "@/lib/schemas";
import { showError } from "@/lib/toast";
@@ -732,17 +735,18 @@ export class IpcClient {
return this.ipcRenderer.invoke("get-language-model-providers");
}
public async getLanguageModels(params: {
providerId: string;
}): Promise<LanguageModel[]> {
return this.ipcRenderer.invoke("get-language-models", params);
}
public async createCustomLanguageModelProvider({
id,
name,
apiBaseUrl,
envVarName,
}: {
id: string;
name: string;
apiBaseUrl: string;
envVarName?: string;
}): Promise<LanguageModelProvider> {
}: CreateCustomLanguageModelProviderParams): Promise<LanguageModelProvider> {
return this.ipcRenderer.invoke("create-custom-language-model-provider", {
id,
name,
@@ -751,5 +755,11 @@ export class IpcClient {
});
}
public async createCustomLanguageModel(
params: CreateCustomLanguageModelParams,
): Promise<void> {
await this.ipcRenderer.invoke("create-custom-language-model", params);
}
// --- End window control methods ---
}

View File

@@ -143,3 +143,30 @@ export interface LanguageModelProvider {
apiBaseUrl?: string;
type: "custom" | "local" | "cloud";
}
export interface LanguageModel {
id: string;
name: string;
displayName: string;
description: string;
tag?: string;
maxOutputTokens?: number;
contextWindow?: number;
type: "local" | "cloud" | "custom";
}
export interface CreateCustomLanguageModelProviderParams {
id: string;
name: string;
apiBaseUrl: string;
envVarName?: string;
}
export interface CreateCustomLanguageModelParams {
id: string;
name: string;
providerId: string;
description?: string;
maxOutputTokens?: number;
contextWindow?: number;
}

View File

@@ -1,7 +1,11 @@
import { db } from "@/db";
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
import { RegularModelProvider } from "@/constants/models";
import type { LanguageModelProvider } from "@/ipc/ipc_types";
import {
language_model_providers as languageModelProvidersSchema,
language_models as languageModelsSchema,
} from "@/db/schema";
import { MODEL_OPTIONS, RegularModelProvider } from "@/constants/models";
import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types";
import { eq } from "drizzle-orm";
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
openai: "OPENAI_API_KEY",
@@ -129,3 +133,80 @@ export async function getLanguageModelProviders(): Promise<
return Array.from(mergedProvidersMap.values());
}
/**
* Fetches language models for a specific provider.
* @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects.
*/
export async function getLanguageModels(obj: {
providerId: string;
}): Promise<LanguageModel[]> {
const { providerId } = obj;
const allProviders = await getLanguageModelProviders();
const provider = allProviders.find((p) => p.id === providerId);
if (!provider) {
console.warn(`Provider with ID "${providerId}" not found.`);
return [];
}
if (provider.type === "cloud") {
// Check if providerId is a valid key for MODEL_OPTIONS
if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
return models.map((model) => ({
...model,
id: model.name,
type: "cloud",
}));
} else {
console.warn(
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
);
return [];
}
} else if (provider.type === "custom") {
// Fetch models from the database for this custom provider
// Assuming a language_models table with necessary columns and provider_id foreign key
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
// Map DB columns to LanguageModel fields
name: languageModelsSchema.name,
// No display_name in DB, use name instead
description: languageModelsSchema.description,
// No tag in DB
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available
return customModelsDb.map((model) => ({
...model,
displayName: model.name, // Use name as displayName for custom models
// Ensure possibly null fields are handled, provide defaults or undefined if needed
description: model.description ?? "",
tag: undefined, // No tag for custom models from DB
maxOutputTokens: model.maxOutputTokens ?? undefined,
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${providerId}" from DB:`,
error,
);
// Depending on desired behavior, could throw, return empty, or return a specific error state
return [];
}
} else {
// Handle other types like "local" if necessary, currently ignored
console.warn(
`Provider type "${provider.type}" not handled for model fetching.`,
);
return [];
}
}