Add MCP support (#1028)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { ipcMain } from "electron";
|
||||
import { ipcMain, IpcMainInvokeEvent } from "electron";
|
||||
import {
|
||||
ModelMessage,
|
||||
TextPart,
|
||||
@@ -7,7 +7,9 @@ import {
|
||||
streamText,
|
||||
ToolSet,
|
||||
TextStreamPart,
|
||||
stepCountIs,
|
||||
} from "ai";
|
||||
|
||||
import { db } from "../../db";
|
||||
import { chats, messages } from "../../db/schema";
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
@@ -42,6 +44,8 @@ import { getMaxTokens, getTemperature } from "../utils/token_utils";
|
||||
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
|
||||
import { validateChatContext } from "../utils/context_paths_utils";
|
||||
import { GoogleGenerativeAIProviderOptions } from "@ai-sdk/google";
|
||||
import { mcpServers } from "../../db/schema";
|
||||
import { requireMcpToolConsent } from "../utils/mcp_consent";
|
||||
|
||||
import { getExtraProviderOptions } from "../utils/thinking_utils";
|
||||
|
||||
@@ -64,6 +68,7 @@ import { parseAppMentions } from "@/shared/parse_mention_apps";
|
||||
import { prompts as promptsTable } from "../../db/schema";
|
||||
import { inArray } from "drizzle-orm";
|
||||
import { replacePromptReference } from "../utils/replacePromptReference";
|
||||
import { mcpManager } from "../utils/mcp_manager";
|
||||
|
||||
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
|
||||
|
||||
@@ -103,6 +108,23 @@ function escapeXml(unsafe: string): string {
|
||||
.replace(/"/g, """);
|
||||
}
|
||||
|
||||
// Safely parse an MCP tool key that combines server and tool names.
|
||||
// We split on the LAST occurrence of "__" to avoid ambiguity if either
|
||||
// side contains "__" as part of its sanitized name.
|
||||
function parseMcpToolKey(toolKey: string): {
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
} {
|
||||
const separator = "__";
|
||||
const lastIndex = toolKey.lastIndexOf(separator);
|
||||
if (lastIndex === -1) {
|
||||
return { serverName: "", toolName: toolKey };
|
||||
}
|
||||
const serverName = toolKey.slice(0, lastIndex);
|
||||
const toolName = toolKey.slice(lastIndex + separator.length);
|
||||
return { serverName, toolName };
|
||||
}
|
||||
|
||||
// Ensure the temp directory exists
|
||||
if (!fs.existsSync(TEMP_DIR)) {
|
||||
fs.mkdirSync(TEMP_DIR, { recursive: true });
|
||||
@@ -129,11 +151,16 @@ async function processStreamChunks({
|
||||
|
||||
for await (const part of fullStream) {
|
||||
let chunk = "";
|
||||
if (
|
||||
inThinkingBlock &&
|
||||
!["reasoning-delta", "reasoning-end", "reasoning-start"].includes(
|
||||
part.type,
|
||||
)
|
||||
) {
|
||||
chunk = "</think>";
|
||||
inThinkingBlock = false;
|
||||
}
|
||||
if (part.type === "text-delta") {
|
||||
if (inThinkingBlock) {
|
||||
chunk = "</think>";
|
||||
inThinkingBlock = false;
|
||||
}
|
||||
chunk += part.text;
|
||||
} else if (part.type === "reasoning-delta") {
|
||||
if (!inThinkingBlock) {
|
||||
@@ -142,6 +169,14 @@ async function processStreamChunks({
|
||||
}
|
||||
|
||||
chunk += escapeDyadTags(part.text);
|
||||
} else if (part.type === "tool-call") {
|
||||
const { serverName, toolName } = parseMcpToolKey(part.toolName);
|
||||
const content = escapeDyadTags(JSON.stringify(part.input));
|
||||
chunk = `<dyad-mcp-tool-call server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-call>\n`;
|
||||
} else if (part.type === "tool-result") {
|
||||
const { serverName, toolName } = parseMcpToolKey(part.toolName);
|
||||
const content = escapeDyadTags(part.output);
|
||||
chunk = `<dyad-mcp-tool-result server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-result>\n`;
|
||||
}
|
||||
|
||||
if (!chunk) {
|
||||
@@ -496,7 +531,10 @@ ${componentSnippet}
|
||||
|
||||
let systemPrompt = constructSystemPrompt({
|
||||
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
|
||||
chatMode: settings.selectedChatMode,
|
||||
chatMode:
|
||||
settings.selectedChatMode === "agent"
|
||||
? "build"
|
||||
: settings.selectedChatMode,
|
||||
});
|
||||
|
||||
// Add information about mentioned apps if any
|
||||
@@ -603,19 +641,21 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
] as const)
|
||||
: [];
|
||||
|
||||
const limitedHistoryChatMessages = limitedMessageHistory.map((msg) => ({
|
||||
role: msg.role as "user" | "assistant" | "system",
|
||||
// Why remove thinking tags?
|
||||
// Thinking tags are generally not critical for the context
|
||||
// and eats up extra tokens.
|
||||
content:
|
||||
settings.selectedChatMode === "ask"
|
||||
? removeDyadTags(removeNonEssentialTags(msg.content))
|
||||
: removeNonEssentialTags(msg.content),
|
||||
}));
|
||||
|
||||
let chatMessages: ModelMessage[] = [
|
||||
...codebasePrefix,
|
||||
...otherCodebasePrefix,
|
||||
...limitedMessageHistory.map((msg) => ({
|
||||
role: msg.role as "user" | "assistant" | "system",
|
||||
// Why remove thinking tags?
|
||||
// Thinking tags are generally not critical for the context
|
||||
// and eats up extra tokens.
|
||||
content:
|
||||
settings.selectedChatMode === "ask"
|
||||
? removeDyadTags(removeNonEssentialTags(msg.content))
|
||||
: removeNonEssentialTags(msg.content),
|
||||
})),
|
||||
...limitedHistoryChatMessages,
|
||||
];
|
||||
|
||||
// Check if the last message should include attachments
|
||||
@@ -654,9 +694,15 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
const simpleStreamText = async ({
|
||||
chatMessages,
|
||||
modelClient,
|
||||
tools,
|
||||
systemPromptOverride = systemPrompt,
|
||||
dyadDisableFiles = false,
|
||||
}: {
|
||||
chatMessages: ModelMessage[];
|
||||
modelClient: ModelClient;
|
||||
tools?: ToolSet;
|
||||
systemPromptOverride?: string;
|
||||
dyadDisableFiles?: boolean;
|
||||
}) => {
|
||||
const dyadRequestId = uuidv4();
|
||||
if (isEngineEnabled) {
|
||||
@@ -671,6 +717,7 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
const providerOptions: Record<string, any> = {
|
||||
"dyad-engine": {
|
||||
dyadRequestId,
|
||||
dyadDisableFiles,
|
||||
},
|
||||
"dyad-gateway": getExtraProviderOptions(
|
||||
modelClient.builtinProviderId,
|
||||
@@ -708,6 +755,7 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
},
|
||||
} satisfies GoogleGenerativeAIProviderOptions;
|
||||
}
|
||||
|
||||
return streamText({
|
||||
headers: isAnthropic
|
||||
? {
|
||||
@@ -718,8 +766,10 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
temperature: await getTemperature(settings.selectedModel),
|
||||
maxRetries: 2,
|
||||
model: modelClient.model,
|
||||
stopWhen: stepCountIs(3),
|
||||
providerOptions,
|
||||
system: systemPrompt,
|
||||
system: systemPromptOverride,
|
||||
tools,
|
||||
messages: chatMessages.filter((m) => m.content),
|
||||
onError: (error: any) => {
|
||||
logger.error("Error streaming text:", error);
|
||||
@@ -780,6 +830,38 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
return fullResponse;
|
||||
};
|
||||
|
||||
if (settings.selectedChatMode === "agent") {
|
||||
const tools = await getMcpTools(event);
|
||||
|
||||
const { fullStream } = await simpleStreamText({
|
||||
chatMessages: limitedHistoryChatMessages,
|
||||
modelClient,
|
||||
tools,
|
||||
systemPromptOverride: constructSystemPrompt({
|
||||
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
|
||||
chatMode: "agent",
|
||||
}),
|
||||
dyadDisableFiles: true,
|
||||
});
|
||||
|
||||
const result = await processStreamChunks({
|
||||
fullStream,
|
||||
fullResponse,
|
||||
abortController,
|
||||
chatId: req.chatId,
|
||||
processResponseChunkUpdate,
|
||||
});
|
||||
fullResponse = result.fullResponse;
|
||||
chatMessages.push({
|
||||
role: "assistant",
|
||||
content: fullResponse,
|
||||
});
|
||||
chatMessages.push({
|
||||
role: "user",
|
||||
content: "OK.",
|
||||
});
|
||||
}
|
||||
|
||||
// When calling streamText, the messages need to be properly formatted for mixed content
|
||||
const { fullStream } = await simpleStreamText({
|
||||
chatMessages,
|
||||
@@ -1316,3 +1398,48 @@ These are the other apps that I've mentioned in my prompt. These other apps' cod
|
||||
${otherAppsCodebaseInfo}
|
||||
`;
|
||||
}
|
||||
|
||||
async function getMcpTools(event: IpcMainInvokeEvent): Promise<ToolSet> {
|
||||
const mcpToolSet: ToolSet = {};
|
||||
try {
|
||||
const servers = await db
|
||||
.select()
|
||||
.from(mcpServers)
|
||||
.where(eq(mcpServers.enabled, true as any));
|
||||
for (const s of servers) {
|
||||
const client = await mcpManager.getClient(s.id);
|
||||
const toolSet = await client.tools();
|
||||
for (const [name, tool] of Object.entries(toolSet)) {
|
||||
const key = `${String(s.name || "").replace(/[^a-zA-Z0-9_-]/g, "-")}__${String(name).replace(/[^a-zA-Z0-9_-]/g, "-")}`;
|
||||
const original = tool;
|
||||
mcpToolSet[key] = {
|
||||
description: original?.description,
|
||||
inputSchema: original?.inputSchema,
|
||||
execute: async (args: any, execCtx: any) => {
|
||||
const inputPreview =
|
||||
typeof args === "string"
|
||||
? args
|
||||
: Array.isArray(args)
|
||||
? args.join(" ")
|
||||
: JSON.stringify(args).slice(0, 500);
|
||||
const ok = await requireMcpToolConsent(event, {
|
||||
serverId: s.id,
|
||||
serverName: s.name,
|
||||
toolName: name,
|
||||
toolDescription: original?.description,
|
||||
inputPreview,
|
||||
});
|
||||
|
||||
if (!ok) throw new Error(`User declined running tool ${key}`);
|
||||
const res = await original.execute?.(args, execCtx);
|
||||
|
||||
return typeof res === "string" ? res : JSON.stringify(res);
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn("Failed building MCP toolset", e);
|
||||
}
|
||||
return mcpToolSet;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user