Allow creating custom providers (#132)

This commit is contained in:
Will Chen
2025-05-12 15:04:42 -07:00
committed by GitHub
parent cd7eaa8ece
commit 642895f0ba
7 changed files with 341 additions and 1 deletions

View File

@@ -0,0 +1,167 @@
import React, { useState } from "react";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
DialogDescription,
} from "@/components/ui/dialog";
import { Button } from "@/components/ui/button";
import { Input } from "@/components/ui/input";
import { Label } from "@/components/ui/label";
import { Loader2 } from "lucide-react";
import { useCustomLanguageModelProvider } from "@/hooks/useCustomLanguageModelProvider";
interface CreateCustomProviderDialogProps {
isOpen: boolean;
onClose: () => void;
onSuccess: () => void;
}
export function CreateCustomProviderDialog({
isOpen,
onClose,
onSuccess,
}: CreateCustomProviderDialogProps) {
const [id, setId] = useState("");
const [name, setName] = useState("");
const [apiBaseUrl, setApiBaseUrl] = useState("");
const [envVarName, setEnvVarName] = useState("");
const [errorMessage, setErrorMessage] = useState("");
const { createProvider, isCreating, error } =
useCustomLanguageModelProvider();
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setErrorMessage("");
try {
await createProvider({
id: id.trim(),
name: name.trim(),
apiBaseUrl: apiBaseUrl.trim(),
envVarName: envVarName.trim() || undefined,
});
// Reset form
setId("");
setName("");
setApiBaseUrl("");
setEnvVarName("");
onSuccess();
} catch (error) {
setErrorMessage(
error instanceof Error
? error.message
: "Failed to create custom provider",
);
}
};
const handleClose = () => {
if (!isCreating) {
setErrorMessage("");
onClose();
}
};
return (
<Dialog open={isOpen} onOpenChange={handleClose}>
<DialogContent className="sm:max-w-md">
<DialogHeader>
<DialogTitle>Add Custom Provider</DialogTitle>
<DialogDescription>
Connect to a custom language model provider API
</DialogDescription>
</DialogHeader>
<form onSubmit={handleSubmit} className="space-y-4 pt-4">
<div className="space-y-2">
<Label htmlFor="id">Provider ID</Label>
<Input
id="id"
value={id}
onChange={(e) => setId(e.target.value)}
placeholder="E.g., my-provider"
required
disabled={isCreating}
/>
<p className="text-xs text-muted-foreground">
A unique identifier for this provider (no spaces).
</p>
</div>
<div className="space-y-2">
<Label htmlFor="name">Display Name</Label>
<Input
id="name"
value={name}
onChange={(e) => setName(e.target.value)}
placeholder="E.g., My Provider"
required
disabled={isCreating}
/>
<p className="text-xs text-muted-foreground">
The name that will be displayed in the UI.
</p>
</div>
<div className="space-y-2">
<Label htmlFor="apiBaseUrl">API Base URL</Label>
<Input
id="apiBaseUrl"
value={apiBaseUrl}
onChange={(e) => setApiBaseUrl(e.target.value)}
placeholder="E.g., https://api.example.com/v1"
required
disabled={isCreating}
/>
<p className="text-xs text-muted-foreground">
The base URL for the API endpoint.
</p>
</div>
<div className="space-y-2">
<Label htmlFor="envVarName">Environment Variable (Optional)</Label>
<Input
id="envVarName"
value={envVarName}
onChange={(e) => setEnvVarName(e.target.value)}
placeholder="E.g., MY_PROVIDER_API_KEY"
disabled={isCreating}
/>
<p className="text-xs text-muted-foreground">
Environment variable name for the API key.
</p>
</div>
{(errorMessage || error) && (
<div className="text-sm text-red-500">
{errorMessage ||
(error instanceof Error
? error.message
: "Failed to create custom provider")}
</div>
)}
<div className="flex justify-end gap-2">
<Button
type="button"
variant="outline"
onClick={handleClose}
disabled={isCreating}
>
Cancel
</Button>
<Button type="submit" disabled={isCreating}>
{isCreating && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
{isCreating ? "Adding..." : "Add Provider"}
</Button>
</div>
</form>
</DialogContent>
</Dialog>
);
}

View File

@@ -9,19 +9,24 @@ import { providerSettingsRoute } from "@/routes/settings/providers/$provider";
import type { LanguageModelProvider } from "@/ipc/ipc_types";
import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders";
import { GiftIcon } from "lucide-react";
import { GiftIcon, PlusIcon } from "lucide-react";
import { Skeleton } from "./ui/skeleton";
import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
import { AlertTriangle } from "lucide-react";
import { useState } from "react";
import { CreateCustomProviderDialog } from "./CreateCustomProviderDialog";
export function ProviderSettingsGrid() {
const navigate = useNavigate();
const [isDialogOpen, setIsDialogOpen] = useState(false);
const {
data: providers,
isLoading,
error,
isProviderSetup,
refetch,
} = useLanguageModelProviders();
const handleProviderClick = (providerId: string) => {
@@ -100,7 +105,32 @@ export function ProviderSettingsGrid() {
</Card>
);
})}
{/* Add custom provider button */}
<Card
className="cursor-pointer transition-all hover:shadow-md border-border border-dashed hover:border-primary/70"
onClick={() => setIsDialogOpen(true)}
>
<CardHeader className="p-4 flex flex-col items-center justify-center h-full">
<PlusIcon className="h-10 w-10 text-muted-foreground mb-2" />
<CardTitle className="text-xl text-center">
Add custom provider
</CardTitle>
<CardDescription className="text-center">
Connect to a custom LLM API endpoint
</CardDescription>
</CardHeader>
</Card>
</div>
<CreateCustomProviderDialog
isOpen={isDialogOpen}
onClose={() => setIsDialogOpen(false)}
onSuccess={() => {
setIsDialogOpen(false);
refetch();
}}
/>
</div>
);
}

