Precise custom model selection & simplify language model/provider log… (#147)

…ic (no merging)
This commit is contained in:
Will Chen
2025-05-12 23:24:39 -07:00
committed by GitHub
parent b45dff3862
commit ee5865dcf8
4 changed files with 29 additions and 52 deletions

View File

@@ -84,6 +84,13 @@ export function ModelPicker() {
// For cloud models, look up in the modelsByProviders data // For cloud models, look up in the modelsByProviders data
if (modelsByProviders && modelsByProviders[selectedModel.provider]) { if (modelsByProviders && modelsByProviders[selectedModel.provider]) {
const customFoundModel = modelsByProviders[selectedModel.provider].find(
(model) =>
model.type === "custom" && model.id === selectedModel.customModelId,
);
if (customFoundModel) {
return customFoundModel.displayName;
}
const foundModel = modelsByProviders[selectedModel.provider].find( const foundModel = modelsByProviders[selectedModel.provider].find(
(model) => model.apiName === selectedModel.name, (model) => model.apiName === selectedModel.name,
); );
@@ -227,9 +234,12 @@ export function ModelPicker() {
: "" : ""
} }
onClick={() => { onClick={() => {
const customModelId =
model.type === "custom" ? model.id : undefined;
onModelSelect({ onModelSelect({
name: model.apiName, name: model.apiName,
provider: providerId, provider: providerId,
customModelId,
}); });
setOpen(false); setOpen(false);
}} }}

View File

@@ -191,35 +191,7 @@ export async function getLanguageModelProviders(): Promise<
} }
} }
// Merge lists: custom providers take precedence return [...hardcodedProviders, ...customProvidersMap.values()];
const mergedProvidersMap = new Map<string, LanguageModelProvider>();
// Add all hardcoded providers first
for (const hp of hardcodedProviders) {
mergedProvidersMap.set(hp.id, hp);
}
// Add/overwrite with custom providers from DB
for (const [id, cp] of customProvidersMap) {
const existingProvider = mergedProvidersMap.get(id);
if (existingProvider) {
// If exists, merge. Custom fields take precedence.
mergedProvidersMap.set(id, {
...existingProvider, // start with hardcoded
...cp, // override with custom where defined
id: cp.id, // ensure custom id is used
name: cp.name, // ensure custom name is used
type: "custom", // explicitly set type to custom
apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl,
envVarName: cp.envVarName ?? existingProvider.envVarName,
});
} else {
// If it doesn't exist in hardcoded, just add the custom one
mergedProvidersMap.set(id, cp);
}
}
return Array.from(mergedProvidersMap.values());
} }
/** /**
@@ -293,20 +265,7 @@ export async function getLanguageModels({
} }
} }
// Merge the models, with custom models taking precedence over hardcoded ones return [...hardcodedModels, ...customModels];
const mergedModelsMap = new Map<string, LanguageModel>();
// Add hardcoded models first
for (const model of hardcodedModels) {
mergedModelsMap.set(model.apiName, model);
}
// Then override with custom models
for (const model of customModels) {
mergedModelsMap.set(model.apiName, model);
}
return Array.from(mergedModelsMap.values());
} }
/** /**

View File

@@ -20,13 +20,7 @@ const DEFAULT_CONTEXT_WINDOW = 128_000;
export async function getContextWindow() { export async function getContextWindow() {
const settings = readSettings(); const settings = readSettings();
const model = settings.selectedModel; const modelOption = await findLanguageModel(settings.selectedModel);
const models = await getLanguageModels({
providerId: model.provider,
});
const modelOption = models.find((m) => m.apiName === model.name);
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW; return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
} }
@@ -34,10 +28,23 @@ export async function getContextWindow() {
const DEFAULT_MAX_TOKENS = 8_000; const DEFAULT_MAX_TOKENS = 8_000;
export async function getMaxTokens(model: LargeLanguageModel) { export async function getMaxTokens(model: LargeLanguageModel) {
const modelOption = await findLanguageModel(model);
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS;
}
async function findLanguageModel(model: LargeLanguageModel) {
const models = await getLanguageModels({ const models = await getLanguageModels({
providerId: model.provider, providerId: model.provider,
}); });
const modelOption = models.find((m) => m.apiName === model.name); if (model.customModelId) {
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS; const customModel = models.find(
(m) => m.type === "custom" && m.id === model.customModelId,
);
if (customModel) {
return customModel;
}
}
return models.find((m) => m.apiName === model.name);
} }

View File

@@ -46,6 +46,7 @@ export const cloudProviders = providers.filter(
export const LargeLanguageModelSchema = z.object({ export const LargeLanguageModelSchema = z.object({
name: z.string(), name: z.string(),
provider: z.string(), provider: z.string(),
customModelId: z.number().optional(),
}); });
/** /**