feat: ConversationSummaryBufferMemory (#973)

* refactor: pass model in message edit payload, use encoder in standalone util function

* feat: add summaryBuffer helper

* refactor(api/messages): use new countTokens helper and add auth middleware at top

* wip: ConversationSummaryBufferMemory

* refactor: move pre-generation helpers to prompts dir

* chore: remove console log

* chore: remove test as payload will no longer carry tokenCount

* chore: update getMessagesWithinTokenLimit JSDoc

* refactor: optimize getMessagesForConversation and also break on summary, feat(ci): getMessagesForConversation tests

* refactor(getMessagesForConvo): count '00000000-0000-0000-0000-000000000000' as root message

* chore: add newer model to token map

* fix: condition was point to prop of array instead of message prop

* refactor(BaseClient): use object for refineMessages param, rename 'summary' to 'summaryMessage', add previous_summary
refactor(getMessagesWithinTokenLimit): replace text and tokenCount if should summarize, summary, and summaryTokenCount are present
fix/refactor(handleContextStrategy): use the right comparison length for context diff, and replace payload first message when a summary is present

* chore: log previous_summary if debugging

* refactor(formatMessage): assume if role is defined that it's a valid value

* refactor(getMessagesWithinTokenLimit): remove summary logic
refactor(handleContextStrategy): add usePrevSummary logic in case only summary was pruned
refactor(loadHistory): initial message query will return all ordered messages but keep track of the latest summary
refactor(getMessagesForConversation): use object for single param, edit jsdoc, edit all files using the method
refactor(ChatGPTClient): order messages before buildPrompt is called, TODO: add convoSumBuffMemory logic

* fix: undefined handling and summarizing only when shouldRefineContext is true

* chore(BaseClient): fix test results omitting system role for summaries and test edge case

* chore: export summaryBuffer from index file

* refactor(OpenAIClient/BaseClient): move refineMessages to subclass, implement LLM initialization for summaryBuffer

* feat: add OPENAI_SUMMARIZE to enable summarizing, refactor: rename client prop 'shouldRefineContext' to 'shouldSummarize', change contextStrategy value to 'summarize' from 'refine'

* refactor: rename refineMessages method to summarizeMessages for clarity

* chore: clarify summary future intent in .env.example

* refactor(initializeLLM): handle case for either 'model' or 'modelName' being passed

* feat(gptPlugins): enable summarization for plugins

* refactor(gptPlugins): utilize new initializeLLM method and formatting methods for messages, use payload array for currentMessages and assign pastMessages sooner

* refactor(agents): use ConversationSummaryBufferMemory for both agent types

* refactor(formatMessage): optimize original method for langchain, add helper function for langchain messages, add JSDocs and tests

* refactor(summaryBuffer): add helper to createSummaryBufferMemory, and use new formatting helpers

* fix: forgot to spread formatMessages also took opportunity to pluralize filename

* refactor: pass memory to tools, namely openapi specs. not used and may never be used by new method but added for testing

* ci(formatMessages): add more exhaustive checks for langchain messages

* feat: add debug env var for OpenAI

* chore: delete unnecessary comments

* chore: add extra note about summary feature

* fix: remove tokenCount from payload instructions

* fix: test fail

* fix: only pass instructions to payload when defined or not empty object

* refactor: fromPromptMessages is deprecated, use renamed method fromMessages

* refactor: use 'includes' instead of 'startsWith' for extended OpenRouter compatibility

* fix(PluginsClient.buildPromptBody): handle undefined message strings

* chore: log langchain titling error

* feat: getModelMaxTokens helper

* feat: tokenSplit helper

* feat: summary prompts updated

* fix: optimize _CUT_OFF_SUMMARIZER prompt

* refactor(summaryBuffer): use custom summary prompt, allow prompt to be passed, pass humanPrefix and aiPrefix to memory, along with any future variables, rename messagesToRefine to context

* fix(summaryBuffer): handle edge case where messagesToRefine exceeds summary context,
refactor(BaseClient): allow custom maxContextTokens to be passed to getMessagesWithinTokenLimit, add defined check before unshifting summaryMessage, update shouldSummarize based on this
refactor(OpenAIClient): use getModelMaxTokens, use cut-off message method for summary if no messages were left after pruning

* fix(handleContextStrategy): handle case where incoming prompt is bigger than model context

* chore: rename refinedContent to splitText

* chore: remove unnecessary debug log
This commit is contained in:
Danny Avila 2023-09-26 21:02:28 -04:00 committed by GitHub
parent be73deddcc
commit 317a1bd8da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 1410 additions and 440 deletions

View file

@ -1,11 +1,7 @@
const crypto = require('crypto');
const TextStream = require('./TextStream');
const { RecursiveCharacterTextSplitter } = require('langchain/text_splitter');
const { ChatOpenAI } = require('langchain/chat_models/openai');
const { loadSummarizationChain } = require('langchain/chains');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('../../models');
const { addSpaceIfNeeded } = require('../../server/utils');
const { refinePrompt } = require('./prompts');
class BaseClient {
constructor(apiKey, options = {}) {
@ -39,6 +35,10 @@ class BaseClient {
throw new Error('Subclasses must implement buildMessages');
}
async summarizeMessages() {
throw new Error('Subclasses attempted to call summarizeMessages without implementing it');
}
getBuildMessagesOptions() {
throw new Error('Subclasses must implement getBuildMessagesOptions');
}
@ -137,9 +137,18 @@ class BaseClient {
};
}
/**
* Adds instructions to the messages array. If the instructions object is empty or undefined,
* the original messages array is returned. Otherwise, the instructions are added to the messages
* array, preserving the last message at the end.
*
* @param {Array} messages - An array of messages.
* @param {Object} instructions - An object containing instructions to be added to the messages.
* @returns {Array} An array containing messages and instructions, or the original messages if instructions are empty.
*/
addInstructions(messages, instructions) {
const payload = [];
if (!instructions) {
if (!instructions || Object.keys(instructions).length === 0) {
return messages;
}
if (messages.length > 1) {
@ -170,19 +179,15 @@ class BaseClient {
const { messageId } = message;
const update = {};
if (messageId === tokenCountMap.refined?.messageId) {
if (this.options.debug) {
console.debug(`Adding refined props to ${messageId}.`);
}
if (messageId === tokenCountMap.summaryMessage?.messageId) {
this.options.debug && console.debug(`Adding summary props to ${messageId}.`);
update.refinedMessageText = tokenCountMap.refined.content;
update.refinedTokenCount = tokenCountMap.refined.tokenCount;
update.summary = tokenCountMap.summaryMessage.content;
update.summaryTokenCount = tokenCountMap.summaryMessage.tokenCount;
}
if (message.tokenCount && !update.refinedTokenCount) {
if (this.options.debug) {
console.debug(`Skipping ${messageId}: already had a token count.`);
}
if (message.tokenCount && !update.summaryTokenCount) {
this.options.debug && console.debug(`Skipping ${messageId}: already had a token count.`);
continue;
}
@ -202,193 +207,141 @@ class BaseClient {
}, '');
}
async refineMessages(messagesToRefine, remainingContextTokens) {
const model = new ChatOpenAI({ temperature: 0 });
const chain = loadSummarizationChain(model, {
type: 'refine',
verbose: this.options.debug,
refinePrompt,
});
const splitter = new RecursiveCharacterTextSplitter({
chunkSize: 1500,
chunkOverlap: 100,
});
const userMessages = this.concatenateMessages(
messagesToRefine.filter((m) => m.role === 'user'),
);
const assistantMessages = this.concatenateMessages(
messagesToRefine.filter((m) => m.role !== 'user'),
);
const userDocs = await splitter.createDocuments([userMessages], [], {
chunkHeader: 'DOCUMENT NAME: User Message\n\n---\n\n',
appendChunkOverlapHeader: true,
});
const assistantDocs = await splitter.createDocuments([assistantMessages], [], {
chunkHeader: 'DOCUMENT NAME: Assistant Message\n\n---\n\n',
appendChunkOverlapHeader: true,
});
// const chunkSize = Math.round(concatenatedMessages.length / 512);
const input_documents = userDocs.concat(assistantDocs);
if (this.options.debug) {
console.debug('Refining messages...');
}
try {
const res = await chain.call({
input_documents,
signal: this.abortController.signal,
});
const refinedMessage = {
role: 'assistant',
content: res.output_text,
tokenCount: this.getTokenCount(res.output_text),
};
if (this.options.debug) {
console.debug('Refined messages', refinedMessage);
console.debug(
`remainingContextTokens: ${remainingContextTokens}, after refining: ${
remainingContextTokens - refinedMessage.tokenCount
}`,
);
}
return refinedMessage;
} catch (e) {
console.error('Error refining messages');
console.error(e);
return null;
}
}
/**
* This method processes an array of messages and returns a context of messages that fit within a token limit.
* This method processes an array of messages and returns a context of messages that fit within a specified token limit.
* It iterates over the messages from newest to oldest, adding them to the context until the token limit is reached.
* If the token limit would be exceeded by adding a message, that message and possibly the previous one are added to a separate array of messages to refine.
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the arrays at the end to maintain the original order of the messages.
* The method also includes a mechanism to avoid blocking the event loop by waiting for the next tick after each iteration.
* If the token limit would be exceeded by adding a message, that message is not added to the context and remains in the original array.
* The method uses `push` and `pop` operations for efficient array manipulation, and reverses the context array at the end to maintain the original order of the messages.
*
* @param {Array} messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
* @returns {Object} An object with three properties: `context`, `remainingContextTokens`, and `messagesToRefine`. `context` is an array of messages that fit within the token limit. `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context. `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
* @param {Array} _messages - An array of messages, each with a `tokenCount` property. The messages should be ordered from oldest to newest.
* @param {number} [maxContextTokens] - The max number of tokens allowed in the context. If not provided, defaults to `this.maxContextTokens`.
* @returns {Object} An object with four properties: `context`, `summaryIndex`, `remainingContextTokens`, and `messagesToRefine`.
* `context` is an array of messages that fit within the token limit.
* `summaryIndex` is the index of the first message in the `messagesToRefine` array.
* `remainingContextTokens` is the number of tokens remaining within the limit after adding the messages to the context.
* `messagesToRefine` is an array of messages that were not added to the context because they would have exceeded the token limit.
*/
async getMessagesWithinTokenLimit(messages) {
async getMessagesWithinTokenLimit(_messages, maxContextTokens) {
// Every reply is primed with <|start|>assistant<|message|>, so we
// start with 3 tokens for the label after all messages have been counted.
let currentTokenCount = 3;
let context = [];
let messagesToRefine = [];
let refineIndex = -1;
let remainingContextTokens = this.maxContextTokens;
let summaryIndex = -1;
let remainingContextTokens = maxContextTokens ?? this.maxContextTokens;
const messages = [..._messages];
for (let i = messages.length - 1; i >= 0; i--) {
const message = messages[i];
const newTokenCount = currentTokenCount + message.tokenCount;
const exceededLimit = newTokenCount > this.maxContextTokens;
let shouldRefine = exceededLimit && this.shouldRefineContext;
let refineNextMessage = i !== 0 && i !== 1 && context.length > 0;
const context = [];
if (currentTokenCount < remainingContextTokens) {
while (messages.length > 0 && currentTokenCount < remainingContextTokens) {
const poppedMessage = messages.pop();
const { tokenCount } = poppedMessage;
if (shouldRefine) {
messagesToRefine.push(message);
if (refineIndex === -1) {
refineIndex = i;
if (poppedMessage && currentTokenCount + tokenCount <= remainingContextTokens) {
context.push(poppedMessage);
currentTokenCount += tokenCount;
} else {
messages.push(poppedMessage);
break;
}
if (refineNextMessage) {
refineIndex = i + 1;
const removedMessage = context.pop();
messagesToRefine.push(removedMessage);
currentTokenCount -= removedMessage.tokenCount;
remainingContextTokens = this.maxContextTokens - currentTokenCount;
refineNextMessage = false;
}
continue;
} else if (exceededLimit) {
break;
}
context.push(message);
currentTokenCount = newTokenCount;
remainingContextTokens = this.maxContextTokens - currentTokenCount;
await new Promise((resolve) => setImmediate(resolve));
}
const prunedMemory = messages;
summaryIndex = prunedMemory.length - 1;
remainingContextTokens -= currentTokenCount;
return {
context: context.reverse(),
remainingContextTokens,
messagesToRefine: messagesToRefine.reverse(),
refineIndex,
messagesToRefine: prunedMemory,
summaryIndex,
};
}
async handleContextStrategy({ instructions, orderedMessages, formattedMessages }) {
let payload = this.addInstructions(formattedMessages, instructions);
let _instructions;
let tokenCount;
if (instructions) {
({ tokenCount, ..._instructions } = instructions);
}
this.options.debug && _instructions && console.debug('instructions tokenCount', tokenCount);
let payload = this.addInstructions(formattedMessages, _instructions);
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
let { context, remainingContextTokens, messagesToRefine, refineIndex } =
await this.getMessagesWithinTokenLimit(payload);
payload = context;
let refinedMessage;
let { context, remainingContextTokens, messagesToRefine, summaryIndex } =
await this.getMessagesWithinTokenLimit(orderedWithInstructions);
// if (messagesToRefine.length > 0) {
// refinedMessage = await this.refineMessages(messagesToRefine, remainingContextTokens);
// payload.unshift(refinedMessage);
// remainingContextTokens -= refinedMessage.tokenCount;
// }
// if (remainingContextTokens <= instructions?.tokenCount) {
// if (this.options.debug) {
// console.debug(`Remaining context (${remainingContextTokens}) is less than instructions token count: ${instructions.tokenCount}`);
// }
// ({ context, remainingContextTokens, messagesToRefine, refineIndex } = await this.getMessagesWithinTokenLimit(payload));
// payload = context;
// }
// Calculate the difference in length to determine how many messages were discarded if any
let diff = orderedWithInstructions.length - payload.length;
if (this.options.debug) {
console.debug('<---------------------------------DIFF--------------------------------->');
console.debug(
`Difference between payload (${payload.length}) and orderedWithInstructions (${orderedWithInstructions.length}): ${diff}`,
);
this.options.debug &&
console.debug(
'remainingContextTokens, this.maxContextTokens (1/2)',
remainingContextTokens,
this.maxContextTokens,
);
}
// If the difference is positive, slice the orderedWithInstructions array
let summaryMessage;
let summaryTokenCount;
let { shouldSummarize } = this;
// Calculate the difference in length to determine how many messages were discarded if any
const { length } = payload;
const diff = length - context.length;
const firstMessage = orderedWithInstructions[0];
const usePrevSummary =
shouldSummarize &&
diff === 1 &&
firstMessage?.summary &&
this.previous_summary.messageId === firstMessage.messageId;
if (diff > 0) {
orderedWithInstructions = orderedWithInstructions.slice(diff);
payload = payload.slice(diff);
this.options.debug &&
console.debug(
`Difference between original payload (${length}) and context (${context.length}): ${diff}`,
);
}
if (messagesToRefine.length > 0) {
refinedMessage = await this.refineMessages(messagesToRefine, remainingContextTokens);
payload.unshift(refinedMessage);
remainingContextTokens -= refinedMessage.tokenCount;
const latestMessage = orderedWithInstructions[orderedWithInstructions.length - 1];
if (payload.length === 0 && !shouldSummarize && latestMessage) {
throw new Error(
`Prompt token count of ${latestMessage.tokenCount} exceeds max token count of ${this.maxContextTokens}.`,
);
}
if (this.options.debug) {
if (usePrevSummary) {
summaryMessage = { role: 'system', content: firstMessage.summary };
summaryTokenCount = firstMessage.summaryTokenCount;
payload.unshift(summaryMessage);
remainingContextTokens -= summaryTokenCount;
} else if (shouldSummarize && messagesToRefine.length > 0) {
({ summaryMessage, summaryTokenCount } = await this.summarizeMessages({
messagesToRefine,
remainingContextTokens,
}));
summaryMessage && payload.unshift(summaryMessage);
remainingContextTokens -= summaryTokenCount;
}
// Make sure to only continue summarization logic if the summary message was generated
shouldSummarize = summaryMessage && shouldSummarize;
this.options.debug &&
console.debug(
'remainingContextTokens, this.maxContextTokens (2/2)',
remainingContextTokens,
this.maxContextTokens,
);
}
let tokenCountMap = orderedWithInstructions.reduce((map, message, index) => {
if (!message.messageId) {
const { messageId } = message;
if (!messageId) {
return map;
}
if (index === refineIndex) {
map.refined = { ...refinedMessage, messageId: message.messageId };
if (shouldSummarize && index === summaryIndex && !usePrevSummary) {
map.summaryMessage = { ...summaryMessage, messageId, tokenCount: summaryTokenCount };
}
map[message.messageId] = payload[index].tokenCount;
map[messageId] = orderedWithInstructions[index].tokenCount;
return map;
}, {});
@ -396,9 +349,16 @@ class BaseClient {
if (this.options.debug) {
console.debug('<-------------------------PAYLOAD/TOKEN COUNT MAP------------------------->');
// console.debug('Payload:', payload);
console.debug('Payload:', payload);
console.debug('Token Count Map:', tokenCountMap);
console.debug('Prompt Tokens', promptTokens, remainingContextTokens, this.maxContextTokens);
console.debug(
'Prompt Tokens',
promptTokens,
'remainingContextTokens',
remainingContextTokens,
'this.maxContextTokens',
this.maxContextTokens,
);
}
return { payload, tokenCountMap, promptTokens, messages: orderedWithInstructions };
@ -445,11 +405,6 @@ class BaseClient {
this.getBuildMessagesOptions(opts),
);
if (this.options.debug) {
console.debug('payload');
console.debug(payload);
}
if (tokenCountMap) {
console.dir(tokenCountMap, { depth: null });
if (tokenCountMap[userMessage.messageId]) {
@ -458,11 +413,6 @@ class BaseClient {
console.log('userMessage', userMessage);
}
payload = payload.map((message) => {
const messageWithoutTokenCount = message;
delete messageWithoutTokenCount.tokenCount;
return messageWithoutTokenCount;
});
this.handleTokenCountMap(tokenCountMap);
}
@ -511,7 +461,30 @@ class BaseClient {
mapMethod = this.getMessageMapMethod();
}
return this.constructor.getMessagesForConversation(messages, parentMessageId, mapMethod);
const orderedMessages = this.constructor.getMessagesForConversation({
messages,
parentMessageId,
mapMethod,
});
if (!this.shouldSummarize) {
return orderedMessages;
}
// Find the latest message with a 'summary' property
for (let i = orderedMessages.length - 1; i >= 0; i--) {
if (orderedMessages[i]?.summary) {
this.previous_summary = orderedMessages[i];
break;
}
}
if (this.options.debug && this.previous_summary) {
const { messageId, summary, tokenCount, summaryTokenCount } = this.previous_summary;
console.debug('Previous summary:', { messageId, summary, tokenCount, summaryTokenCount });
}
return orderedMessages;
}
async saveMessageToDatabase(message, endpointOptions, user = null) {
@ -529,30 +502,79 @@ class BaseClient {
/**
* Iterate through messages, building an array based on the parentMessageId.
* Each message has an id and a parentMessageId. The parentMessageId is the id of the message that this message is a reply to.
* @param messages
* @param parentMessageId
* @returns {*[]} An array containing the messages in the order they should be displayed, starting with the root message.
*
* This function constructs a conversation thread by traversing messages from a given parentMessageId up to the root message.
* It handles cyclic references by ensuring that a message is not processed more than once.
* If the 'summary' option is set to true and a message has a 'summary' property:
* - The message's 'role' is set to 'system'.
* - The message's 'text' is set to its 'summary'.
* - If the message has a 'summaryTokenCount', the message's 'tokenCount' is set to 'summaryTokenCount'.
* The traversal stops at the message with the 'summary' property.
*
* Each message object should have an 'id' or 'messageId' property and may have a 'parentMessageId' property.
* The 'parentMessageId' is the ID of the message that the current message is a reply to.
* If 'parentMessageId' is not present, null, or is '00000000-0000-0000-0000-000000000000',
* the message is considered a root message.
*
* @param {Object} options - The options for the function.
* @param {Array} options.messages - An array of message objects. Each object should have either an 'id' or 'messageId' property, and may have a 'parentMessageId' property.
* @param {string} options.parentMessageId - The ID of the parent message to start the traversal from.
* @param {Function} [options.mapMethod] - An optional function to map over the ordered messages. If provided, it will be applied to each message in the resulting array.
* @param {boolean} [options.summary=false] - If set to true, the traversal modifies messages with 'summary' and 'summaryTokenCount' properties and stops at the message with a 'summary' property.
* @returns {Array} An array containing the messages in the order they should be displayed, starting with the most recent message with a 'summary' property if the 'summary' option is true, and ending with the message identified by 'parentMessageId'.
*/
static getMessagesForConversation(messages, parentMessageId, mapMethod = null) {
static getMessagesForConversation({
messages,
parentMessageId,
mapMethod = null,
summary = false,
}) {
if (!messages || messages.length === 0) {
return [];
}
const orderedMessages = [];
let currentMessageId = parentMessageId;
const visitedMessageIds = new Set();
while (currentMessageId) {
if (visitedMessageIds.has(currentMessageId)) {
break;
}
const message = messages.find((msg) => {
const messageId = msg.messageId ?? msg.id;
return messageId === currentMessageId;
});
visitedMessageIds.add(currentMessageId);
if (!message) {
break;
}
orderedMessages.unshift(message);
currentMessageId = message.parentMessageId;
if (summary && message.summary) {
message.role = 'system';
message.text = message.summary;
}
if (summary && message.summaryTokenCount) {
message.tokenCount = message.summaryTokenCount;
}
orderedMessages.push(message);
if (summary && message.summary) {
break;
}
currentMessageId =
message.parentMessageId === '00000000-0000-0000-0000-000000000000'
? null
: message.parentMessageId;
}
orderedMessages.reverse();
if (mapMethod) {
return orderedMessages.map(mapMethod);
}
@ -565,6 +587,7 @@ class BaseClient {
* https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
*
* An additional 3 tokens need to be added for assistant label priming after all messages have been counted.
* In our implementation, this is accounted for in the getMessagesWithinTokenLimit method.
*
* @param {Object} message
*/