diff --git a/drizzle/0005_superb_lady_mastermind.sql b/drizzle/0005_clumsy_namor.sql similarity index 77% rename from drizzle/0005_superb_lady_mastermind.sql rename to drizzle/0005_clumsy_namor.sql index 1d5dbd5..4bbb033 100644 --- a/drizzle/0005_superb_lady_mastermind.sql +++ b/drizzle/0005_clumsy_namor.sql @@ -11,11 +11,12 @@ CREATE TABLE `language_models` ( `id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, `display_name` text NOT NULL, `api_name` text NOT NULL, - `provider_id` text NOT NULL, + `builtin_provider_id` text, + `custom_provider_id` text, `description` text, `max_output_tokens` integer, `context_window` integer, `created_at` integer DEFAULT (unixepoch()) NOT NULL, `updated_at` integer DEFAULT (unixepoch()) NOT NULL, - FOREIGN KEY (`provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade + FOREIGN KEY (`custom_provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade ); diff --git a/drizzle/meta/0005_snapshot.json b/drizzle/meta/0005_snapshot.json index 568b2aa..d8b47c6 100644 --- a/drizzle/meta/0005_snapshot.json +++ b/drizzle/meta/0005_snapshot.json @@ -1,7 +1,7 @@ { "version": "6", "dialect": "sqlite", - "id": "424973e5-a102-4b71-8d5d-6160f1609b8c", + "id": "0a47ec41-9477-4457-b3e8-e5ecb3e3a855", "prevId": "ceedb797-6aa3-4a50-b42f-bc85ee08b3df", "tables": { "apps": { @@ -210,11 +210,18 @@ "notNull": true, "autoincrement": false }, - "provider_id": { - "name": "provider_id", + "builtin_provider_id": { + "name": "builtin_provider_id", "type": "text", "primaryKey": false, - "notNull": true, + "notNull": false, + "autoincrement": false + }, + "custom_provider_id": { + "name": "custom_provider_id", + "type": "text", + "primaryKey": false, + "notNull": false, "autoincrement": false }, "description": { @@ -257,12 +264,12 @@ }, "indexes": {}, "foreignKeys": { - "language_models_provider_id_language_model_providers_id_fk": { - "name": "language_models_provider_id_language_model_providers_id_fk", + "language_models_custom_provider_id_language_model_providers_id_fk": { + "name": "language_models_custom_provider_id_language_model_providers_id_fk", "tableFrom": "language_models", "tableTo": "language_model_providers", "columnsFrom": [ - "provider_id" + "custom_provider_id" ], "columnsTo": [ "id" diff --git a/drizzle/meta/_journal.json b/drizzle/meta/_journal.json index 1d98390..0fdec92 100644 --- a/drizzle/meta/_journal.json +++ b/drizzle/meta/_journal.json @@ -40,8 +40,8 @@ { "idx": 5, "version": "6", - "when": 1747091036229, - "tag": "0005_superb_lady_mastermind", + "when": 1747095436506, + "tag": "0005_clumsy_namor", "breakpoints": true } ] diff --git a/src/components/ModelPicker.tsx b/src/components/ModelPicker.tsx index e7b1165..6c414fa 100644 --- a/src/components/ModelPicker.tsx +++ b/src/components/ModelPicker.tsx @@ -17,9 +17,9 @@ import { DropdownMenuSubContent, } from "@/components/ui/dropdown-menu"; import { useEffect, useState } from "react"; -import { MODEL_OPTIONS } from "@/constants/models"; import { useLocalModels } from "@/hooks/useLocalModels"; import { useLocalLMSModels } from "@/hooks/useLMStudioModels"; +import { useLanguageModelsByProviders } from "@/hooks/useLanguageModelsByProviders"; import { ChevronDown } from "lucide-react"; import { LocalModel } from "@/ipc/ipc_types"; interface ModelPickerProps { @@ -33,6 +33,10 @@ export function ModelPicker({ }: ModelPickerProps) { const [open, setOpen] = useState(false); + // Cloud models from providers + const { data: modelsByProviders, isLoading: providersLoading } = + useLanguageModelsByProviders(); + // Ollama Models Hook const { models: ollamaModels, @@ -74,24 +78,35 @@ export function ModelPicker({ ); } - // Fallback for cloud models - return ( - MODEL_OPTIONS[selectedModel.provider]?.find( - (model) => model.name === selectedModel.name, - )?.displayName || selectedModel.name - ); + // For cloud models, look up in the modelsByProviders data + if (modelsByProviders && modelsByProviders[selectedModel.provider]) { + const foundModel = modelsByProviders[selectedModel.provider].find( + (model) => model.apiName === selectedModel.name, + ); + if (foundModel) { + return foundModel.displayName; + } + } + + // Fallback if not found + return selectedModel.name; }; const modelDisplayName = getModelDisplayName(); - // Flatten the cloud model options - const cloudModels = Object.entries(MODEL_OPTIONS).flatMap( - ([provider, models]) => - models.map((model) => ({ - ...model, - provider: provider as ModelProvider, - })), - ); + // Flatten the cloud models from all providers + const cloudModels = + !providersLoading && modelsByProviders + ? Object.entries(modelsByProviders).flatMap(([providerId, models]) => + models.map((model) => ({ + name: model.apiName, + displayName: model.displayName, + description: model.description || "", + tag: model.tag, + provider: providerId as ModelProvider, + })), + ) + : []; // Determine availability of local models const hasOllamaModels = @@ -119,43 +134,54 @@ export function ModelPicker({ {/* Increased width slightly */} Cloud Models - {/* Cloud models */} - {cloudModels.map((model) => ( - - - { - onModelSelect({ - name: model.name, - provider: model.provider, - }); - setOpen(false); - }} - > - - - {model.displayName} - - {model.provider} + {/* Cloud models - loading state */} + {providersLoading ? ( + + Loading models... + + ) : cloudModels.length === 0 ? ( + + No cloud models available + + ) : ( + /* Cloud models loaded */ + cloudModels.map((model) => ( + + + { + onModelSelect({ + name: model.name, + provider: model.provider, + }); + setOpen(false); + }} + > + + + {model.displayName} + + {model.provider} + - - {model.tag && ( - - {model.tag} - - )} - - - - {model.description} - - ))} + {model.tag && ( + + {model.tag} + + )} + + + + {model.description} + + )) + )} {/* Local Models Parent SubMenu */} diff --git a/src/db/schema.ts b/src/db/schema.ts index ff9785c..7847cc6 100644 --- a/src/db/schema.ts +++ b/src/db/schema.ts @@ -85,9 +85,11 @@ export const language_models = sqliteTable("language_models", { id: integer("id").primaryKey({ autoIncrement: true }), displayName: text("display_name").notNull(), apiName: text("api_name").notNull(), - provider_id: text("provider_id") - .notNull() - .references(() => language_model_providers.id, { onDelete: "cascade" }), + builtinProviderId: text("builtin_provider_id"), + customProviderId: text("custom_provider_id").references( + () => language_model_providers.id, + { onDelete: "cascade" }, + ), description: text("description"), max_output_tokens: integer("max_output_tokens"), context_window: integer("context_window"), @@ -111,7 +113,7 @@ export const languageModelsRelations = relations( language_models, ({ one }) => ({ provider: one(language_model_providers, { - fields: [language_models.provider_id], + fields: [language_models.customProviderId], references: [language_model_providers.id], }), }), diff --git a/src/hooks/useLanguageModelsByProviders.ts b/src/hooks/useLanguageModelsByProviders.ts new file mode 100644 index 0000000..1da517d --- /dev/null +++ b/src/hooks/useLanguageModelsByProviders.ts @@ -0,0 +1,19 @@ +import { useQuery } from "@tanstack/react-query"; +import { IpcClient } from "@/ipc/ipc_client"; +import type { LanguageModel } from "@/ipc/ipc_types"; + +/** + * Fetches all available language models grouped by their provider IDs. + * + * @returns TanStack Query result object for the language models organized by provider. + */ +export function useLanguageModelsByProviders() { + const ipcClient = IpcClient.getInstance(); + + return useQuery, Error>({ + queryKey: ["language-models-by-providers"], + queryFn: async () => { + return ipcClient.getLanguageModelsByProviders(); + }, + }); +} diff --git a/src/ipc/handlers/language_model_handlers.ts b/src/ipc/handlers/language_model_handlers.ts index 296d5c5..729fd54 100644 --- a/src/ipc/handlers/language_model_handlers.ts +++ b/src/ipc/handlers/language_model_handlers.ts @@ -9,6 +9,7 @@ import log from "electron-log"; import { getLanguageModelProviders, getLanguageModels, + getLanguageModelsByProviders, } from "../shared/language_model_helpers"; import { db } from "@/db"; import { @@ -64,7 +65,8 @@ export function registerLanguageModelHandlers() { // Insert the new provider await db.insert(languageModelProvidersSchema).values({ - id, + // Make sure we will never have accidental collisions with builtin providers + id: "custom::" + id, name, api_base_url: apiBaseUrl, env_var_name: envVarName || null, @@ -108,12 +110,8 @@ export function registerLanguageModelHandlers() { } // Check if provider exists - const provider = db - .select() - .from(languageModelProvidersSchema) - .where(eq(languageModelProvidersSchema.id, providerId)) - .get(); - + const providers = await getLanguageModelProviders(); + const provider = providers.find((p) => p.id === providerId); if (!provider) { throw new Error(`Provider with ID "${providerId}" not found`); } @@ -122,7 +120,8 @@ export function registerLanguageModelHandlers() { await db.insert(languageModelsSchema).values({ displayName, apiName, - provider_id: providerId, + builtinProviderId: provider.type === "cloud" ? providerId : undefined, + customProviderId: provider.type === "custom" ? providerId : undefined, description: description || null, max_output_tokens: maxOutputTokens || null, context_window: contextWindow || null, @@ -182,11 +181,22 @@ export function registerLanguageModelHandlers() { `Attempting to delete custom model ${modelApiName} for provider ${providerId}`, ); + const providers = await getLanguageModelProviders(); + const provider = providers.find((p) => p.id === providerId); + if (!provider) { + throw new Error(`Provider with ID "${providerId}" not found`); + } + if (provider.type === "local") { + throw new Error("Local models cannot be deleted"); + } const result = db .delete(language_models) .where( and( - eq(language_models.provider_id, providerId), + provider.type === "cloud" + ? eq(language_models.builtinProviderId, providerId) + : eq(language_models.customProviderId, providerId), + eq(language_models.apiName, modelApiName), ), ) @@ -243,7 +253,7 @@ export function registerLanguageModelHandlers() { // 1. Delete associated models const deleteModelsResult = await tx .delete(languageModelsSchema) - .where(eq(languageModelsSchema.provider_id, providerId)) + .where(eq(languageModelsSchema.customProviderId, providerId)) .run(); logger.info( `Deleted ${deleteModelsResult.changes} model(s) associated with provider ${providerId}`, @@ -279,7 +289,26 @@ export function registerLanguageModelHandlers() { if (!params || typeof params.providerId !== "string") { throw new Error("Invalid parameters: providerId (string) is required."); } - return getLanguageModels({ providerId: params.providerId }); + const providers = await getLanguageModelProviders(); + const provider = providers.find((p) => p.id === params.providerId); + if (!provider) { + throw new Error(`Provider with ID "${params.providerId}" not found`); + } + if (provider.type === "local") { + throw new Error("Local models cannot be fetched"); + } + return getLanguageModels( + provider.type === "cloud" + ? { builtinProviderId: params.providerId } + : { customProviderId: params.providerId }, + ); + }, + ); + + handle( + "get-language-models-by-providers", + async (): Promise> => { + return getLanguageModelsByProviders(); }, ); } diff --git a/src/ipc/ipc_client.ts b/src/ipc/ipc_client.ts index 24f5192..542628d 100644 --- a/src/ipc/ipc_client.ts +++ b/src/ipc/ipc_client.ts @@ -746,6 +746,12 @@ export class IpcClient { return this.ipcRenderer.invoke("get-language-models", params); } + public async getLanguageModelsByProviders(): Promise< + Record + > { + return this.ipcRenderer.invoke("get-language-models-by-providers"); + } + public async createCustomLanguageModelProvider({ id, name, diff --git a/src/ipc/shared/language_model_helpers.ts b/src/ipc/shared/language_model_helpers.ts index f53b4fd..310336e 100644 --- a/src/ipc/shared/language_model_helpers.ts +++ b/src/ipc/shared/language_model_helpers.ts @@ -139,23 +139,72 @@ export async function getLanguageModelProviders(): Promise< * @param obj An object containing the providerId. * @returns A promise that resolves to an array of LanguageModel objects. */ -export async function getLanguageModels(obj: { - providerId: string; -}): Promise { - const { providerId } = obj; +export async function getLanguageModels( + obj: + | { + customProviderId: string; + // builtinProviderId?: undefined; + } + | { + builtinProviderId: string; + // customProviderId?: undefined; + }, +): Promise { const allProviders = await getLanguageModelProviders(); - const provider = allProviders.find((p) => p.id === providerId); + const provider = allProviders.find( + (p) => + p.id === (obj as { customProviderId: string }).customProviderId || + p.id === (obj as { builtinProviderId: string }).builtinProviderId, + ); if (!provider) { - console.warn(`Provider with ID "${providerId}" not found.`); + console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`); return []; } + // Get custom models from DB for all provider types + let customModels: LanguageModel[] = []; + + try { + const customModelsDb = await db + .select({ + id: languageModelsSchema.id, + displayName: languageModelsSchema.displayName, + apiName: languageModelsSchema.apiName, + description: languageModelsSchema.description, + maxOutputTokens: languageModelsSchema.max_output_tokens, + contextWindow: languageModelsSchema.context_window, + }) + .from(languageModelsSchema) + .where( + "customProviderId" in obj + ? eq(languageModelsSchema.customProviderId, obj.customProviderId) + : eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId), + ); + + customModels = customModelsDb.map((model) => ({ + ...model, + description: model.description ?? "", + tag: undefined, + maxOutputTokens: model.maxOutputTokens ?? undefined, + contextWindow: model.contextWindow ?? undefined, + type: "custom", + })); + } catch (error) { + console.error( + `Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`, + error, + ); + // Continue with empty custom models array + } + + // If it's a cloud provider, also get the hardcoded models + let hardcodedModels: LanguageModel[] = []; + const providerId = provider.id; if (provider.type === "cloud") { - // Check if providerId is a valid key for MODEL_OPTIONS if (providerId in MODEL_OPTIONS) { const models = MODEL_OPTIONS[providerId as RegularModelProvider] || []; - return models.map((model) => ({ + hardcodedModels = models.map((model) => ({ ...model, apiName: model.name, type: "cloud", @@ -164,49 +213,54 @@ export async function getLanguageModels(obj: { console.warn( `Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`, ); - return []; } - } else if (provider.type === "custom") { - // Fetch models from the database for this custom provider - // Assuming a language_models table with necessary columns and provider_id foreign key - try { - const customModelsDb = await db - .select({ - id: languageModelsSchema.id, - // Map DB columns to LanguageModel fields - displayName: languageModelsSchema.displayName, - apiName: languageModelsSchema.apiName, - // No display_name in DB, use name instead - description: languageModelsSchema.description, - // No tag in DB - maxOutputTokens: languageModelsSchema.max_output_tokens, - contextWindow: languageModelsSchema.context_window, - }) - .from(languageModelsSchema) - .where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available - - return customModelsDb.map((model) => ({ - ...model, - // Ensure possibly null fields are handled, provide defaults or undefined if needed - description: model.description ?? "", - tag: undefined, // No tag for custom models from DB - maxOutputTokens: model.maxOutputTokens ?? undefined, - contextWindow: model.contextWindow ?? undefined, - type: "custom", - })); - } catch (error) { - console.error( - `Error fetching custom models for provider "${providerId}" from DB:`, - error, - ); - // Depending on desired behavior, could throw, return empty, or return a specific error state - return []; - } - } else { - // Handle other types like "local" if necessary, currently ignored - console.warn( - `Provider type "${provider.type}" not handled for model fetching.`, - ); - return []; } + + // Merge the models, with custom models taking precedence over hardcoded ones + const mergedModelsMap = new Map(); + + // Add hardcoded models first + for (const model of hardcodedModels) { + mergedModelsMap.set(model.apiName, model); + } + + // Then override with custom models + for (const model of customModels) { + mergedModelsMap.set(model.apiName, model); + } + + return Array.from(mergedModelsMap.values()); +} + +/** + * Fetches all language models grouped by their provider IDs. + * @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects. + */ +export async function getLanguageModelsByProviders(): Promise< + Record +> { + const providers = await getLanguageModelProviders(); + + // Fetch all models concurrently + const modelPromises = providers + .filter((p) => p.type !== "local") + .map(async (provider) => { + const models = await getLanguageModels( + provider.type === "cloud" + ? { builtinProviderId: provider.id } + : { customProviderId: provider.id }, + ); + return { providerId: provider.id, models }; + }); + + // Wait for all requests to complete + const results = await Promise.all(modelPromises); + + // Convert the array of results to a record + const record: Record = {}; + for (const result of results) { + record[result.providerId] = result.models; + } + + return record; } diff --git a/src/preload.ts b/src/preload.ts index 82ac7c1..77b0e44 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -6,6 +6,7 @@ import { contextBridge, ipcRenderer } from "electron"; // Whitelist of valid channels const validInvokeChannels = [ "get-language-models", + "get-language-models-by-providers", "create-custom-language-model", "get-language-model-providers", "delete-custom-language-model-provider",