From 477015b43d33c2b57cffa73d06e0c1f1e88fde76 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Mon, 12 May 2025 16:00:16 -0700 Subject: [PATCH] allow creating and listing custom language model (#134) --- .cursor/rules/ipc.mdc | 3 +- src/components/CreateCustomModelDialog.tsx | 199 +++++++++++ .../settings/ApiKeyConfiguration.tsx | 194 +++++++++++ src/components/settings/ModelsSection.tsx | 114 +++++++ .../settings/ProviderSettingsHeader.tsx | 115 +++++++ .../settings/ProviderSettingsPage.tsx | 313 +++--------------- src/hooks/useCustomLanguageModelProvider.ts | 12 +- src/hooks/useLanguageModelsForProvider.ts | 29 ++ src/ipc/handlers/language_model_handlers.ts | 99 +++++- src/ipc/ipc_client.ts | 22 +- src/ipc/ipc_types.ts | 27 ++ src/ipc/shared/language_model_helpers.ts | 87 ++++- src/preload.ts | 2 + 13 files changed, 925 insertions(+), 291 deletions(-) create mode 100644 src/components/CreateCustomModelDialog.tsx create mode 100644 src/components/settings/ApiKeyConfiguration.tsx create mode 100644 src/components/settings/ModelsSection.tsx create mode 100644 src/components/settings/ProviderSettingsHeader.tsx create mode 100644 src/hooks/useLanguageModelsForProvider.ts diff --git a/.cursor/rules/ipc.mdc b/.cursor/rules/ipc.mdc index d8d59e1..18a9505 100644 --- a/.cursor/rules/ipc.mdc +++ b/.cursor/rules/ipc.mdc @@ -7,7 +7,8 @@ You're building an Electron app following good security practices # IPC Structure: -- [ipc_client.ts](mdc:src/ipc/ipc_client.ts) - lives in the renderer process and is used to send IPCs to the main process +- [ipc_client.ts](mdc:src/ipc/ipc_client.ts) - lives in the renderer process and is used to send IPCs to the main process. + - to use it just do `IpcClient.getInstance()` - [preload.ts](mdc:src/preload.ts) - allowlist - [ipc_host.ts](mdc:src/ipc/ipc_host.ts) - contains the various IPC handlers attached which are: [app_handlers.ts](mdc:src/ipc/handlers/app_handlers.ts), [chat_stream_handlers.ts](mdc:src/ipc/handlers/chat_stream_handlers.ts), [settings_handlers.ts](mdc:src/ipc/handlers/settings_handlers.ts) etc. diff --git a/src/components/CreateCustomModelDialog.tsx b/src/components/CreateCustomModelDialog.tsx new file mode 100644 index 0000000..2aa7332 --- /dev/null +++ b/src/components/CreateCustomModelDialog.tsx @@ -0,0 +1,199 @@ +import React, { useState } from "react"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, + DialogFooter, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { IpcClient } from "@/ipc/ipc_client"; +import { useMutation } from "@tanstack/react-query"; +import { showError, showSuccess } from "@/lib/toast"; + +interface CreateCustomModelDialogProps { + isOpen: boolean; + onClose: () => void; + onSuccess: () => void; + providerId: string; +} + +export function CreateCustomModelDialog({ + isOpen, + onClose, + onSuccess, + providerId, +}: CreateCustomModelDialogProps) { + const [id, setId] = useState(""); + const [name, setName] = useState(""); + const [description, setDescription] = useState(""); + const [maxOutputTokens, setMaxOutputTokens] = useState(""); + const [contextWindow, setContextWindow] = useState(""); + + const ipcClient = IpcClient.getInstance(); + + const mutation = useMutation({ + mutationFn: async () => { + const params = { + id, + name, + providerId, + description: description || undefined, + maxOutputTokens: maxOutputTokens + ? parseInt(maxOutputTokens, 10) + : undefined, + contextWindow: contextWindow ? parseInt(contextWindow, 10) : undefined, + }; + + if (!params.id) throw new Error("Model ID is required"); + if (!params.name) throw new Error("Model Name is required"); + if (maxOutputTokens && isNaN(params.maxOutputTokens ?? NaN)) + throw new Error("Max Output Tokens must be a valid number"); + if (contextWindow && isNaN(params.contextWindow ?? NaN)) + throw new Error("Context Window must be a valid number"); + + await ipcClient.createCustomLanguageModel(params); + }, + onSuccess: () => { + showSuccess("Custom model created successfully!"); + resetForm(); + onSuccess(); // Refetch or update UI + onClose(); + }, + onError: (error) => { + showError(error); + }, + }); + + const resetForm = () => { + setId(""); + setName(""); + setDescription(""); + setMaxOutputTokens(""); + setContextWindow(""); + }; + + const handleSubmit = (e: React.FormEvent) => { + e.preventDefault(); + mutation.mutate(); + }; + + const handleClose = () => { + if (!mutation.isPending) { + resetForm(); + onClose(); + } + }; + + return ( + + + + Add Custom Model + + Configure a new language model for the selected provider. + + +
+
+
+ + ) => + setId(e.target.value) + } + className="col-span-3" + placeholder="This must match the model expected by the API" + required + disabled={mutation.isPending} + /> +
+
+ + ) => + setName(e.target.value) + } + className="col-span-3" + placeholder="Human-friendly name for the model" + required + disabled={mutation.isPending} + /> +
+
+ + ) => + setDescription(e.target.value) + } + className="col-span-3" + placeholder="Optional: Describe the model's capabilities" + disabled={mutation.isPending} + /> +
+
+ + ) => + setMaxOutputTokens(e.target.value) + } + className="col-span-3" + placeholder="Optional: e.g., 4096" + disabled={mutation.isPending} + /> +
+
+ + ) => + setContextWindow(e.target.value) + } + className="col-span-3" + placeholder="Optional: e.g., 8192" + disabled={mutation.isPending} + /> +
+
+ + + + +
+
+
+ ); +} diff --git a/src/components/settings/ApiKeyConfiguration.tsx b/src/components/settings/ApiKeyConfiguration.tsx new file mode 100644 index 0000000..f26e515 --- /dev/null +++ b/src/components/settings/ApiKeyConfiguration.tsx @@ -0,0 +1,194 @@ +import { Info, KeyRound, Trash2 } from "lucide-react"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { + Accordion, + AccordionContent, + AccordionItem, + AccordionTrigger, +} from "@/components/ui/accordion"; +import { Input } from "@/components/ui/input"; +import { Button } from "@/components/ui/button"; +import { UserSettings } from "@/lib/schemas"; + +// Helper function to mask ENV API keys (move or duplicate if needed elsewhere) +const maskEnvApiKey = (key: string | undefined): string => { + if (!key) return "Not Set"; + if (key.length < 8) return "****"; + return `${key.substring(0, 4)}...${key.substring(key.length - 4)}`; +}; + +interface ApiKeyConfigurationProps { + provider: string; + providerDisplayName: string; + settings: UserSettings | null | undefined; + envVars: Record; + envVarName?: string; + isSaving: boolean; + saveError: string | null; + apiKeyInput: string; + onApiKeyInputChange: (value: string) => void; + onSaveKey: () => Promise; + onDeleteKey: () => Promise; + isDyad: boolean; +} + +export function ApiKeyConfiguration({ + provider, + providerDisplayName, + settings, + envVars, + envVarName, + isSaving, + saveError, + apiKeyInput, + onApiKeyInputChange, + onSaveKey, + onDeleteKey, + isDyad, +}: ApiKeyConfigurationProps) { + const envApiKey = envVarName ? envVars[envVarName] : undefined; + const userApiKey = settings?.providerSettings?.[provider]?.apiKey?.value; + + const isValidUserKey = + !!userApiKey && + !userApiKey.startsWith("Invalid Key") && + userApiKey !== "Not Set"; + const hasEnvKey = !!envApiKey; + + const activeKeySource = isValidUserKey + ? "settings" + : hasEnvKey + ? "env" + : "none"; + + const defaultAccordionValue = []; + if (isValidUserKey || !hasEnvKey) { + defaultAccordionValue.push("settings-key"); + } + if (!isDyad && hasEnvKey) { + defaultAccordionValue.push("env-key"); + } + + return ( + + + + API Key from Settings + + + {isValidUserKey && ( + + + + Current Key (Settings) + + + +

