Simplify provider logic and migrate getContextWindow (#142)

This commit is contained in:
Will Chen
2025-05-12 22:18:49 -07:00
committed by GitHub
parent 11ba46db38
commit 877c8f7f4f
7 changed files with 48 additions and 67 deletions

View File

@@ -139,26 +139,16 @@ export async function getLanguageModelProviders(): Promise<
* @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects.
*/
export async function getLanguageModels(
obj:
| {
customProviderId: string;
// builtinProviderId?: undefined;
}
| {
builtinProviderId: string;
// customProviderId?: undefined;
},
): Promise<LanguageModel[]> {
export async function getLanguageModels({
providerId,
}: {
providerId: string;
}): Promise<LanguageModel[]> {
const allProviders = await getLanguageModelProviders();
const provider = allProviders.find(
(p) =>
p.id === (obj as { customProviderId: string }).customProviderId ||
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
);
const provider = allProviders.find((p) => p.id === providerId);
if (!provider) {
console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`);
console.warn(`Provider with ID "${providerId}" not found.`);
return [];
}
@@ -177,9 +167,9 @@ export async function getLanguageModels(
})
.from(languageModelsSchema)
.where(
"customProviderId" in obj
? eq(languageModelsSchema.customProviderId, obj.customProviderId)
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId),
isCustomProvider({ providerId })
? eq(languageModelsSchema.customProviderId, providerId)
: eq(languageModelsSchema.builtinProviderId, providerId),
);
customModels = customModelsDb.map((model) => ({
@@ -192,7 +182,7 @@ export async function getLanguageModels(
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`,
`Error fetching custom models for provider "${providerId}" from DB:`,
error,
);
// Continue with empty custom models array
@@ -200,7 +190,6 @@ export async function getLanguageModels(
// If it's a cloud provider, also get the hardcoded models
let hardcodedModels: LanguageModel[] = [];
const providerId = provider.id;
if (provider.type === "cloud") {
if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
@@ -245,11 +234,7 @@ export async function getLanguageModelsByProviders(): Promise<
const modelPromises = providers
.filter((p) => p.type !== "local")
.map(async (provider) => {
const models = await getLanguageModels(
provider.type === "cloud"
? { builtinProviderId: provider.id }
: { customProviderId: provider.id },
);
const models = await getLanguageModels({ providerId: provider.id });
return { providerId: provider.id, models };
});
@@ -264,3 +249,9 @@ export async function getLanguageModelsByProviders(): Promise<
return record;
}
export function isCustomProvider({ providerId }: { providerId: string }) {
return providerId.startsWith(CUSTOM_PROVIDER_PREFIX);
}
export const CUSTOM_PROVIDER_PREFIX = "custom::";