Files
moreminimore-vibe/src/ipc/shared/language_model_helpers.ts
2025-05-12 22:20:16 -07:00

352 lines
11 KiB
TypeScript

import { db } from "@/db";
import {
language_model_providers as languageModelProvidersSchema,
language_models as languageModelsSchema,
} from "@/db/schema";
import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types";
import { eq } from "drizzle-orm";
import { ModelProvider } from "@/lib/schemas";
export interface ModelOption {
name: string;
displayName: string;
description: string;
tag?: string;
maxOutputTokens?: number;
contextWindow?: number;
}
export type RegularModelProvider = Exclude<
ModelProvider,
"ollama" | "lmstudio"
>;
export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
openai: [
// https://platform.openai.com/docs/models/gpt-4.1
{
name: "gpt-4.1",
displayName: "GPT 4.1",
description: "OpenAI's flagship model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
},
// https://platform.openai.com/docs/models/gpt-4.1-mini
{
name: "gpt-4.1-mini",
displayName: "GPT 4.1 Mini",
description: "OpenAI's lightweight, but intelligent model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
},
// https://platform.openai.com/docs/models/o3-mini
{
name: "o3-mini",
displayName: "o3 mini",
description: "Reasoning model",
maxOutputTokens: 100_000,
contextWindow: 200_000,
},
],
// https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison-table
anthropic: [
{
name: "claude-3-7-sonnet-latest",
displayName: "Claude 3.7 Sonnet",
description: "Excellent coder",
maxOutputTokens: 64_000,
contextWindow: 200_000,
},
],
google: [
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-pro-preview-03-25
{
name: "gemini-2.5-pro-exp-03-25",
displayName: "Gemini 2.5 Pro",
description: "Experimental version of Google's Gemini 2.5 Pro model",
tag: "Recommended",
// See Flash 2.5 comment below (go 1 below just to be safe, even though it seems OK now).
maxOutputTokens: 65_536 - 1,
// Gemini context window = input token + output token
contextWindow: 1_048_576,
},
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview
{
name: "gemini-2.5-flash-preview-04-17",
displayName: "Gemini 2.5 Flash",
description: "Preview version of Google's Gemini 2.5 Flash model",
// Weirdly for Vertex AI, the output token limit is *exclusive* of the stated limit.
maxOutputTokens: 65_536 - 1,
// Gemini context window = input token + output token
contextWindow: 1_048_576,
},
],
openrouter: [
// https://openrouter.ai/deepseek/deepseek-chat-v3-0324:free
{
name: "deepseek/deepseek-chat-v3-0324:free",
displayName: "DeepSeek v3 (free)",
description: "Use for free (data may be used for training)",
maxOutputTokens: 32_000,
contextWindow: 128_000,
},
],
auto: [
{
name: "auto",
displayName: "Auto",
description: "Automatically selects the best model",
tag: "Default",
},
],
};
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
google: "GEMINI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
};
export const PROVIDERS: Record<
RegularModelProvider,
{
displayName: string;
hasFreeTier?: boolean;
websiteUrl?: string;
gatewayPrefix: string;
}
> = {
openai: {
displayName: "OpenAI",
hasFreeTier: false,
websiteUrl: "https://platform.openai.com/api-keys",
gatewayPrefix: "",
},
anthropic: {
displayName: "Anthropic",
hasFreeTier: false,
websiteUrl: "https://console.anthropic.com/settings/keys",
gatewayPrefix: "anthropic/",
},
google: {
displayName: "Google",
hasFreeTier: true,
websiteUrl: "https://aistudio.google.com/app/apikey",
gatewayPrefix: "gemini/",
},
openrouter: {
displayName: "OpenRouter",
hasFreeTier: true,
websiteUrl: "https://openrouter.ai/settings/keys",
gatewayPrefix: "openrouter/",
},
auto: {
displayName: "Dyad",
websiteUrl: "https://academy.dyad.sh/settings",
gatewayPrefix: "",
},
};
/**
* Fetches language model providers from both the database (custom) and hardcoded constants (cloud),
* merging them with custom providers taking precedence.
* @returns A promise that resolves to an array of LanguageModelProvider objects.
*/
export async function getLanguageModelProviders(): Promise<
LanguageModelProvider[]
> {
// Fetch custom providers from the database
const customProvidersDb = await db
.select()
.from(languageModelProvidersSchema);
const customProvidersMap = new Map<string, LanguageModelProvider>();
for (const cp of customProvidersDb) {
customProvidersMap.set(cp.id, {
id: cp.id,
name: cp.name,
apiBaseUrl: cp.api_base_url,
envVarName: cp.env_var_name ?? undefined,
type: "custom",
// hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema
// They will be undefined unless overridden by hardcoded values if IDs match
});
}
// Get hardcoded cloud providers
const hardcodedProviders: LanguageModelProvider[] = [];
for (const providerKey in PROVIDERS) {
if (Object.prototype.hasOwnProperty.call(PROVIDERS, providerKey)) {
// Ensure providerKey is a key of PROVIDERS
const key = providerKey as keyof typeof PROVIDERS;
const providerDetails = PROVIDERS[key];
if (providerDetails) {
// Ensure providerDetails is not undefined
hardcodedProviders.push({
id: key,
name: providerDetails.displayName,
hasFreeTier: providerDetails.hasFreeTier,
websiteUrl: providerDetails.websiteUrl,
gatewayPrefix: providerDetails.gatewayPrefix,
envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined,
type: "cloud",
// apiBaseUrl is not directly in PROVIDERS
});
}
}
}
// Merge lists: custom providers take precedence
const mergedProvidersMap = new Map<string, LanguageModelProvider>();
// Add all hardcoded providers first
for (const hp of hardcodedProviders) {
mergedProvidersMap.set(hp.id, hp);
}
// Add/overwrite with custom providers from DB
for (const [id, cp] of customProvidersMap) {
const existingProvider = mergedProvidersMap.get(id);
if (existingProvider) {
// If exists, merge. Custom fields take precedence.
mergedProvidersMap.set(id, {
...existingProvider, // start with hardcoded
...cp, // override with custom where defined
id: cp.id, // ensure custom id is used
name: cp.name, // ensure custom name is used
type: "custom", // explicitly set type to custom
apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl,
envVarName: cp.envVarName ?? existingProvider.envVarName,
});
} else {
// If it doesn't exist in hardcoded, just add the custom one
mergedProvidersMap.set(id, cp);
}
}
return Array.from(mergedProvidersMap.values());
}
/**
* Fetches language models for a specific provider.
* @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects.
*/
export async function getLanguageModels({
providerId,
}: {
providerId: string;
}): Promise<LanguageModel[]> {
const allProviders = await getLanguageModelProviders();
const provider = allProviders.find((p) => p.id === providerId);
if (!provider) {
console.warn(`Provider with ID "${providerId}" not found.`);
return [];
}
// Get custom models from DB for all provider types
let customModels: LanguageModel[] = [];
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
displayName: languageModelsSchema.displayName,
apiName: languageModelsSchema.apiName,
description: languageModelsSchema.description,
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(
isCustomProvider({ providerId })
? eq(languageModelsSchema.customProviderId, providerId)
: eq(languageModelsSchema.builtinProviderId, providerId),
);
customModels = customModelsDb.map((model) => ({
...model,
description: model.description ?? "",
tag: undefined,
maxOutputTokens: model.maxOutputTokens ?? undefined,
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${providerId}" from DB:`,
error,
);
// Continue with empty custom models array
}
// If it's a cloud provider, also get the hardcoded models
let hardcodedModels: LanguageModel[] = [];
if (provider.type === "cloud") {
if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
hardcodedModels = models.map((model) => ({
...model,
apiName: model.name,
type: "cloud",
}));
} else {
console.warn(
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
);
}
}
// Merge the models, with custom models taking precedence over hardcoded ones
const mergedModelsMap = new Map<string, LanguageModel>();
// Add hardcoded models first
for (const model of hardcodedModels) {
mergedModelsMap.set(model.apiName, model);
}
// Then override with custom models
for (const model of customModels) {
mergedModelsMap.set(model.apiName, model);
}
return Array.from(mergedModelsMap.values());
}
/**
* Fetches all language models grouped by their provider IDs.
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
*/
export async function getLanguageModelsByProviders(): Promise<
Record<string, LanguageModel[]>
> {
const providers = await getLanguageModelProviders();
// Fetch all models concurrently
const modelPromises = providers
.filter((p) => p.type !== "local")
.map(async (provider) => {
const models = await getLanguageModels({ providerId: provider.id });
return { providerId: provider.id, models };
});
// Wait for all requests to complete
const results = await Promise.all(modelPromises);
// Convert the array of results to a record
const record: Record<string, LanguageModel[]> = {};
for (const result of results) {
record[result.providerId] = result.models;
}
return record;
}
export function isCustomProvider({ providerId }: { providerId: string }) {
return providerId.startsWith(CUSTOM_PROVIDER_PREFIX);
}
export const CUSTOM_PROVIDER_PREFIX = "custom::";