mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-09-22 06:00:56 +02:00
feat(experimental): FunctionsAgent, uses new function payload for tooling
This commit is contained in:
parent
550e566097
commit
3caddd6854
8 changed files with 227 additions and 52 deletions
|
@ -6,11 +6,11 @@ const {
|
|||
} = require('@dqbd/tiktoken');
|
||||
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
||||
const { Agent, ProxyAgent } = require('undici');
|
||||
const TextStream = require('../stream');
|
||||
// const TextStream = require('../stream');
|
||||
const { ChatOpenAI } = require('langchain/chat_models/openai');
|
||||
const { CallbackManager } = require('langchain/callbacks');
|
||||
const { HumanChatMessage, AIChatMessage } = require('langchain/schema');
|
||||
const { initializeCustomAgent } = require('./agents/CustomAgent/initializeCustomAgent');
|
||||
const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/');
|
||||
const { getMessages, saveMessage, saveConvo } = require('../../models');
|
||||
const { loadTools, SelfReflectionTool } = require('./tools');
|
||||
const {
|
||||
|
@ -106,10 +106,10 @@ class ChatAgent {
|
|||
const preliminaryAnswer =
|
||||
result.output?.length > 0 ? `Preliminary Answer: "${result.output.trim()}"` : '';
|
||||
const prefix = preliminaryAnswer
|
||||
? `review and improve the answer you generated using plugins in response to the User Message below. The answer hasn't been sent to the user yet.`
|
||||
? `review and improve the answer you generated using plugins in response to the User Message below. The user hasn't seen your answer or thoughts yet.`
|
||||
: 'respond to the User Message below based on your preliminary thoughts & actions.';
|
||||
|
||||
return `As ChatGPT, ${prefix}${errorMessage}\n${internalActions}
|
||||
return `As a helpful AI Assistant, ${prefix}${errorMessage}\n${internalActions}
|
||||
${preliminaryAnswer}
|
||||
Reply conversationally to the User based on your ${
|
||||
preliminaryAnswer ? 'preliminary answer, ' : ''
|
||||
|
@ -145,8 +145,7 @@ Only respond with your conversational reply to the following User Message:
|
|||
this.options = options;
|
||||
}
|
||||
|
||||
this.agentOptions = this.options.agentOptions || {};
|
||||
this.agentIsGpt3 = this.agentOptions.model.startsWith('gpt-3');
|
||||
|
||||
const modelOptions = this.options.modelOptions || {};
|
||||
this.modelOptions = {
|
||||
...modelOptions,
|
||||
|
@ -160,10 +159,27 @@ Only respond with your conversational reply to the following User Message:
|
|||
stop: modelOptions.stop
|
||||
};
|
||||
|
||||
this.agentOptions = this.options.agentOptions || {};
|
||||
this.functionsAgent = this.agentOptions.agent === 'functions';
|
||||
this.agentIsGpt3 = this.agentOptions.model.startsWith('gpt-3');
|
||||
if (this.functionsAgent) {
|
||||
this.agentOptions.model = this.getFunctionModelName(this.agentOptions.model);
|
||||
}
|
||||
|
||||
this.isChatGptModel = this.modelOptions.model.startsWith('gpt-');
|
||||
this.isGpt3 = this.modelOptions.model.startsWith('gpt-3');
|
||||
this.maxContextTokens = this.modelOptions.model === 'gpt-4-32k' ? 32767 : this.modelOptions.model.startsWith('gpt-4') ? 8191 : 4095,
|
||||
|
||||
const maxTokensMap = {
|
||||
'gpt-4': 8191,
|
||||
'gpt-4-0613': 8191,
|
||||
'gpt-4-32k': 32767,
|
||||
'gpt-4-32k-0613': 32767,
|
||||
'gpt-3.5-turbo': 4095,
|
||||
'gpt-3.5-turbo-0613': 4095,
|
||||
'gpt-3.5-turbo-0301': 4095,
|
||||
'gpt-3.5-turbo-16k': 15999,
|
||||
};
|
||||
|
||||
this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4095; // 1 less than maximum
|
||||
// Reserve 1024 tokens for the response.
|
||||
// The max prompt tokens is determined by the max context tokens minus the max response tokens.
|
||||
// Earlier messages will be dropped until the prompt is within the limit.
|
||||
|
@ -180,7 +196,7 @@ Only respond with your conversational reply to the following User Message:
|
|||
}
|
||||
|
||||
this.userLabel = this.options.userLabel || 'User';
|
||||
this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT';
|
||||
this.chatGptLabel = this.options.chatGptLabel || 'Assistant';
|
||||
|
||||
// Use these faux tokens to help the AI understand the context since we are building the chat log ourselves.
|
||||
// Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason,
|
||||
|
@ -388,6 +404,26 @@ Only respond with your conversational reply to the following User Message:
|
|||
this.actions.push(action);
|
||||
}
|
||||
|
||||
getFunctionModelName(input) {
|
||||
const prefixMap = {
|
||||
'gpt-4': 'gpt-4-0613',
|
||||
'gpt-4-32k': 'gpt-4-32k-0613',
|
||||
'gpt-3.5-turbo': 'gpt-3.5-turbo-0613'
|
||||
};
|
||||
|
||||
const prefix = Object.keys(prefixMap).find(key => input.startsWith(key));
|
||||
return prefix ? prefixMap[prefix] : 'gpt-3.5-turbo-0613';
|
||||
}
|
||||
|
||||
createLLM(modelOptions, configOptions) {
|
||||
let credentials = { openAIApiKey: this.openAIApiKey };
|
||||
if (this.azure) {
|
||||
credentials = { ...this.azure };
|
||||
}
|
||||
|
||||
return new ChatOpenAI({ credentials, ...modelOptions }, configOptions);
|
||||
}
|
||||
|
||||
async initialize({ user, message, onAgentAction, onChainEnd, signal }) {
|
||||
const modelOptions = {
|
||||
modelName: this.agentOptions.model,
|
||||
|
@ -400,21 +436,7 @@ Only respond with your conversational reply to the following User Message:
|
|||
configOptions.basePath = this.langchainProxy;
|
||||
}
|
||||
|
||||
const model = this.azure
|
||||
? new ChatOpenAI({
|
||||
...this.azure,
|
||||
...modelOptions
|
||||
})
|
||||
: new ChatOpenAI(
|
||||
{
|
||||
openAIApiKey: this.openAIApiKey,
|
||||
...modelOptions
|
||||
},
|
||||
configOptions
|
||||
// {
|
||||
// basePath: 'http://localhost:8080/v1'
|
||||
// }
|
||||
);
|
||||
const model = this.createLLM(modelOptions, configOptions);
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug(`<-----Agent Model: ${model.modelName} | Temp: ${model.temperature}----->`);
|
||||
|
@ -466,7 +488,8 @@ Only respond with your conversational reply to the following User Message:
|
|||
};
|
||||
|
||||
// initialize agent
|
||||
this.executor = await initializeCustomAgent({
|
||||
const initializer = this.options.agentOptions?.agent === 'functions' ? initializeFunctionsAgent : initializeCustomAgent;
|
||||
this.executor = await initializer({
|
||||
model,
|
||||
signal,
|
||||
tools: this.tools,
|
||||
|
@ -594,7 +617,7 @@ Only respond with your conversational reply to the following User Message:
|
|||
console.log('sendMessage', message, opts);
|
||||
|
||||
const user = opts.user || null;
|
||||
const { onAgentAction, onChainEnd, onProgress } = opts;
|
||||
const { onAgentAction, onChainEnd } = opts;
|
||||
const conversationId = opts.conversationId || crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000';
|
||||
const userMessageId = opts.overrideParentMessageId || crypto.randomUUID();
|
||||
|
@ -658,13 +681,13 @@ Only respond with your conversational reply to the following User Message:
|
|||
return { ...responseMessage, ...this.result };
|
||||
}
|
||||
|
||||
if (!this.agentIsGpt3 && this.result.output) {
|
||||
responseMessage.text = this.result.output;
|
||||
await this.saveMessageToDatabase(responseMessage, user);
|
||||
const textStream = new TextStream(this.result.output);
|
||||
await textStream.processTextStream(onProgress);
|
||||
return { ...responseMessage, ...this.result };
|
||||
}
|
||||
// if (!this.agentIsGpt3 && this.result.output) {
|
||||
// responseMessage.text = this.result.output;
|
||||
// await this.saveMessageToDatabase(responseMessage, user);
|
||||
// const textStream = new TextStream(this.result.output);
|
||||
// await textStream.processTextStream(opts.onProgress);
|
||||
// return { ...responseMessage, ...this.result };
|
||||
// }
|
||||
|
||||
if (this.options.debug) {
|
||||
console.debug('this.result', this.result);
|
||||
|
@ -871,7 +894,7 @@ Only respond with your conversational reply to the following User Message:
|
|||
return orderedMessages.map((msg) => ({
|
||||
messageId: msg.messageId,
|
||||
parentMessageId: msg.parentMessageId,
|
||||
role: msg.isCreatedByUser ? 'User' : 'ChatGPT',
|
||||
role: msg.isCreatedByUser ? 'User' : 'Assistant',
|
||||
text: msg.text
|
||||
}));
|
||||
}
|
||||
|
|
|
@ -51,6 +51,4 @@ Query: {input}
|
|||
return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest });
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
initializeCustomAgent
|
||||
};
|
||||
module.exports = initializeCustomAgent;
|
||||
|
|
122
api/app/langchain/agents/Functions/FunctionsAgent.js
Normal file
122
api/app/langchain/agents/Functions/FunctionsAgent.js
Normal file
|
@ -0,0 +1,122 @@
|
|||
const { Agent } = require('langchain/agents');
|
||||
const { LLMChain } = require('langchain/chains');
|
||||
const { FunctionChatMessage, AIChatMessage } = require('langchain/schema');
|
||||
const {
|
||||
ChatPromptTemplate,
|
||||
MessagesPlaceholder,
|
||||
SystemMessagePromptTemplate,
|
||||
HumanMessagePromptTemplate
|
||||
} = require('langchain/prompts');
|
||||
const PREFIX = `You are a helpful AI assistant. Objective: Understand the human's query with available functions.
|
||||
The user is expecting a function response to the query; if only part of the query involves a function, prioritize the function response.`;
|
||||
|
||||
function parseOutput(message) {
|
||||
if (message.additional_kwargs.function_call) {
|
||||
const function_call = message.additional_kwargs.function_call;
|
||||
return {
|
||||
tool: function_call.name,
|
||||
toolInput: function_call.arguments ? JSON.parse(function_call.arguments) : {},
|
||||
log: message.text
|
||||
};
|
||||
} else {
|
||||
return { returnValues: { output: message.text }, log: message.text };
|
||||
}
|
||||
}
|
||||
|
||||
class FunctionsAgent extends Agent {
|
||||
constructor(input) {
|
||||
super({ ...input, outputParser: undefined });
|
||||
this.tools = input.tools;
|
||||
}
|
||||
|
||||
lc_namespace = ['langchain', 'agents', 'openai'];
|
||||
|
||||
_agentType() {
|
||||
return 'openai-functions';
|
||||
}
|
||||
|
||||
observationPrefix() {
|
||||
return 'Observation: ';
|
||||
}
|
||||
|
||||
llmPrefix() {
|
||||
return 'Thought:';
|
||||
}
|
||||
|
||||
_stop() {
|
||||
return ['Observation:'];
|
||||
}
|
||||
|
||||
static createPrompt(_tools, fields) {
|
||||
const { prefix = PREFIX, currentDateString } = fields || {};
|
||||
|
||||
return ChatPromptTemplate.fromPromptMessages([
|
||||
SystemMessagePromptTemplate.fromTemplate(`Date: ${currentDateString}\n${prefix}`),
|
||||
HumanMessagePromptTemplate.fromTemplate(`{chat_history}
|
||||
Query: {input}
|
||||
{agent_scratchpad}`),
|
||||
new MessagesPlaceholder('agent_scratchpad')
|
||||
]);
|
||||
}
|
||||
|
||||
static fromLLMAndTools(llm, tools, args) {
|
||||
FunctionsAgent.validateTools(tools);
|
||||
const prompt = FunctionsAgent.createPrompt(tools, args);
|
||||
const chain = new LLMChain({
|
||||
prompt,
|
||||
llm,
|
||||
callbacks: args?.callbacks
|
||||
});
|
||||
return new FunctionsAgent({
|
||||
llmChain: chain,
|
||||
allowedTools: tools.map((t) => t.name),
|
||||
tools
|
||||
});
|
||||
}
|
||||
|
||||
async constructScratchPad(steps) {
|
||||
return steps.flatMap(({ action, observation }) => [
|
||||
new AIChatMessage('', {
|
||||
function_call: {
|
||||
name: action.tool,
|
||||
arguments: JSON.stringify(action.toolInput)
|
||||
}
|
||||
}),
|
||||
new FunctionChatMessage(observation, action.tool)
|
||||
]);
|
||||
}
|
||||
|
||||
async plan(steps, inputs, callbackManager) {
|
||||
// Add scratchpad and stop to inputs
|
||||
var thoughts = await this.constructScratchPad(steps);
|
||||
var newInputs = Object.assign({}, inputs, { agent_scratchpad: thoughts });
|
||||
if (this._stop().length !== 0) {
|
||||
newInputs.stop = this._stop();
|
||||
}
|
||||
|
||||
// Split inputs between prompt and llm
|
||||
var llm = this.llmChain.llm;
|
||||
var valuesForPrompt = Object.assign({}, newInputs);
|
||||
var valuesForLLM = {
|
||||
tools: this.tools
|
||||
};
|
||||
for (var i = 0; i < this.llmChain.llm.callKeys.length; i++) {
|
||||
var key = this.llmChain.llm.callKeys[i];
|
||||
if (key in inputs) {
|
||||
valuesForLLM[key] = inputs[key];
|
||||
delete valuesForPrompt[key];
|
||||
}
|
||||
}
|
||||
|
||||
var promptValue = await this.llmChain.prompt.formatPromptValue(valuesForPrompt);
|
||||
var message = await llm.predictMessages(
|
||||
promptValue.toChatMessages(),
|
||||
valuesForLLM,
|
||||
callbackManager
|
||||
);
|
||||
console.log('message', message);
|
||||
return parseOutput(message);
|
||||
}
|
||||
}
|
||||
|
||||
module.exports = FunctionsAgent;
|
|
@ -0,0 +1,33 @@
|
|||
const FunctionsAgent = require('./FunctionsAgent');
|
||||
const { AgentExecutor } = require('langchain/agents');
|
||||
const { BufferMemory, ChatMessageHistory } = require('langchain/memory');
|
||||
|
||||
const initializeFunctionsAgent = async ({
|
||||
tools,
|
||||
model,
|
||||
pastMessages,
|
||||
currentDateString,
|
||||
...rest
|
||||
}) => {
|
||||
const agent = FunctionsAgent.fromLLMAndTools(
|
||||
model,
|
||||
tools,
|
||||
{
|
||||
currentDateString,
|
||||
});
|
||||
|
||||
const memory = new BufferMemory({
|
||||
chatHistory: new ChatMessageHistory(pastMessages),
|
||||
// returnMessages: true, // commenting this out retains memory
|
||||
memoryKey: 'chat_history',
|
||||
humanPrefix: 'User',
|
||||
aiPrefix: 'Assistant',
|
||||
inputKey: 'input',
|
||||
outputKey: 'output'
|
||||
});
|
||||
|
||||
return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest });
|
||||
};
|
||||
|
||||
module.exports = initializeFunctionsAgent;
|
||||
|
7
api/app/langchain/agents/index.js
Normal file
7
api/app/langchain/agents/index.js
Normal file
|
@ -0,0 +1,7 @@
|
|||
const initializeCustomAgent = require('./CustomAgent/initializeCustomAgent');
|
||||
const initializeFunctionsAgent = require('./Functions/initializeFunctionsAgent');
|
||||
|
||||
module.exports = {
|
||||
initializeCustomAgent,
|
||||
initializeFunctionsAgent
|
||||
};
|
|
@ -46,7 +46,7 @@
|
|||
"jsonwebtoken": "^9.0.0",
|
||||
"keyv": "^4.5.2",
|
||||
"keyv-file": "^0.2.0",
|
||||
"langchain": "^0.0.92",
|
||||
"langchain": "^0.0.94",
|
||||
"lodash": "^4.17.21",
|
||||
"meilisearch": "^0.33.0",
|
||||
"mongoose": "^7.1.1",
|
||||
|
|
|
@ -39,8 +39,8 @@ router.post('/', requireJwtAuth, async (req, res) => {
|
|||
if (endpoint !== 'gptPlugins') return handleError(res, { text: 'Illegal request' });
|
||||
|
||||
const agentOptions = req.body?.agentOptions ?? {
|
||||
agent: 'classic',
|
||||
model: 'gpt-3.5-turbo',
|
||||
// model: 'gpt-4', // for agent model
|
||||
temperature: 0,
|
||||
// top_p: 1,
|
||||
// presence_penalty: 0,
|
||||
|
@ -60,20 +60,12 @@ router.post('/', requireJwtAuth, async (req, res) => {
|
|||
presence_penalty: req.body?.presence_penalty ?? 0,
|
||||
frequency_penalty: req.body?.frequency_penalty ?? 0
|
||||
},
|
||||
agentOptions
|
||||
agentOptions: {
|
||||
...agentOptions,
|
||||
// agent: 'functions'
|
||||
}
|
||||
};
|
||||
|
||||
// const availableModels = getOpenAIModels();
|
||||
// if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) {
|
||||
// return handleError(res, { text: `Illegal request: model` });
|
||||
// }
|
||||
|
||||
// console.log('ask log', {
|
||||
// text,
|
||||
// conversationId,
|
||||
// endpointOption
|
||||
// });
|
||||
|
||||
console.log('ask log');
|
||||
console.dir({ text, conversationId, endpointOption }, { depth: null });
|
||||
|
||||
|
|
|
@ -53,7 +53,7 @@ router.get('/', async function (req, res) {
|
|||
? { availableModels: getOpenAIModels(), userProvide: apiKey === 'user_provided' }
|
||||
: false;
|
||||
const gptPlugins = apiKey
|
||||
? { availableModels: getPluginModels(), availableTools }
|
||||
? { availableModels: getPluginModels(), availableTools, availableAgents: ['classic', 'functions'] }
|
||||
: false;
|
||||
const bingAI = process.env.BINGAI_TOKEN
|
||||
? { userProvide: process.env.BINGAI_TOKEN == 'user_provided' }
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue