hotfix(Plugins): retain message history, improve completion prompt handling (#597)

This commit is contained in:
Danny Avila 2023-07-06 16:45:39 -04:00 committed by GitHub
parent 12e2826d39
commit fabd85ff40
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 50 additions and 59 deletions

View file

@ -19,15 +19,15 @@ class BaseClient {
} }
setOptions() { setOptions() {
throw new Error("Method 'setOptions' must be implemented."); throw new Error('Method \'setOptions\' must be implemented.');
} }
getCompletion() { getCompletion() {
throw new Error("Method 'getCompletion' must be implemented."); throw new Error('Method \'getCompletion\' must be implemented.');
} }
sendCompletion() { async sendCompletion() {
throw new Error("Method 'sendCompletion' must be implemented."); throw new Error('Method \'sendCompletion\' must be implemented.');
} }
getSaveOptions() { getSaveOptions() {
@ -112,11 +112,6 @@ class BaseClient {
opts.onStart(userMessage); opts.onStart(userMessage);
} }
if (this.options.debug) {
console.debug('options');
console.debug(this.options);
}
return { return {
...opts, ...opts,
user, user,
@ -202,17 +197,17 @@ class BaseClient {
const userMessages = this.concatenateMessages(messagesToRefine.filter(m => m.role === 'user')); const userMessages = this.concatenateMessages(messagesToRefine.filter(m => m.role === 'user'));
const assistantMessages = this.concatenateMessages(messagesToRefine.filter(m => m.role !== 'user')); const assistantMessages = this.concatenateMessages(messagesToRefine.filter(m => m.role !== 'user'));
const userDocs = await splitter.createDocuments([userMessages],[],{ const userDocs = await splitter.createDocuments([userMessages],[],{
chunkHeader: `DOCUMENT NAME: User Message\n\n---\n\n`, chunkHeader: 'DOCUMENT NAME: User Message\n\n---\n\n',
appendChunkOverlapHeader: true, appendChunkOverlapHeader: true,
}); });
const assistantDocs = await splitter.createDocuments([assistantMessages],[],{ const assistantDocs = await splitter.createDocuments([assistantMessages],[],{
chunkHeader: `DOCUMENT NAME: Assistant Message\n\n---\n\n`, chunkHeader: 'DOCUMENT NAME: Assistant Message\n\n---\n\n',
appendChunkOverlapHeader: true, appendChunkOverlapHeader: true,
}); });
// const chunkSize = Math.round(concatenatedMessages.length / 512); // const chunkSize = Math.round(concatenatedMessages.length / 512);
const input_documents = userDocs.concat(assistantDocs); const input_documents = userDocs.concat(assistantDocs);
if (this.options.debug ) { if (this.options.debug ) {
console.debug(`Refining messages...`); console.debug('Refining messages...');
} }
try { try {
const res = await chain.call({ const res = await chain.call({
@ -290,13 +285,23 @@ class BaseClient {
await new Promise(resolve => setImmediate(resolve)); await new Promise(resolve => setImmediate(resolve));
} }
return { context: context.reverse(), remainingContextTokens, messagesToRefine: messagesToRefine.reverse(), refineIndex }; return {
context: context.reverse(),
remainingContextTokens,
messagesToRefine: messagesToRefine.reverse(),
refineIndex
};
} }
async handleContextStrategy({instructions, orderedMessages, formattedMessages}) { async handleContextStrategy({instructions, orderedMessages, formattedMessages}) {
let payload = this.addInstructions(formattedMessages, instructions); let payload = this.addInstructions(formattedMessages, instructions);
let orderedWithInstructions = this.addInstructions(orderedMessages, instructions); let orderedWithInstructions = this.addInstructions(orderedMessages, instructions);
let { context, remainingContextTokens, messagesToRefine, refineIndex } = await this.getMessagesWithinTokenLimit(payload); let {
context,
remainingContextTokens,
messagesToRefine,
refineIndex
} = await this.getMessagesWithinTokenLimit(payload);
payload = context; payload = context;
let refinedMessage; let refinedMessage;
@ -380,6 +385,8 @@ class BaseClient {
let { prompt: payload, tokenCountMap, promptTokens } = await this.buildMessages( let { prompt: payload, tokenCountMap, promptTokens } = await this.buildMessages(
this.currentMessages, this.currentMessages,
// When the userMessage is pushed to currentMessages, the parentMessage is the userMessageId.
// this only matters when buildMessages is utilizing the parentMessageId, and may vary on implementation
userMessage.messageId, userMessage.messageId,
this.getBuildMessagesOptions(opts), this.getBuildMessagesOptions(opts),
); );
@ -390,13 +397,16 @@ class BaseClient {
} }
if (tokenCountMap) { if (tokenCountMap) {
payload = payload.map((message, i) => { console.dir(tokenCountMap, { depth: null })
const { tokenCount, ...messageWithoutTokenCount } = message; if (tokenCountMap[userMessage.messageId]) {
// userMessage is always the last one in the payload userMessage.tokenCount = tokenCountMap[userMessage.messageId];
if (i === payload.length - 1) { console.log('userMessage.tokenCount', userMessage.tokenCount);
userMessage.tokenCount = message.tokenCount; console.log('userMessage', userMessage);
console.debug(`Token count for user message: ${tokenCount}`, `Instruction Tokens: ${tokenCountMap.instructions || 'N/A'}`); }
}
payload = payload.map((message) => {
const messageWithoutTokenCount = message;
delete messageWithoutTokenCount.tokenCount;
return messageWithoutTokenCount; return messageWithoutTokenCount;
}); });
this.handleTokenCountMap(tokenCountMap); this.handleTokenCountMap(tokenCountMap);

View file

@ -69,7 +69,7 @@ class OpenAIClient extends BaseClient {
} }
this.userLabel = this.options.userLabel || 'User'; this.userLabel = this.options.userLabel || 'User';
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT'; this.chatGptLabel = this.options.chatGptLabel || 'Assistant';
this.setupTokens(); this.setupTokens();
this.setupTokenizer(); this.setupTokenizer();

View file

@ -4,7 +4,7 @@ const { CallbackManager } = require('langchain/callbacks');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/'); const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/');
const { loadTools } = require('./tools/util'); const { loadTools } = require('./tools/util');
const { SelfReflectionTool } = require('./tools/'); const { SelfReflectionTool } = require('./tools/');
const { HumanChatMessage, AIChatMessage } = require('langchain/schema'); const { HumanMessage, AIMessage } = require('langchain/schema');
const { const {
instructions, instructions,
imageInstructions, imageInstructions,
@ -235,17 +235,10 @@ Only respond with your conversational reply to the following User Message:
}; };
// Map Messages to Langchain format // Map Messages to Langchain format
const pastMessages = this.currentMessages.map( const pastMessages = this.currentMessages.slice(0, -1).map(
msg => msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user' msg => msg?.isCreatedByUser || msg?.role?.toLowerCase() === 'user'
? new HumanChatMessage(msg.text) ? new HumanMessage(msg.text)
: new AIChatMessage(msg.text)); : new AIMessage(msg.text));
if (this.options.debug) {
console.debug('Current Messages');
console.debug(this.currentMessages);
console.debug('Past Messages');
console.debug(pastMessages);
}
// initialize agent // initialize agent
const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent; const initializer = this.functionsAgent ? initializeFunctionsAgent : initializeCustomAgent;
@ -351,6 +344,8 @@ Only respond with your conversational reply to the following User Message:
onChainEnd, onChainEnd,
} = await this.handleStartMethods(message, opts); } = await this.handleStartMethods(message, opts);
this.currentMessages.push(userMessage);
let { prompt: payload, tokenCountMap, promptTokens, messages } = await this.buildMessages( let { prompt: payload, tokenCountMap, promptTokens, messages } = await this.buildMessages(
this.currentMessages, this.currentMessages,
userMessage.messageId, userMessage.messageId,
@ -360,19 +355,15 @@ Only respond with your conversational reply to the following User Message:
}), }),
); );
if (this.options.debug) {
console.debug('buildMessages: Messages');
console.debug(messages);
}
if (tokenCountMap) { if (tokenCountMap) {
payload = payload.map((message, i) => { console.dir(tokenCountMap, { depth: null })
const { tokenCount, ...messageWithoutTokenCount } = message; if (tokenCountMap[userMessage.messageId]) {
// userMessage is always the last one in the payload userMessage.tokenCount = tokenCountMap[userMessage.messageId];
if (i === payload.length - 1) { console.log('userMessage.tokenCount', userMessage.tokenCount);
userMessage.tokenCount = message.tokenCount; }
console.debug(`Token count for user message: ${tokenCount}`, `Instruction Tokens: ${tokenCountMap.instructions || 'N/A'}`); payload = payload.map((message) => {
} const messageWithoutTokenCount = message;
delete messageWithoutTokenCount.tokenCount;
return messageWithoutTokenCount; return messageWithoutTokenCount;
}); });
this.handleTokenCountMap(tokenCountMap); this.handleTokenCountMap(tokenCountMap);
@ -482,25 +473,15 @@ Only respond with your conversational reply to the following User Message:
let promptBody = ''; let promptBody = '';
const maxTokenCount = this.maxPromptTokens; const maxTokenCount = this.maxPromptTokens;
// Iterate backwards through the messages, adding them to the prompt until we reach the max token count. // Iterate backwards through the messages, adding them to the prompt until we reach the max token count.
// Do this within a recursive async function so that it doesn't block the event loop for too long. // Do this within a recursive async function so that it doesn't block the event loop for too long.
const buildPromptBody = async () => { const buildPromptBody = async () => {
if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) { if (currentTokenCount < maxTokenCount && orderedMessages.length > 0) {
const message = orderedMessages.pop(); const message = orderedMessages.pop();
// const roleLabel = message.role === 'User' ? this.userLabel : this.chatGptLabel; const isCreatedByUser = message.isCreatedByUser || message.role?.toLowerCase() === 'user';
const roleLabel = message.role; const roleLabel = isCreatedByUser ? this.userLabel : this.chatGptLabel;
let messageString = `${this.startToken}${roleLabel}:\n${message.text}${this.endToken}\n`; let messageString = `${this.startToken}${roleLabel}:\n${message.text}${this.endToken}\n`;
let newPromptBody; let newPromptBody = `${messageString}${promptBody}`;
if (promptBody) {
newPromptBody = `${messageString}${promptBody}`;
} else {
// Always insert prompt prefix before the last user message, if not gpt-3.5-turbo.
// This makes the AI obey the prompt instructions better, which is important for custom instructions.
// After a bunch of testing, it doesn't seem to cause the AI any confusion, even if you ask it things
// like "what's the last thing I wrote?".
newPromptBody = `${promptPrefix}${messageString}${promptBody}`;
}
const tokenCountForMessage = this.getTokenCount(messageString); const tokenCountForMessage = this.getTokenCount(messageString);
const newTokenCount = currentTokenCount + tokenCountForMessage; const newTokenCount = currentTokenCount + tokenCountForMessage;
@ -549,7 +530,7 @@ Only respond with your conversational reply to the following User Message:
const result = [messagePayload, instructionsPayload]; const result = [messagePayload, instructionsPayload];
if (this.functionsAgent && !this.isGpt3) { if (this.functionsAgent && !this.isGpt3) {
result[1].content = `${result[1].content}\nSure thing! Here is the output you requested:\n`; result[1].content = `${result[1].content}\n${this.startToken}${this.chatGptLabel}:\nSure thing! Here is the output you requested:\n`;
} }
return result.filter((message) => message.content.length > 0); return result.filter((message) => message.content.length > 0);