Allow creating custom providers (#132)
This commit is contained in:
167
src/components/CreateCustomProviderDialog.tsx
Normal file
167
src/components/CreateCustomProviderDialog.tsx
Normal file
@@ -0,0 +1,167 @@
|
||||
import React, { useState } from "react";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
DialogDescription,
|
||||
} from "@/components/ui/dialog";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { Loader2 } from "lucide-react";
|
||||
import { useCustomLanguageModelProvider } from "@/hooks/useCustomLanguageModelProvider";
|
||||
|
||||
interface CreateCustomProviderDialogProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
onSuccess: () => void;
|
||||
}
|
||||
|
||||
export function CreateCustomProviderDialog({
|
||||
isOpen,
|
||||
onClose,
|
||||
onSuccess,
|
||||
}: CreateCustomProviderDialogProps) {
|
||||
const [id, setId] = useState("");
|
||||
const [name, setName] = useState("");
|
||||
const [apiBaseUrl, setApiBaseUrl] = useState("");
|
||||
const [envVarName, setEnvVarName] = useState("");
|
||||
const [errorMessage, setErrorMessage] = useState("");
|
||||
|
||||
const { createProvider, isCreating, error } =
|
||||
useCustomLanguageModelProvider();
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setErrorMessage("");
|
||||
|
||||
try {
|
||||
await createProvider({
|
||||
id: id.trim(),
|
||||
name: name.trim(),
|
||||
apiBaseUrl: apiBaseUrl.trim(),
|
||||
envVarName: envVarName.trim() || undefined,
|
||||
});
|
||||
|
||||
// Reset form
|
||||
setId("");
|
||||
setName("");
|
||||
setApiBaseUrl("");
|
||||
setEnvVarName("");
|
||||
|
||||
onSuccess();
|
||||
} catch (error) {
|
||||
setErrorMessage(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: "Failed to create custom provider",
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
const handleClose = () => {
|
||||
if (!isCreating) {
|
||||
setErrorMessage("");
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={handleClose}>
|
||||
<DialogContent className="sm:max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle>Add Custom Provider</DialogTitle>
|
||||
<DialogDescription>
|
||||
Connect to a custom language model provider API
|
||||
</DialogDescription>
|
||||
</DialogHeader>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-4 pt-4">
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="id">Provider ID</Label>
|
||||
<Input
|
||||
id="id"
|
||||
value={id}
|
||||
onChange={(e) => setId(e.target.value)}
|
||||
placeholder="E.g., my-provider"
|
||||
required
|
||||
disabled={isCreating}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
A unique identifier for this provider (no spaces).
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="name">Display Name</Label>
|
||||
<Input
|
||||
id="name"
|
||||
value={name}
|
||||
onChange={(e) => setName(e.target.value)}
|
||||
placeholder="E.g., My Provider"
|
||||
required
|
||||
disabled={isCreating}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
The name that will be displayed in the UI.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="apiBaseUrl">API Base URL</Label>
|
||||
<Input
|
||||
id="apiBaseUrl"
|
||||
value={apiBaseUrl}
|
||||
onChange={(e) => setApiBaseUrl(e.target.value)}
|
||||
placeholder="E.g., https://api.example.com/v1"
|
||||
required
|
||||
disabled={isCreating}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
The base URL for the API endpoint.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<Label htmlFor="envVarName">Environment Variable (Optional)</Label>
|
||||
<Input
|
||||
id="envVarName"
|
||||
value={envVarName}
|
||||
onChange={(e) => setEnvVarName(e.target.value)}
|
||||
placeholder="E.g., MY_PROVIDER_API_KEY"
|
||||
disabled={isCreating}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
Environment variable name for the API key.
|
||||
</p>
|
||||
</div>
|
||||
|
||||
{(errorMessage || error) && (
|
||||
<div className="text-sm text-red-500">
|
||||
{errorMessage ||
|
||||
(error instanceof Error
|
||||
? error.message
|
||||
: "Failed to create custom provider")}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex justify-end gap-2">
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
onClick={handleClose}
|
||||
disabled={isCreating}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="submit" disabled={isCreating}>
|
||||
{isCreating && <Loader2 className="mr-2 h-4 w-4 animate-spin" />}
|
||||
{isCreating ? "Adding..." : "Add Provider"}
|
||||
</Button>
|
||||
</div>
|
||||
</form>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
}
|
||||
@@ -9,19 +9,24 @@ import { providerSettingsRoute } from "@/routes/settings/providers/$provider";
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
|
||||
import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders";
|
||||
import { GiftIcon } from "lucide-react";
|
||||
import { GiftIcon, PlusIcon } from "lucide-react";
|
||||
import { Skeleton } from "./ui/skeleton";
|
||||
import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
|
||||
import { AlertTriangle } from "lucide-react";
|
||||
import { useState } from "react";
|
||||
|
||||
import { CreateCustomProviderDialog } from "./CreateCustomProviderDialog";
|
||||
|
||||
export function ProviderSettingsGrid() {
|
||||
const navigate = useNavigate();
|
||||
const [isDialogOpen, setIsDialogOpen] = useState(false);
|
||||
|
||||
const {
|
||||
data: providers,
|
||||
isLoading,
|
||||
error,
|
||||
isProviderSetup,
|
||||
refetch,
|
||||
} = useLanguageModelProviders();
|
||||
|
||||
const handleProviderClick = (providerId: string) => {
|
||||
@@ -100,7 +105,32 @@ export function ProviderSettingsGrid() {
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
|
||||
{/* Add custom provider button */}
|
||||
<Card
|
||||
className="cursor-pointer transition-all hover:shadow-md border-border border-dashed hover:border-primary/70"
|
||||
onClick={() => setIsDialogOpen(true)}
|
||||
>
|
||||
<CardHeader className="p-4 flex flex-col items-center justify-center h-full">
|
||||
<PlusIcon className="h-10 w-10 text-muted-foreground mb-2" />
|
||||
<CardTitle className="text-xl text-center">
|
||||
Add custom provider
|
||||
</CardTitle>
|
||||
<CardDescription className="text-center">
|
||||
Connect to a custom LLM API endpoint
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
</Card>
|
||||
</div>
|
||||
|
||||
<CreateCustomProviderDialog
|
||||
isOpen={isDialogOpen}
|
||||
onClose={() => setIsDialogOpen(false)}
|
||||
onSuccess={() => {
|
||||
setIsDialogOpen(false);
|
||||
refetch();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
58
src/hooks/useCustomLanguageModelProvider.ts
Normal file
58
src/hooks/useCustomLanguageModelProvider.ts
Normal file
@@ -0,0 +1,58 @@
|
||||
import { useMutation, useQueryClient } from "@tanstack/react-query";
|
||||
import { IpcClient } from "@/ipc/ipc_client";
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
import { showError } from "@/lib/toast";
|
||||
|
||||
export interface CreateCustomLanguageModelProviderParams {
|
||||
id: string;
|
||||
name: string;
|
||||
apiBaseUrl: string;
|
||||
envVarName?: string;
|
||||
}
|
||||
|
||||
export function useCustomLanguageModelProvider() {
|
||||
const queryClient = useQueryClient();
|
||||
const ipcClient = IpcClient.getInstance();
|
||||
|
||||
const createProviderMutation = useMutation({
|
||||
mutationFn: async (
|
||||
params: CreateCustomLanguageModelProviderParams,
|
||||
): Promise<LanguageModelProvider> => {
|
||||
if (!params.id.trim()) {
|
||||
throw new Error("Provider ID is required");
|
||||
}
|
||||
if (!params.name.trim()) {
|
||||
throw new Error("Provider name is required");
|
||||
}
|
||||
if (!params.apiBaseUrl.trim()) {
|
||||
throw new Error("API base URL is required");
|
||||
}
|
||||
|
||||
return ipcClient.createCustomLanguageModelProvider({
|
||||
id: params.id.trim(),
|
||||
name: params.name.trim(),
|
||||
apiBaseUrl: params.apiBaseUrl.trim(),
|
||||
envVarName: params.envVarName?.trim() || undefined,
|
||||
});
|
||||
},
|
||||
onSuccess: () => {
|
||||
// Invalidate and refetch
|
||||
queryClient.invalidateQueries({ queryKey: ["languageModelProviders"] });
|
||||
},
|
||||
onError: (error) => {
|
||||
showError(error);
|
||||
},
|
||||
});
|
||||
|
||||
const createProvider = async (
|
||||
params: CreateCustomLanguageModelProviderParams,
|
||||
): Promise<LanguageModelProvider> => {
|
||||
return createProviderMutation.mutateAsync(params);
|
||||
};
|
||||
|
||||
return {
|
||||
createProvider,
|
||||
isCreating: createProviderMutation.isPending,
|
||||
error: createProviderMutation.error,
|
||||
};
|
||||
}
|
||||
@@ -2,10 +2,21 @@ import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
import { db } from "@/db";
|
||||
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { IpcMainInvokeEvent } from "electron";
|
||||
|
||||
const logger = log.scope("language_model_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
|
||||
export interface CreateCustomLanguageModelProviderParams {
|
||||
id: string;
|
||||
name: string;
|
||||
apiBaseUrl: string;
|
||||
envVarName?: string;
|
||||
}
|
||||
|
||||
export function registerLanguageModelHandlers() {
|
||||
handle(
|
||||
"get-language-model-providers",
|
||||
@@ -13,4 +24,55 @@ export function registerLanguageModelHandlers() {
|
||||
return getLanguageModelProviders();
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"create-custom-language-model-provider",
|
||||
async (
|
||||
event: IpcMainInvokeEvent,
|
||||
params: CreateCustomLanguageModelProviderParams,
|
||||
): Promise<LanguageModelProvider> => {
|
||||
const { id, name, apiBaseUrl, envVarName } = params;
|
||||
|
||||
// Validation
|
||||
if (!id) {
|
||||
throw new Error("Provider ID is required");
|
||||
}
|
||||
|
||||
if (!name) {
|
||||
throw new Error("Provider name is required");
|
||||
}
|
||||
|
||||
if (!apiBaseUrl) {
|
||||
throw new Error("API base URL is required");
|
||||
}
|
||||
|
||||
// Check if a provider with this ID already exists
|
||||
const existingProvider = await db
|
||||
.select()
|
||||
.from(languageModelProvidersSchema)
|
||||
.where(eq(languageModelProvidersSchema.id, id))
|
||||
.get();
|
||||
|
||||
if (existingProvider) {
|
||||
throw new Error(`A provider with ID "${id}" already exists`);
|
||||
}
|
||||
|
||||
// Insert the new provider
|
||||
await db.insert(languageModelProvidersSchema).values({
|
||||
id,
|
||||
name,
|
||||
api_base_url: apiBaseUrl,
|
||||
env_var_name: envVarName || null,
|
||||
});
|
||||
|
||||
// Return the newly created provider
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
apiBaseUrl,
|
||||
envVarName,
|
||||
type: "custom",
|
||||
};
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
@@ -732,5 +732,24 @@ export class IpcClient {
|
||||
return this.ipcRenderer.invoke("get-language-model-providers");
|
||||
}
|
||||
|
||||
public async createCustomLanguageModelProvider({
|
||||
id,
|
||||
name,
|
||||
apiBaseUrl,
|
||||
envVarName,
|
||||
}: {
|
||||
id: string;
|
||||
name: string;
|
||||
apiBaseUrl: string;
|
||||
envVarName?: string;
|
||||
}): Promise<LanguageModelProvider> {
|
||||
return this.ipcRenderer.invoke("create-custom-language-model-provider", {
|
||||
id,
|
||||
name,
|
||||
apiBaseUrl,
|
||||
envVarName,
|
||||
});
|
||||
}
|
||||
|
||||
// --- End window control methods ---
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ import { contextBridge, ipcRenderer } from "electron";
|
||||
// Whitelist of valid channels
|
||||
const validInvokeChannels = [
|
||||
"get-language-model-providers",
|
||||
"create-custom-language-model-provider",
|
||||
"chat:add-dep",
|
||||
"chat:message",
|
||||
"chat:cancel",
|
||||
|
||||
@@ -32,6 +32,9 @@ const queryClient = new QueryClient({
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
mutations: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
queryCache: new QueryCache({
|
||||
onError: (error, query) => {
|
||||
|
||||
Reference in New Issue
Block a user