Precise custom model selection & simplify language model/provider log… (#147)

…ic (no merging)
This commit is contained in:
Will Chen
2025-05-12 23:24:39 -07:00
committed by GitHub
parent b45dff3862
commit ee5865dcf8
4 changed files with 29 additions and 52 deletions

View File

@@ -84,6 +84,13 @@ export function ModelPicker() {
// For cloud models, look up in the modelsByProviders data
if (modelsByProviders && modelsByProviders[selectedModel.provider]) {
const customFoundModel = modelsByProviders[selectedModel.provider].find(
(model) =>
model.type === "custom" && model.id === selectedModel.customModelId,
);
if (customFoundModel) {
return customFoundModel.displayName;
}
const foundModel = modelsByProviders[selectedModel.provider].find(
(model) => model.apiName === selectedModel.name,
);
@@ -227,9 +234,12 @@ export function ModelPicker() {
: ""
}
onClick={() => {
const customModelId =
model.type === "custom" ? model.id : undefined;
onModelSelect({
name: model.apiName,
provider: providerId,
customModelId,
});
setOpen(false);
}}

View File

@@ -191,35 +191,7 @@ export async function getLanguageModelProviders(): Promise<
}
}
// Merge lists: custom providers take precedence
const mergedProvidersMap = new Map<string, LanguageModelProvider>();
// Add all hardcoded providers first
for (const hp of hardcodedProviders) {
mergedProvidersMap.set(hp.id, hp);
}
// Add/overwrite with custom providers from DB
for (const [id, cp] of customProvidersMap) {
const existingProvider = mergedProvidersMap.get(id);
if (existingProvider) {
// If exists, merge. Custom fields take precedence.
mergedProvidersMap.set(id, {
...existingProvider, // start with hardcoded
...cp, // override with custom where defined
id: cp.id, // ensure custom id is used
name: cp.name, // ensure custom name is used
type: "custom", // explicitly set type to custom
apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl,
envVarName: cp.envVarName ?? existingProvider.envVarName,
});
} else {
// If it doesn't exist in hardcoded, just add the custom one
mergedProvidersMap.set(id, cp);
}
}
return Array.from(mergedProvidersMap.values());
return [...hardcodedProviders, ...customProvidersMap.values()];
}
/**
@@ -293,20 +265,7 @@ export async function getLanguageModels({
}
}
// Merge the models, with custom models taking precedence over hardcoded ones
const mergedModelsMap = new Map<string, LanguageModel>();
// Add hardcoded models first
for (const model of hardcodedModels) {
mergedModelsMap.set(model.apiName, model);
}
// Then override with custom models
for (const model of customModels) {
mergedModelsMap.set(model.apiName, model);
}
return Array.from(mergedModelsMap.values());
return [...hardcodedModels, ...customModels];
}
/**

View File

@@ -20,13 +20,7 @@ const DEFAULT_CONTEXT_WINDOW = 128_000;
export async function getContextWindow() {
const settings = readSettings();
const model = settings.selectedModel;
const models = await getLanguageModels({
providerId: model.provider,
});
const modelOption = models.find((m) => m.apiName === model.name);
const modelOption = await findLanguageModel(settings.selectedModel);
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
}
@@ -34,10 +28,23 @@ export async function getContextWindow() {
const DEFAULT_MAX_TOKENS = 8_000;
export async function getMaxTokens(model: LargeLanguageModel) {
const modelOption = await findLanguageModel(model);
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS;
}
async function findLanguageModel(model: LargeLanguageModel) {
const models = await getLanguageModels({
providerId: model.provider,
});
const modelOption = models.find((m) => m.apiName === model.name);
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS;
if (model.customModelId) {
const customModel = models.find(
(m) => m.type === "custom" && m.id === model.customModelId,
);
if (customModel) {
return customModel;
}
}
return models.find((m) => m.apiName === model.name);
}

View File

@@ -46,6 +46,7 @@ export const cloudProviders = providers.filter(
export const LargeLanguageModelSchema = z.object({
name: z.string(),
provider: z.string(),
customModelId: z.number().optional(),
});
/**