diff --git a/src/components/CreateCustomProviderDialog.tsx b/src/components/CreateCustomProviderDialog.tsx new file mode 100644 index 0000000..d8dcd19 --- /dev/null +++ b/src/components/CreateCustomProviderDialog.tsx @@ -0,0 +1,167 @@ +import React, { useState } from "react"; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, +} from "@/components/ui/dialog"; +import { Button } from "@/components/ui/button"; +import { Input } from "@/components/ui/input"; +import { Label } from "@/components/ui/label"; +import { Loader2 } from "lucide-react"; +import { useCustomLanguageModelProvider } from "@/hooks/useCustomLanguageModelProvider"; + +interface CreateCustomProviderDialogProps { + isOpen: boolean; + onClose: () => void; + onSuccess: () => void; +} + +export function CreateCustomProviderDialog({ + isOpen, + onClose, + onSuccess, +}: CreateCustomProviderDialogProps) { + const [id, setId] = useState(""); + const [name, setName] = useState(""); + const [apiBaseUrl, setApiBaseUrl] = useState(""); + const [envVarName, setEnvVarName] = useState(""); + const [errorMessage, setErrorMessage] = useState(""); + + const { createProvider, isCreating, error } = + useCustomLanguageModelProvider(); + + const handleSubmit = async (e: React.FormEvent) => { + e.preventDefault(); + setErrorMessage(""); + + try { + await createProvider({ + id: id.trim(), + name: name.trim(), + apiBaseUrl: apiBaseUrl.trim(), + envVarName: envVarName.trim() || undefined, + }); + + // Reset form + setId(""); + setName(""); + setApiBaseUrl(""); + setEnvVarName(""); + + onSuccess(); + } catch (error) { + setErrorMessage( + error instanceof Error + ? error.message + : "Failed to create custom provider", + ); + } + }; + + const handleClose = () => { + if (!isCreating) { + setErrorMessage(""); + onClose(); + } + }; + + return ( + + + + Add Custom Provider + + Connect to a custom language model provider API + + + +
+
+ + setId(e.target.value)} + placeholder="E.g., my-provider" + required + disabled={isCreating} + /> +

+ A unique identifier for this provider (no spaces). +

+
+ +
+ + setName(e.target.value)} + placeholder="E.g., My Provider" + required + disabled={isCreating} + /> +

+ The name that will be displayed in the UI. +

+
+ +
+ + setApiBaseUrl(e.target.value)} + placeholder="E.g., https://api.example.com/v1" + required + disabled={isCreating} + /> +

+ The base URL for the API endpoint. +

+
+ +
+ + setEnvVarName(e.target.value)} + placeholder="E.g., MY_PROVIDER_API_KEY" + disabled={isCreating} + /> +

+ Environment variable name for the API key. +