{userApiKey}

+ {activeKeySource === "settings" && ( +

+ This key is currently active. +

+ )} +
+
+ )} + +
+ +
+ onApiKeyInputChange(e.target.value)} + placeholder={`Enter new ${providerDisplayName} API Key here`} + className={`flex-grow ${saveError ? "border-red-500" : ""}`} + /> + +
+ {saveError &&

{saveError}

} +

+ Setting a key here will override the environment variable (if + set). +

+
+
+
+ + {!isDyad && envVarName && ( + + + API Key from Environment Variable + + + {hasEnvKey ? ( + + + Environment Variable Key ({envVarName}) + +

+ {maskEnvApiKey(envApiKey)} +

+ {activeKeySource === "env" && ( +

+ This key is currently active (no settings key set). +

+ )} + {activeKeySource === "settings" && ( +

+ This key is currently being overridden by the key set in + Settings. +

+ )} +
+
+ ) : ( + + + Environment Variable Not Set + + The{" "} + + {envVarName} + {" "} + environment variable is not set. + + + )} +

+ This key is set outside the application. If present, it will be + used only if no key is configured in the Settings section above. + Requires app restart to detect changes. +

+
+
+ )} +
+ ); +} diff --git a/src/components/settings/ModelsSection.tsx b/src/components/settings/ModelsSection.tsx new file mode 100644 index 0000000..e555f30 --- /dev/null +++ b/src/components/settings/ModelsSection.tsx @@ -0,0 +1,114 @@ +import { useState } from "react"; +import { AlertTriangle, PlusIcon } from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { Skeleton } from "@/components/ui/skeleton"; +import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; +import { CreateCustomModelDialog } from "@/components/CreateCustomModelDialog"; +import { useLanguageModelsForProvider } from "@/hooks/useLanguageModelsForProvider"; // Use the hook directly here + +interface ModelsSectionProps { + providerId: string; +} + +export function ModelsSection({ providerId }: ModelsSectionProps) { + const [isCustomModelDialogOpen, setIsCustomModelDialogOpen] = useState(false); + + // Fetch custom models within this component now + const { + data: models, + isLoading: modelsLoading, + error: modelsError, + refetch: refetchModels, + } = useLanguageModelsForProvider(providerId); + + return ( +
+

