653 lines
21 KiB
TypeScript
653 lines
21 KiB
TypeScript
import { ipcMain } from "electron";
|
|
import { CoreMessage, TextPart, ImagePart, streamText } from "ai";
|
|
import { db } from "../../db";
|
|
import { chats, messages } from "../../db/schema";
|
|
import { and, eq, isNull } from "drizzle-orm";
|
|
import { SYSTEM_PROMPT } from "../../prompts/system_prompt";
|
|
import {
|
|
SUPABASE_AVAILABLE_SYSTEM_PROMPT,
|
|
SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT,
|
|
} from "../../prompts/supabase_prompt";
|
|
import { getDyadAppPath } from "../../paths/paths";
|
|
import { readSettings } from "../../main/settings";
|
|
import type { ChatResponseEnd, ChatStreamParams } from "../ipc_types";
|
|
import { extractCodebase } from "../../utils/codebase";
|
|
import { processFullResponseActions } from "../processors/response_processor";
|
|
import { streamTestResponse } from "./testing_chat_handlers";
|
|
import { getTestResponse } from "./testing_chat_handlers";
|
|
import { getMaxTokens, getModelClient } from "../utils/get_model_client";
|
|
import log from "electron-log";
|
|
import {
|
|
getSupabaseContext,
|
|
getSupabaseClientCode,
|
|
} from "../../supabase_admin/supabase_context";
|
|
import { SUMMARIZE_CHAT_SYSTEM_PROMPT } from "../../prompts/summarize_chat_system_prompt";
|
|
import * as fs from "fs";
|
|
import * as path from "path";
|
|
import * as os from "os";
|
|
import * as crypto from "crypto";
|
|
import { readFile, writeFile, unlink } from "fs/promises";
|
|
|
|
const logger = log.scope("chat_stream_handlers");
|
|
|
|
// Track active streams for cancellation
|
|
const activeStreams = new Map<number, AbortController>();
|
|
|
|
// Track partial responses for cancelled streams
|
|
const partialResponses = new Map<number, string>();
|
|
|
|
// Directory for storing temporary files
|
|
const TEMP_DIR = path.join(os.tmpdir(), "dyad-attachments");
|
|
|
|
// Common helper functions
|
|
const TEXT_FILE_EXTENSIONS = [
|
|
".md",
|
|
".txt",
|
|
".json",
|
|
".csv",
|
|
".js",
|
|
".ts",
|
|
".html",
|
|
".css",
|
|
];
|
|
|
|
async function isTextFile(filePath: string): Promise<boolean> {
|
|
const ext = path.extname(filePath).toLowerCase();
|
|
return TEXT_FILE_EXTENSIONS.includes(ext);
|
|
}
|
|
|
|
// Ensure the temp directory exists
|
|
if (!fs.existsSync(TEMP_DIR)) {
|
|
fs.mkdirSync(TEMP_DIR, { recursive: true });
|
|
}
|
|
|
|
export function registerChatStreamHandlers() {
|
|
ipcMain.handle("chat:stream", async (event, req: ChatStreamParams) => {
|
|
try {
|
|
// Create an AbortController for this stream
|
|
const abortController = new AbortController();
|
|
activeStreams.set(req.chatId, abortController);
|
|
|
|
// Get the chat to check for existing messages
|
|
const chat = await db.query.chats.findFirst({
|
|
where: eq(chats.id, req.chatId),
|
|
with: {
|
|
messages: {
|
|
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
|
|
},
|
|
app: true, // Include app information
|
|
},
|
|
});
|
|
|
|
if (!chat) {
|
|
throw new Error(`Chat not found: ${req.chatId}`);
|
|
}
|
|
|
|
// Handle redo option: remove the most recent messages if needed
|
|
if (req.redo) {
|
|
// Get the most recent messages
|
|
const chatMessages = [...chat.messages];
|
|
|
|
// Find the most recent user message
|
|
let lastUserMessageIndex = chatMessages.length - 1;
|
|
while (
|
|
lastUserMessageIndex >= 0 &&
|
|
chatMessages[lastUserMessageIndex].role !== "user"
|
|
) {
|
|
lastUserMessageIndex--;
|
|
}
|
|
|
|
if (lastUserMessageIndex >= 0) {
|
|
// Delete the user message
|
|
await db
|
|
.delete(messages)
|
|
.where(eq(messages.id, chatMessages[lastUserMessageIndex].id));
|
|
|
|
// If there's an assistant message after the user message, delete it too
|
|
if (
|
|
lastUserMessageIndex < chatMessages.length - 1 &&
|
|
chatMessages[lastUserMessageIndex + 1].role === "assistant"
|
|
) {
|
|
await db
|
|
.delete(messages)
|
|
.where(
|
|
eq(messages.id, chatMessages[lastUserMessageIndex + 1].id),
|
|
);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Process attachments if any
|
|
let attachmentInfo = "";
|
|
let attachmentPaths: string[] = [];
|
|
|
|
if (req.attachments && req.attachments.length > 0) {
|
|
attachmentInfo = "\n\nAttachments:\n";
|
|
|
|
for (const attachment of req.attachments) {
|
|
// Generate a unique filename
|
|
const hash = crypto
|
|
.createHash("md5")
|
|
.update(attachment.name + Date.now())
|
|
.digest("hex");
|
|
const fileExtension = path.extname(attachment.name);
|
|
const filename = `${hash}${fileExtension}`;
|
|
const filePath = path.join(TEMP_DIR, filename);
|
|
|
|
// Extract the base64 data (remove the data:mime/type;base64, prefix)
|
|
const base64Data = attachment.data.split(";base64,").pop() || "";
|
|
|
|
await writeFile(filePath, Buffer.from(base64Data, "base64"));
|
|
attachmentPaths.push(filePath);
|
|
attachmentInfo += `- ${attachment.name} (${attachment.type})\n`;
|
|
// If it's a text-based file, try to include the content
|
|
if (await isTextFile(filePath)) {
|
|
try {
|
|
attachmentInfo += `<dyad-text-attachment filename="${attachment.name}" type="${attachment.type}" path="${filePath}">
|
|
</dyad-text-attachment>
|
|
\n\n`;
|
|
} catch (err) {
|
|
logger.error(`Error reading file content: ${err}`);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add user message to database with attachment info
|
|
const userPrompt = req.prompt + (attachmentInfo ? attachmentInfo : "");
|
|
await db
|
|
.insert(messages)
|
|
.values({
|
|
chatId: req.chatId,
|
|
role: "user",
|
|
content: userPrompt,
|
|
})
|
|
.returning();
|
|
|
|
// Add a placeholder assistant message immediately
|
|
const [placeholderAssistantMessage] = await db
|
|
.insert(messages)
|
|
.values({
|
|
chatId: req.chatId,
|
|
role: "assistant",
|
|
content: "", // Start with empty content
|
|
})
|
|
.returning();
|
|
|
|
// Fetch updated chat data after possible deletions and additions
|
|
const updatedChat = await db.query.chats.findFirst({
|
|
where: eq(chats.id, req.chatId),
|
|
with: {
|
|
messages: {
|
|
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
|
|
},
|
|
app: true, // Include app information
|
|
},
|
|
});
|
|
|
|
if (!updatedChat) {
|
|
throw new Error(`Chat not found: ${req.chatId}`);
|
|
}
|
|
|
|
// Send the messages right away so that the loading state is shown for the message.
|
|
event.sender.send("chat:response:chunk", {
|
|
chatId: req.chatId,
|
|
messages: updatedChat.messages,
|
|
});
|
|
|
|
let fullResponse = "";
|
|
|
|
// Check if this is a test prompt
|
|
const testResponse = getTestResponse(req.prompt);
|
|
|
|
if (testResponse) {
|
|
// For test prompts, use the dedicated function
|
|
fullResponse = await streamTestResponse(
|
|
event,
|
|
req.chatId,
|
|
testResponse,
|
|
abortController,
|
|
updatedChat,
|
|
);
|
|
} else {
|
|
// Normal AI processing for non-test prompts
|
|
const settings = readSettings();
|
|
const modelClient = await getModelClient(
|
|
settings.selectedModel,
|
|
settings,
|
|
);
|
|
|
|
// Extract codebase information if app is associated with the chat
|
|
let codebaseInfo = "";
|
|
if (updatedChat.app) {
|
|
const appPath = getDyadAppPath(updatedChat.app.path);
|
|
try {
|
|
codebaseInfo = await extractCodebase(appPath);
|
|
logger.log(`Extracted codebase information from ${appPath}`);
|
|
} catch (error) {
|
|
logger.error("Error extracting codebase:", error);
|
|
}
|
|
}
|
|
logger.log(
|
|
"codebaseInfo: length",
|
|
codebaseInfo.length,
|
|
"estimated tokens",
|
|
codebaseInfo.length / 4,
|
|
);
|
|
|
|
// Prepare message history for the AI
|
|
const messageHistory = updatedChat.messages.map((message) => ({
|
|
role: message.role as "user" | "assistant" | "system",
|
|
content: message.content,
|
|
}));
|
|
let systemPrompt = SYSTEM_PROMPT;
|
|
if (
|
|
updatedChat.app?.supabaseProjectId &&
|
|
settings.supabase?.accessToken?.value
|
|
) {
|
|
systemPrompt +=
|
|
"\n\n" +
|
|
SUPABASE_AVAILABLE_SYSTEM_PROMPT +
|
|
"\n\n" +
|
|
(await getSupabaseContext({
|
|
supabaseProjectId: updatedChat.app.supabaseProjectId,
|
|
}));
|
|
} else {
|
|
systemPrompt += "\n\n" + SUPABASE_NOT_AVAILABLE_SYSTEM_PROMPT;
|
|
}
|
|
const isSummarizeIntent = req.prompt.startsWith(
|
|
"Summarize from chat-id=",
|
|
);
|
|
if (isSummarizeIntent) {
|
|
systemPrompt = SUMMARIZE_CHAT_SYSTEM_PROMPT;
|
|
}
|
|
|
|
// Update the system prompt for images if there are image attachments
|
|
const hasImageAttachments =
|
|
req.attachments &&
|
|
req.attachments.some((attachment) =>
|
|
attachment.type.startsWith("image/"),
|
|
);
|
|
|
|
if (hasImageAttachments) {
|
|
systemPrompt += `
|
|
|
|
# Image Analysis Capabilities
|
|
This conversation includes one or more image attachments. When the user uploads images:
|
|
1. If the user explicitly asks for analysis, description, or information about the image, please analyze the image content.
|
|
2. Describe what you see in the image if asked.
|
|
3. You can use images as references when the user has coding or design-related questions.
|
|
4. For diagrams or wireframes, try to understand the content and structure shown.
|
|
5. For screenshots of code or errors, try to identify the issue or explain the code.
|
|
`;
|
|
}
|
|
|
|
let chatMessages: CoreMessage[] = [
|
|
{
|
|
role: "user",
|
|
content: "This is my codebase. " + codebaseInfo,
|
|
},
|
|
{
|
|
role: "assistant",
|
|
content: "OK, got it. I'm ready to help",
|
|
},
|
|
...messageHistory.map((msg) => ({
|
|
role: msg.role as "user" | "assistant" | "system",
|
|
content: msg.content,
|
|
})),
|
|
];
|
|
|
|
// Check if the last message should include attachments
|
|
if (chatMessages.length >= 2 && attachmentPaths.length > 0) {
|
|
const lastUserIndex = chatMessages.length - 2;
|
|
const lastUserMessage = chatMessages[lastUserIndex];
|
|
|
|
if (lastUserMessage.role === "user") {
|
|
// Replace the last message with one that includes attachments
|
|
chatMessages[lastUserIndex] = await prepareMessageWithAttachments(
|
|
lastUserMessage,
|
|
attachmentPaths,
|
|
);
|
|
}
|
|
}
|
|
|
|
if (isSummarizeIntent) {
|
|
const previousChat = await db.query.chats.findFirst({
|
|
where: eq(chats.id, parseInt(req.prompt.split("=")[1])),
|
|
with: {
|
|
messages: {
|
|
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
|
|
},
|
|
},
|
|
});
|
|
chatMessages = [
|
|
{
|
|
role: "user",
|
|
content:
|
|
"Summarize the following chat: " +
|
|
formatMessages(previousChat?.messages ?? []),
|
|
} satisfies CoreMessage,
|
|
];
|
|
}
|
|
|
|
// When calling streamText, the messages need to be properly formatted for mixed content
|
|
const { textStream } = streamText({
|
|
maxTokens: getMaxTokens(settings.selectedModel),
|
|
temperature: 0,
|
|
model: modelClient,
|
|
system: systemPrompt,
|
|
messages: chatMessages.filter((m) => m.content),
|
|
onError: (error) => {
|
|
logger.error("Error streaming text:", error);
|
|
const message =
|
|
(error as any)?.error?.message || JSON.stringify(error);
|
|
event.sender.send(
|
|
"chat:response:error",
|
|
`Sorry, there was an error from the AI: ${message}`,
|
|
);
|
|
// Clean up the abort controller
|
|
activeStreams.delete(req.chatId);
|
|
},
|
|
abortSignal: abortController.signal,
|
|
});
|
|
|
|
// Process the stream as before
|
|
try {
|
|
for await (const textPart of textStream) {
|
|
fullResponse += textPart;
|
|
if (
|
|
fullResponse.includes("$$SUPABASE_CLIENT_CODE$$") &&
|
|
updatedChat.app?.supabaseProjectId
|
|
) {
|
|
const supabaseClientCode = await getSupabaseClientCode({
|
|
projectId: updatedChat.app?.supabaseProjectId,
|
|
});
|
|
fullResponse = fullResponse.replace(
|
|
"$$SUPABASE_CLIENT_CODE$$",
|
|
supabaseClientCode,
|
|
);
|
|
}
|
|
// Store the current partial response
|
|
partialResponses.set(req.chatId, fullResponse);
|
|
|
|
// Update the placeholder assistant message content in the messages array
|
|
const currentMessages = [...updatedChat.messages];
|
|
if (
|
|
currentMessages.length > 0 &&
|
|
currentMessages[currentMessages.length - 1].role === "assistant"
|
|
) {
|
|
currentMessages[currentMessages.length - 1].content =
|
|
fullResponse;
|
|
}
|
|
|
|
// Update the assistant message in the database
|
|
event.sender.send("chat:response:chunk", {
|
|
chatId: req.chatId,
|
|
messages: currentMessages,
|
|
});
|
|
|
|
// If the stream was aborted, exit early
|
|
if (abortController.signal.aborted) {
|
|
logger.log(`Stream for chat ${req.chatId} was aborted`);
|
|
break;
|
|
}
|
|
}
|
|
} catch (streamError) {
|
|
// Check if this was an abort error
|
|
if (abortController.signal.aborted) {
|
|
const chatId = req.chatId;
|
|
const partialResponse = partialResponses.get(req.chatId);
|
|
// If we have a partial response, save it to the database
|
|
if (partialResponse) {
|
|
try {
|
|
// Update the placeholder assistant message with the partial content and cancellation note
|
|
await db
|
|
.update(messages)
|
|
.set({
|
|
content: `${partialResponse}
|
|
|
|
[Response cancelled by user]`,
|
|
})
|
|
.where(eq(messages.id, placeholderAssistantMessage.id));
|
|
|
|
logger.log(
|
|
`Updated cancelled response for placeholder message ${placeholderAssistantMessage.id} in chat ${chatId}`,
|
|
);
|
|
partialResponses.delete(req.chatId);
|
|
} catch (error) {
|
|
logger.error(
|
|
`Error saving partial response for chat ${chatId}:`,
|
|
error,
|
|
);
|
|
}
|
|
}
|
|
return req.chatId;
|
|
}
|
|
throw streamError;
|
|
}
|
|
}
|
|
|
|
// Only save the response and process it if we weren't aborted
|
|
if (!abortController.signal.aborted && fullResponse) {
|
|
// Scrape from: <dyad-chat-summary>Renaming profile file</dyad-chat-title>
|
|
const chatTitle = fullResponse.match(
|
|
/<dyad-chat-summary>(.*?)<\/dyad-chat-summary>/,
|
|
);
|
|
if (chatTitle) {
|
|
await db
|
|
.update(chats)
|
|
.set({ title: chatTitle[1] })
|
|
.where(and(eq(chats.id, req.chatId), isNull(chats.title)));
|
|
}
|
|
const chatSummary = chatTitle?.[1];
|
|
|
|
// Update the placeholder assistant message with the full response
|
|
await db
|
|
.update(messages)
|
|
.set({ content: fullResponse })
|
|
.where(eq(messages.id, placeholderAssistantMessage.id));
|
|
|
|
if (readSettings().autoApproveChanges) {
|
|
const status = await processFullResponseActions(
|
|
fullResponse,
|
|
req.chatId,
|
|
{ chatSummary, messageId: placeholderAssistantMessage.id }, // Use placeholder ID
|
|
);
|
|
|
|
const chat = await db.query.chats.findFirst({
|
|
where: eq(chats.id, req.chatId),
|
|
with: {
|
|
messages: {
|
|
orderBy: (messages, { asc }) => [asc(messages.createdAt)],
|
|
},
|
|
},
|
|
});
|
|
|
|
event.sender.send("chat:response:chunk", {
|
|
chatId: req.chatId,
|
|
messages: chat!.messages,
|
|
});
|
|
|
|
if (status.error) {
|
|
event.sender.send(
|
|
"chat:response:error",
|
|
`Sorry, there was an error applying the AI's changes: ${status.error}`,
|
|
);
|
|
}
|
|
|
|
// Signal that the stream has completed
|
|
event.sender.send("chat:response:end", {
|
|
chatId: req.chatId,
|
|
updatedFiles: status.updatedFiles ?? false,
|
|
uncommittedFiles: status.uncommittedFiles,
|
|
} satisfies ChatResponseEnd);
|
|
} else {
|
|
event.sender.send("chat:response:end", {
|
|
chatId: req.chatId,
|
|
updatedFiles: false,
|
|
} satisfies ChatResponseEnd);
|
|
}
|
|
}
|
|
|
|
// Clean up any temporary files
|
|
if (attachmentPaths.length > 0) {
|
|
for (const filePath of attachmentPaths) {
|
|
try {
|
|
// We don't immediately delete files because they might be needed for reference
|
|
// Instead, schedule them for deletion after some time
|
|
setTimeout(
|
|
async () => {
|
|
if (fs.existsSync(filePath)) {
|
|
await unlink(filePath);
|
|
logger.log(`Deleted temporary file: ${filePath}`);
|
|
}
|
|
},
|
|
30 * 60 * 1000,
|
|
); // Delete after 30 minutes
|
|
} catch (error) {
|
|
logger.error(`Error scheduling file deletion: ${error}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Return the chat ID for backwards compatibility
|
|
return req.chatId;
|
|
} catch (error) {
|
|
logger.error("Error calling LLM:", error);
|
|
event.sender.send(
|
|
"chat:response:error",
|
|
`Sorry, there was an error processing your request: ${error}`,
|
|
);
|
|
// Clean up the abort controller
|
|
activeStreams.delete(req.chatId);
|
|
return "error";
|
|
}
|
|
});
|
|
|
|
// Handler to cancel an ongoing stream
|
|
ipcMain.handle("chat:cancel", async (event, chatId: number) => {
|
|
const abortController = activeStreams.get(chatId);
|
|
|
|
if (abortController) {
|
|
// Abort the stream
|
|
abortController.abort();
|
|
activeStreams.delete(chatId);
|
|
logger.log(`Aborted stream for chat ${chatId}`);
|
|
} else {
|
|
logger.warn(`No active stream found for chat ${chatId}`);
|
|
}
|
|
|
|
// Send the end event to the renderer
|
|
event.sender.send("chat:response:end", {
|
|
chatId,
|
|
updatedFiles: false,
|
|
} satisfies ChatResponseEnd);
|
|
|
|
return true;
|
|
});
|
|
}
|
|
|
|
export function formatMessages(
|
|
messages: { role: string; content: string | undefined }[],
|
|
) {
|
|
return messages
|
|
.map((m) => `<message role="${m.role}">${m.content}</message>`)
|
|
.join("\n");
|
|
}
|
|
|
|
// Helper function to replace text attachment placeholders with full content
|
|
async function replaceTextAttachmentWithContent(
|
|
text: string,
|
|
filePath: string,
|
|
fileName: string,
|
|
): Promise<string> {
|
|
try {
|
|
if (await isTextFile(filePath)) {
|
|
// Read the full content
|
|
const fullContent = await readFile(filePath, "utf-8");
|
|
|
|
// Replace the placeholder tag with the full content
|
|
const escapedPath = filePath.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
|
const tagPattern = new RegExp(
|
|
`<dyad-text-attachment filename="[^"]*" type="[^"]*" path="${escapedPath}">\\s*<\\/dyad-text-attachment>`,
|
|
"g",
|
|
);
|
|
|
|
const replacedText = text.replace(
|
|
tagPattern,
|
|
`Full content of ${fileName}:\n\`\`\`\n${fullContent}\n\`\`\``,
|
|
);
|
|
|
|
logger.log(
|
|
`Replaced text attachment content for: ${fileName} - length before: ${text.length} - length after: ${replacedText.length}`,
|
|
);
|
|
return replacedText;
|
|
}
|
|
return text;
|
|
} catch (error) {
|
|
logger.error(`Error processing text file: ${error}`);
|
|
return text;
|
|
}
|
|
}
|
|
|
|
// Helper function to convert traditional message to one with proper image attachments
|
|
async function prepareMessageWithAttachments(
|
|
message: CoreMessage,
|
|
attachmentPaths: string[],
|
|
): Promise<CoreMessage> {
|
|
let textContent = message.content;
|
|
// Get the original text content
|
|
if (typeof textContent !== "string") {
|
|
logger.warn(
|
|
"Message content is not a string - shouldn't happen but using message as-is",
|
|
);
|
|
return message;
|
|
}
|
|
|
|
// Process text file attachments - replace placeholder tags with full content
|
|
for (const filePath of attachmentPaths) {
|
|
const fileName = path.basename(filePath);
|
|
textContent = await replaceTextAttachmentWithContent(
|
|
textContent,
|
|
filePath,
|
|
fileName,
|
|
);
|
|
}
|
|
|
|
// For user messages with attachments, create a content array
|
|
const contentParts: (TextPart | ImagePart)[] = [];
|
|
|
|
// Add the text part first with possibly modified content
|
|
contentParts.push({
|
|
type: "text",
|
|
text: textContent,
|
|
});
|
|
|
|
// Add image parts for any image attachments
|
|
for (const filePath of attachmentPaths) {
|
|
const ext = path.extname(filePath).toLowerCase();
|
|
if ([".jpg", ".jpeg", ".png", ".gif", ".webp"].includes(ext)) {
|
|
try {
|
|
// Read the file as a buffer
|
|
const imageBuffer = await readFile(filePath);
|
|
|
|
// Add the image to the content parts
|
|
contentParts.push({
|
|
type: "image",
|
|
image: imageBuffer,
|
|
});
|
|
|
|
logger.log(`Added image attachment: ${filePath}`);
|
|
} catch (error) {
|
|
logger.error(`Error reading image file: ${error}`);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Return the message with the content array
|
|
return {
|
|
role: "user",
|
|
content: contentParts,
|
|
};
|
|
}
|