Add MCP support (#1028)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { ipcMain } from "electron";
|
||||
import { ipcMain, IpcMainInvokeEvent } from "electron";
|
||||
import {
|
||||
ModelMessage,
|
||||
TextPart,
|
||||
@@ -7,7 +7,9 @@ import {
|
||||
streamText,
|
||||
ToolSet,
|
||||
TextStreamPart,
|
||||
stepCountIs,
|
||||
} from "ai";
|
||||
|
||||
import { db } from "../../db";
|
||||
import { chats, messages } from "../../db/schema";
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
@@ -42,6 +44,8 @@ import { getMaxTokens, getTemperature } from "../utils/token_utils";
|
||||
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
|
||||
import { validateChatContext } from "../utils/context_paths_utils";
|
||||
import { GoogleGenerativeAIProviderOptions } from "@ai-sdk/google";
|
||||
import { mcpServers } from "../../db/schema";
|
||||
import { requireMcpToolConsent } from "../utils/mcp_consent";
|
||||
|
||||
import { getExtraProviderOptions } from "../utils/thinking_utils";
|
||||
|
||||
@@ -64,6 +68,7 @@ import { parseAppMentions } from "@/shared/parse_mention_apps";
|
||||
import { prompts as promptsTable } from "../../db/schema";
|
||||
import { inArray } from "drizzle-orm";
|
||||
import { replacePromptReference } from "../utils/replacePromptReference";
|
||||
import { mcpManager } from "../utils/mcp_manager";
|
||||
|
||||
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
|
||||
|
||||
@@ -103,6 +108,23 @@ function escapeXml(unsafe: string): string {
|
||||
.replace(/"/g, """);
|
||||
}
|
||||
|
||||
// Safely parse an MCP tool key that combines server and tool names.
|
||||
// We split on the LAST occurrence of "__" to avoid ambiguity if either
|
||||
// side contains "__" as part of its sanitized name.
|
||||
function parseMcpToolKey(toolKey: string): {
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
} {
|
||||
const separator = "__";
|
||||
const lastIndex = toolKey.lastIndexOf(separator);
|
||||
if (lastIndex === -1) {
|
||||
return { serverName: "", toolName: toolKey };
|
||||
}
|
||||
const serverName = toolKey.slice(0, lastIndex);
|
||||
const toolName = toolKey.slice(lastIndex + separator.length);
|
||||
return { serverName, toolName };
|
||||
}
|
||||
|
||||
// Ensure the temp directory exists
|
||||
if (!fs.existsSync(TEMP_DIR)) {
|
||||
fs.mkdirSync(TEMP_DIR, { recursive: true });
|
||||
@@ -129,11 +151,16 @@ async function processStreamChunks({
|
||||
|
||||
for await (const part of fullStream) {
|
||||
let chunk = "";
|
||||
if (
|
||||
inThinkingBlock &&
|
||||
!["reasoning-delta", "reasoning-end", "reasoning-start"].includes(
|
||||
part.type,
|
||||
)
|
||||
) {
|
||||
chunk = "</think>";
|
||||
inThinkingBlock = false;
|
||||
}
|
||||
if (part.type === "text-delta") {
|
||||
if (inThinkingBlock) {
|
||||
chunk = "</think>";
|
||||
inThinkingBlock = false;
|
||||
}
|
||||
chunk += part.text;
|
||||
} else if (part.type === "reasoning-delta") {
|
||||
if (!inThinkingBlock) {
|
||||
@@ -142,6 +169,14 @@ async function processStreamChunks({
|
||||
}
|
||||
|
||||
chunk += escapeDyadTags(part.text);
|
||||
} else if (part.type === "tool-call") {
|
||||
const { serverName, toolName } = parseMcpToolKey(part.toolName);
|
||||
const content = escapeDyadTags(JSON.stringify(part.input));
|
||||
chunk = `<dyad-mcp-tool-call server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-call>\n`;
|
||||
} else if (part.type === "tool-result") {
|
||||
const { serverName, toolName } = parseMcpToolKey(part.toolName);
|
||||
const content = escapeDyadTags(part.output);
|
||||
chunk = `<dyad-mcp-tool-result server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-result>\n`;
|
||||
}
|
||||
|
||||
if (!chunk) {
|
||||
@@ -496,7 +531,10 @@ ${componentSnippet}
|
||||
|
||||
let systemPrompt = constructSystemPrompt({
|
||||
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
|
||||
chatMode: settings.selectedChatMode,
|
||||
chatMode:
|
||||
settings.selectedChatMode === "agent"
|
||||
? "build"
|
||||
: settings.selectedChatMode,
|
||||
});
|
||||
|
||||
// Add information about mentioned apps if any
|
||||
@@ -603,19 +641,21 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
] as const)
|
||||
: [];
|
||||
|
||||
const limitedHistoryChatMessages = limitedMessageHistory.map((msg) => ({
|
||||
role: msg.role as "user" | "assistant" | "system",
|
||||
// Why remove thinking tags?
|
||||
// Thinking tags are generally not critical for the context
|
||||
// and eats up extra tokens.
|
||||
content:
|
||||
settings.selectedChatMode === "ask"
|
||||
? removeDyadTags(removeNonEssentialTags(msg.content))
|
||||
: removeNonEssentialTags(msg.content),
|
||||
}));
|
||||
|
||||
let chatMessages: ModelMessage[] = [
|
||||
...codebasePrefix,
|
||||
...otherCodebasePrefix,
|
||||
...limitedMessageHistory.map((msg) => ({
|
||||
role: msg.role as "user" | "assistant" | "system",
|
||||
// Why remove thinking tags?
|
||||
// Thinking tags are generally not critical for the context
|
||||
// and eats up extra tokens.
|
||||
content:
|
||||
settings.selectedChatMode === "ask"
|
||||
? removeDyadTags(removeNonEssentialTags(msg.content))
|
||||
: removeNonEssentialTags(msg.content),
|
||||
})),
|
||||
...limitedHistoryChatMessages,
|
||||
];
|
||||
|
||||
// Check if the last message should include attachments
|
||||
@@ -654,9 +694,15 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
const simpleStreamText = async ({
|
||||
chatMessages,
|
||||
modelClient,
|
||||
tools,
|
||||
systemPromptOverride = systemPrompt,
|
||||
dyadDisableFiles = false,
|
||||
}: {
|
||||
chatMessages: ModelMessage[];
|
||||
modelClient: ModelClient;
|
||||
tools?: ToolSet;
|
||||
systemPromptOverride?: string;
|
||||
dyadDisableFiles?: boolean;
|
||||
}) => {
|
||||
const dyadRequestId = uuidv4();
|
||||
if (isEngineEnabled) {
|
||||
@@ -671,6 +717,7 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
const providerOptions: Record<string, any> = {
|
||||
"dyad-engine": {
|
||||
dyadRequestId,
|
||||
dyadDisableFiles,
|
||||
},
|
||||
"dyad-gateway": getExtraProviderOptions(
|
||||
modelClient.builtinProviderId,
|
||||
@@ -708,6 +755,7 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
},
|
||||
} satisfies GoogleGenerativeAIProviderOptions;
|
||||
}
|
||||
|
||||
return streamText({
|
||||
headers: isAnthropic
|
||||
? {
|
||||
@@ -718,8 +766,10 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
temperature: await getTemperature(settings.selectedModel),
|
||||
maxRetries: 2,
|
||||
model: modelClient.model,
|
||||
stopWhen: stepCountIs(3),
|
||||
providerOptions,
|
||||
system: systemPrompt,
|
||||
system: systemPromptOverride,
|
||||
tools,
|
||||
messages: chatMessages.filter((m) => m.content),
|
||||
onError: (error: any) => {
|
||||
logger.error("Error streaming text:", error);
|
||||
@@ -780,6 +830,38 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
return fullResponse;
|
||||
};
|
||||
|
||||
if (settings.selectedChatMode === "agent") {
|
||||
const tools = await getMcpTools(event);
|
||||
|
||||
const { fullStream } = await simpleStreamText({
|
||||
chatMessages: limitedHistoryChatMessages,
|
||||
modelClient,
|
||||
tools,
|
||||
systemPromptOverride: constructSystemPrompt({
|
||||
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
|
||||
chatMode: "agent",
|
||||
}),
|
||||
dyadDisableFiles: true,
|
||||
});
|
||||
|
||||
const result = await processStreamChunks({
|
||||
fullStream,
|
||||
fullResponse,
|
||||
abortController,
|
||||
chatId: req.chatId,
|
||||
processResponseChunkUpdate,
|
||||
});
|
||||
fullResponse = result.fullResponse;
|
||||
chatMessages.push({
|
||||
role: "assistant",
|
||||
content: fullResponse,
|
||||
});
|
||||
chatMessages.push({
|
||||
role: "user",
|
||||
content: "OK.",
|
||||
});
|
||||
}
|
||||
|
||||
// When calling streamText, the messages need to be properly formatted for mixed content
|
||||
const { fullStream } = await simpleStreamText({
|
||||
chatMessages,
|
||||
@@ -1316,3 +1398,48 @@ These are the other apps that I've mentioned in my prompt. These other apps' cod
|
||||
${otherAppsCodebaseInfo}
|
||||
`;
|
||||
}
|
||||
|
||||
async function getMcpTools(event: IpcMainInvokeEvent): Promise<ToolSet> {
|
||||
const mcpToolSet: ToolSet = {};
|
||||
try {
|
||||
const servers = await db
|
||||
.select()
|
||||
.from(mcpServers)
|
||||
.where(eq(mcpServers.enabled, true as any));
|
||||
for (const s of servers) {
|
||||
const client = await mcpManager.getClient(s.id);
|
||||
const toolSet = await client.tools();
|
||||
for (const [name, tool] of Object.entries(toolSet)) {
|
||||
const key = `${String(s.name || "").replace(/[^a-zA-Z0-9_-]/g, "-")}__${String(name).replace(/[^a-zA-Z0-9_-]/g, "-")}`;
|
||||
const original = tool;
|
||||
mcpToolSet[key] = {
|
||||
description: original?.description,
|
||||
inputSchema: original?.inputSchema,
|
||||
execute: async (args: any, execCtx: any) => {
|
||||
const inputPreview =
|
||||
typeof args === "string"
|
||||
? args
|
||||
: Array.isArray(args)
|
||||
? args.join(" ")
|
||||
: JSON.stringify(args).slice(0, 500);
|
||||
const ok = await requireMcpToolConsent(event, {
|
||||
serverId: s.id,
|
||||
serverName: s.name,
|
||||
toolName: name,
|
||||
toolDescription: original?.description,
|
||||
inputPreview,
|
||||
});
|
||||
|
||||
if (!ok) throw new Error(`User declined running tool ${key}`);
|
||||
const res = await original.execute?.(args, execCtx);
|
||||
|
||||
return typeof res === "string" ? res : JSON.stringify(res);
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn("Failed building MCP toolset", e);
|
||||
}
|
||||
return mcpToolSet;
|
||||
}
|
||||
|
||||
163
src/ipc/handlers/mcp_handlers.ts
Normal file
163
src/ipc/handlers/mcp_handlers.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import { IpcMainInvokeEvent } from "electron";
|
||||
import log from "electron-log";
|
||||
import { db } from "../../db";
|
||||
import { mcpServers, mcpToolConsents } from "../../db/schema";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
|
||||
import { resolveConsent } from "../utils/mcp_consent";
|
||||
import { getStoredConsent } from "../utils/mcp_consent";
|
||||
import { mcpManager } from "../utils/mcp_manager";
|
||||
import { CreateMcpServer, McpServerUpdate, McpTool } from "../ipc_types";
|
||||
|
||||
const logger = log.scope("mcp_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
|
||||
type ConsentDecision = "accept-once" | "accept-always" | "decline";
|
||||
|
||||
export function registerMcpHandlers() {
|
||||
// CRUD for MCP servers
|
||||
handle("mcp:list-servers", async () => {
|
||||
return await db.select().from(mcpServers);
|
||||
});
|
||||
|
||||
handle(
|
||||
"mcp:create-server",
|
||||
async (_event: IpcMainInvokeEvent, params: CreateMcpServer) => {
|
||||
const { name, transport, command, args, envJson, url, enabled } = params;
|
||||
const result = await db
|
||||
.insert(mcpServers)
|
||||
.values({
|
||||
name,
|
||||
transport,
|
||||
command: command || null,
|
||||
args: args || null,
|
||||
envJson: envJson || null,
|
||||
url: url || null,
|
||||
enabled: !!enabled,
|
||||
})
|
||||
.returning();
|
||||
return result[0];
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"mcp:update-server",
|
||||
async (_event: IpcMainInvokeEvent, params: McpServerUpdate) => {
|
||||
const update: any = {};
|
||||
if (params.name !== undefined) update.name = params.name;
|
||||
if (params.transport !== undefined) update.transport = params.transport;
|
||||
if (params.command !== undefined) update.command = params.command;
|
||||
if (params.args !== undefined) update.args = params.args || null;
|
||||
if (params.cwd !== undefined) update.cwd = params.cwd;
|
||||
if (params.envJson !== undefined) update.envJson = params.envJson || null;
|
||||
if (params.url !== undefined) update.url = params.url;
|
||||
if (params.enabled !== undefined) update.enabled = !!params.enabled;
|
||||
|
||||
const result = await db
|
||||
.update(mcpServers)
|
||||
.set(update)
|
||||
.where(eq(mcpServers.id, params.id))
|
||||
.returning();
|
||||
// If server config changed, dispose cached client to be recreated on next use
|
||||
try {
|
||||
mcpManager.dispose(params.id);
|
||||
} catch {}
|
||||
return result[0];
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"mcp:delete-server",
|
||||
async (_event: IpcMainInvokeEvent, id: number) => {
|
||||
try {
|
||||
mcpManager.dispose(id);
|
||||
} catch {}
|
||||
await db.delete(mcpServers).where(eq(mcpServers.id, id));
|
||||
return { success: true };
|
||||
},
|
||||
);
|
||||
|
||||
// Tools listing (dynamic)
|
||||
handle(
|
||||
"mcp:list-tools",
|
||||
async (
|
||||
_event: IpcMainInvokeEvent,
|
||||
serverId: number,
|
||||
): Promise<McpTool[]> => {
|
||||
try {
|
||||
const client = await mcpManager.getClient(serverId);
|
||||
const remoteTools = await client.tools();
|
||||
const tools = await Promise.all(
|
||||
Object.entries(remoteTools).map(async ([name, tool]) => ({
|
||||
name,
|
||||
description: tool.description ?? null,
|
||||
consent: await getStoredConsent(serverId, name),
|
||||
})),
|
||||
);
|
||||
return tools;
|
||||
} catch (e) {
|
||||
logger.error("Failed to list tools", e);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
);
|
||||
// Consents
|
||||
handle("mcp:get-tool-consents", async () => {
|
||||
return await db.select().from(mcpToolConsents);
|
||||
});
|
||||
|
||||
handle(
|
||||
"mcp:set-tool-consent",
|
||||
async (
|
||||
_event: IpcMainInvokeEvent,
|
||||
params: {
|
||||
serverId: number;
|
||||
toolName: string;
|
||||
consent: "ask" | "always" | "denied";
|
||||
},
|
||||
) => {
|
||||
const existing = await db
|
||||
.select()
|
||||
.from(mcpToolConsents)
|
||||
.where(
|
||||
and(
|
||||
eq(mcpToolConsents.serverId, params.serverId),
|
||||
eq(mcpToolConsents.toolName, params.toolName),
|
||||
),
|
||||
);
|
||||
if (existing.length > 0) {
|
||||
const result = await db
|
||||
.update(mcpToolConsents)
|
||||
.set({ consent: params.consent })
|
||||
.where(
|
||||
and(
|
||||
eq(mcpToolConsents.serverId, params.serverId),
|
||||
eq(mcpToolConsents.toolName, params.toolName),
|
||||
),
|
||||
)
|
||||
.returning();
|
||||
return result[0];
|
||||
} else {
|
||||
const result = await db
|
||||
.insert(mcpToolConsents)
|
||||
.values({
|
||||
serverId: params.serverId,
|
||||
toolName: params.toolName,
|
||||
consent: params.consent,
|
||||
})
|
||||
.returning();
|
||||
return result[0];
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Tool consent request/response handshake
|
||||
// Receive consent response from renderer
|
||||
handle(
|
||||
"mcp:tool-consent-response",
|
||||
async (_event, data: { requestId: string; decision: ConsentDecision }) => {
|
||||
resolveConsent(data.requestId, data.decision);
|
||||
},
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user