diff --git a/src/components/ModelPicker.tsx b/src/components/ModelPicker.tsx index 5ae28f6..4e033e0 100644 --- a/src/components/ModelPicker.tsx +++ b/src/components/ModelPicker.tsx @@ -84,6 +84,13 @@ export function ModelPicker() { // For cloud models, look up in the modelsByProviders data if (modelsByProviders && modelsByProviders[selectedModel.provider]) { + const customFoundModel = modelsByProviders[selectedModel.provider].find( + (model) => + model.type === "custom" && model.id === selectedModel.customModelId, + ); + if (customFoundModel) { + return customFoundModel.displayName; + } const foundModel = modelsByProviders[selectedModel.provider].find( (model) => model.apiName === selectedModel.name, ); @@ -227,9 +234,12 @@ export function ModelPicker() { : "" } onClick={() => { + const customModelId = + model.type === "custom" ? model.id : undefined; onModelSelect({ name: model.apiName, provider: providerId, + customModelId, }); setOpen(false); }} diff --git a/src/ipc/shared/language_model_helpers.ts b/src/ipc/shared/language_model_helpers.ts index f2ae288..dd12988 100644 --- a/src/ipc/shared/language_model_helpers.ts +++ b/src/ipc/shared/language_model_helpers.ts @@ -191,35 +191,7 @@ export async function getLanguageModelProviders(): Promise< } } - // Merge lists: custom providers take precedence - const mergedProvidersMap = new Map(); - - // Add all hardcoded providers first - for (const hp of hardcodedProviders) { - mergedProvidersMap.set(hp.id, hp); - } - - // Add/overwrite with custom providers from DB - for (const [id, cp] of customProvidersMap) { - const existingProvider = mergedProvidersMap.get(id); - if (existingProvider) { - // If exists, merge. Custom fields take precedence. - mergedProvidersMap.set(id, { - ...existingProvider, // start with hardcoded - ...cp, // override with custom where defined - id: cp.id, // ensure custom id is used - name: cp.name, // ensure custom name is used - type: "custom", // explicitly set type to custom - apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl, - envVarName: cp.envVarName ?? existingProvider.envVarName, - }); - } else { - // If it doesn't exist in hardcoded, just add the custom one - mergedProvidersMap.set(id, cp); - } - } - - return Array.from(mergedProvidersMap.values()); + return [...hardcodedProviders, ...customProvidersMap.values()]; } /** @@ -293,20 +265,7 @@ export async function getLanguageModels({ } } - // Merge the models, with custom models taking precedence over hardcoded ones - const mergedModelsMap = new Map(); - - // Add hardcoded models first - for (const model of hardcodedModels) { - mergedModelsMap.set(model.apiName, model); - } - - // Then override with custom models - for (const model of customModels) { - mergedModelsMap.set(model.apiName, model); - } - - return Array.from(mergedModelsMap.values()); + return [...hardcodedModels, ...customModels]; } /** diff --git a/src/ipc/utils/token_utils.ts b/src/ipc/utils/token_utils.ts index 9b34470..0118ffd 100644 --- a/src/ipc/utils/token_utils.ts +++ b/src/ipc/utils/token_utils.ts @@ -20,13 +20,7 @@ const DEFAULT_CONTEXT_WINDOW = 128_000; export async function getContextWindow() { const settings = readSettings(); - const model = settings.selectedModel; - - const models = await getLanguageModels({ - providerId: model.provider, - }); - - const modelOption = models.find((m) => m.apiName === model.name); + const modelOption = await findLanguageModel(settings.selectedModel); return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW; } @@ -34,10 +28,23 @@ export async function getContextWindow() { const DEFAULT_MAX_TOKENS = 8_000; export async function getMaxTokens(model: LargeLanguageModel) { + const modelOption = await findLanguageModel(model); + return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS; +} + +async function findLanguageModel(model: LargeLanguageModel) { const models = await getLanguageModels({ providerId: model.provider, }); - const modelOption = models.find((m) => m.apiName === model.name); - return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS; + if (model.customModelId) { + const customModel = models.find( + (m) => m.type === "custom" && m.id === model.customModelId, + ); + if (customModel) { + return customModel; + } + } + + return models.find((m) => m.apiName === model.name); } diff --git a/src/lib/schemas.ts b/src/lib/schemas.ts index 84140f8..485253f 100644 --- a/src/lib/schemas.ts +++ b/src/lib/schemas.ts @@ -46,6 +46,7 @@ export const cloudProviders = providers.filter( export const LargeLanguageModelSchema = z.object({ name: z.string(), provider: z.string(), + customModelId: z.number().optional(), }); /**