allow creating and listing custom language model (#134)
This commit is contained in:
@@ -1,22 +1,26 @@
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
import type {
|
||||
LanguageModelProvider,
|
||||
LanguageModel,
|
||||
CreateCustomLanguageModelProviderParams,
|
||||
CreateCustomLanguageModelParams,
|
||||
} from "@/ipc/ipc_types";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
import {
|
||||
getLanguageModelProviders,
|
||||
getLanguageModels,
|
||||
} from "../shared/language_model_helpers";
|
||||
import { db } from "@/db";
|
||||
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
|
||||
import {
|
||||
language_model_providers as languageModelProvidersSchema,
|
||||
language_models as languageModelsSchema,
|
||||
} from "@/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { IpcMainInvokeEvent } from "electron";
|
||||
|
||||
const logger = log.scope("language_model_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
|
||||
export interface CreateCustomLanguageModelProviderParams {
|
||||
id: string;
|
||||
name: string;
|
||||
apiBaseUrl: string;
|
||||
envVarName?: string;
|
||||
}
|
||||
|
||||
export function registerLanguageModelHandlers() {
|
||||
handle(
|
||||
"get-language-model-providers",
|
||||
@@ -47,7 +51,7 @@ export function registerLanguageModelHandlers() {
|
||||
}
|
||||
|
||||
// Check if a provider with this ID already exists
|
||||
const existingProvider = await db
|
||||
const existingProvider = db
|
||||
.select()
|
||||
.from(languageModelProvidersSchema)
|
||||
.where(eq(languageModelProvidersSchema.id, id))
|
||||
@@ -75,4 +79,77 @@ export function registerLanguageModelHandlers() {
|
||||
};
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"create-custom-language-model",
|
||||
async (
|
||||
event: IpcMainInvokeEvent,
|
||||
params: CreateCustomLanguageModelParams,
|
||||
): Promise<void> => {
|
||||
const {
|
||||
id,
|
||||
name,
|
||||
providerId,
|
||||
description,
|
||||
maxOutputTokens,
|
||||
contextWindow,
|
||||
} = params;
|
||||
|
||||
// Validation
|
||||
if (!id) {
|
||||
throw new Error("Model ID is required");
|
||||
}
|
||||
if (!name) {
|
||||
throw new Error("Model name is required");
|
||||
}
|
||||
if (!providerId) {
|
||||
throw new Error("Provider ID is required");
|
||||
}
|
||||
|
||||
// Check if provider exists
|
||||
const provider = db
|
||||
.select()
|
||||
.from(languageModelProvidersSchema)
|
||||
.where(eq(languageModelProvidersSchema.id, providerId))
|
||||
.get();
|
||||
|
||||
if (!provider) {
|
||||
throw new Error(`Provider with ID "${providerId}" not found`);
|
||||
}
|
||||
|
||||
// Check if model ID already exists
|
||||
const existingModel = db
|
||||
.select()
|
||||
.from(languageModelsSchema)
|
||||
.where(eq(languageModelsSchema.id, id))
|
||||
.get();
|
||||
|
||||
if (existingModel) {
|
||||
throw new Error(`A model with ID "${id}" already exists`);
|
||||
}
|
||||
|
||||
// Insert the new model
|
||||
await db.insert(languageModelsSchema).values({
|
||||
id,
|
||||
name,
|
||||
provider_id: providerId,
|
||||
description: description || null,
|
||||
max_output_tokens: maxOutputTokens || null,
|
||||
context_window: contextWindow || null,
|
||||
});
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"get-language-models",
|
||||
async (
|
||||
event: IpcMainInvokeEvent,
|
||||
params: { providerId: string },
|
||||
): Promise<LanguageModel[]> => {
|
||||
if (!params || typeof params.providerId !== "string") {
|
||||
throw new Error("Invalid parameters: providerId (string) is required.");
|
||||
}
|
||||
return getLanguageModels({ providerId: params.providerId });
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -22,6 +22,9 @@ import type {
|
||||
ChatLogsData,
|
||||
BranchResult,
|
||||
LanguageModelProvider,
|
||||
LanguageModel,
|
||||
CreateCustomLanguageModelProviderParams,
|
||||
CreateCustomLanguageModelParams,
|
||||
} from "./ipc_types";
|
||||
import type { ProposalResult } from "@/lib/schemas";
|
||||
import { showError } from "@/lib/toast";
|
||||
@@ -732,17 +735,18 @@ export class IpcClient {
|
||||
return this.ipcRenderer.invoke("get-language-model-providers");
|
||||
}
|
||||
|
||||
public async getLanguageModels(params: {
|
||||
providerId: string;
|
||||
}): Promise<LanguageModel[]> {
|
||||
return this.ipcRenderer.invoke("get-language-models", params);
|
||||
}
|
||||
|
||||
public async createCustomLanguageModelProvider({
|
||||
id,
|
||||
name,
|
||||
apiBaseUrl,
|
||||
envVarName,
|
||||
}: {
|
||||
id: string;
|
||||
name: string;
|
||||
apiBaseUrl: string;
|
||||
envVarName?: string;
|
||||
}): Promise<LanguageModelProvider> {
|
||||
}: CreateCustomLanguageModelProviderParams): Promise<LanguageModelProvider> {
|
||||
return this.ipcRenderer.invoke("create-custom-language-model-provider", {
|
||||
id,
|
||||
name,
|
||||
@@ -751,5 +755,11 @@ export class IpcClient {
|
||||
});
|
||||
}
|
||||
|
||||
public async createCustomLanguageModel(
|
||||
params: CreateCustomLanguageModelParams,
|
||||
): Promise<void> {
|
||||
await this.ipcRenderer.invoke("create-custom-language-model", params);
|
||||
}
|
||||
|
||||
// --- End window control methods ---
|
||||
}
|
||||
|
||||
@@ -143,3 +143,30 @@ export interface LanguageModelProvider {
|
||||
apiBaseUrl?: string;
|
||||
type: "custom" | "local" | "cloud";
|
||||
}
|
||||
|
||||
export interface LanguageModel {
|
||||
id: string;
|
||||
name: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
tag?: string;
|
||||
maxOutputTokens?: number;
|
||||
contextWindow?: number;
|
||||
type: "local" | "cloud" | "custom";
|
||||
}
|
||||
|
||||
export interface CreateCustomLanguageModelProviderParams {
|
||||
id: string;
|
||||
name: string;
|
||||
apiBaseUrl: string;
|
||||
envVarName?: string;
|
||||
}
|
||||
|
||||
export interface CreateCustomLanguageModelParams {
|
||||
id: string;
|
||||
name: string;
|
||||
providerId: string;
|
||||
description?: string;
|
||||
maxOutputTokens?: number;
|
||||
contextWindow?: number;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import { db } from "@/db";
|
||||
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
|
||||
import { RegularModelProvider } from "@/constants/models";
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
import {
|
||||
language_model_providers as languageModelProvidersSchema,
|
||||
language_models as languageModelsSchema,
|
||||
} from "@/db/schema";
|
||||
import { MODEL_OPTIONS, RegularModelProvider } from "@/constants/models";
|
||||
import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
@@ -129,3 +133,80 @@ export async function getLanguageModelProviders(): Promise<
|
||||
|
||||
return Array.from(mergedProvidersMap.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches language models for a specific provider.
|
||||
* @param obj An object containing the providerId.
|
||||
* @returns A promise that resolves to an array of LanguageModel objects.
|
||||
*/
|
||||
export async function getLanguageModels(obj: {
|
||||
providerId: string;
|
||||
}): Promise<LanguageModel[]> {
|
||||
const { providerId } = obj;
|
||||
const allProviders = await getLanguageModelProviders();
|
||||
const provider = allProviders.find((p) => p.id === providerId);
|
||||
|
||||
if (!provider) {
|
||||
console.warn(`Provider with ID "${providerId}" not found.`);
|
||||
return [];
|
||||
}
|
||||
|
||||
if (provider.type === "cloud") {
|
||||
// Check if providerId is a valid key for MODEL_OPTIONS
|
||||
if (providerId in MODEL_OPTIONS) {
|
||||
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
||||
return models.map((model) => ({
|
||||
...model,
|
||||
id: model.name,
|
||||
type: "cloud",
|
||||
}));
|
||||
} else {
|
||||
console.warn(
|
||||
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
} else if (provider.type === "custom") {
|
||||
// Fetch models from the database for this custom provider
|
||||
// Assuming a language_models table with necessary columns and provider_id foreign key
|
||||
try {
|
||||
const customModelsDb = await db
|
||||
.select({
|
||||
id: languageModelsSchema.id,
|
||||
// Map DB columns to LanguageModel fields
|
||||
name: languageModelsSchema.name,
|
||||
// No display_name in DB, use name instead
|
||||
description: languageModelsSchema.description,
|
||||
// No tag in DB
|
||||
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
||||
contextWindow: languageModelsSchema.context_window,
|
||||
})
|
||||
.from(languageModelsSchema)
|
||||
.where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available
|
||||
|
||||
return customModelsDb.map((model) => ({
|
||||
...model,
|
||||
displayName: model.name, // Use name as displayName for custom models
|
||||
// Ensure possibly null fields are handled, provide defaults or undefined if needed
|
||||
description: model.description ?? "",
|
||||
tag: undefined, // No tag for custom models from DB
|
||||
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
||||
contextWindow: model.contextWindow ?? undefined,
|
||||
type: "custom",
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error fetching custom models for provider "${providerId}" from DB:`,
|
||||
error,
|
||||
);
|
||||
// Depending on desired behavior, could throw, return empty, or return a specific error state
|
||||
return [];
|
||||
}
|
||||
} else {
|
||||
// Handle other types like "local" if necessary, currently ignored
|
||||
console.warn(
|
||||
`Provider type "${provider.type}" not handled for model fetching.`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user