🌿 fix: forking a long conversation breaks chat structure (#4778)

* fix: branching and forking sometimes break conversation structure

* fix test for forking.

* chore: message type issues

* test: add conversation structure tests for message handling

---------

Co-authored-by: xyqyear <xyqyear@gmail.com>
This commit is contained in:
Danny Avila 2024-11-22 16:10:59 -05:00 committed by GitHub
parent 7d5be68747
commit c87a51eaab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 248 additions and 16 deletions

View file

@ -73,15 +73,17 @@ async function saveMessage(req, params, metadata) {
* @async * @async
* @function bulkSaveMessages * @function bulkSaveMessages
* @param {Object[]} messages - An array of message objects to save. * @param {Object[]} messages - An array of message objects to save.
* @param {boolean} [overrideTimestamp=false] - Indicates whether to override the timestamps of the messages. Defaults to false.
* @returns {Promise<Object>} The result of the bulk write operation. * @returns {Promise<Object>} The result of the bulk write operation.
* @throws {Error} If there is an error in saving messages in bulk. * @throws {Error} If there is an error in saving messages in bulk.
*/ */
async function bulkSaveMessages(messages) { async function bulkSaveMessages(messages, overrideTimestamp=false) {
try { try {
const bulkOps = messages.map((message) => ({ const bulkOps = messages.map((message) => ({
updateOne: { updateOne: {
filter: { messageId: message.messageId }, filter: { messageId: message.messageId },
update: message, update: message,
timestamps: !overrideTimestamp,
upsert: true, upsert: true,
}, },
})); }));

View file

@ -0,0 +1,223 @@
const mongoose = require('mongoose');
const { MongoMemoryServer } = require('mongodb-memory-server');
const { Message, getMessages, bulkSaveMessages } = require('./Message');
// Original version of buildTree function
function buildTree({ messages, fileMap }) {
if (messages === null) {
return null;
}
const messageMap = {};
const rootMessages = [];
const childrenCount = {};
messages.forEach((message) => {
const parentId = message.parentMessageId ?? '';
childrenCount[parentId] = (childrenCount[parentId] || 0) + 1;
const extendedMessage = {
...message,
children: [],
depth: 0,
siblingIndex: childrenCount[parentId] - 1,
};
if (message.files && fileMap) {
extendedMessage.files = message.files.map((file) => fileMap[file.file_id ?? ''] ?? file);
}
messageMap[message.messageId] = extendedMessage;
const parentMessage = messageMap[parentId];
if (parentMessage) {
parentMessage.children.push(extendedMessage);
extendedMessage.depth = parentMessage.depth + 1;
} else {
rootMessages.push(extendedMessage);
}
});
return rootMessages;
}
let mongod;
beforeAll(async () => {
mongod = await MongoMemoryServer.create();
const uri = mongod.getUri();
await mongoose.connect(uri);
});
afterAll(async () => {
await mongoose.disconnect();
await mongod.stop();
});
beforeEach(async () => {
await Message.deleteMany({});
});
describe('Conversation Structure Tests', () => {
test('Conversation folding/corrupting with inconsistent timestamps', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';
// Create messages with inconsistent timestamps
const messages = [
{
messageId: 'message0',
parentMessageId: null,
text: 'Message 0',
createdAt: new Date('2023-01-01T00:00:00Z'),
},
{
messageId: 'message1',
parentMessageId: 'message0',
text: 'Message 1',
createdAt: new Date('2023-01-01T00:02:00Z'),
},
{
messageId: 'message2',
parentMessageId: 'message1',
text: 'Message 2',
createdAt: new Date('2023-01-01T00:01:00Z'),
}, // Note: Earlier than its parent
{
messageId: 'message3',
parentMessageId: 'message1',
text: 'Message 3',
createdAt: new Date('2023-01-01T00:03:00Z'),
},
{
messageId: 'message4',
parentMessageId: 'message2',
text: 'Message 4',
createdAt: new Date('2023-01-01T00:04:00Z'),
},
];
// Add common properties to all messages
messages.forEach((msg) => {
msg.conversationId = conversationId;
msg.user = userId;
msg.isCreatedByUser = false;
msg.error = false;
msg.unfinished = false;
});
// Save messages with overrideTimestamp omitted (default is false)
await bulkSaveMessages(messages, true);
// Retrieve messages (this will sort by createdAt)
const retrievedMessages = await getMessages({ conversationId, user: userId });
// Build tree
const tree = buildTree({ messages: retrievedMessages });
// Check if the tree is incorrect (folded/corrupted)
expect(tree.length).toBeGreaterThan(1); // Should have multiple root messages, indicating corruption
});
test('Fix: Conversation structure maintained with more than 16 messages', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';
// Create more than 16 messages
const messages = Array.from({ length: 20 }, (_, i) => ({
messageId: `message${i}`,
parentMessageId: i === 0 ? null : `message${i - 1}`,
conversationId,
user: userId,
text: `Message ${i}`,
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 500000 : -i * 500000)),
}));
// Save messages with new timestamps being generated (message objects ignored)
await bulkSaveMessages(messages);
// Retrieve messages (this will sort by createdAt, but it shouldn't matter now)
const retrievedMessages = await getMessages({ conversationId, user: userId });
// Build tree
const tree = buildTree({ messages: retrievedMessages });
// Check if the tree is correct
expect(tree.length).toBe(1); // Should have only one root message
let currentNode = tree[0];
for (let i = 1; i < 20; i++) {
expect(currentNode.children.length).toBe(1);
currentNode = currentNode.children[0];
expect(currentNode.text).toBe(`Message ${i}`);
}
expect(currentNode.children.length).toBe(0); // Last message should have no children
});
test('Simulate MongoDB ordering issue with more than 16 messages and close timestamps', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';
// Create more than 16 messages with very close timestamps
const messages = Array.from({ length: 20 }, (_, i) => ({
messageId: `message${i}`,
parentMessageId: i === 0 ? null : `message${i - 1}`,
conversationId,
user: userId,
text: `Message ${i}`,
createdAt: new Date(Date.now() + (i % 2 === 0 ? i * 1 : -i * 1)),
}));
// Add common properties to all messages
messages.forEach((msg) => {
msg.isCreatedByUser = false;
msg.error = false;
msg.unfinished = false;
});
await bulkSaveMessages(messages, true);
const retrievedMessages = await getMessages({ conversationId, user: userId });
const tree = buildTree({ messages: retrievedMessages });
expect(tree.length).toBeGreaterThan(1);
});
test('Fix: Preserve order with more than 16 messages by maintaining original timestamps', async () => {
const userId = 'testUser';
const conversationId = 'testConversation';
// Create more than 16 messages with distinct timestamps
const messages = Array.from({ length: 20 }, (_, i) => ({
messageId: `message${i}`,
parentMessageId: i === 0 ? null : `message${i - 1}`,
conversationId,
user: userId,
text: `Message ${i}`,
createdAt: new Date(Date.now() + i * 1000), // Ensure each message has a distinct timestamp
}));
// Add common properties to all messages
messages.forEach((msg) => {
msg.isCreatedByUser = false;
msg.error = false;
msg.unfinished = false;
});
// Save messages with overriding timestamps (preserve original timestamps)
await bulkSaveMessages(messages, true);
// Retrieve messages (this will sort by createdAt)
const retrievedMessages = await getMessages({ conversationId, user: userId });
// Build tree
const tree = buildTree({ messages: retrievedMessages });
// Check if the tree is correct
expect(tree.length).toBe(1); // Should have only one root message
let currentNode = tree[0];
for (let i = 1; i < 20; i++) {
expect(currentNode.children.length).toBe(1);
currentNode = currentNode.children[0];
expect(currentNode.text).toBe(`Message ${i}`);
}
expect(currentNode.children.length).toBe(0); // Last message should have no children
});
});

View file

@ -104,7 +104,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith( expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining( expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })), expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
), ), true,
); );
}); });
@ -122,7 +122,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith( expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining( expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })), expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
), ), true,
); );
}); });
@ -141,7 +141,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith( expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining( expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })), expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
), ), true,
); );
}); });
@ -160,7 +160,7 @@ describe('forkConversation', () => {
expect(bulkSaveMessages).toHaveBeenCalledWith( expect(bulkSaveMessages).toHaveBeenCalledWith(
expect.arrayContaining( expect.arrayContaining(
expectedMessagesTexts.map((text) => expect.objectContaining({ text })), expectedMessagesTexts.map((text) => expect.objectContaining({ text })),
), ), true,
); );
}); });

