diff --git a/src/__tests__/formatMessagesForSummary.test.ts b/src/__tests__/formatMessagesForSummary.test.ts new file mode 100644 index 0000000..39bcdb7 --- /dev/null +++ b/src/__tests__/formatMessagesForSummary.test.ts @@ -0,0 +1,166 @@ +import { formatMessagesForSummary } from "../ipc/handlers/chat_stream_handlers"; + +describe("formatMessagesForSummary", () => { + it("should return all messages when there are 8 or fewer messages", () => { + const messages = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: "Hi there!" }, + { role: "user", content: "How are you?" }, + { role: "assistant", content: "I'm doing well, thanks!" }, + ]; + + const result = formatMessagesForSummary(messages); + const expected = [ + 'Hello', + 'Hi there!', + 'How are you?', + 'I\'m doing well, thanks!', + ].join("\n"); + + expect(result).toBe(expected); + }); + + it("should return all messages when there are exactly 8 messages", () => { + const messages = Array.from({ length: 8 }, (_, i) => ({ + role: i % 2 === 0 ? "user" : "assistant", + content: `Message ${i + 1}`, + })); + + const result = formatMessagesForSummary(messages); + const expected = messages + .map((m) => `${m.content}`) + .join("\n"); + + expect(result).toBe(expected); + }); + + it("should truncate messages when there are more than 8 messages", () => { + const messages = Array.from({ length: 12 }, (_, i) => ({ + role: i % 2 === 0 ? "user" : "assistant", + content: `Message ${i + 1}`, + })); + + const result = formatMessagesForSummary(messages); + + // Should contain first 2 messages + expect(result).toContain('Message 1'); + expect(result).toContain('Message 2'); + + // Should contain omission indicator + expect(result).toContain( + '[... 4 messages omitted ...]', + ); + + // Should contain last 6 messages + expect(result).toContain('Message 7'); + expect(result).toContain('Message 8'); + expect(result).toContain('Message 9'); + expect(result).toContain('Message 10'); + expect(result).toContain('Message 11'); + expect(result).toContain('Message 12'); + + // Should not contain middle messages + expect(result).not.toContain('Message 3'); + expect(result).not.toContain( + 'Message 4', + ); + expect(result).not.toContain('Message 5'); + expect(result).not.toContain( + 'Message 6', + ); + }); + + it("should handle messages with undefined content", () => { + const messages = [ + { role: "user", content: "Hello" }, + { role: "assistant", content: undefined }, + { role: "user", content: "Are you there?" }, + ]; + + const result = formatMessagesForSummary(messages); + const expected = [ + 'Hello', + 'undefined', + 'Are you there?', + ].join("\n"); + + expect(result).toBe(expected); + }); + + it("should handle empty messages array", () => { + const messages: { role: string; content: string | undefined }[] = []; + const result = formatMessagesForSummary(messages); + expect(result).toBe(""); + }); + + it("should handle single message", () => { + const messages = [{ role: "user", content: "Hello world" }]; + const result = formatMessagesForSummary(messages); + expect(result).toBe('Hello world'); + }); + + it("should correctly calculate omitted messages count", () => { + const messages = Array.from({ length: 20 }, (_, i) => ({ + role: i % 2 === 0 ? "user" : "assistant", + content: `Message ${i + 1}`, + })); + + const result = formatMessagesForSummary(messages); + + // Should indicate 12 messages omitted (20 total - 2 first - 6 last = 12) + expect(result).toContain( + '[... 12 messages omitted ...]', + ); + }); + + it("should handle messages with special characters in content", () => { + const messages = [ + { role: "user", content: 'Hello & "friends"' }, + { role: "assistant", content: "Hi there! content" }, + ]; + + const result = formatMessagesForSummary(messages); + + // Should preserve special characters as-is (no HTML escaping) + expect(result).toContain( + 'Hello & "friends"', + ); + expect(result).toContain( + 'Hi there! content', + ); + }); + + it("should maintain message order in truncated output", () => { + const messages = Array.from({ length: 15 }, (_, i) => ({ + role: i % 2 === 0 ? "user" : "assistant", + content: `Message ${i + 1}`, + })); + + const result = formatMessagesForSummary(messages); + const lines = result.split("\n"); + + // Should have exactly 9 lines (2 first + 1 omission + 6 last) + expect(lines).toHaveLength(9); + + // Check order: first 2, then omission, then last 6 + expect(lines[0]).toBe('Message 1'); + expect(lines[1]).toBe('Message 2'); + expect(lines[2]).toBe( + '[... 7 messages omitted ...]', + ); + + // Last 6 messages are messages 10-15 (indices 9-14) + // Message 10 (index 9): 9 % 2 === 1, so "assistant" + // Message 11 (index 10): 10 % 2 === 0, so "user" + // Message 12 (index 11): 11 % 2 === 1, so "assistant" + // Message 13 (index 12): 12 % 2 === 0, so "user" + // Message 14 (index 13): 13 % 2 === 1, so "assistant" + // Message 15 (index 14): 14 % 2 === 0, so "user" + expect(lines[3]).toBe('Message 10'); + expect(lines[4]).toBe('Message 11'); + expect(lines[5]).toBe('Message 12'); + expect(lines[6]).toBe('Message 13'); + expect(lines[7]).toBe('Message 14'); + expect(lines[8]).toBe('Message 15'); + }); +}); diff --git a/src/ipc/handlers/chat_stream_handlers.ts b/src/ipc/handlers/chat_stream_handlers.ts index cda6416..a2620cd 100644 --- a/src/ipc/handlers/chat_stream_handlers.ts +++ b/src/ipc/handlers/chat_stream_handlers.ts @@ -446,7 +446,7 @@ This conversation includes one or more image attachments. When the user uploads role: "user", content: "Summarize the following chat: " + - formatMessages(previousChat?.messages ?? []), + formatMessagesForSummary(previousChat?.messages ?? []), } satisfies CoreMessage, ]; } @@ -775,10 +775,31 @@ This conversation includes one or more image attachments. When the user uploads }); } -export function formatMessages( +export function formatMessagesForSummary( messages: { role: string; content: string | undefined }[], ) { - return messages + if (messages.length <= 8) { + // If we have 8 or fewer messages, include all of them + return messages + .map((m) => `${m.content}`) + .join("\n"); + } + + // Take first 2 messages and last 6 messages + const firstMessages = messages.slice(0, 2); + const lastMessages = messages.slice(-6); + + // Combine them with an indicator of skipped messages + const combinedMessages = [ + ...firstMessages, + { + role: "system", + content: `[... ${messages.length - 8} messages omitted ...]`, + }, + ...lastMessages, + ]; + + return combinedMessages .map((m) => `${m.content}`) .join("\n"); }