312 lines
9.2 KiB
TypeScript
312 lines
9.2 KiB
TypeScript
import type {
|
|
LanguageModelProvider,
|
|
LanguageModel,
|
|
CreateCustomLanguageModelProviderParams,
|
|
CreateCustomLanguageModelParams,
|
|
} from "@/ipc/ipc_types";
|
|
import { createLoggedHandler } from "./safe_handle";
|
|
import log from "electron-log";
|
|
import {
|
|
CUSTOM_PROVIDER_PREFIX,
|
|
getLanguageModelProviders,
|
|
getLanguageModels,
|
|
getLanguageModelsByProviders,
|
|
} from "../shared/language_model_helpers";
|
|
import { db } from "@/db";
|
|
import {
|
|
language_models,
|
|
language_model_providers as languageModelProvidersSchema,
|
|
language_models as languageModelsSchema,
|
|
} from "@/db/schema";
|
|
import { and, eq } from "drizzle-orm";
|
|
import { IpcMainInvokeEvent } from "electron";
|
|
|
|
const logger = log.scope("language_model_handlers");
|
|
const handle = createLoggedHandler(logger);
|
|
|
|
export function registerLanguageModelHandlers() {
|
|
handle(
|
|
"get-language-model-providers",
|
|
async (): Promise<LanguageModelProvider[]> => {
|
|
return getLanguageModelProviders();
|
|
},
|
|
);
|
|
|
|
handle(
|
|
"create-custom-language-model-provider",
|
|
async (
|
|
event: IpcMainInvokeEvent,
|
|
params: CreateCustomLanguageModelProviderParams,
|
|
): Promise<LanguageModelProvider> => {
|
|
const { id, name, apiBaseUrl, envVarName } = params;
|
|
|
|
// Validation
|
|
if (!id) {
|
|
throw new Error("Provider ID is required");
|
|
}
|
|
|
|
if (!name) {
|
|
throw new Error("Provider name is required");
|
|
}
|
|
|
|
if (!apiBaseUrl) {
|
|
throw new Error("API base URL is required");
|
|
}
|
|
|
|
// Check if a provider with this ID already exists
|
|
const existingProvider = db
|
|
.select()
|
|
.from(languageModelProvidersSchema)
|
|
.where(eq(languageModelProvidersSchema.id, id))
|
|
.get();
|
|
|
|
if (existingProvider) {
|
|
throw new Error(`A provider with ID "${id}" already exists`);
|
|
}
|
|
|
|
// Insert the new provider
|
|
await db.insert(languageModelProvidersSchema).values({
|
|
// Make sure we will never have accidental collisions with builtin providers
|
|
id: CUSTOM_PROVIDER_PREFIX + id,
|
|
name,
|
|
api_base_url: apiBaseUrl,
|
|
env_var_name: envVarName || null,
|
|
});
|
|
|
|
// Return the newly created provider
|
|
return {
|
|
id,
|
|
name,
|
|
apiBaseUrl,
|
|
envVarName,
|
|
type: "custom",
|
|
};
|
|
},
|
|
);
|
|
|
|
handle(
|
|
"create-custom-language-model",
|
|
async (
|
|
event: IpcMainInvokeEvent,
|
|
params: CreateCustomLanguageModelParams,
|
|
): Promise<void> => {
|
|
const {
|
|
apiName,
|
|
displayName,
|
|
providerId,
|
|
description,
|
|
maxOutputTokens,
|
|
contextWindow,
|
|
} = params;
|
|
|
|
// Validation
|
|
if (!apiName) {
|
|
throw new Error("Model API name is required");
|
|
}
|
|
if (!displayName) {
|
|
throw new Error("Model display name is required");
|
|
}
|
|
if (!providerId) {
|
|
throw new Error("Provider ID is required");
|
|
}
|
|
|
|
// Check if provider exists
|
|
const providers = await getLanguageModelProviders();
|
|
const provider = providers.find((p) => p.id === providerId);
|
|
if (!provider) {
|
|
throw new Error(`Provider with ID "${providerId}" not found`);
|
|
}
|
|
|
|
// Insert the new model
|
|
await db.insert(languageModelsSchema).values({
|
|
displayName,
|
|
apiName,
|
|
builtinProviderId: provider.type === "cloud" ? providerId : undefined,
|
|
customProviderId: provider.type === "custom" ? providerId : undefined,
|
|
description: description || null,
|
|
max_output_tokens: maxOutputTokens || null,
|
|
context_window: contextWindow || null,
|
|
});
|
|
},
|
|
);
|
|
|
|
handle(
|
|
"delete-custom-language-model",
|
|
async (
|
|
event: IpcMainInvokeEvent,
|
|
params: { modelId: string },
|
|
): Promise<void> => {
|
|
const { modelId: apiName } = params;
|
|
|
|
// Validation
|
|
if (!apiName) {
|
|
throw new Error("Model API name (modelId) is required");
|
|
}
|
|
|
|
logger.info(
|
|
`Handling delete-custom-language-model for apiName: ${apiName}`,
|
|
);
|
|
|
|
const existingModel = await db
|
|
.select()
|
|
.from(languageModelsSchema)
|
|
.where(eq(languageModelsSchema.apiName, apiName))
|
|
.get();
|
|
|
|
if (!existingModel) {
|
|
throw new Error(
|
|
`A model with API name (modelId) "${apiName}" was not found`,
|
|
);
|
|
}
|
|
|
|
await db
|
|
.delete(languageModelsSchema)
|
|
.where(eq(languageModelsSchema.apiName, apiName));
|
|
},
|
|
);
|
|
|
|
handle(
|
|
"delete-custom-model",
|
|
async (
|
|
_event: IpcMainInvokeEvent,
|
|
params: { providerId: string; modelApiName: string },
|
|
): Promise<void> => {
|
|
const { providerId, modelApiName } = params;
|
|
logger.info(
|
|
`Handling delete-custom-model for ${providerId} / ${modelApiName}`,
|
|
);
|
|
if (!providerId || !modelApiName) {
|
|
throw new Error("Provider ID and Model API Name are required.");
|
|
}
|
|
logger.info(
|
|
`Attempting to delete custom model ${modelApiName} for provider ${providerId}`,
|
|
);
|
|
|
|
const providers = await getLanguageModelProviders();
|
|
const provider = providers.find((p) => p.id === providerId);
|
|
if (!provider) {
|
|
throw new Error(`Provider with ID "${providerId}" not found`);
|
|
}
|
|
if (provider.type === "local") {
|
|
throw new Error("Local models cannot be deleted");
|
|
}
|
|
const result = db
|
|
.delete(language_models)
|
|
.where(
|
|
and(
|
|
provider.type === "cloud"
|
|
? eq(language_models.builtinProviderId, providerId)
|
|
: eq(language_models.customProviderId, providerId),
|
|
|
|
eq(language_models.apiName, modelApiName),
|
|
),
|
|
)
|
|
.run();
|
|
|
|
if (result.changes === 0) {
|
|
logger.warn(
|
|
`No custom model found matching providerId=${providerId} and apiName=${modelApiName} for deletion.`,
|
|
);
|
|
} else {
|
|
logger.info(
|
|
`Successfully deleted ${result.changes} custom model(s) with apiName=${modelApiName} for provider=${providerId}`,
|
|
);
|
|
}
|
|
},
|
|
);
|
|
|
|
handle(
|
|
"delete-custom-language-model-provider",
|
|
async (
|
|
event: IpcMainInvokeEvent,
|
|
params: { providerId: string },
|
|
): Promise<void> => {
|
|
const { providerId } = params;
|
|
|
|
// Validation
|
|
if (!providerId) {
|
|
throw new Error("Provider ID is required");
|
|
}
|
|
|
|
logger.info(
|
|
`Handling delete-custom-language-model-provider for providerId: ${providerId}`,
|
|
);
|
|
|
|
// Check if the provider exists before attempting deletion
|
|
const existingProvider = await db
|
|
.select({ id: languageModelProvidersSchema.id })
|
|
.from(languageModelProvidersSchema)
|
|
.where(eq(languageModelProvidersSchema.id, providerId))
|
|
.get();
|
|
|
|
if (!existingProvider) {
|
|
// If the provider doesn't exist, maybe it was already deleted. Log and return.
|
|
logger.warn(
|
|
`Provider with ID "${providerId}" not found. It might have been deleted already.`,
|
|
);
|
|
// Optionally, throw new Error(`Provider with ID "${providerId}" not found`);
|
|
// Deciding to return gracefully instead of throwing an error if not found.
|
|
return;
|
|
}
|
|
|
|
// Use a transaction to ensure atomicity
|
|
await db.transaction(async (tx) => {
|
|
// 1. Delete associated models
|
|
const deleteModelsResult = await tx
|
|
.delete(languageModelsSchema)
|
|
.where(eq(languageModelsSchema.customProviderId, providerId))
|
|
.run();
|
|
logger.info(
|
|
`Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`,
|
|
);
|
|
|
|
// 2. Delete the provider
|
|
const deleteProviderResult = await tx
|
|
.delete(languageModelProvidersSchema)
|
|
.where(eq(languageModelProvidersSchema.id, providerId))
|
|
.run();
|
|
|
|
if (deleteProviderResult.changes === 0) {
|
|
// This case should ideally not happen if existingProvider check passed,
|
|
// but adding safety check within transaction.
|
|
logger.error(
|
|
`Failed to delete provider with ID "${providerId}" during transaction, although it was found initially. Rolling back.`,
|
|
);
|
|
throw new Error(
|
|
`Failed to delete provider with ID "${providerId}" which should have existed.`,
|
|
);
|
|
}
|
|
logger.info(`Successfully deleted provider with ID "${providerId}".`);
|
|
});
|
|
},
|
|
);
|
|
|
|
handle(
|
|
"get-language-models",
|
|
async (
|
|
event: IpcMainInvokeEvent,
|
|
params: { providerId: string },
|
|
): Promise<LanguageModel[]> => {
|
|
if (!params || typeof params.providerId !== "string") {
|
|
throw new Error("Invalid parameters: providerId (string) is required.");
|
|
}
|
|
const providers = await getLanguageModelProviders();
|
|
const provider = providers.find((p) => p.id === params.providerId);
|
|
if (!provider) {
|
|
throw new Error(`Provider with ID "${params.providerId}" not found`);
|
|
}
|
|
if (provider.type === "local") {
|
|
throw new Error("Local models cannot be fetched");
|
|
}
|
|
return getLanguageModels({ providerId: params.providerId });
|
|
},
|
|
);
|
|
|
|
handle(
|
|
"get-language-models-by-providers",
|
|
async (): Promise<Record<string, LanguageModel[]>> => {
|
|
return getLanguageModelsByProviders();
|
|
},
|
|
);
|
|
}
|