Files
moreminimore-vibe/src/ipc/shared/language_model_helpers.ts
2025-05-23 15:45:20 -07:00

370 lines
11 KiB
TypeScript

import { db } from "@/db";
import {
language_model_providers as languageModelProvidersSchema,
language_models as languageModelsSchema,
} from "@/db/schema";
import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types";
import { eq } from "drizzle-orm";
export interface ModelOption {
name: string;
displayName: string;
description: string;
tag?: string;
maxOutputTokens?: number;
contextWindow?: number;
supportsTurboEdits?: boolean;
}
export const MODEL_OPTIONS: Record<string, ModelOption[]> = {
openai: [
// https://platform.openai.com/docs/models/gpt-4.1
{
name: "gpt-4.1",
displayName: "GPT 4.1",
description: "OpenAI's flagship model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
supportsTurboEdits: true,
},
// https://platform.openai.com/docs/models/gpt-4.1-mini
{
name: "gpt-4.1-mini",
displayName: "GPT 4.1 Mini",
description: "OpenAI's lightweight, but intelligent model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
},
// https://platform.openai.com/docs/models/o3-mini
{
name: "o3-mini",
displayName: "o3 mini",
description: "Reasoning model",
maxOutputTokens: 100_000,
contextWindow: 200_000,
},
// https://platform.openai.com/docs/models/o4-mini
{
name: "o4-mini",
displayName: "o4 mini",
description: "Reasoning model",
maxOutputTokens: 100_000,
contextWindow: 200_000,
},
],
// https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison-table
anthropic: [
{
name: "claude-sonnet-4-20250514",
displayName: "Claude 4 Sonnet",
description: "Excellent coder",
maxOutputTokens: 64_000,
contextWindow: 200_000,
supportsTurboEdits: true,
},
{
name: "claude-3-7-sonnet-latest",
displayName: "Claude 3.7 Sonnet",
description: "Excellent coder",
maxOutputTokens: 64_000,
contextWindow: 200_000,
supportsTurboEdits: true,
},
{
name: "claude-3-5-sonnet-20241022",
displayName: "Claude 3.5 Sonnet",
description: "Good coder, excellent at following instructions",
maxOutputTokens: 8_000,
contextWindow: 200_000,
supportsTurboEdits: true,
},
{
name: "claude-3-5-haiku-20241022",
displayName: "Claude 3.5 Haiku",
description: "Lightweight coder",
maxOutputTokens: 8_000,
contextWindow: 200_000,
},
],
google: [
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-pro-preview-03-25
{
name: "gemini-2.5-pro-preview-05-06",
displayName: "Gemini 2.5 Pro",
description: "Preview version of Google's Gemini 2.5 Pro model",
// See Flash 2.5 comment below (go 1 below just to be safe, even though it seems OK now).
maxOutputTokens: 65_536 - 1,
// Gemini context window = input token + output token
contextWindow: 1_048_576,
supportsTurboEdits: true,
},
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview
{
name: "gemini-2.5-flash-preview-05-20",
displayName: "Gemini 2.5 Flash",
description:
"Preview version of Google's Gemini 2.5 Flash model (free tier available)",
// Weirdly for Vertex AI, the output token limit is *exclusive* of the stated limit.
maxOutputTokens: 65_536 - 1,
// Gemini context window = input token + output token
contextWindow: 1_048_576,
},
],
openrouter: [
// https://openrouter.ai/deepseek/deepseek-chat-v3-0324:free
{
name: "deepseek/deepseek-chat-v3-0324:free",
displayName: "DeepSeek v3 (free)",
description: "Use for free (data may be used for training)",
maxOutputTokens: 32_000,
contextWindow: 128_000,
},
],
auto: [
{
name: "auto",
displayName: "Auto",
description: "Automatically selects the best model",
tag: "Default",
},
],
};
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
openai: "OPENAI_API_KEY",
anthropic: "ANTHROPIC_API_KEY",
google: "GEMINI_API_KEY",
openrouter: "OPENROUTER_API_KEY",
};
export const CLOUD_PROVIDERS: Record<
string,
{
displayName: string;
hasFreeTier?: boolean;
websiteUrl?: string;
gatewayPrefix: string;
}
> = {
openai: {
displayName: "OpenAI",
hasFreeTier: false,
websiteUrl: "https://platform.openai.com/api-keys",
gatewayPrefix: "",
},
anthropic: {
displayName: "Anthropic",
hasFreeTier: false,
websiteUrl: "https://console.anthropic.com/settings/keys",
gatewayPrefix: "anthropic/",
},
google: {
displayName: "Google",
hasFreeTier: true,
websiteUrl: "https://aistudio.google.com/app/apikey",
gatewayPrefix: "gemini/",
},
openrouter: {
displayName: "OpenRouter",
hasFreeTier: true,
websiteUrl: "https://openrouter.ai/settings/keys",
gatewayPrefix: "openrouter/",
},
auto: {
displayName: "Dyad",
websiteUrl: "https://academy.dyad.sh/settings",
gatewayPrefix: "",
},
};
const LOCAL_PROVIDERS: Record<
string,
{
displayName: string;
hasFreeTier: boolean;
}
> = {
ollama: {
displayName: "Ollama",
hasFreeTier: true,
},
lmstudio: {
displayName: "LM Studio",
hasFreeTier: true,
},
};
/**
* Fetches language model providers from both the database (custom) and hardcoded constants (cloud),
* merging them with custom providers taking precedence.
* @returns A promise that resolves to an array of LanguageModelProvider objects.
*/
export async function getLanguageModelProviders(): Promise<
LanguageModelProvider[]
> {
// Fetch custom providers from the database
const customProvidersDb = await db
.select()
.from(languageModelProvidersSchema);
const customProvidersMap = new Map<string, LanguageModelProvider>();
for (const cp of customProvidersDb) {
customProvidersMap.set(cp.id, {
id: cp.id,
name: cp.name,
apiBaseUrl: cp.api_base_url,
envVarName: cp.env_var_name ?? undefined,
type: "custom",
// hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema
// They will be undefined unless overridden by hardcoded values if IDs match
});
}
// Get hardcoded cloud providers
const hardcodedProviders: LanguageModelProvider[] = [];
for (const providerKey in CLOUD_PROVIDERS) {
if (Object.prototype.hasOwnProperty.call(CLOUD_PROVIDERS, providerKey)) {
// Ensure providerKey is a key of PROVIDERS
const key = providerKey as keyof typeof CLOUD_PROVIDERS;
const providerDetails = CLOUD_PROVIDERS[key];
if (providerDetails) {
// Ensure providerDetails is not undefined
hardcodedProviders.push({
id: key,
name: providerDetails.displayName,
hasFreeTier: providerDetails.hasFreeTier,
websiteUrl: providerDetails.websiteUrl,
gatewayPrefix: providerDetails.gatewayPrefix,
envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined,
type: "cloud",
// apiBaseUrl is not directly in PROVIDERS
});
}
}
}
for (const providerKey in LOCAL_PROVIDERS) {
if (Object.prototype.hasOwnProperty.call(LOCAL_PROVIDERS, providerKey)) {
const key = providerKey as keyof typeof LOCAL_PROVIDERS;
const providerDetails = LOCAL_PROVIDERS[key];
hardcodedProviders.push({
id: key,
name: providerDetails.displayName,
hasFreeTier: providerDetails.hasFreeTier,
type: "local",
});
}
}
return [...hardcodedProviders, ...customProvidersMap.values()];
}
/**
* Fetches language models for a specific provider.
* @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects.
*/
export async function getLanguageModels({
providerId,
}: {
providerId: string;
}): Promise<LanguageModel[]> {
const allProviders = await getLanguageModelProviders();
const provider = allProviders.find((p) => p.id === providerId);
if (!provider) {
console.warn(`Provider with ID "${providerId}" not found.`);
return [];
}
// Get custom models from DB for all provider types
let customModels: LanguageModel[] = [];
try {
const customModelsDb = await db
.select({
id: languageModelsSchema.id,
displayName: languageModelsSchema.displayName,
apiName: languageModelsSchema.apiName,
description: languageModelsSchema.description,
maxOutputTokens: languageModelsSchema.max_output_tokens,
contextWindow: languageModelsSchema.context_window,
})
.from(languageModelsSchema)
.where(
isCustomProvider({ providerId })
? eq(languageModelsSchema.customProviderId, providerId)
: eq(languageModelsSchema.builtinProviderId, providerId),
);
customModels = customModelsDb.map((model) => ({
...model,
description: model.description ?? "",
tag: undefined,
maxOutputTokens: model.maxOutputTokens ?? undefined,
contextWindow: model.contextWindow ?? undefined,
type: "custom",
}));
} catch (error) {
console.error(
`Error fetching custom models for provider "${providerId}" from DB:`,
error,
);
// Continue with empty custom models array
}
// If it's a cloud provider, also get the hardcoded models
let hardcodedModels: LanguageModel[] = [];
if (provider.type === "cloud") {
if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId] || [];
hardcodedModels = models.map((model) => ({
...model,
apiName: model.name,
type: "cloud",
}));
} else {
console.warn(
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
);
}
}
return [...hardcodedModels, ...customModels];
}
/**
* Fetches all language models grouped by their provider IDs.
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
*/
export async function getLanguageModelsByProviders(): Promise<
Record<string, LanguageModel[]>
> {
const providers = await getLanguageModelProviders();
// Fetch all models concurrently
const modelPromises = providers
.filter((p) => p.type !== "local")
.map(async (provider) => {
const models = await getLanguageModels({ providerId: provider.id });
return { providerId: provider.id, models };
});
// Wait for all requests to complete
const results = await Promise.all(modelPromises);
// Convert the array of results to a record
const record: Record<string, LanguageModel[]> = {};
for (const result of results) {
record[result.providerId] = result.models;
}
return record;
}
export function isCustomProvider({ providerId }: { providerId: string }) {
return providerId.startsWith(CUSTOM_PROVIDER_PREFIX);
}
export const CUSTOM_PROVIDER_PREFIX = "custom::";