diff --git a/package-lock.json b/package-lock.json index fa1092c..ec722f4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -52,6 +52,7 @@ "kill-port": "^2.0.1", "lucide-react": "^0.487.0", "monaco-editor": "^0.52.2", + "ollama-ai-provider": "^1.2.0", "openai": "^4.91.1", "posthog-js": "^1.236.3", "react": "^19.0.0", @@ -16841,6 +16842,28 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/ollama-ai-provider": { + "version": "1.2.0", + "resolved": "https://registry.npmjs.org/ollama-ai-provider/-/ollama-ai-provider-1.2.0.tgz", + "integrity": "sha512-jTNFruwe3O/ruJeppI/quoOUxG7NA6blG3ZyQj3lei4+NnJo7bi3eIRWqlVpRlu/mbzbFXeJSBuYQWF6pzGKww==", + "license": "Apache-2.0", + "dependencies": { + "@ai-sdk/provider": "^1.0.0", + "@ai-sdk/provider-utils": "^2.0.0", + "partial-json": "0.1.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + } + } + }, "node_modules/on-finished": { "version": "2.4.1", "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz", @@ -17236,6 +17259,12 @@ "node": ">= 0.8" } }, + "node_modules/partial-json": { + "version": "0.1.7", + "resolved": "https://registry.npmjs.org/partial-json/-/partial-json-0.1.7.tgz", + "integrity": "sha512-Njv/59hHaokb/hRUjce3Hdv12wd60MtM9Z5Olmn+nehe0QDAsRtRbJPvJ0Z91TusF0SuZRIvnM+S4l6EIP8leA==", + "license": "MIT" + }, "node_modules/path-browserify": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/path-browserify/-/path-browserify-1.0.1.tgz", diff --git a/package.json b/package.json index 4b569bd..f5c37b4 100644 --- a/package.json +++ b/package.json @@ -67,6 +67,7 @@ "@ai-sdk/google": "^1.2.10", "@ai-sdk/openai": "^1.3.7", "@biomejs/biome": "^1.9.4", + "@dyad-sh/supabase-management-js": "v1.0.0", "@monaco-editor/react": "^4.7.0-rc.0", "@openrouter/ai-sdk-provider": "^0.4.5", "@radix-ui/react-accordion": "^1.2.4", @@ -105,6 +106,7 @@ "kill-port": "^2.0.1", "lucide-react": "^0.487.0", "monaco-editor": "^0.52.2", + "ollama-ai-provider": "^1.2.0", "openai": "^4.91.1", "posthog-js": "^1.236.3", "react": "^19.0.0", @@ -116,7 +118,6 @@ "shell-env": "^4.0.1", "shiki": "^3.2.1", "sonner": "^2.0.3", - "@dyad-sh/supabase-management-js": "v1.0.0", "tailwind-merge": "^3.1.0", "tailwindcss": "^4.1.3", "tree-kill": "^1.2.2", @@ -124,4 +125,4 @@ "update-electron-app": "^3.1.1", "uuid": "^11.1.0" } -} \ No newline at end of file +} diff --git a/src/atoms/localModelsAtoms.ts b/src/atoms/localModelsAtoms.ts new file mode 100644 index 0000000..4afddde --- /dev/null +++ b/src/atoms/localModelsAtoms.ts @@ -0,0 +1,6 @@ +import { atom } from "jotai"; +import { type LocalModel } from "@/ipc/ipc_types"; + +export const localModelsAtom = atom([]); +export const localModelsLoadingAtom = atom(false); +export const localModelsErrorAtom = atom(null); diff --git a/src/components/ModelPicker.tsx b/src/components/ModelPicker.tsx index ae921aa..ce15a30 100644 --- a/src/components/ModelPicker.tsx +++ b/src/components/ModelPicker.tsx @@ -1,17 +1,25 @@ import type { LargeLanguageModel, ModelProvider } from "@/lib/schemas"; import { Button } from "@/components/ui/button"; -import { - Popover, - PopoverContent, - PopoverTrigger, -} from "@/components/ui/popover"; import { Tooltip, TooltipContent, TooltipTrigger, } from "@/components/ui/tooltip"; -import { useState } from "react"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuLabel, + DropdownMenuSeparator, + DropdownMenuTrigger, + DropdownMenuSub, + DropdownMenuSubTrigger, + DropdownMenuSubContent, +} from "@/components/ui/dropdown-menu"; +import { useEffect, useState } from "react"; import { MODEL_OPTIONS } from "@/constants/models"; +import { useLocalModels } from "@/hooks/useLocalModels"; +import { ChevronDown } from "lucide-react"; interface ModelPickerProps { selectedModel: LargeLanguageModel; @@ -23,9 +31,37 @@ export function ModelPicker({ onModelSelect, }: ModelPickerProps) { const [open, setOpen] = useState(false); - const modelDisplayName = MODEL_OPTIONS[selectedModel.provider].find( - (model) => model.name === selectedModel.name - )?.displayName; + const { + models: localModels, + loading: localModelsLoading, + error: localModelsError, + loadModels, + } = useLocalModels(); + + // Load local models when the component mounts or the dropdown opens + useEffect(() => { + if (open) { + loadModels(); + } + }, [open, loadModels]); + + // Get display name for the selected model + const getModelDisplayName = () => { + if (selectedModel.provider === "ollama") { + return ( + localModels.find((model) => model.modelName === selectedModel.name) + ?.displayName || selectedModel.name + ); + } + + return ( + MODEL_OPTIONS[selectedModel.provider]?.find( + (model) => model.name === selectedModel.name + )?.displayName || selectedModel.name + ); + }; + + const modelDisplayName = getModelDisplayName(); // Flatten the model options into a single array with provider information const allModels = Object.entries(MODEL_OPTIONS).flatMap( @@ -36,9 +72,13 @@ export function ModelPicker({ })) ); + // Determine if we have local models available + const hasLocalModels = + !localModelsLoading && !localModelsError && localModels.length > 0; + return ( - - + + - - -
- {allModels.map((model) => ( - - - - - {model.description} - - ))} -
-
-
+ + )) + )} + + + + ); } diff --git a/src/components/ui/dropdown-menu.tsx b/src/components/ui/dropdown-menu.tsx index 0d6741b..f80a794 100644 --- a/src/components/ui/dropdown-menu.tsx +++ b/src/components/ui/dropdown-menu.tsx @@ -1,13 +1,13 @@ -import * as React from "react" -import * as DropdownMenuPrimitive from "@radix-ui/react-dropdown-menu" -import { CheckIcon, ChevronRightIcon, CircleIcon } from "lucide-react" +import * as React from "react"; +import * as DropdownMenuPrimitive from "@radix-ui/react-dropdown-menu"; +import { CheckIcon, ChevronRightIcon, CircleIcon } from "lucide-react"; -import { cn } from "@/lib/utils" +import { cn } from "@/lib/utils"; function DropdownMenu({ ...props }: React.ComponentProps) { - return + return ; } function DropdownMenuPortal({ @@ -15,7 +15,7 @@ function DropdownMenuPortal({ }: React.ComponentProps) { return ( - ) + ); } function DropdownMenuTrigger({ @@ -26,7 +26,7 @@ function DropdownMenuTrigger({ data-slot="dropdown-menu-trigger" {...props} /> - ) + ); } function DropdownMenuContent({ @@ -46,7 +46,7 @@ function DropdownMenuContent({ {...props} /> - ) + ); } function DropdownMenuGroup({ @@ -54,7 +54,7 @@ function DropdownMenuGroup({ }: React.ComponentProps) { return ( - ) + ); } function DropdownMenuItem({ @@ -63,8 +63,8 @@ function DropdownMenuItem({ variant = "default", ...props }: React.ComponentProps & { - inset?: boolean - variant?: "default" | "destructive" + inset?: boolean; + variant?: "default" | "destructive"; }) { return ( - ) + ); } function DropdownMenuCheckboxItem({ @@ -103,7 +103,7 @@ function DropdownMenuCheckboxItem({ {children} - ) + ); } function DropdownMenuRadioGroup({ @@ -114,7 +114,7 @@ function DropdownMenuRadioGroup({ data-slot="dropdown-menu-radio-group" {...props} /> - ) + ); } function DropdownMenuRadioItem({ @@ -138,7 +138,7 @@ function DropdownMenuRadioItem({ {children} - ) + ); } function DropdownMenuLabel({ @@ -146,7 +146,7 @@ function DropdownMenuLabel({ inset, ...props }: React.ComponentProps & { - inset?: boolean + inset?: boolean; }) { return ( - ) + ); } function DropdownMenuSeparator({ @@ -171,7 +171,7 @@ function DropdownMenuSeparator({ className={cn("bg-border -mx-1 my-1 h-px", className)} {...props} /> - ) + ); } function DropdownMenuShortcut({ @@ -187,13 +187,13 @@ function DropdownMenuShortcut({ )} {...props} /> - ) + ); } function DropdownMenuSub({ ...props }: React.ComponentProps) { - return + return ; } function DropdownMenuSubTrigger({ @@ -202,7 +202,7 @@ function DropdownMenuSubTrigger({ children, ...props }: React.ComponentProps & { - inset?: boolean + inset?: boolean; }) { return ( - ) + ); } function DropdownMenuSubContent({ @@ -233,7 +233,7 @@ function DropdownMenuSubContent({ )} {...props} /> - ) + ); } export { @@ -252,4 +252,4 @@ export { DropdownMenuSub, DropdownMenuSubTrigger, DropdownMenuSubContent, -} +}; diff --git a/src/constants/models.ts b/src/constants/models.ts index 0074877..0c83b51 100644 --- a/src/constants/models.ts +++ b/src/constants/models.ts @@ -6,7 +6,8 @@ export interface ModelOption { tag?: string; } -export const MODEL_OPTIONS: Record = { +type RegularModelProvider = Exclude; +export const MODEL_OPTIONS: Record = { openai: [ { name: "gpt-4.1", @@ -52,7 +53,7 @@ export const MODEL_OPTIONS: Record = { }; export const PROVIDERS: Record< - ModelProvider, + RegularModelProvider, { name: string; displayName: string; diff --git a/src/hooks/useLocalModels.ts b/src/hooks/useLocalModels.ts new file mode 100644 index 0000000..75e1996 --- /dev/null +++ b/src/hooks/useLocalModels.ts @@ -0,0 +1,43 @@ +import { useCallback } from "react"; +import { useAtom } from "jotai"; +import { + localModelsAtom, + localModelsLoadingAtom, + localModelsErrorAtom, +} from "@/atoms/localModelsAtoms"; +import { IpcClient } from "@/ipc/ipc_client"; + +export function useLocalModels() { + const [models, setModels] = useAtom(localModelsAtom); + const [loading, setLoading] = useAtom(localModelsLoadingAtom); + const [error, setError] = useAtom(localModelsErrorAtom); + + const ipcClient = IpcClient.getInstance(); + + /** + * Load local models from Ollama + */ + const loadModels = useCallback(async () => { + setLoading(true); + try { + const modelList = await ipcClient.listLocalModels(); + setModels(modelList); + setError(null); + + return modelList; + } catch (error) { + console.error("Error loading local models:", error); + setError(error instanceof Error ? error : new Error(String(error))); + return []; + } finally { + setLoading(false); + } + }, [ipcClient, setModels, setError, setLoading]); + + return { + models, + loading, + error, + loadModels, + }; +} diff --git a/src/ipc/handlers/local_model_handlers.ts b/src/ipc/handlers/local_model_handlers.ts new file mode 100644 index 0000000..b0234ff --- /dev/null +++ b/src/ipc/handlers/local_model_handlers.ts @@ -0,0 +1,80 @@ +import { ipcMain } from "electron"; +import log from "electron-log"; +import { LocalModelListResponse, LocalModel } from "../ipc_types"; + +const logger = log.scope("local_model_handlers"); +const OLLAMA_API_URL = "http://localhost:11434"; + +interface OllamaModel { + name: string; + modified_at: string; + size: number; + digest: string; + details: { + format: string; + family: string; + families: string[]; + parameter_size: string; + quantization_level: string; + }; +} + +export function registerLocalModelHandlers() { + // Get list of models from Ollama + ipcMain.handle( + "local-models:list", + async (): Promise => { + try { + const response = await fetch(`${OLLAMA_API_URL}/api/tags`); + + if (!response.ok) { + throw new Error(`Failed to fetch models: ${response.statusText}`); + } + + const data = await response.json(); + const ollamaModels: OllamaModel[] = data.models || []; + + // Transform the data to return just what we need + const models: LocalModel[] = ollamaModels.map((model) => { + // Extract display name by cleaning up the model name + // For names like "llama2:latest" we want to show "Llama 2" + let displayName = model.name.split(":")[0]; // Remove tags like ":latest" + + // Capitalize and add spaces for readability + displayName = displayName + .replace(/-/g, " ") + .replace(/(\d+)/, " $1 ") // Add spaces around numbers + .split(" ") + .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) + .join(" ") + .trim(); + + return { + modelName: model.name, // The actual model name used for API calls + displayName, // The user-friendly name + }; + }); + + logger.info( + `Successfully fetched ${models.length} local models from Ollama` + ); + return { models, error: null }; + } catch (error) { + if ( + error instanceof TypeError && + (error as Error).message.includes("fetch failed") + ) { + logger.error("Could not connect to Ollama. Is it running?"); + return { + models: [], + error: + "Could not connect to Ollama. Make sure it's running at http://localhost:11434", + }; + } + + logger.error("Error fetching local models:", error); + return { models: [], error: "Failed to fetch models from Ollama" }; + } + } + ); +} diff --git a/src/ipc/ipc_client.ts b/src/ipc/ipc_client.ts index 0dd9948..8141629 100644 --- a/src/ipc/ipc_client.ts +++ b/src/ipc/ipc_client.ts @@ -17,6 +17,8 @@ import type { Message, Version, SystemDebugInfo, + LocalModel, + LocalModelListResponse, } from "./ipc_types"; import type { CodeProposal, ProposalResult } from "@/lib/schemas"; import { showError } from "@/lib/toast"; @@ -729,14 +731,24 @@ export class IpcClient { // Get system debug information public async getSystemDebugInfo(): Promise { try { - const result = await this.ipcRenderer.invoke("get-system-debug-info"); - return result; + const data = await this.ipcRenderer.invoke("get-system-debug-info"); + return data; } catch (error) { showError(error); throw error; } } + public async listLocalModels(): Promise { + const { models, error } = (await this.ipcRenderer.invoke( + "local-models:list" + )) as LocalModelListResponse; + if (error) { + throw new Error(error); + } + return models; + } + // Listen for deep link events public onDeepLinkReceived( callback: (data: DeepLinkData) => void diff --git a/src/ipc/ipc_host.ts b/src/ipc/ipc_host.ts index 64c1e69..08d7575 100644 --- a/src/ipc/ipc_host.ts +++ b/src/ipc/ipc_host.ts @@ -9,6 +9,7 @@ import { registerNodeHandlers } from "./handlers/node_handlers"; import { registerProposalHandlers } from "./handlers/proposal_handlers"; import { registerDebugHandlers } from "./handlers/debug_handlers"; import { registerSupabaseHandlers } from "./handlers/supabase_handlers"; +import { registerLocalModelHandlers } from "./handlers/local_model_handlers"; export function registerIpcHandlers() { // Register all IPC handlers by category @@ -23,4 +24,5 @@ export function registerIpcHandlers() { registerProposalHandlers(); registerDebugHandlers(); registerSupabaseHandlers(); + registerLocalModelHandlers(); } diff --git a/src/ipc/ipc_types.ts b/src/ipc/ipc_types.ts index e71e3d9..67d6919 100644 --- a/src/ipc/ipc_types.ts +++ b/src/ipc/ipc_types.ts @@ -91,3 +91,13 @@ export interface SystemDebugInfo { architecture: string; logs: string; } + +export interface LocalModel { + modelName: string; // Name used for API calls (e.g., "llama2:latest") + displayName: string; // User-friendly name (e.g., "Llama 2") +} + +export type LocalModelListResponse = { + models: LocalModel[]; + error: string | null; +}; diff --git a/src/ipc/utils/get_model_client.ts b/src/ipc/utils/get_model_client.ts index d02bb70..4f089c4 100644 --- a/src/ipc/utils/get_model_client.ts +++ b/src/ipc/utils/get_model_client.ts @@ -2,6 +2,8 @@ import { createOpenAI } from "@ai-sdk/openai"; import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google"; import { createAnthropic } from "@ai-sdk/anthropic"; import { createOpenRouter } from "@openrouter/ai-sdk-provider"; +import { createOllama } from "ollama-ai-provider"; + import type { LargeLanguageModel, UserSettings } from "../../lib/schemas"; import { PROVIDER_TO_ENV_VAR, AUTO_MODELS } from "../../constants/models"; import { getEnvVar } from "./read_env"; @@ -56,6 +58,10 @@ export function getModelClient( const provider = createOpenRouter({ apiKey }); return provider(model.name); } + case "ollama": { + const provider = createOllama(); + return provider(model.name); + } default: { // Ensure exhaustive check if more providers are added const _exhaustiveCheck: never = model.provider; diff --git a/src/lib/schemas.ts b/src/lib/schemas.ts index e29e250..0bbe639 100644 --- a/src/lib/schemas.ts +++ b/src/lib/schemas.ts @@ -35,6 +35,7 @@ export const ModelProviderSchema = z.enum([ "google", "auto", "openrouter", + "ollama", ]); /** diff --git a/src/preload.ts b/src/preload.ts index 4ec1f19..279105f 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -45,6 +45,7 @@ const validInvokeChannels = [ "supabase:list-projects", "supabase:set-app-project", "supabase:unset-app-project", + "local-models:list", ] as const; // Add valid receive channels