Add MCP support (#1028)

This commit is contained in:
Will Chen
2025-09-19 15:43:39 -07:00
committed by GitHub
parent 7b160b7d0b
commit 6d3c397d40
39 changed files with 3865 additions and 650 deletions

View File

@@ -1,5 +1,5 @@
import { v4 as uuidv4 } from "uuid";
import { ipcMain } from "electron";
import { ipcMain, IpcMainInvokeEvent } from "electron";
import {
ModelMessage,
TextPart,
@@ -7,7 +7,9 @@ import {
streamText,
ToolSet,
TextStreamPart,
stepCountIs,
} from "ai";
import { db } from "../../db";
import { chats, messages } from "../../db/schema";
import { and, eq, isNull } from "drizzle-orm";
@@ -42,6 +44,8 @@ import { getMaxTokens, getTemperature } from "../utils/token_utils";
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
import { validateChatContext } from "../utils/context_paths_utils";
import { GoogleGenerativeAIProviderOptions } from "@ai-sdk/google";
import { mcpServers } from "../../db/schema";
import { requireMcpToolConsent } from "../utils/mcp_consent";
import { getExtraProviderOptions } from "../utils/thinking_utils";
@@ -64,6 +68,7 @@ import { parseAppMentions } from "@/shared/parse_mention_apps";
import { prompts as promptsTable } from "../../db/schema";
import { inArray } from "drizzle-orm";
import { replacePromptReference } from "../utils/replacePromptReference";
import { mcpManager } from "../utils/mcp_manager";
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
@@ -103,6 +108,23 @@ function escapeXml(unsafe: string): string {
.replace(/"/g, "&quot;");
}
// Safely parse an MCP tool key that combines server and tool names.
// We split on the LAST occurrence of "__" to avoid ambiguity if either
// side contains "__" as part of its sanitized name.
function parseMcpToolKey(toolKey: string): {
serverName: string;
toolName: string;
} {
const separator = "__";
const lastIndex = toolKey.lastIndexOf(separator);
if (lastIndex === -1) {
return { serverName: "", toolName: toolKey };
}
const serverName = toolKey.slice(0, lastIndex);
const toolName = toolKey.slice(lastIndex + separator.length);
return { serverName, toolName };
}
// Ensure the temp directory exists
if (!fs.existsSync(TEMP_DIR)) {
fs.mkdirSync(TEMP_DIR, { recursive: true });
@@ -129,11 +151,16 @@ async function processStreamChunks({
for await (const part of fullStream) {
let chunk = "";
if (
inThinkingBlock &&
!["reasoning-delta", "reasoning-end", "reasoning-start"].includes(
part.type,
)
) {
chunk = "</think>";
inThinkingBlock = false;
}
if (part.type === "text-delta") {
if (inThinkingBlock) {
chunk = "</think>";
inThinkingBlock = false;
}
chunk += part.text;
} else if (part.type === "reasoning-delta") {
if (!inThinkingBlock) {
@@ -142,6 +169,14 @@ async function processStreamChunks({
}
chunk += escapeDyadTags(part.text);
} else if (part.type === "tool-call") {
const { serverName, toolName } = parseMcpToolKey(part.toolName);
const content = escapeDyadTags(JSON.stringify(part.input));
chunk = `<dyad-mcp-tool-call server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-call>\n`;
} else if (part.type === "tool-result") {
const { serverName, toolName } = parseMcpToolKey(part.toolName);
const content = escapeDyadTags(part.output);
chunk = `<dyad-mcp-tool-result server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-result>\n`;
}
if (!chunk) {
@@ -496,7 +531,10 @@ ${componentSnippet}
let systemPrompt = constructSystemPrompt({
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
chatMode: settings.selectedChatMode,
chatMode:
settings.selectedChatMode === "agent"
? "build"
: settings.selectedChatMode,
});
// Add information about mentioned apps if any
@@ -603,19 +641,21 @@ This conversation includes one or more image attachments. When the user uploads
] as const)
: [];
const limitedHistoryChatMessages = limitedMessageHistory.map((msg) => ({
role: msg.role as "user" | "assistant" | "system",
// Why remove thinking tags?
// Thinking tags are generally not critical for the context
// and eats up extra tokens.
content:
settings.selectedChatMode === "ask"
? removeDyadTags(removeNonEssentialTags(msg.content))
: removeNonEssentialTags(msg.content),
}));
let chatMessages: ModelMessage[] = [
...codebasePrefix,
...otherCodebasePrefix,
...limitedMessageHistory.map((msg) => ({
role: msg.role as "user" | "assistant" | "system",
// Why remove thinking tags?
// Thinking tags are generally not critical for the context
// and eats up extra tokens.
content:
settings.selectedChatMode === "ask"
? removeDyadTags(removeNonEssentialTags(msg.content))
: removeNonEssentialTags(msg.content),
})),
...limitedHistoryChatMessages,
];
// Check if the last message should include attachments
@@ -654,9 +694,15 @@ This conversation includes one or more image attachments. When the user uploads
const simpleStreamText = async ({
chatMessages,
modelClient,
tools,
systemPromptOverride = systemPrompt,
dyadDisableFiles = false,
}: {
chatMessages: ModelMessage[];
modelClient: ModelClient;
tools?: ToolSet;
systemPromptOverride?: string;
dyadDisableFiles?: boolean;
}) => {
const dyadRequestId = uuidv4();
if (isEngineEnabled) {
@@ -671,6 +717,7 @@ This conversation includes one or more image attachments. When the user uploads
const providerOptions: Record<string, any> = {
"dyad-engine": {
dyadRequestId,
dyadDisableFiles,
},
"dyad-gateway": getExtraProviderOptions(
modelClient.builtinProviderId,
@@ -708,6 +755,7 @@ This conversation includes one or more image attachments. When the user uploads
},
} satisfies GoogleGenerativeAIProviderOptions;
}
return streamText({
headers: isAnthropic
? {
@@ -718,8 +766,10 @@ This conversation includes one or more image attachments. When the user uploads
temperature: await getTemperature(settings.selectedModel),
maxRetries: 2,
model: modelClient.model,
stopWhen: stepCountIs(3),
providerOptions,
system: systemPrompt,
system: systemPromptOverride,
tools,
messages: chatMessages.filter((m) => m.content),
onError: (error: any) => {
logger.error("Error streaming text:", error);
@@ -780,6 +830,38 @@ This conversation includes one or more image attachments. When the user uploads
return fullResponse;
};
if (settings.selectedChatMode === "agent") {
const tools = await getMcpTools(event);
const { fullStream } = await simpleStreamText({
chatMessages: limitedHistoryChatMessages,
modelClient,
tools,
systemPromptOverride: constructSystemPrompt({
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
chatMode: "agent",
}),
dyadDisableFiles: true,
});
const result = await processStreamChunks({
fullStream,
fullResponse,
abortController,
chatId: req.chatId,
processResponseChunkUpdate,
});
fullResponse = result.fullResponse;
chatMessages.push({
role: "assistant",
content: fullResponse,
});
chatMessages.push({
role: "user",
content: "OK.",
});
}
// When calling streamText, the messages need to be properly formatted for mixed content
const { fullStream } = await simpleStreamText({
chatMessages,
@@ -1316,3 +1398,48 @@ These are the other apps that I've mentioned in my prompt. These other apps' cod
${otherAppsCodebaseInfo}
`;
}
async function getMcpTools(event: IpcMainInvokeEvent): Promise<ToolSet> {
const mcpToolSet: ToolSet = {};
try {
const servers = await db
.select()
.from(mcpServers)
.where(eq(mcpServers.enabled, true as any));
for (const s of servers) {
const client = await mcpManager.getClient(s.id);
const toolSet = await client.tools();
for (const [name, tool] of Object.entries(toolSet)) {
const key = `${String(s.name || "").replace(/[^a-zA-Z0-9_-]/g, "-")}__${String(name).replace(/[^a-zA-Z0-9_-]/g, "-")}`;
const original = tool;
mcpToolSet[key] = {
description: original?.description,
inputSchema: original?.inputSchema,
execute: async (args: any, execCtx: any) => {
const inputPreview =
typeof args === "string"
? args
: Array.isArray(args)
? args.join(" ")
: JSON.stringify(args).slice(0, 500);
const ok = await requireMcpToolConsent(event, {
serverId: s.id,
serverName: s.name,
toolName: name,
toolDescription: original?.description,
inputPreview,
});
if (!ok) throw new Error(`User declined running tool ${key}`);
const res = await original.execute?.(args, execCtx);
return typeof res === "string" ? res : JSON.stringify(res);
},
};
}
}
} catch (e) {
logger.warn("Failed building MCP toolset", e);
}
return mcpToolSet;
}

View File

@@ -0,0 +1,163 @@
import { IpcMainInvokeEvent } from "electron";
import log from "electron-log";
import { db } from "../../db";
import { mcpServers, mcpToolConsents } from "../../db/schema";
import { eq, and } from "drizzle-orm";
import { createLoggedHandler } from "./safe_handle";
import { resolveConsent } from "../utils/mcp_consent";
import { getStoredConsent } from "../utils/mcp_consent";
import { mcpManager } from "../utils/mcp_manager";
import { CreateMcpServer, McpServerUpdate, McpTool } from "../ipc_types";
const logger = log.scope("mcp_handlers");
const handle = createLoggedHandler(logger);
type ConsentDecision = "accept-once" | "accept-always" | "decline";
export function registerMcpHandlers() {
// CRUD for MCP servers
handle("mcp:list-servers", async () => {
return await db.select().from(mcpServers);
});
handle(
"mcp:create-server",
async (_event: IpcMainInvokeEvent, params: CreateMcpServer) => {
const { name, transport, command, args, envJson, url, enabled } = params;
const result = await db
.insert(mcpServers)
.values({
name,
transport,
command: command || null,
args: args || null,
envJson: envJson || null,
url: url || null,
enabled: !!enabled,
})
.returning();
return result[0];
},
);
handle(
"mcp:update-server",
async (_event: IpcMainInvokeEvent, params: McpServerUpdate) => {
const update: any = {};
if (params.name !== undefined) update.name = params.name;
if (params.transport !== undefined) update.transport = params.transport;
if (params.command !== undefined) update.command = params.command;
if (params.args !== undefined) update.args = params.args || null;
if (params.cwd !== undefined) update.cwd = params.cwd;
if (params.envJson !== undefined) update.envJson = params.envJson || null;
if (params.url !== undefined) update.url = params.url;
if (params.enabled !== undefined) update.enabled = !!params.enabled;
const result = await db
.update(mcpServers)
.set(update)
.where(eq(mcpServers.id, params.id))
.returning();
// If server config changed, dispose cached client to be recreated on next use
try {
mcpManager.dispose(params.id);
} catch {}
return result[0];
},
);
handle(
"mcp:delete-server",
async (_event: IpcMainInvokeEvent, id: number) => {
try {
mcpManager.dispose(id);
} catch {}
await db.delete(mcpServers).where(eq(mcpServers.id, id));
return { success: true };
},
);
// Tools listing (dynamic)
handle(
"mcp:list-tools",
async (
_event: IpcMainInvokeEvent,
serverId: number,
): Promise<McpTool[]> => {
try {
const client = await mcpManager.getClient(serverId);
const remoteTools = await client.tools();
const tools = await Promise.all(
Object.entries(remoteTools).map(async ([name, tool]) => ({
name,
description: tool.description ?? null,
consent: await getStoredConsent(serverId, name),
})),
);
return tools;
} catch (e) {
logger.error("Failed to list tools", e);
return [];
}
},
);
// Consents
handle("mcp:get-tool-consents", async () => {
return await db.select().from(mcpToolConsents);
});
handle(
"mcp:set-tool-consent",
async (
_event: IpcMainInvokeEvent,
params: {
serverId: number;
toolName: string;
consent: "ask" | "always" | "denied";
},
) => {
const existing = await db
.select()
.from(mcpToolConsents)
.where(
and(
eq(mcpToolConsents.serverId, params.serverId),
eq(mcpToolConsents.toolName, params.toolName),
),
);
if (existing.length > 0) {
const result = await db
.update(mcpToolConsents)
.set({ consent: params.consent })
.where(
and(
eq(mcpToolConsents.serverId, params.serverId),
eq(mcpToolConsents.toolName, params.toolName),
),
)
.returning();
return result[0];
} else {
const result = await db
.insert(mcpToolConsents)
.values({
serverId: params.serverId,
toolName: params.toolName,
consent: params.consent,
})
.returning();
return result[0];
}
},
);
// Tool consent request/response handshake
// Receive consent response from renderer
handle(
"mcp:tool-consent-response",
async (_event, data: { requestId: string; decision: ConsentDecision }) => {
resolveConsent(data.requestId, data.decision);
},
);
}

View File

@@ -63,6 +63,8 @@ import type {
PromptDto,
CreatePromptParamsDto,
UpdatePromptParamsDto,
McpServerUpdate,
CreateMcpServer,
} from "./ipc_types";
import type { Template } from "../shared/templates";
import type {
@@ -119,11 +121,13 @@ export class IpcClient {
onError: (error: string) => void;
}
>;
private mcpConsentHandlers: Map<string, (payload: any) => void>;
private constructor() {
this.ipcRenderer = (window as any).electron.ipcRenderer as IpcRenderer;
this.chatStreams = new Map();
this.appStreams = new Map();
this.helpStreams = new Map();
this.mcpConsentHandlers = new Map();
// Set up listeners for stream events
this.ipcRenderer.on("chat:response:chunk", (data) => {
if (
@@ -238,6 +242,12 @@ export class IpcClient {
this.helpStreams.delete(sessionId);
}
});
// MCP tool consent request from main
this.ipcRenderer.on("mcp:tool-consent-request", (payload) => {
const handler = this.mcpConsentHandlers.get("consent");
if (handler) handler(payload);
});
}
public static getInstance(): IpcClient {
@@ -814,6 +824,67 @@ export class IpcClient {
return result.version as string;
}
// --- MCP Client Methods ---
public async listMcpServers() {
return this.ipcRenderer.invoke("mcp:list-servers");
}
public async createMcpServer(params: CreateMcpServer) {
return this.ipcRenderer.invoke("mcp:create-server", params);
}
public async updateMcpServer(params: McpServerUpdate) {
return this.ipcRenderer.invoke("mcp:update-server", params);
}
public async deleteMcpServer(id: number) {
return this.ipcRenderer.invoke("mcp:delete-server", id);
}
public async listMcpTools(serverId: number) {
return this.ipcRenderer.invoke("mcp:list-tools", serverId);
}
// Removed: upsertMcpTools and setMcpToolActive tools are fetched dynamically at runtime
public async getMcpToolConsents() {
return this.ipcRenderer.invoke("mcp:get-tool-consents");
}
public async setMcpToolConsent(params: {
serverId: number;
toolName: string;
consent: "ask" | "always" | "denied";
}) {
return this.ipcRenderer.invoke("mcp:set-tool-consent", params);
}
public onMcpToolConsentRequest(
handler: (payload: {
requestId: string;
serverId: number;
serverName: string;
toolName: string;
toolDescription?: string | null;
inputPreview?: string | null;
}) => void,
) {
this.mcpConsentHandlers.set("consent", handler as any);
return () => {
this.mcpConsentHandlers.delete("consent");
};
}
public respondToMcpConsentRequest(
requestId: string,
decision: "accept-once" | "accept-always" | "decline",
) {
this.ipcRenderer.invoke("mcp:tool-consent-response", {
requestId,
decision,
});
}
// Get proposal details
public async getProposal(chatId: number): Promise<ProposalResult | null> {
try {

View File

@@ -30,6 +30,7 @@ import { registerTemplateHandlers } from "./handlers/template_handlers";
import { registerPortalHandlers } from "./handlers/portal_handlers";
import { registerPromptHandlers } from "./handlers/prompt_handlers";
import { registerHelpBotHandlers } from "./handlers/help_bot_handlers";
import { registerMcpHandlers } from "./handlers/mcp_handlers";
export function registerIpcHandlers() {
// Register all IPC handlers by category
@@ -65,4 +66,5 @@ export function registerIpcHandlers() {
registerPortalHandlers();
registerPromptHandlers();
registerHelpBotHandlers();
registerMcpHandlers();
}

View File

@@ -450,3 +450,37 @@ export interface HelpChatResponseError {
sessionId: string;
error: string;
}
// --- MCP Types ---
export interface McpServer {
id: number;
name: string;
transport: string;
command?: string | null;
args?: string[] | null;
cwd?: string | null;
envJson?: Record<string, string> | null;
url?: string | null;
enabled: boolean;
createdAt: number;
updatedAt: number;
}
export interface CreateMcpServer
extends Omit<McpServer, "id" | "createdAt" | "updatedAt"> {}
export type McpServerUpdate = Partial<McpServer> & Pick<McpServer, "id">;
export type McpToolConsentType = "ask" | "always" | "denied";
export interface McpTool {
name: string;
description?: string | null;
consent: McpToolConsentType;
}
export interface McpToolConsent {
id: number;
serverId: number;
toolName: string;
consent: McpToolConsentType;
updatedAt: number;
}

View File

@@ -137,6 +137,10 @@ export function createDyadEngine(
if ("dyadRequestId" in parsedBody) {
delete parsedBody.dyadRequestId;
}
const dyadDisableFiles = parsedBody.dyadDisableFiles;
if ("dyadDisableFiles" in parsedBody) {
delete parsedBody.dyadDisableFiles;
}
// Track and modify requestId with attempt number
let modifiedRequestId = requestId;
@@ -147,7 +151,7 @@ export function createDyadEngine(
}
// Add files to the request if they exist
if (files?.length) {
if (files?.length && !dyadDisableFiles) {
parsedBody.dyad_options = {
files,
enable_lazy_edits: options.dyadOptions.enableLazyEdits,

View File

@@ -0,0 +1,108 @@
import { db } from "../../db";
import { mcpToolConsents } from "../../db/schema";
import { and, eq } from "drizzle-orm";
import { IpcMainInvokeEvent } from "electron";
export type Consent = "ask" | "always" | "denied";
const pendingConsentResolvers = new Map<
string,
(d: "accept-once" | "accept-always" | "decline") => void
>();
export function waitForConsent(
requestId: string,
): Promise<"accept-once" | "accept-always" | "decline"> {
return new Promise((resolve) => {
pendingConsentResolvers.set(requestId, resolve);
});
}
export function resolveConsent(
requestId: string,
decision: "accept-once" | "accept-always" | "decline",
) {
const resolver = pendingConsentResolvers.get(requestId);
if (resolver) {
pendingConsentResolvers.delete(requestId);
resolver(decision);
}
}
export async function getStoredConsent(
serverId: number,
toolName: string,
): Promise<Consent> {
const rows = await db
.select()
.from(mcpToolConsents)
.where(
and(
eq(mcpToolConsents.serverId, serverId),
eq(mcpToolConsents.toolName, toolName),
),
);
if (rows.length === 0) return "ask";
return (rows[0].consent as Consent) ?? "ask";
}
export async function setStoredConsent(
serverId: number,
toolName: string,
consent: Consent,
): Promise<void> {
const rows = await db
.select()
.from(mcpToolConsents)
.where(
and(
eq(mcpToolConsents.serverId, serverId),
eq(mcpToolConsents.toolName, toolName),
),
);
if (rows.length > 0) {
await db
.update(mcpToolConsents)
.set({ consent })
.where(
and(
eq(mcpToolConsents.serverId, serverId),
eq(mcpToolConsents.toolName, toolName),
),
);
} else {
await db.insert(mcpToolConsents).values({ serverId, toolName, consent });
}
}
export async function requireMcpToolConsent(
event: IpcMainInvokeEvent,
params: {
serverId: number;
serverName: string;
toolName: string;
toolDescription?: string | null;
inputPreview?: string | null;
},
): Promise<boolean> {
const current = await getStoredConsent(params.serverId, params.toolName);
if (current === "always") return true;
if (current === "denied") return false;
// Ask renderer for a decision via event bridge
const requestId = `${params.serverId}:${params.toolName}:${Date.now()}`;
(event.sender as any).send("mcp:tool-consent-request", {
requestId,
...params,
});
const response = await waitForConsent(requestId);
if (response === "accept-always") {
await setStoredConsent(params.serverId, params.toolName, "always");
return true;
}
if (response === "decline") {
return false;
}
return response === "accept-once";
}

View File

@@ -0,0 +1,59 @@
import { db } from "../../db";
import { mcpServers } from "../../db/schema";
import { experimental_createMCPClient, experimental_MCPClient } from "ai";
import { eq } from "drizzle-orm";
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
class McpManager {
private static _instance: McpManager;
static get instance(): McpManager {
if (!this._instance) this._instance = new McpManager();
return this._instance;
}
private clients = new Map<number, experimental_MCPClient>();
async getClient(serverId: number): Promise<experimental_MCPClient> {
const existing = this.clients.get(serverId);
if (existing) return existing;
const server = await db
.select()
.from(mcpServers)
.where(eq(mcpServers.id, serverId));
const s = server.find((x) => x.id === serverId);
if (!s) throw new Error(`MCP server not found: ${serverId}`);
let transport: StdioClientTransport | StreamableHTTPClientTransport;
if (s.transport === "stdio") {
const args = s.args ?? [];
const env = s.envJson ?? undefined;
if (!s.command) throw new Error("MCP server command is required");
transport = new StdioClientTransport({
command: s.command,
args,
env,
});
} else if (s.transport === "http") {
if (!s.url) throw new Error("HTTP MCP requires url");
transport = new StreamableHTTPClientTransport(new URL(s.url as string));
} else {
throw new Error(`Unsupported MCP transport: ${s.transport}`);
}
const client = await experimental_createMCPClient({
transport,
});
this.clients.set(serverId, client);
return client;
}
dispose(serverId: number) {
const c = this.clients.get(serverId);
if (c) {
c.close();
this.clients.delete(serverId);
}
}
}
export const mcpManager = McpManager.instance;