Add MCP support (#1028)
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
import { v4 as uuidv4 } from "uuid";
|
||||
import { ipcMain } from "electron";
|
||||
import { ipcMain, IpcMainInvokeEvent } from "electron";
|
||||
import {
|
||||
ModelMessage,
|
||||
TextPart,
|
||||
@@ -7,7 +7,9 @@ import {
|
||||
streamText,
|
||||
ToolSet,
|
||||
TextStreamPart,
|
||||
stepCountIs,
|
||||
} from "ai";
|
||||
|
||||
import { db } from "../../db";
|
||||
import { chats, messages } from "../../db/schema";
|
||||
import { and, eq, isNull } from "drizzle-orm";
|
||||
@@ -42,6 +44,8 @@ import { getMaxTokens, getTemperature } from "../utils/token_utils";
|
||||
import { MAX_CHAT_TURNS_IN_CONTEXT } from "@/constants/settings_constants";
|
||||
import { validateChatContext } from "../utils/context_paths_utils";
|
||||
import { GoogleGenerativeAIProviderOptions } from "@ai-sdk/google";
|
||||
import { mcpServers } from "../../db/schema";
|
||||
import { requireMcpToolConsent } from "../utils/mcp_consent";
|
||||
|
||||
import { getExtraProviderOptions } from "../utils/thinking_utils";
|
||||
|
||||
@@ -64,6 +68,7 @@ import { parseAppMentions } from "@/shared/parse_mention_apps";
|
||||
import { prompts as promptsTable } from "../../db/schema";
|
||||
import { inArray } from "drizzle-orm";
|
||||
import { replacePromptReference } from "../utils/replacePromptReference";
|
||||
import { mcpManager } from "../utils/mcp_manager";
|
||||
|
||||
type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
|
||||
|
||||
@@ -103,6 +108,23 @@ function escapeXml(unsafe: string): string {
|
||||
.replace(/"/g, """);
|
||||
}
|
||||
|
||||
// Safely parse an MCP tool key that combines server and tool names.
|
||||
// We split on the LAST occurrence of "__" to avoid ambiguity if either
|
||||
// side contains "__" as part of its sanitized name.
|
||||
function parseMcpToolKey(toolKey: string): {
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
} {
|
||||
const separator = "__";
|
||||
const lastIndex = toolKey.lastIndexOf(separator);
|
||||
if (lastIndex === -1) {
|
||||
return { serverName: "", toolName: toolKey };
|
||||
}
|
||||
const serverName = toolKey.slice(0, lastIndex);
|
||||
const toolName = toolKey.slice(lastIndex + separator.length);
|
||||
return { serverName, toolName };
|
||||
}
|
||||
|
||||
// Ensure the temp directory exists
|
||||
if (!fs.existsSync(TEMP_DIR)) {
|
||||
fs.mkdirSync(TEMP_DIR, { recursive: true });
|
||||
@@ -129,11 +151,16 @@ async function processStreamChunks({
|
||||
|
||||
for await (const part of fullStream) {
|
||||
let chunk = "";
|
||||
if (
|
||||
inThinkingBlock &&
|
||||
!["reasoning-delta", "reasoning-end", "reasoning-start"].includes(
|
||||
part.type,
|
||||
)
|
||||
) {
|
||||
chunk = "</think>";
|
||||
inThinkingBlock = false;
|
||||
}
|
||||
if (part.type === "text-delta") {
|
||||
if (inThinkingBlock) {
|
||||
chunk = "</think>";
|
||||
inThinkingBlock = false;
|
||||
}
|
||||
chunk += part.text;
|
||||
} else if (part.type === "reasoning-delta") {
|
||||
if (!inThinkingBlock) {
|
||||
@@ -142,6 +169,14 @@ async function processStreamChunks({
|
||||
}
|
||||
|
||||
chunk += escapeDyadTags(part.text);
|
||||
} else if (part.type === "tool-call") {
|
||||
const { serverName, toolName } = parseMcpToolKey(part.toolName);
|
||||
const content = escapeDyadTags(JSON.stringify(part.input));
|
||||
chunk = `<dyad-mcp-tool-call server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-call>\n`;
|
||||
} else if (part.type === "tool-result") {
|
||||
const { serverName, toolName } = parseMcpToolKey(part.toolName);
|
||||
const content = escapeDyadTags(part.output);
|
||||
chunk = `<dyad-mcp-tool-result server="${serverName}" tool="${toolName}">\n${content}\n</dyad-mcp-tool-result>\n`;
|
||||
}
|
||||
|
||||
if (!chunk) {
|
||||
@@ -496,7 +531,10 @@ ${componentSnippet}
|
||||
|
||||
let systemPrompt = constructSystemPrompt({
|
||||
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
|
||||
chatMode: settings.selectedChatMode,
|
||||
chatMode:
|
||||
settings.selectedChatMode === "agent"
|
||||
? "build"
|
||||
: settings.selectedChatMode,
|
||||
});
|
||||
|
||||
// Add information about mentioned apps if any
|
||||
@@ -603,19 +641,21 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
] as const)
|
||||
: [];
|
||||
|
||||
const limitedHistoryChatMessages = limitedMessageHistory.map((msg) => ({
|
||||
role: msg.role as "user" | "assistant" | "system",
|
||||
// Why remove thinking tags?
|
||||
// Thinking tags are generally not critical for the context
|
||||
// and eats up extra tokens.
|
||||
content:
|
||||
settings.selectedChatMode === "ask"
|
||||
? removeDyadTags(removeNonEssentialTags(msg.content))
|
||||
: removeNonEssentialTags(msg.content),
|
||||
}));
|
||||
|
||||
let chatMessages: ModelMessage[] = [
|
||||
...codebasePrefix,
|
||||
...otherCodebasePrefix,
|
||||
...limitedMessageHistory.map((msg) => ({
|
||||
role: msg.role as "user" | "assistant" | "system",
|
||||
// Why remove thinking tags?
|
||||
// Thinking tags are generally not critical for the context
|
||||
// and eats up extra tokens.
|
||||
content:
|
||||
settings.selectedChatMode === "ask"
|
||||
? removeDyadTags(removeNonEssentialTags(msg.content))
|
||||
: removeNonEssentialTags(msg.content),
|
||||
})),
|
||||
...limitedHistoryChatMessages,
|
||||
];
|
||||
|
||||
// Check if the last message should include attachments
|
||||
@@ -654,9 +694,15 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
const simpleStreamText = async ({
|
||||
chatMessages,
|
||||
modelClient,
|
||||
tools,
|
||||
systemPromptOverride = systemPrompt,
|
||||
dyadDisableFiles = false,
|
||||
}: {
|
||||
chatMessages: ModelMessage[];
|
||||
modelClient: ModelClient;
|
||||
tools?: ToolSet;
|
||||
systemPromptOverride?: string;
|
||||
dyadDisableFiles?: boolean;
|
||||
}) => {
|
||||
const dyadRequestId = uuidv4();
|
||||
if (isEngineEnabled) {
|
||||
@@ -671,6 +717,7 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
const providerOptions: Record<string, any> = {
|
||||
"dyad-engine": {
|
||||
dyadRequestId,
|
||||
dyadDisableFiles,
|
||||
},
|
||||
"dyad-gateway": getExtraProviderOptions(
|
||||
modelClient.builtinProviderId,
|
||||
@@ -708,6 +755,7 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
},
|
||||
} satisfies GoogleGenerativeAIProviderOptions;
|
||||
}
|
||||
|
||||
return streamText({
|
||||
headers: isAnthropic
|
||||
? {
|
||||
@@ -718,8 +766,10 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
temperature: await getTemperature(settings.selectedModel),
|
||||
maxRetries: 2,
|
||||
model: modelClient.model,
|
||||
stopWhen: stepCountIs(3),
|
||||
providerOptions,
|
||||
system: systemPrompt,
|
||||
system: systemPromptOverride,
|
||||
tools,
|
||||
messages: chatMessages.filter((m) => m.content),
|
||||
onError: (error: any) => {
|
||||
logger.error("Error streaming text:", error);
|
||||
@@ -780,6 +830,38 @@ This conversation includes one or more image attachments. When the user uploads
|
||||
return fullResponse;
|
||||
};
|
||||
|
||||
if (settings.selectedChatMode === "agent") {
|
||||
const tools = await getMcpTools(event);
|
||||
|
||||
const { fullStream } = await simpleStreamText({
|
||||
chatMessages: limitedHistoryChatMessages,
|
||||
modelClient,
|
||||
tools,
|
||||
systemPromptOverride: constructSystemPrompt({
|
||||
aiRules: await readAiRules(getDyadAppPath(updatedChat.app.path)),
|
||||
chatMode: "agent",
|
||||
}),
|
||||
dyadDisableFiles: true,
|
||||
});
|
||||
|
||||
const result = await processStreamChunks({
|
||||
fullStream,
|
||||
fullResponse,
|
||||
abortController,
|
||||
chatId: req.chatId,
|
||||
processResponseChunkUpdate,
|
||||
});
|
||||
fullResponse = result.fullResponse;
|
||||
chatMessages.push({
|
||||
role: "assistant",
|
||||
content: fullResponse,
|
||||
});
|
||||
chatMessages.push({
|
||||
role: "user",
|
||||
content: "OK.",
|
||||
});
|
||||
}
|
||||
|
||||
// When calling streamText, the messages need to be properly formatted for mixed content
|
||||
const { fullStream } = await simpleStreamText({
|
||||
chatMessages,
|
||||
@@ -1316,3 +1398,48 @@ These are the other apps that I've mentioned in my prompt. These other apps' cod
|
||||
${otherAppsCodebaseInfo}
|
||||
`;
|
||||
}
|
||||
|
||||
async function getMcpTools(event: IpcMainInvokeEvent): Promise<ToolSet> {
|
||||
const mcpToolSet: ToolSet = {};
|
||||
try {
|
||||
const servers = await db
|
||||
.select()
|
||||
.from(mcpServers)
|
||||
.where(eq(mcpServers.enabled, true as any));
|
||||
for (const s of servers) {
|
||||
const client = await mcpManager.getClient(s.id);
|
||||
const toolSet = await client.tools();
|
||||
for (const [name, tool] of Object.entries(toolSet)) {
|
||||
const key = `${String(s.name || "").replace(/[^a-zA-Z0-9_-]/g, "-")}__${String(name).replace(/[^a-zA-Z0-9_-]/g, "-")}`;
|
||||
const original = tool;
|
||||
mcpToolSet[key] = {
|
||||
description: original?.description,
|
||||
inputSchema: original?.inputSchema,
|
||||
execute: async (args: any, execCtx: any) => {
|
||||
const inputPreview =
|
||||
typeof args === "string"
|
||||
? args
|
||||
: Array.isArray(args)
|
||||
? args.join(" ")
|
||||
: JSON.stringify(args).slice(0, 500);
|
||||
const ok = await requireMcpToolConsent(event, {
|
||||
serverId: s.id,
|
||||
serverName: s.name,
|
||||
toolName: name,
|
||||
toolDescription: original?.description,
|
||||
inputPreview,
|
||||
});
|
||||
|
||||
if (!ok) throw new Error(`User declined running tool ${key}`);
|
||||
const res = await original.execute?.(args, execCtx);
|
||||
|
||||
return typeof res === "string" ? res : JSON.stringify(res);
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
} catch (e) {
|
||||
logger.warn("Failed building MCP toolset", e);
|
||||
}
|
||||
return mcpToolSet;
|
||||
}
|
||||
|
||||
163
src/ipc/handlers/mcp_handlers.ts
Normal file
163
src/ipc/handlers/mcp_handlers.ts
Normal file
@@ -0,0 +1,163 @@
|
||||
import { IpcMainInvokeEvent } from "electron";
|
||||
import log from "electron-log";
|
||||
import { db } from "../../db";
|
||||
import { mcpServers, mcpToolConsents } from "../../db/schema";
|
||||
import { eq, and } from "drizzle-orm";
|
||||
import { createLoggedHandler } from "./safe_handle";
|
||||
|
||||
import { resolveConsent } from "../utils/mcp_consent";
|
||||
import { getStoredConsent } from "../utils/mcp_consent";
|
||||
import { mcpManager } from "../utils/mcp_manager";
|
||||
import { CreateMcpServer, McpServerUpdate, McpTool } from "../ipc_types";
|
||||
|
||||
const logger = log.scope("mcp_handlers");
|
||||
const handle = createLoggedHandler(logger);
|
||||
|
||||
type ConsentDecision = "accept-once" | "accept-always" | "decline";
|
||||
|
||||
export function registerMcpHandlers() {
|
||||
// CRUD for MCP servers
|
||||
handle("mcp:list-servers", async () => {
|
||||
return await db.select().from(mcpServers);
|
||||
});
|
||||
|
||||
handle(
|
||||
"mcp:create-server",
|
||||
async (_event: IpcMainInvokeEvent, params: CreateMcpServer) => {
|
||||
const { name, transport, command, args, envJson, url, enabled } = params;
|
||||
const result = await db
|
||||
.insert(mcpServers)
|
||||
.values({
|
||||
name,
|
||||
transport,
|
||||
command: command || null,
|
||||
args: args || null,
|
||||
envJson: envJson || null,
|
||||
url: url || null,
|
||||
enabled: !!enabled,
|
||||
})
|
||||
.returning();
|
||||
return result[0];
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"mcp:update-server",
|
||||
async (_event: IpcMainInvokeEvent, params: McpServerUpdate) => {
|
||||
const update: any = {};
|
||||
if (params.name !== undefined) update.name = params.name;
|
||||
if (params.transport !== undefined) update.transport = params.transport;
|
||||
if (params.command !== undefined) update.command = params.command;
|
||||
if (params.args !== undefined) update.args = params.args || null;
|
||||
if (params.cwd !== undefined) update.cwd = params.cwd;
|
||||
if (params.envJson !== undefined) update.envJson = params.envJson || null;
|
||||
if (params.url !== undefined) update.url = params.url;
|
||||
if (params.enabled !== undefined) update.enabled = !!params.enabled;
|
||||
|
||||
const result = await db
|
||||
.update(mcpServers)
|
||||
.set(update)
|
||||
.where(eq(mcpServers.id, params.id))
|
||||
.returning();
|
||||
// If server config changed, dispose cached client to be recreated on next use
|
||||
try {
|
||||
mcpManager.dispose(params.id);
|
||||
} catch {}
|
||||
return result[0];
|
||||
},
|
||||
);
|
||||
|
||||
handle(
|
||||
"mcp:delete-server",
|
||||
async (_event: IpcMainInvokeEvent, id: number) => {
|
||||
try {
|
||||
mcpManager.dispose(id);
|
||||
} catch {}
|
||||
await db.delete(mcpServers).where(eq(mcpServers.id, id));
|
||||
return { success: true };
|
||||
},
|
||||
);
|
||||
|
||||
// Tools listing (dynamic)
|
||||
handle(
|
||||
"mcp:list-tools",
|
||||
async (
|
||||
_event: IpcMainInvokeEvent,
|
||||
serverId: number,
|
||||
): Promise<McpTool[]> => {
|
||||
try {
|
||||
const client = await mcpManager.getClient(serverId);
|
||||
const remoteTools = await client.tools();
|
||||
const tools = await Promise.all(
|
||||
Object.entries(remoteTools).map(async ([name, tool]) => ({
|
||||
name,
|
||||
description: tool.description ?? null,
|
||||
consent: await getStoredConsent(serverId, name),
|
||||
})),
|
||||
);
|
||||
return tools;
|
||||
} catch (e) {
|
||||
logger.error("Failed to list tools", e);
|
||||
return [];
|
||||
}
|
||||
},
|
||||
);
|
||||
// Consents
|
||||
handle("mcp:get-tool-consents", async () => {
|
||||
return await db.select().from(mcpToolConsents);
|
||||
});
|
||||
|
||||
handle(
|
||||
"mcp:set-tool-consent",
|
||||
async (
|
||||
_event: IpcMainInvokeEvent,
|
||||
params: {
|
||||
serverId: number;
|
||||
toolName: string;
|
||||
consent: "ask" | "always" | "denied";
|
||||
},
|
||||
) => {
|
||||
const existing = await db
|
||||
.select()
|
||||
.from(mcpToolConsents)
|
||||
.where(
|
||||
and(
|
||||
eq(mcpToolConsents.serverId, params.serverId),
|
||||
eq(mcpToolConsents.toolName, params.toolName),
|
||||
),
|
||||
);
|
||||
if (existing.length > 0) {
|
||||
const result = await db
|
||||
.update(mcpToolConsents)
|
||||
.set({ consent: params.consent })
|
||||
.where(
|
||||
and(
|
||||
eq(mcpToolConsents.serverId, params.serverId),
|
||||
eq(mcpToolConsents.toolName, params.toolName),
|
||||
),
|
||||
)
|
||||
.returning();
|
||||
return result[0];
|
||||
} else {
|
||||
const result = await db
|
||||
.insert(mcpToolConsents)
|
||||
.values({
|
||||
serverId: params.serverId,
|
||||
toolName: params.toolName,
|
||||
consent: params.consent,
|
||||
})
|
||||
.returning();
|
||||
return result[0];
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
// Tool consent request/response handshake
|
||||
// Receive consent response from renderer
|
||||
handle(
|
||||
"mcp:tool-consent-response",
|
||||
async (_event, data: { requestId: string; decision: ConsentDecision }) => {
|
||||
resolveConsent(data.requestId, data.decision);
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -63,6 +63,8 @@ import type {
|
||||
PromptDto,
|
||||
CreatePromptParamsDto,
|
||||
UpdatePromptParamsDto,
|
||||
McpServerUpdate,
|
||||
CreateMcpServer,
|
||||
} from "./ipc_types";
|
||||
import type { Template } from "../shared/templates";
|
||||
import type {
|
||||
@@ -119,11 +121,13 @@ export class IpcClient {
|
||||
onError: (error: string) => void;
|
||||
}
|
||||
>;
|
||||
private mcpConsentHandlers: Map<string, (payload: any) => void>;
|
||||
private constructor() {
|
||||
this.ipcRenderer = (window as any).electron.ipcRenderer as IpcRenderer;
|
||||
this.chatStreams = new Map();
|
||||
this.appStreams = new Map();
|
||||
this.helpStreams = new Map();
|
||||
this.mcpConsentHandlers = new Map();
|
||||
// Set up listeners for stream events
|
||||
this.ipcRenderer.on("chat:response:chunk", (data) => {
|
||||
if (
|
||||
@@ -238,6 +242,12 @@ export class IpcClient {
|
||||
this.helpStreams.delete(sessionId);
|
||||
}
|
||||
});
|
||||
|
||||
// MCP tool consent request from main
|
||||
this.ipcRenderer.on("mcp:tool-consent-request", (payload) => {
|
||||
const handler = this.mcpConsentHandlers.get("consent");
|
||||
if (handler) handler(payload);
|
||||
});
|
||||
}
|
||||
|
||||
public static getInstance(): IpcClient {
|
||||
@@ -814,6 +824,67 @@ export class IpcClient {
|
||||
return result.version as string;
|
||||
}
|
||||
|
||||
// --- MCP Client Methods ---
|
||||
public async listMcpServers() {
|
||||
return this.ipcRenderer.invoke("mcp:list-servers");
|
||||
}
|
||||
|
||||
public async createMcpServer(params: CreateMcpServer) {
|
||||
return this.ipcRenderer.invoke("mcp:create-server", params);
|
||||
}
|
||||
|
||||
public async updateMcpServer(params: McpServerUpdate) {
|
||||
return this.ipcRenderer.invoke("mcp:update-server", params);
|
||||
}
|
||||
|
||||
public async deleteMcpServer(id: number) {
|
||||
return this.ipcRenderer.invoke("mcp:delete-server", id);
|
||||
}
|
||||
|
||||
public async listMcpTools(serverId: number) {
|
||||
return this.ipcRenderer.invoke("mcp:list-tools", serverId);
|
||||
}
|
||||
|
||||
// Removed: upsertMcpTools and setMcpToolActive – tools are fetched dynamically at runtime
|
||||
|
||||
public async getMcpToolConsents() {
|
||||
return this.ipcRenderer.invoke("mcp:get-tool-consents");
|
||||
}
|
||||
|
||||
public async setMcpToolConsent(params: {
|
||||
serverId: number;
|
||||
toolName: string;
|
||||
consent: "ask" | "always" | "denied";
|
||||
}) {
|
||||
return this.ipcRenderer.invoke("mcp:set-tool-consent", params);
|
||||
}
|
||||
|
||||
public onMcpToolConsentRequest(
|
||||
handler: (payload: {
|
||||
requestId: string;
|
||||
serverId: number;
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
toolDescription?: string | null;
|
||||
inputPreview?: string | null;
|
||||
}) => void,
|
||||
) {
|
||||
this.mcpConsentHandlers.set("consent", handler as any);
|
||||
return () => {
|
||||
this.mcpConsentHandlers.delete("consent");
|
||||
};
|
||||
}
|
||||
|
||||
public respondToMcpConsentRequest(
|
||||
requestId: string,
|
||||
decision: "accept-once" | "accept-always" | "decline",
|
||||
) {
|
||||
this.ipcRenderer.invoke("mcp:tool-consent-response", {
|
||||
requestId,
|
||||
decision,
|
||||
});
|
||||
}
|
||||
|
||||
// Get proposal details
|
||||
public async getProposal(chatId: number): Promise<ProposalResult | null> {
|
||||
try {
|
||||
|
||||
@@ -30,6 +30,7 @@ import { registerTemplateHandlers } from "./handlers/template_handlers";
|
||||
import { registerPortalHandlers } from "./handlers/portal_handlers";
|
||||
import { registerPromptHandlers } from "./handlers/prompt_handlers";
|
||||
import { registerHelpBotHandlers } from "./handlers/help_bot_handlers";
|
||||
import { registerMcpHandlers } from "./handlers/mcp_handlers";
|
||||
|
||||
export function registerIpcHandlers() {
|
||||
// Register all IPC handlers by category
|
||||
@@ -65,4 +66,5 @@ export function registerIpcHandlers() {
|
||||
registerPortalHandlers();
|
||||
registerPromptHandlers();
|
||||
registerHelpBotHandlers();
|
||||
registerMcpHandlers();
|
||||
}
|
||||
|
||||
@@ -450,3 +450,37 @@ export interface HelpChatResponseError {
|
||||
sessionId: string;
|
||||
error: string;
|
||||
}
|
||||
|
||||
// --- MCP Types ---
|
||||
export interface McpServer {
|
||||
id: number;
|
||||
name: string;
|
||||
transport: string;
|
||||
command?: string | null;
|
||||
args?: string[] | null;
|
||||
cwd?: string | null;
|
||||
envJson?: Record<string, string> | null;
|
||||
url?: string | null;
|
||||
enabled: boolean;
|
||||
createdAt: number;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
export interface CreateMcpServer
|
||||
extends Omit<McpServer, "id" | "createdAt" | "updatedAt"> {}
|
||||
export type McpServerUpdate = Partial<McpServer> & Pick<McpServer, "id">;
|
||||
export type McpToolConsentType = "ask" | "always" | "denied";
|
||||
|
||||
export interface McpTool {
|
||||
name: string;
|
||||
description?: string | null;
|
||||
consent: McpToolConsentType;
|
||||
}
|
||||
|
||||
export interface McpToolConsent {
|
||||
id: number;
|
||||
serverId: number;
|
||||
toolName: string;
|
||||
consent: McpToolConsentType;
|
||||
updatedAt: number;
|
||||
}
|
||||
|
||||
@@ -137,6 +137,10 @@ export function createDyadEngine(
|
||||
if ("dyadRequestId" in parsedBody) {
|
||||
delete parsedBody.dyadRequestId;
|
||||
}
|
||||
const dyadDisableFiles = parsedBody.dyadDisableFiles;
|
||||
if ("dyadDisableFiles" in parsedBody) {
|
||||
delete parsedBody.dyadDisableFiles;
|
||||
}
|
||||
|
||||
// Track and modify requestId with attempt number
|
||||
let modifiedRequestId = requestId;
|
||||
@@ -147,7 +151,7 @@ export function createDyadEngine(
|
||||
}
|
||||
|
||||
// Add files to the request if they exist
|
||||
if (files?.length) {
|
||||
if (files?.length && !dyadDisableFiles) {
|
||||
parsedBody.dyad_options = {
|
||||
files,
|
||||
enable_lazy_edits: options.dyadOptions.enableLazyEdits,
|
||||
|
||||
108
src/ipc/utils/mcp_consent.ts
Normal file
108
src/ipc/utils/mcp_consent.ts
Normal file
@@ -0,0 +1,108 @@
|
||||
import { db } from "../../db";
|
||||
import { mcpToolConsents } from "../../db/schema";
|
||||
import { and, eq } from "drizzle-orm";
|
||||
import { IpcMainInvokeEvent } from "electron";
|
||||
|
||||
export type Consent = "ask" | "always" | "denied";
|
||||
|
||||
const pendingConsentResolvers = new Map<
|
||||
string,
|
||||
(d: "accept-once" | "accept-always" | "decline") => void
|
||||
>();
|
||||
|
||||
export function waitForConsent(
|
||||
requestId: string,
|
||||
): Promise<"accept-once" | "accept-always" | "decline"> {
|
||||
return new Promise((resolve) => {
|
||||
pendingConsentResolvers.set(requestId, resolve);
|
||||
});
|
||||
}
|
||||
|
||||
export function resolveConsent(
|
||||
requestId: string,
|
||||
decision: "accept-once" | "accept-always" | "decline",
|
||||
) {
|
||||
const resolver = pendingConsentResolvers.get(requestId);
|
||||
if (resolver) {
|
||||
pendingConsentResolvers.delete(requestId);
|
||||
resolver(decision);
|
||||
}
|
||||
}
|
||||
|
||||
export async function getStoredConsent(
|
||||
serverId: number,
|
||||
toolName: string,
|
||||
): Promise<Consent> {
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(mcpToolConsents)
|
||||
.where(
|
||||
and(
|
||||
eq(mcpToolConsents.serverId, serverId),
|
||||
eq(mcpToolConsents.toolName, toolName),
|
||||
),
|
||||
);
|
||||
if (rows.length === 0) return "ask";
|
||||
return (rows[0].consent as Consent) ?? "ask";
|
||||
}
|
||||
|
||||
export async function setStoredConsent(
|
||||
serverId: number,
|
||||
toolName: string,
|
||||
consent: Consent,
|
||||
): Promise<void> {
|
||||
const rows = await db
|
||||
.select()
|
||||
.from(mcpToolConsents)
|
||||
.where(
|
||||
and(
|
||||
eq(mcpToolConsents.serverId, serverId),
|
||||
eq(mcpToolConsents.toolName, toolName),
|
||||
),
|
||||
);
|
||||
if (rows.length > 0) {
|
||||
await db
|
||||
.update(mcpToolConsents)
|
||||
.set({ consent })
|
||||
.where(
|
||||
and(
|
||||
eq(mcpToolConsents.serverId, serverId),
|
||||
eq(mcpToolConsents.toolName, toolName),
|
||||
),
|
||||
);
|
||||
} else {
|
||||
await db.insert(mcpToolConsents).values({ serverId, toolName, consent });
|
||||
}
|
||||
}
|
||||
|
||||
export async function requireMcpToolConsent(
|
||||
event: IpcMainInvokeEvent,
|
||||
params: {
|
||||
serverId: number;
|
||||
serverName: string;
|
||||
toolName: string;
|
||||
toolDescription?: string | null;
|
||||
inputPreview?: string | null;
|
||||
},
|
||||
): Promise<boolean> {
|
||||
const current = await getStoredConsent(params.serverId, params.toolName);
|
||||
if (current === "always") return true;
|
||||
if (current === "denied") return false;
|
||||
|
||||
// Ask renderer for a decision via event bridge
|
||||
const requestId = `${params.serverId}:${params.toolName}:${Date.now()}`;
|
||||
(event.sender as any).send("mcp:tool-consent-request", {
|
||||
requestId,
|
||||
...params,
|
||||
});
|
||||
const response = await waitForConsent(requestId);
|
||||
|
||||
if (response === "accept-always") {
|
||||
await setStoredConsent(params.serverId, params.toolName, "always");
|
||||
return true;
|
||||
}
|
||||
if (response === "decline") {
|
||||
return false;
|
||||
}
|
||||
return response === "accept-once";
|
||||
}
|
||||
59
src/ipc/utils/mcp_manager.ts
Normal file
59
src/ipc/utils/mcp_manager.ts
Normal file
@@ -0,0 +1,59 @@
|
||||
import { db } from "../../db";
|
||||
import { mcpServers } from "../../db/schema";
|
||||
import { experimental_createMCPClient, experimental_MCPClient } from "ai";
|
||||
import { eq } from "drizzle-orm";
|
||||
|
||||
import { StreamableHTTPClientTransport } from "@modelcontextprotocol/sdk/client/streamableHttp.js";
|
||||
import { StdioClientTransport } from "@modelcontextprotocol/sdk/client/stdio.js";
|
||||
|
||||
class McpManager {
|
||||
private static _instance: McpManager;
|
||||
static get instance(): McpManager {
|
||||
if (!this._instance) this._instance = new McpManager();
|
||||
return this._instance;
|
||||
}
|
||||
|
||||
private clients = new Map<number, experimental_MCPClient>();
|
||||
|
||||
async getClient(serverId: number): Promise<experimental_MCPClient> {
|
||||
const existing = this.clients.get(serverId);
|
||||
if (existing) return existing;
|
||||
const server = await db
|
||||
.select()
|
||||
.from(mcpServers)
|
||||
.where(eq(mcpServers.id, serverId));
|
||||
const s = server.find((x) => x.id === serverId);
|
||||
if (!s) throw new Error(`MCP server not found: ${serverId}`);
|
||||
let transport: StdioClientTransport | StreamableHTTPClientTransport;
|
||||
if (s.transport === "stdio") {
|
||||
const args = s.args ?? [];
|
||||
const env = s.envJson ?? undefined;
|
||||
if (!s.command) throw new Error("MCP server command is required");
|
||||
transport = new StdioClientTransport({
|
||||
command: s.command,
|
||||
args,
|
||||
env,
|
||||
});
|
||||
} else if (s.transport === "http") {
|
||||
if (!s.url) throw new Error("HTTP MCP requires url");
|
||||
transport = new StreamableHTTPClientTransport(new URL(s.url as string));
|
||||
} else {
|
||||
throw new Error(`Unsupported MCP transport: ${s.transport}`);
|
||||
}
|
||||
const client = await experimental_createMCPClient({
|
||||
transport,
|
||||
});
|
||||
this.clients.set(serverId, client);
|
||||
return client;
|
||||
}
|
||||
|
||||
dispose(serverId: number) {
|
||||
const c = this.clients.get(serverId);
|
||||
if (c) {
|
||||
c.close();
|
||||
this.clients.delete(serverId);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const mcpManager = McpManager.instance;
|
||||
Reference in New Issue
Block a user