Fix DB schemas (#138)
This commit is contained in:
@@ -11,11 +11,12 @@ CREATE TABLE `language_models` (
|
|||||||
`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL,
|
`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||||
`display_name` text NOT NULL,
|
`display_name` text NOT NULL,
|
||||||
`api_name` text NOT NULL,
|
`api_name` text NOT NULL,
|
||||||
`provider_id` text NOT NULL,
|
`builtin_provider_id` text,
|
||||||
|
`custom_provider_id` text,
|
||||||
`description` text,
|
`description` text,
|
||||||
`max_output_tokens` integer,
|
`max_output_tokens` integer,
|
||||||
`context_window` integer,
|
`context_window` integer,
|
||||||
`created_at` integer DEFAULT (unixepoch()) NOT NULL,
|
`created_at` integer DEFAULT (unixepoch()) NOT NULL,
|
||||||
`updated_at` integer DEFAULT (unixepoch()) NOT NULL,
|
`updated_at` integer DEFAULT (unixepoch()) NOT NULL,
|
||||||
FOREIGN KEY (`provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade
|
FOREIGN KEY (`custom_provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade
|
||||||
);
|
);
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"version": "6",
|
"version": "6",
|
||||||
"dialect": "sqlite",
|
"dialect": "sqlite",
|
||||||
"id": "424973e5-a102-4b71-8d5d-6160f1609b8c",
|
"id": "0a47ec41-9477-4457-b3e8-e5ecb3e3a855",
|
||||||
"prevId": "ceedb797-6aa3-4a50-b42f-bc85ee08b3df",
|
"prevId": "ceedb797-6aa3-4a50-b42f-bc85ee08b3df",
|
||||||
"tables": {
|
"tables": {
|
||||||
"apps": {
|
"apps": {
|
||||||
@@ -210,11 +210,18 @@
|
|||||||
"notNull": true,
|
"notNull": true,
|
||||||
"autoincrement": false
|
"autoincrement": false
|
||||||
},
|
},
|
||||||
"provider_id": {
|
"builtin_provider_id": {
|
||||||
"name": "provider_id",
|
"name": "builtin_provider_id",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"primaryKey": false,
|
"primaryKey": false,
|
||||||
"notNull": true,
|
"notNull": false,
|
||||||
|
"autoincrement": false
|
||||||
|
},
|
||||||
|
"custom_provider_id": {
|
||||||
|
"name": "custom_provider_id",
|
||||||
|
"type": "text",
|
||||||
|
"primaryKey": false,
|
||||||
|
"notNull": false,
|
||||||
"autoincrement": false
|
"autoincrement": false
|
||||||
},
|
},
|
||||||
"description": {
|
"description": {
|
||||||
@@ -257,12 +264,12 @@
|
|||||||
},
|
},
|
||||||
"indexes": {},
|
"indexes": {},
|
||||||
"foreignKeys": {
|
"foreignKeys": {
|
||||||
"language_models_provider_id_language_model_providers_id_fk": {
|
"language_models_custom_provider_id_language_model_providers_id_fk": {
|
||||||
"name": "language_models_provider_id_language_model_providers_id_fk",
|
"name": "language_models_custom_provider_id_language_model_providers_id_fk",
|
||||||
"tableFrom": "language_models",
|
"tableFrom": "language_models",
|
||||||
"tableTo": "language_model_providers",
|
"tableTo": "language_model_providers",
|
||||||
"columnsFrom": [
|
"columnsFrom": [
|
||||||
"provider_id"
|
"custom_provider_id"
|
||||||
],
|
],
|
||||||
"columnsTo": [
|
"columnsTo": [
|
||||||
"id"
|
"id"
|
||||||
|
|||||||
@@ -40,8 +40,8 @@
|
|||||||
{
|
{
|
||||||
"idx": 5,
|
"idx": 5,
|
||||||
"version": "6",
|
"version": "6",
|
||||||
"when": 1747091036229,
|
"when": 1747095436506,
|
||||||
"tag": "0005_superb_lady_mastermind",
|
"tag": "0005_clumsy_namor",
|
||||||
"breakpoints": true
|
"breakpoints": true
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -17,9 +17,9 @@ import {
|
|||||||
DropdownMenuSubContent,
|
DropdownMenuSubContent,
|
||||||
} from "@/components/ui/dropdown-menu";
|
} from "@/components/ui/dropdown-menu";
|
||||||
import { useEffect, useState } from "react";
|
import { useEffect, useState } from "react";
|
||||||
import { MODEL_OPTIONS } from "@/constants/models";
|
|
||||||
import { useLocalModels } from "@/hooks/useLocalModels";
|
import { useLocalModels } from "@/hooks/useLocalModels";
|
||||||
import { useLocalLMSModels } from "@/hooks/useLMStudioModels";
|
import { useLocalLMSModels } from "@/hooks/useLMStudioModels";
|
||||||
|
import { useLanguageModelsByProviders } from "@/hooks/useLanguageModelsByProviders";
|
||||||
import { ChevronDown } from "lucide-react";
|
import { ChevronDown } from "lucide-react";
|
||||||
import { LocalModel } from "@/ipc/ipc_types";
|
import { LocalModel } from "@/ipc/ipc_types";
|
||||||
interface ModelPickerProps {
|
interface ModelPickerProps {
|
||||||
@@ -33,6 +33,10 @@ export function ModelPicker({
|
|||||||
}: ModelPickerProps) {
|
}: ModelPickerProps) {
|
||||||
const [open, setOpen] = useState(false);
|
const [open, setOpen] = useState(false);
|
||||||
|
|
||||||
|
// Cloud models from providers
|
||||||
|
const { data: modelsByProviders, isLoading: providersLoading } =
|
||||||
|
useLanguageModelsByProviders();
|
||||||
|
|
||||||
// Ollama Models Hook
|
// Ollama Models Hook
|
||||||
const {
|
const {
|
||||||
models: ollamaModels,
|
models: ollamaModels,
|
||||||
@@ -74,24 +78,35 @@ export function ModelPicker({
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback for cloud models
|
// For cloud models, look up in the modelsByProviders data
|
||||||
return (
|
if (modelsByProviders && modelsByProviders[selectedModel.provider]) {
|
||||||
MODEL_OPTIONS[selectedModel.provider]?.find(
|
const foundModel = modelsByProviders[selectedModel.provider].find(
|
||||||
(model) => model.name === selectedModel.name,
|
(model) => model.apiName === selectedModel.name,
|
||||||
)?.displayName || selectedModel.name
|
|
||||||
);
|
);
|
||||||
|
if (foundModel) {
|
||||||
|
return foundModel.displayName;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback if not found
|
||||||
|
return selectedModel.name;
|
||||||
};
|
};
|
||||||
|
|
||||||
const modelDisplayName = getModelDisplayName();
|
const modelDisplayName = getModelDisplayName();
|
||||||
|
|
||||||
// Flatten the cloud model options
|
// Flatten the cloud models from all providers
|
||||||
const cloudModels = Object.entries(MODEL_OPTIONS).flatMap(
|
const cloudModels =
|
||||||
([provider, models]) =>
|
!providersLoading && modelsByProviders
|
||||||
|
? Object.entries(modelsByProviders).flatMap(([providerId, models]) =>
|
||||||
models.map((model) => ({
|
models.map((model) => ({
|
||||||
...model,
|
name: model.apiName,
|
||||||
provider: provider as ModelProvider,
|
displayName: model.displayName,
|
||||||
|
description: model.description || "",
|
||||||
|
tag: model.tag,
|
||||||
|
provider: providerId as ModelProvider,
|
||||||
})),
|
})),
|
||||||
);
|
)
|
||||||
|
: [];
|
||||||
|
|
||||||
// Determine availability of local models
|
// Determine availability of local models
|
||||||
const hasOllamaModels =
|
const hasOllamaModels =
|
||||||
@@ -119,8 +134,18 @@ export function ModelPicker({
|
|||||||
{/* Increased width slightly */}
|
{/* Increased width slightly */}
|
||||||
<DropdownMenuLabel>Cloud Models</DropdownMenuLabel>
|
<DropdownMenuLabel>Cloud Models</DropdownMenuLabel>
|
||||||
<DropdownMenuSeparator />
|
<DropdownMenuSeparator />
|
||||||
{/* Cloud models */}
|
{/* Cloud models - loading state */}
|
||||||
{cloudModels.map((model) => (
|
{providersLoading ? (
|
||||||
|
<div className="text-xs text-center py-2 text-muted-foreground">
|
||||||
|
Loading models...
|
||||||
|
</div>
|
||||||
|
) : cloudModels.length === 0 ? (
|
||||||
|
<div className="text-xs text-center py-2 text-muted-foreground">
|
||||||
|
No cloud models available
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
/* Cloud models loaded */
|
||||||
|
cloudModels.map((model) => (
|
||||||
<Tooltip key={`${model.provider}-${model.name}`}>
|
<Tooltip key={`${model.provider}-${model.name}`}>
|
||||||
<TooltipTrigger asChild>
|
<TooltipTrigger asChild>
|
||||||
<DropdownMenuItem
|
<DropdownMenuItem
|
||||||
@@ -155,7 +180,8 @@ export function ModelPicker({
|
|||||||
</TooltipTrigger>
|
</TooltipTrigger>
|
||||||
<TooltipContent side="right">{model.description}</TooltipContent>
|
<TooltipContent side="right">{model.description}</TooltipContent>
|
||||||
</Tooltip>
|
</Tooltip>
|
||||||
))}
|
))
|
||||||
|
)}
|
||||||
<DropdownMenuSeparator />
|
<DropdownMenuSeparator />
|
||||||
{/* Local Models Parent SubMenu */}
|
{/* Local Models Parent SubMenu */}
|
||||||
<DropdownMenuSub>
|
<DropdownMenuSub>
|
||||||
|
|||||||
@@ -85,9 +85,11 @@ export const language_models = sqliteTable("language_models", {
|
|||||||
id: integer("id").primaryKey({ autoIncrement: true }),
|
id: integer("id").primaryKey({ autoIncrement: true }),
|
||||||
displayName: text("display_name").notNull(),
|
displayName: text("display_name").notNull(),
|
||||||
apiName: text("api_name").notNull(),
|
apiName: text("api_name").notNull(),
|
||||||
provider_id: text("provider_id")
|
builtinProviderId: text("builtin_provider_id"),
|
||||||
.notNull()
|
customProviderId: text("custom_provider_id").references(
|
||||||
.references(() => language_model_providers.id, { onDelete: "cascade" }),
|
() => language_model_providers.id,
|
||||||
|
{ onDelete: "cascade" },
|
||||||
|
),
|
||||||
description: text("description"),
|
description: text("description"),
|
||||||
max_output_tokens: integer("max_output_tokens"),
|
max_output_tokens: integer("max_output_tokens"),
|
||||||
context_window: integer("context_window"),
|
context_window: integer("context_window"),
|
||||||
@@ -111,7 +113,7 @@ export const languageModelsRelations = relations(
|
|||||||
language_models,
|
language_models,
|
||||||
({ one }) => ({
|
({ one }) => ({
|
||||||
provider: one(language_model_providers, {
|
provider: one(language_model_providers, {
|
||||||
fields: [language_models.provider_id],
|
fields: [language_models.customProviderId],
|
||||||
references: [language_model_providers.id],
|
references: [language_model_providers.id],
|
||||||
}),
|
}),
|
||||||
}),
|
}),
|
||||||
|
|||||||
19
src/hooks/useLanguageModelsByProviders.ts
Normal file
19
src/hooks/useLanguageModelsByProviders.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import { useQuery } from "@tanstack/react-query";
|
||||||
|
import { IpcClient } from "@/ipc/ipc_client";
|
||||||
|
import type { LanguageModel } from "@/ipc/ipc_types";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetches all available language models grouped by their provider IDs.
|
||||||
|
*
|
||||||
|
* @returns TanStack Query result object for the language models organized by provider.
|
||||||
|
*/
|
||||||
|
export function useLanguageModelsByProviders() {
|
||||||
|
const ipcClient = IpcClient.getInstance();
|
||||||
|
|
||||||
|
return useQuery<Record<string, LanguageModel[]>, Error>({
|
||||||
|
queryKey: ["language-models-by-providers"],
|
||||||
|
queryFn: async () => {
|
||||||
|
return ipcClient.getLanguageModelsByProviders();
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import log from "electron-log";
|
|||||||
import {
|
import {
|
||||||
getLanguageModelProviders,
|
getLanguageModelProviders,
|
||||||
getLanguageModels,
|
getLanguageModels,
|
||||||
|
getLanguageModelsByProviders,
|
||||||
} from "../shared/language_model_helpers";
|
} from "../shared/language_model_helpers";
|
||||||
import { db } from "@/db";
|
import { db } from "@/db";
|
||||||
import {
|
import {
|
||||||
@@ -64,7 +65,8 @@ export function registerLanguageModelHandlers() {
|
|||||||
|
|
||||||
// Insert the new provider
|
// Insert the new provider
|
||||||
await db.insert(languageModelProvidersSchema).values({
|
await db.insert(languageModelProvidersSchema).values({
|
||||||
id,
|
// Make sure we will never have accidental collisions with builtin providers
|
||||||
|
id: "custom::" + id,
|
||||||
name,
|
name,
|
||||||
api_base_url: apiBaseUrl,
|
api_base_url: apiBaseUrl,
|
||||||
env_var_name: envVarName || null,
|
env_var_name: envVarName || null,
|
||||||
@@ -108,12 +110,8 @@ export function registerLanguageModelHandlers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if provider exists
|
// Check if provider exists
|
||||||
const provider = db
|
const providers = await getLanguageModelProviders();
|
||||||
.select()
|
const provider = providers.find((p) => p.id === providerId);
|
||||||
.from(languageModelProvidersSchema)
|
|
||||||
.where(eq(languageModelProvidersSchema.id, providerId))
|
|
||||||
.get();
|
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
throw new Error(`Provider with ID "${providerId}" not found`);
|
throw new Error(`Provider with ID "${providerId}" not found`);
|
||||||
}
|
}
|
||||||
@@ -122,7 +120,8 @@ export function registerLanguageModelHandlers() {
|
|||||||
await db.insert(languageModelsSchema).values({
|
await db.insert(languageModelsSchema).values({
|
||||||
displayName,
|
displayName,
|
||||||
apiName,
|
apiName,
|
||||||
provider_id: providerId,
|
builtinProviderId: provider.type === "cloud" ? providerId : undefined,
|
||||||
|
customProviderId: provider.type === "custom" ? providerId : undefined,
|
||||||
description: description || null,
|
description: description || null,
|
||||||
max_output_tokens: maxOutputTokens || null,
|
max_output_tokens: maxOutputTokens || null,
|
||||||
context_window: contextWindow || null,
|
context_window: contextWindow || null,
|
||||||
@@ -182,11 +181,22 @@ export function registerLanguageModelHandlers() {
|
|||||||
`Attempting to delete custom model ${modelApiName} for provider ${providerId}`,
|
`Attempting to delete custom model ${modelApiName} for provider ${providerId}`,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const providers = await getLanguageModelProviders();
|
||||||
|
const provider = providers.find((p) => p.id === providerId);
|
||||||
|
if (!provider) {
|
||||||
|
throw new Error(`Provider with ID "${providerId}" not found`);
|
||||||
|
}
|
||||||
|
if (provider.type === "local") {
|
||||||
|
throw new Error("Local models cannot be deleted");
|
||||||
|
}
|
||||||
const result = db
|
const result = db
|
||||||
.delete(language_models)
|
.delete(language_models)
|
||||||
.where(
|
.where(
|
||||||
and(
|
and(
|
||||||
eq(language_models.provider_id, providerId),
|
provider.type === "cloud"
|
||||||
|
? eq(language_models.builtinProviderId, providerId)
|
||||||
|
: eq(language_models.customProviderId, providerId),
|
||||||
|
|
||||||
eq(language_models.apiName, modelApiName),
|
eq(language_models.apiName, modelApiName),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@@ -243,7 +253,7 @@ export function registerLanguageModelHandlers() {
|
|||||||
// 1. Delete associated models
|
// 1. Delete associated models
|
||||||
const deleteModelsResult = await tx
|
const deleteModelsResult = await tx
|
||||||
.delete(languageModelsSchema)
|
.delete(languageModelsSchema)
|
||||||
.where(eq(languageModelsSchema.provider_id, providerId))
|
.where(eq(languageModelsSchema.customProviderId, providerId))
|
||||||
.run();
|
.run();
|
||||||
logger.info(
|
logger.info(
|
||||||
`Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`,
|
`Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`,
|
||||||
@@ -279,7 +289,26 @@ export function registerLanguageModelHandlers() {
|
|||||||
if (!params || typeof params.providerId !== "string") {
|
if (!params || typeof params.providerId !== "string") {
|
||||||
throw new Error("Invalid parameters: providerId (string) is required.");
|
throw new Error("Invalid parameters: providerId (string) is required.");
|
||||||
}
|
}
|
||||||
return getLanguageModels({ providerId: params.providerId });
|
const providers = await getLanguageModelProviders();
|
||||||
|
const provider = providers.find((p) => p.id === params.providerId);
|
||||||
|
if (!provider) {
|
||||||
|
throw new Error(`Provider with ID "${params.providerId}" not found`);
|
||||||
|
}
|
||||||
|
if (provider.type === "local") {
|
||||||
|
throw new Error("Local models cannot be fetched");
|
||||||
|
}
|
||||||
|
return getLanguageModels(
|
||||||
|
provider.type === "cloud"
|
||||||
|
? { builtinProviderId: params.providerId }
|
||||||
|
: { customProviderId: params.providerId },
|
||||||
|
);
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
handle(
|
||||||
|
"get-language-models-by-providers",
|
||||||
|
async (): Promise<Record<string, LanguageModel[]>> => {
|
||||||
|
return getLanguageModelsByProviders();
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -746,6 +746,12 @@ export class IpcClient {
|
|||||||
return this.ipcRenderer.invoke("get-language-models", params);
|
return this.ipcRenderer.invoke("get-language-models", params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public async getLanguageModelsByProviders(): Promise<
|
||||||
|
Record<string, LanguageModel[]>
|
||||||
|
> {
|
||||||
|
return this.ipcRenderer.invoke("get-language-models-by-providers");
|
||||||
|
}
|
||||||
|
|
||||||
public async createCustomLanguageModelProvider({
|
public async createCustomLanguageModelProvider({
|
||||||
id,
|
id,
|
||||||
name,
|
name,
|
||||||
|
|||||||
@@ -139,23 +139,72 @@ export async function getLanguageModelProviders(): Promise<
|
|||||||
* @param obj An object containing the providerId.
|
* @param obj An object containing the providerId.
|
||||||
* @returns A promise that resolves to an array of LanguageModel objects.
|
* @returns A promise that resolves to an array of LanguageModel objects.
|
||||||
*/
|
*/
|
||||||
export async function getLanguageModels(obj: {
|
export async function getLanguageModels(
|
||||||
providerId: string;
|
obj:
|
||||||
}): Promise<LanguageModel[]> {
|
| {
|
||||||
const { providerId } = obj;
|
customProviderId: string;
|
||||||
|
// builtinProviderId?: undefined;
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
builtinProviderId: string;
|
||||||
|
// customProviderId?: undefined;
|
||||||
|
},
|
||||||
|
): Promise<LanguageModel[]> {
|
||||||
const allProviders = await getLanguageModelProviders();
|
const allProviders = await getLanguageModelProviders();
|
||||||
const provider = allProviders.find((p) => p.id === providerId);
|
const provider = allProviders.find(
|
||||||
|
(p) =>
|
||||||
|
p.id === (obj as { customProviderId: string }).customProviderId ||
|
||||||
|
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
|
||||||
|
);
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
console.warn(`Provider with ID "${providerId}" not found.`);
|
console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`);
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get custom models from DB for all provider types
|
||||||
|
let customModels: LanguageModel[] = [];
|
||||||
|
|
||||||
|
try {
|
||||||
|
const customModelsDb = await db
|
||||||
|
.select({
|
||||||
|
id: languageModelsSchema.id,
|
||||||
|
displayName: languageModelsSchema.displayName,
|
||||||
|
apiName: languageModelsSchema.apiName,
|
||||||
|
description: languageModelsSchema.description,
|
||||||
|
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
||||||
|
contextWindow: languageModelsSchema.context_window,
|
||||||
|
})
|
||||||
|
.from(languageModelsSchema)
|
||||||
|
.where(
|
||||||
|
"customProviderId" in obj
|
||||||
|
? eq(languageModelsSchema.customProviderId, obj.customProviderId)
|
||||||
|
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId),
|
||||||
|
);
|
||||||
|
|
||||||
|
customModels = customModelsDb.map((model) => ({
|
||||||
|
...model,
|
||||||
|
description: model.description ?? "",
|
||||||
|
tag: undefined,
|
||||||
|
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
||||||
|
contextWindow: model.contextWindow ?? undefined,
|
||||||
|
type: "custom",
|
||||||
|
}));
|
||||||
|
} catch (error) {
|
||||||
|
console.error(
|
||||||
|
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`,
|
||||||
|
error,
|
||||||
|
);
|
||||||
|
// Continue with empty custom models array
|
||||||
|
}
|
||||||
|
|
||||||
|
// If it's a cloud provider, also get the hardcoded models
|
||||||
|
let hardcodedModels: LanguageModel[] = [];
|
||||||
|
const providerId = provider.id;
|
||||||
if (provider.type === "cloud") {
|
if (provider.type === "cloud") {
|
||||||
// Check if providerId is a valid key for MODEL_OPTIONS
|
|
||||||
if (providerId in MODEL_OPTIONS) {
|
if (providerId in MODEL_OPTIONS) {
|
||||||
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
||||||
return models.map((model) => ({
|
hardcodedModels = models.map((model) => ({
|
||||||
...model,
|
...model,
|
||||||
apiName: model.name,
|
apiName: model.name,
|
||||||
type: "cloud",
|
type: "cloud",
|
||||||
@@ -164,49 +213,54 @@ export async function getLanguageModels(obj: {
|
|||||||
console.warn(
|
console.warn(
|
||||||
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
|
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
|
||||||
);
|
);
|
||||||
return [];
|
|
||||||
}
|
}
|
||||||
} else if (provider.type === "custom") {
|
}
|
||||||
// Fetch models from the database for this custom provider
|
|
||||||
// Assuming a language_models table with necessary columns and provider_id foreign key
|
|
||||||
try {
|
|
||||||
const customModelsDb = await db
|
|
||||||
.select({
|
|
||||||
id: languageModelsSchema.id,
|
|
||||||
// Map DB columns to LanguageModel fields
|
|
||||||
displayName: languageModelsSchema.displayName,
|
|
||||||
apiName: languageModelsSchema.apiName,
|
|
||||||
// No display_name in DB, use name instead
|
|
||||||
description: languageModelsSchema.description,
|
|
||||||
// No tag in DB
|
|
||||||
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
|
||||||
contextWindow: languageModelsSchema.context_window,
|
|
||||||
})
|
|
||||||
.from(languageModelsSchema)
|
|
||||||
.where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available
|
|
||||||
|
|
||||||
return customModelsDb.map((model) => ({
|
// Merge the models, with custom models taking precedence over hardcoded ones
|
||||||
...model,
|
const mergedModelsMap = new Map<string, LanguageModel>();
|
||||||
// Ensure possibly null fields are handled, provide defaults or undefined if needed
|
|
||||||
description: model.description ?? "",
|
// Add hardcoded models first
|
||||||
tag: undefined, // No tag for custom models from DB
|
for (const model of hardcodedModels) {
|
||||||
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
mergedModelsMap.set(model.apiName, model);
|
||||||
contextWindow: model.contextWindow ?? undefined,
|
}
|
||||||
type: "custom",
|
|
||||||
}));
|
// Then override with custom models
|
||||||
} catch (error) {
|
for (const model of customModels) {
|
||||||
console.error(
|
mergedModelsMap.set(model.apiName, model);
|
||||||
`Error fetching custom models for provider "${providerId}" from DB:`,
|
}
|
||||||
error,
|
|
||||||
|
return Array.from(mergedModelsMap.values());
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetches all language models grouped by their provider IDs.
|
||||||
|
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
|
||||||
|
*/
|
||||||
|
export async function getLanguageModelsByProviders(): Promise<
|
||||||
|
Record<string, LanguageModel[]>
|
||||||
|
> {
|
||||||
|
const providers = await getLanguageModelProviders();
|
||||||
|
|
||||||
|
// Fetch all models concurrently
|
||||||
|
const modelPromises = providers
|
||||||
|
.filter((p) => p.type !== "local")
|
||||||
|
.map(async (provider) => {
|
||||||
|
const models = await getLanguageModels(
|
||||||
|
provider.type === "cloud"
|
||||||
|
? { builtinProviderId: provider.id }
|
||||||
|
: { customProviderId: provider.id },
|
||||||
);
|
);
|
||||||
// Depending on desired behavior, could throw, return empty, or return a specific error state
|
return { providerId: provider.id, models };
|
||||||
return [];
|
});
|
||||||
}
|
|
||||||
} else {
|
// Wait for all requests to complete
|
||||||
// Handle other types like "local" if necessary, currently ignored
|
const results = await Promise.all(modelPromises);
|
||||||
console.warn(
|
|
||||||
`Provider type "${provider.type}" not handled for model fetching.`,
|
// Convert the array of results to a record
|
||||||
);
|
const record: Record<string, LanguageModel[]> = {};
|
||||||
return [];
|
for (const result of results) {
|
||||||
|
record[result.providerId] = result.models;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return record;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import { contextBridge, ipcRenderer } from "electron";
|
|||||||
// Whitelist of valid channels
|
// Whitelist of valid channels
|
||||||
const validInvokeChannels = [
|
const validInvokeChannels = [
|
||||||
"get-language-models",
|
"get-language-models",
|
||||||
|
"get-language-models-by-providers",
|
||||||
"create-custom-language-model",
|
"create-custom-language-model",
|
||||||
"get-language-model-providers",
|
"get-language-model-providers",
|
||||||
"delete-custom-language-model-provider",
|
"delete-custom-language-model-provider",
|
||||||
|
|||||||
Reference in New Issue
Block a user