feat(experimental): FunctionsAgent, uses new function payload for tooling

This commit is contained in:
Daniel Avila 2023-06-13 23:39:22 -04:00 committed by Danny Avila
parent 550e566097
commit 3caddd6854
8 changed files with 227 additions and 52 deletions

View file

@ -6,11 +6,11 @@ const {
} = require('@dqbd/tiktoken');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { Agent, ProxyAgent } = require('undici');
const TextStream = require('../stream');
// const TextStream = require('../stream');
const { ChatOpenAI } = require('langchain/chat_models/openai');
const { CallbackManager } = require('langchain/callbacks');
const { HumanChatMessage, AIChatMessage } = require('langchain/schema');
const { initializeCustomAgent } = require('./agents/CustomAgent/initializeCustomAgent');
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/');
const { getMessages, saveMessage, saveConvo } = require('../../models');
const { loadTools, SelfReflectionTool } = require('./tools');
const {
@ -106,10 +106,10 @@ class ChatAgent {
const preliminaryAnswer =
result.output?.length > 0 ? `Preliminary Answer: "${result.output.trim()}"` : '';
const prefix = preliminaryAnswer
? `review and improve the answer you generated using plugins in response to the User Message below. The answer hasn't been sent to the user yet.`
? `review and improve the answer you generated using plugins in response to the User Message below. The user hasn't seen your answer or thoughts yet.`
: 'respond to the User Message below based on your preliminary thoughts & actions.';
return `As ChatGPT, ${prefix}${errorMessage}\n${internalActions}
return `As a helpful AI Assistant, ${prefix}${errorMessage}\n${internalActions}
${preliminaryAnswer}
Reply conversationally to the User based on your ${
preliminaryAnswer ? 'preliminary answer, ' : ''
@ -145,8 +145,7 @@ Only respond with your conversational reply to the following User Message:
this.options = options;
}
this.agentOptions = this.options.agentOptions || {};
this.agentIsGpt3 = this.agentOptions.model.startsWith('gpt-3');
const modelOptions = this.options.modelOptions || {};
this.modelOptions = {
...modelOptions,
@ -160,10 +159,27 @@ Only respond with your conversational reply to the following User Message:
stop: modelOptions.stop
};
this.agentOptions = this.options.agentOptions || {};
this.functionsAgent = this.agentOptions.agent === 'functions';
this.agentIsGpt3 = this.agentOptions.model.startsWith('gpt-3');
if (this.functionsAgent) {
this.agentOptions.model = this.getFunctionModelName(this.agentOptions.model);
}
this.isChatGptModel = this.modelOptions.model.startsWith('gpt-');
this.isGpt3 = this.modelOptions.model.startsWith('gpt-3');
this.maxContextTokens = this.modelOptions.model === 'gpt-4-32k' ? 32767 : this.modelOptions.model.startsWith('gpt-4') ? 8191 : 4095,
const maxTokensMap = {
'gpt-4': 8191,
'gpt-4-0613': 8191,
'gpt-4-32k': 32767,
'gpt-4-32k-0613': 32767,
'gpt-3.5-turbo': 4095,
'gpt-3.5-turbo-0613': 4095,
'gpt-3.5-turbo-0301': 4095,
'gpt-3.5-turbo-16k': 15999,
};
this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4095; // 1 less than maximum
// Reserve 1024 tokens for the response.
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
// Earlier messages will be dropped until the prompt is within the limit.
@ -180,7 +196,7 @@ Only respond with your conversational reply to the following User Message:
}
this.userLabel = this.options.userLabel || 'User';
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
this.chatGptLabel = this.options.chatGptLabel || 'Assistant';
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
@ -388,6 +404,26 @@ Only respond with your conversational reply to the following User Message:
this.actions.push(action);
}
getFunctionModelName(input) {
const prefixMap = {
'gpt-4': 'gpt-4-0613',
'gpt-4-32k': 'gpt-4-32k-0613',
'gpt-3.5-turbo': 'gpt-3.5-turbo-0613'
};
const prefix = Object.keys(prefixMap).find(key => input.startsWith(key));
return prefix ? prefixMap[prefix] : 'gpt-3.5-turbo-0613';
}
createLLM(modelOptions, configOptions) {
let credentials = { openAIApiKey: this.openAIApiKey };
if (this.azure) {
credentials = { ...this.azure };
}
return new ChatOpenAI({ credentials, ...modelOptions }, configOptions);
}
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
const modelOptions = {
modelName: this.agentOptions.model,
@ -400,21 +436,7 @@ Only respond with your conversational reply to the following User Message:
configOptions.basePath = this.langchainProxy;
}
const model = this.azure
? new ChatOpenAI({
...this.azure,
...modelOptions
})
: new ChatOpenAI(
{
openAIApiKey: this.openAIApiKey,
...modelOptions
},
configOptions
// {
// basePath: 'http://localhost:8080/v1'
// }
);
const model = this.createLLM(modelOptions, configOptions);
if (this.options.debug) {
console.debug(`<-----Agent Model: ${model.modelName} | Temp: ${model.temperature}----->`);
@ -466,7 +488,8 @@ Only respond with your conversational reply to the following User Message:
};
// initialize agent
this.executor = await initializeCustomAgent({
const initializer = this.options.agentOptions?.agent === 'functions' ? initializeFunctionsAgent : initializeCustomAgent;
this.executor = await initializer({
model,
signal,
tools: this.tools,
@ -594,7 +617,7 @@ Only respond with your conversational reply to the following User Message:
console.log('sendMessage', message, opts);
const user = opts.user || null;
const { onAgentAction, onChainEnd, onProgress } = opts;
const { onAgentAction, onChainEnd } = opts;
const conversationId = opts.conversationId || crypto.randomUUID();
const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000';
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
@ -658,13 +681,13 @@ Only respond with your conversational reply to the following User Message:
return { ...responseMessage, ...this.result };
}
if (!this.agentIsGpt3 && this.result.output) {
responseMessage.text = this.result.output;
await this.saveMessageToDatabase(responseMessage, user);
const textStream = new TextStream(this.result.output);
await textStream.processTextStream(onProgress);
return { ...responseMessage, ...this.result };
}
// if (!this.agentIsGpt3 && this.result.output) {
// responseMessage.text = this.result.output;
// await this.saveMessageToDatabase(responseMessage, user);
// const textStream = new TextStream(this.result.output);
// await textStream.processTextStream(opts.onProgress);
// return { ...responseMessage, ...this.result };
// }
if (this.options.debug) {
console.debug('this.result', this.result);
@ -871,7 +894,7 @@ Only respond with your conversational reply to the following User Message:
return orderedMessages.map((msg) => ({
messageId: msg.messageId,
parentMessageId: msg.parentMessageId,
role: msg.isCreatedByUser ? 'User' : 'ChatGPT',
role: msg.isCreatedByUser ? 'User' : 'Assistant',
text: msg.text
}));
}

View file

@ -51,6 +51,4 @@ Query: {input}
return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest });
};
module.exports = {
initializeCustomAgent
};
module.exports = initializeCustomAgent;