Models

+

+ Manage specific models available through this provider. +

+ + {/* Custom Models List Area */} + {modelsLoading && ( +
+ + +
+ )} + {modelsError && ( + + + Error Loading Models + {modelsError.message} + + )} + {!modelsLoading && !modelsError && models && models.length > 0 && ( +
+ {models.map((model) => ( +
+
+

+ {model.displayName} +

+ {/* Optional: Add an edit/delete button here later */} +
+

+ {model.name} +

+ {model.description && ( +

+ {model.description} +

+ )} +
+ {model.contextWindow && ( + + Context: {model.contextWindow.toLocaleString()} tokens + + )} + {model.maxOutputTokens && ( + + Max Output: {model.maxOutputTokens.toLocaleString()} tokens + + )} +
+ {model.tag && ( + + {model.tag} + + )} +
+ ))} +
+ )} + {!modelsLoading && !modelsError && (!models || models.length === 0) && ( +

+ No custom models have been added for this provider yet. +

+ )} + {/* End Custom Models List Area */} + + + + {/* Render the dialog */} + setIsCustomModelDialogOpen(false)} + onSuccess={() => { + setIsCustomModelDialogOpen(false); + refetchModels(); // Refetch models on success + }} + providerId={providerId} + /> +
+ ); +} diff --git a/src/components/settings/ProviderSettingsHeader.tsx b/src/components/settings/ProviderSettingsHeader.tsx new file mode 100644 index 0000000..9b90e7d --- /dev/null +++ b/src/components/settings/ProviderSettingsHeader.tsx @@ -0,0 +1,115 @@ +import { + ArrowLeft, + Circle, + ExternalLink, + GiftIcon, + KeyRound, + Settings as SettingsIcon, +} from "lucide-react"; +import { Button } from "@/components/ui/button"; +import { Skeleton } from "@/components/ui/skeleton"; +import { IpcClient } from "@/ipc/ipc_client"; + +interface ProviderSettingsHeaderProps { + providerDisplayName: string; + isConfigured: boolean; + isLoading: boolean; + hasFreeTier?: boolean; + providerWebsiteUrl?: string; + isDyad: boolean; + onBackClick: () => void; +} + +function getKeyButtonText({ + isConfigured, + isDyad, +}: { + isConfigured: boolean; + isDyad: boolean; +}) { + if (isDyad) { + return isConfigured + ? "Manage Dyad Pro Subscription" + : "Setup Dyad Pro Subscription"; + } + return isConfigured ? "Manage API Keys" : "Setup API Key"; +} + +export function ProviderSettingsHeader({ + providerDisplayName, + isConfigured, + isLoading, + hasFreeTier, + providerWebsiteUrl, + isDyad, + onBackClick, +}: ProviderSettingsHeaderProps) { + const handleGetApiKeyClick = (e: React.MouseEvent) => { + e.preventDefault(); + if (providerWebsiteUrl) { + IpcClient.getInstance().openExternalUrl(providerWebsiteUrl); + } + }; + + return ( + <> + + +
+
+

+ Configure {providerDisplayName} +

+ {isLoading ? ( + + ) : ( + + )} + + {isLoading + ? "Loading..." + : isConfigured + ? "Setup Complete" + : "Not Setup"} + +
+ {!isLoading && hasFreeTier && ( + + + Free tier available + + )} +
+ + {providerWebsiteUrl && !isLoading && ( + + )} + + ); +} diff --git a/src/components/settings/ProviderSettingsPage.tsx b/src/components/settings/ProviderSettingsPage.tsx index db05990..32df229 100644 --- a/src/components/settings/ProviderSettingsPage.tsx +++ b/src/components/settings/ProviderSettingsPage.tsx @@ -1,44 +1,26 @@ import { useState, useEffect } from "react"; import { useRouter } from "@tanstack/react-router"; -import { - ArrowLeft, - ExternalLink, - KeyRound, - Info, - Circle, - Settings as SettingsIcon, - GiftIcon, - Trash2, - AlertTriangle, -} from "lucide-react"; +import { ArrowLeft, AlertTriangle } from "lucide-react"; import { useSettings } from "@/hooks/useSettings"; import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders"; + import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert"; import { Skeleton } from "@/components/ui/skeleton"; -import { - Accordion, - AccordionContent, - AccordionItem, - AccordionTrigger, -} from "@/components/ui/accordion"; -import { Input } from "@/components/ui/input"; +import {} from "@/components/ui/accordion"; + import { Button } from "@/components/ui/button"; -import { IpcClient } from "@/ipc/ipc_client"; import { Switch } from "@/components/ui/switch"; import { showError } from "@/lib/toast"; import { UserSettings } from "@/lib/schemas"; +import { ProviderSettingsHeader } from "./ProviderSettingsHeader"; +import { ApiKeyConfiguration } from "./ApiKeyConfiguration"; +import { ModelsSection } from "./ModelsSection"; + interface ProviderSettingsPageProps { provider: string; } -// Helper function to mask ENV API keys (still needed for env vars) -const maskEnvApiKey = (key: string | undefined): string => { - if (!key) return "Not Set"; - if (key.length < 8) return "****"; - return `${key.substring(0, 4)}...${key.substring(key.length - 4)}`; -}; - export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { const { settings, @@ -55,6 +37,11 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { error: providersError, } = useLanguageModelProviders(); + // Find the specific provider data from the fetched list + const providerData = allProviders?.find((p) => p.id === provider); + const supportsCustomModels = + providerData?.type === "custom" || providerData?.type === "cloud"; + const isDyad = provider === "auto"; const [apiKeyInput, setApiKeyInput] = useState(""); @@ -62,9 +49,6 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { const [saveError, setSaveError] = useState(null); const router = useRouter(); - // Find the specific provider data from the fetched list - const providerData = allProviders?.find((p) => p.id === provider); - // Use fetched data (or defaults for Dyad) const providerDisplayName = isDyad ? "Dyad" @@ -74,7 +58,6 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { : providerData?.websiteUrl; const hasFreeTier = isDyad ? false : providerData?.hasFreeTier; const envVarName = isDyad ? undefined : providerData?.envVarName; - const envApiKey = envVarName ? envVars[envVarName] : undefined; // Use provider ID (which is the 'provider' prop) const userApiKey = settings?.providerSettings?.[provider]?.apiKey?.value; @@ -84,25 +67,9 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { !!userApiKey && !userApiKey.startsWith("Invalid Key") && userApiKey !== "Not Set"; - const hasEnvKey = !!envApiKey; + const hasEnvKey = !!(envVarName && envVars[envVarName]); const isConfigured = isValidUserKey || hasEnvKey; // Configured if either is set - // Settings key takes precedence if it's valid - const activeKeySource = isValidUserKey - ? "settings" - : hasEnvKey - ? "env" - : "none"; - - // --- Accordion Logic --- - const defaultAccordionValue = []; - if (isValidUserKey || !hasEnvKey) { - // If user key is set OR env key is NOT set, open the settings accordion item - defaultAccordionValue.push("settings-key"); - } - if (hasEnvKey) { - defaultAccordionValue.push("env-key"); - } // --- Save Handler --- const handleSaveKey = async () => { @@ -182,24 +149,23 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { } }, [apiKeyInput]); - // --- Loading State for Providers --- (Added) + // --- Loading State for Providers --- if (providersLoading) { return (
- {/* Back button */} - {/* Title */} - {/* Get Key button */} -
- - + + + +
+
); } - // --- Error State for Providers --- (Added) + // --- Error State for Providers --- if (providersError) { return (
@@ -260,71 +226,19 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { return (
- - -
-
-

- Configure {providerDisplayName} -

- {settingsLoading ? ( - - ) : ( - - )} - - {settingsLoading - ? "Loading..." - : isConfigured - ? "Setup Complete" - : "Not Setup"} - -
- {!settingsLoading && hasFreeTier && ( - - - Free tier available - - )} -
- - {providerWebsiteUrl && !settingsLoading && ( - - )} + router.history.back()} + /> {settingsLoading ? (
- - +
) : settingsError ? ( @@ -334,136 +248,20 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { ) : ( - - - - API Key from Settings - - - {isValidUserKey && ( - - - - Current Key (Settings) - - - -

{userApiKey}

- {activeKeySource === "settings" && ( -

- This key is currently active. -

- )} -
-
- )} - -
- -
- setApiKeyInput(e.target.value)} - placeholder={`Enter new ${providerDisplayName} API Key here`} - className={`flex-grow ${ - saveError ? "border-red-500" : "" - }`} - /> - -
- {saveError && ( -

{saveError}

- )} -

- Setting a key here will override the environment variable - (if set). -

-
-
-
- - {!isDyad && envVarName && ( - - - API Key from Environment Variable - - - {hasEnvKey ? ( - - - - Environment Variable Key ({envVarName}) - - -

- {maskEnvApiKey(envApiKey)} -

- {activeKeySource === "env" && ( -

- This key is currently active (no settings key set). -

- )} - {activeKeySource === "settings" && ( -

- This key is currently being overridden by the key - set in Settings. -

- )} -
-
- ) : ( - - - Environment Variable Not Set - - The{" "} - - {envVarName} - {" "} - environment variable is not set. - - - )} -

- This key is set outside the application. If present, it will - be used only if no key is configured in the Settings section - above. Requires app restart to detect changes. -

-
-
- )} -
+ )} {isDyad && !settingsLoading && ( @@ -481,22 +279,13 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { />
)} + + {/* Conditionally render CustomModelsSection */} + {supportsCustomModels && providerData && ( + + )} +
); } - -function getKeyButtonText({ - isConfigured, - isDyad, -}: { - isConfigured: boolean; - isDyad: boolean; -}) { - if (isDyad) { - return isConfigured - ? "Manage Dyad Pro Subscription" - : "Setup Dyad Pro Subscription"; - } - return isConfigured ? "Manage API Keys" : "Setup API Key"; -} diff --git a/src/hooks/useCustomLanguageModelProvider.ts b/src/hooks/useCustomLanguageModelProvider.ts index d8c00f4..35ddbd5 100644 --- a/src/hooks/useCustomLanguageModelProvider.ts +++ b/src/hooks/useCustomLanguageModelProvider.ts @@ -1,15 +1,11 @@ import { useMutation, useQueryClient } from "@tanstack/react-query"; import { IpcClient } from "@/ipc/ipc_client"; -import type { LanguageModelProvider } from "@/ipc/ipc_types"; +import type { + CreateCustomLanguageModelProviderParams, + LanguageModelProvider, +} from "@/ipc/ipc_types"; import { showError } from "@/lib/toast"; -export interface CreateCustomLanguageModelProviderParams { - id: string; - name: string; - apiBaseUrl: string; - envVarName?: string; -} - export function useCustomLanguageModelProvider() { const queryClient = useQueryClient(); const ipcClient = IpcClient.getInstance(); diff --git a/src/hooks/useLanguageModelsForProvider.ts b/src/hooks/useLanguageModelsForProvider.ts new file mode 100644 index 0000000..9840eae --- /dev/null +++ b/src/hooks/useLanguageModelsForProvider.ts @@ -0,0 +1,29 @@ +import { useQuery } from "@tanstack/react-query"; +import { IpcClient } from "@/ipc/ipc_client"; +import type { LanguageModel } from "@/ipc/ipc_types"; + +/** + * Fetches the list of available language models for a specific provider. + * + * @param providerId The ID of the language model provider. + * @returns TanStack Query result object for the language models. + */ +export function useLanguageModelsForProvider(providerId: string | undefined) { + const ipcClient = IpcClient.getInstance(); + + return useQuery< + LanguageModel[], + Error // Specify Error type for better error handling + >({ + queryKey: ["language-models", providerId], + queryFn: async () => { + if (!providerId) { + // Avoid calling IPC if providerId is not set + // Return an empty array as it's a query, not an error state + return []; + } + return ipcClient.getLanguageModels({ providerId }); + }, + enabled: !!providerId, + }); +} diff --git a/src/ipc/handlers/language_model_handlers.ts b/src/ipc/handlers/language_model_handlers.ts index 0c45725..83b69d4 100644 --- a/src/ipc/handlers/language_model_handlers.ts +++ b/src/ipc/handlers/language_model_handlers.ts @@ -1,22 +1,26 @@ -import type { LanguageModelProvider } from "@/ipc/ipc_types"; +import type { + LanguageModelProvider, + LanguageModel, + CreateCustomLanguageModelProviderParams, + CreateCustomLanguageModelParams, +} from "@/ipc/ipc_types"; import { createLoggedHandler } from "./safe_handle"; import log from "electron-log"; -import { getLanguageModelProviders } from "../shared/language_model_helpers"; +import { + getLanguageModelProviders, + getLanguageModels, +} from "../shared/language_model_helpers"; import { db } from "@/db"; -import { language_model_providers as languageModelProvidersSchema } from "@/db/schema"; +import { + language_model_providers as languageModelProvidersSchema, + language_models as languageModelsSchema, +} from "@/db/schema"; import { eq } from "drizzle-orm"; import { IpcMainInvokeEvent } from "electron"; const logger = log.scope("language_model_handlers"); const handle = createLoggedHandler(logger); -export interface CreateCustomLanguageModelProviderParams { - id: string; - name: string; - apiBaseUrl: string; - envVarName?: string; -} - export function registerLanguageModelHandlers() { handle( "get-language-model-providers", @@ -47,7 +51,7 @@ export function registerLanguageModelHandlers() { } // Check if a provider with this ID already exists - const existingProvider = await db + const existingProvider = db .select() .from(languageModelProvidersSchema) .where(eq(languageModelProvidersSchema.id, id)) @@ -75,4 +79,77 @@ export function registerLanguageModelHandlers() { }; }, ); + + handle( + "create-custom-language-model", + async ( + event: IpcMainInvokeEvent, + params: CreateCustomLanguageModelParams, + ): Promise => { + const { + id, + name, + providerId, + description, + maxOutputTokens, + contextWindow, + } = params; + + // Validation + if (!id) { + throw new Error("Model ID is required"); + } + if (!name) { + throw new Error("Model name is required"); + } + if (!providerId) { + throw new Error("Provider ID is required"); + } + + // Check if provider exists + const provider = db + .select() + .from(languageModelProvidersSchema) + .where(eq(languageModelProvidersSchema.id, providerId)) + .get(); + + if (!provider) { + throw new Error(`Provider with ID "${providerId}" not found`); + } + + // Check if model ID already exists + const existingModel = db + .select() + .from(languageModelsSchema) + .where(eq(languageModelsSchema.id, id)) + .get(); + + if (existingModel) { + throw new Error(`A model with ID "${id}" already exists`); + } + + // Insert the new model + await db.insert(languageModelsSchema).values({ + id, + name, + provider_id: providerId, + description: description || null, + max_output_tokens: maxOutputTokens || null, + context_window: contextWindow || null, + }); + }, + ); + + handle( + "get-language-models", + async ( + event: IpcMainInvokeEvent, + params: { providerId: string }, + ): Promise => { + if (!params || typeof params.providerId !== "string") { + throw new Error("Invalid parameters: providerId (string) is required."); + } + return getLanguageModels({ providerId: params.providerId }); + }, + ); } diff --git a/src/ipc/ipc_client.ts b/src/ipc/ipc_client.ts index 49817a0..f8a0f87 100644 --- a/src/ipc/ipc_client.ts +++ b/src/ipc/ipc_client.ts @@ -22,6 +22,9 @@ import type { ChatLogsData, BranchResult, LanguageModelProvider, + LanguageModel, + CreateCustomLanguageModelProviderParams, + CreateCustomLanguageModelParams, } from "./ipc_types"; import type { ProposalResult } from "@/lib/schemas"; import { showError } from "@/lib/toast"; @@ -732,17 +735,18 @@ export class IpcClient { return this.ipcRenderer.invoke("get-language-model-providers"); } + public async getLanguageModels(params: { + providerId: string; + }): Promise { + return this.ipcRenderer.invoke("get-language-models", params); + } + public async createCustomLanguageModelProvider({ id, name, apiBaseUrl, envVarName, - }: { - id: string; - name: string; - apiBaseUrl: string; - envVarName?: string; - }): Promise { + }: CreateCustomLanguageModelProviderParams): Promise { return this.ipcRenderer.invoke("create-custom-language-model-provider", { id, name, @@ -751,5 +755,11 @@ export class IpcClient { }); } + public async createCustomLanguageModel( + params: CreateCustomLanguageModelParams, + ): Promise { + await this.ipcRenderer.invoke("create-custom-language-model", params); + } + // --- End window control methods --- } diff --git a/src/ipc/ipc_types.ts b/src/ipc/ipc_types.ts index 29ea347..d7863c3 100644 --- a/src/ipc/ipc_types.ts +++ b/src/ipc/ipc_types.ts @@ -143,3 +143,30 @@ export interface LanguageModelProvider { apiBaseUrl?: string; type: "custom" | "local" | "cloud"; } + +export interface LanguageModel { + id: string; + name: string; + displayName: string; + description: string; + tag?: string; + maxOutputTokens?: number; + contextWindow?: number; + type: "local" | "cloud" | "custom"; +} + +export interface CreateCustomLanguageModelProviderParams { + id: string; + name: string; + apiBaseUrl: string; + envVarName?: string; +} + +export interface CreateCustomLanguageModelParams { + id: string; + name: string; + providerId: string; + description?: string; + maxOutputTokens?: number; + contextWindow?: number; +} diff --git a/src/ipc/shared/language_model_helpers.ts b/src/ipc/shared/language_model_helpers.ts index 2e2699d..a9e78fe 100644 --- a/src/ipc/shared/language_model_helpers.ts +++ b/src/ipc/shared/language_model_helpers.ts @@ -1,7 +1,11 @@ import { db } from "@/db"; -import { language_model_providers as languageModelProvidersSchema } from "@/db/schema"; -import { RegularModelProvider } from "@/constants/models"; -import type { LanguageModelProvider } from "@/ipc/ipc_types"; +import { + language_model_providers as languageModelProvidersSchema, + language_models as languageModelsSchema, +} from "@/db/schema"; +import { MODEL_OPTIONS, RegularModelProvider } from "@/constants/models"; +import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types"; +import { eq } from "drizzle-orm"; export const PROVIDER_TO_ENV_VAR: Record = { openai: "OPENAI_API_KEY", @@ -129,3 +133,80 @@ export async function getLanguageModelProviders(): Promise< return Array.from(mergedProvidersMap.values()); } + +/** + * Fetches language models for a specific provider. + * @param obj An object containing the providerId. + * @returns A promise that resolves to an array of LanguageModel objects. + */ +export async function getLanguageModels(obj: { + providerId: string; +}): Promise { + const { providerId } = obj; + const allProviders = await getLanguageModelProviders(); + const provider = allProviders.find((p) => p.id === providerId); + + if (!provider) { + console.warn(`Provider with ID "${providerId}" not found.`); + return []; + } + + if (provider.type === "cloud") { + // Check if providerId is a valid key for MODEL_OPTIONS + if (providerId in MODEL_OPTIONS) { + const models = MODEL_OPTIONS[providerId as RegularModelProvider] || []; + return models.map((model) => ({ + ...model, + id: model.name, + type: "cloud", + })); + } else { + console.warn( + `Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`, + ); + return []; + } + } else if (provider.type === "custom") { + // Fetch models from the database for this custom provider + // Assuming a language_models table with necessary columns and provider_id foreign key + try { + const customModelsDb = await db + .select({ + id: languageModelsSchema.id, + // Map DB columns to LanguageModel fields + name: languageModelsSchema.name, + // No display_name in DB, use name instead + description: languageModelsSchema.description, + // No tag in DB + maxOutputTokens: languageModelsSchema.max_output_tokens, + contextWindow: languageModelsSchema.context_window, + }) + .from(languageModelsSchema) + .where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available + + return customModelsDb.map((model) => ({ + ...model, + displayName: model.name, // Use name as displayName for custom models + // Ensure possibly null fields are handled, provide defaults or undefined if needed + description: model.description ?? "", + tag: undefined, // No tag for custom models from DB + maxOutputTokens: model.maxOutputTokens ?? undefined, + contextWindow: model.contextWindow ?? undefined, + type: "custom", + })); + } catch (error) { + console.error( + `Error fetching custom models for provider "${providerId}" from DB:`, + error, + ); + // Depending on desired behavior, could throw, return empty, or return a specific error state + return []; + } + } else { + // Handle other types like "local" if necessary, currently ignored + console.warn( + `Provider type "${provider.type}" not handled for model fetching.`, + ); + return []; + } +} diff --git a/src/preload.ts b/src/preload.ts index 872dadf..b22ad26 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -5,6 +5,8 @@ import { contextBridge, ipcRenderer } from "electron"; // Whitelist of valid channels const validInvokeChannels = [ + "get-language-models", + "create-custom-language-model", "get-language-model-providers", "create-custom-language-model-provider", "chat:add-dep",