View file

@ -99,7 +99,7 @@ class ImportBatchBuilder {
async saveBatch() { async saveBatch() {
try { try {
await bulkSaveConvos(this.conversations); await bulkSaveConvos(this.conversations);
await bulkSaveMessages(this.messages); await bulkSaveMessages(this.messages, true);
logger.debug( logger.debug(
`user: ${this.requestUserId} | Added ${this.conversations.length} conversations and ${this.messages.length} messages to the DB.`, `user: ${this.requestUserId} | Added ${this.conversations.length} conversations and ${this.messages.length} messages to the DB.`,
); );

View file

@ -24,10 +24,10 @@ export default function Message({ message }: Pick<TMessageProps, 'message'>) {
let messageLabel = ''; let messageLabel = '';
if (isCreatedByUser) { if (isCreatedByUser) {
messageLabel = UsernameDisplay messageLabel = UsernameDisplay
? (user?.name ?? '') || user?.username ? (user?.name ?? '') || (user?.username ?? '')
: localize('com_user_message'); : localize('com_user_message');
} else { } else {
messageLabel = message.sender; messageLabel = message.sender || '';
} }
return ( return (

View file

@ -28,13 +28,20 @@ export default function Message(props: TMessageProps) {
return null; return null;
} }
const { text, children, messageId = null, isCreatedByUser, error, unfinished } = message ?? {}; const {
text = '',
children,
messageId = null,
isCreatedByUser = true,
error = false,
unfinished = false,
} = message;
let messageLabel = ''; let messageLabel = '';
if (isCreatedByUser) { if (isCreatedByUser) {
messageLabel = 'anonymous'; messageLabel = 'anonymous';
} else { } else {
messageLabel = message.sender; messageLabel = message.sender || '';
} }
return ( return (
@ -67,12 +74,12 @@ export default function Message(props: TMessageProps) {
error={error} error={error}
isLast={false} isLast={false}
ask={() => ({})} ask={() => ({})}
text={text ?? ''} text={text}
message={message} message={message}
isSubmitting={false} isSubmitting={false}
enterEdit={() => ({})} enterEdit={() => ({})}
unfinished={!!unfinished} unfinished={!!unfinished}
isCreatedByUser={isCreatedByUser ?? true} isCreatedByUser={isCreatedByUser}
siblingIdx={siblingIdx ?? 0} siblingIdx={siblingIdx ?? 0}
setSiblingIdx={setSiblingIdx ?? (() => ({}))} setSiblingIdx={setSiblingIdx ?? (() => ({}))}
/> />

View file

@ -67,11 +67,11 @@ export default function useExportConversation({
}; };
if (!message.content) { if (!message.content) {
return formatText(message.sender, message.text); return formatText(message.sender || '', message.text);
} }
return message.content return message.content
.map((content) => getMessageContent(message.sender, content)) .map((content) => getMessageContent(message.sender || '', content))
.map((text) => { .map((text) => {
return formatText(text[0], text[1]); return formatText(text[0], text[1]);
}) })

View file

@ -445,12 +445,12 @@ export const tMessageSchema = z.object({
bg: z.string().nullable().optional(), bg: z.string().nullable().optional(),
model: z.string().nullable().optional(), model: z.string().nullable().optional(),
title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'), title: z.string().nullable().or(z.literal('New Chat')).default('New Chat'),
sender: z.string(), sender: z.string().optional(),
text: z.string(), text: z.string(),
generation: z.string().nullable().optional(), generation: z.string().nullable().optional(),
isEdited: z.boolean().optional(), isEdited: z.boolean().optional(),
isCreatedByUser: z.boolean(), isCreatedByUser: z.boolean(),
error: z.boolean(), error: z.boolean().optional(),
createdAt: z createdAt: z
.string() .string()
.optional() .optional()