View file

@ -0,0 +1,122 @@
const { Agent } = require('langchain/agents');
const { LLMChain } = require('langchain/chains');
const { FunctionChatMessage, AIChatMessage } = require('langchain/schema');
const {
ChatPromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate
} = require('langchain/prompts');
const PREFIX = `You are a helpful AI assistant. Objective: Understand the human's query with available functions.
The user is expecting a function response to the query; if only part of the query involves a function, prioritize the function response.`;
function parseOutput(message) {
if (message.additional_kwargs.function_call) {
const function_call = message.additional_kwargs.function_call;
return {
tool: function_call.name,
toolInput: function_call.arguments ? JSON.parse(function_call.arguments) : {},
log: message.text
};
} else {
return { returnValues: { output: message.text }, log: message.text };
}
}
class FunctionsAgent extends Agent {
constructor(input) {
super({ ...input, outputParser: undefined });
this.tools = input.tools;
}
lc_namespace = ['langchain', 'agents', 'openai'];
_agentType() {
return 'openai-functions';
}
observationPrefix() {
return 'Observation: ';
}
llmPrefix() {
return 'Thought:';
}
_stop() {
return ['Observation:'];
}
static createPrompt(_tools, fields) {
const { prefix = PREFIX, currentDateString } = fields || {};
return ChatPromptTemplate.fromPromptMessages([
SystemMessagePromptTemplate.fromTemplate(`Date: ${currentDateString}\n${prefix}`),
HumanMessagePromptTemplate.fromTemplate(`{chat_history}
Query: {input}
{agent_scratchpad}`),
new MessagesPlaceholder('agent_scratchpad')
]);
}
static fromLLMAndTools(llm, tools, args) {
FunctionsAgent.validateTools(tools);
const prompt = FunctionsAgent.createPrompt(tools, args);
const chain = new LLMChain({
prompt,
llm,
callbacks: args?.callbacks
});
return new FunctionsAgent({
llmChain: chain,
allowedTools: tools.map((t) => t.name),
tools
});
}
async constructScratchPad(steps) {
return steps.flatMap(({ action, observation }) => [
new AIChatMessage('', {
function_call: {
name: action.tool,
arguments: JSON.stringify(action.toolInput)
}
}),
new FunctionChatMessage(observation, action.tool)
]);
}
async plan(steps, inputs, callbackManager) {
// Add scratchpad and stop to inputs
var thoughts = await this.constructScratchPad(steps);
var newInputs = Object.assign({}, inputs, { agent_scratchpad: thoughts });
if (this._stop().length !== 0) {
newInputs.stop = this._stop();
}
// Split inputs between prompt and llm
var llm = this.llmChain.llm;
var valuesForPrompt = Object.assign({}, newInputs);
var valuesForLLM = {
tools: this.tools
};
for (var i = 0; i < this.llmChain.llm.callKeys.length; i++) {
var key = this.llmChain.llm.callKeys[i];
if (key in inputs) {
valuesForLLM[key] = inputs[key];
delete valuesForPrompt[key];
}
}
var promptValue = await this.llmChain.prompt.formatPromptValue(valuesForPrompt);
var message = await llm.predictMessages(
promptValue.toChatMessages(),
valuesForLLM,
callbackManager
);
console.log('message', message);
return parseOutput(message);
}
}
module.exports = FunctionsAgent;

