Allow manual context management (#376)

This commit is contained in:
Will Chen
2025-06-10 13:52:20 -07:00
committed by GitHub
parent e7941bc6f7
commit 534cbad909
55 changed files with 3296 additions and 114 deletions

View File

@@ -33,6 +33,7 @@ import { readFile, writeFile, unlink } from "fs/promises";
import { getMaxTokens } from "../utils/token_utils";
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
import { streamTextWithBackup } from "../utils/stream_utils";
import { validateChatContext } from "../utils/context_paths_utils";
const logger = log.scope("chat_stream_handlers");
@@ -226,7 +227,10 @@ export function registerChatStreamHandlers() {
if (updatedChat.app) {
const appPath = getDyadAppPath(updatedChat.app.path);
try {
const out = await extractCodebase(appPath);
const out = await extractCodebase({
appPath,
chatContext: validateChatContext(updatedChat.app.chatContext),
});
codebaseInfo = out.formattedOutput;
files = out.files;
logger.log(`Extracted codebase information from ${appPath}`);

View File

@@ -0,0 +1,98 @@
import { db } from "@/db";
import { apps } from "@/db/schema";
import { eq } from "drizzle-orm";
import { z } from "zod";
import {
AppChatContext,
AppChatContextSchema,
ContextPathResults,
} from "@/lib/schemas";
import { estimateTokens } from "../utils/token_utils";
import { createLoggedHandler } from "./safe_handle";
import log from "electron-log";
import { getDyadAppPath } from "@/paths/paths";
import { extractCodebase } from "@/utils/codebase";
import { validateChatContext } from "../utils/context_paths_utils";
const logger = log.scope("context_paths_handlers");
const handle = createLoggedHandler(logger);
export function registerContextPathsHandlers() {
handle(
"get-context-paths",
async (_, { appId }: { appId: number }): Promise<ContextPathResults> => {
z.object({ appId: z.number() }).parse({ appId });
const app = await db.query.apps.findFirst({
where: eq(apps.id, appId),
});
if (!app) {
throw new Error("App not found");
}
if (!app.path) {
throw new Error("App path not set");
}
const appPath = getDyadAppPath(app.path);
const results: ContextPathResults = {
contextPaths: [],
smartContextAutoIncludes: [],
};
const { contextPaths, smartContextAutoIncludes } = validateChatContext(
app.chatContext,
);
for (const contextPath of contextPaths) {
const { formattedOutput, files } = await extractCodebase({
appPath,
chatContext: {
contextPaths: [contextPath],
smartContextAutoIncludes: [],
},
});
const totalTokens = estimateTokens(formattedOutput);
results.contextPaths.push({
...contextPath,
files: files.length,
tokens: totalTokens,
});
}
for (const contextPath of smartContextAutoIncludes) {
const { formattedOutput, files } = await extractCodebase({
appPath,
chatContext: {
contextPaths: [contextPath],
smartContextAutoIncludes: [],
},
});
const totalTokens = estimateTokens(formattedOutput);
results.smartContextAutoIncludes.push({
...contextPath,
files: files.length,
tokens: totalTokens,
});
}
return results;
},
);
handle(
"set-context-paths",
async (
_,
{ appId, chatContext }: { appId: number; chatContext: AppChatContext },
) => {
const schema = z.object({
appId: z.number(),
chatContext: AppChatContextSchema,
});
schema.parse({ appId, chatContext });
await db.update(apps).set({ chatContext }).where(eq(apps.id, appId));
},
);
}

View File

@@ -13,6 +13,7 @@ import { chats, apps } from "../../db/schema";
import { eq } from "drizzle-orm";
import { getDyadAppPath } from "../../paths/paths";
import { LargeLanguageModel } from "@/lib/schemas";
import { validateChatContext } from "../utils/context_paths_utils";
// Shared function to get system debug info
async function getSystemDebugInfo({
@@ -175,7 +176,12 @@ export function registerDebugHandlers() {
// Extract codebase
const appPath = getDyadAppPath(app.path);
const codebase = (await extractCodebase(appPath)).formattedOutput;
const codebase = (
await extractCodebase({
appPath,
chatContext: validateChatContext(app.chatContext),
})
).formattedOutput;
return {
debugInfo,

View File

@@ -31,6 +31,7 @@ import { getDyadAppPath } from "../../paths/paths";
import { withLock } from "../utils/lock_utils";
import { createLoggedHandler } from "./safe_handle";
import { ApproveProposalResult } from "../ipc_types";
import { validateChatContext } from "../utils/context_paths_utils";
const logger = log.scope("proposal_handlers");
const handle = createLoggedHandler(logger);
@@ -41,6 +42,7 @@ interface CodebaseTokenCache {
messageContent: string;
tokenCount: number;
timestamp: number;
chatContext: string;
}
// Cache expiration time (5 minutes)
@@ -74,6 +76,7 @@ async function getCodebaseTokenCount(
messageId: number,
messageContent: string,
appPath: string,
chatContext: unknown,
): Promise<number> {
// Clean up expired cache entries first
cleanupExpiredCacheEntries();
@@ -86,6 +89,7 @@ async function getCodebaseTokenCount(
cacheEntry &&
cacheEntry.messageId === messageId &&
cacheEntry.messageContent === messageContent &&
cacheEntry.chatContext === JSON.stringify(chatContext) &&
now - cacheEntry.timestamp < CACHE_EXPIRATION_MS
) {
logger.log(`Using cached codebase token count for chatId: ${chatId}`);
@@ -94,8 +98,12 @@ async function getCodebaseTokenCount(
// Calculate and cache the token count
logger.log(`Calculating codebase token count for chatId: ${chatId}`);
const codebase = (await extractCodebase(getDyadAppPath(appPath)))
.formattedOutput;
const codebase = (
await extractCodebase({
appPath: getDyadAppPath(appPath),
chatContext: validateChatContext(chatContext),
})
).formattedOutput;
const tokenCount = estimateTokens(codebase);
// Store in cache
@@ -105,6 +113,7 @@ async function getCodebaseTokenCount(
messageContent,
tokenCount,
timestamp: now,
chatContext: JSON.stringify(chatContext),
});
return tokenCount;
@@ -277,6 +286,7 @@ const getProposalHandler = async (
latestAssistantMessage.id,
latestAssistantMessage.content || "",
chat.app.path,
chat.app.chatContext,
);
const totalTokens = messagesTokenCount + codebaseTokenCount;

View File

@@ -18,6 +18,7 @@ import { TokenCountParams } from "../ipc_types";
import { TokenCountResult } from "../ipc_types";
import { estimateTokens, getContextWindow } from "../utils/token_utils";
import { createLoggedHandler } from "./safe_handle";
import { validateChatContext } from "../utils/context_paths_utils";
const logger = log.scope("token_count_handlers");
@@ -73,7 +74,12 @@ export function registerTokenCountHandlers() {
if (chat.app) {
const appPath = getDyadAppPath(chat.app.path);
codebaseInfo = (await extractCodebase(appPath)).formattedOutput;
codebaseInfo = (
await extractCodebase({
appPath,
chatContext: validateChatContext(chat.app.chatContext),
})
).formattedOutput;
codebaseTokens = estimateTokens(codebaseInfo);
logger.log(
`Extracted codebase information from ${appPath}, tokens: ${codebaseTokens}`,

View File

@@ -3,9 +3,9 @@ import {
type ChatSummary,
ChatSummariesSchema,
type UserSettings,
type ContextPathResults,
} from "../lib/schemas";
import type {
App,
AppOutput,
Chat,
ChatResponseEnd,
@@ -32,8 +32,9 @@ import type {
RenameBranchParams,
UserBudgetInfo,
CopyAppParams,
App,
} from "./ipc_types";
import type { ProposalResult } from "@/lib/schemas";
import type { AppChatContext, ProposalResult } from "@/lib/schemas";
import { showError } from "@/lib/toast";
export interface ChatStreamCallbacks {
@@ -847,4 +848,17 @@ export class IpcClient {
public async getUserBudget(): Promise<UserBudgetInfo | null> {
return this.ipcRenderer.invoke("get-user-budget");
}
public async getChatContextResults(params: {
appId: number;
}): Promise<ContextPathResults> {
return this.ipcRenderer.invoke("get-context-paths", params);
}
public async setChatContext(params: {
appId: number;
chatContext: AppChatContext;
}): Promise<void> {
return this.ipcRenderer.invoke("set-context-paths", params);
}
}

View File

@@ -19,6 +19,7 @@ import { registerReleaseNoteHandlers } from "./handlers/release_note_handlers";
import { registerImportHandlers } from "./handlers/import_handlers";
import { registerSessionHandlers } from "./handlers/session_handlers";
import { registerProHandlers } from "./handlers/pro_handlers";
import { registerContextPathsHandlers } from "./handlers/context_paths_handlers";
export function registerIpcHandlers() {
// Register all IPC handlers by category
@@ -43,4 +44,5 @@ export function registerIpcHandlers() {
registerImportHandlers();
registerSessionHandlers();
registerProHandlers();
registerContextPathsHandlers();
}

View File

@@ -0,0 +1,25 @@
import { AppChatContext, AppChatContextSchema } from "@/lib/schemas";
import log from "electron-log";
const logger = log.scope("context_paths_utils");
export function validateChatContext(chatContext: unknown): AppChatContext {
if (!chatContext) {
return {
contextPaths: [],
smartContextAutoIncludes: [],
};
}
try {
// Validate that the contextPaths data matches the expected schema
return AppChatContextSchema.parse(chatContext);
} catch (error) {
logger.warn("Invalid contextPaths data:", error);
// Return empty array as fallback if validation fails
return {
contextPaths: [],
smartContextAutoIncludes: [],
};
}
}