diff --git a/package-lock.json b/package-lock.json index ec722f4..7310b9c 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12,6 +12,7 @@ "@ai-sdk/anthropic": "^1.2.8", "@ai-sdk/google": "^1.2.10", "@ai-sdk/openai": "^1.3.7", + "@ai-sdk/openai-compatible": "^0.2.13", "@biomejs/biome": "^1.9.4", "@dyad-sh/supabase-management-js": "v1.0.0", "@monaco-editor/react": "^4.7.0-rc.0", @@ -104,6 +105,51 @@ "node": ">=20" } }, + "node_modules/@ai-sdk/openai-compatible": { + "version": "0.2.13", + "resolved": "https://registry.npmjs.org/@ai-sdk/openai-compatible/-/openai-compatible-0.2.13.tgz", + "integrity": "sha512-tB+lL8Z3j0qDod/mvxwjrPhbLUHp/aQW+NvMoJaqeTtP+Vmv5qR800pncGczxn5WN0pllQm+7aIRDnm69XeSbg==", + "dev": true, + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "@ai-sdk/provider-utils": "2.2.7" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.0.0" + } + }, + "node_modules/@ai-sdk/openai-compatible/node_modules/@ai-sdk/provider": { + "version": "1.1.3", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider/-/provider-1.1.3.tgz", + "integrity": "sha512-qZMxYJ0qqX/RfnuIaab+zp8UAeJn/ygXXAffR5I4N0n1IrvA6qBsjc8hXLmBiMV2zoXlifkacF7sEFnYnjBcqg==", + "dev": true, + "dependencies": { + "json-schema": "^0.4.0" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/@ai-sdk/openai-compatible/node_modules/@ai-sdk/provider-utils": { + "version": "2.2.7", + "resolved": "https://registry.npmjs.org/@ai-sdk/provider-utils/-/provider-utils-2.2.7.tgz", + "integrity": "sha512-kM0xS3GWg3aMChh9zfeM+80vEZfXzR3JEUBdycZLtbRZ2TRT8xOj3WodGHPb06sUK5yD7pAXC/P7ctsi2fvUGQ==", + "dev": true, + "dependencies": { + "@ai-sdk/provider": "1.1.3", + "nanoid": "^3.3.8", + "secure-json-parse": "^2.7.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "zod": "^3.23.8" + } + }, "node_modules/@ai-sdk/anthropic": { "version": "1.2.8", "resolved": "https://registry.npmjs.org/@ai-sdk/anthropic/-/anthropic-1.2.8.tgz", diff --git a/package.json b/package.json index 979106e..8b73f04 100644 --- a/package.json +++ b/package.json @@ -63,6 +63,7 @@ "vitest": "^3.1.1" }, "dependencies": { + "@ai-sdk/openai-compatible": "^0.2.13", "@ai-sdk/anthropic": "^1.2.8", "@ai-sdk/google": "^1.2.10", "@ai-sdk/openai": "^1.3.7", @@ -125,4 +126,4 @@ "update-electron-app": "^3.1.1", "uuid": "^11.1.0" } -} \ No newline at end of file +} diff --git a/src/atoms/localModelsAtoms.ts b/src/atoms/localModelsAtoms.ts index 4afddde..f783f0c 100644 --- a/src/atoms/localModelsAtoms.ts +++ b/src/atoms/localModelsAtoms.ts @@ -4,3 +4,7 @@ import { type LocalModel } from "@/ipc/ipc_types"; export const localModelsAtom = atom([]); export const localModelsLoadingAtom = atom(false); export const localModelsErrorAtom = atom(null); + +export const lmStudioModelsAtom = atom([]); +export const lmStudioModelsLoadingAtom = atom(false); +export const lmStudioModelsErrorAtom = atom(null); diff --git a/src/components/ModelPicker.tsx b/src/components/ModelPicker.tsx index ce15a30..24e8ff3 100644 --- a/src/components/ModelPicker.tsx +++ b/src/components/ModelPicker.tsx @@ -19,8 +19,9 @@ import { import { useEffect, useState } from "react"; import { MODEL_OPTIONS } from "@/constants/models"; import { useLocalModels } from "@/hooks/useLocalModels"; +import { useLocalLMSModels } from "@/hooks/useLMStudioModels"; import { ChevronDown } from "lucide-react"; - +import { LocalModel } from "@/ipc/ipc_types"; interface ModelPickerProps { selectedModel: LargeLanguageModel; onModelSelect: (model: LargeLanguageModel) => void; @@ -31,29 +32,47 @@ export function ModelPicker({ onModelSelect, }: ModelPickerProps) { const [open, setOpen] = useState(false); + + // Ollama Models Hook const { - models: localModels, - loading: localModelsLoading, - error: localModelsError, - loadModels, + models: ollamaModels, + loading: ollamaLoading, + error: ollamaError, + loadModels: loadOllamaModels, } = useLocalModels(); - // Load local models when the component mounts or the dropdown opens + // LM Studio Models Hook + const { + models: lmStudioModels, + loading: lmStudioLoading, + error: lmStudioError, + loadModels: loadLMStudioModels, + } = useLocalLMSModels(); + + // Load models when the dropdown opens useEffect(() => { if (open) { - loadModels(); + loadOllamaModels(); + loadLMStudioModels(); } - }, [open, loadModels]); + }, [open, loadOllamaModels, loadLMStudioModels]); // Get display name for the selected model const getModelDisplayName = () => { if (selectedModel.provider === "ollama") { return ( - localModels.find((model) => model.modelName === selectedModel.name) + ollamaModels.find((model: LocalModel) => model.modelName === selectedModel.name) ?.displayName || selectedModel.name ); } + if (selectedModel.provider === "lmstudio") { + return ( + lmStudioModels.find((model: LocalModel) => model.modelName === selectedModel.name) + ?.displayName || selectedModel.name // Fallback to path if not found + ); + } + // Fallback for cloud models return ( MODEL_OPTIONS[selectedModel.provider]?.find( (model) => model.name === selectedModel.name @@ -63,8 +82,8 @@ export function ModelPicker({ const modelDisplayName = getModelDisplayName(); - // Flatten the model options into a single array with provider information - const allModels = Object.entries(MODEL_OPTIONS).flatMap( + // Flatten the cloud model options + const cloudModels = Object.entries(MODEL_OPTIONS).flatMap( ([provider, models]) => models.map((model) => ({ ...model, @@ -72,9 +91,9 @@ export function ModelPicker({ })) ); - // Determine if we have local models available - const hasLocalModels = - !localModelsLoading && !localModelsError && localModels.length > 0; + // Determine availability of local models + const hasOllamaModels = !ollamaLoading && !ollamaError && ollamaModels.length > 0; + const hasLMStudioModels = !lmStudioLoading && !lmStudioError && lmStudioModels.length > 0; return ( @@ -91,12 +110,12 @@ export function ModelPicker({ - + {/* Increased width slightly */} Cloud Models {/* Cloud models */} - {allModels.map((model) => ( + {cloudModels.map((model) => ( - {/* Ollama Models Dropdown */} + {/* Ollama Models SubMenu */}
Local models (Ollama) - {localModelsLoading ? ( + {ollamaLoading ? ( Loading... - ) : !hasLocalModels ? ( + ) : ollamaError ? ( + + Error loading + + ): !hasOllamaModels ? ( None available ) : ( - {localModels.length} models + {ollamaModels.length} models )}
@@ -162,27 +185,32 @@ export function ModelPicker({ Ollama Models - {localModelsLoading ? ( + {ollamaLoading && ollamaModels.length === 0 ? ( // Show loading only if no models are loaded yet
Loading models...
- ) : localModelsError ? ( -
- Error loading models -
- ) : localModels.length === 0 ? ( + ) : ollamaError ? ( +
+
+ Error loading models + + Is Ollama running? + +
+
+ ) : !hasOllamaModels ? (
- No local models available + No local models found - Start Ollama to use local models + Ensure Ollama is running and models are pulled.
) : ( - localModels.map((model) => ( + ollamaModels.map((model: LocalModel) => (
+ + {/* LM Studio Models SubMenu */} + + +
+ Local models (LM Studio) + {lmStudioLoading ? ( + + Loading... + + ) : lmStudioError ? ( + + Error loading + + ) : !hasLMStudioModels ? ( + + None available + + ) : ( + + {lmStudioModels.length} models + + )} +
+
+ + LM Studio Models + + + {lmStudioLoading && lmStudioModels.length === 0 ? ( // Show loading only if no models are loaded yet +
+ Loading models... +
+ ) : lmStudioError ? ( +
+
+ Error loading models + + {lmStudioError.message} {/* Display specific error */} + +
+
+ ) : !hasLMStudioModels ? ( +
+
+ No loaded models found + + Ensure LM Studio is running and models are loaded. + +
+
+ ) : ( + lmStudioModels.map((model: LocalModel) => ( + { + onModelSelect({ + name: model.modelName, + provider: "lmstudio", + }); + setOpen(false); + }} + > +
+ {/* Display the user-friendly name */} + {model.displayName} + {/* Show the path as secondary info */} + + {model.modelName} + +
+
+ )) + )} +
+
+
); diff --git a/src/constants/models.ts b/src/constants/models.ts index 29e33f9..6394b7c 100644 --- a/src/constants/models.ts +++ b/src/constants/models.ts @@ -8,7 +8,7 @@ export interface ModelOption { contextWindow?: number; } -type RegularModelProvider = Exclude; +type RegularModelProvider = Exclude; export const MODEL_OPTIONS: Record = { openai: [ // https://platform.openai.com/docs/models/gpt-4.1 diff --git a/src/hooks/useLMStudioModels.ts b/src/hooks/useLMStudioModels.ts new file mode 100644 index 0000000..0d812c3 --- /dev/null +++ b/src/hooks/useLMStudioModels.ts @@ -0,0 +1,43 @@ +import { useCallback } from "react"; +import { useAtom } from "jotai"; +import { + lmStudioModelsAtom, + lmStudioModelsLoadingAtom, + lmStudioModelsErrorAtom, +} from "@/atoms/localModelsAtoms"; +import { IpcClient } from "@/ipc/ipc_client"; + +export function useLocalLMSModels() { + const [models, setModels] = useAtom(lmStudioModelsAtom); + const [loading, setLoading] = useAtom(lmStudioModelsLoadingAtom); + const [error, setError] = useAtom(lmStudioModelsErrorAtom); + + const ipcClient = IpcClient.getInstance(); + + /** + * Load local models from Ollama + */ + const loadModels = useCallback(async () => { + setLoading(true); + try { + const modelList = await ipcClient.listLocalLMStudioModels(); + setModels(modelList); + setError(null); + + return modelList; + } catch (error) { + console.error("Error loading local LMStudio models:", error); + setError(error instanceof Error ? error : new Error(String(error))); + return []; + } finally { + setLoading(false); + } + }, [ipcClient, setModels, setError, setLoading]); + + return { + models, + loading, + error, + loadModels, + }; +} diff --git a/src/hooks/useLocalModels.ts b/src/hooks/useLocalModels.ts index 75e1996..0d18323 100644 --- a/src/hooks/useLocalModels.ts +++ b/src/hooks/useLocalModels.ts @@ -20,13 +20,13 @@ export function useLocalModels() { const loadModels = useCallback(async () => { setLoading(true); try { - const modelList = await ipcClient.listLocalModels(); + const modelList = await ipcClient.listLocalOllamaModels(); setModels(modelList); setError(null); return modelList; } catch (error) { - console.error("Error loading local models:", error); + console.error("Error loading local Ollama models:", error); setError(error instanceof Error ? error : new Error(String(error))); return []; } finally { diff --git a/src/ipc/handlers/local_model_handlers.ts b/src/ipc/handlers/local_model_handlers.ts index b0234ff..a1ccffd 100644 --- a/src/ipc/handlers/local_model_handlers.ts +++ b/src/ipc/handlers/local_model_handlers.ts @@ -1,80 +1,7 @@ -import { ipcMain } from "electron"; -import log from "electron-log"; -import { LocalModelListResponse, LocalModel } from "../ipc_types"; - -const logger = log.scope("local_model_handlers"); -const OLLAMA_API_URL = "http://localhost:11434"; - -interface OllamaModel { - name: string; - modified_at: string; - size: number; - digest: string; - details: { - format: string; - family: string; - families: string[]; - parameter_size: string; - quantization_level: string; - }; -} +import { registerOllamaHandlers } from "./local_model_ollama_handler"; +import { registerLMStudioHandlers } from "./local_model_lmstudio_handler"; export function registerLocalModelHandlers() { - // Get list of models from Ollama - ipcMain.handle( - "local-models:list", - async (): Promise => { - try { - const response = await fetch(`${OLLAMA_API_URL}/api/tags`); - - if (!response.ok) { - throw new Error(`Failed to fetch models: ${response.statusText}`); - } - - const data = await response.json(); - const ollamaModels: OllamaModel[] = data.models || []; - - // Transform the data to return just what we need - const models: LocalModel[] = ollamaModels.map((model) => { - // Extract display name by cleaning up the model name - // For names like "llama2:latest" we want to show "Llama 2" - let displayName = model.name.split(":")[0]; // Remove tags like ":latest" - - // Capitalize and add spaces for readability - displayName = displayName - .replace(/-/g, " ") - .replace(/(\d+)/, " $1 ") // Add spaces around numbers - .split(" ") - .map((word) => word.charAt(0).toUpperCase() + word.slice(1)) - .join(" ") - .trim(); - - return { - modelName: model.name, // The actual model name used for API calls - displayName, // The user-friendly name - }; - }); - - logger.info( - `Successfully fetched ${models.length} local models from Ollama` - ); - return { models, error: null }; - } catch (error) { - if ( - error instanceof TypeError && - (error as Error).message.includes("fetch failed") - ) { - logger.error("Could not connect to Ollama. Is it running?"); - return { - models: [], - error: - "Could not connect to Ollama. Make sure it's running at http://localhost:11434", - }; - } - - logger.error("Error fetching local models:", error); - return { models: [], error: "Failed to fetch models from Ollama" }; - } - } - ); + registerOllamaHandlers(); + registerLMStudioHandlers(); } diff --git a/src/ipc/handlers/local_model_lmstudio_handler.ts b/src/ipc/handlers/local_model_lmstudio_handler.ts new file mode 100644 index 0000000..dbca1d2 --- /dev/null +++ b/src/ipc/handlers/local_model_lmstudio_handler.ts @@ -0,0 +1,47 @@ +import { ipcMain } from "electron"; +import log from "electron-log"; +import type { LocalModelListResponse, LocalModel } from "../ipc_types"; + +const logger = log.scope("lmstudio_handler"); + +export interface LMStudioModel { + type: "llm" | "embedding" | string; + id: string; + object: string; + publisher: string; + state: "loaded" | "not-loaded"; + max_context_length: number; + quantization: string + compatibility_type: string + arch: string; + [key: string]: any; +} + +export async function fetchLMStudioModels(): Promise { + try { + const modelsResponse: Response = await fetch("http://localhost:1234/api/v0/models"); + if (!modelsResponse.ok) { + throw new Error("Failed to fetch models from LM Studio"); + } + const modelsJson = await modelsResponse.json(); + const downloadedModels = modelsJson.data as LMStudioModel[]; + const models: LocalModel[] = downloadedModels + .filter((model: any) => model.type === "llm") + .map((model: any) => ({ + modelName: model.id, + displayName: model.id, + provider: "lmstudio" + })); + + logger.info(`Successfully fetched ${models.length} models from LM Studio`); + return { models, error: null }; + } catch (error) { + return { models: [], error: "Failed to fetch models from LM Studio" }; + } +} + +export function registerLMStudioHandlers() { + ipcMain.handle('local-models:list-lmstudio', async (): Promise => { + return fetchLMStudioModels(); + }); +} \ No newline at end of file diff --git a/src/ipc/handlers/local_model_ollama_handler.ts b/src/ipc/handlers/local_model_ollama_handler.ts new file mode 100644 index 0000000..6648585 --- /dev/null +++ b/src/ipc/handlers/local_model_ollama_handler.ts @@ -0,0 +1,66 @@ +import { ipcMain } from "electron"; +import log from "electron-log"; +import { LocalModelListResponse, LocalModel } from "../ipc_types"; + +const logger = log.scope("ollama_handler"); + +const OLLAMA_API_URL = "http://localhost:11434"; + +interface OllamaModel { + name: string; + modified_at: string; + size: number; + digest: string; + details: { + format: string; + family: string; + families: string[]; + parameter_size: string; + quantization_level: string; + }; +} + +export async function fetchOllamaModels(): Promise { + try { + const response = await fetch(`${OLLAMA_API_URL}/api/tags`); + if (!response.ok) { + throw new Error(`Failed to fetch model: ${response.statusText}`); + } + + const data = await response.json(); + const ollamaModels: OllamaModel[] = data.models || []; + + const models: LocalModel[] = ollamaModels.map((model: OllamaModel) => { + const displayName = model.name.split(':')[0] + .replace(/-/g, ' ') + .replace(/(\d+)/, ' $1 ') + .split(' ') + .map(word => word.charAt(0).toUpperCase() + word.slice(1)) + .join(' ') + .trim(); + + return { + modelName: model.name, + displayName, + provider: "ollama", + }; + }); + logger.info(`Successfully fetched ${models.length} models from Ollama`); + return { models, error: null }; + } catch (error) { + if (error instanceof TypeError && (error as Error).message.includes('fetch failed')) { + logger.error("Could not connect to Ollama"); + return { + models: [], + error: "Could not connect to Ollama. Make sure it's running at http://localhost:11434" + }; + } + return { models: [], error: "Failed to fetch models from Ollama" }; + } +} + +export function registerOllamaHandlers() { + ipcMain.handle('local-models:list-ollama', async (): Promise => { + return fetchOllamaModels(); + }); +} \ No newline at end of file diff --git a/src/ipc/ipc_client.ts b/src/ipc/ipc_client.ts index a37f159..ab55a47 100644 --- a/src/ipc/ipc_client.ts +++ b/src/ipc/ipc_client.ts @@ -785,14 +785,28 @@ export class IpcClient { } } - public async listLocalModels(): Promise { - const { models, error } = (await this.ipcRenderer.invoke( - "local-models:list" - )) as LocalModelListResponse; - if (error) { - throw new Error(error); + public async listLocalOllamaModels(): Promise { + try { + const response = await this.ipcRenderer.invoke("local-models:list-ollama"); + return response?.models || []; + } catch (error) { + if (error instanceof Error) { + throw new Error(`Failed to fetch Ollama models: ${error.message}`); + } + throw new Error('Failed to fetch Ollama models: Unknown error occurred'); + } + } + + public async listLocalLMStudioModels(): Promise { + try { + const response = await this.ipcRenderer.invoke("local-models:list-lmstudio"); + return response?.models || []; + } catch (error) { + if (error instanceof Error) { + throw new Error(`Failed to fetch LM Studio models: ${error.message}`); + } + throw new Error('Failed to fetch LM Studio models: Unknown error occurred'); } - return models; } // Listen for deep link events diff --git a/src/ipc/ipc_types.ts b/src/ipc/ipc_types.ts index 09ae89e..7fe7c90 100644 --- a/src/ipc/ipc_types.ts +++ b/src/ipc/ipc_types.ts @@ -94,6 +94,7 @@ export interface SystemDebugInfo { } export interface LocalModel { + provider: "ollama" | "lmstudio"; modelName: string; // Name used for API calls (e.g., "llama2:latest") displayName: string; // User-friendly name (e.g., "Llama 2") } diff --git a/src/ipc/utils/get_model_client.ts b/src/ipc/utils/get_model_client.ts index a62edf8..d4030f6 100644 --- a/src/ipc/utils/get_model_client.ts +++ b/src/ipc/utils/get_model_client.ts @@ -3,7 +3,7 @@ import { createGoogleGenerativeAI as createGoogle } from "@ai-sdk/google"; import { createAnthropic } from "@ai-sdk/anthropic"; import { createOpenRouter } from "@openrouter/ai-sdk-provider"; import { createOllama } from "ollama-ai-provider"; - +import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import type { LargeLanguageModel, UserSettings } from "../../lib/schemas"; import { PROVIDER_TO_ENV_VAR, @@ -82,8 +82,14 @@ export function getModelClient( case "ollama": { const provider = createOllama(); return provider(model.name); - } - default: { + } + case "lmstudio": { + // Using LM Studio's OpenAI compatible API + const baseURL = "http://localhost:1234/v1"; // Default LM Studio OpenAI API URL + const provider = createOpenAICompatible({ name: "lmstudio", baseURL }); + return provider(model.name); + } + default: { // Ensure exhaustive check if more providers are added const _exhaustiveCheck: never = model.provider; throw new Error(`Unsupported model provider: ${model.provider}`); diff --git a/src/lib/schemas.ts b/src/lib/schemas.ts index d01a389..0c12b4c 100644 --- a/src/lib/schemas.ts +++ b/src/lib/schemas.ts @@ -36,6 +36,7 @@ export const ModelProviderSchema = z.enum([ "auto", "openrouter", "ollama", + "lmstudio", ]); /** diff --git a/src/preload.ts b/src/preload.ts index 9471fdd..14a73fa 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -49,7 +49,8 @@ const validInvokeChannels = [ "supabase:list-projects", "supabase:set-app-project", "supabase:unset-app-project", - "local-models:list", + "local-models:list-ollama", + "local-models:list-lmstudio", "window:minimize", "window:maximize", "window:close",