From d1027622b4ca7f704a93719530571b387f72cf6a Mon Sep 17 00:00:00 2001 From: Will Chen Date: Mon, 12 May 2025 22:29:36 -0700 Subject: [PATCH] Loosen provider type to a string (#144) --- src/components/ModelPicker.tsx | 4 ++-- src/ipc/shared/language_model_helpers.ts | 12 +++--------- src/lib/schemas.ts | 11 +---------- 3 files changed, 6 insertions(+), 21 deletions(-) diff --git a/src/components/ModelPicker.tsx b/src/components/ModelPicker.tsx index f3a8f73..1054483 100644 --- a/src/components/ModelPicker.tsx +++ b/src/components/ModelPicker.tsx @@ -1,4 +1,4 @@ -import type { LargeLanguageModel, ModelProvider } from "@/lib/schemas"; +import type { LargeLanguageModel } from "@/lib/schemas"; import { Button } from "@/components/ui/button"; import { Tooltip, @@ -216,7 +216,7 @@ export function ModelPicker({ onClick={() => { onModelSelect({ name: model.apiName, - provider: providerId as ModelProvider, + provider: providerId, }); setOpen(false); }} diff --git a/src/ipc/shared/language_model_helpers.ts b/src/ipc/shared/language_model_helpers.ts index dfbaf73..f2ae288 100644 --- a/src/ipc/shared/language_model_helpers.ts +++ b/src/ipc/shared/language_model_helpers.ts @@ -5,7 +5,6 @@ import { } from "@/db/schema"; import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types"; import { eq } from "drizzle-orm"; -import { ModelProvider } from "@/lib/schemas"; export interface ModelOption { name: string; @@ -16,12 +15,7 @@ export interface ModelOption { contextWindow?: number; } -export type RegularModelProvider = Exclude< - ModelProvider, - "ollama" | "lmstudio" ->; - -export const MODEL_OPTIONS: Record = { +export const MODEL_OPTIONS: Record = { openai: [ // https://platform.openai.com/docs/models/gpt-4.1 { @@ -109,7 +103,7 @@ export const PROVIDER_TO_ENV_VAR: Record = { }; export const PROVIDERS: Record< - RegularModelProvider, + string, { displayName: string; hasFreeTier?: boolean; @@ -286,7 +280,7 @@ export async function getLanguageModels({ let hardcodedModels: LanguageModel[] = []; if (provider.type === "cloud") { if (providerId in MODEL_OPTIONS) { - const models = MODEL_OPTIONS[providerId as RegularModelProvider] || []; + const models = MODEL_OPTIONS[providerId] || []; hardcodedModels = models.map((model) => ({ ...model, apiName: model.name, diff --git a/src/lib/schemas.ts b/src/lib/schemas.ts index 422e3ce..84140f8 100644 --- a/src/lib/schemas.ts +++ b/src/lib/schemas.ts @@ -35,26 +35,17 @@ const providers = [ "ollama", "lmstudio", ] as const; -/** - * Zod schema for model provider - */ -export const ModelProviderSchema = z.enum(providers); export const cloudProviders = providers.filter( (provider) => provider !== "ollama" && provider !== "lmstudio", ); -/** - * Type derived from the ModelProviderSchema - */ -export type ModelProvider = z.infer; - /** * Zod schema for large language model configuration */ export const LargeLanguageModelSchema = z.object({ name: z.string(), - provider: ModelProviderSchema, + provider: z.string(), }); /**