Fix DB schemas (#138)
This commit is contained in:
@@ -139,23 +139,72 @@ export async function getLanguageModelProviders(): Promise<
|
||||
* @param obj An object containing the providerId.
|
||||
* @returns A promise that resolves to an array of LanguageModel objects.
|
||||
*/
|
||||
export async function getLanguageModels(obj: {
|
||||
providerId: string;
|
||||
}): Promise<LanguageModel[]> {
|
||||
const { providerId } = obj;
|
||||
export async function getLanguageModels(
|
||||
obj:
|
||||
| {
|
||||
customProviderId: string;
|
||||
// builtinProviderId?: undefined;
|
||||
}
|
||||
| {
|
||||
builtinProviderId: string;
|
||||
// customProviderId?: undefined;
|
||||
},
|
||||
): Promise<LanguageModel[]> {
|
||||
const allProviders = await getLanguageModelProviders();
|
||||
const provider = allProviders.find((p) => p.id === providerId);
|
||||
const provider = allProviders.find(
|
||||
(p) =>
|
||||
p.id === (obj as { customProviderId: string }).customProviderId ||
|
||||
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
|
||||
);
|
||||
|
||||
if (!provider) {
|
||||
console.warn(`Provider with ID "${providerId}" not found.`);
|
||||
console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`);
|
||||
return [];
|
||||
}
|
||||
|
||||
// Get custom models from DB for all provider types
|
||||
let customModels: LanguageModel[] = [];
|
||||
|
||||
try {
|
||||
const customModelsDb = await db
|
||||
.select({
|
||||
id: languageModelsSchema.id,
|
||||
displayName: languageModelsSchema.displayName,
|
||||
apiName: languageModelsSchema.apiName,
|
||||
description: languageModelsSchema.description,
|
||||
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
||||
contextWindow: languageModelsSchema.context_window,
|
||||
})
|
||||
.from(languageModelsSchema)
|
||||
.where(
|
||||
"customProviderId" in obj
|
||||
? eq(languageModelsSchema.customProviderId, obj.customProviderId)
|
||||
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId),
|
||||
);
|
||||
|
||||
customModels = customModelsDb.map((model) => ({
|
||||
...model,
|
||||
description: model.description ?? "",
|
||||
tag: undefined,
|
||||
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
||||
contextWindow: model.contextWindow ?? undefined,
|
||||
type: "custom",
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`,
|
||||
error,
|
||||
);
|
||||
// Continue with empty custom models array
|
||||
}
|
||||
|
||||
// If it's a cloud provider, also get the hardcoded models
|
||||
let hardcodedModels: LanguageModel[] = [];
|
||||
const providerId = provider.id;
|
||||
if (provider.type === "cloud") {
|
||||
// Check if providerId is a valid key for MODEL_OPTIONS
|
||||
if (providerId in MODEL_OPTIONS) {
|
||||
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
||||
return models.map((model) => ({
|
||||
hardcodedModels = models.map((model) => ({
|
||||
...model,
|
||||
apiName: model.name,
|
||||
type: "cloud",
|
||||
@@ -164,49 +213,54 @@ export async function getLanguageModels(obj: {
|
||||
console.warn(
|
||||
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
} else if (provider.type === "custom") {
|
||||
// Fetch models from the database for this custom provider
|
||||
// Assuming a language_models table with necessary columns and provider_id foreign key
|
||||
try {
|
||||
const customModelsDb = await db
|
||||
.select({
|
||||
id: languageModelsSchema.id,
|
||||
// Map DB columns to LanguageModel fields
|
||||
displayName: languageModelsSchema.displayName,
|
||||
apiName: languageModelsSchema.apiName,
|
||||
// No display_name in DB, use name instead
|
||||
description: languageModelsSchema.description,
|
||||
// No tag in DB
|
||||
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
||||
contextWindow: languageModelsSchema.context_window,
|
||||
})
|
||||
.from(languageModelsSchema)
|
||||
.where(eq(languageModelsSchema.provider_id, providerId)); // Assuming eq is imported or available
|
||||
|
||||
return customModelsDb.map((model) => ({
|
||||
...model,
|
||||
// Ensure possibly null fields are handled, provide defaults or undefined if needed
|
||||
description: model.description ?? "",
|
||||
tag: undefined, // No tag for custom models from DB
|
||||
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
||||
contextWindow: model.contextWindow ?? undefined,
|
||||
type: "custom",
|
||||
}));
|
||||
} catch (error) {
|
||||
console.error(
|
||||
`Error fetching custom models for provider "${providerId}" from DB:`,
|
||||
error,
|
||||
);
|
||||
// Depending on desired behavior, could throw, return empty, or return a specific error state
|
||||
return [];
|
||||
}
|
||||
} else {
|
||||
// Handle other types like "local" if necessary, currently ignored
|
||||
console.warn(
|
||||
`Provider type "${provider.type}" not handled for model fetching.`,
|
||||
);
|
||||
return [];
|
||||
}
|
||||
|
||||
// Merge the models, with custom models taking precedence over hardcoded ones
|
||||
const mergedModelsMap = new Map<string, LanguageModel>();
|
||||
|
||||
// Add hardcoded models first
|
||||
for (const model of hardcodedModels) {
|
||||
mergedModelsMap.set(model.apiName, model);
|
||||
}
|
||||
|
||||
// Then override with custom models
|
||||
for (const model of customModels) {
|
||||
mergedModelsMap.set(model.apiName, model);
|
||||
}
|
||||
|
||||
return Array.from(mergedModelsMap.values());
|
||||
}
|
||||
|
||||
/**
|
||||
* Fetches all language models grouped by their provider IDs.
|
||||
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
|
||||
*/
|
||||
export async function getLanguageModelsByProviders(): Promise<
|
||||
Record<string, LanguageModel[]>
|
||||
> {
|
||||
const providers = await getLanguageModelProviders();
|
||||
|
||||
// Fetch all models concurrently
|
||||
const modelPromises = providers
|
||||
.filter((p) => p.type !== "local")
|
||||
.map(async (provider) => {
|
||||
const models = await getLanguageModels(
|
||||
provider.type === "cloud"
|
||||
? { builtinProviderId: provider.id }
|
||||
: { customProviderId: provider.id },
|
||||
);
|
||||
return { providerId: provider.id, models };
|
||||
});
|
||||
|
||||
// Wait for all requests to complete
|
||||
const results = await Promise.all(modelPromises);
|
||||
|
||||
// Convert the array of results to a record
|
||||
const record: Record<string, LanguageModel[]> = {};
|
||||
for (const result of results) {
|
||||
record[result.providerId] = result.models;
|
||||
}
|
||||
|
||||
return record;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user