Add MCP support (#1028)

This commit is contained in:
Will Chen
2025-09-19 15:43:39 -07:00
committed by GitHub
parent 7b160b7d0b
commit 6d3c397d40
39 changed files with 3865 additions and 650 deletions

View File

@@ -1,5 +1,5 @@
import { v4 as uuidv4 } from "uuid";
import { ipcMain } from "electron";
import { ipcMain, IpcMainInvokeEvent } from "electron";
import {
ModelMessage,
TextPart,
@@ -7,7 +7,9 @@ import {
streamText,
ToolSet,
TextStreamPart,
stepCountIs,
} from "ai";
import { db } from "../../db";
import { chats, messages } from "../../db/schema";
import { and, eq, isNull } from "drizzle-orm";
@@ -42,6 +44,8 @@ import { getMaxTokens, getTemperature } from "../utils/token_utils";
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
import { validateChatContext } from "../utils/context_paths_utils";
import { GoogleGenerativeAIProviderOptions } from "@ai-sdk/google";
import { mcpServers } from "../../db/schema";
import { requireMcpToolConsent } from "../utils/mcp_consent";
import { getExtraProviderOptions } from "../utils/thinking_utils";
@@ -64,6 +68,7 @@ import { parseAppMentions } from "@/shared/parse_mention_apps";
import { prompts as promptsTable } from "../../db/schema";
import { inArray } from "drizzle-orm";
import { replacePromptReference } from "../utils/replacePromptReference";
import { mcpManager } from "../utils/mcp_manager";
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
@@ -103,6 +108,23 @@ function escapeXml(unsafe: string): string {
.replace(/"/g, "&quot;");
}
// Safely parse an MCP tool key that combines server and tool names.
// We split on the LAST occurrence of "__" to avoid ambiguity if either
// side contains "__" as part of its sanitized name.
function parseMcpToolKey(toolKey: string): {
serverName: string;
toolName: string;
} {
const separator = "__";
const lastIndex = toolKey.lastIndexOf(separator);
if (lastIndex === -1) {
return { serverName: "", toolName: toolKey };
}
const serverName = toolKey.slice(0, lastIndex);
const toolName = toolKey.slice(lastIndex + separator.length);
return { serverName, toolName };
}
// Ensure the temp directory exists
if (!fs.existsSync(TEMP_DIR)) {
fs.mkdirSync(TEMP_DIR, { recursive: true });
@@ -129,11 +151,16 @@ async function processStreamChunks({
for await (const part of fullStream) {
let chunk = "";
if (
inThinkingBlock &&
!["reasoning-delta", "reasoning-end", "reasoning-start"].includes(
part.type,
)
) {
chunk = "</think>";
inThinkingBlock = false;
}
if (part.type === "text-delta") {
if (inThinkingBlock) {
chunk = "</think>";
inThinkingBlock = false;
}
chunk += part.text;
} else if (part.type === "reasoning-delta") {
if (!inThinkingBlock) {
@@ -142,6 +169,14 @@ async function processStreamChunks({
}
chunk += escapeDyadTags(part.text);
} else if (part.type === "tool-call") {
const { serverName, toolName } = parseMcpToolKey(part.toolName);
const content = escapeDyadTags(JSON.stringify(part.input));
chunk = `<dyad-mcp-tool-call server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-call>\n`;
} else if (part.type === "tool-result") {
const { serverName, toolName } = parseMcpToolKey(part.toolName);
const content = escapeDyadTags(part.output);
chunk = `<dyad-mcp-tool-result server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-result>\n`;
}
if (!chunk) {
@@ -496,7 +531,10 @@ ${componentSnippet}
let systemPrompt = constructSystemPrompt({
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
chatMode: settings.selectedChatMode,
chatMode:
settings.selectedChatMode === "agent"
? "build"
: settings.selectedChatMode,
});
// Add information about mentioned apps if any
@@ -603,19 +641,21 @@ This conversation includes one or more image attachments. When the user uploads
] as const)
: [];
const limitedHistoryChatMessages = limitedMessageHistory.map((msg) => ({
role: msg.role as "user" | "assistant" | "system",
// Why remove thinking tags?
// Thinking tags are generally not critical for the context
// and eats up extra tokens.
content:
settings.selectedChatMode === "ask"
? removeDyadTags(removeNonEssentialTags(msg.content))
: removeNonEssentialTags(msg.content),
}));
let chatMessages: ModelMessage[] = [
...codebasePrefix,
...otherCodebasePrefix,
...limitedMessageHistory.map((msg) => ({
role: msg.role as "user" | "assistant" | "system",
// Why remove thinking tags?
// Thinking tags are generally not critical for the context
// and eats up extra tokens.
content:
settings.selectedChatMode === "ask"
? removeDyadTags(removeNonEssentialTags(msg.content))
: removeNonEssentialTags(msg.content),
})),
...limitedHistoryChatMessages,
];
// Check if the last message should include attachments
@@ -654,9 +694,15 @@ This conversation includes one or more image attachments. When the user uploads
const simpleStreamText = async ({
chatMessages,
modelClient,
tools,
systemPromptOverride = systemPrompt,
dyadDisableFiles = false,
}: {
chatMessages: ModelMessage[];
modelClient: ModelClient;
tools?: ToolSet;
systemPromptOverride?: string;
dyadDisableFiles?: boolean;
}) => {
const dyadRequestId = uuidv4();
if (isEngineEnabled) {
@@ -671,6 +717,7 @@ This conversation includes one or more image attachments. When the user uploads
const providerOptions: Record<string, any> = {
"dyad-engine": {
dyadRequestId,
dyadDisableFiles,
},
"dyad-gateway": getExtraProviderOptions(
modelClient.builtinProviderId,
@@ -708,6 +755,7 @@ This conversation includes one or more image attachments. When the user uploads
},
} satisfies GoogleGenerativeAIProviderOptions;
}
return streamText({
headers: isAnthropic
? {
@@ -718,8 +766,10 @@ This conversation includes one or more image attachments. When the user uploads
temperature: await getTemperature(settings.selectedModel),
maxRetries: 2,
model: modelClient.model,
stopWhen: stepCountIs(3),
providerOptions,
system: systemPrompt,
system: systemPromptOverride,
tools,
messages: chatMessages.filter((m) => m.content),
onError: (error: any) => {
logger.error("Error streaming text:", error);
@@ -780,6 +830,38 @@ This conversation includes one or more image attachments. When the user uploads
return fullResponse;
};
if (settings.selectedChatMode === "agent") {
const tools = await getMcpTools(event);
const { fullStream } = await simpleStreamText({
chatMessages: limitedHistoryChatMessages,
modelClient,
tools,
systemPromptOverride: constructSystemPrompt({
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
chatMode: "agent",
}),
dyadDisableFiles: true,
});
const result = await processStreamChunks({
fullStream,
fullResponse,
abortController,
chatId: req.chatId,
processResponseChunkUpdate,
});
fullResponse = result.fullResponse;
chatMessages.push({
role: "assistant",
content: fullResponse,
});
chatMessages.push({
role: "user",
content: "OK.",
});
}
// When calling streamText, the messages need to be properly formatted for mixed content
const { fullStream } = await simpleStreamText({
chatMessages,
@@ -1316,3 +1398,48 @@ These are the other apps that I've mentioned in my prompt. These other apps' cod
${otherAppsCodebaseInfo}
`;
}
async function getMcpTools(event: IpcMainInvokeEvent): Promise<ToolSet> {
const mcpToolSet: ToolSet = {};
try {
const servers = await db
.select()
.from(mcpServers)
.where(eq(mcpServers.enabled, true as any));
for (const s of servers) {
const client = await mcpManager.getClient(s.id);
const toolSet = await client.tools();
for (const [name, tool] of Object.entries(toolSet)) {
const key = `${String(s.name || "").replace(/[^a-zA-Z0-9_-]/g, "-")}__${String(name).replace(/[^a-zA-Z0-9_-]/g, "-")}`;
const original = tool;
mcpToolSet[key] = {
description: original?.description,
inputSchema: original?.inputSchema,
execute: async (args: any, execCtx: any) => {
const inputPreview =
typeof args === "string"
? args
: Array.isArray(args)
? args.join(" ")
: JSON.stringify(args).slice(0, 500);
const ok = await requireMcpToolConsent(event, {
serverId: s.id,
serverName: s.name,
toolName: name,
toolDescription: original?.description,
inputPreview,
});
if (!ok) throw new Error(`User declined running tool ${key}`);
const res = await original.execute?.(args, execCtx);
return typeof res === "string" ? res : JSON.stringify(res);
},
};
}
}
} catch (e) {
logger.warn("Failed building MCP toolset", e);
}
return mcpToolSet;
}