Simplify provider logic and migrate getContextWindow (#142)

This commit is contained in:
Will Chen
2025-05-12 22:18:49 -07:00
committed by GitHub
parent 11ba46db38
commit 877c8f7f4f
7 changed files with 48 additions and 67 deletions

View File

@@ -15,7 +15,7 @@ import { extractCodebase } from "../../utils/codebase";
import { processFullResponseActions } from "../processors/response_processor"; import { processFullResponseActions } from "../processors/response_processor";
import { streamTestResponse } from "./testing_chat_handlers"; import { streamTestResponse } from "./testing_chat_handlers";
import { getTestResponse } from "./testing_chat_handlers"; import { getTestResponse } from "./testing_chat_handlers";
import { getMaxTokens, getModelClient } from "../utils/get_model_client"; import { getModelClient } from "../utils/get_model_client";
import log from "electron-log"; import log from "electron-log";
import { import {
getSupabaseContext, getSupabaseContext,
@@ -27,6 +27,7 @@ import * as path from "path";
import * as os from "os"; import * as os from "os";
import * as crypto from "crypto"; import * as crypto from "crypto";
import { readFile, writeFile, unlink } from "fs/promises"; import { readFile, writeFile, unlink } from "fs/promises";
import { getMaxTokens } from "../utils/token_utils";
const logger = log.scope("chat_stream_handlers"); const logger = log.scope("chat_stream_handlers");
@@ -332,7 +333,7 @@ This conversation includes one or more image attachments. When the user uploads
// When calling streamText, the messages need to be properly formatted for mixed content // When calling streamText, the messages need to be properly formatted for mixed content
const { textStream } = streamText({ const { textStream } = streamText({
maxTokens: getMaxTokens(settings.selectedModel), maxTokens: await getMaxTokens(settings.selectedModel),
temperature: 0, temperature: 0,
model: modelClient, model: modelClient,
system: systemPrompt, system: systemPrompt,

View File

@@ -7,6 +7,7 @@ import type {
import { createLoggedHandler } from "./safe_handle"; import { createLoggedHandler } from "./safe_handle";
import log from "electron-log"; import log from "electron-log";
import { import {
CUSTOM_PROVIDER_PREFIX,
getLanguageModelProviders, getLanguageModelProviders,
getLanguageModels, getLanguageModels,
getLanguageModelsByProviders, getLanguageModelsByProviders,
@@ -66,7 +67,7 @@ export function registerLanguageModelHandlers() {
// Insert the new provider // Insert the new provider
await db.insert(languageModelProvidersSchema).values({ await db.insert(languageModelProvidersSchema).values({
// Make sure we will never have accidental collisions with builtin providers // Make sure we will never have accidental collisions with builtin providers
id: "custom::" + id, id: CUSTOM_PROVIDER_PREFIX + id,
name, name,
api_base_url: apiBaseUrl, api_base_url: apiBaseUrl,
env_var_name: envVarName || null, env_var_name: envVarName || null,
@@ -297,11 +298,7 @@ export function registerLanguageModelHandlers() {
if (provider.type === "local") { if (provider.type === "local") {
throw new Error("Local models cannot be fetched"); throw new Error("Local models cannot be fetched");
} }
return getLanguageModels( return getLanguageModels({ providerId: params.providerId });
provider.type === "cloud"
? { builtinProviderId: params.providerId }
: { customProviderId: params.providerId },
);
}, },
); );

View File

@@ -277,7 +277,7 @@ const getProposalHandler = async (
); );
const totalTokens = messagesTokenCount + codebaseTokenCount; const totalTokens = messagesTokenCount + codebaseTokenCount;
const contextWindow = Math.min(getContextWindow(), 100_000); const contextWindow = Math.min(await getContextWindow(), 100_000);
logger.log( logger.log(
`Token usage: ${totalTokens}/${contextWindow} (${ `Token usage: ${totalTokens}/${contextWindow} (${
(totalTokens / contextWindow) * 100 (totalTokens / contextWindow) * 100

View File

@@ -88,7 +88,7 @@ export function registerTokenCountHandlers() {
codebaseTokens, codebaseTokens,
inputTokens, inputTokens,
systemPromptTokens, systemPromptTokens,
contextWindow: getContextWindow(), contextWindow: await getContextWindow(),
}; };
}, },
); );

View File

@@ -139,26 +139,16 @@ export async function getLanguageModelProviders(): Promise<
* @param obj An object containing the providerId. * @param obj An object containing the providerId.
* @returns A promise that resolves to an array of LanguageModel objects. * @returns A promise that resolves to an array of LanguageModel objects.
*/ */
export async function getLanguageModels( export async function getLanguageModels({
obj: providerId,
| { }: {
customProviderId: string; providerId: string;
// builtinProviderId?: undefined; }): Promise<LanguageModel[]> {
}
| {
builtinProviderId: string;
// customProviderId?: undefined;
},
): Promise<LanguageModel[]> {
const allProviders = await getLanguageModelProviders(); const allProviders = await getLanguageModelProviders();
const provider = allProviders.find( const provider = allProviders.find((p) => p.id === providerId);
(p) =>
p.id === (obj as { customProviderId: string }).customProviderId ||
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
);
if (!provider) { if (!provider) {
console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`); console.warn(`Provider with ID "${providerId}" not found.`);
return []; return [];
} }
@@ -177,9 +167,9 @@ export async function getLanguageModels(
}) })
.from(languageModelsSchema) .from(languageModelsSchema)
.where( .where(
"customProviderId" in obj isCustomProvider({ providerId })
? eq(languageModelsSchema.customProviderId, obj.customProviderId) ? eq(languageModelsSchema.customProviderId, providerId)
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId), : eq(languageModelsSchema.builtinProviderId, providerId),
); );
customModels = customModelsDb.map((model) => ({ customModels = customModelsDb.map((model) => ({
@@ -192,7 +182,7 @@ export async function getLanguageModels(
})); }));
} catch (error) { } catch (error) {
console.error( console.error(
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`, `Error fetching custom models for provider "${providerId}" from DB:`,
error, error,
); );
// Continue with empty custom models array // Continue with empty custom models array
@@ -200,7 +190,6 @@ export async function getLanguageModels(
// If it's a cloud provider, also get the hardcoded models // If it's a cloud provider, also get the hardcoded models
let hardcodedModels: LanguageModel[] = []; let hardcodedModels: LanguageModel[] = [];
const providerId = provider.id;
if (provider.type === "cloud") { if (provider.type === "cloud") {
if (providerId in MODEL_OPTIONS) { if (providerId in MODEL_OPTIONS) {
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || []; const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
@@ -245,11 +234,7 @@ export async function getLanguageModelsByProviders(): Promise<
const modelPromises = providers const modelPromises = providers
.filter((p) => p.type !== "local") .filter((p) => p.type !== "local")
.map(async (provider) => { .map(async (provider) => {
const models = await getLanguageModels( const models = await getLanguageModels({ providerId: provider.id });
provider.type === "cloud"
? { builtinProviderId: provider.id }
: { customProviderId: provider.id },
);
return { providerId: provider.id, models }; return { providerId: provider.id, models };
}); });
@@ -264,3 +249,9 @@ export async function getLanguageModelsByProviders(): Promise<
return record; return record;
} }
export function isCustomProvider({ providerId }: { providerId: string }) {
return providerId.startsWith(CUSTOM_PROVIDER_PREFIX);
}
export const CUSTOM_PROVIDER_PREFIX = "custom::";

View File

@@ -5,7 +5,7 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { createOllama } from "ollama-ai-provider"; import { createOllama } from "ollama-ai-provider";
import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas"; import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models"; import { AUTO_MODELS } from "../../constants/models";
import { getEnvVar } from "./read_env"; import { getEnvVar } from "./read_env";
import log from "electron-log"; import log from "electron-log";
import { getLanguageModelProviders } from "../shared/language_model_helpers"; import { getLanguageModelProviders } from "../shared/language_model_helpers";
@@ -140,19 +140,3 @@ export async function getModelClient(
} }
} }
} }
// Most models support at least 8000 output tokens so we use it as a default value.
const DEFAULT_MAX_TOKENS = 8_000;
export function getMaxTokens(model: LargeLanguageModel) {
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
logger.warn(
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`,
);
return DEFAULT_MAX_TOKENS;
}
const modelOption = MODEL_OPTIONS[
model.provider as keyof typeof MODEL_OPTIONS
].find((m) => m.name === model.name);
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS;
}

View File

@@ -1,9 +1,8 @@
import { LargeLanguageModel } from "@/lib/schemas";
import { readSettings } from "../../main/settings"; import { readSettings } from "../../main/settings";
import { Message } from "../ipc_types"; import { Message } from "../ipc_types";
import { MODEL_OPTIONS } from "../../constants/models";
import log from "electron-log";
const logger = log.scope("token_utils"); import { getLanguageModels } from "../shared/language_model_helpers";
// Estimate tokens (4 characters per token) // Estimate tokens (4 characters per token)
export const estimateTokens = (text: string): number => { export const estimateTokens = (text: string): number => {
@@ -19,17 +18,26 @@ export const estimateMessagesTokens = (messages: Message[]): number => {
const DEFAULT_CONTEXT_WINDOW = 128_000; const DEFAULT_CONTEXT_WINDOW = 128_000;
export function getContextWindow() { export async function getContextWindow() {
const settings = readSettings(); const settings = readSettings();
const model = settings.selectedModel; const model = settings.selectedModel;
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
logger.warn( const models = await getLanguageModels({
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`, providerId: model.provider,
); });
return DEFAULT_CONTEXT_WINDOW;
} const modelOption = models.find((m) => m.apiName === model.name);
const modelOption = MODEL_OPTIONS[
model.provider as keyof typeof MODEL_OPTIONS
].find((m) => m.name === model.name);
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW; return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
} }
// Most models support at least 8000 output tokens so we use it as a default value.
const DEFAULT_MAX_TOKENS = 8_000;
export async function getMaxTokens(model: LargeLanguageModel) {
const models = await getLanguageModels({
providerId: model.provider,
});
const modelOption = models.find((m) => m.apiName === model.name);
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS;
}