diff --git a/api/app/langchain/ChatAgent.js b/api/app/langchain/ChatAgent.js index ac2b418fa..1f255425d 100644 --- a/api/app/langchain/ChatAgent.js +++ b/api/app/langchain/ChatAgent.js @@ -6,11 +6,11 @@ const { } = require('@dqbd/tiktoken'); const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source'); const { Agent, ProxyAgent } = require('undici'); -const TextStream = require('../stream'); +// const TextStream = require('../stream'); const { ChatOpenAI } = require('langchain/chat_models/openai'); const { CallbackManager } = require('langchain/callbacks'); const { HumanChatMessage, AIChatMessage } = require('langchain/schema'); -const { initializeCustomAgent } = require('./agents/CustomAgent/initializeCustomAgent'); +const { initializeCustomAgent, initializeFunctionsAgent } = require('./agents/'); const { getMessages, saveMessage, saveConvo } = require('../../models'); const { loadTools, SelfReflectionTool } = require('./tools'); const { @@ -106,10 +106,10 @@ class ChatAgent { const preliminaryAnswer = result.output?.length > 0 ? `Preliminary Answer: "${result.output.trim()}"` : ''; const prefix = preliminaryAnswer - ? `review and improve the answer you generated using plugins in response to the User Message below. The answer hasn't been sent to the user yet.` + ? `review and improve the answer you generated using plugins in response to the User Message below. The user hasn't seen your answer or thoughts yet.` : 'respond to the User Message below based on your preliminary thoughts & actions.'; - return `As ChatGPT, ${prefix}${errorMessage}\n${internalActions} + return `As a helpful AI Assistant, ${prefix}${errorMessage}\n${internalActions} ${preliminaryAnswer} Reply conversationally to the User based on your ${ preliminaryAnswer ? 'preliminary answer, ' : '' @@ -145,8 +145,7 @@ Only respond with your conversational reply to the following User Message: this.options = options; } - this.agentOptions = this.options.agentOptions || {}; - this.agentIsGpt3 = this.agentOptions.model.startsWith('gpt-3'); + const modelOptions = this.options.modelOptions || {}; this.modelOptions = { ...modelOptions, @@ -160,10 +159,27 @@ Only respond with your conversational reply to the following User Message: stop: modelOptions.stop }; + this.agentOptions = this.options.agentOptions || {}; + this.functionsAgent = this.agentOptions.agent === 'functions'; + this.agentIsGpt3 = this.agentOptions.model.startsWith('gpt-3'); + if (this.functionsAgent) { + this.agentOptions.model = this.getFunctionModelName(this.agentOptions.model); + } + this.isChatGptModel = this.modelOptions.model.startsWith('gpt-'); this.isGpt3 = this.modelOptions.model.startsWith('gpt-3'); - this.maxContextTokens = this.modelOptions.model === 'gpt-4-32k' ? 32767 : this.modelOptions.model.startsWith('gpt-4') ? 8191 : 4095, - + const maxTokensMap = { + 'gpt-4': 8191, + 'gpt-4-0613': 8191, + 'gpt-4-32k': 32767, + 'gpt-4-32k-0613': 32767, + 'gpt-3.5-turbo': 4095, + 'gpt-3.5-turbo-0613': 4095, + 'gpt-3.5-turbo-0301': 4095, + 'gpt-3.5-turbo-16k': 15999, + }; + + this.maxContextTokens = maxTokensMap[this.modelOptions.model] ?? 4095; // 1 less than maximum // Reserve 1024 tokens for the response. // The max prompt tokens is determined by the max context tokens minus the max response tokens. // Earlier messages will be dropped until the prompt is within the limit. @@ -180,7 +196,7 @@ Only respond with your conversational reply to the following User Message: } this.userLabel = this.options.userLabel || 'User'; - this.chatGptLabel = this.options.chatGptLabel || 'ChatGPT'; + this.chatGptLabel = this.options.chatGptLabel || 'Assistant'; // Use these faux tokens to help the AI understand the context since we are building the chat log ourselves. // Trying to use "<|im_start|>" causes the AI to still generate "<" or "<|" at the end sometimes for some reason, @@ -388,6 +404,26 @@ Only respond with your conversational reply to the following User Message: this.actions.push(action); } + getFunctionModelName(input) { + const prefixMap = { + 'gpt-4': 'gpt-4-0613', + 'gpt-4-32k': 'gpt-4-32k-0613', + 'gpt-3.5-turbo': 'gpt-3.5-turbo-0613' + }; + + const prefix = Object.keys(prefixMap).find(key => input.startsWith(key)); + return prefix ? prefixMap[prefix] : 'gpt-3.5-turbo-0613'; + } + + createLLM(modelOptions, configOptions) { + let credentials = { openAIApiKey: this.openAIApiKey }; + if (this.azure) { + credentials = { ...this.azure }; + } + + return new ChatOpenAI({ credentials, ...modelOptions }, configOptions); + } + async initialize({ user, message, onAgentAction, onChainEnd, signal }) { const modelOptions = { modelName: this.agentOptions.model, @@ -400,21 +436,7 @@ Only respond with your conversational reply to the following User Message: configOptions.basePath = this.langchainProxy; } - const model = this.azure - ? new ChatOpenAI({ - ...this.azure, - ...modelOptions - }) - : new ChatOpenAI( - { - openAIApiKey: this.openAIApiKey, - ...modelOptions - }, - configOptions - // { - // basePath: 'http://localhost:8080/v1' - // } - ); + const model = this.createLLM(modelOptions, configOptions); if (this.options.debug) { console.debug(`<-----Agent Model: ${model.modelName} | Temp: ${model.temperature}----->`); @@ -466,7 +488,8 @@ Only respond with your conversational reply to the following User Message: }; // initialize agent - this.executor = await initializeCustomAgent({ + const initializer = this.options.agentOptions?.agent === 'functions' ? initializeFunctionsAgent : initializeCustomAgent; + this.executor = await initializer({ model, signal, tools: this.tools, @@ -594,7 +617,7 @@ Only respond with your conversational reply to the following User Message: console.log('sendMessage', message, opts); const user = opts.user || null; - const { onAgentAction, onChainEnd, onProgress } = opts; + const { onAgentAction, onChainEnd } = opts; const conversationId = opts.conversationId || crypto.randomUUID(); const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); @@ -658,13 +681,13 @@ Only respond with your conversational reply to the following User Message: return { ...responseMessage, ...this.result }; } - if (!this.agentIsGpt3 && this.result.output) { - responseMessage.text = this.result.output; - await this.saveMessageToDatabase(responseMessage, user); - const textStream = new TextStream(this.result.output); - await textStream.processTextStream(onProgress); - return { ...responseMessage, ...this.result }; - } + // if (!this.agentIsGpt3 && this.result.output) { + // responseMessage.text = this.result.output; + // await this.saveMessageToDatabase(responseMessage, user); + // const textStream = new TextStream(this.result.output); + // await textStream.processTextStream(opts.onProgress); + // return { ...responseMessage, ...this.result }; + // } if (this.options.debug) { console.debug('this.result', this.result); @@ -871,7 +894,7 @@ Only respond with your conversational reply to the following User Message: return orderedMessages.map((msg) => ({ messageId: msg.messageId, parentMessageId: msg.parentMessageId, - role: msg.isCreatedByUser ? 'User' : 'ChatGPT', + role: msg.isCreatedByUser ? 'User' : 'Assistant', text: msg.text })); } diff --git a/api/app/langchain/agents/CustomAgent/initializeCustomAgent.js b/api/app/langchain/agents/CustomAgent/initializeCustomAgent.js index 88b5bb289..4639feaa5 100644 --- a/api/app/langchain/agents/CustomAgent/initializeCustomAgent.js +++ b/api/app/langchain/agents/CustomAgent/initializeCustomAgent.js @@ -51,6 +51,4 @@ Query: {input} return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest }); }; -module.exports = { - initializeCustomAgent -}; +module.exports = initializeCustomAgent; diff --git a/api/app/langchain/agents/Functions/FunctionsAgent.js b/api/app/langchain/agents/Functions/FunctionsAgent.js new file mode 100644 index 000000000..7a4d45d05 --- /dev/null +++ b/api/app/langchain/agents/Functions/FunctionsAgent.js @@ -0,0 +1,122 @@ +const { Agent } = require('langchain/agents'); +const { LLMChain } = require('langchain/chains'); +const { FunctionChatMessage, AIChatMessage } = require('langchain/schema'); +const { + ChatPromptTemplate, + MessagesPlaceholder, + SystemMessagePromptTemplate, + HumanMessagePromptTemplate +} = require('langchain/prompts'); +const PREFIX = `You are a helpful AI assistant. Objective: Understand the human's query with available functions. +The user is expecting a function response to the query; if only part of the query involves a function, prioritize the function response.`; + +function parseOutput(message) { + if (message.additional_kwargs.function_call) { + const function_call = message.additional_kwargs.function_call; + return { + tool: function_call.name, + toolInput: function_call.arguments ? JSON.parse(function_call.arguments) : {}, + log: message.text + }; + } else { + return { returnValues: { output: message.text }, log: message.text }; + } +} + +class FunctionsAgent extends Agent { + constructor(input) { + super({ ...input, outputParser: undefined }); + this.tools = input.tools; + } + + lc_namespace = ['langchain', 'agents', 'openai']; + + _agentType() { + return 'openai-functions'; + } + + observationPrefix() { + return 'Observation: '; + } + + llmPrefix() { + return 'Thought:'; + } + + _stop() { + return ['Observation:']; + } + + static createPrompt(_tools, fields) { + const { prefix = PREFIX, currentDateString } = fields || {}; + + return ChatPromptTemplate.fromPromptMessages([ + SystemMessagePromptTemplate.fromTemplate(`Date: ${currentDateString}\n${prefix}`), + HumanMessagePromptTemplate.fromTemplate(`{chat_history} +Query: {input} +{agent_scratchpad}`), + new MessagesPlaceholder('agent_scratchpad') + ]); + } + + static fromLLMAndTools(llm, tools, args) { + FunctionsAgent.validateTools(tools); + const prompt = FunctionsAgent.createPrompt(tools, args); + const chain = new LLMChain({ + prompt, + llm, + callbacks: args?.callbacks + }); + return new FunctionsAgent({ + llmChain: chain, + allowedTools: tools.map((t) => t.name), + tools + }); + } + + async constructScratchPad(steps) { + return steps.flatMap(({ action, observation }) => [ + new AIChatMessage('', { + function_call: { + name: action.tool, + arguments: JSON.stringify(action.toolInput) + } + }), + new FunctionChatMessage(observation, action.tool) + ]); + } + + async plan(steps, inputs, callbackManager) { + // Add scratchpad and stop to inputs + var thoughts = await this.constructScratchPad(steps); + var newInputs = Object.assign({}, inputs, { agent_scratchpad: thoughts }); + if (this._stop().length !== 0) { + newInputs.stop = this._stop(); + } + + // Split inputs between prompt and llm + var llm = this.llmChain.llm; + var valuesForPrompt = Object.assign({}, newInputs); + var valuesForLLM = { + tools: this.tools + }; + for (var i = 0; i < this.llmChain.llm.callKeys.length; i++) { + var key = this.llmChain.llm.callKeys[i]; + if (key in inputs) { + valuesForLLM[key] = inputs[key]; + delete valuesForPrompt[key]; + } + } + + var promptValue = await this.llmChain.prompt.formatPromptValue(valuesForPrompt); + var message = await llm.predictMessages( + promptValue.toChatMessages(), + valuesForLLM, + callbackManager + ); + console.log('message', message); + return parseOutput(message); + } +} + +module.exports = FunctionsAgent; diff --git a/api/app/langchain/agents/Functions/initializeFunctionsAgent.js b/api/app/langchain/agents/Functions/initializeFunctionsAgent.js new file mode 100644 index 000000000..e3ae04989 --- /dev/null +++ b/api/app/langchain/agents/Functions/initializeFunctionsAgent.js @@ -0,0 +1,33 @@ +const FunctionsAgent = require('./FunctionsAgent'); +const { AgentExecutor } = require('langchain/agents'); +const { BufferMemory, ChatMessageHistory } = require('langchain/memory'); + +const initializeFunctionsAgent = async ({ + tools, + model, + pastMessages, + currentDateString, + ...rest +}) => { + const agent = FunctionsAgent.fromLLMAndTools( + model, + tools, + { + currentDateString, + }); + + const memory = new BufferMemory({ + chatHistory: new ChatMessageHistory(pastMessages), + // returnMessages: true, // commenting this out retains memory + memoryKey: 'chat_history', + humanPrefix: 'User', + aiPrefix: 'Assistant', + inputKey: 'input', + outputKey: 'output' + }); + + return AgentExecutor.fromAgentAndTools({ agent, tools, memory, ...rest }); +}; + +module.exports = initializeFunctionsAgent; + diff --git a/api/app/langchain/agents/index.js b/api/app/langchain/agents/index.js new file mode 100644 index 000000000..3dc7d299e --- /dev/null +++ b/api/app/langchain/agents/index.js @@ -0,0 +1,7 @@ +const initializeCustomAgent = require('./CustomAgent/initializeCustomAgent'); +const initializeFunctionsAgent = require('./Functions/initializeFunctionsAgent'); + +module.exports = { + initializeCustomAgent, + initializeFunctionsAgent +}; \ No newline at end of file diff --git a/api/package.json b/api/package.json index 873a5a62f..f3992c865 100644 --- a/api/package.json +++ b/api/package.json @@ -46,7 +46,7 @@ "jsonwebtoken": "^9.0.0", "keyv": "^4.5.2", "keyv-file": "^0.2.0", - "langchain": "^0.0.92", + "langchain": "^0.0.94", "lodash": "^4.17.21", "meilisearch": "^0.33.0", "mongoose": "^7.1.1", diff --git a/api/server/routes/ask/askGPTPlugins.js b/api/server/routes/ask/askGPTPlugins.js index 1c0315237..56b18364c 100644 --- a/api/server/routes/ask/askGPTPlugins.js +++ b/api/server/routes/ask/askGPTPlugins.js @@ -39,8 +39,8 @@ router.post('/', requireJwtAuth, async (req, res) => { if (endpoint !== 'gptPlugins') return handleError(res, { text: 'Illegal request' }); const agentOptions = req.body?.agentOptions ?? { + agent: 'classic', model: 'gpt-3.5-turbo', - // model: 'gpt-4', // for agent model temperature: 0, // top_p: 1, // presence_penalty: 0, @@ -60,20 +60,12 @@ router.post('/', requireJwtAuth, async (req, res) => { presence_penalty: req.body?.presence_penalty ?? 0, frequency_penalty: req.body?.frequency_penalty ?? 0 }, - agentOptions + agentOptions: { + ...agentOptions, + // agent: 'functions' + } }; - // const availableModels = getOpenAIModels(); - // if (availableModels.find((model) => model === endpointOption.modelOptions.model) === undefined) { - // return handleError(res, { text: `Illegal request: model` }); - // } - - // console.log('ask log', { - // text, - // conversationId, - // endpointOption - // }); - console.log('ask log'); console.dir({ text, conversationId, endpointOption }, { depth: null }); diff --git a/api/server/routes/endpoints.js b/api/server/routes/endpoints.js index 4b7402505..2e54c0b42 100644 --- a/api/server/routes/endpoints.js +++ b/api/server/routes/endpoints.js @@ -53,7 +53,7 @@ router.get('/', async function (req, res) { ? { availableModels: getOpenAIModels(), userProvide: apiKey === 'user_provided' } : false; const gptPlugins = apiKey - ? { availableModels: getPluginModels(), availableTools } + ? { availableModels: getPluginModels(), availableTools, availableAgents: ['classic', 'functions'] } : false; const bingAI = process.env.BINGAI_TOKEN ? { userProvide: process.env.BINGAI_TOKEN == 'user_provided' }