From 0d441b15cae85bf16ac5e94e386bb5d06c449fd6 Mon Sep 17 00:00:00 2001 From: Will Chen Date: Mon, 28 Apr 2025 14:45:54 -0700 Subject: [PATCH] Show token bar at bottom of chat input (#33) --- src/components/chat/ChatInput.tsx | 33 ++++-- src/components/chat/TokenBar.tsx | 130 +++++++++++++++++++++++ src/constants/models.ts | 8 ++ src/hooks/useCountTokens.ts | 43 ++++++++ src/hooks/useStreamChat.ts | 5 + src/ipc/handlers/token_count_handlers.ts | 127 ++++++++++++++++++++++ src/ipc/ipc_client.ts | 15 +++ src/ipc/ipc_host.ts | 2 + src/ipc/ipc_types.ts | 13 +++ src/preload.ts | 1 + 10 files changed, 369 insertions(+), 8 deletions(-) create mode 100644 src/components/chat/TokenBar.tsx create mode 100644 src/hooks/useCountTokens.ts create mode 100644 src/ipc/handlers/token_count_handlers.ts diff --git a/src/components/chat/ChatInput.tsx b/src/components/chat/ChatInput.tsx index ae1cb83..9ccf0b8 100644 --- a/src/components/chat/ChatInput.tsx +++ b/src/components/chat/ChatInput.tsx @@ -15,6 +15,7 @@ import { Database, ChevronsUpDown, ChevronsDownUp, + BarChart2, } from "lucide-react"; import type React from "react"; import { useCallback, useEffect, useRef, useState } from "react"; @@ -22,7 +23,7 @@ import { ModelPicker } from "@/components/ModelPicker"; import { useSettings } from "@/hooks/useSettings"; import { IpcClient } from "@/ipc/ipc_client"; import { chatInputValueAtom, chatMessagesAtom } from "@/atoms/chatAtoms"; -import { useAtom, useSetAtom } from "jotai"; +import { atom, useAtom, useSetAtom } from "jotai"; import { useStreamChat } from "@/hooks/useStreamChat"; import { useChats } from "@/hooks/useChats"; import { selectedAppIdAtom } from "@/atoms/appAtoms"; @@ -46,6 +47,9 @@ import { useRunApp } from "@/hooks/useRunApp"; import { AutoApproveSwitch } from "../AutoApproveSwitch"; import { usePostHog } from "posthog-js/react"; import { CodeHighlight } from "./CodeHighlight"; +import { TokenBar } from "./TokenBar"; + +const showTokenBarAtom = atom(false); export function ChatInput({ chatId }: { chatId?: number }) { const posthog = usePostHog(); @@ -60,6 +64,7 @@ export function ChatInput({ chatId }: { chatId?: number }) { const [isRejecting, setIsRejecting] = useState(false); // State for rejecting const [messages, setMessages] = useAtom(chatMessagesAtom); const setIsPreviewOpen = useSetAtom(isPreviewOpenAtom); + const [showTokenBar, setShowTokenBar] = useAtom(showTokenBarAtom); const { refreshAppIframe } = useRunApp(); @@ -274,14 +279,26 @@ export function ChatInput({ chatId }: { chatId?: number }) { )} -
- - updateSettings({ selectedModel: model }) - } - /> +
+
+ + updateSettings({ selectedModel: model }) + } + /> +
+
+ {/* TokenBar is only displayed when showTokenBar is true */} + {showTokenBar && }
diff --git a/src/components/chat/TokenBar.tsx b/src/components/chat/TokenBar.tsx new file mode 100644 index 0000000..d249094 --- /dev/null +++ b/src/components/chat/TokenBar.tsx @@ -0,0 +1,130 @@ +import React, { useEffect, useState } from "react"; +import { + Tooltip, + TooltipContent, + TooltipProvider, + TooltipTrigger, +} from "@/components/ui/tooltip"; +import { useCountTokens } from "@/hooks/useCountTokens"; +import { MessageSquare, Code, Bot, AlignLeft } from "lucide-react"; +import { chatInputValueAtom } from "@/atoms/chatAtoms"; +import { useAtom } from "jotai"; +import { useSettings } from "@/hooks/useSettings"; + +interface TokenBarProps { + chatId?: number; +} + +export function TokenBar({ chatId }: TokenBarProps) { + const [inputValue] = useAtom(chatInputValueAtom); + const { countTokens, result } = useCountTokens(); + const [error, setError] = useState(null); + const { settings } = useSettings(); + useEffect(() => { + if (!chatId) return; + // Mark this as used, we need to re-trigger token count + // when selected model changes. + void settings?.selectedModel; + + const debounceTimer = setTimeout(() => { + countTokens(chatId, inputValue).catch((err) => { + setError("Failed to count tokens"); + console.error("Token counting error:", err); + }); + }, 500); + + return () => clearTimeout(debounceTimer); + }, [chatId, inputValue, countTokens, settings?.selectedModel]); + + if (!chatId || !result) { + return null; + } + + const { + totalTokens, + messageHistoryTokens, + codebaseTokens, + systemPromptTokens, + inputTokens, + contextWindow, + } = result; + + const percentUsed = Math.min((totalTokens / contextWindow) * 100, 100); + + // Calculate widths for each token type + const messageHistoryPercent = (messageHistoryTokens / contextWindow) * 100; + const codebasePercent = (codebaseTokens / contextWindow) * 100; + const systemPromptPercent = (systemPromptTokens / contextWindow) * 100; + const inputPercent = (inputTokens / contextWindow) * 100; + + return ( +
+ + + +
+
+ Tokens: {totalTokens.toLocaleString()} + + {Math.round(percentUsed)}% of{" "} + {(contextWindow / 1000).toFixed(0)}K + +
+
+ {/* Message history tokens */} +
+ {/* Codebase tokens */} +
+ {/* System prompt tokens */} +
+ {/* Input tokens */} +
+
+
+ + +
+
Token Usage Breakdown
+
+ + Message History + {messageHistoryTokens.toLocaleString()} + + + Codebase + {codebaseTokens.toLocaleString()} + + + System Prompt + {systemPromptTokens.toLocaleString()} + + + Current Input + {inputTokens.toLocaleString()} +
+
+
+ Total + {totalTokens.toLocaleString()} +
+
+
+
+ + + {error &&
{error}
} +
+ ); +} diff --git a/src/constants/models.ts b/src/constants/models.ts index 8fcd199..fd51900 100644 --- a/src/constants/models.ts +++ b/src/constants/models.ts @@ -5,6 +5,7 @@ export interface ModelOption { description: string; tag?: string; maxOutputTokens?: number; + contextWindow?: number; } type RegularModelProvider = Exclude; @@ -16,6 +17,7 @@ export const MODEL_OPTIONS: Record = { displayName: "GPT 4.1", description: "OpenAI's flagship model", maxOutputTokens: 32_768, + contextWindow: 1_047_576, }, // https://platform.openai.com/docs/models/gpt-4.1-mini { @@ -23,6 +25,7 @@ export const MODEL_OPTIONS: Record = { displayName: "GPT 4.1 Mini", description: "OpenAI's lightweight, but intelligent model", maxOutputTokens: 32_768, + contextWindow: 1_047_576, }, // https://platform.openai.com/docs/models/o3-mini { @@ -30,6 +33,7 @@ export const MODEL_OPTIONS: Record = { displayName: "o3 mini", description: "Reasoning model", maxOutputTokens: 100_000, + contextWindow: 200_000, }, ], // https://docs.anthropic.com/en/docs/about-claude/models/all-models#model-comparison-table @@ -39,6 +43,7 @@ export const MODEL_OPTIONS: Record = { displayName: "Claude 3.7 Sonnet", description: "Excellent coder", maxOutputTokens: 64_000, + contextWindow: 200_000, }, ], google: [ @@ -49,6 +54,8 @@ export const MODEL_OPTIONS: Record = { description: "Experimental version of Google's Gemini 2.5 Pro model", tag: "Recommended", maxOutputTokens: 65_536, + // Gemini context window = input token + output token + contextWindow: 1_048_576, }, ], openrouter: [ @@ -58,6 +65,7 @@ export const MODEL_OPTIONS: Record = { displayName: "DeepSeek v3 (free)", description: "Use for free (data may be used for training)", maxOutputTokens: 32_000, + contextWindow: 128_000, }, ], auto: [ diff --git a/src/hooks/useCountTokens.ts b/src/hooks/useCountTokens.ts new file mode 100644 index 0000000..ffbf8c8 --- /dev/null +++ b/src/hooks/useCountTokens.ts @@ -0,0 +1,43 @@ +import { useCallback } from "react"; +import { atom, useAtom } from "jotai"; +import { IpcClient } from "@/ipc/ipc_client"; +import type { TokenCountResult } from "@/ipc/ipc_types"; + +// Create atoms to store the token count state +export const tokenCountResultAtom = atom(null); +export const tokenCountLoadingAtom = atom(false); +export const tokenCountErrorAtom = atom(null); + +export function useCountTokens() { + const [result, setResult] = useAtom(tokenCountResultAtom); + const [loading, setLoading] = useAtom(tokenCountLoadingAtom); + const [error, setError] = useAtom(tokenCountErrorAtom); + + const countTokens = useCallback( + async (chatId: number, input: string) => { + setLoading(true); + setError(null); + + try { + const ipcClient = IpcClient.getInstance(); + const tokenResult = await ipcClient.countTokens({ chatId, input }); + setResult(tokenResult); + return tokenResult; + } catch (error) { + console.error("Error counting tokens:", error); + setError(error instanceof Error ? error : new Error(String(error))); + throw error; + } finally { + setLoading(false); + } + }, + [setLoading, setError, setResult] + ); + + return { + countTokens, + result, + loading, + error, + }; +} diff --git a/src/hooks/useStreamChat.ts b/src/hooks/useStreamChat.ts index 09d9cf3..d18cf03 100644 --- a/src/hooks/useStreamChat.ts +++ b/src/hooks/useStreamChat.ts @@ -18,6 +18,7 @@ import { showError } from "@/lib/toast"; import { useProposal } from "./useProposal"; import { useSearch } from "@tanstack/react-router"; import { useRunApp } from "./useRunApp"; +import { useCountTokens } from "./useCountTokens"; export function getRandomNumberId() { return Math.floor(Math.random() * 1_000_000_000_000_000); @@ -36,6 +37,8 @@ export function useStreamChat({ const setStreamCount = useSetAtom(chatStreamCountAtom); const { refreshVersions } = useLoadVersions(selectedAppId); const { refreshAppIframe } = useRunApp(); + const { countTokens } = useCountTokens(); + let chatId: number | undefined; if (hasChatId) { @@ -111,6 +114,7 @@ export function useStreamChat({ refreshChats(); refreshApp(); refreshVersions(); + countTokens(chatId, ""); }, onError: (errorMessage: string) => { console.error(`[CHAT] Stream error for ${chatId}:`, errorMessage); @@ -121,6 +125,7 @@ export function useStreamChat({ refreshChats(); refreshApp(); refreshVersions(); + countTokens(chatId, ""); }, }); } catch (error) { diff --git a/src/ipc/handlers/token_count_handlers.ts b/src/ipc/handlers/token_count_handlers.ts new file mode 100644 index 0000000..8ea8257 --- /dev/null +++ b/src/ipc/handlers/token_count_handlers.ts @@ -0,0 +1,127 @@ +import { ipcMain } from "electron"; +import { db } from "../../db"; +import { chats, messages } from "../../db/schema"; +import { eq } from "drizzle-orm"; +import { SYSTEM_PROMPT } from "../../prompts/system_prompt"; +import { + SUPABASE_AVAILABLE_SYSTEM_PROMPT, + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT, +} from "../../prompts/supabase_prompt"; +import { getDyadAppPath } from "../../paths/paths"; +import log from "electron-log"; +import { extractCodebase } from "../../utils/codebase"; +import { getSupabaseContext } from "../../supabase_admin/supabase_context"; +import { readSettings } from "../../main/settings"; +import { MODEL_OPTIONS } from "../../constants/models"; +import { TokenCountParams } from "../ipc_types"; +import { TokenCountResult } from "../ipc_types"; + +const logger = log.scope("token_count_handlers"); + +// Estimate tokens (4 characters per token) +const estimateTokens = (text: string): number => { + return Math.ceil(text.length / 4); +}; + +export function registerTokenCountHandlers() { + ipcMain.handle( + "chat:count-tokens", + async (event, req: TokenCountParams): Promise => { + try { + // Get the chat with messages + const chat = await db.query.chats.findFirst({ + where: eq(chats.id, req.chatId), + with: { + messages: { + orderBy: (messages, { asc }) => [asc(messages.createdAt)], + }, + app: true, + }, + }); + + if (!chat) { + throw new Error(`Chat not found: ${req.chatId}`); + } + + // Prepare message history for token counting + const messageHistory = chat.messages + .map((message) => message.content) + .join(""); + const messageHistoryTokens = estimateTokens(messageHistory); + + // Count input tokens + const inputTokens = estimateTokens(req.input); + + // Count system prompt tokens + let systemPrompt = SYSTEM_PROMPT; + let supabaseContext = ""; + + if (chat.app?.supabaseProjectId) { + systemPrompt += "\n\n" + SUPABASE_AVAILABLE_SYSTEM_PROMPT; + supabaseContext = await getSupabaseContext({ + supabaseProjectId: chat.app.supabaseProjectId, + }); + } else { + systemPrompt += "\n\n" + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT; + } + + const systemPromptTokens = estimateTokens( + systemPrompt + supabaseContext + ); + + // Extract codebase information if app is associated with the chat + let codebaseInfo = ""; + let codebaseTokens = 0; + + if (chat.app) { + const appPath = getDyadAppPath(chat.app.path); + try { + codebaseInfo = await extractCodebase(appPath); + codebaseTokens = estimateTokens(codebaseInfo); + logger.log( + `Extracted codebase information from ${appPath}, tokens: ${codebaseTokens}` + ); + } catch (error) { + logger.error("Error extracting codebase:", error); + } + } + + // Calculate total tokens + const totalTokens = + messageHistoryTokens + + inputTokens + + systemPromptTokens + + codebaseTokens; + + return { + totalTokens, + messageHistoryTokens, + codebaseTokens, + inputTokens, + systemPromptTokens, + contextWindow: getContextWindow(), + }; + } catch (error) { + logger.error("Error counting tokens:", error); + throw error; + } + } + ); +} + +const DEFAULT_CONTEXT_WINDOW = 128_000; + +function getContextWindow() { + const settings = readSettings(); + const model = settings.selectedModel; + if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) { + logger.warn( + `Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.` + ); + return DEFAULT_CONTEXT_WINDOW; + } + const modelOption = MODEL_OPTIONS[ + model.provider as keyof typeof MODEL_OPTIONS + ].find((m) => m.name === model.name); + return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW; +} diff --git a/src/ipc/ipc_client.ts b/src/ipc/ipc_client.ts index 8307bd0..71dd2da 100644 --- a/src/ipc/ipc_client.ts +++ b/src/ipc/ipc_client.ts @@ -19,6 +19,8 @@ import type { SystemDebugInfo, LocalModel, LocalModelListResponse, + TokenCountParams, + TokenCountResult, } from "./ipc_types"; import type { CodeProposal, ProposalResult } from "@/lib/schemas"; import { showError } from "@/lib/toast"; @@ -747,4 +749,17 @@ export class IpcClient { this.ipcRenderer.removeListener("deep-link-received", listener); }; } + + // Count tokens for a chat and input + public async countTokens( + params: TokenCountParams + ): Promise { + try { + const result = await this.ipcRenderer.invoke("chat:count-tokens", params); + return result as TokenCountResult; + } catch (error) { + showError(error); + throw error; + } + } } diff --git a/src/ipc/ipc_host.ts b/src/ipc/ipc_host.ts index 08d7575..2f443c4 100644 --- a/src/ipc/ipc_host.ts +++ b/src/ipc/ipc_host.ts @@ -10,6 +10,7 @@ import { registerProposalHandlers } from "./handlers/proposal_handlers"; import { registerDebugHandlers } from "./handlers/debug_handlers"; import { registerSupabaseHandlers } from "./handlers/supabase_handlers"; import { registerLocalModelHandlers } from "./handlers/local_model_handlers"; +import { registerTokenCountHandlers } from "./handlers/token_count_handlers"; export function registerIpcHandlers() { // Register all IPC handlers by category @@ -25,4 +26,5 @@ export function registerIpcHandlers() { registerDebugHandlers(); registerSupabaseHandlers(); registerLocalModelHandlers(); + registerTokenCountHandlers(); } diff --git a/src/ipc/ipc_types.ts b/src/ipc/ipc_types.ts index 67d6919..a69e0c1 100644 --- a/src/ipc/ipc_types.ts +++ b/src/ipc/ipc_types.ts @@ -101,3 +101,16 @@ export type LocalModelListResponse = { models: LocalModel[]; error: string | null; }; + +export interface TokenCountParams { + chatId: number; + input: string; +} +export interface TokenCountResult { + totalTokens: number; + messageHistoryTokens: number; + codebaseTokens: number; + inputTokens: number; + systemPromptTokens: number; + contextWindow: number; +} diff --git a/src/preload.ts b/src/preload.ts index 279105f..29d5efb 100644 --- a/src/preload.ts +++ b/src/preload.ts @@ -9,6 +9,7 @@ const validInvokeChannels = [ "chat:message", "chat:cancel", "chat:stream", + "chat:count-tokens", "create-chat", "create-app", "get-chat",