352 lines
11 KiB
TypeScript
352 lines
11 KiB
TypeScript
import { db } from "@/db";
|
|
import {
|
|
language_model_providers as languageModelProvidersSchema,
|
|
language_models as languageModelsSchema,
|
|
} from "@/db/schema";
|
|
import type { LanguageModelProvider, LanguageModel } from "@/ipc/ipc_types";
|
|
import { eq } from "drizzle-orm";
|
|
import { ModelProvider } from "@/lib/schemas";
|
|
|
|
export interface ModelOption {
|
|
name: string;
|
|
displayName: string;
|
|
description: string;
|
|
tag?: string;
|
|
maxOutputTokens?: number;
|
|
contextWindow?: number;
|
|
}
|
|
|
|
export type RegularModelProvider = Exclude<
|
|
ModelProvider,
|
|
"ollama" | "lmstudio"
|
|
>;
|
|
|
|
export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
|
|
openai: [
|
|
// https://platform.openai.com/docs/models/gpt-4.1
|
|
{
|
|
name: "gpt-4.1",
|
|
displayName: "GPT 4.1",
|
|
description: "OpenAI's flagship model",
|
|
maxOutputTokens: 32_768,
|
|
contextWindow: 1_047_576,
|
|
},
|
|
// https://platform.openai.com/docs/models/gpt-4.1-mini
|
|
{
|
|
name: "gpt-4.1-mini",
|
|
displayName: "GPT 4.1 Mini",
|
|
description: "OpenAI's lightweight, but intelligent model",
|
|
maxOutputTokens: 32_768,
|
|
contextWindow: 1_047_576,
|
|
},
|
|
// https://platform.openai.com/docs/models/o3-mini
|
|
{
|
|
name: "o3-mini",
|
|
displayName: "o3 mini",
|
|
description: "Reasoning model",
|
|
maxOutputTokens: 100_000,
|
|
contextWindow: 200_000,
|
|
},
|
|
],
|
|
// https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison-table
|
|
anthropic: [
|
|
{
|
|
name: "claude-3-7-sonnet-latest",
|
|
displayName: "Claude 3.7 Sonnet",
|
|
description: "Excellent coder",
|
|
maxOutputTokens: 64_000,
|
|
contextWindow: 200_000,
|
|
},
|
|
],
|
|
google: [
|
|
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-pro-preview-03-25
|
|
{
|
|
name: "gemini-2.5-pro-exp-03-25",
|
|
displayName: "Gemini 2.5 Pro",
|
|
description: "Experimental version of Google's Gemini 2.5 Pro model",
|
|
tag: "Recommended",
|
|
// See Flash 2.5 comment below (go 1 below just to be safe, even though it seems OK now).
|
|
maxOutputTokens: 65_536 - 1,
|
|
// Gemini context window = input token + output token
|
|
contextWindow: 1_048_576,
|
|
},
|
|
// https://ai.google.dev/gemini-api/docs/models#gemini-2.5-flash-preview
|
|
{
|
|
name: "gemini-2.5-flash-preview-04-17",
|
|
displayName: "Gemini 2.5 Flash",
|
|
description: "Preview version of Google's Gemini 2.5 Flash model",
|
|
// Weirdly for Vertex AI, the output token limit is *exclusive* of the stated limit.
|
|
maxOutputTokens: 65_536 - 1,
|
|
// Gemini context window = input token + output token
|
|
contextWindow: 1_048_576,
|
|
},
|
|
],
|
|
openrouter: [
|
|
// https://openrouter.ai/deepseek/deepseek-chat-v3-0324:free
|
|
{
|
|
name: "deepseek/deepseek-chat-v3-0324:free",
|
|
displayName: "DeepSeek v3 (free)",
|
|
description: "Use for free (data may be used for training)",
|
|
maxOutputTokens: 32_000,
|
|
contextWindow: 128_000,
|
|
},
|
|
],
|
|
auto: [
|
|
{
|
|
name: "auto",
|
|
displayName: "Auto",
|
|
description: "Automatically selects the best model",
|
|
tag: "Default",
|
|
},
|
|
],
|
|
};
|
|
|
|
export const PROVIDER_TO_ENV_VAR: Record<string, string> = {
|
|
openai: "OPENAI_API_KEY",
|
|
anthropic: "ANTHROPIC_API_KEY",
|
|
google: "GEMINI_API_KEY",
|
|
openrouter: "OPENROUTER_API_KEY",
|
|
};
|
|
|
|
export const PROVIDERS: Record<
|
|
RegularModelProvider,
|
|
{
|
|
displayName: string;
|
|
hasFreeTier?: boolean;
|
|
websiteUrl?: string;
|
|
gatewayPrefix: string;
|
|
}
|
|
> = {
|
|
openai: {
|
|
displayName: "OpenAI",
|
|
hasFreeTier: false,
|
|
websiteUrl: "https://platform.openai.com/api-keys",
|
|
gatewayPrefix: "",
|
|
},
|
|
anthropic: {
|
|
displayName: "Anthropic",
|
|
hasFreeTier: false,
|
|
websiteUrl: "https://console.anthropic.com/settings/keys",
|
|
gatewayPrefix: "anthropic/",
|
|
},
|
|
google: {
|
|
displayName: "Google",
|
|
hasFreeTier: true,
|
|
websiteUrl: "https://aistudio.google.com/app/apikey",
|
|
gatewayPrefix: "gemini/",
|
|
},
|
|
openrouter: {
|
|
displayName: "OpenRouter",
|
|
hasFreeTier: true,
|
|
websiteUrl: "https://openrouter.ai/settings/keys",
|
|
gatewayPrefix: "openrouter/",
|
|
},
|
|
auto: {
|
|
displayName: "Dyad",
|
|
websiteUrl: "https://academy.dyad.sh/settings",
|
|
gatewayPrefix: "",
|
|
},
|
|
};
|
|
|
|
/**
|
|
* Fetches language model providers from both the database (custom) and hardcoded constants (cloud),
|
|
* merging them with custom providers taking precedence.
|
|
* @returns A promise that resolves to an array of LanguageModelProvider objects.
|
|
*/
|
|
export async function getLanguageModelProviders(): Promise<
|
|
LanguageModelProvider[]
|
|
> {
|
|
// Fetch custom providers from the database
|
|
const customProvidersDb = await db
|
|
.select()
|
|
.from(languageModelProvidersSchema);
|
|
|
|
const customProvidersMap = new Map<string, LanguageModelProvider>();
|
|
for (const cp of customProvidersDb) {
|
|
customProvidersMap.set(cp.id, {
|
|
id: cp.id,
|
|
name: cp.name,
|
|
apiBaseUrl: cp.api_base_url,
|
|
envVarName: cp.env_var_name ?? undefined,
|
|
type: "custom",
|
|
// hasFreeTier, websiteUrl, gatewayPrefix are not in the custom DB schema
|
|
// They will be undefined unless overridden by hardcoded values if IDs match
|
|
});
|
|
}
|
|
|
|
// Get hardcoded cloud providers
|
|
const hardcodedProviders: LanguageModelProvider[] = [];
|
|
for (const providerKey in PROVIDERS) {
|
|
if (Object.prototype.hasOwnProperty.call(PROVIDERS, providerKey)) {
|
|
// Ensure providerKey is a key of PROVIDERS
|
|
const key = providerKey as keyof typeof PROVIDERS;
|
|
const providerDetails = PROVIDERS[key];
|
|
if (providerDetails) {
|
|
// Ensure providerDetails is not undefined
|
|
hardcodedProviders.push({
|
|
id: key,
|
|
name: providerDetails.displayName,
|
|
hasFreeTier: providerDetails.hasFreeTier,
|
|
websiteUrl: providerDetails.websiteUrl,
|
|
gatewayPrefix: providerDetails.gatewayPrefix,
|
|
envVarName: PROVIDER_TO_ENV_VAR[key] ?? undefined,
|
|
type: "cloud",
|
|
// apiBaseUrl is not directly in PROVIDERS
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Merge lists: custom providers take precedence
|
|
const mergedProvidersMap = new Map<string, LanguageModelProvider>();
|
|
|
|
// Add all hardcoded providers first
|
|
for (const hp of hardcodedProviders) {
|
|
mergedProvidersMap.set(hp.id, hp);
|
|
}
|
|
|
|
// Add/overwrite with custom providers from DB
|
|
for (const [id, cp] of customProvidersMap) {
|
|
const existingProvider = mergedProvidersMap.get(id);
|
|
if (existingProvider) {
|
|
// If exists, merge. Custom fields take precedence.
|
|
mergedProvidersMap.set(id, {
|
|
...existingProvider, // start with hardcoded
|
|
...cp, // override with custom where defined
|
|
id: cp.id, // ensure custom id is used
|
|
name: cp.name, // ensure custom name is used
|
|
type: "custom", // explicitly set type to custom
|
|
apiBaseUrl: cp.apiBaseUrl ?? existingProvider.apiBaseUrl,
|
|
envVarName: cp.envVarName ?? existingProvider.envVarName,
|
|
});
|
|
} else {
|
|
// If it doesn't exist in hardcoded, just add the custom one
|
|
mergedProvidersMap.set(id, cp);
|
|
}
|
|
}
|
|
|
|
return Array.from(mergedProvidersMap.values());
|
|
}
|
|
|
|
/**
|
|
* Fetches language models for a specific provider.
|
|
* @param obj An object containing the providerId.
|
|
* @returns A promise that resolves to an array of LanguageModel objects.
|
|
*/
|
|
export async function getLanguageModels({
|
|
providerId,
|
|
}: {
|
|
providerId: string;
|
|
}): Promise<LanguageModel[]> {
|
|
const allProviders = await getLanguageModelProviders();
|
|
const provider = allProviders.find((p) => p.id === providerId);
|
|
|
|
if (!provider) {
|
|
console.warn(`Provider with ID "${providerId}" not found.`);
|
|
return [];
|
|
}
|
|
|
|
// Get custom models from DB for all provider types
|
|
let customModels: LanguageModel[] = [];
|
|
|
|
try {
|
|
const customModelsDb = await db
|
|
.select({
|
|
id: languageModelsSchema.id,
|
|
displayName: languageModelsSchema.displayName,
|
|
apiName: languageModelsSchema.apiName,
|
|
description: languageModelsSchema.description,
|
|
maxOutputTokens: languageModelsSchema.max_output_tokens,
|
|
contextWindow: languageModelsSchema.context_window,
|
|
})
|
|
.from(languageModelsSchema)
|
|
.where(
|
|
isCustomProvider({ providerId })
|
|
? eq(languageModelsSchema.customProviderId, providerId)
|
|
: eq(languageModelsSchema.builtinProviderId, providerId),
|
|
);
|
|
|
|
customModels = customModelsDb.map((model) => ({
|
|
...model,
|
|
description: model.description ?? "",
|
|
tag: undefined,
|
|
maxOutputTokens: model.maxOutputTokens ?? undefined,
|
|
contextWindow: model.contextWindow ?? undefined,
|
|
type: "custom",
|
|
}));
|
|
} catch (error) {
|
|
console.error(
|
|
`Error fetching custom models for provider "${providerId}" from DB:`,
|
|
error,
|
|
);
|
|
// Continue with empty custom models array
|
|
}
|
|
|
|
// If it's a cloud provider, also get the hardcoded models
|
|
let hardcodedModels: LanguageModel[] = [];
|
|
if (provider.type === "cloud") {
|
|
if (providerId in MODEL_OPTIONS) {
|
|
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
|
hardcodedModels = models.map((model) => ({
|
|
...model,
|
|
apiName: model.name,
|
|
type: "cloud",
|
|
}));
|
|
} else {
|
|
console.warn(
|
|
`Provider "${providerId}" is cloud type but not found in MODEL_OPTIONS.`,
|
|
);
|
|
}
|
|
}
|
|
|
|
// Merge the models, with custom models taking precedence over hardcoded ones
|
|
const mergedModelsMap = new Map<string, LanguageModel>();
|
|
|
|
// Add hardcoded models first
|
|
for (const model of hardcodedModels) {
|
|
mergedModelsMap.set(model.apiName, model);
|
|
}
|
|
|
|
// Then override with custom models
|
|
for (const model of customModels) {
|
|
mergedModelsMap.set(model.apiName, model);
|
|
}
|
|
|
|
return Array.from(mergedModelsMap.values());
|
|
}
|
|
|
|
/**
|
|
* Fetches all language models grouped by their provider IDs.
|
|
* @returns A promise that resolves to a Record mapping provider IDs to arrays of LanguageModel objects.
|
|
*/
|
|
export async function getLanguageModelsByProviders(): Promise<
|
|
Record<string, LanguageModel[]>
|
|
> {
|
|
const providers = await getLanguageModelProviders();
|
|
|
|
// Fetch all models concurrently
|
|
const modelPromises = providers
|
|
.filter((p) => p.type !== "local")
|
|
.map(async (provider) => {
|
|
const models = await getLanguageModels({ providerId: provider.id });
|
|
return { providerId: provider.id, models };
|
|
});
|
|
|
|
// Wait for all requests to complete
|
|
const results = await Promise.all(modelPromises);
|
|
|
|
// Convert the array of results to a record
|
|
const record: Record<string, LanguageModel[]> = {};
|
|
for (const result of results) {
|
|
record[result.providerId] = result.models;
|
|
}
|
|
|
|
return record;
|
|
}
|
|
|
|
export function isCustomProvider({ providerId }: { providerId: string }) {
|
|
return providerId.startsWith(CUSTOM_PROVIDER_PREFIX);
|
|
}
|
|
|
|
export const CUSTOM_PROVIDER_PREFIX = "custom::";
|