Prep for custom models: support reading custom providers (#131)
This commit is contained in:
@@ -66,27 +66,7 @@ The pattern involves a client-side React hook interacting with main process IPC
|
||||
* Contains the core business logic, interacting with databases (e.g., `db`), file system (`fs`), or other main-process services (e.g., `git`).
|
||||
* **Error Handling (Crucial):**
|
||||
* **Handlers MUST `throw new Error("Descriptive error message")` when an operation fails or an invalid state is encountered.** This is the preferred pattern over returning objects like `{ success: false, errorMessage: "..." }`.
|
||||
* Use `try...catch` blocks to handle errors from underlying operations (e.g., database queries, file system access, git commands).
|
||||
* Inside the `catch` block, log the original error for debugging purposes and then `throw` a new, often more user-friendly or context-specific, `Error`.
|
||||
* Example:
|
||||
```typescript
|
||||
ipcMain.handle("list-entities", async (_, { parentId }) => {
|
||||
if (!parentId) {
|
||||
throw new Error("Parent ID is required to list entities.");
|
||||
}
|
||||
try {
|
||||
const entities = await db.query.entities.findMany({ where: eq(entities.parentId, parentId) });
|
||||
if (!entities) {
|
||||
// Or handle as empty list depending on requirements
|
||||
throw new Error(`No entities found for parent ID: ${parentId}`);
|
||||
}
|
||||
return entities;
|
||||
} catch (error: any) {
|
||||
logger.error(`Error listing entities for parent ${parentId}:`, error);
|
||||
throw new Error(`Failed to list entities: ${error.message}`);
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
* **Concurrency (If Applicable):**
|
||||
* For operations that modify shared resources related to a specific entity (like an `appId`), use a locking mechanism (e.g., `withLock(appId, async () => { ... })`) to prevent race conditions.
|
||||
|
||||
|
||||
20
drizzle/0005_left_thor.sql
Normal file
20
drizzle/0005_left_thor.sql
Normal file
@@ -0,0 +1,20 @@
|
||||
CREATE TABLE `language_model_providers` (
|
||||
`id` text PRIMARY KEY NOT NULL,
|
||||
`name` text NOT NULL,
|
||||
`api_base_url` text NOT NULL,
|
||||
`env_var_name` text,
|
||||
`created_at` integer DEFAULT (unixepoch()) NOT NULL,
|
||||
`updated_at` integer DEFAULT (unixepoch()) NOT NULL
|
||||
);
|
||||
--> statement-breakpoint
|
||||
CREATE TABLE `language_models` (
|
||||
`id` integer PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
`name` text NOT NULL,
|
||||
`provider_id` text NOT NULL,
|
||||
`description` text,
|
||||
`max_output_tokens` integer,
|
||||
`context_window` integer,
|
||||
`created_at` integer DEFAULT (unixepoch()) NOT NULL,
|
||||
`updated_at` integer DEFAULT (unixepoch()) NOT NULL,
|
||||
FOREIGN KEY (`provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade
|
||||
);
|
||||
356
drizzle/meta/0005_snapshot.json
Normal file
356
drizzle/meta/0005_snapshot.json
Normal file
@@ -0,0 +1,356 @@
|
||||
{
|
||||
"version": "6",
|
||||
"dialect": "sqlite",
|
||||
"id": "29ca03c0-a5d6-4db2-a84a-03721206fdb4",
|
||||
"prevId": "ceedb797-6aa3-4a50-b42f-bc85ee08b3df",
|
||||
"tables": {
|
||||
"apps": {
|
||||
"name": "apps",
|
||||
"columns": {
|
||||
"id": {
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"primaryKey": true,
|
||||
"notNull": true,
|
||||
"autoincrement": true
|
||||
},
|
||||
"name": {
|
||||
"name": "name",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"path": {
|
||||
"name": "path",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
},
|
||||
"updated_at": {
|
||||
"name": "updated_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
},
|
||||
"github_org": {
|
||||
"name": "github_org",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"github_repo": {
|
||||
"name": "github_repo",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"supabase_project_id": {
|
||||
"name": "supabase_project_id",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
}
|
||||
},
|
||||
"indexes": {},
|
||||
"foreignKeys": {},
|
||||
"compositePrimaryKeys": {},
|
||||
"uniqueConstraints": {},
|
||||
"checkConstraints": {}
|
||||
},
|
||||
"chats": {
|
||||
"name": "chats",
|
||||
"columns": {
|
||||
"id": {
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"primaryKey": true,
|
||||
"notNull": true,
|
||||
"autoincrement": true
|
||||
},
|
||||
"app_id": {
|
||||
"name": "app_id",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"title": {
|
||||
"name": "title",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"initial_commit_hash": {
|
||||
"name": "initial_commit_hash",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
}
|
||||
},
|
||||
"indexes": {},
|
||||
"foreignKeys": {
|
||||
"chats_app_id_apps_id_fk": {
|
||||
"name": "chats_app_id_apps_id_fk",
|
||||
"tableFrom": "chats",
|
||||
"tableTo": "apps",
|
||||
"columnsFrom": [
|
||||
"app_id"
|
||||
],
|
||||
"columnsTo": [
|
||||
"id"
|
||||
],
|
||||
"onDelete": "cascade",
|
||||
"onUpdate": "no action"
|
||||
}
|
||||
},
|
||||
"compositePrimaryKeys": {},
|
||||
"uniqueConstraints": {},
|
||||
"checkConstraints": {}
|
||||
},
|
||||
"language_model_providers": {
|
||||
"name": "language_model_providers",
|
||||
"columns": {
|
||||
"id": {
|
||||
"name": "id",
|
||||
"type": "text",
|
||||
"primaryKey": true,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"name": {
|
||||
"name": "name",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"api_base_url": {
|
||||
"name": "api_base_url",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"env_var_name": {
|
||||
"name": "env_var_name",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
},
|
||||
"updated_at": {
|
||||
"name": "updated_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
}
|
||||
},
|
||||
"indexes": {},
|
||||
"foreignKeys": {},
|
||||
"compositePrimaryKeys": {},
|
||||
"uniqueConstraints": {},
|
||||
"checkConstraints": {}
|
||||
},
|
||||
"language_models": {
|
||||
"name": "language_models",
|
||||
"columns": {
|
||||
"id": {
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"primaryKey": true,
|
||||
"notNull": true,
|
||||
"autoincrement": true
|
||||
},
|
||||
"name": {
|
||||
"name": "name",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"provider_id": {
|
||||
"name": "provider_id",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"description": {
|
||||
"name": "description",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"max_output_tokens": {
|
||||
"name": "max_output_tokens",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"context_window": {
|
||||
"name": "context_window",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
},
|
||||
"updated_at": {
|
||||
"name": "updated_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
}
|
||||
},
|
||||
"indexes": {},
|
||||
"foreignKeys": {
|
||||
"language_models_provider_id_language_model_providers_id_fk": {
|
||||
"name": "language_models_provider_id_language_model_providers_id_fk",
|
||||
"tableFrom": "language_models",
|
||||
"tableTo": "language_model_providers",
|
||||
"columnsFrom": [
|
||||
"provider_id"
|
||||
],
|
||||
"columnsTo": [
|
||||
"id"
|
||||
],
|
||||
"onDelete": "cascade",
|
||||
"onUpdate": "no action"
|
||||
}
|
||||
},
|
||||
"compositePrimaryKeys": {},
|
||||
"uniqueConstraints": {},
|
||||
"checkConstraints": {}
|
||||
},
|
||||
"messages": {
|
||||
"name": "messages",
|
||||
"columns": {
|
||||
"id": {
|
||||
"name": "id",
|
||||
"type": "integer",
|
||||
"primaryKey": true,
|
||||
"notNull": true,
|
||||
"autoincrement": true
|
||||
},
|
||||
"chat_id": {
|
||||
"name": "chat_id",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"role": {
|
||||
"name": "role",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"content": {
|
||||
"name": "content",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false
|
||||
},
|
||||
"approval_state": {
|
||||
"name": "approval_state",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"commit_hash": {
|
||||
"name": "commit_hash",
|
||||
"type": "text",
|
||||
"primaryKey": false,
|
||||
"notNull": false,
|
||||
"autoincrement": false
|
||||
},
|
||||
"created_at": {
|
||||
"name": "created_at",
|
||||
"type": "integer",
|
||||
"primaryKey": false,
|
||||
"notNull": true,
|
||||
"autoincrement": false,
|
||||
"default": "(unixepoch())"
|
||||
}
|
||||
},
|
||||
"indexes": {},
|
||||
"foreignKeys": {
|
||||
"messages_chat_id_chats_id_fk": {
|
||||
"name": "messages_chat_id_chats_id_fk",
|
||||
"tableFrom": "messages",
|
||||
"tableTo": "chats",
|
||||
"columnsFrom": [
|
||||
"chat_id"
|
||||
],
|
||||
"columnsTo": [
|
||||
"id"
|
||||
],
|
||||
"onDelete": "cascade",
|
||||
"onUpdate": "no action"
|
||||
}
|
||||
},
|
||||
"compositePrimaryKeys": {},
|
||||
"uniqueConstraints": {},
|
||||
"checkConstraints": {}
|
||||
}
|
||||
},
|
||||
"views": {},
|
||||
"enums": {},
|
||||
"_meta": {
|
||||
"schemas": {},
|
||||
"tables": {},
|
||||
"columns": {}
|
||||
},
|
||||
"internal": {
|
||||
"indexes": {}
|
||||
}
|
||||
}
|
||||
@@ -36,6 +36,13 @@
|
||||
"when": 1746556241557,
|
||||
"tag": "0004_flawless_jigsaw",
|
||||
"breakpoints": true
|
||||
},
|
||||
{
|
||||
"idx": 5,
|
||||
"version": "6",
|
||||
"when": 1747083562867,
|
||||
"tag": "0005_left_thor",
|
||||
"breakpoints": true
|
||||
}
|
||||
]
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
import { PROVIDERS } from "@/constants/models";
|
||||
import {
|
||||
Card,
|
||||
CardHeader,
|
||||
@@ -7,42 +6,84 @@ import {
|
||||
} from "@/components/ui/card";
|
||||
import { useNavigate } from "@tanstack/react-router";
|
||||
import { providerSettingsRoute } from "@/routes/settings/providers/$provider";
|
||||
import type { ModelProvider } from "@/lib/schemas";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
|
||||
import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders";
|
||||
import { GiftIcon } from "lucide-react";
|
||||
import { Skeleton } from "./ui/skeleton";
|
||||
import { Alert, AlertDescription, AlertTitle } from "./ui/alert";
|
||||
import { AlertTriangle } from "lucide-react";
|
||||
|
||||
export function ProviderSettingsGrid() {
|
||||
const navigate = useNavigate();
|
||||
|
||||
const handleProviderClick = (provider: ModelProvider) => {
|
||||
const {
|
||||
data: providers,
|
||||
isLoading,
|
||||
error,
|
||||
isProviderSetup,
|
||||
} = useLanguageModelProviders();
|
||||
|
||||
const handleProviderClick = (providerId: string) => {
|
||||
navigate({
|
||||
to: providerSettingsRoute.id,
|
||||
params: { provider },
|
||||
params: { provider: providerId },
|
||||
});
|
||||
};
|
||||
|
||||
const { isProviderSetup } = useSettings();
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h2 className="text-2xl font-bold mb-6">AI Providers</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{[1, 2, 3, 4, 5].map((i) => (
|
||||
<Card key={i} className="border-border">
|
||||
<CardHeader className="p-4">
|
||||
<Skeleton className="h-6 w-3/4 mb-2" />
|
||||
<Skeleton className="h-4 w-1/2" />
|
||||
</CardHeader>
|
||||
</Card>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h2 className="text-2xl font-bold mb-6">AI Providers</h2>
|
||||
<Alert variant="destructive">
|
||||
<AlertTriangle className="h-4 w-4" />
|
||||
<AlertTitle>Error</AlertTitle>
|
||||
<AlertDescription>
|
||||
Failed to load AI providers: {error.message}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="p-6">
|
||||
<h2 className="text-2xl font-bold mb-6">AI Providers</h2>
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-4">
|
||||
{Object.entries(PROVIDERS).map(([key, provider]) => {
|
||||
{providers?.map((provider: LanguageModelProvider) => {
|
||||
return (
|
||||
<Card
|
||||
key={key}
|
||||
key={provider.id}
|
||||
className="cursor-pointer transition-all hover:shadow-md border-border"
|
||||
onClick={() => handleProviderClick(key as ModelProvider)}
|
||||
onClick={() => handleProviderClick(provider.id)}
|
||||
>
|
||||
<CardHeader className="p-4">
|
||||
<CardTitle className="text-xl flex items-center justify-between">
|
||||
{provider.displayName}
|
||||
{isProviderSetup(key) ? (
|
||||
{provider.name}
|
||||
{isProviderSetup(provider.id) ? (
|
||||
<span className="ml-3 text-sm font-medium text-green-500 bg-green-50 dark:bg-green-900/30 border border-green-500/50 dark:border-green-500/50 px-2 py-1 rounded-full">
|
||||
Ready
|
||||
</span>
|
||||
) : (
|
||||
<span className="text-sm text-gray-500 bg-gray-50 dark:bg-gray-900 dark:text-gray-300 px-2 py-1 rounded-full">
|
||||
<span className="text-sm text-gray-500 bg-gray-50 dark:bg-gray-900 dark:text-gray-300 px-2 py-1 rounded-full">
|
||||
Needs Setup
|
||||
</span>
|
||||
)}
|
||||
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
} from "lucide-react";
|
||||
import { providerSettingsRoute } from "@/routes/settings/providers/$provider";
|
||||
import { settingsRoute } from "@/routes/settings";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { IpcClient } from "@/ipc/ipc_client";
|
||||
import {
|
||||
@@ -24,6 +24,7 @@ import { Button } from "@/components/ui/button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { NodeSystemInfo } from "@/ipc/ipc_types";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders";
|
||||
type NodeInstallStep =
|
||||
| "install"
|
||||
| "waiting-for-continue"
|
||||
@@ -33,7 +34,8 @@ type NodeInstallStep =
|
||||
export function SetupBanner() {
|
||||
const posthog = usePostHog();
|
||||
const navigate = useNavigate();
|
||||
const { isAnyProviderSetup, loading } = useSettings();
|
||||
const { isAnyProviderSetup, isLoading: loading } =
|
||||
useLanguageModelProviders();
|
||||
const [nodeSystemInfo, setNodeSystemInfo] = useState<NodeSystemInfo | null>(
|
||||
null,
|
||||
);
|
||||
|
||||
@@ -20,7 +20,7 @@ export function HomeChatInput({
|
||||
const posthog = usePostHog();
|
||||
const [inputValue, setInputValue] = useAtom(homeChatInputValueAtom);
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
const { settings, updateSettings, isAnyProviderSetup } = useSettings();
|
||||
const { settings, updateSettings } = useSettings();
|
||||
const { isStreaming } = useStreamChat({
|
||||
hasChatId: false,
|
||||
}); // eslint-disable-line @typescript-eslint/no-unused-vars
|
||||
@@ -137,10 +137,7 @@ export function HomeChatInput({
|
||||
) : (
|
||||
<button
|
||||
onClick={handleCustomSubmit}
|
||||
disabled={
|
||||
(!inputValue.trim() && attachments.length === 0) ||
|
||||
!isAnyProviderSetup()
|
||||
}
|
||||
disabled={!inputValue.trim() && attachments.length === 0}
|
||||
className="px-2 py-2 mt-1 mr-2 hover:bg-(--background-darkest) text-(--sidebar-accent-fg) rounded-lg disabled:opacity-50"
|
||||
title="Start new chat"
|
||||
>
|
||||
|
||||
@@ -3,7 +3,7 @@ import type { Message } from "@/ipc/ipc_types";
|
||||
import { forwardRef, useState } from "react";
|
||||
import ChatMessage from "./ChatMessage";
|
||||
import { SetupBanner } from "../SetupBanner";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
|
||||
import { useStreamChat } from "@/hooks/useStreamChat";
|
||||
import { selectedChatIdAtom } from "@/atoms/chatAtoms";
|
||||
import { useAtomValue, useSetAtom } from "jotai";
|
||||
@@ -14,7 +14,7 @@ import { selectedAppIdAtom } from "@/atoms/appAtoms";
|
||||
import { showError, showWarning } from "@/lib/toast";
|
||||
import { IpcClient } from "@/ipc/ipc_client";
|
||||
import { chatMessagesAtom } from "@/atoms/chatAtoms";
|
||||
|
||||
import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders";
|
||||
interface MessagesListProps {
|
||||
messages: Message[];
|
||||
messagesEndRef: React.RefObject<HTMLDivElement | null>;
|
||||
@@ -25,7 +25,7 @@ export const MessagesList = forwardRef<HTMLDivElement, MessagesListProps>(
|
||||
const appId = useAtomValue(selectedAppIdAtom);
|
||||
const { versions, revertVersion } = useVersions(appId);
|
||||
const { streamMessage, isStreaming } = useStreamChat();
|
||||
const { isAnyProviderSetup } = useSettings();
|
||||
const { isAnyProviderSetup } = useLanguageModelProviders();
|
||||
|
||||
const setMessages = useSetAtom(chatMessagesAtom);
|
||||
const [isUndoLoading, setIsUndoLoading] = useState(false);
|
||||
|
||||
@@ -9,9 +9,10 @@ import {
|
||||
Settings as SettingsIcon,
|
||||
GiftIcon,
|
||||
Trash2,
|
||||
AlertTriangle,
|
||||
} from "lucide-react";
|
||||
import { useSettings } from "@/hooks/useSettings";
|
||||
import { PROVIDER_TO_ENV_VAR, PROVIDERS } from "@/constants/models";
|
||||
import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import {
|
||||
@@ -47,6 +48,13 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) {
|
||||
updateSettings,
|
||||
} = useSettings();
|
||||
|
||||
// Fetch all providers
|
||||
const {
|
||||
data: allProviders,
|
||||
isLoading: providersLoading,
|
||||
error: providersError,
|
||||
} = useLanguageModelProviders();
|
||||
|
||||
const isDyad = provider === "auto";
|
||||
|
||||
const [apiKeyInput, setApiKeyInput] = useState("");
|
||||
@@ -54,16 +62,21 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) {
|
||||
const [saveError, setSaveError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
// Find provider details
|
||||
const providerInfo = PROVIDERS[provider as keyof typeof PROVIDERS];
|
||||
const providerDisplayName =
|
||||
providerInfo?.displayName ||
|
||||
provider.charAt(0).toUpperCase() + provider.slice(1);
|
||||
const providerWebsiteUrl = providerInfo?.websiteUrl;
|
||||
const hasFreeTier = providerInfo?.hasFreeTier;
|
||||
// Find the specific provider data from the fetched list
|
||||
const providerData = allProviders?.find((p) => p.id === provider);
|
||||
|
||||
const envVarName = PROVIDER_TO_ENV_VAR[provider];
|
||||
const envApiKey = envVars[envVarName];
|
||||
// Use fetched data (or defaults for Dyad)
|
||||
const providerDisplayName = isDyad
|
||||
? "Dyad"
|
||||
: (providerData?.name ?? "Unknown Provider");
|
||||
const providerWebsiteUrl = isDyad
|
||||
? "https://academy.dyad.sh/settings"
|
||||
: providerData?.websiteUrl;
|
||||
const hasFreeTier = isDyad ? false : providerData?.hasFreeTier;
|
||||
const envVarName = isDyad ? undefined : providerData?.envVarName;
|
||||
const envApiKey = envVarName ? envVars[envVarName] : undefined;
|
||||
|
||||
// Use provider ID (which is the 'provider' prop)
|
||||
const userApiKey = settings?.providerSettings?.[provider]?.apiKey?.value;
|
||||
|
||||
// --- Configuration Logic --- Updated Priority ---
|
||||
@@ -169,6 +182,81 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) {
|
||||
}
|
||||
}, [apiKeyInput]);
|
||||
|
||||
// --- Loading State for Providers --- (Added)
|
||||
if (providersLoading) {
|
||||
return (
|
||||
<div className="min-h-screen px-8 py-4">
|
||||
<div className="max-w-4xl mx-auto">
|
||||
<Skeleton className="h-8 w-24 mb-4" /> {/* Back button */}
|
||||
<Skeleton className="h-10 w-1/2 mb-6" /> {/* Title */}
|
||||
<Skeleton className="h-10 w-48 mb-4" /> {/* Get Key button */}
|
||||
<div className="space-y-4">
|
||||
<Skeleton className="h-20 w-full" />
|
||||
<Skeleton className="h-20 w-full" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// --- Error State for Providers --- (Added)
|
||||
if (providersError) {
|
||||
return (
|
||||
<div className="min-h-screen px-8 py-4">
|
||||
<div className="max-w-4xl mx-auto">
|
||||
<Button
|
||||
onClick={() => router.history.back()}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="flex items-center gap-2 mb-4 bg-(--background-lightest) py-5"
|
||||
>
|
||||
<ArrowLeft className="h-4 w-4" />
|
||||
Go Back
|
||||
</Button>
|
||||
<h1 className="text-3xl font-bold text-gray-900 dark:text-white mr-3 mb-6">
|
||||
Configure Provider
|
||||
</h1>
|
||||
<Alert variant="destructive">
|
||||
<AlertTriangle className="h-4 w-4" />
|
||||
<AlertTitle>Error Loading Provider Details</AlertTitle>
|
||||
<AlertDescription>
|
||||
Could not load provider data: {providersError.message}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Handle case where provider is not found (e.g., invalid ID in URL)
|
||||
if (!providerData && !isDyad) {
|
||||
return (
|
||||
<div className="min-h-screen px-8 py-4">
|
||||
<div className="max-w-4xl mx-auto">
|
||||
<Button
|
||||
onClick={() => router.history.back()}
|
||||
variant="outline"
|
||||
size="sm"
|
||||
className="flex items-center gap-2 mb-4 bg-(--background-lightest) py-5"
|
||||
>
|
||||
<ArrowLeft className="h-4 w-4" />
|
||||
Go Back
|
||||
</Button>
|
||||
<h1 className="text-3xl font-bold text-gray-900 dark:text-white mr-3 mb-6">
|
||||
Provider Not Found
|
||||
</h1>
|
||||
<Alert variant="destructive">
|
||||
<AlertTriangle className="h-4 w-4" />
|
||||
<AlertTitle>Error</AlertTitle>
|
||||
<AlertDescription>
|
||||
The provider with ID "{provider}" could not be found.
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-screen px-8 py-4">
|
||||
<div className="max-w-4xl mx-auto">
|
||||
@@ -322,7 +410,7 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) {
|
||||
</AccordionContent>
|
||||
</AccordionItem>
|
||||
|
||||
{!isDyad && (
|
||||
{!isDyad && envVarName && (
|
||||
<AccordionItem
|
||||
value="env-key"
|
||||
className="border rounded-lg px-4 bg-(--background-lightest)"
|
||||
|
||||
@@ -8,7 +8,10 @@ export interface ModelOption {
|
||||
contextWindow?: number;
|
||||
}
|
||||
|
||||
type RegularModelProvider = Exclude<ModelProvider, "ollama" | "lmstudio">;
|
||||
export type RegularModelProvider = Exclude<
|
||||
ModelProvider,
|
||||
"ollama" | "lmstudio"
|
||||
>;
|
||||
export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
|
||||
openai: [
|
||||
// https://platform.openai.com/docs/models/gpt-4.1
|
||||
@@ -89,57 +92,6 @@ export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
|
||||
],
|
||||
};
|
||||
|
||||
export const PROVIDERS: Record<
|
||||
RegularModelProvider,
|
||||
{
|
||||
displayName: string;
|
||||
hasFreeTier?: boolean;
|
||||
websiteUrl?: string;
|
||||
gatewayPrefix: string;
|
||||
}
|
||||
> = {
|
||||
openai: {
|
||||
displayName: "OpenAI",
|
||||
hasFreeTier: false,
|
||||
websiteUrl: "https://platform.openai.com/api-keys",
|
||||
gatewayPrefix: "",
|
||||
},
|
||||
anthropic: {
|
||||
displayName: "Anthropic",
|
||||
hasFreeTier: false,
|
||||
websiteUrl: "https://console.anthropic.com/settings/keys",
|
||||
gatewayPrefix: "anthropic/",
|
||||
},
|
||||
google: {
|
||||
displayName: "Google",
|
||||
hasFreeTier: true,
|
||||
websiteUrl: "https://aistudio.google.com/app/apikey",
|
||||
gatewayPrefix: "gemini/",
|
||||
},
|
||||
openrouter: {
|
||||
displayName: "OpenRouter",
|
||||
hasFreeTier: true,
|
||||
websiteUrl: "https://openrouter.ai/settings/keys",
|
||||
gatewayPrefix: "openrouter/",
|
||||
},
|
||||
auto: {
|
||||
displayName: "Dyad",
|
||||
websiteUrl: "https://academy.dyad.sh/settings",
|
||||
gatewayPrefix: "",
|
||||
},
|
||||
};
|
||||
|
||||
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
};
|
||||
|
||||
export const ALLOWED_ENV_VARS = Object.keys(PROVIDER_TO_ENV_VAR).map(
|
||||
(provider) => PROVIDER_TO_ENV_VAR[provider],
|
||||
);
|
||||
|
||||
export const AUTO_MODELS = [
|
||||
{
|
||||
provider: "google",
|
||||
|
||||
@@ -64,3 +64,54 @@ export const messagesRelations = relations(messages, ({ one }) => ({
|
||||
references: [chats.id],
|
||||
}),
|
||||
}));
|
||||
|
||||
export const language_model_providers = sqliteTable(
|
||||
"language_model_providers",
|
||||
{
|
||||
id: text("id").primaryKey(),
|
||||
name: text("name").notNull(),
|
||||
api_base_url: text("api_base_url").notNull(),
|
||||
env_var_name: text("env_var_name"),
|
||||
createdAt: integer("created_at", { mode: "timestamp" })
|
||||
.notNull()
|
||||
.default(sql`(unixepoch())`),
|
||||
updatedAt: integer("updated_at", { mode: "timestamp" })
|
||||
.notNull()
|
||||
.default(sql`(unixepoch())`),
|
||||
},
|
||||
);
|
||||
|
||||
export const language_models = sqliteTable("language_models", {
|
||||
id: integer("id").primaryKey({ autoIncrement: true }),
|
||||
name: text("name").notNull(),
|
||||
provider_id: text("provider_id")
|
||||
.notNull()
|
||||
.references(() => language_model_providers.id, { onDelete: "cascade" }),
|
||||
description: text("description"),
|
||||
max_output_tokens: integer("max_output_tokens"),
|
||||
context_window: integer("context_window"),
|
||||
createdAt: integer("created_at", { mode: "timestamp" })
|
||||
.notNull()
|
||||
.default(sql`(unixepoch())`),
|
||||
updatedAt: integer("updated_at", { mode: "timestamp" })
|
||||
.notNull()
|
||||
.default(sql`(unixepoch())`),
|
||||
});
|
||||
|
||||
// Define relations for new tables
|
||||
export const languageModelProvidersRelations = relations(
|
||||
language_model_providers,
|
||||
({ many }) => ({
|
||||
languageModels: many(language_models),
|
||||
}),
|
||||
);
|
||||
|
||||
export const languageModelsRelations = relations(
|
||||
language_models,
|
||||
({ one }) => ({
|
||||
provider: one(language_model_providers, {
|
||||
fields: [language_models.provider_id],
|
||||
references: [language_model_providers.id],
|
||||
}),
|
||||
}),
|
||||
);
|
||||
|
||||
42
src/hooks/useLanguageModelProviders.ts
Normal file
42
src/hooks/useLanguageModelProviders.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { IpcClient } from "@/ipc/ipc_client";
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
import { useSettings } from "./useSettings";
|
||||
import { cloudProviders } from "@/lib/schemas";
|
||||
|
||||
export function useLanguageModelProviders() {
|
||||
const ipcClient = IpcClient.getInstance();
|
||||
const { settings, envVars } = useSettings();
|
||||
|
||||
const queryResult = useQuery<LanguageModelProvider[], Error>({
|
||||
queryKey: ["languageModelProviders"],
|
||||
queryFn: async () => {
|
||||
return ipcClient.getLanguageModelProviders();
|
||||
},
|
||||
});
|
||||
|
||||
const isProviderSetup = (provider: string) => {
|
||||
const providerSettings = settings?.providerSettings[provider];
|
||||
if (queryResult.isLoading) {
|
||||
return false;
|
||||
}
|
||||
if (providerSettings?.apiKey?.value) {
|
||||
return true;
|
||||
}
|
||||
const providerData = queryResult.data?.find((p) => p.id === provider);
|
||||
if (providerData?.envVarName && envVars[providerData.envVarName]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
const isAnyProviderSetup = () => {
|
||||
return cloudProviders.some((provider) => isProviderSetup(provider));
|
||||
};
|
||||
|
||||
return {
|
||||
...queryResult,
|
||||
isProviderSetup,
|
||||
isAnyProviderSetup,
|
||||
};
|
||||
}
|
||||
@@ -2,15 +2,9 @@ import { useState, useEffect, useCallback } from "react";
|
||||
import { useAtom } from "jotai";
|
||||
import { userSettingsAtom, envVarsAtom } from "@/atoms/appAtoms";
|
||||
import { IpcClient } from "@/ipc/ipc_client";
|
||||
import { cloudProviders, type UserSettings } from "@/lib/schemas";
|
||||
import { type UserSettings } from "@/lib/schemas";
|
||||
import { usePostHog } from "posthog-js/react";
|
||||
|
||||
const PROVIDER_TO_ENV_VAR: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
};
|
||||
|
||||
const TELEMETRY_CONSENT_KEY = "dyadTelemetryConsent";
|
||||
const TELEMETRY_USER_ID_KEY = "dyadTelemetryUserId";
|
||||
|
||||
@@ -81,17 +75,6 @@ export function useSettings() {
|
||||
}
|
||||
};
|
||||
|
||||
const isProviderSetup = (provider: string) => {
|
||||
const providerSettings = settings?.providerSettings[provider];
|
||||
if (providerSettings?.apiKey?.value) {
|
||||
return true;
|
||||
}
|
||||
if (envVars[PROVIDER_TO_ENV_VAR[provider]]) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
return {
|
||||
settings,
|
||||
envVars,
|
||||
@@ -99,13 +82,6 @@ export function useSettings() {
|
||||
error,
|
||||
updateSettings,
|
||||
|
||||
isProviderSetup,
|
||||
isAnyProviderSetup: () => {
|
||||
// Technically we should check for ollama and lmstudio being setup, but
|
||||
// practically most users will want to use a cloud provider (at least
|
||||
// some of the time)
|
||||
return cloudProviders.some((provider) => isProviderSetup(provider));
|
||||
},
|
||||
refreshSettings: () => {
|
||||
return loadInitialData();
|
||||
},
|
||||
|
||||
@@ -22,7 +22,6 @@ import {
|
||||
killProcess,
|
||||
removeAppIfCurrentProcess,
|
||||
} from "../utils/process_manager";
|
||||
import { ALLOWED_ENV_VARS } from "../../constants/models";
|
||||
import { getEnvVar } from "../utils/read_env";
|
||||
import { readSettings } from "../../main/settings";
|
||||
|
||||
@@ -33,6 +32,7 @@ import util from "util";
|
||||
import log from "electron-log";
|
||||
import { getSupabaseProjectName } from "../../supabase_admin/supabase_management_client";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
|
||||
const logger = log.scope("app_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
@@ -291,8 +291,11 @@ export function registerAppHandlers() {
|
||||
// Do NOT use handle for this, it contains sensitive information.
|
||||
ipcMain.handle("get-env-vars", async () => {
|
||||
const envVars: Record<string, string | undefined> = {};
|
||||
for (const key of ALLOWED_ENV_VARS) {
|
||||
envVars[key] = getEnvVar(key);
|
||||
const providers = await getLanguageModelProviders();
|
||||
for (const provider of providers) {
|
||||
if (provider.envVarName) {
|
||||
envVars[provider.envVarName] = getEnvVar(provider.envVarName);
|
||||
}
|
||||
}
|
||||
return envVars;
|
||||
});
|
||||
|
||||
@@ -212,7 +212,10 @@ export function registerChatStreamHandlers() {
|
||||
} else {
|
||||
// Normal AI processing for non-test prompts
|
||||
const settings = readSettings();
|
||||
const modelClient = getModelClient(settings.selectedModel, settings);
|
||||
const modelClient = await getModelClient(
|
||||
settings.selectedModel,
|
||||
settings,
|
||||
);
|
||||
|
||||
// Extract codebase information if app is associated with the chat
|
||||
let codebaseInfo = "";
|
||||
|
||||
16
src/ipc/handlers/language_model_handlers.ts
Normal file
16
src/ipc/handlers/language_model_handlers.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
|
||||
const logger = log.scope("language_model_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
|
||||
export function registerLanguageModelHandlers() {
|
||||
handle(
|
||||
"get-language-model-providers",
|
||||
async (): Promise<LanguageModelProvider[]> => {
|
||||
return getLanguageModelProviders();
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -21,6 +21,7 @@ import type {
|
||||
TokenCountResult,
|
||||
ChatLogsData,
|
||||
BranchResult,
|
||||
LanguageModelProvider,
|
||||
} from "./ipc_types";
|
||||
import type { ProposalResult } from "@/lib/schemas";
|
||||
import { showError } from "@/lib/toast";
|
||||
@@ -724,13 +725,11 @@ export class IpcClient {
|
||||
|
||||
// Get system platform (win32, darwin, linux)
|
||||
public async getSystemPlatform(): Promise<string> {
|
||||
try {
|
||||
const platform = await this.ipcRenderer.invoke("window:get-platform");
|
||||
return platform;
|
||||
} catch (error) {
|
||||
showError(error);
|
||||
throw error;
|
||||
}
|
||||
return this.ipcRenderer.invoke("get-system-platform");
|
||||
}
|
||||
|
||||
public async getLanguageModelProviders(): Promise<LanguageModelProvider[]> {
|
||||
return this.ipcRenderer.invoke("get-language-model-providers");
|
||||
}
|
||||
|
||||
// --- End window control methods ---
|
||||
|
||||
@@ -14,6 +14,8 @@ import { registerTokenCountHandlers } from "./handlers/token_count_handlers";
|
||||
import { registerWindowHandlers } from "./handlers/window_handlers";
|
||||
import { registerUploadHandlers } from "./handlers/upload_handlers";
|
||||
import { registerVersionHandlers } from "./handlers/version_handlers";
|
||||
import { registerLanguageModelHandlers } from "./handlers/language_model_handlers";
|
||||
|
||||
export function registerIpcHandlers() {
|
||||
// Register all IPC handlers by category
|
||||
registerAppHandlers();
|
||||
@@ -32,4 +34,5 @@ export function registerIpcHandlers() {
|
||||
registerWindowHandlers();
|
||||
registerUploadHandlers();
|
||||
registerVersionHandlers();
|
||||
registerLanguageModelHandlers();
|
||||
}
|
||||
|
||||
@@ -132,3 +132,14 @@ export interface ChatLogsData {
|
||||
chat: Chat;
|
||||
codebase: string;
|
||||
}
|
||||
|
||||
export interface LanguageModelProvider {
|
||||
id: string;
|
||||
name: string;
|
||||
hasFreeTier?: boolean;
|
||||
websiteUrl?: string;
|
||||
gatewayPrefix?: string;
|
||||
envVarName?: string;
|
||||
apiBaseUrl?: string;
|
||||
type: "custom" | "local" | "cloud";
|
||||
}
|
||||
|
||||
131
src/ipc/shared/language_model_helpers.ts
Normal file
131
src/ipc/shared/language_model_helpers.ts
Normal file
@@ -0,0 +1,131 @@
|
||||
import { db } from "@/db";
|
||||
import { language_model_providers as languageModelProvidersSchema } from "@/db/schema";
|
||||
import { RegularModelProvider } from "@/constants/models";
|
||||
import type { LanguageModelProvider } from "@/ipc/ipc_types";
|
||||
|
||||
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
|
||||
openai: "OPENAI_API_KEY",
|
||||
anthropic: "ANTHROPIC_API_KEY",
|
||||
google: "GEMINI_API_KEY",
|
||||
openrouter: "OPENROUTER_API_KEY",
|
||||
};
|
||||
|
||||
export const PROVIDERS: Record<
|
||||
RegularModelProvider,
|
||||
{
|
||||
displayName: string;
|
||||
hasFreeTier?: boolean;
|
||||
websiteUrl?: string;
|
||||
gatewayPrefix: string;
|
||||
}
|
||||
> = {
|
||||
openai: {
|
||||
displayName: "OpenAI",
|
||||
hasFreeTier: false,
|
||||
websiteUrl: "https://platform.openai.com/api-keys",
|
||||
gatewayPrefix: "",
|
||||
},
|
||||
anthropic: {
|
||||
displayName: "Anthropic",
|
||||
hasFreeTier: false,
|
||||
websiteUrl: "https://console.anthropic.com/settings/keys",
|
||||
gatewayPrefix: "anthropic/",
|
||||
},
|
||||
google: {
|
||||
displayName: "Google",
|
||||
hasFreeTier: true,
|
||||
websiteUrl: "https://aistudio.google.com/app/apikey",
|
||||
gatewayPrefix: "gemini/",
|
||||
},
|
||||
openrouter: {
|
||||
displayName: "OpenRouter",
|
||||
hasFreeTier: true,
|
||||
websiteUrl: "https://openrouter.ai/settings/keys",
|
||||
gatewayPrefix: "openrouter/",
|
||||
},
|
||||
auto: {
|
||||
displayName: "Dyad",
|
||||
websiteUrl: "https://academy.dyad.sh/settings",
|
||||
gatewayPrefix: "",
|
||||
},
|
||||
};
|
||||
|
||||
/**
|
||||
* Fetches language model providers from both the database (custom) and hardcoded constants (cloud),
|
||||
* merging them with custom providers taking precedence.
|
||||
* @returns A promise that resolves to an array of LanguageModelProvider objects.
|
||||
*/
|
||||
export async function getLanguageModelProviders(): Promise<
|
||||
LanguageModelProvider[]
|
||||
> {
|
||||
// Fetch custom providers from the database
|
||||
const customProvidersDb = await db
|
||||
.select()
|
||||
.from(languageModelProvidersSchema);
|
||||
|
||||
const customProvidersMap = new Map<string, LanguageModelProvider>();
|
||||
for (const cp of customProvidersDb) {
|
||||
customProvidersMap.set(cp.id, {
|
||||
id: cp.id,
|
||||
name: cp.name,
|
||||
apiBaseUrl: cp.api_base_url,
|
||||
envVarName: cp.env_var_name ?? undefined,
|
||||
type: "custom",
|
||||
// hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema
|
||||
// They will be undefined unless overridden by hardcoded values if IDs match
|
||||
});
|
||||
}
|
||||
|
||||
// Get hardcoded cloud providers
|
||||
const hardcodedProviders: LanguageModelProvider[] = [];
|
||||
for (const providerKey in PROVIDERS) {
|
||||
if (Object.prototype.hasOwnProperty.call(PROVIDERS, providerKey)) {
|
||||
// Ensure providerKey is a key of PROVIDERS
|
||||
const key = providerKey as keyof typeof PROVIDERS;
|
||||
const providerDetails = PROVIDERS[key];
|
||||
if (providerDetails) {
|
||||
// Ensure providerDetails is not undefined
|
||||
hardcodedProviders.push({
|
||||
id: key,
|
||||
name: providerDetails.displayName,
|
||||
hasFreeTier: providerDetails.hasFreeTier,
|
||||
websiteUrl: providerDetails.websiteUrl,
|
||||
gatewayPrefix: providerDetails.gatewayPrefix,
|
||||
envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined,
|
||||
type: "cloud",
|
||||
// apiBaseUrl is not directly in PROVIDERS
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Merge lists: custom providers take precedence
|
||||
const mergedProvidersMap = new Map<string, LanguageModelProvider>();
|
||||
|
||||
// Add all hardcoded providers first
|
||||
for (const hp of hardcodedProviders) {
|
||||
mergedProvidersMap.set(hp.id, hp);
|
||||
}
|
||||
|
||||
// Add/overwrite with custom providers from DB
|
||||
for (const [id, cp] of customProvidersMap) {
|
||||
const existingProvider = mergedProvidersMap.get(id);
|
||||
if (existingProvider) {
|
||||
// If exists, merge. Custom fields take precedence.
|
||||
mergedProvidersMap.set(id, {
|
||||
...existingProvider, // start with hardcoded
|
||||
...cp, // override with custom where defined
|
||||
id: cp.id, // ensure custom id is used
|
||||
name: cp.name, // ensure custom name is used
|
||||
type: "custom", // explicitly set type to custom
|
||||
apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl,
|
||||
envVarName: cp.envVarName ?? existingProvider.envVarName,
|
||||
});
|
||||
} else {
|
||||
// If it doesn't exist in hardcoded, just add the custom one
|
||||
mergedProvidersMap.set(id, cp);
|
||||
}
|
||||
}
|
||||
|
||||
return Array.from(mergedProvidersMap.values());
|
||||
}
|
||||
@@ -5,36 +5,38 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider";
|
||||
import { createOllama } from "ollama-ai-provider";
|
||||
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
||||
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
|
||||
import {
|
||||
PROVIDER_TO_ENV_VAR,
|
||||
AUTO_MODELS,
|
||||
PROVIDERS,
|
||||
MODEL_OPTIONS,
|
||||
} from "../../constants/models";
|
||||
import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models";
|
||||
import { getEnvVar } from "./read_env";
|
||||
import log from "electron-log";
|
||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||
|
||||
const logger = log.scope("getModelClient");
|
||||
export function getModelClient(
|
||||
export async function getModelClient(
|
||||
model: LargeLanguageModel,
|
||||
settings: UserSettings,
|
||||
) {
|
||||
const allProviders = await getLanguageModelProviders();
|
||||
|
||||
const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value;
|
||||
// Handle 'auto' provider by trying each model in AUTO_MODELS until one works
|
||||
if (model.provider === "auto") {
|
||||
// Try each model in AUTO_MODELS in order until finding one with an API key
|
||||
for (const autoModel of AUTO_MODELS) {
|
||||
const providerInfo = allProviders.find(
|
||||
(p) => p.id === autoModel.provider,
|
||||
);
|
||||
const envVarName = providerInfo?.envVarName;
|
||||
|
||||
const apiKey =
|
||||
dyadApiKey ||
|
||||
settings.providerSettings?.[autoModel.provider]?.apiKey ||
|
||||
getEnvVar(PROVIDER_TO_ENV_VAR[autoModel.provider]);
|
||||
settings.providerSettings?.[autoModel.provider]?.apiKey?.value ||
|
||||
(envVarName ? getEnvVar(envVarName) : undefined);
|
||||
|
||||
if (apiKey) {
|
||||
logger.log(
|
||||
`Using provider: ${autoModel.provider} model: ${autoModel.name}`,
|
||||
);
|
||||
// Use the first model that has an API key
|
||||
return getModelClient(
|
||||
// Recursively call with the specific model found
|
||||
return await getModelClient(
|
||||
{
|
||||
provider: autoModel.provider,
|
||||
name: autoModel.name,
|
||||
@@ -43,27 +45,48 @@ export function getModelClient(
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// If no models have API keys, throw an error
|
||||
throw new Error("No API keys available for any model in AUTO_MODELS");
|
||||
throw new Error(
|
||||
"No API keys available for any model supported by the 'auto' provider.",
|
||||
);
|
||||
}
|
||||
|
||||
// --- Handle specific provider ---
|
||||
const providerConfig = allProviders.find((p) => p.id === model.provider);
|
||||
|
||||
if (!providerConfig) {
|
||||
throw new Error(`Configuration not found for provider: ${model.provider}`);
|
||||
}
|
||||
|
||||
// Handle Dyad Pro override
|
||||
if (dyadApiKey && settings.enableDyadPro) {
|
||||
const provider = createOpenAI({
|
||||
apiKey: dyadApiKey,
|
||||
baseURL: "https://llm-gateway.dyad.sh/v1",
|
||||
});
|
||||
const providerInfo = PROVIDERS[model.provider as keyof typeof PROVIDERS];
|
||||
logger.info("Using Dyad Pro API key");
|
||||
// Do not use free variant (for openrouter).
|
||||
const modelName = model.name.split(":free")[0];
|
||||
return provider(`${providerInfo.gatewayPrefix}${modelName}`);
|
||||
// Check if the selected provider supports Dyad Pro (has a gateway prefix)
|
||||
if (providerConfig.gatewayPrefix) {
|
||||
const provider = createOpenAI({
|
||||
apiKey: dyadApiKey,
|
||||
baseURL: "https://llm-gateway.dyad.sh/v1",
|
||||
});
|
||||
logger.info("Using Dyad Pro API key via Gateway");
|
||||
// Do not use free variant (for openrouter).
|
||||
const modelName = model.name.split(":free")[0];
|
||||
return provider(`${providerConfig.gatewayPrefix}${modelName}`);
|
||||
} else {
|
||||
logger.warn(
|
||||
`Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`,
|
||||
);
|
||||
// Fall through to regular provider logic if gateway prefix is missing
|
||||
}
|
||||
}
|
||||
|
||||
// Get API key for the specific provider
|
||||
const apiKey =
|
||||
settings.providerSettings?.[model.provider]?.apiKey?.value ||
|
||||
getEnvVar(PROVIDER_TO_ENV_VAR[model.provider]);
|
||||
switch (model.provider) {
|
||||
(providerConfig.envVarName
|
||||
? getEnvVar(providerConfig.envVarName)
|
||||
: undefined);
|
||||
|
||||
// Create client based on provider ID or type
|
||||
switch (providerConfig.id) {
|
||||
case "openai": {
|
||||
const provider = createOpenAI({ apiKey });
|
||||
return provider(model.name);
|
||||
@@ -81,18 +104,38 @@ export function getModelClient(
|
||||
return provider(model.name);
|
||||
}
|
||||
case "ollama": {
|
||||
const provider = createOllama();
|
||||
// Ollama typically runs locally and doesn't require an API key in the same way
|
||||
const provider = createOllama({
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
case "lmstudio": {
|
||||
// Using LM Studio's OpenAI compatible API
|
||||
const baseURL = "http://localhost:1234/v1"; // Default LM Studio OpenAI API URL
|
||||
const provider = createOpenAICompatible({ name: "lmstudio", baseURL });
|
||||
// LM Studio uses OpenAI compatible API
|
||||
const baseURL = providerConfig.apiBaseUrl || "http://localhost:1234/v1";
|
||||
const provider = createOpenAICompatible({
|
||||
name: "lmstudio",
|
||||
baseURL,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
default: {
|
||||
// Ensure exhaustive check if more providers are added
|
||||
const _exhaustiveCheck: never = model.provider;
|
||||
// Handle custom providers
|
||||
if (providerConfig.type === "custom") {
|
||||
if (!providerConfig.apiBaseUrl) {
|
||||
throw new Error(
|
||||
`Custom provider ${model.provider} is missing the API Base URL.`,
|
||||
);
|
||||
}
|
||||
// Assume custom providers are OpenAI compatible for now
|
||||
const provider = createOpenAICompatible({
|
||||
name: providerConfig.id,
|
||||
baseURL: providerConfig.apiBaseUrl,
|
||||
apiKey: apiKey,
|
||||
});
|
||||
return provider(model.name);
|
||||
}
|
||||
// If it's not a known ID and not type 'custom', it's unsupported
|
||||
throw new Error(`Unsupported model provider: ${model.provider}`);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import { contextBridge, ipcRenderer } from "electron";
|
||||
|
||||
// Whitelist of valid channels
|
||||
const validInvokeChannels = [
|
||||
"get-language-model-providers",
|
||||
"chat:add-dep",
|
||||
"chat:message",
|
||||
"chat:cancel",
|
||||
|
||||
@@ -1,7 +1,13 @@
|
||||
import { defineConfig } from "vite";
|
||||
import path from "path";
|
||||
|
||||
// https://vitejs.dev/config
|
||||
export default defineConfig({
|
||||
resolve: {
|
||||
alias: {
|
||||
"@": path.resolve(__dirname, "./src"),
|
||||
},
|
||||
},
|
||||
build: {
|
||||
rollupOptions: {
|
||||
external: ["better-sqlite3"],
|
||||
|
||||
Reference in New Issue
Block a user