Show token bar at bottom of chat input (#33)
This commit is contained in:
127
src/ipc/handlers/token_count_handlers.ts
Normal file
127
src/ipc/handlers/token_count_handlers.ts
Normal file
@@ -0,0 +1,127 @@
|
||||
import { ipcMain } from "electron";
|
||||
import { db } from "../../db";
|
||||
import { chats, messages } from "../../db/schema";
|
||||
import { eq } from "drizzle-orm";
|
||||
import { SYSTEM_PROMPT } from "../../prompts/system_prompt";
|
||||
import {
|
||||
SUPABASE_AVAILABLE_SYSTEM_PROMPT,
|
||||
SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT,
|
||||
} from "../../prompts/supabase_prompt";
|
||||
import { getDyadAppPath } from "../../paths/paths";
|
||||
import log from "electron-log";
|
||||
import { extractCodebase } from "../../utils/codebase";
|
||||
import { getSupabaseContext } from "../../supabase_admin/supabase_context";
|
||||
import { readSettings } from "../../main/settings";
|
||||
import { MODEL_OPTIONS } from "../../constants/models";
|
||||
import { TokenCountParams } from "../ipc_types";
|
||||
import { TokenCountResult } from "../ipc_types";
|
||||
|
||||
const logger = log.scope("token_count_handlers");
|
||||
|
||||
// Estimate tokens (4 characters per token)
|
||||
const estimateTokens = (text: string): number => {
|
||||
return Math.ceil(text.length / 4);
|
||||
};
|
||||
|
||||
export function registerTokenCountHandlers() {
|
||||
ipcMain.handle(
|
||||
"chat:count-tokens",
|
||||
async (event, req: TokenCountParams): Promise<TokenCountResult> => {
|
||||
try {
|
||||
// Get the chat with messages
|
||||
const chat = await db.query.chats.findFirst({
|
||||
where: eq(chats.id, req.chatId),
|
||||
with: {
|
||||
messages: {
|
||||
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
|
||||
},
|
||||
app: true,
|
||||
},
|
||||
});
|
||||
|
||||
if (!chat) {
|
||||
throw new Error(`Chat not found: ${req.chatId}`);
|
||||
}
|
||||
|
||||
// Prepare message history for token counting
|
||||
const messageHistory = chat.messages
|
||||
.map((message) => message.content)
|
||||
.join("");
|
||||
const messageHistoryTokens = estimateTokens(messageHistory);
|
||||
|
||||
// Count input tokens
|
||||
const inputTokens = estimateTokens(req.input);
|
||||
|
||||
// Count system prompt tokens
|
||||
let systemPrompt = SYSTEM_PROMPT;
|
||||
let supabaseContext = "";
|
||||
|
||||
if (chat.app?.supabaseProjectId) {
|
||||
systemPrompt += "\n\n" + SUPABASE_AVAILABLE_SYSTEM_PROMPT;
|
||||
supabaseContext = await getSupabaseContext({
|
||||
supabaseProjectId: chat.app.supabaseProjectId,
|
||||
});
|
||||
} else {
|
||||
systemPrompt += "\n\n" + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT;
|
||||
}
|
||||
|
||||
const systemPromptTokens = estimateTokens(
|
||||
systemPrompt + supabaseContext
|
||||
);
|
||||
|
||||
// Extract codebase information if app is associated with the chat
|
||||
let codebaseInfo = "";
|
||||
let codebaseTokens = 0;
|
||||
|
||||
if (chat.app) {
|
||||
const appPath = getDyadAppPath(chat.app.path);
|
||||
try {
|
||||
codebaseInfo = await extractCodebase(appPath);
|
||||
codebaseTokens = estimateTokens(codebaseInfo);
|
||||
logger.log(
|
||||
`Extracted codebase information from ${appPath}, tokens: ${codebaseTokens}`
|
||||
);
|
||||
} catch (error) {
|
||||
logger.error("Error extracting codebase:", error);
|
||||
}
|
||||
}
|
||||
|
||||
// Calculate total tokens
|
||||
const totalTokens =
|
||||
messageHistoryTokens +
|
||||
inputTokens +
|
||||
systemPromptTokens +
|
||||
codebaseTokens;
|
||||
|
||||
return {
|
||||
totalTokens,
|
||||
messageHistoryTokens,
|
||||
codebaseTokens,
|
||||
inputTokens,
|
||||
systemPromptTokens,
|
||||
contextWindow: getContextWindow(),
|
||||
};
|
||||
} catch (error) {
|
||||
logger.error("Error counting tokens:", error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
const DEFAULT_CONTEXT_WINDOW = 128_000;
|
||||
|
||||
function getContextWindow() {
|
||||
const settings = readSettings();
|
||||
const model = settings.selectedModel;
|
||||
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
|
||||
logger.warn(
|
||||
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`
|
||||
);
|
||||
return DEFAULT_CONTEXT_WINDOW;
|
||||
}
|
||||
const modelOption = MODEL_OPTIONS[
|
||||
model.provider as keyof typeof MODEL_OPTIONS
|
||||
].find((m) => m.name === model.name);
|
||||
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
|
||||
}
|
||||
@@ -19,6 +19,8 @@ import type {
|
||||
SystemDebugInfo,
|
||||
LocalModel,
|
||||
LocalModelListResponse,
|
||||
TokenCountParams,
|
||||
TokenCountResult,
|
||||
} from "./ipc_types";
|
||||
import type { CodeProposal, ProposalResult } from "@/lib/schemas";
|
||||
import { showError } from "@/lib/toast";
|
||||
@@ -747,4 +749,17 @@ export class IpcClient {
|
||||
this.ipcRenderer.removeListener("deep-link-received", listener);
|
||||
};
|
||||
}
|
||||
|
||||
// Count tokens for a chat and input
|
||||
public async countTokens(
|
||||
params: TokenCountParams
|
||||
): Promise<TokenCountResult> {
|
||||
try {
|
||||
const result = await this.ipcRenderer.invoke("chat:count-tokens", params);
|
||||
return result as TokenCountResult;
|
||||
} catch (error) {
|
||||
showError(error);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import { registerProposalHandlers } from "./handlers/proposal_handlers";
|
||||
import { registerDebugHandlers } from "./handlers/debug_handlers";
|
||||
import { registerSupabaseHandlers } from "./handlers/supabase_handlers";
|
||||
import { registerLocalModelHandlers } from "./handlers/local_model_handlers";
|
||||
import { registerTokenCountHandlers } from "./handlers/token_count_handlers";
|
||||
|
||||
export function registerIpcHandlers() {
|
||||
// Register all IPC handlers by category
|
||||
@@ -25,4 +26,5 @@ export function registerIpcHandlers() {
|
||||
registerDebugHandlers();
|
||||
registerSupabaseHandlers();
|
||||
registerLocalModelHandlers();
|
||||
registerTokenCountHandlers();
|
||||
}
|
||||
|
||||
@@ -101,3 +101,16 @@ export type LocalModelListResponse = {
|
||||
models: LocalModel[];
|
||||
error: string | null;
|
||||
};
|
||||
|
||||
export interface TokenCountParams {
|
||||
chatId: number;
|
||||
input: string;
|
||||
}
|
||||
export interface TokenCountResult {
|
||||
totalTokens: number;
|
||||
messageHistoryTokens: number;
|
||||
codebaseTokens: number;
|
||||
inputTokens: number;
|
||||
systemPromptTokens: number;
|
||||
contextWindow: number;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user