feat: ConversationSummaryBufferMemory (#973)

* refactor: pass model in message edit payload, use encoder in standalone util function

* feat: add summaryBuffer helper

* refactor(api/messages): use new countTokens helper and add auth middleware at top

* wip: ConversationSummaryBufferMemory

* refactor: move pre-generation helpers to prompts dir

* chore: remove console log

* chore: remove test as payload will no longer carry tokenCount

* chore: update getMessagesWithinTokenLimit JSDoc

* refactor: optimize getMessagesForConversation and also break on summary, feat(ci): getMessagesForConversation tests

* refactor(getMessagesForConvo): count '00000000-0000-0000-0000-000000000000' as root message

* chore: add newer model to token map

* fix: condition was point to prop of array instead of message prop

* refactor(BaseClient): use object for refineMessages param, rename 'summary' to 'summaryMessage', add previous_summary
refactor(getMessagesWithinTokenLimit): replace text and tokenCount if should summarize, summary, and summaryTokenCount are present
fix/refactor(handleContextStrategy): use the right comparison length for context diff, and replace payload first message when a summary is present

* chore: log previous_summary if debugging

* refactor(formatMessage): assume if role is defined that it's a valid value

* refactor(getMessagesWithinTokenLimit): remove summary logic
refactor(handleContextStrategy): add usePrevSummary logic in case only summary was pruned
refactor(loadHistory): initial message query will return all ordered messages but keep track of the latest summary
refactor(getMessagesForConversation): use object for single param, edit jsdoc, edit all files using the method
refactor(ChatGPTClient): order messages before buildPrompt is called, TODO: add convoSumBuffMemory logic

* fix: undefined handling and summarizing only when shouldRefineContext is true

* chore(BaseClient): fix test results omitting system role for summaries and test edge case

* chore: export summaryBuffer from index file

* refactor(OpenAIClient/BaseClient): move refineMessages to subclass, implement LLM initialization for summaryBuffer

* feat: add OPENAI_SUMMARIZE to enable summarizing, refactor: rename client prop 'shouldRefineContext' to 'shouldSummarize', change contextStrategy value to 'summarize' from 'refine'

* refactor: rename refineMessages method to summarizeMessages for clarity

* chore: clarify summary future intent in .env.example

* refactor(initializeLLM): handle case for either 'model' or 'modelName' being passed

* feat(gptPlugins): enable summarization for plugins

* refactor(gptPlugins): utilize new initializeLLM method and formatting methods for messages, use payload array for currentMessages and assign pastMessages sooner

* refactor(agents): use ConversationSummaryBufferMemory for both agent types

* refactor(formatMessage): optimize original method for langchain, add helper function for langchain messages, add JSDocs and tests

* refactor(summaryBuffer): add helper to createSummaryBufferMemory, and use new formatting helpers

* fix: forgot to spread formatMessages also took opportunity to pluralize filename

* refactor: pass memory to tools, namely openapi specs. not used and may never be used by new method but added for testing

* ci(formatMessages): add more exhaustive checks for langchain messages

* feat: add debug env var for OpenAI

* chore: delete unnecessary comments

* chore: add extra note about summary feature

* fix: remove tokenCount from payload instructions

* fix: test fail

* fix: only pass instructions to payload when defined or not empty object

* refactor: fromPromptMessages is deprecated, use renamed method fromMessages

* refactor: use 'includes' instead of 'startsWith' for extended OpenRouter compatibility

* fix(PluginsClient.buildPromptBody): handle undefined message strings

* chore: log langchain titling error

* feat: getModelMaxTokens helper

* feat: tokenSplit helper

* feat: summary prompts updated

* fix: optimize _CUT_OFF_SUMMARIZER prompt

* refactor(summaryBuffer): use custom summary prompt, allow prompt to be passed, pass humanPrefix and aiPrefix to memory, along with any future variables, rename messagesToRefine to context

* fix(summaryBuffer): handle edge case where messagesToRefine exceeds summary context,
refactor(BaseClient): allow custom maxContextTokens to be passed to getMessagesWithinTokenLimit, add defined check before unshifting summaryMessage, update shouldSummarize based on this
refactor(OpenAIClient): use getModelMaxTokens, use cut-off message method for summary if no messages were left after pruning

* fix(handleContextStrategy): handle case where incoming prompt is bigger than model context

* chore: rename refinedContent to splitText

* chore: remove unnecessary debug log
This commit is contained in:
Danny Avila 2023-09-26 21:02:28 -04:00 committed by GitHub
parent be73deddcc
commit 317a1bd8da
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 1410 additions and 440 deletions

View file

@ -1,11 +1,11 @@
const OpenAIClient = require('./OpenAIClient');
const { CallbackManager } = require('langchain/callbacks');
const { HumanChatMessage, AIChatMessage } = require('langchain/schema');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents');
const { addImages, buildErrorInput, buildPromptPrefix } = require('./output_parsers');
// const { createSummaryBufferMemory } = require('./memory');
const { formatLangChainMessages } = require('./prompts');
const { SelfReflectionTool } = require('./tools');
const { loadTools } = require('./tools/util');
const { createLLM } = require('./llm');
class PluginsClient extends OpenAIClient {
constructor(apiKey, options = {}) {
@ -50,9 +50,9 @@ class PluginsClient extends OpenAIClient {
}
getFunctionModelName(input) {
if (input.startsWith('gpt-3.5-turbo')) {
if (input.includes('gpt-3.5-turbo')) {
return 'gpt-3.5-turbo';
} else if (input.startsWith('gpt-4')) {
} else if (input.includes('gpt-4')) {
return 'gpt-4';
} else {
return 'gpt-3.5-turbo';
@ -73,28 +73,7 @@ class PluginsClient extends OpenAIClient {
temperature: this.agentOptions.temperature,
};
const configOptions = {};
if (this.langchainProxy) {
configOptions.basePath = this.langchainProxy;
}
if (this.useOpenRouter) {
configOptions.basePath = 'https://openrouter.ai/api/v1';
configOptions.baseOptions = {
headers: {
'HTTP-Referer': 'https://librechat.ai',
'X-Title': 'LibreChat',
},
};
}
const model = createLLM({
modelOptions,
configOptions,
openAIApiKey: this.openAIApiKey,
azure: this.azure,
});
const model = this.initializeLLM(modelOptions);
if (this.options.debug) {
console.debug(
@ -102,12 +81,22 @@ class PluginsClient extends OpenAIClient {
);
}
// Map Messages to Langchain format
const pastMessages = formatLangChainMessages(this.currentMessages.slice(0, -1), {
userName: this.options?.name,
});
this.options.debug && console.debug('pastMessages: ', pastMessages);
// TODO: implement new token efficient way of processing openAPI plugins so they can "share" memory with agent
// const memory = createSummaryBufferMemory({ llm: this.initializeLLM(modelOptions), messages: pastMessages });
this.tools = await loadTools({
user,
model,
tools: this.options.tools,
functions: this.functionsAgent,
options: {
// memory,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
debug: this.options?.debug,
@ -140,15 +129,6 @@ class PluginsClient extends OpenAIClient {
}
};
// Map Messages to Langchain format
const pastMessages = this.currentMessages
.slice(0, -1)
.map((msg) =>
msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanChatMessage(msg.text)
: new AIChatMessage(msg.text),
);
// initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
this.executor = await initializer({
@ -272,7 +252,6 @@ class PluginsClient extends OpenAIClient {
prompt: payload,
tokenCountMap,
promptTokens,
messages,
} = await this.buildMessages(
this.currentMessages,
userMessage.messageId,
@ -288,17 +267,12 @@ class PluginsClient extends OpenAIClient {
userMessage.tokenCount = tokenCountMap[userMessage.messageId];
console.log('userMessage.tokenCount', userMessage.tokenCount);
}
payload = payload.map((message) => {
const messageWithoutTokenCount = message;
delete messageWithoutTokenCount.tokenCount;
return messageWithoutTokenCount;
});
this.handleTokenCountMap(tokenCountMap);
}
this.result = {};
if (messages) {
this.currentMessages = messages;
if (payload) {
this.currentMessages = payload;
}
await this.saveMessageToDatabase(userMessage, saveOptions, user);
const responseMessage = {
@ -431,7 +405,9 @@ class PluginsClient extends OpenAIClient {
const message = orderedMessages.pop();
const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
let messageString = `${this.startToken}${roleLabel}:\n${message.text}${this.endToken}\n`;
let messageString = `${this.startToken}${roleLabel}:\n${
message.text ?? message.content ?? ''
}${this.endToken}\n`;
let newPromptBody = `${messageString}${promptBody}`;
const tokenCountForMessage = this.getTokenCount(messageString);