Add MCP support (#1028)

This commit is contained in:
Will Chen
2025-09-19 15:43:39 -07:00
committed by GitHub
parent 7b160b7d0b
commit 6d3c397d40
39 changed files with 3865 additions and 650 deletions

View File

@@ -1,5 +1,5 @@
import { v4 as uuidv4 } from "uuid";
import { ipcMain } from "electron";
import { ipcMain, IpcMainInvokeEvent } from "electron";
import {
ModelMessage,
TextPart,
@@ -7,7 +7,9 @@ import {
streamText,
ToolSet,
TextStreamPart,
stepCountIs,
} from "ai";
import { db } from "../../db";
import { chats, messages } from "../../db/schema";
import { and, eq, isNull } from "drizzle-orm";
@@ -42,6 +44,8 @@ import { getMaxTokens, getTemperature } from "../utils/token_utils";
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
import { validateChatContext } from "../utils/context_paths_utils";
import { GoogleGenerativeAIProviderOptions } from "@ai-sdk/google";
import { mcpServers } from "../../db/schema";
import { requireMcpToolConsent } from "../utils/mcp_consent";
import { getExtraProviderOptions } from "../utils/thinking_utils";
@@ -64,6 +68,7 @@ import { parseAppMentions } from "@/shared/parse_mention_apps";
import { prompts as promptsTable } from "../../db/schema";
import { inArray } from "drizzle-orm";
import { replacePromptReference } from "../utils/replacePromptReference";
import { mcpManager } from "../utils/mcp_manager";
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
@@ -103,6 +108,23 @@ function escapeXml(unsafe: string): string {
.replace(/"/g, "&quot;");
}
// Safely parse an MCP tool key that combines server and tool names.
// We split on the LAST occurrence of "__" to avoid ambiguity if either
// side contains "__" as part of its sanitized name.
function parseMcpToolKey(toolKey: string): {
serverName: string;
toolName: string;
} {
const separator = "__";
const lastIndex = toolKey.lastIndexOf(separator);
if (lastIndex === -1) {
return { serverName: "", toolName: toolKey };
}
const serverName = toolKey.slice(0, lastIndex);
const toolName = toolKey.slice(lastIndex + separator.length);
return { serverName, toolName };
}
// Ensure the temp directory exists
if (!fs.existsSync(TEMP_DIR)) {
fs.mkdirSync(TEMP_DIR, { recursive: true });
@@ -129,11 +151,16 @@ async function processStreamChunks({
for await (const part of fullStream) {
let chunk = "";
if (
inThinkingBlock &&
!["reasoning-delta", "reasoning-end", "reasoning-start"].includes(
part.type,
)
) {
chunk = "</think>";
inThinkingBlock = false;
}
if (part.type === "text-delta") {
if (inThinkingBlock) {
chunk = "</think>";
inThinkingBlock = false;
}
chunk += part.text;
} else if (part.type === "reasoning-delta") {
if (!inThinkingBlock) {
@@ -142,6 +169,14 @@ async function processStreamChunks({
}
chunk += escapeDyadTags(part.text);
} else if (part.type === "tool-call") {
const { serverName, toolName } = parseMcpToolKey(part.toolName);
const content = escapeDyadTags(JSON.stringify(part.input));
chunk = `<dyad-mcp-tool-call server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-call>\n`;
} else if (part.type === "tool-result") {
const { serverName, toolName } = parseMcpToolKey(part.toolName);
const content = escapeDyadTags(part.output);
chunk = `<dyad-mcp-tool-result server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-result>\n`;
}
if (!chunk) {
@@ -496,7 +531,10 @@ ${componentSnippet}
let systemPrompt = constructSystemPrompt({
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
chatMode: settings.selectedChatMode,
chatMode:
settings.selectedChatMode === "agent"
? "build"
: settings.selectedChatMode,
});
// Add information about mentioned apps if any
@@ -603,19 +641,21 @@ This conversation includes one or more image attachments. When the user uploads
] as const)
: [];
const limitedHistoryChatMessages = limitedMessageHistory.map((msg) => ({
role: msg.role as "user" | "assistant" | "system",
// Why remove thinking tags?
// Thinking tags are generally not critical for the context
// and eats up extra tokens.
content:
settings.selectedChatMode === "ask"
? removeDyadTags(removeNonEssentialTags(msg.content))
: removeNonEssentialTags(msg.content),
}));
let chatMessages: ModelMessage[] = [
...codebasePrefix,
...otherCodebasePrefix,
...limitedMessageHistory.map((msg) => ({
role: msg.role as "user" | "assistant" | "system",
// Why remove thinking tags?
// Thinking tags are generally not critical for the context
// and eats up extra tokens.
content:
settings.selectedChatMode === "ask"
? removeDyadTags(removeNonEssentialTags(msg.content))
: removeNonEssentialTags(msg.content),
})),
...limitedHistoryChatMessages,
];
// Check if the last message should include attachments
@@ -654,9 +694,15 @@ This conversation includes one or more image attachments. When the user uploads
const simpleStreamText = async ({
chatMessages,
modelClient,
tools,
systemPromptOverride = systemPrompt,
dyadDisableFiles = false,
}: {
chatMessages: ModelMessage[];
modelClient: ModelClient;
tools?: ToolSet;
systemPromptOverride?: string;
dyadDisableFiles?: boolean;
}) => {
const dyadRequestId = uuidv4();
if (isEngineEnabled) {
@@ -671,6 +717,7 @@ This conversation includes one or more image attachments. When the user uploads
const providerOptions: Record<string, any> = {
"dyad-engine": {
dyadRequestId,
dyadDisableFiles,
},
"dyad-gateway": getExtraProviderOptions(
modelClient.builtinProviderId,
@@ -708,6 +755,7 @@ This conversation includes one or more image attachments. When the user uploads
},
} satisfies GoogleGenerativeAIProviderOptions;
}
return streamText({
headers: isAnthropic
? {
@@ -718,8 +766,10 @@ This conversation includes one or more image attachments. When the user uploads
temperature: await getTemperature(settings.selectedModel),
maxRetries: 2,
model: modelClient.model,
stopWhen: stepCountIs(3),
providerOptions,
system: systemPrompt,
system: systemPromptOverride,
tools,
messages: chatMessages.filter((m) => m.content),
onError: (error: any) => {
logger.error("Error streaming text:", error);
@@ -780,6 +830,38 @@ This conversation includes one or more image attachments. When the user uploads
return fullResponse;
};
if (settings.selectedChatMode === "agent") {
const tools = await getMcpTools(event);
const { fullStream } = await simpleStreamText({
chatMessages: limitedHistoryChatMessages,
modelClient,
tools,
systemPromptOverride: constructSystemPrompt({
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
chatMode: "agent",
}),
dyadDisableFiles: true,
});
const result = await processStreamChunks({
fullStream,
fullResponse,
abortController,
chatId: req.chatId,
processResponseChunkUpdate,
});
fullResponse = result.fullResponse;
chatMessages.push({
role: "assistant",
content: fullResponse,
});
chatMessages.push({
role: "user",
content: "OK.",
});
}
// When calling streamText, the messages need to be properly formatted for mixed content
const { fullStream } = await simpleStreamText({
chatMessages,
@@ -1316,3 +1398,48 @@ These are the other apps that I've mentioned in my prompt. These other apps' cod
${otherAppsCodebaseInfo}
`;
}
async function getMcpTools(event: IpcMainInvokeEvent): Promise<ToolSet> {
const mcpToolSet: ToolSet = {};
try {
const servers = await db
.select()
.from(mcpServers)
.where(eq(mcpServers.enabled, true as any));
for (const s of servers) {
const client = await mcpManager.getClient(s.id);
const toolSet = await client.tools();
for (const [name, tool] of Object.entries(toolSet)) {
const key = `${String(s.name || "").replace(/[^a-zA-Z0-9_-]/g, "-")}__${String(name).replace(/[^a-zA-Z0-9_-]/g, "-")}`;
const original = tool;
mcpToolSet[key] = {
description: original?.description,
inputSchema: original?.inputSchema,
execute: async (args: any, execCtx: any) => {
const inputPreview =
typeof args === "string"
? args
: Array.isArray(args)
? args.join(" ")
: JSON.stringify(args).slice(0, 500);
const ok = await requireMcpToolConsent(event, {
serverId: s.id,
serverName: s.name,
toolName: name,
toolDescription: original?.description,
inputPreview,
});
if (!ok) throw new Error(`User declined running tool ${key}`);
const res = await original.execute?.(args, execCtx);
return typeof res === "string" ? res : JSON.stringify(res);
},
};
}
}
} catch (e) {
logger.warn("Failed building MCP toolset", e);
}
return mcpToolSet;
}

View File

@@ -0,0 +1,163 @@
import { IpcMainInvokeEvent } from "electron";
import log from "electron-log";
import { db } from "../../db";
import { mcpServers, mcpToolConsents } from "../../db/schema";
import { eq, and } from "drizzle-orm";
import { createLoggedHandler } from "./safe_handle";
import { resolveConsent } from "../utils/mcp_consent";
import { getStoredConsent } from "../utils/mcp_consent";
import { mcpManager } from "../utils/mcp_manager";
import { CreateMcpServer, McpServerUpdate, McpTool } from "../ipc_types";
const logger = log.scope("mcp_handlers");
const handle = createLoggedHandler(logger);
type ConsentDecision = "accept-once" | "accept-always" | "decline";
export function registerMcpHandlers() {
// CRUD for MCP servers
handle("mcp:list-servers", async () => {
return await db.select().from(mcpServers);
});
handle(
"mcp:create-server",
async (_event: IpcMainInvokeEvent, params: CreateMcpServer) => {
const { name, transport, command, args, envJson, url, enabled } = params;
const result = await db
.insert(mcpServers)
.values({
name,
transport,
command: command || null,
args: args || null,
envJson: envJson || null,
url: url || null,
enabled: !!enabled,
})
.returning();
return result[0];
},
);
handle(
"mcp:update-server",
async (_event: IpcMainInvokeEvent, params: McpServerUpdate) => {
const update: any = {};
if (params.name !== undefined) update.name = params.name;
if (params.transport !== undefined) update.transport = params.transport;
if (params.command !== undefined) update.command = params.command;
if (params.args !== undefined) update.args = params.args || null;
if (params.cwd !== undefined) update.cwd = params.cwd;
if (params.envJson !== undefined) update.envJson = params.envJson || null;
if (params.url !== undefined) update.url = params.url;
if (params.enabled !== undefined) update.enabled = !!params.enabled;
const result = await db
.update(mcpServers)
.set(update)
.where(eq(mcpServers.id, params.id))
.returning();
// If server config changed, dispose cached client to be recreated on next use
try {
mcpManager.dispose(params.id);
} catch {}
return result[0];
},
);
handle(
"mcp:delete-server",
async (_event: IpcMainInvokeEvent, id: number) => {
try {
mcpManager.dispose(id);
} catch {}
await db.delete(mcpServers).where(eq(mcpServers.id, id));
return { success: true };
},
);
// Tools listing (dynamic)
handle(
"mcp:list-tools",
async (
_event: IpcMainInvokeEvent,
serverId: number,
): Promise<McpTool[]> => {
try {
const client = await mcpManager.getClient(serverId);
const remoteTools = await client.tools();
const tools = await Promise.all(
Object.entries(remoteTools).map(async ([name, tool]) => ({
name,
description: tool.description ?? null,
consent: await getStoredConsent(serverId, name),
})),
);
return tools;
} catch (e) {
logger.error("Failed to list tools", e);
return [];
}
},
);
// Consents
handle("mcp:get-tool-consents", async () => {
return await db.select().from(mcpToolConsents);
});
handle(
"mcp:set-tool-consent",
async (
_event: IpcMainInvokeEvent,
params: {
serverId: number;
toolName: string;
consent: "ask" | "always" | "denied";
},
) => {
const existing = await db
.select()
.from(mcpToolConsents)
.where(
and(
eq(mcpToolConsents.serverId, params.serverId),
eq(mcpToolConsents.toolName, params.toolName),
),
);
if (existing.length > 0) {
const result = await db
.update(mcpToolConsents)
.set({ consent: params.consent })
.where(
and(
eq(mcpToolConsents.serverId, params.serverId),
eq(mcpToolConsents.toolName, params.toolName),
),
)
.returning();
return result[0];
} else {
const result = await db
.insert(mcpToolConsents)
.values({
serverId: params.serverId,
toolName: params.toolName,
consent: params.consent,
})
.returning();
return result[0];
}
},
);
// Tool consent request/response handshake
// Receive consent response from renderer
handle(
"mcp:tool-consent-response",
async (_event, data: { requestId: string; decision: ConsentDecision }) => {
resolveConsent(data.requestId, data.decision);
},
);
}