Simplify provider logic and migrate getContextWindow (#142)
This commit is contained in:
@@ -15,7 +15,7 @@ import { extractCodebase } from "../../utils/codebase";
|
|||||||
import { processFullResponseActions } from "../processors/response_processor";
|
import { processFullResponseActions } from "../processors/response_processor";
|
||||||
import { streamTestResponse } from "./testing_chat_handlers";
|
import { streamTestResponse } from "./testing_chat_handlers";
|
||||||
import { getTestResponse } from "./testing_chat_handlers";
|
import { getTestResponse } from "./testing_chat_handlers";
|
||||||
import { getMaxTokens, getModelClient } from "../utils/get_model_client";
|
import { getModelClient } from "../utils/get_model_client";
|
||||||
import log from "electron-log";
|
import log from "electron-log";
|
||||||
import {
|
import {
|
||||||
getSupabaseContext,
|
getSupabaseContext,
|
||||||
@@ -27,6 +27,7 @@ import * as path from "path";
|
|||||||
import * as os from "os";
|
import * as os from "os";
|
||||||
import * as crypto from "crypto";
|
import * as crypto from "crypto";
|
||||||
import { readFile, writeFile, unlink } from "fs/promises";
|
import { readFile, writeFile, unlink } from "fs/promises";
|
||||||
|
import { getMaxTokens } from "../utils/token_utils";
|
||||||
|
|
||||||
const logger = log.scope("chat_stream_handlers");
|
const logger = log.scope("chat_stream_handlers");
|
||||||
|
|
||||||
@@ -332,7 +333,7 @@ This conversation includes one or more image attachments. When the user uploads
|
|||||||
|
|
||||||
// When calling streamText, the messages need to be properly formatted for mixed content
|
// When calling streamText, the messages need to be properly formatted for mixed content
|
||||||
const { textStream } = streamText({
|
const { textStream } = streamText({
|
||||||
maxTokens: getMaxTokens(settings.selectedModel),
|
maxTokens: await getMaxTokens(settings.selectedModel),
|
||||||
temperature: 0,
|
temperature: 0,
|
||||||
model: modelClient,
|
model: modelClient,
|
||||||
system: systemPrompt,
|
system: systemPrompt,
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import type {
|
|||||||
import { createLoggedHandler } from "./safe_handle";
|
import { createLoggedHandler } from "./safe_handle";
|
||||||
import log from "electron-log";
|
import log from "electron-log";
|
||||||
import {
|
import {
|
||||||
|
CUSTOM_PROVIDER_PREFIX,
|
||||||
getLanguageModelProviders,
|
getLanguageModelProviders,
|
||||||
getLanguageModels,
|
getLanguageModels,
|
||||||
getLanguageModelsByProviders,
|
getLanguageModelsByProviders,
|
||||||
@@ -66,7 +67,7 @@ export function registerLanguageModelHandlers() {
|
|||||||
// Insert the new provider
|
// Insert the new provider
|
||||||
await db.insert(languageModelProvidersSchema).values({
|
await db.insert(languageModelProvidersSchema).values({
|
||||||
// Make sure we will never have accidental collisions with builtin providers
|
// Make sure we will never have accidental collisions with builtin providers
|
||||||
id: "custom::" + id,
|
id: CUSTOM_PROVIDER_PREFIX + id,
|
||||||
name,
|
name,
|
||||||
api_base_url: apiBaseUrl,
|
api_base_url: apiBaseUrl,
|
||||||
env_var_name: envVarName || null,
|
env_var_name: envVarName || null,
|
||||||
@@ -297,11 +298,7 @@ export function registerLanguageModelHandlers() {
|
|||||||
if (provider.type === "local") {
|
if (provider.type === "local") {
|
||||||
throw new Error("Local models cannot be fetched");
|
throw new Error("Local models cannot be fetched");
|
||||||
}
|
}
|
||||||
return getLanguageModels(
|
return getLanguageModels({ providerId: params.providerId });
|
||||||
provider.type === "cloud"
|
|
||||||
? { builtinProviderId: params.providerId }
|
|
||||||
: { customProviderId: params.providerId },
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ const getProposalHandler = async (
|
|||||||
);
|
);
|
||||||
|
|
||||||
const totalTokens = messagesTokenCount + codebaseTokenCount;
|
const totalTokens = messagesTokenCount + codebaseTokenCount;
|
||||||
const contextWindow = Math.min(getContextWindow(), 100_000);
|
const contextWindow = Math.min(await getContextWindow(), 100_000);
|
||||||
logger.log(
|
logger.log(
|
||||||
`Token usage: ${totalTokens}/${contextWindow} (${
|
`Token usage: ${totalTokens}/${contextWindow} (${
|
||||||
(totalTokens / contextWindow) * 100
|
(totalTokens / contextWindow) * 100
|
||||||
|
|||||||
@@ -88,7 +88,7 @@ export function registerTokenCountHandlers() {
|
|||||||
codebaseTokens,
|
codebaseTokens,
|
||||||
inputTokens,
|
inputTokens,
|
||||||
systemPromptTokens,
|
systemPromptTokens,
|
||||||
contextWindow: getContextWindow(),
|
contextWindow: await getContextWindow(),
|
||||||
};
|
};
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -139,26 +139,16 @@ export async function getLanguageModelProviders(): Promise<
|
|||||||
* @param obj An object containing the providerId.
|
* @param obj An object containing the providerId.
|
||||||
* @returns A promise that resolves to an array of LanguageModel objects.
|
* @returns A promise that resolves to an array of LanguageModel objects.
|
||||||
*/
|
*/
|
||||||
export async function getLanguageModels(
|
export async function getLanguageModels({
|
||||||
obj:
|
providerId,
|
||||||
| {
|
}: {
|
||||||
customProviderId: string;
|
providerId: string;
|
||||||
// builtinProviderId?: undefined;
|
}): Promise<LanguageModel[]> {
|
||||||
}
|
|
||||||
| {
|
|
||||||
builtinProviderId: string;
|
|
||||||
// customProviderId?: undefined;
|
|
||||||
},
|
|
||||||
): Promise<LanguageModel[]> {
|
|
||||||
const allProviders = await getLanguageModelProviders();
|
const allProviders = await getLanguageModelProviders();
|
||||||
const provider = allProviders.find(
|
const provider = allProviders.find((p) => p.id === providerId);
|
||||||
(p) =>
|
|
||||||
p.id === (obj as { customProviderId: string }).customProviderId ||
|
|
||||||
p.id === (obj as { builtinProviderId: string }).builtinProviderId,
|
|
||||||
);
|
|
||||||
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
console.warn(`Provider with ID "${JSON.stringify(obj)}" not found.`);
|
console.warn(`Provider with ID "${providerId}" not found.`);
|
||||||
return [];
|
return [];
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -177,9 +167,9 @@ export async function getLanguageModels(
|
|||||||
})
|
})
|
||||||
.from(languageModelsSchema)
|
.from(languageModelsSchema)
|
||||||
.where(
|
.where(
|
||||||
"customProviderId" in obj
|
isCustomProvider({ providerId })
|
||||||
? eq(languageModelsSchema.customProviderId, obj.customProviderId)
|
? eq(languageModelsSchema.customProviderId, providerId)
|
||||||
: eq(languageModelsSchema.builtinProviderId, obj.builtinProviderId),
|
: eq(languageModelsSchema.builtinProviderId, providerId),
|
||||||
);
|
);
|
||||||
|
|
||||||
customModels = customModelsDb.map((model) => ({
|
customModels = customModelsDb.map((model) => ({
|
||||||
@@ -192,7 +182,7 @@ export async function getLanguageModels(
|
|||||||
}));
|
}));
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error(
|
console.error(
|
||||||
`Error fetching custom models for provider "${JSON.stringify(obj)}" from DB:`,
|
`Error fetching custom models for provider "${providerId}" from DB:`,
|
||||||
error,
|
error,
|
||||||
);
|
);
|
||||||
// Continue with empty custom models array
|
// Continue with empty custom models array
|
||||||
@@ -200,7 +190,6 @@ export async function getLanguageModels(
|
|||||||
|
|
||||||
// If it's a cloud provider, also get the hardcoded models
|
// If it's a cloud provider, also get the hardcoded models
|
||||||
let hardcodedModels: LanguageModel[] = [];
|
let hardcodedModels: LanguageModel[] = [];
|
||||||
const providerId = provider.id;
|
|
||||||
if (provider.type === "cloud") {
|
if (provider.type === "cloud") {
|
||||||
if (providerId in MODEL_OPTIONS) {
|
if (providerId in MODEL_OPTIONS) {
|
||||||
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
const models = MODEL_OPTIONS[providerId as RegularModelProvider] || [];
|
||||||
@@ -245,11 +234,7 @@ export async function getLanguageModelsByProviders(): Promise<
|
|||||||
const modelPromises = providers
|
const modelPromises = providers
|
||||||
.filter((p) => p.type !== "local")
|
.filter((p) => p.type !== "local")
|
||||||
.map(async (provider) => {
|
.map(async (provider) => {
|
||||||
const models = await getLanguageModels(
|
const models = await getLanguageModels({ providerId: provider.id });
|
||||||
provider.type === "cloud"
|
|
||||||
? { builtinProviderId: provider.id }
|
|
||||||
: { customProviderId: provider.id },
|
|
||||||
);
|
|
||||||
return { providerId: provider.id, models };
|
return { providerId: provider.id, models };
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -264,3 +249,9 @@ export async function getLanguageModelsByProviders(): Promise<
|
|||||||
|
|
||||||
return record;
|
return record;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function isCustomProvider({ providerId }: { providerId: string }) {
|
||||||
|
return providerId.startsWith(CUSTOM_PROVIDER_PREFIX);
|
||||||
|
}
|
||||||
|
|
||||||
|
export const CUSTOM_PROVIDER_PREFIX = "custom::";
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import { createOpenRouter } from "@openrouter/ai-sdk-provider";
|
|||||||
import { createOllama } from "ollama-ai-provider";
|
import { createOllama } from "ollama-ai-provider";
|
||||||
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
import { createOpenAICompatible } from "@ai-sdk/openai-compatible";
|
||||||
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
|
import type { LargeLanguageModel, UserSettings } from "../../lib/schemas";
|
||||||
import { AUTO_MODELS, MODEL_OPTIONS } from "../../constants/models";
|
import { AUTO_MODELS } from "../../constants/models";
|
||||||
import { getEnvVar } from "./read_env";
|
import { getEnvVar } from "./read_env";
|
||||||
import log from "electron-log";
|
import log from "electron-log";
|
||||||
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
import { getLanguageModelProviders } from "../shared/language_model_helpers";
|
||||||
@@ -140,19 +140,3 @@ export async function getModelClient(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Most models support at least 8000 output tokens so we use it as a default value.
|
|
||||||
const DEFAULT_MAX_TOKENS = 8_000;
|
|
||||||
|
|
||||||
export function getMaxTokens(model: LargeLanguageModel) {
|
|
||||||
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
|
|
||||||
logger.warn(
|
|
||||||
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`,
|
|
||||||
);
|
|
||||||
return DEFAULT_MAX_TOKENS;
|
|
||||||
}
|
|
||||||
const modelOption = MODEL_OPTIONS[
|
|
||||||
model.provider as keyof typeof MODEL_OPTIONS
|
|
||||||
].find((m) => m.name === model.name);
|
|
||||||
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS;
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,9 +1,8 @@
|
|||||||
|
import { LargeLanguageModel } from "@/lib/schemas";
|
||||||
import { readSettings } from "../../main/settings";
|
import { readSettings } from "../../main/settings";
|
||||||
import { Message } from "../ipc_types";
|
import { Message } from "../ipc_types";
|
||||||
import { MODEL_OPTIONS } from "../../constants/models";
|
|
||||||
import log from "electron-log";
|
|
||||||
|
|
||||||
const logger = log.scope("token_utils");
|
import { getLanguageModels } from "../shared/language_model_helpers";
|
||||||
|
|
||||||
// Estimate tokens (4 characters per token)
|
// Estimate tokens (4 characters per token)
|
||||||
export const estimateTokens = (text: string): number => {
|
export const estimateTokens = (text: string): number => {
|
||||||
@@ -19,17 +18,26 @@ export const estimateMessagesTokens = (messages: Message[]): number => {
|
|||||||
|
|
||||||
const DEFAULT_CONTEXT_WINDOW = 128_000;
|
const DEFAULT_CONTEXT_WINDOW = 128_000;
|
||||||
|
|
||||||
export function getContextWindow() {
|
export async function getContextWindow() {
|
||||||
const settings = readSettings();
|
const settings = readSettings();
|
||||||
const model = settings.selectedModel;
|
const model = settings.selectedModel;
|
||||||
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
|
|
||||||
logger.warn(
|
const models = await getLanguageModels({
|
||||||
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`,
|
providerId: model.provider,
|
||||||
);
|
});
|
||||||
return DEFAULT_CONTEXT_WINDOW;
|
|
||||||
}
|
const modelOption = models.find((m) => m.apiName === model.name);
|
||||||
const modelOption = MODEL_OPTIONS[
|
|
||||||
model.provider as keyof typeof MODEL_OPTIONS
|
|
||||||
].find((m) => m.name === model.name);
|
|
||||||
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
|
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Most models support at least 8000 output tokens so we use it as a default value.
|
||||||
|
const DEFAULT_MAX_TOKENS = 8_000;
|
||||||
|
|
||||||
|
export async function getMaxTokens(model: LargeLanguageModel) {
|
||||||
|
const models = await getLanguageModels({
|
||||||
|
providerId: model.provider,
|
||||||
|
});
|
||||||
|
|
||||||
|
const modelOption = models.find((m) => m.apiName === model.name);
|
||||||
|
return modelOption?.maxOutputTokens || DEFAULT_MAX_TOKENS;
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user