View file

@ -0,0 +1,33 @@
const FunctionsAgent = require('./FunctionsAgent');
const { AgentExecutor } = require('langchain/agents');
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
const initializeFunctionsAgent = async ({
tools,
model,
pastMessages,
currentDateString,
...rest
}) => {
const agent = FunctionsAgent.fromLLMAndTools(
model,
tools,
{
currentDateString,
});
const memory = new BufferMemory({
chatHistory: new ChatMessageHistory(pastMessages),
// returnMessages: true, // commenting this out retains memory
memoryKey: 'chat_history',
humanPrefix: 'User',
aiPrefix: 'Assistant',
inputKey: 'input',
outputKey: 'output'
});
return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest });
};
module.exports = initializeFunctionsAgent;

View file

@ -0,0 +1,7 @@
const initializeCustomAgent = require('./CustomAgent/initializeCustomAgent');
const initializeFunctionsAgent = require('./Functions/initializeFunctionsAgent');
module.exports = {
initializeCustomAgent,
initializeFunctionsAgent
};

View file

@ -46,7 +46,7 @@
"jsonwebtoken": "^9.0.0",
"keyv": "^4.5.2",
"keyv-file": "^0.2.0",
"langchain": "^0.0.92",
"langchain": "^0.0.94",
"lodash": "^4.17.21",
"meilisearch": "^0.33.0",
"mongoose": "^7.1.1",

View file

@ -39,8 +39,8 @@ router.post('/', requireJwtAuth, async (req, res) => {
if (endpoint !== 'gptPlugins') return handleError(res, { text: 'Illegal request' });
const agentOptions = req.body?.agentOptions ?? {
agent: 'classic',
model: 'gpt-3.5-turbo',
// model: 'gpt-4', // for agent model
temperature: 0,
// top_p: 1,
// presence_penalty: 0,
@ -60,20 +60,12 @@ router.post('/', requireJwtAuth, async (req, res) => {
presence_penalty: req.body?.presence_penalty ?? 0,
frequency_penalty: req.body?.frequency_penalty ?? 0
},
agentOptions
agentOptions: {
...agentOptions,
// agent: 'functions'
}
};
// const availableModels = getOpenAIModels();
// if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) {
// return handleError(res, { text: `Illegal request: model` });
// }
// console.log('ask log', {
// text,
// conversationId,
// endpointOption
// });
console.log('ask log');
console.dir({ text, conversationId, endpointOption }, { depth: null });

View file

@ -53,7 +53,7 @@ router.get('/', async function (req, res) {
? { availableModels: getOpenAIModels(), userProvide: apiKey === 'user_provided' }
: false;
const gptPlugins = apiKey
? { availableModels: getPluginModels(), availableTools }
? { availableModels: getPluginModels(), availableTools, availableAgents: ['classic', 'functions'] }
: false;
const bingAI = process.env.BINGAI_TOKEN
? { userProvide: process.env.BINGAI_TOKEN == 'user_provided' }