From bffa9ad0161d4a6505c4dd1f143905acf3ba274c Mon Sep 17 00:00:00 2001 From: Daniel Avila Date: Wed, 14 Jun 2023 13:23:02 -0400 Subject: [PATCH] refactor(handleTools.js): change loadTools function signature to include functions parameter feat(handleTools.test.js): add test for loading StructuredSD tool with functions parameter --- api/app/langchain/ChatAgent.js | 1 + api/app/langchain/tools/util/handleTools.js | 3 +-- api/app/langchain/tools/util/handleTools.test.js | 14 +++++++++++++- 3 files changed, 15 insertions(+), 3 deletions(-) diff --git a/api/app/langchain/ChatAgent.js b/api/app/langchain/ChatAgent.js index be4635c553..b17c55ae7e 100644 --- a/api/app/langchain/ChatAgent.js +++ b/api/app/langchain/ChatAgent.js @@ -451,6 +451,7 @@ Only respond with your conversational reply to the following User Message: user, model, tools: this.options.tools, + functions: this.functionsAgent, options: { openAIApiKey: this.openAIApiKey } diff --git a/api/app/langchain/tools/util/handleTools.js b/api/app/langchain/tools/util/handleTools.js index 2e71946839..ed7d39dc47 100644 --- a/api/app/langchain/tools/util/handleTools.js +++ b/api/app/langchain/tools/util/handleTools.js @@ -72,8 +72,7 @@ const loadToolWithAuth = async (user, authFields, ToolConstructor, options = {}) }; }; -const loadTools = async ({ user, model, tools = [], options = {} }) => { - const { functions } = options; +const loadTools = async ({ user, model, functions = null, tools = [], options = {} }) => { const toolConstructors = { calculator: Calculator, google: GoogleSearchAPI, diff --git a/api/app/langchain/tools/util/handleTools.test.js b/api/app/langchain/tools/util/handleTools.test.js index 20c792c847..2c79b700bc 100644 --- a/api/app/langchain/tools/util/handleTools.test.js +++ b/api/app/langchain/tools/util/handleTools.test.js @@ -24,7 +24,7 @@ const { validateTools, loadTools } = require('./'); const PluginService = require('../../../../server/services/PluginService'); const { BaseChatModel } = require('langchain/chat_models/openai'); const { Calculator } = require('langchain/tools/calculator'); -const { availableTools, OpenAICreateImage, GoogleSearchAPI } = require('../'); +const { availableTools, OpenAICreateImage, GoogleSearchAPI, StructuredSD } = require('../'); describe('Tool Handlers', () => { let fakeUser; @@ -174,5 +174,17 @@ describe('Tool Handlers', () => { }); expect(toolFunctions).toEqual({}); }); + it('should return the StructuredTool version when using functions', async () => { + process.env.SD_WEBUI_URL = mockCredential; + toolFunctions = await loadTools({ + user: fakeUser._id, + model: BaseChatModel, + tools: ['stable-diffusion'], + functions: true + }); + const structuredTool = await toolFunctions['stable-diffusion'](); + expect(structuredTool).toBeInstanceOf(StructuredSD); + delete process.env.SD_WEBUI_URL; + }); }); });