Show token bar at bottom of chat input (#33)

This commit is contained in:
Will Chen
2025-04-28 14:45:54 -07:00
committed by GitHub
parent aec5882c8d
commit 0d441b15ca
10 changed files with 369 additions and 8 deletions

View File

@@ -15,6 +15,7 @@ import {
Database,
ChevronsUpDown,
ChevronsDownUp,
BarChart2,
} from "lucide-react";
import type React from "react";
import { useCallback, useEffect, useRef, useState } from "react";
@@ -22,7 +23,7 @@ import { ModelPicker } from "@/components/ModelPicker";
import { useSettings } from "@/hooks/useSettings";
import { IpcClient } from "@/ipc/ipc_client";
import { chatInputValueAtom, chatMessagesAtom } from "@/atoms/chatAtoms";
import { useAtom, useSetAtom } from "jotai";
import { atom, useAtom, useSetAtom } from "jotai";
import { useStreamChat } from "@/hooks/useStreamChat";
import { useChats } from "@/hooks/useChats";
import { selectedAppIdAtom } from "@/atoms/appAtoms";
@@ -46,6 +47,9 @@ import { useRunApp } from "@/hooks/useRunApp";
import { AutoApproveSwitch } from "../AutoApproveSwitch";
import { usePostHog } from "posthog-js/react";
import { CodeHighlight } from "./CodeHighlight";
import { TokenBar } from "./TokenBar";
const showTokenBarAtom = atom(false);
export function ChatInput({ chatId }: { chatId?: number }) {
const posthog = usePostHog();
@@ -60,6 +64,7 @@ export function ChatInput({ chatId }: { chatId?: number }) {
const [isRejecting, setIsRejecting] = useState(false); // State for rejecting
const [messages, setMessages] = useAtom<Message[]>(chatMessagesAtom);
const setIsPreviewOpen = useSetAtom(isPreviewOpenAtom);
const [showTokenBar, setShowTokenBar] = useAtom(showTokenBarAtom);
const { refreshAppIframe } = useRunApp();
@@ -274,14 +279,26 @@ export function ChatInput({ chatId }: { chatId?: number }) {
</button>
)}
</div>
<div className="px-2 pb-2">
<ModelPicker
selectedModel={settings.selectedModel}
onModelSelect={(model) =>
updateSettings({ selectedModel: model })
}
/>
<div className="pl-2 pr-1 flex items-center justify-between">
<div className="pb-2">
<ModelPicker
selectedModel={settings.selectedModel}
onModelSelect={(model) =>
updateSettings({ selectedModel: model })
}
/>
</div>
<button
onClick={() => setShowTokenBar(!showTokenBar)}
className="flex items-center px-2 py-1 text-xs text-muted-foreground hover:bg-muted rounded"
title={showTokenBar ? "Hide token usage" : "Show token usage"}
>
<BarChart2 size={14} className="mr-1" />
{showTokenBar ? "Hide tokens" : "Tokens"}
</button>
</div>
{/* TokenBar is only displayed when showTokenBar is true */}
{showTokenBar && <TokenBar chatId={chatId} />}
</div>
</div>
</>

View File

@@ -0,0 +1,130 @@
import React, { useEffect, useState } from "react";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { useCountTokens } from "@/hooks/useCountTokens";
import { MessageSquare, Code, Bot, AlignLeft } from "lucide-react";
import { chatInputValueAtom } from "@/atoms/chatAtoms";
import { useAtom } from "jotai";
import { useSettings } from "@/hooks/useSettings";
interface TokenBarProps {
chatId?: number;
}
export function TokenBar({ chatId }: TokenBarProps) {
const [inputValue] = useAtom(chatInputValueAtom);
const { countTokens, result } = useCountTokens();
const [error, setError] = useState<string | null>(null);
const { settings } = useSettings();
useEffect(() => {
if (!chatId) return;
// Mark this as used, we need to re-trigger token count
// when selected model changes.
void settings?.selectedModel;
const debounceTimer = setTimeout(() => {
countTokens(chatId, inputValue).catch((err) => {
setError("Failed to count tokens");
console.error("Token counting error:", err);
});
}, 500);
return () => clearTimeout(debounceTimer);
}, [chatId, inputValue, countTokens, settings?.selectedModel]);
if (!chatId || !result) {
return null;
}
const {
totalTokens,
messageHistoryTokens,
codebaseTokens,
systemPromptTokens,
inputTokens,
contextWindow,
} = result;
const percentUsed = Math.min((totalTokens / contextWindow) * 100, 100);
// Calculate widths for each token type
const messageHistoryPercent = (messageHistoryTokens / contextWindow) * 100;
const codebasePercent = (codebaseTokens / contextWindow) * 100;
const systemPromptPercent = (systemPromptTokens / contextWindow) * 100;
const inputPercent = (inputTokens / contextWindow) * 100;
return (
<div className="px-4 pb-2 text-xs">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div className="w-full">
<div className="flex justify-between mb-1 text-xs text-muted-foreground">
<span>Tokens: {totalTokens.toLocaleString()}</span>
<span>
{Math.round(percentUsed)}% of{" "}
{(contextWindow / 1000).toFixed(0)}K
</span>
</div>
<div className="w-full h-1.5 bg-muted rounded-full overflow-hidden flex">
{/* Message history tokens */}
<div
className="h-full bg-blue-400"
style={{ width: `${messageHistoryPercent}%` }}
/>
{/* Codebase tokens */}
<div
className="h-full bg-green-400"
style={{ width: `${codebasePercent}%` }}
/>
{/* System prompt tokens */}
<div
className="h-full bg-purple-400"
style={{ width: `${systemPromptPercent}%` }}
/>
{/* Input tokens */}
<div
className="h-full bg-yellow-400"
style={{ width: `${inputPercent}%` }}
/>
</div>
</div>
</TooltipTrigger>
<TooltipContent side="top" className="w-64 p-2">
<div className="space-y-1">
<div className="font-medium">Token Usage Breakdown</div>
<div className="grid grid-cols-[20px_1fr_auto] gap-x-2 items-center">
<MessageSquare size={12} className="text-blue-500" />
<span>Message History</span>
<span>{messageHistoryTokens.toLocaleString()}</span>
<Code size={12} className="text-green-500" />
<span>Codebase</span>
<span>{codebaseTokens.toLocaleString()}</span>
<Bot size={12} className="text-purple-500" />
<span>System Prompt</span>
<span>{systemPromptTokens.toLocaleString()}</span>
<AlignLeft size={12} className="text-yellow-500" />
<span>Current Input</span>
<span>{inputTokens.toLocaleString()}</span>
</div>
<div className="pt-1 border-t border-border">
<div className="flex justify-between font-medium">
<span>Total</span>
<span>{totalTokens.toLocaleString()}</span>
</div>
</div>
</div>
</TooltipContent>
</Tooltip>
</TooltipProvider>
{error && <div className="text-red-500 text-xs mt-1">{error}</div>}
</div>
);
}

View File

@@ -5,6 +5,7 @@ export interface ModelOption {
description: string;
tag?: string;
maxOutputTokens?: number;
contextWindow?: number;
}
type RegularModelProvider = Exclude<ModelProvider, "ollama">;
@@ -16,6 +17,7 @@ export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
displayName: "GPT 4.1",
description: "OpenAI's flagship model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
},
// https://platform.openai.com/docs/models/gpt-4.1-mini
{
@@ -23,6 +25,7 @@ export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
displayName: "GPT 4.1 Mini",
description: "OpenAI's lightweight, but intelligent model",
maxOutputTokens: 32_768,
contextWindow: 1_047_576,
},
// https://platform.openai.com/docs/models/o3-mini
{
@@ -30,6 +33,7 @@ export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
displayName: "o3 mini",
description: "Reasoning model",
maxOutputTokens: 100_000,
contextWindow: 200_000,
},
],
// https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison-table
@@ -39,6 +43,7 @@ export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
displayName: "Claude 3.7 Sonnet",
description: "Excellent coder",
maxOutputTokens: 64_000,
contextWindow: 200_000,
},
],
google: [
@@ -49,6 +54,8 @@ export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
description: "Experimental version of Google's Gemini 2.5 Pro model",
tag: "Recommended",
maxOutputTokens: 65_536,
// Gemini context window = input token + output token
contextWindow: 1_048_576,
},
],
openrouter: [
@@ -58,6 +65,7 @@ export const MODEL_OPTIONS: Record<RegularModelProvider, ModelOption[]> = {
displayName: "DeepSeek v3 (free)",
description: "Use for free (data may be used for training)",
maxOutputTokens: 32_000,
contextWindow: 128_000,
},
],
auto: [

View File

@@ -0,0 +1,43 @@
import { useCallback } from "react";
import { atom, useAtom } from "jotai";
import { IpcClient } from "@/ipc/ipc_client";
import type { TokenCountResult } from "@/ipc/ipc_types";
// Create atoms to store the token count state
export const tokenCountResultAtom = atom<TokenCountResult | null>(null);
export const tokenCountLoadingAtom = atom<boolean>(false);
export const tokenCountErrorAtom = atom<Error | null>(null);
export function useCountTokens() {
const [result, setResult] = useAtom(tokenCountResultAtom);
const [loading, setLoading] = useAtom(tokenCountLoadingAtom);
const [error, setError] = useAtom(tokenCountErrorAtom);
const countTokens = useCallback(
async (chatId: number, input: string) => {
setLoading(true);
setError(null);
try {
const ipcClient = IpcClient.getInstance();
const tokenResult = await ipcClient.countTokens({ chatId, input });
setResult(tokenResult);
return tokenResult;
} catch (error) {
console.error("Error counting tokens:", error);
setError(error instanceof Error ? error : new Error(String(error)));
throw error;
} finally {
setLoading(false);
}
},
[setLoading, setError, setResult]
);
return {
countTokens,
result,
loading,
error,
};
}

View File

@@ -18,6 +18,7 @@ import { showError } from "@/lib/toast";
import { useProposal } from "./useProposal";
import { useSearch } from "@tanstack/react-router";
import { useRunApp } from "./useRunApp";
import { useCountTokens } from "./useCountTokens";
export function getRandomNumberId() {
return Math.floor(Math.random() * 1_000_000_000_000_000);
@@ -36,6 +37,8 @@ export function useStreamChat({
const setStreamCount = useSetAtom(chatStreamCountAtom);
const { refreshVersions } = useLoadVersions(selectedAppId);
const { refreshAppIframe } = useRunApp();
const { countTokens } = useCountTokens();
let chatId: number | undefined;
if (hasChatId) {
@@ -111,6 +114,7 @@ export function useStreamChat({
refreshChats();
refreshApp();
refreshVersions();
countTokens(chatId, "");
},
onError: (errorMessage: string) => {
console.error(`[CHAT] Stream error for ${chatId}:`, errorMessage);
@@ -121,6 +125,7 @@ export function useStreamChat({
refreshChats();
refreshApp();
refreshVersions();
countTokens(chatId, "");
},
});
} catch (error) {

View File

@@ -0,0 +1,127 @@
import { ipcMain } from "electron";
import { db } from "../../db";
import { chats, messages } from "../../db/schema";
import { eq } from "drizzle-orm";
import { SYSTEM_PROMPT } from "../../prompts/system_prompt";
import {
SUPABASE_AVAILABLE_SYSTEM_PROMPT,
SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT,
} from "../../prompts/supabase_prompt";
import { getDyadAppPath } from "../../paths/paths";
import log from "electron-log";
import { extractCodebase } from "../../utils/codebase";
import { getSupabaseContext } from "../../supabase_admin/supabase_context";
import { readSettings } from "../../main/settings";
import { MODEL_OPTIONS } from "../../constants/models";
import { TokenCountParams } from "../ipc_types";
import { TokenCountResult } from "../ipc_types";
const logger = log.scope("token_count_handlers");
// Estimate tokens (4 characters per token)
const estimateTokens = (text: string): number => {
return Math.ceil(text.length / 4);
};
export function registerTokenCountHandlers() {
ipcMain.handle(
"chat:count-tokens",
async (event, req: TokenCountParams): Promise<TokenCountResult> => {
try {
// Get the chat with messages
const chat = await db.query.chats.findFirst({
where: eq(chats.id, req.chatId),
with: {
messages: {
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
},
app: true,
},
});
if (!chat) {
throw new Error(`Chat not found: ${req.chatId}`);
}
// Prepare message history for token counting
const messageHistory = chat.messages
.map((message) => message.content)
.join("");
const messageHistoryTokens = estimateTokens(messageHistory);
// Count input tokens
const inputTokens = estimateTokens(req.input);
// Count system prompt tokens
let systemPrompt = SYSTEM_PROMPT;
let supabaseContext = "";
if (chat.app?.supabaseProjectId) {
systemPrompt += "\n\n" + SUPABASE_AVAILABLE_SYSTEM_PROMPT;
supabaseContext = await getSupabaseContext({
supabaseProjectId: chat.app.supabaseProjectId,
});
} else {
systemPrompt += "\n\n" + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT;
}
const systemPromptTokens = estimateTokens(
systemPrompt + supabaseContext
);
// Extract codebase information if app is associated with the chat
let codebaseInfo = "";
let codebaseTokens = 0;
if (chat.app) {
const appPath = getDyadAppPath(chat.app.path);
try {
codebaseInfo = await extractCodebase(appPath);
codebaseTokens = estimateTokens(codebaseInfo);
logger.log(
`Extracted codebase information from ${appPath}, tokens: ${codebaseTokens}`
);
} catch (error) {
logger.error("Error extracting codebase:", error);
}
}
// Calculate total tokens
const totalTokens =
messageHistoryTokens +
inputTokens +
systemPromptTokens +
codebaseTokens;
return {
totalTokens,
messageHistoryTokens,
codebaseTokens,
inputTokens,
systemPromptTokens,
contextWindow: getContextWindow(),
};
} catch (error) {
logger.error("Error counting tokens:", error);
throw error;
}
}
);
}
const DEFAULT_CONTEXT_WINDOW = 128_000;
function getContextWindow() {
const settings = readSettings();
const model = settings.selectedModel;
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
logger.warn(
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`
);
return DEFAULT_CONTEXT_WINDOW;
}
const modelOption = MODEL_OPTIONS[
model.provider as keyof typeof MODEL_OPTIONS
].find((m) => m.name === model.name);
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
}

View File

@@ -19,6 +19,8 @@ import type {
SystemDebugInfo,
LocalModel,
LocalModelListResponse,
TokenCountParams,
TokenCountResult,
} from "./ipc_types";
import type { CodeProposal, ProposalResult } from "@/lib/schemas";
import { showError } from "@/lib/toast";
@@ -747,4 +749,17 @@ export class IpcClient {
this.ipcRenderer.removeListener("deep-link-received", listener);
};
}
// Count tokens for a chat and input
public async countTokens(
params: TokenCountParams
): Promise<TokenCountResult> {
try {
const result = await this.ipcRenderer.invoke("chat:count-tokens", params);
return result as TokenCountResult;
} catch (error) {
showError(error);
throw error;
}
}
}

View File

@@ -10,6 +10,7 @@ import { registerProposalHandlers } from "./handlers/proposal_handlers";
import { registerDebugHandlers } from "./handlers/debug_handlers";
import { registerSupabaseHandlers } from "./handlers/supabase_handlers";
import { registerLocalModelHandlers } from "./handlers/local_model_handlers";
import { registerTokenCountHandlers } from "./handlers/token_count_handlers";
export function registerIpcHandlers() {
// Register all IPC handlers by category
@@ -25,4 +26,5 @@ export function registerIpcHandlers() {
registerDebugHandlers();
registerSupabaseHandlers();
registerLocalModelHandlers();
registerTokenCountHandlers();
}

View File

@@ -101,3 +101,16 @@ export type LocalModelListResponse = {
models: LocalModel[];
error: string | null;
};
export interface TokenCountParams {
chatId: number;
input: string;
}
export interface TokenCountResult {
totalTokens: number;
messageHistoryTokens: number;
codebaseTokens: number;
inputTokens: number;
systemPromptTokens: number;
contextWindow: number;
}

View File

@@ -9,6 +9,7 @@ const validInvokeChannels = [
"chat:message",
"chat:cancel",
"chat:stream",
"chat:count-tokens",
"create-chat",
"create-app",
"get-chat",