diff --git a/src/ipc/handlers/chat_stream_handlers.ts b/src/ipc/handlers/chat_stream_handlers.ts index 19b3479..9d66953 100644 --- a/src/ipc/handlers/chat_stream_handlers.ts +++ b/src/ipc/handlers/chat_stream_handlers.ts @@ -15,7 +15,7 @@ import { extractCodebase } from "../../utils/codebase"; import { processFullResponseActions } from "../processors/response_processor"; import { streamTestResponse } from "./testing_chat_handlers"; import { getTestResponse } from "./testing_chat_handlers"; -import { getMaxTokens, getModelClient } from "../utils/get_model_client"; +import { getModelClient } from "../utils/get_model_client"; import log from "electron-log"; import { getSupabaseContext, @@ -27,6 +27,7 @@ import * as path from "path"; import * as os from "os"; import * as crypto from "crypto"; import { readFile, writeFile, unlink } from "fs/promises"; +import { getMaxTokens } from "../utils/token_utils"; const logger = log.scope("chat_stream_handlers"); @@ -332,7 +333,7 @@ This conversation includes one or more image attachments. When the user uploads // When calling streamText, the messages need to be properly formatted for mixed content const { textStream } = streamText({ - maxTokens: getMaxTokens(settings.selectedModel), + maxTokens: await getMaxTokens(settings.selectedModel), temperature: 0, model: modelClient, system: systemPrompt, diff --git a/src/ipc/handlers/language_model_handlers.ts b/src/ipc/handlers/language_model_handlers.ts index 729fd54..381faa1 100644 --- a/src/ipc/handlers/language_model_handlers.ts +++ b/src/ipc/handlers/language_model_handlers.ts @@ -7,6 +7,7 @@ import type { import { createLoggedHandler } from "./safe_handle"; import log from "electron-log"; import { + CUSTOM_PROVIDER_PREFIX, getLanguageModelProviders, getLanguageModels, getLanguageModelsByProviders, @@ -66,7 +67,7 @@ export function registerLanguageModelHandlers() { // Insert the new provider await db.insert(languageModelProvidersSchema).values({ // Make sure we will never have accidental collisions with builtin providers - id: "custom::" + id, + id: CUSTOM_PROVIDER_PREFIX + id, name, api_base_url: apiBaseUrl, env_var_name: envVarName || null, @@ -297,11 +298,7 @@ export function registerLanguageModelHandlers() { if (provider.type === "local") { throw new Error("Local models cannot be fetched"); } - return getLanguageModels( - provider.type === "cloud" - ? { builtinProviderId: params.providerId } - : { customProviderId: params.providerId }, - ); + return getLanguageModels({ providerId: params.providerId }); }, ); diff --git a/src/ipc/handlers/proposal_handlers.ts b/src/ipc/handlers/proposal_handlers.ts index 6e44a60..9a5bd87 100644 --- a/src/ipc/handlers/proposal_handlers.ts +++ b/src/ipc/handlers/proposal_handlers.ts @@ -277,7 +277,7 @@ const getProposalHandler = async ( ); const totalTokens = messagesTokenCount + codebaseTokenCount; - const contextWindow = Math.min(getContextWindow(), 100_000); + const contextWindow = Math.min(await getContextWindow(), 100_000); logger.log( `Token usage: ${totalTokens}/${contextWindow} (${ (totalTokens / contextWindow) * 100 diff --git a/src/ipc/handlers/token_count_handlers.ts b/src/ipc/handlers/token_count_handlers.ts index 111495b..198f4d0 100644 --- a/src/ipc/handlers/token_count_handlers.ts +++ b/src/ipc/handlers/token_count_handlers.ts @@ -88,7 +88,7 @@ export function registerTokenCountHandlers() { codebaseTokens, inputTokens, systemPromptTokens, - contextWindow: getContextWindow(), + contextWindow: await getContextWindow(), }; }, ); diff --git a/src/ipc/shared/language_model_helpers.ts b/src/ipc/shared/language_model_helpers.ts index 310336e..df18726 100644 --- a/src/ipc/shared/language_model_helpers.ts +++ b/src/ipc/shared/language_model_helpers.ts @@ -139,26 +139,16 @@ export async function getLanguageModelProviders(): Promise< * @param obj An object containing the providerId. * @returns A promise that resolves to an array of LanguageModel objects. */ -export async function getLanguageModels( - obj: - | { - customProviderId: string; - // builtinProviderId?: undefined; - } - | { - builtinProviderId: string; - // customProviderId?: undefined; - }, -): Promise { +export async function getLanguageModels({ + providerId, +}: { + providerId: string; +}): Promise { const allProviders = await getLanguageModelProviders(); - const provider = allProviders.find( - (p) => - p.id === (obj as { customProviderId: string }).customProviderId || - p.id === (obj as { builtinProviderId: string }).builtinProviderId, - ); + const provider = allProviders.find((p) => p.id === providerId); if (!provider) { - console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`); + console.warn(`Provider with ID "${providerId}" not found.`); return []; } @@ -177,9 +167,9 @@ export async function getLanguageModels( }) .from(languageModelsSchema) .where( - "customProviderId" in obj - ? eq(languageModelsSchema.customProviderId, obj.customProviderId) - : eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId), + isCustomProvider({ providerId }) + ? eq(languageModelsSchema.customProviderId, providerId) + : eq(languageModelsSchema.builtinProviderId, providerId), ); customModels = customModelsDb.map((model) => ({ @@ -192,7 +182,7 @@ export async function getLanguageModels( })); } catch (error) { console.error( - `Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`, + `Error fetching custom models for provider "${providerId}" from DB:`, error, ); // Continue with empty custom models array @@ -200,7 +190,6 @@ export async function getLanguageModels( // If it's a cloud provider, also get the hardcoded models let hardcodedModels: LanguageModel[] = []; - const providerId = provider.id; if (provider.type === "cloud") { if (providerId in MODEL_OPTIONS) { const models = MODEL_OPTIONS[providerId as RegularModelProvider] || []; @@ -245,11 +234,7 @@ export async function getLanguageModelsByProviders(): Promise< const modelPromises = providers .filter((p) => p.type !== "local") .map(async (provider) => { - const models = await getLanguageModels( - provider.type === "cloud" - ? { builtinProviderId: provider.id } - : { customProviderId: provider.id }, - ); + const models = await getLanguageModels({ providerId: provider.id }); return { providerId: provider.id, models }; }); @@ -264,3 +249,9 @@ export async function getLanguageModelsByProviders(): Promise< return record; } + +export function isCustomProvider({ providerId }: { providerId: string }) { + return providerId.startsWith(CUSTOM_PROVIDER_PREFIX); +} + +export const CUSTOM_PROVIDER_PREFIX = "custom::"; diff --git a/src/ipc/utils/get_model_client.ts b/src/ipc/utils/get_model_client.ts index a681d0a..3c89efa 100644 --- a/src/ipc/utils/get_model_client.ts +++ b/src/ipc/utils/get_model_client.ts @@ -5,7 +5,7 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider"; import { createOllama } from "ollama-ai-provider"; import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import type { LargeLanguageModel, UserSettings } from "../../lib/schemas"; -import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models"; +import { AUTO_MODELS } from "../../constants/models"; import { getEnvVar } from "./read_env"; import log from "electron-log"; import { getLanguageModelProviders } from "../shared/language_model_helpers"; @@ -140,19 +140,3 @@ export async function getModelClient( } } } - -// Most models support at least 8000 output tokens so we use it as a default value. -const DEFAULT_MAX_TOKENS = 8_000; - -export function getMaxTokens(model: LargeLanguageModel) { - if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) { - logger.warn( - `Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`, - ); - return DEFAULT_MAX_TOKENS; - } - const modelOption = MODEL_OPTIONS[ - model.provider as keyof typeof MODEL_OPTIONS - ].find((m) => m.name === model.name); - return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS; -} diff --git a/src/ipc/utils/token_utils.ts b/src/ipc/utils/token_utils.ts index b5b8685..9b34470 100644 --- a/src/ipc/utils/token_utils.ts +++ b/src/ipc/utils/token_utils.ts @@ -1,9 +1,8 @@ +import { LargeLanguageModel } from "@/lib/schemas"; import { readSettings } from "../../main/settings"; import { Message } from "../ipc_types"; -import { MODEL_OPTIONS } from "../../constants/models"; -import log from "electron-log"; -const logger = log.scope("token_utils"); +import { getLanguageModels } from "../shared/language_model_helpers"; // Estimate tokens (4 characters per token) export const estimateTokens = (text: string): number => { @@ -19,17 +18,26 @@ export const estimateMessagesTokens = (messages: Message[]): number => { const DEFAULT_CONTEXT_WINDOW = 128_000; -export function getContextWindow() { +export async function getContextWindow() { const settings = readSettings(); const model = settings.selectedModel; - if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) { - logger.warn( - `Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`, - ); - return DEFAULT_CONTEXT_WINDOW; - } - const modelOption = MODEL_OPTIONS[ - model.provider as keyof typeof MODEL_OPTIONS - ].find((m) => m.name === model.name); + + const models = await getLanguageModels({ + providerId: model.provider, + }); + + const modelOption = models.find((m) => m.apiName === model.name); return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW; } + +// Most models support at least 8000 output tokens so we use it as a default value. +const DEFAULT_MAX_TOKENS = 8_000; + +export async function getMaxTokens(model: LargeLanguageModel) { + const models = await getLanguageModels({ + providerId: model.provider, + }); + + const modelOption = models.find((m) => m.apiName === model.name); + return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS; +}