From cd7eaa8ece7daba65e3d3e8cbfe65ca2d31ed9a0 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Mon, 12 May 2025 14:52:48 -0700 Subject: [PATCH] Prep for custom models: support reading custom providers (#131) --- .cursor/rules/ipc.mdc | 22 +- drizzle/0005_left_thor.sql | 20 + drizzle/meta/0005_snapshot.json | 356 ++++++++++++++++++ drizzle/meta/_journal.json | 7 + src/components/ProviderSettings.tsx | 65 +++- src/components/SetupBanner.tsx | 6 +- src/components/chat/HomeChatInput.tsx | 7 +- src/components/chat/MessagesList.tsx | 6 +- .../settings/ProviderSettingsPage.tsx | 110 +++++- src/constants/models.ts | 56 +-- src/db/schema.ts | 51 +++ src/hooks/useLanguageModelProviders.ts | 42 +++ src/hooks/useSettings.ts | 26 +- src/ipc/handlers/app_handlers.ts | 9 +- src/ipc/handlers/chat_stream_handlers.ts | 5 +- src/ipc/handlers/language_model_handlers.ts | 16 + src/ipc/ipc_client.ts | 13 +- src/ipc/ipc_host.ts | 3 + src/ipc/ipc_types.ts | 11 + src/ipc/shared/language_model_helpers.ts | 131 +++++++ src/ipc/utils/get_model_client.ts | 105 ++++-- src/preload.ts | 1 + vite.main.config.mts | 6 + 23 files changed, 901 insertions(+), 173 deletions(-) create mode 100644 drizzle/0005_left_thor.sql create mode 100644 drizzle/meta/0005_snapshot.json create mode 100644 src/hooks/useLanguageModelProviders.ts create mode 100644 src/ipc/handlers/language_model_handlers.ts create mode 100644 src/ipc/shared/language_model_helpers.ts diff --git a/.cursor/rules/ipc.mdc b/.cursor/rules/ipc.mdc index 2b88b96..d8d59e1 100644 --- a/.cursor/rules/ipc.mdc +++ b/.cursor/rules/ipc.mdc @@ -66,27 +66,7 @@ The pattern involves a client-side React hook interacting with main process IPC * Contains the core business logic, interacting with databases (e.g., `db`), file system (`fs`), or other main-process services (e.g., `git`). * **Error Handling (Crucial):** * **Handlers MUST `throw new Error("Descriptive error message")` when an operation fails or an invalid state is encountered.** This is the preferred pattern over returning objects like `{ success: false, errorMessage: "..." }`. - * Use `try...catch` blocks to handle errors from underlying operations (e.g., database queries, file system access, git commands). - * Inside the `catch` block, log the original error for debugging purposes and then `throw` a new, often more user-friendly or context-specific, `Error`. - * Example: - ```typescript - ipcMain.handle("list-entities", async (_, { parentId }) => { - if (!parentId) { - throw new Error("Parent ID is required to list entities."); - } - try { - const entities = await db.query.entities.findMany({ where: eq(entities.parentId, parentId) }); - if (!entities) { - // Or handle as empty list depending on requirements - throw new Error(`No entities found for parent ID: ${parentId}`); - } - return entities; - } catch (error: any) { - logger.error(`Error listing entities for parent ${parentId}:`, error); - throw new Error(`Failed to list entities: ${error.message}`); - } - }); - ``` + * **Concurrency (If Applicable):** * For operations that modify shared resources related to a specific entity (like an `appId`), use a locking mechanism (e.g., `withLock(appId, async () => { ... })`) to prevent race conditions. diff --git a/drizzle/0005_left_thor.sql b/drizzle/0005_left_thor.sql new file mode 100644 index 0000000..8942bb5 --- /dev/null +++ b/drizzle/0005_left_thor.sql @@ -0,0 +1,20 @@ +CREATE TABLE `language_model_providers` ( + `id` text PRIMARY KEY NOT NULL, + `name` text NOT NULL, + `api_base_url` text NOT NULL, + `env_var_name` text, + `created_at` integer DEFAULT (unixepoch()) NOT NULL, + `updated_at` integer DEFAULT (unixepoch()) NOT NULL +); +--> statement-breakpoint +CREATE TABLE `language_models` ( + `id` integer PRIMARY KEY AUTOINCREMENT NOT NULL, + `name` text NOT NULL, + `provider_id` text NOT NULL, + `description` text, + `max_output_tokens` integer, + `context_window` integer, + `created_at` integer DEFAULT (unixepoch()) NOT NULL, + `updated_at` integer DEFAULT (unixepoch()) NOT NULL, + FOREIGN KEY (`provider_id`) REFERENCES `language_model_providers`(`id`) ON UPDATE no action ON DELETE cascade +); diff --git a/drizzle/meta/0005_snapshot.json b/drizzle/meta/0005_snapshot.json new file mode 100644 index 0000000..3fdb55f --- /dev/null +++ b/drizzle/meta/0005_snapshot.json @@ -0,0 +1,356 @@ +{ + "version": "6", + "dialect": "sqlite", + "id": "29ca03c0-a5d6-4db2-a84a-03721206fdb4", + "prevId": "ceedb797-6aa3-4a50-b42f-bc85ee08b3df", + "tables": { + "apps": { + "name": "apps", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "path": { + "name": "path", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + }, + "updated_at": { + "name": "updated_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + }, + "github_org": { + "name": "github_org", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "github_repo": { + "name": "github_repo", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "supabase_project_id": { + "name": "supabase_project_id", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "chats": { + "name": "chats", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "app_id": { + "name": "app_id", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "title": { + "name": "title", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "initial_commit_hash": { + "name": "initial_commit_hash", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + } + }, + "indexes": {}, + "foreignKeys": { + "chats_app_id_apps_id_fk": { + "name": "chats_app_id_apps_id_fk", + "tableFrom": "chats", + "tableTo": "apps", + "columnsFrom": [ + "app_id" + ], + "columnsTo": [ + "id" + ], + "onDelete": "cascade", + "onUpdate": "no action" + } + }, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "language_model_providers": { + "name": "language_model_providers", + "columns": { + "id": { + "name": "id", + "type": "text", + "primaryKey": true, + "notNull": true, + "autoincrement": false + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "api_base_url": { + "name": "api_base_url", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "env_var_name": { + "name": "env_var_name", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + }, + "updated_at": { + "name": "updated_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + } + }, + "indexes": {}, + "foreignKeys": {}, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "language_models": { + "name": "language_models", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "name": { + "name": "name", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "provider_id": { + "name": "provider_id", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "description": { + "name": "description", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "max_output_tokens": { + "name": "max_output_tokens", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "context_window": { + "name": "context_window", + "type": "integer", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + }, + "updated_at": { + "name": "updated_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + } + }, + "indexes": {}, + "foreignKeys": { + "language_models_provider_id_language_model_providers_id_fk": { + "name": "language_models_provider_id_language_model_providers_id_fk", + "tableFrom": "language_models", + "tableTo": "language_model_providers", + "columnsFrom": [ + "provider_id" + ], + "columnsTo": [ + "id" + ], + "onDelete": "cascade", + "onUpdate": "no action" + } + }, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + }, + "messages": { + "name": "messages", + "columns": { + "id": { + "name": "id", + "type": "integer", + "primaryKey": true, + "notNull": true, + "autoincrement": true + }, + "chat_id": { + "name": "chat_id", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "role": { + "name": "role", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "content": { + "name": "content", + "type": "text", + "primaryKey": false, + "notNull": true, + "autoincrement": false + }, + "approval_state": { + "name": "approval_state", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "commit_hash": { + "name": "commit_hash", + "type": "text", + "primaryKey": false, + "notNull": false, + "autoincrement": false + }, + "created_at": { + "name": "created_at", + "type": "integer", + "primaryKey": false, + "notNull": true, + "autoincrement": false, + "default": "(unixepoch())" + } + }, + "indexes": {}, + "foreignKeys": { + "messages_chat_id_chats_id_fk": { + "name": "messages_chat_id_chats_id_fk", + "tableFrom": "messages", + "tableTo": "chats", + "columnsFrom": [ + "chat_id" + ], + "columnsTo": [ + "id" + ], + "onDelete": "cascade", + "onUpdate": "no action" + } + }, + "compositePrimaryKeys": {}, + "uniqueConstraints": {}, + "checkConstraints": {} + } + }, + "views": {}, + "enums": {}, + "_meta": { + "schemas": {}, + "tables": {}, + "columns": {} + }, + "internal": { + "indexes": {} + } +} \ No newline at end of file diff --git a/drizzle/meta/_journal.json b/drizzle/meta/_journal.json index 1c3c5a2..9a45e80 100644 --- a/drizzle/meta/_journal.json +++ b/drizzle/meta/_journal.json @@ -36,6 +36,13 @@ "when": 1746556241557, "tag": "0004_flawless_jigsaw", "breakpoints": true + }, + { + "idx": 5, + "version": "6", + "when": 1747083562867, + "tag": "0005_left_thor", + "breakpoints": true } ] } \ No newline at end of file diff --git a/src/components/ProviderSettings.tsx b/src/components/ProviderSettings.tsx index 9a32e59..71e8d60 100644 --- a/src/components/ProviderSettings.tsx +++ b/src/components/ProviderSettings.tsx @@ -1,4 +1,3 @@ -import { PROVIDERS } from "@/constants/models"; import { Card, CardHeader, @@ -7,42 +6,84 @@ import { } from "@/components/ui/card"; import { useNavigate } from "@tanstack/react-router"; import { providerSettingsRoute } from "@/routes/settings/providers/$provider"; -import type { ModelProvider } from "@/lib/schemas"; -import { useSettings } from "@/hooks/useSettings"; +import type { LanguageModelProvider } from "@/ipc/ipc_types"; + +import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders"; import { GiftIcon } from "lucide-react"; +import { Skeleton } from "./ui/skeleton"; +import { Alert, AlertDescription, AlertTitle } from "./ui/alert"; +import { AlertTriangle } from "lucide-react"; export function ProviderSettingsGrid() { const navigate = useNavigate(); - const handleProviderClick = (provider: ModelProvider) => { + const { + data: providers, + isLoading, + error, + isProviderSetup, + } = useLanguageModelProviders(); + + const handleProviderClick = (providerId: string) => { navigate({ to: providerSettingsRoute.id, - params: { provider }, + params: { provider: providerId }, }); }; - const { isProviderSetup } = useSettings(); + if (isLoading) { + return ( +
+

AI Providers

+
+ {[1, 2, 3, 4, 5].map((i) => ( + + + + + + + ))} +
+
+ ); + } + + if (error) { + return ( +
+

AI Providers

+ + + Error + + Failed to load AI providers: {error.message} + + +
+ ); + } return (

AI Providers

- {Object.entries(PROVIDERS).map(([key, provider]) => { + {providers?.map((provider: LanguageModelProvider) => { return ( handleProviderClick(key as ModelProvider)} + onClick={() => handleProviderClick(provider.id)} > - {provider.displayName} - {isProviderSetup(key) ? ( + {provider.name} + {isProviderSetup(provider.id) ? ( Ready ) : ( - + Needs Setup )} diff --git a/src/components/SetupBanner.tsx b/src/components/SetupBanner.tsx index 1496d26..0d3d34e 100644 --- a/src/components/SetupBanner.tsx +++ b/src/components/SetupBanner.tsx @@ -11,7 +11,7 @@ import { } from "lucide-react"; import { providerSettingsRoute } from "@/routes/settings/providers/$provider"; import { settingsRoute } from "@/routes/settings"; -import { useSettings } from "@/hooks/useSettings"; + import { useState, useEffect, useCallback } from "react"; import { IpcClient } from "@/ipc/ipc_client"; import { @@ -24,6 +24,7 @@ import { Button } from "@/components/ui/button"; import { cn } from "@/lib/utils"; import { NodeSystemInfo } from "@/ipc/ipc_types"; import { usePostHog } from "posthog-js/react"; +import { useLanguageModelProviders } from "@/hooks/useLanguageModelProviders"; type NodeInstallStep = | "install" | "waiting-for-continue" @@ -33,7 +34,8 @@ type NodeInstallStep = export function SetupBanner() { const posthog = usePostHog(); const navigate = useNavigate(); - const { isAnyProviderSetup, loading } = useSettings(); + const { isAnyProviderSetup, isLoading: loading } = + useLanguageModelProviders(); const [nodeSystemInfo, setNodeSystemInfo] = useState( null, ); diff --git a/src/components/chat/HomeChatInput.tsx b/src/components/chat/HomeChatInput.tsx index d38129e..86626c5 100644 --- a/src/components/chat/HomeChatInput.tsx +++ b/src/components/chat/HomeChatInput.tsx @@ -20,7 +20,7 @@ export function HomeChatInput({ const posthog = usePostHog(); const [inputValue, setInputValue] = useAtom(homeChatInputValueAtom); const textareaRef = useRef(null); - const { settings, updateSettings, isAnyProviderSetup } = useSettings(); + const { settings, updateSettings } = useSettings(); const { isStreaming } = useStreamChat({ hasChatId: false, }); // eslint-disable-line @typescript-eslint/no-unused-vars @@ -137,10 +137,7 @@ export function HomeChatInput({ ) : ( +

+ Configure Provider +

+ + + Error Loading Provider Details + + Could not load provider data: {providersError.message} + + +
+
+ ); + } + + // Handle case where provider is not found (e.g., invalid ID in URL) + if (!providerData && !isDyad) { + return ( +
+
+ +

+ Provider Not Found +

+ + + Error + + The provider with ID "{provider}" could not be found. + + +
+
+ ); + } + return (
@@ -322,7 +410,7 @@ export function ProviderSettingsPage({ provider }: ProviderSettingsPageProps) { - {!isDyad && ( + {!isDyad && envVarName && ( ; +export type RegularModelProvider = Exclude< + ModelProvider, + "ollama" | "lmstudio" +>; export const MODEL_OPTIONS: Record = { openai: [ // https://platform.openai.com/docs/models/gpt-4.1 @@ -89,57 +92,6 @@ export const MODEL_OPTIONS: Record = { ], }; -export const PROVIDERS: Record< - RegularModelProvider, - { - displayName: string; - hasFreeTier?: boolean; - websiteUrl?: string; - gatewayPrefix: string; - } -> = { - openai: { - displayName: "OpenAI", - hasFreeTier: false, - websiteUrl: "https://platform.openai.com/api-keys", - gatewayPrefix: "", - }, - anthropic: { - displayName: "Anthropic", - hasFreeTier: false, - websiteUrl: "https://console.anthropic.com/settings/keys", - gatewayPrefix: "anthropic/", - }, - google: { - displayName: "Google", - hasFreeTier: true, - websiteUrl: "https://aistudio.google.com/app/apikey", - gatewayPrefix: "gemini/", - }, - openrouter: { - displayName: "OpenRouter", - hasFreeTier: true, - websiteUrl: "https://openrouter.ai/settings/keys", - gatewayPrefix: "openrouter/", - }, - auto: { - displayName: "Dyad", - websiteUrl: "https://academy.dyad.sh/settings", - gatewayPrefix: "", - }, -}; - -export const PROVIDER_TO_ENV_VAR: Record = { - openai: "OPENAI_API_KEY", - anthropic: "ANTHROPIC_API_KEY", - google: "GEMINI_API_KEY", - openrouter: "OPENROUTER_API_KEY", -}; - -export const ALLOWED_ENV_VARS = Object.keys(PROVIDER_TO_ENV_VAR).map( - (provider) => PROVIDER_TO_ENV_VAR[provider], -); - export const AUTO_MODELS = [ { provider: "google", diff --git a/src/db/schema.ts b/src/db/schema.ts index a1e4df8..1f25712 100644 --- a/src/db/schema.ts +++ b/src/db/schema.ts @@ -64,3 +64,54 @@ export const messagesRelations = relations(messages, ({ one }) => ({ references: [chats.id], }), })); + +export const language_model_providers = sqliteTable( + "language_model_providers", + { + id: text("id").primaryKey(), + name: text("name").notNull(), + api_base_url: text("api_base_url").notNull(), + env_var_name: text("env_var_name"), + createdAt: integer("created_at", { mode: "timestamp" }) + .notNull() + .default(sql`(unixepoch())`), + updatedAt: integer("updated_at", { mode: "timestamp" }) + .notNull() + .default(sql`(unixepoch())`), + }, +); + +export const language_models = sqliteTable("language_models", { + id: integer("id").primaryKey({ autoIncrement: true }), + name: text("name").notNull(), + provider_id: text("provider_id") + .notNull() + .references(() => language_model_providers.id, { onDelete: "cascade" }), + description: text("description"), + max_output_tokens: integer("max_output_tokens"), + context_window: integer("context_window"), + createdAt: integer("created_at", { mode: "timestamp" }) + .notNull() + .default(sql`(unixepoch())`), + updatedAt: integer("updated_at", { mode: "timestamp" }) + .notNull() + .default(sql`(unixepoch())`), +}); + +// Define relations for new tables +export const languageModelProvidersRelations = relations( + language_model_providers, + ({ many }) => ({ + languageModels: many(language_models), + }), +); + +export const languageModelsRelations = relations( + language_models, + ({ one }) => ({ + provider: one(language_model_providers, { + fields: [language_models.provider_id], + references: [language_model_providers.id], + }), + }), +); diff --git a/src/hooks/useLanguageModelProviders.ts b/src/hooks/useLanguageModelProviders.ts new file mode 100644 index 0000000..68793a9 --- /dev/null +++ b/src/hooks/useLanguageModelProviders.ts @@ -0,0 +1,42 @@ +import { useQuery } from "@tanstack/react-query"; +import { IpcClient } from "@/ipc/ipc_client"; +import type { LanguageModelProvider } from "@/ipc/ipc_types"; +import { useSettings } from "./useSettings"; +import { cloudProviders } from "@/lib/schemas"; + +export function useLanguageModelProviders() { + const ipcClient = IpcClient.getInstance(); + const { settings, envVars } = useSettings(); + + const queryResult = useQuery({ + queryKey: ["languageModelProviders"], + queryFn: async () => { + return ipcClient.getLanguageModelProviders(); + }, + }); + + const isProviderSetup = (provider: string) => { + const providerSettings = settings?.providerSettings[provider]; + if (queryResult.isLoading) { + return false; + } + if (providerSettings?.apiKey?.value) { + return true; + } + const providerData = queryResult.data?.find((p) => p.id === provider); + if (providerData?.envVarName && envVars[providerData.envVarName]) { + return true; + } + return false; + }; + + const isAnyProviderSetup = () => { + return cloudProviders.some((provider) => isProviderSetup(provider)); + }; + + return { + ...queryResult, + isProviderSetup, + isAnyProviderSetup, + }; +} diff --git a/src/hooks/useSettings.ts b/src/hooks/useSettings.ts index 53e2167..3994a78 100644 --- a/src/hooks/useSettings.ts +++ b/src/hooks/useSettings.ts @@ -2,15 +2,9 @@ import { useState, useEffect, useCallback } from "react"; import { useAtom } from "jotai"; import { userSettingsAtom, envVarsAtom } from "@/atoms/appAtoms"; import { IpcClient } from "@/ipc/ipc_client"; -import { cloudProviders, type UserSettings } from "@/lib/schemas"; +import { type UserSettings } from "@/lib/schemas"; import { usePostHog } from "posthog-js/react"; -const PROVIDER_TO_ENV_VAR: Record = { - openai: "OPENAI_API_KEY", - anthropic: "ANTHROPIC_API_KEY", - google: "GEMINI_API_KEY", -}; - const TELEMETRY_CONSENT_KEY = "dyadTelemetryConsent"; const TELEMETRY_USER_ID_KEY = "dyadTelemetryUserId"; @@ -81,17 +75,6 @@ export function useSettings() { } }; - const isProviderSetup = (provider: string) => { - const providerSettings = settings?.providerSettings[provider]; - if (providerSettings?.apiKey?.value) { - return true; - } - if (envVars[PROVIDER_TO_ENV_VAR[provider]]) { - return true; - } - return false; - }; - return { settings, envVars, @@ -99,13 +82,6 @@ export function useSettings() { error, updateSettings, - isProviderSetup, - isAnyProviderSetup: () => { - // Technically we should check for ollama and lmstudio being setup, but - // practically most users will want to use a cloud provider (at least - // some of the time) - return cloudProviders.some((provider) => isProviderSetup(provider)); - }, refreshSettings: () => { return loadInitialData(); }, diff --git a/src/ipc/handlers/app_handlers.ts b/src/ipc/handlers/app_handlers.ts index 92cf0f7..0b49fd6 100644 --- a/src/ipc/handlers/app_handlers.ts +++ b/src/ipc/handlers/app_handlers.ts @@ -22,7 +22,6 @@ import { killProcess, removeAppIfCurrentProcess, } from "../utils/process_manager"; -import { ALLOWED_ENV_VARS } from "../../constants/models"; import { getEnvVar } from "../utils/read_env"; import { readSettings } from "../../main/settings"; @@ -33,6 +32,7 @@ import util from "util"; import log from "electron-log"; import { getSupabaseProjectName } from "../../supabase_admin/supabase_management_client"; import { createLoggedHandler } from "./safe_handle"; +import { getLanguageModelProviders } from "../shared/language_model_helpers"; const logger = log.scope("app_handlers"); const handle = createLoggedHandler(logger); @@ -291,8 +291,11 @@ export function registerAppHandlers() { // Do NOT use handle for this, it contains sensitive information. ipcMain.handle("get-env-vars", async () => { const envVars: Record = {}; - for (const key of ALLOWED_ENV_VARS) { - envVars[key] = getEnvVar(key); + const providers = await getLanguageModelProviders(); + for (const provider of providers) { + if (provider.envVarName) { + envVars[provider.envVarName] = getEnvVar(provider.envVarName); + } } return envVars; }); diff --git a/src/ipc/handlers/chat_stream_handlers.ts b/src/ipc/handlers/chat_stream_handlers.ts index d777903..19b3479 100644 --- a/src/ipc/handlers/chat_stream_handlers.ts +++ b/src/ipc/handlers/chat_stream_handlers.ts @@ -212,7 +212,10 @@ export function registerChatStreamHandlers() { } else { // Normal AI processing for non-test prompts const settings = readSettings(); - const modelClient = getModelClient(settings.selectedModel, settings); + const modelClient = await getModelClient( + settings.selectedModel, + settings, + ); // Extract codebase information if app is associated with the chat let codebaseInfo = ""; diff --git a/src/ipc/handlers/language_model_handlers.ts b/src/ipc/handlers/language_model_handlers.ts new file mode 100644 index 0000000..6765bfc --- /dev/null +++ b/src/ipc/handlers/language_model_handlers.ts @@ -0,0 +1,16 @@ +import type { LanguageModelProvider } from "@/ipc/ipc_types"; +import { createLoggedHandler } from "./safe_handle"; +import log from "electron-log"; +import { getLanguageModelProviders } from "../shared/language_model_helpers"; + +const logger = log.scope("language_model_handlers"); +const handle = createLoggedHandler(logger); + +export function registerLanguageModelHandlers() { + handle( + "get-language-model-providers", + async (): Promise => { + return getLanguageModelProviders(); + }, + ); +} diff --git a/src/ipc/ipc_client.ts b/src/ipc/ipc_client.ts index f0e12e5..0416655 100644 --- a/src/ipc/ipc_client.ts +++ b/src/ipc/ipc_client.ts @@ -21,6 +21,7 @@ import type { TokenCountResult, ChatLogsData, BranchResult, + LanguageModelProvider, } from "./ipc_types"; import type { ProposalResult } from "@/lib/schemas"; import { showError } from "@/lib/toast"; @@ -724,13 +725,11 @@ export class IpcClient { // Get system platform (win32, darwin, linux) public async getSystemPlatform(): Promise { - try { - const platform = await this.ipcRenderer.invoke("window:get-platform"); - return platform; - } catch (error) { - showError(error); - throw error; - } + return this.ipcRenderer.invoke("get-system-platform"); + } + + public async getLanguageModelProviders(): Promise { + return this.ipcRenderer.invoke("get-language-model-providers"); } // --- End window control methods --- diff --git a/src/ipc/ipc_host.ts b/src/ipc/ipc_host.ts index 0275a01..03cf4a0 100644 --- a/src/ipc/ipc_host.ts +++ b/src/ipc/ipc_host.ts @@ -14,6 +14,8 @@ import { registerTokenCountHandlers } from "./handlers/token_count_handlers"; import { registerWindowHandlers } from "./handlers/window_handlers"; import { registerUploadHandlers } from "./handlers/upload_handlers"; import { registerVersionHandlers } from "./handlers/version_handlers"; +import { registerLanguageModelHandlers } from "./handlers/language_model_handlers"; + export function registerIpcHandlers() { // Register all IPC handlers by category registerAppHandlers(); @@ -32,4 +34,5 @@ export function registerIpcHandlers() { registerWindowHandlers(); registerUploadHandlers(); registerVersionHandlers(); + registerLanguageModelHandlers(); } diff --git a/src/ipc/ipc_types.ts b/src/ipc/ipc_types.ts index 2942abe..29ea347 100644 --- a/src/ipc/ipc_types.ts +++ b/src/ipc/ipc_types.ts @@ -132,3 +132,14 @@ export interface ChatLogsData { chat: Chat; codebase: string; } + +export interface LanguageModelProvider { + id: string; + name: string; + hasFreeTier?: boolean; + websiteUrl?: string; + gatewayPrefix?: string; + envVarName?: string; + apiBaseUrl?: string; + type: "custom" | "local" | "cloud"; +} diff --git a/src/ipc/shared/language_model_helpers.ts b/src/ipc/shared/language_model_helpers.ts new file mode 100644 index 0000000..2e2699d --- /dev/null +++ b/src/ipc/shared/language_model_helpers.ts @@ -0,0 +1,131 @@ +import { db } from "@/db"; +import { language_model_providers as languageModelProvidersSchema } from "@/db/schema"; +import { RegularModelProvider } from "@/constants/models"; +import type { LanguageModelProvider } from "@/ipc/ipc_types"; + +export const PROVIDER_TO_ENV_VAR: Record = { + openai: "OPENAI_API_KEY", + anthropic: "ANTHROPIC_API_KEY", + google: "GEMINI_API_KEY", + openrouter: "OPENROUTER_API_KEY", +}; + +export const PROVIDERS: Record< + RegularModelProvider, + { + displayName: string; + hasFreeTier?: boolean; + websiteUrl?: string; + gatewayPrefix: string; + } +> = { + openai: { + displayName: "OpenAI", + hasFreeTier: false, + websiteUrl: "https://platform.openai.com/api-keys", + gatewayPrefix: "", + }, + anthropic: { + displayName: "Anthropic", + hasFreeTier: false, + websiteUrl: "https://console.anthropic.com/settings/keys", + gatewayPrefix: "anthropic/", + }, + google: { + displayName: "Google", + hasFreeTier: true, + websiteUrl: "https://aistudio.google.com/app/apikey", + gatewayPrefix: "gemini/", + }, + openrouter: { + displayName: "OpenRouter", + hasFreeTier: true, + websiteUrl: "https://openrouter.ai/settings/keys", + gatewayPrefix: "openrouter/", + }, + auto: { + displayName: "Dyad", + websiteUrl: "https://academy.dyad.sh/settings", + gatewayPrefix: "", + }, +}; + +/** + * Fetches language model providers from both the database (custom) and hardcoded constants (cloud), + * merging them with custom providers taking precedence. + * @returns A promise that resolves to an array of LanguageModelProvider objects. + */ +export async function getLanguageModelProviders(): Promise< + LanguageModelProvider[] +> { + // Fetch custom providers from the database + const customProvidersDb = await db + .select() + .from(languageModelProvidersSchema); + + const customProvidersMap = new Map(); + for (const cp of customProvidersDb) { + customProvidersMap.set(cp.id, { + id: cp.id, + name: cp.name, + apiBaseUrl: cp.api_base_url, + envVarName: cp.env_var_name ?? undefined, + type: "custom", + // hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema + // They will be undefined unless overridden by hardcoded values if IDs match + }); + } + + // Get hardcoded cloud providers + const hardcodedProviders: LanguageModelProvider[] = []; + for (const providerKey in PROVIDERS) { + if (Object.prototype.hasOwnProperty.call(PROVIDERS, providerKey)) { + // Ensure providerKey is a key of PROVIDERS + const key = providerKey as keyof typeof PROVIDERS; + const providerDetails = PROVIDERS[key]; + if (providerDetails) { + // Ensure providerDetails is not undefined + hardcodedProviders.push({ + id: key, + name: providerDetails.displayName, + hasFreeTier: providerDetails.hasFreeTier, + websiteUrl: providerDetails.websiteUrl, + gatewayPrefix: providerDetails.gatewayPrefix, + envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined, + type: "cloud", + // apiBaseUrl is not directly in PROVIDERS + }); + } + } + } + + // Merge lists: custom providers take precedence + const mergedProvidersMap = new Map(); + + // Add all hardcoded providers first + for (const hp of hardcodedProviders) { + mergedProvidersMap.set(hp.id, hp); + } + + // Add/overwrite with custom providers from DB + for (const [id, cp] of customProvidersMap) { + const existingProvider = mergedProvidersMap.get(id); + if (existingProvider) { + // If exists, merge. Custom fields take precedence. + mergedProvidersMap.set(id, { + ...existingProvider, // start with hardcoded + ...cp, // override with custom where defined + id: cp.id, // ensure custom id is used + name: cp.name, // ensure custom name is used + type: "custom", // explicitly set type to custom + apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl, + envVarName: cp.envVarName ?? existingProvider.envVarName, + }); + } else { + // If it doesn't exist in hardcoded, just add the custom one + mergedProvidersMap.set(id, cp); + } + } + + return Array.from(mergedProvidersMap.values()); +} diff --git a/src/ipc/utils/get_model_client.ts b/src/ipc/utils/get_model_client.ts index df36abf..a681d0a 100644 --- a/src/ipc/utils/get_model_client.ts +++ b/src/ipc/utils/get_model_client.ts @@ -5,36 +5,38 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider"; import { createOllama } from "ollama-ai-provider"; import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import type { LargeLanguageModel, UserSettings } from "../../lib/schemas"; -import { - PROVIDER_TO_ENV_VAR, - AUTO_MODELS, - PROVIDERS, - MODEL_OPTIONS, -} from "../../constants/models"; +import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models"; import { getEnvVar } from "./read_env"; import log from "electron-log"; +import { getLanguageModelProviders } from "../shared/language_model_helpers"; const logger = log.scope("getModelClient"); -export function getModelClient( +export async function getModelClient( model: LargeLanguageModel, settings: UserSettings, ) { + const allProviders = await getLanguageModelProviders(); + const dyadApiKey = settings.providerSettings?.auto?.apiKey?.value; // Handle 'auto' provider by trying each model in AUTO_MODELS until one works if (model.provider === "auto") { - // Try each model in AUTO_MODELS in order until finding one with an API key for (const autoModel of AUTO_MODELS) { + const providerInfo = allProviders.find( + (p) => p.id === autoModel.provider, + ); + const envVarName = providerInfo?.envVarName; + const apiKey = dyadApiKey || - settings.providerSettings?.[autoModel.provider]?.apiKey || - getEnvVar(PROVIDER_TO_ENV_VAR[autoModel.provider]); + settings.providerSettings?.[autoModel.provider]?.apiKey?.value || + (envVarName ? getEnvVar(envVarName) : undefined); if (apiKey) { logger.log( `Using provider: ${autoModel.provider} model: ${autoModel.name}`, ); - // Use the first model that has an API key - return getModelClient( + // Recursively call with the specific model found + return await getModelClient( { provider: autoModel.provider, name: autoModel.name, @@ -43,27 +45,48 @@ export function getModelClient( ); } } - // If no models have API keys, throw an error - throw new Error("No API keys available for any model in AUTO_MODELS"); + throw new Error( + "No API keys available for any model supported by the 'auto' provider.", + ); } + // --- Handle specific provider --- + const providerConfig = allProviders.find((p) => p.id === model.provider); + + if (!providerConfig) { + throw new Error(`Configuration not found for provider: ${model.provider}`); + } + + // Handle Dyad Pro override if (dyadApiKey && settings.enableDyadPro) { - const provider = createOpenAI({ - apiKey: dyadApiKey, - baseURL: "https://llm-gateway.dyad.sh/v1", - }); - const providerInfo = PROVIDERS[model.provider as keyof typeof PROVIDERS]; - logger.info("Using Dyad Pro API key"); - // Do not use free variant (for openrouter). - const modelName = model.name.split(":free")[0]; - return provider(`${providerInfo.gatewayPrefix}${modelName}`); + // Check if the selected provider supports Dyad Pro (has a gateway prefix) + if (providerConfig.gatewayPrefix) { + const provider = createOpenAI({ + apiKey: dyadApiKey, + baseURL: "https://llm-gateway.dyad.sh/v1", + }); + logger.info("Using Dyad Pro API key via Gateway"); + // Do not use free variant (for openrouter). + const modelName = model.name.split(":free")[0]; + return provider(`${providerConfig.gatewayPrefix}${modelName}`); + } else { + logger.warn( + `Dyad Pro enabled, but provider ${model.provider} does not have a gateway prefix defined. Falling back to direct provider connection.`, + ); + // Fall through to regular provider logic if gateway prefix is missing + } } + // Get API key for the specific provider const apiKey = settings.providerSettings?.[model.provider]?.apiKey?.value || - getEnvVar(PROVIDER_TO_ENV_VAR[model.provider]); - switch (model.provider) { + (providerConfig.envVarName + ? getEnvVar(providerConfig.envVarName) + : undefined); + + // Create client based on provider ID or type + switch (providerConfig.id) { case "openai": { const provider = createOpenAI({ apiKey }); return provider(model.name); @@ -81,18 +104,38 @@ export function getModelClient( return provider(model.name); } case "ollama": { - const provider = createOllama(); + // Ollama typically runs locally and doesn't require an API key in the same way + const provider = createOllama({ + baseURL: providerConfig.apiBaseUrl, + }); return provider(model.name); } case "lmstudio": { - // Using LM Studio's OpenAI compatible API - const baseURL = "http://localhost:1234/v1"; // Default LM Studio OpenAI API URL - const provider = createOpenAICompatible({ name: "lmstudio", baseURL }); + // LM Studio uses OpenAI compatible API + const baseURL = providerConfig.apiBaseUrl || "http://localhost:1234/v1"; + const provider = createOpenAICompatible({ + name: "lmstudio", + baseURL, + }); return provider(model.name); } default: { - // Ensure exhaustive check if more providers are added - const _exhaustiveCheck: never = model.provider; + // Handle custom providers + if (providerConfig.type === "custom") { + if (!providerConfig.apiBaseUrl) { + throw new Error( + `Custom provider ${model.provider} is missing the API Base URL.`, + ); + } + // Assume custom providers are OpenAI compatible for now + const provider = createOpenAICompatible({ + name: providerConfig.id, + baseURL: providerConfig.apiBaseUrl, + apiKey: apiKey, + }); + return provider(model.name); + } + // If it's not a known ID and not type 'custom', it's unsupported throw new Error(`Unsupported model provider: ${model.provider}`); } } diff --git a/src/preload.ts b/src/preload.ts index cfbfd68..e304d96 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -5,6 +5,7 @@ import { contextBridge, ipcRenderer } from "electron"; // Whitelist of valid channels const validInvokeChannels = [ + "get-language-model-providers", "chat:add-dep", "chat:message", "chat:cancel", diff --git a/vite.main.config.mts b/vite.main.config.mts index 18045b0..db7944c 100644 --- a/vite.main.config.mts +++ b/vite.main.config.mts @@ -1,7 +1,13 @@ import { defineConfig } from "vite"; +import path from "path"; // https://vitejs.dev/config export default defineConfig({ + resolve: { + alias: { + "@": path.resolve(__dirname, "./src"), + }, + }, build: { rollupOptions: { external: ["better-sqlite3"],