+
+ + {(errorMessage || error) && ( +
+ {errorMessage || + (error instanceof Error + ? error.message + : "Failed to create custom provider")} +
+ )} + +
+ + +
+
+
+
+ ); +} diff --git a/src/components/ProviderSettings.tsx b/src/components/ProviderSettings.tsx index 71e8d60..c757752 100644 --- a/src/components/ProviderSettings.tsx +++ b/src/components/ProviderSettings.tsx @@ -9,19 +9,24 @@ import { providerSettingsRoute } from "@/routes/settings/providers/$provider"; import type { LanguageModelProvider } from "@/ipc/ipc_types"; import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders"; -import { GiftIcon } from "lucide-react"; +import { GiftIcon, PlusIcon } from "lucide-react"; import { Skeleton } from "./ui/skeleton"; import { Alert, AlertDescription, AlertTitle } from "./ui/alert"; import { AlertTriangle } from "lucide-react"; +import { useState } from "react"; + +import { CreateCustomProviderDialog } from "./CreateCustomProviderDialog"; export function ProviderSettingsGrid() { const navigate = useNavigate(); + const [isDialogOpen, setIsDialogOpen] = useState(false); const { data: providers, isLoading, error, isProviderSetup, + refetch, } = useLanguageModelProviders(); const handleProviderClick = (providerId: string) => { @@ -100,7 +105,32 @@ export function ProviderSettingsGrid() { ); })} + + {/* Add custom provider button */} + setIsDialogOpen(true)} + > + + + + Add custom provider + + + Connect to a custom LLM API endpoint + + + + + setIsDialogOpen(false)} + onSuccess={() => { + setIsDialogOpen(false); + refetch(); + }} + /> ); } diff --git a/src/hooks/useCustomLanguageModelProvider.ts b/src/hooks/useCustomLanguageModelProvider.ts new file mode 100644 index 0000000..d8c00f4 --- /dev/null +++ b/src/hooks/useCustomLanguageModelProvider.ts @@ -0,0 +1,58 @@ +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { IpcClient } from "@/ipc/ipc_client"; +import type { LanguageModelProvider } from "@/ipc/ipc_types"; +import { showError } from "@/lib/toast"; + +export interface CreateCustomLanguageModelProviderParams { + id: string; + name: string; + apiBaseUrl: string; + envVarName?: string; +} + +export function useCustomLanguageModelProvider() { + const queryClient = useQueryClient(); + const ipcClient = IpcClient.getInstance(); + + const createProviderMutation = useMutation({ + mutationFn: async ( + params: CreateCustomLanguageModelProviderParams, + ): Promise => { + if (!params.id.trim()) { + throw new Error("Provider ID is required"); + } + if (!params.name.trim()) { + throw new Error("Provider name is required"); + } + if (!params.apiBaseUrl.trim()) { + throw new Error("API base URL is required"); + } + + return ipcClient.createCustomLanguageModelProvider({ + id: params.id.trim(), + name: params.name.trim(), + apiBaseUrl: params.apiBaseUrl.trim(), + envVarName: params.envVarName?.trim() || undefined, + }); + }, + onSuccess: () => { + // Invalidate and refetch + queryClient.invalidateQueries({ queryKey: ["languageModelProviders"] }); + }, + onError: (error) => { + showError(error); + }, + }); + + const createProvider = async ( + params: CreateCustomLanguageModelProviderParams, + ): Promise => { + return createProviderMutation.mutateAsync(params); + }; + + return { + createProvider, + isCreating: createProviderMutation.isPending, + error: createProviderMutation.error, + }; +} diff --git a/src/ipc/handlers/language_model_handlers.ts b/src/ipc/handlers/language_model_handlers.ts index 6765bfc..0c45725 100644 --- a/src/ipc/handlers/language_model_handlers.ts +++ b/src/ipc/handlers/language_model_handlers.ts @@ -2,10 +2,21 @@ import type { LanguageModelProvider } from "@/ipc/ipc_types"; import { createLoggedHandler } from "./safe_handle"; import log from "electron-log"; import { getLanguageModelProviders } from "../shared/language_model_helpers"; +import { db } from "@/db"; +import { language_model_providers as languageModelProvidersSchema } from "@/db/schema"; +import { eq } from "drizzle-orm"; +import { IpcMainInvokeEvent } from "electron"; const logger = log.scope("language_model_handlers"); const handle = createLoggedHandler(logger); +export interface CreateCustomLanguageModelProviderParams { + id: string; + name: string; + apiBaseUrl: string; + envVarName?: string; +} + export function registerLanguageModelHandlers() { handle( "get-language-model-providers", @@ -13,4 +24,55 @@ export function registerLanguageModelHandlers() { return getLanguageModelProviders(); }, ); + + handle( + "create-custom-language-model-provider", + async ( + event: IpcMainInvokeEvent, + params: CreateCustomLanguageModelProviderParams, + ): Promise => { + const { id, name, apiBaseUrl, envVarName } = params; + + // Validation + if (!id) { + throw new Error("Provider ID is required"); + } + + if (!name) { + throw new Error("Provider name is required"); + } + + if (!apiBaseUrl) { + throw new Error("API base URL is required"); + } + + // Check if a provider with this ID already exists + const existingProvider = await db + .select() + .from(languageModelProvidersSchema) + .where(eq(languageModelProvidersSchema.id, id)) + .get(); + + if (existingProvider) { + throw new Error(`A provider with ID "${id}" already exists`); + } + + // Insert the new provider + await db.insert(languageModelProvidersSchema).values({ + id, + name, + api_base_url: apiBaseUrl, + env_var_name: envVarName || null, + }); + + // Return the newly created provider + return { + id, + name, + apiBaseUrl, + envVarName, + type: "custom", + }; + }, + ); } diff --git a/src/ipc/ipc_client.ts b/src/ipc/ipc_client.ts index 0416655..49817a0 100644 --- a/src/ipc/ipc_client.ts +++ b/src/ipc/ipc_client.ts @@ -732,5 +732,24 @@ export class IpcClient { return this.ipcRenderer.invoke("get-language-model-providers"); } + public async createCustomLanguageModelProvider({ + id, + name, + apiBaseUrl, + envVarName, + }: { + id: string; + name: string; + apiBaseUrl: string; + envVarName?: string; + }): Promise { + return this.ipcRenderer.invoke("create-custom-language-model-provider", { + id, + name, + apiBaseUrl, + envVarName, + }); + } + // --- End window control methods --- } diff --git a/src/preload.ts b/src/preload.ts index e304d96..872dadf 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -6,6 +6,7 @@ import { contextBridge, ipcRenderer } from "electron"; // Whitelist of valid channels const validInvokeChannels = [ "get-language-model-providers", + "create-custom-language-model-provider", "chat:add-dep", "chat:message", "chat:cancel", diff --git a/src/renderer.tsx b/src/renderer.tsx index e504416..c36f8c1 100644 --- a/src/renderer.tsx +++ b/src/renderer.tsx @@ -32,6 +32,9 @@ const queryClient = new QueryClient({ queries: { retry: false, }, + mutations: { + retry: false, + }, }, queryCache: new QueryCache({ onError: (error, query) => {