View File

@@ -0,0 +1,58 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { IpcClient } from "@/ipc/ipc_client";
import type { LanguageModelProvider } from "@/ipc/ipc_types";
import { showError } from "@/lib/toast";
export interface CreateCustomLanguageModelProviderParams {
id: string;
name: string;
apiBaseUrl: string;
envVarName?: string;
}
export function useCustomLanguageModelProvider() {
const queryClient = useQueryClient();
const ipcClient = IpcClient.getInstance();
const createProviderMutation = useMutation({
mutationFn: async (
params: CreateCustomLanguageModelProviderParams,
): Promise<LanguageModelProvider> => {
if (!params.id.trim()) {
throw new Error("Provider ID is required");
}
if (!params.name.trim()) {
throw new Error("Provider name is required");
}
if (!params.apiBaseUrl.trim()) {
throw new Error("API base URL is required");
}
return ipcClient.createCustomLanguageModelProvider({
id: params.id.trim(),
name: params.name.trim(),
apiBaseUrl: params.apiBaseUrl.trim(),
envVarName: params.envVarName?.trim() || undefined,
});
},
onSuccess: () => {
// Invalidate and refetch
queryClient.invalidateQueries({ queryKey: ["languageModelProviders"] });
},
onError: (error) => {
showError(error);
},
});
const createProvider = async (
params: CreateCustomLanguageModelProviderParams,
): Promise<LanguageModelProvider> => {
return createProviderMutation.mutateAsync(params);
};
return {
createProvider,
isCreating: createProviderMutation.isPending,
error: createProviderMutation.error,
};
}

View File

@@ -2,10 +2,21 @@ import type { LanguageModelProvider } from "@/ipc/ipc_types";
import { createLoggedHandler } from "./safe_handle";
import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers";
import { db } from "@/db";
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
import { eq } from "drizzle-orm";
import { IpcMainInvokeEvent } from "electron";
const logger = log.scope("language_model_handlers");
const handle = createLoggedHandler(logger);
export interface CreateCustomLanguageModelProviderParams {
id: string;
name: string;
apiBaseUrl: string;
envVarName?: string;
}
export function registerLanguageModelHandlers() {
handle(
"get-language-model-providers",
@@ -13,4 +24,55 @@ export function registerLanguageModelHandlers() {
return getLanguageModelProviders();
},
);
handle(
"create-custom-language-model-provider",
async (
event: IpcMainInvokeEvent,
params: CreateCustomLanguageModelProviderParams,
): Promise<LanguageModelProvider> => {
const { id, name, apiBaseUrl, envVarName } = params;
// Validation
if (!id) {
throw new Error("Provider ID is required");
}
if (!name) {
throw new Error("Provider name is required");
}
if (!apiBaseUrl) {
throw new Error("API base URL is required");
}
// Check if a provider with this ID already exists
const existingProvider = await db
.select()
.from(languageModelProvidersSchema)
.where(eq(languageModelProvidersSchema.id, id))
.get();
if (existingProvider) {
throw new Error(`A provider with ID "${id}" already exists`);
}
// Insert the new provider
await db.insert(languageModelProvidersSchema).values({
id,
name,
api_base_url: apiBaseUrl,
env_var_name: envVarName || null,
});
// Return the newly created provider
return {
id,
name,
apiBaseUrl,
envVarName,
type: "custom",
};
},
);
}

View File

@@ -732,5 +732,24 @@ export class IpcClient {
return this.ipcRenderer.invoke("get-language-model-providers");
}
public async createCustomLanguageModelProvider({
id,
name,
apiBaseUrl,
envVarName,
}: {
id: string;
name: string;
apiBaseUrl: string;
envVarName?: string;
}): Promise<LanguageModelProvider> {
return this.ipcRenderer.invoke("create-custom-language-model-provider", {
id,
name,
apiBaseUrl,
envVarName,
});
}
// --- End window control methods ---
}

View File

@@ -6,6 +6,7 @@ import { contextBridge, ipcRenderer } from "electron";
// Whitelist of valid channels
const validInvokeChannels = [
"get-language-model-providers",
"create-custom-language-model-provider",
"chat:add-dep",
"chat:message",
"chat:cancel",

View File

@@ -32,6 +32,9 @@ const queryClient = new QueryClient({
queries: {
retry: false,
},
mutations: {
retry: false,
},
},
queryCache: new QueryCache({
onError: (error, query) => {