Summarize into new chat suggested action (#34)

This commit is contained in:
Will Chen
2025-04-28 16:14:12 -07:00
committed by GitHub
parent 0d441b15ca
commit 9fb5439ecf
8 changed files with 189 additions and 48 deletions

View File

@@ -22,8 +22,12 @@ import { useCallback, useEffect, useRef, useState } from "react";
import { ModelPicker } from "@/components/ModelPicker"; import { ModelPicker } from "@/components/ModelPicker";
import { useSettings } from "@/hooks/useSettings"; import { useSettings } from "@/hooks/useSettings";
import { IpcClient } from "@/ipc/ipc_client"; import { IpcClient } from "@/ipc/ipc_client";
import { chatInputValueAtom, chatMessagesAtom } from "@/atoms/chatAtoms"; import {
import { atom, useAtom, useSetAtom } from "jotai"; chatInputValueAtom,
chatMessagesAtom,
selectedChatIdAtom,
} from "@/atoms/chatAtoms";
import { atom, useAtom, useSetAtom, useAtomValue } from "jotai";
import { useStreamChat } from "@/hooks/useStreamChat"; import { useStreamChat } from "@/hooks/useStreamChat";
import { useChats } from "@/hooks/useChats"; import { useChats } from "@/hooks/useChats";
import { selectedAppIdAtom } from "@/atoms/appAtoms"; import { selectedAppIdAtom } from "@/atoms/appAtoms";
@@ -48,6 +52,13 @@ import { AutoApproveSwitch } from "../AutoApproveSwitch";
import { usePostHog } from "posthog-js/react"; import { usePostHog } from "posthog-js/react";
import { CodeHighlight } from "./CodeHighlight"; import { CodeHighlight } from "./CodeHighlight";
import { TokenBar } from "./TokenBar"; import { TokenBar } from "./TokenBar";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "../ui/tooltip";
import { useNavigate } from "@tanstack/react-router";
const showTokenBarAtom = atom(false); const showTokenBarAtom = atom(false);
@@ -305,8 +316,43 @@ export function ChatInput({ chatId }: { chatId?: number }) {
); );
} }
function SummarizeInNewChatButton() {
const [chatId] = useAtom(selectedChatIdAtom);
const appId = useAtomValue(selectedAppIdAtom);
const { streamMessage } = useStreamChat();
const navigate = useNavigate();
const onClick = async () => {
if (!appId) {
console.error("No app id found");
return;
}
const newChatId = await IpcClient.getInstance().createChat(appId);
// navigate to new chat
await navigate({ to: "/chat", search: { id: newChatId } });
await streamMessage({
prompt: "Summarize from chat-id=" + chatId,
chatId: newChatId,
});
};
return (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<Button variant="outline" size="sm" onClick={onClick}>
Summarize to new chat
</Button>
</TooltipTrigger>
<TooltipContent>
<p>Creating a new chat makes the AI more focused and efficient</p>
</TooltipContent>
</Tooltip>
</TooltipProvider>
);
}
function mapActionToButton(action: SuggestedAction) { function mapActionToButton(action: SuggestedAction) {
switch (action.id) { switch (action.id) {
case "summarize-in-new-chat":
return <SummarizeInNewChatButton />;
default: default:
console.error(`Unsupported action: ${action.id}`); console.error(`Unsupported action: ${action.id}`);
return ( return (
@@ -323,7 +369,6 @@ function ActionProposalActions({ proposal }: { proposal: ActionProposal }) {
<div className="flex items-center space-x-2"> <div className="flex items-center space-x-2">
{proposal.actions.map((action) => mapActionToButton(action))} {proposal.actions.map((action) => mapActionToButton(action))}
</div> </div>
<AutoApproveSwitch />
</div> </div>
); );
} }

View File

@@ -7,7 +7,6 @@ export function useProposal(chatId?: number | undefined) {
const [proposalResult, setProposalResult] = useAtom(proposalResultAtom); const [proposalResult, setProposalResult] = useAtom(proposalResultAtom);
const [isLoading, setIsLoading] = useState<boolean>(false); const [isLoading, setIsLoading] = useState<boolean>(false);
const [error, setError] = useState<string | null>(null); const [error, setError] = useState<string | null>(null);
const fetchProposal = useCallback( const fetchProposal = useCallback(
async (overrideChatId?: number) => { async (overrideChatId?: number) => {
chatId = overrideChatId ?? chatId; chatId = overrideChatId ?? chatId;
@@ -19,7 +18,6 @@ export function useProposal(chatId?: number | undefined) {
} }
setIsLoading(true); setIsLoading(true);
setError(null); setError(null);
setProposalResult(null); // Reset on new fetch
try { try {
// Type assertion might be needed depending on how IpcClient is typed // Type assertion might be needed depending on how IpcClient is typed
const result = (await IpcClient.getInstance().getProposal( const result = (await IpcClient.getInstance().getProposal(
@@ -39,7 +37,7 @@ export function useProposal(chatId?: number | undefined) {
setIsLoading(false); setIsLoading(false);
} }
}, },
[chatId] [chatId] // Only depend on chatId, setProposalResult is stable
); // Depend on chatId ); // Depend on chatId
useEffect(() => { useEffect(() => {

View File

@@ -1,5 +1,5 @@
import { ipcMain } from "electron"; import { ipcMain } from "electron";
import { streamText } from "ai"; import { CoreMessage, streamText } from "ai";
import { db } from "../../db"; import { db } from "../../db";
import { chats, messages } from "../../db/schema"; import { chats, messages } from "../../db/schema";
import { and, eq, isNull } from "drizzle-orm"; import { and, eq, isNull } from "drizzle-orm";
@@ -21,6 +21,7 @@ import {
getSupabaseContext, getSupabaseContext,
getSupabaseClientCode, getSupabaseClientCode,
} from "../../supabase_admin/supabase_context"; } from "../../supabase_admin/supabase_context";
import { SUMMARIZE_CHAT_SYSTEM_PROMPT } from "../../prompts/summarize_chat_system_prompt";
const logger = log.scope("chat_stream_handlers"); const logger = log.scope("chat_stream_handlers");
@@ -165,12 +166,13 @@ export function registerChatStreamHandlers() {
} else { } else {
systemPrompt += "\n\n" + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT; systemPrompt += "\n\n" + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT;
} }
const { textStream } = streamText({ const isSummarizeIntent = req.prompt.startsWith(
maxTokens: getMaxTokens(settings.selectedModel), "Summarize from chat-id="
temperature: 0, );
model: modelClient, if (isSummarizeIntent) {
system: systemPrompt, systemPrompt = SUMMARIZE_CHAT_SYSTEM_PROMPT;
messages: [ }
let chatMessages = [
{ {
role: "user", role: "user",
content: "This is my codebase. " + codebaseInfo, content: "This is my codebase. " + codebaseInfo,
@@ -180,7 +182,31 @@ export function registerChatStreamHandlers() {
content: "OK, got it. I'm ready to help", content: "OK, got it. I'm ready to help",
}, },
...messageHistory, ...messageHistory,
], ] satisfies CoreMessage[];
if (isSummarizeIntent) {
const previousChat = await db.query.chats.findFirst({
where: eq(chats.id, parseInt(req.prompt.split("=")[1])),
with: {
messages: {
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
},
},
});
chatMessages = [
{
role: "user",
content:
"Summarize the following chat: " +
formatMessages(previousChat?.messages ?? []),
} satisfies CoreMessage,
];
}
const { textStream } = streamText({
maxTokens: getMaxTokens(settings.selectedModel),
temperature: 0,
model: modelClient,
system: systemPrompt,
messages: chatMessages,
onError: (error) => { onError: (error) => {
logger.error("Error streaming text:", error); logger.error("Error streaming text:", error);
const message = const message =
@@ -362,3 +388,11 @@ export function registerChatStreamHandlers() {
return true; return true;
}); });
} }
export function formatMessages(
messages: { role: string; content: string | undefined }[]
) {
return messages
.map((m) => `<message role="${m.role}">${m.content}</message>`)
.join("\n");
}

View File

@@ -4,9 +4,10 @@ import type {
FileChange, FileChange,
ProposalResult, ProposalResult,
SqlQuery, SqlQuery,
ActionProposal,
} from "../../lib/schemas"; } from "../../lib/schemas";
import { db } from "../../db"; import { db } from "../../db";
import { messages } from "../../db/schema"; import { messages, chats } from "../../db/schema";
import { desc, eq, and, Update } from "drizzle-orm"; import { desc, eq, and, Update } from "drizzle-orm";
import path from "node:path"; // Import path for basename import path from "node:path"; // Import path for basename
// Import tag parsers // Import tag parsers
@@ -21,6 +22,7 @@ import {
} from "../processors/response_processor"; } from "../processors/response_processor";
import log from "electron-log"; import log from "electron-log";
import { isServerFunction } from "../../supabase_admin/supabase_utils"; import { isServerFunction } from "../../supabase_admin/supabase_utils";
import { estimateMessagesTokens, getContextWindow } from "../utils/token_utils";
const logger = log.scope("proposal_handlers"); const logger = log.scope("proposal_handlers");
@@ -60,10 +62,45 @@ const getProposalHandler = async (
}, },
}); });
if (latestAssistantMessage?.approvalState === "rejected") { if (
return null; latestAssistantMessage?.approvalState === "rejected" ||
latestAssistantMessage?.approvalState === "approved"
) {
// Get all chat messages to calculate token usage
const chat = await db.query.chats.findFirst({
where: eq(chats.id, chatId),
with: {
messages: {
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
},
},
});
if (chat) {
// Calculate total tokens from message history
const totalTokens = estimateMessagesTokens(chat.messages);
const contextWindow = Math.min(getContextWindow(), 100_000);
logger.log(
`Token usage: ${totalTokens}/${contextWindow} (${
(totalTokens / contextWindow) * 100
}%)`
);
// If we're using more than 80% of the context window, suggest summarizing
if (totalTokens > contextWindow * 0.8) {
logger.log(
`Token usage high (${totalTokens}/${contextWindow}), suggesting summarize action`
);
return {
proposal: {
type: "action-proposal",
actions: [{ id: "summarize-in-new-chat" }],
},
chatId,
messageId: latestAssistantMessage.id,
};
}
} }
if (latestAssistantMessage?.approvalState === "approved") {
return null; return null;
} }
@@ -131,7 +168,12 @@ const getProposalHandler = async (
"packages=", "packages=",
proposal.packagesAdded.length proposal.packagesAdded.length
); );
return { proposal, chatId, messageId }; // Return proposal and messageId
return {
proposal: proposal,
chatId,
messageId,
};
} else { } else {
logger.log( logger.log(
"No relevant tags found in the latest assistant message content." "No relevant tags found in the latest assistant message content."
@@ -228,7 +270,7 @@ const rejectProposalHandler = async (
eq(messages.chatId, chatId), eq(messages.chatId, chatId),
eq(messages.role, "assistant") eq(messages.role, "assistant")
), ),
columns: { id: true }, // Only need to confirm existence columns: { id: true },
}); });
if (!messageToReject) { if (!messageToReject) {

View File

@@ -15,14 +15,10 @@ import { readSettings } from "../../main/settings";
import { MODEL_OPTIONS } from "../../constants/models"; import { MODEL_OPTIONS } from "../../constants/models";
import { TokenCountParams } from "../ipc_types"; import { TokenCountParams } from "../ipc_types";
import { TokenCountResult } from "../ipc_types"; import { TokenCountResult } from "../ipc_types";
import { estimateTokens, getContextWindow } from "../utils/token_utils";
const logger = log.scope("token_count_handlers"); const logger = log.scope("token_count_handlers");
// Estimate tokens (4 characters per token)
const estimateTokens = (text: string): number => {
return Math.ceil(text.length / 4);
};
export function registerTokenCountHandlers() { export function registerTokenCountHandlers() {
ipcMain.handle( ipcMain.handle(
"chat:count-tokens", "chat:count-tokens",
@@ -108,20 +104,3 @@ export function registerTokenCountHandlers() {
} }
); );
} }
const DEFAULT_CONTEXT_WINDOW = 128_000;
function getContextWindow() {
const settings = readSettings();
const model = settings.selectedModel;
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
logger.warn(
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`
);
return DEFAULT_CONTEXT_WINDOW;
}
const modelOption = MODEL_OPTIONS[
model.provider as keyof typeof MODEL_OPTIONS
].find((m) => m.name === model.name);
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
}

View File

@@ -0,0 +1,35 @@
import { readSettings } from "../../main/settings";
import { Message } from "../ipc_types";
import { MODEL_OPTIONS } from "../../constants/models";
import log from "electron-log";
const logger = log.scope("token_utils");
// Estimate tokens (4 characters per token)
export const estimateTokens = (text: string): number => {
return Math.ceil(text.length / 4);
};
export const estimateMessagesTokens = (messages: Message[]): number => {
return messages.reduce(
(acc, message) => acc + estimateTokens(message.content),
0
);
};
const DEFAULT_CONTEXT_WINDOW = 128_000;
export function getContextWindow() {
const settings = readSettings();
const model = settings.selectedModel;
if (!MODEL_OPTIONS[model.provider as keyof typeof MODEL_OPTIONS]) {
logger.warn(
`Model provider ${model.provider} not found in MODEL_OPTIONS. Using default max tokens.`
);
return DEFAULT_CONTEXT_WINDOW;
}
const modelOption = MODEL_OPTIONS[
model.provider as keyof typeof MODEL_OPTIONS
].find((m) => m.name === model.name);
return modelOption?.contextWindow || DEFAULT_CONTEXT_WINDOW;
}

View File

@@ -156,7 +156,7 @@ export interface CodeProposal {
} }
export interface SuggestedAction { export interface SuggestedAction {
id: "restart-app"; id: "restart-app" | "summarize-in-new-chat";
} }
export interface ActionProposal { export interface ActionProposal {

View File

@@ -0,0 +1,8 @@
export const SUMMARIZE_CHAT_SYSTEM_PROMPT = `
You are a helpful assistant that understands long conversations and can summarize them in a few bullet points.
I want you to write down the gist of the conversation in a few bullet points, focusing on the major changes, particularly
at the end of the conversation.
Use <dyad-chat-summary> for setting the chat summary (put this at the end). The chat summary should be less than a sentence, but more than a few words. YOU SHOULD ALWAYS INCLUDE EXACTLY ONE CHAT TITLE
`;