diff --git a/api/app/clients/BaseClient.js b/api/app/clients/BaseClient.js index 20b31a5e3e..a2dfaf9907 100644 --- a/api/app/clients/BaseClient.js +++ b/api/app/clients/BaseClient.js @@ -938,6 +938,7 @@ class BaseClient { throw new Error('User mismatch.'); } + const hasAddedConvo = this.options?.req?.body?.addedConvo != null; const savedMessage = await saveMessage( this.options?.req, { @@ -945,6 +946,7 @@ class BaseClient { endpoint: this.options.endpoint, unfinished: false, user, + ...(hasAddedConvo && { addedConvo: true }), }, { context: 'api/app/clients/BaseClient.js - saveMessageToDatabase #saveMessage' }, ); @@ -1025,7 +1027,8 @@ class BaseClient { * @param {Object} options - The options for the function. * @param {TMessage[]} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property. * @param {string} options.parentMessageId - The ID of the parent message to start the traversal from. - * @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. If provided, it will be applied to each message in the resulting array. + * @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. Applied conditionally based on mapCondition. + * @param {(message: TMessage) => boolean} [options.mapCondition] - An optional function to determine whether mapMethod should be applied to a given message. If not provided and mapMethod is set, mapMethod applies to all messages. * @param {boolean} [options.summary=false] - If set to true, the traversal modifies messages with 'summary' and 'summaryTokenCount' properties and stops at the message with a 'summary' property. * @returns {TMessage[]} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'. */ @@ -1033,6 +1036,7 @@ class BaseClient { messages, parentMessageId, mapMethod = null, + mapCondition = null, summary = false, }) { if (!messages || messages.length === 0) { @@ -1067,7 +1071,9 @@ class BaseClient { message.tokenCount = message.summaryTokenCount; } - orderedMessages.push(message); + const shouldMap = mapMethod != null && (mapCondition != null ? mapCondition(message) : true); + const processedMessage = shouldMap ? mapMethod(message) : message; + orderedMessages.push(processedMessage); if (summary && message.summary) { break; @@ -1078,11 +1084,6 @@ class BaseClient { } orderedMessages.reverse(); - - if (mapMethod) { - return orderedMessages.map(mapMethod); - } - return orderedMessages; } diff --git a/api/server/controllers/agents/client.js b/api/server/controllers/agents/client.js index b29f974485..2f0b576dbb 100644 --- a/api/server/controllers/agents/client.js +++ b/api/server/controllers/agents/client.js @@ -361,14 +361,13 @@ class AgentClient extends BaseClient { { instructions = null, additional_instructions = null }, opts, ) { - const hasAddedConvo = this.options.req?.body?.addedConvo != null; + /** Always pass mapMethod; getMessagesForConversation applies it only to messages with addedConvo flag */ const orderedMessages = this.constructor.getMessagesForConversation({ messages, parentMessageId, summary: this.shouldSummarize, - mapMethod: hasAddedConvo - ? createMultiAgentMapper(this.options.agent, this.agentConfigs) - : undefined, + mapMethod: createMultiAgentMapper(this.options.agent, this.agentConfigs), + mapCondition: (message) => message.addedConvo === true, }); let payload; diff --git a/api/server/controllers/agents/client.test.js b/api/server/controllers/agents/client.test.js index 0ce59c5fbc..f8abf60955 100644 --- a/api/server/controllers/agents/client.test.js +++ b/api/server/controllers/agents/client.test.js @@ -1611,4 +1611,223 @@ describe('AgentClient - titleConvo', () => { expect(mockProcessMemory).not.toHaveBeenCalled(); }); }); + + describe('getMessagesForConversation - mapMethod and mapCondition', () => { + const createMessage = (id, parentId, text, extras = {}) => ({ + messageId: id, + parentMessageId: parentId, + text, + isCreatedByUser: false, + ...extras, + }); + + it('should apply mapMethod to all messages when mapCondition is not provided', () => { + const messages = [ + createMessage('msg-1', null, 'First message'), + createMessage('msg-2', 'msg-1', 'Second message'), + createMessage('msg-3', 'msg-2', 'Third message'), + ]; + + const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true })); + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-3', + mapMethod, + }); + + expect(result).toHaveLength(3); + expect(mapMethod).toHaveBeenCalledTimes(3); + result.forEach((msg) => { + expect(msg.mapped).toBe(true); + }); + }); + + it('should apply mapMethod only to messages where mapCondition returns true', () => { + const messages = [ + createMessage('msg-1', null, 'First message', { addedConvo: false }), + createMessage('msg-2', 'msg-1', 'Second message', { addedConvo: true }), + createMessage('msg-3', 'msg-2', 'Third message', { addedConvo: true }), + createMessage('msg-4', 'msg-3', 'Fourth message', { addedConvo: false }), + ]; + + const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true })); + const mapCondition = (msg) => msg.addedConvo === true; + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-4', + mapMethod, + mapCondition, + }); + + expect(result).toHaveLength(4); + expect(mapMethod).toHaveBeenCalledTimes(2); + + expect(result[0].mapped).toBeUndefined(); + expect(result[1].mapped).toBe(true); + expect(result[2].mapped).toBe(true); + expect(result[3].mapped).toBeUndefined(); + }); + + it('should not apply mapMethod when mapCondition returns false for all messages', () => { + const messages = [ + createMessage('msg-1', null, 'First message', { addedConvo: false }), + createMessage('msg-2', 'msg-1', 'Second message', { addedConvo: false }), + ]; + + const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true })); + const mapCondition = (msg) => msg.addedConvo === true; + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-2', + mapMethod, + mapCondition, + }); + + expect(result).toHaveLength(2); + expect(mapMethod).not.toHaveBeenCalled(); + result.forEach((msg) => { + expect(msg.mapped).toBeUndefined(); + }); + }); + + it('should not call mapMethod when mapMethod is null', () => { + const messages = [ + createMessage('msg-1', null, 'First message'), + createMessage('msg-2', 'msg-1', 'Second message'), + ]; + + const mapCondition = jest.fn(() => true); + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-2', + mapMethod: null, + mapCondition, + }); + + expect(result).toHaveLength(2); + expect(mapCondition).not.toHaveBeenCalled(); + }); + + it('should handle mapCondition with complex logic', () => { + const messages = [ + createMessage('msg-1', null, 'User message', { isCreatedByUser: true, addedConvo: true }), + createMessage('msg-2', 'msg-1', 'Assistant response', { addedConvo: true }), + createMessage('msg-3', 'msg-2', 'Another user message', { isCreatedByUser: true }), + createMessage('msg-4', 'msg-3', 'Another response', { addedConvo: true }), + ]; + + const mapMethod = jest.fn((msg) => ({ ...msg, processed: true })); + const mapCondition = (msg) => msg.addedConvo === true && !msg.isCreatedByUser; + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-4', + mapMethod, + mapCondition, + }); + + expect(result).toHaveLength(4); + expect(mapMethod).toHaveBeenCalledTimes(2); + + expect(result[0].processed).toBeUndefined(); + expect(result[1].processed).toBe(true); + expect(result[2].processed).toBeUndefined(); + expect(result[3].processed).toBe(true); + }); + + it('should preserve message order after applying mapMethod with mapCondition', () => { + const messages = [ + createMessage('msg-1', null, 'First', { addedConvo: true }), + createMessage('msg-2', 'msg-1', 'Second', { addedConvo: false }), + createMessage('msg-3', 'msg-2', 'Third', { addedConvo: true }), + ]; + + const mapMethod = (msg) => ({ ...msg, text: `[MAPPED] ${msg.text}` }); + const mapCondition = (msg) => msg.addedConvo === true; + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-3', + mapMethod, + mapCondition, + }); + + expect(result[0].text).toBe('[MAPPED] First'); + expect(result[1].text).toBe('Second'); + expect(result[2].text).toBe('[MAPPED] Third'); + }); + + it('should work with summary option alongside mapMethod and mapCondition', () => { + const messages = [ + createMessage('msg-1', null, 'First', { addedConvo: false }), + createMessage('msg-2', 'msg-1', 'Second', { + summary: 'Summary of conversation', + addedConvo: true, + }), + createMessage('msg-3', 'msg-2', 'Third', { addedConvo: true }), + createMessage('msg-4', 'msg-3', 'Fourth', { addedConvo: false }), + ]; + + const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true })); + const mapCondition = (msg) => msg.addedConvo === true; + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-4', + mapMethod, + mapCondition, + summary: true, + }); + + /** Traversal stops at msg-2 (has summary), so we get msg-4 -> msg-3 -> msg-2 */ + expect(result).toHaveLength(3); + expect(result[0].text).toBe('Summary of conversation'); + expect(result[0].role).toBe('system'); + expect(result[0].mapped).toBe(true); + expect(result[1].mapped).toBe(true); + expect(result[2].mapped).toBeUndefined(); + }); + + it('should handle empty messages array', () => { + const mapMethod = jest.fn(); + const mapCondition = jest.fn(); + + const result = AgentClient.getMessagesForConversation({ + messages: [], + parentMessageId: 'msg-1', + mapMethod, + mapCondition, + }); + + expect(result).toHaveLength(0); + expect(mapMethod).not.toHaveBeenCalled(); + expect(mapCondition).not.toHaveBeenCalled(); + }); + + it('should handle undefined mapCondition explicitly', () => { + const messages = [ + createMessage('msg-1', null, 'First'), + createMessage('msg-2', 'msg-1', 'Second'), + ]; + + const mapMethod = jest.fn((msg) => ({ ...msg, mapped: true })); + + const result = AgentClient.getMessagesForConversation({ + messages, + parentMessageId: 'msg-2', + mapMethod, + mapCondition: undefined, + }); + + expect(result).toHaveLength(2); + expect(mapMethod).toHaveBeenCalledTimes(2); + result.forEach((msg) => { + expect(msg.mapped).toBe(true); + }); + }); + }); }); diff --git a/packages/data-schemas/src/schema/message.ts b/packages/data-schemas/src/schema/message.ts index 8bfdb1b39e..f960194541 100644 --- a/packages/data-schemas/src/schema/message.ts +++ b/packages/data-schemas/src/schema/message.ts @@ -140,6 +140,10 @@ const messageSchema: Schema = new Schema( expiredAt: { type: Date, }, + addedConvo: { + type: Boolean, + default: undefined, + }, }, { timestamps: true }, ); diff --git a/packages/data-schemas/src/types/message.ts b/packages/data-schemas/src/types/message.ts index f69bcff6b9..2ca262a6bb 100644 --- a/packages/data-schemas/src/types/message.ts +++ b/packages/data-schemas/src/types/message.ts @@ -37,6 +37,7 @@ export interface IMessage extends Document { content?: unknown[]; thread_id?: string; iconURL?: string; + addedConvo?: boolean; metadata?: Record; attachments?: unknown[]; expiredAt?: Date;