Show token bar at bottom of chat input (#33)

This commit is contained in:
Will Chen
2025-04-28 14:45:54 -07:00
committed by GitHub
parent aec5882c8d
commit 0d441b15ca
10 changed files with 369 additions and 8 deletions

View File

@@ -0,0 +1,127 @@
import { ipcMain } from "electron";
import { db } from "../../db";
import { chats, messages } from "../../db/schema";
import { eq } from "drizzle-orm";
import { SYSTEM_PROMPT } from "../../prompts/system_prompt";
import {
SUPABASE_AVAILABLE_SYSTEM_PROMPT,
SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT,
} from "../../prompts/supabase_prompt";
import { getDyadAppPath } from "../../paths/paths";
import log from "electron-log";
import { extractCodebase } from "../../utils/codebase";
import { getSupabaseContext } from "../../supabase_admin/supabase_context";
import { readSettings } from "../../main/settings";
import { MODEL_OPTIONS } from "../../constants/models";
import { TokenCountParams } from "../ipc_types";
import { TokenCountResult } from "../ipc_types";
const logger = log.scope("token_count_handlers");
// Estimate tokens (4 characters per token)
const estimateTokens = (text: string): number => {
return Math.ceil(text.length / 4);
};
export function registerTokenCountHandlers() {
ipcMain.handle(
"chat:count-tokens",
async (event, req: TokenCountParams): Promise<TokenCountResult> => {
try {
// Get the chat with messages
const chat = await db.query.chats.findFirst({
where: eq(chats.id, req.chatId),
with: {
messages: {
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
},
app: true,
},
});
if (!chat) {
throw new Error(`Chat not found: ${req.chatId}`);
}
// Prepare message history for token counting
const messageHistory = chat.messages
.map((message) => message.content)
.join("");
const messageHistoryTokens = estimateTokens(messageHistory);
// Count input tokens
const inputTokens = estimateTokens(req.input);
// Count system prompt tokens
let systemPrompt = SYSTEM_PROMPT;
let supabaseContext = "";
if (chat.app?.supabaseProjectId) {
systemPrompt += "\n\n" + SUPABASE_AVAILABLE_SYSTEM_PROMPT;
supabaseContext = await getSupabaseContext({
supabaseProjectId: chat.app.supabaseProjectId,
});
} else {
systemPrompt += "\n\n" + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT;
}
const systemPromptTokens = estimateTokens(
systemPrompt + supabaseContext
);
// Extract codebase information if app is associated with the chat
let codebaseInfo = "";
let codebaseTokens = 0;
if (chat.app) {
const appPath = getDyadAppPath(chat.app.path);
try {
codebaseInfo = await extractCodebase(appPath);
codebaseTokens = estimateTokens(codebaseInfo);
logger.log(
`Extracted codebase information from ${appPath}, tokens: ${codebaseTokens}`
);
} catch (error) {
logger.error("Error extracting codebase:", error);
}
}
// Calculate total tokens
const totalTokens =
messageHistoryTokens +
inputTokens +
systemPromptTokens +
codebaseTokens;
return {
totalTokens,
messageHistoryTokens,
codebaseTokens,
inputTokens,
systemPromptTokens,
contextWindow: getContextWindow(),
};
} catch (error) {
logger.error("Error counting tokens:", error);
throw error;
}
}
);
}
const DEFAULT_CONTEXT_WINDOW = 128_000;
function getContextWindow() {
const settings = readSettings();
const model = settings.selectedModel;
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
logger.warn(
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`
);
return DEFAULT_CONTEXT_WINDOW;
}
const modelOption = MODEL_OPTIONS[
model.provider as keyof typeof MODEL_OPTIONS
].find((m) => m.name === model.name);
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
}

View File

@@ -19,6 +19,8 @@ import type {
SystemDebugInfo,
LocalModel,
LocalModelListResponse,
TokenCountParams,
TokenCountResult,
} from "./ipc_types";
import type { CodeProposal, ProposalResult } from "@/lib/schemas";
import { showError } from "@/lib/toast";
@@ -747,4 +749,17 @@ export class IpcClient {
this.ipcRenderer.removeListener("deep-link-received", listener);
};
}
// Count tokens for a chat and input
public async countTokens(
params: TokenCountParams
): Promise<TokenCountResult> {
try {
const result = await this.ipcRenderer.invoke("chat:count-tokens", params);
return result as TokenCountResult;
} catch (error) {
showError(error);
throw error;
}
}
}

View File

@@ -10,6 +10,7 @@ import { registerProposalHandlers } from "./handlers/proposal_handlers";
import { registerDebugHandlers } from "./handlers/debug_handlers";
import { registerSupabaseHandlers } from "./handlers/supabase_handlers";
import { registerLocalModelHandlers } from "./handlers/local_model_handlers";
import { registerTokenCountHandlers } from "./handlers/token_count_handlers";
export function registerIpcHandlers() {
// Register all IPC handlers by category
@@ -25,4 +26,5 @@ export function registerIpcHandlers() {
registerDebugHandlers();
registerSupabaseHandlers();
registerLocalModelHandlers();
registerTokenCountHandlers();
}

View File

@@ -101,3 +101,16 @@ export type LocalModelListResponse = {
models: LocalModel[];
error: string | null;
};
export interface TokenCountParams {
chatId: number;
input: string;
}
export interface TokenCountResult {
totalTokens: number;
messageHistoryTokens: number;
codebaseTokens: number;
inputTokens: number;
systemPromptTokens: number;
contextWindow: number;
}