From 5d985746cb0429a06f51d31129a2614bc43ad82a Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Thu, 4 Jul 2024 10:34:28 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=9B=A0=EF=B8=8F=20fix:=20Tool=20Filtering?= =?UTF-8?q?=20in=20PluginsClient=20(#3266)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat(plugins): implement tool filtering in PluginsClient Add functionality to filter tools based on filteredTools and includedTools arrays in the request's app locals. This allows for dynamic tool selection on a per-request basis, enhancing the flexibility of the plugin system. * test(plugins): add unit tests for tool filtering in PluginsClient Introduce comprehensive test suite for the new tool filtering feature in PluginsClient. Cover scenarios including filtering out tools, including specific tools, prioritization of includedTools over filteredTools, and behavior when no filters are provided. * chore: Remove unused legacy Conversation component and update imports --- api/app/clients/PluginsClient.js | 11 ++ api/app/clients/specs/PluginsClient.test.js | 91 ++++++++++ .../components/Conversations/Conversation.jsx | 156 ------------------ client/src/components/Conversations/index.ts | 1 - 4 files changed, 102 insertions(+), 157 deletions(-) delete mode 100644 client/src/components/Conversations/Conversation.jsx diff --git a/api/app/clients/PluginsClient.js b/api/app/clients/PluginsClient.js index 86931c449b..2ce0ece4e7 100644 --- a/api/app/clients/PluginsClient.js +++ b/api/app/clients/PluginsClient.js @@ -244,6 +244,17 @@ class PluginsClient extends OpenAIClient { } async sendMessage(message, opts = {}) { + /** @type {{ filteredTools: string[], includedTools: string[] }} */ + const { filteredTools = [], includedTools = [] } = this.options.req.app.locals; + + if (includedTools.length > 0) { + const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin)); + this.options.tools = tools; + } else { + const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin)); + this.options.tools = tools; + } + // If a message is edited, no tools can be used. const completionMode = this.options.tools.length === 0 || opts.isEdited; if (completionMode) { diff --git a/api/app/clients/specs/PluginsClient.test.js b/api/app/clients/specs/PluginsClient.test.js index dfd57b23b9..57064cf8e6 100644 --- a/api/app/clients/specs/PluginsClient.test.js +++ b/api/app/clients/specs/PluginsClient.test.js @@ -194,6 +194,7 @@ describe('PluginsClient', () => { expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo'); }); }); + describe('Azure OpenAI tests specific to Plugins', () => { // TODO: add more tests for Azure OpenAI integration with Plugins // let client; @@ -220,4 +221,94 @@ describe('PluginsClient', () => { spy.mockRestore(); }); }); + + describe('sendMessage with filtered tools', () => { + let TestAgent; + const apiKey = 'fake-api-key'; + const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }]; + + beforeEach(() => { + TestAgent = new PluginsClient(apiKey, { + tools: mockTools, + modelOptions: { + model: 'gpt-3.5-turbo', + temperature: 0, + max_tokens: 2, + }, + agentOptions: { + model: 'gpt-3.5-turbo', + }, + }); + + TestAgent.options.req = { + app: { + locals: {}, + }, + }; + + TestAgent.sendMessage = jest.fn().mockImplementation(async () => { + const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals; + + if (includedTools.length > 0) { + const tools = TestAgent.options.tools.filter((plugin) => + includedTools.includes(plugin.name), + ); + TestAgent.options.tools = tools; + } else { + const tools = TestAgent.options.tools.filter( + (plugin) => !filteredTools.includes(plugin.name), + ); + TestAgent.options.tools = tools; + } + + return { + text: 'Mocked response', + tools: TestAgent.options.tools, + }; + }); + }); + + test('should filter out tools when filteredTools is provided', async () => { + TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3']; + const response = await TestAgent.sendMessage('Test message'); + expect(response.tools).toHaveLength(2); + expect(response.tools).toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'tool2' }), + expect.objectContaining({ name: 'tool4' }), + ]), + ); + }); + + test('should only include specified tools when includedTools is provided', async () => { + TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4']; + const response = await TestAgent.sendMessage('Test message'); + expect(response.tools).toHaveLength(2); + expect(response.tools).toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'tool2' }), + expect.objectContaining({ name: 'tool4' }), + ]), + ); + }); + + test('should prioritize includedTools over filteredTools', async () => { + TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3']; + TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2']; + const response = await TestAgent.sendMessage('Test message'); + expect(response.tools).toHaveLength(2); + expect(response.tools).toEqual( + expect.arrayContaining([ + expect.objectContaining({ name: 'tool1' }), + expect.objectContaining({ name: 'tool2' }), + ]), + ); + }); + + test('should not modify tools when no filters are provided', async () => { + const response = await TestAgent.sendMessage('Test message'); + expect(response.tools).toHaveLength(4); + expect(response.tools).toEqual(expect.arrayContaining(mockTools)); + }); + }); }); diff --git a/client/src/components/Conversations/Conversation.jsx b/client/src/components/Conversations/Conversation.jsx deleted file mode 100644 index cfa3e3c401..0000000000 --- a/client/src/components/Conversations/Conversation.jsx +++ /dev/null @@ -1,156 +0,0 @@ -import { useState, useRef } from 'react'; -import { useRecoilState, useSetRecoilState } from 'recoil'; -import { useUpdateConversationMutation } from '~/data-provider'; -import { useConversations, useConversation } from '~/hooks'; -import { MinimalIcon } from '~/components/Endpoints'; -import { NotificationSeverity } from '~/common'; -import { useToastContext } from '~/Providers'; -import DeleteButton from './DeleteButton'; -import RenameButton from './RenameButton'; -import store from '~/store'; - -export default function Conversation({ conversation, retainView }) { - const { showToast } = useToastContext(); - const [currentConversation, setCurrentConversation] = useRecoilState(store.conversation); - const setSubmission = useSetRecoilState(store.submission); - - const { refreshConversations } = useConversations(); - const { switchToConversation } = useConversation(); - - const updateConvoMutation = useUpdateConversationMutation(currentConversation?.conversationId); - - const [renaming, setRenaming] = useState(false); - const inputRef = useRef(null); - - const { conversationId, title } = conversation; - - const [titleInput, setTitleInput] = useState(title); - - const clickHandler = async () => { - if (currentConversation?.conversationId === conversationId) { - return; - } - - // stop existing submission - setSubmission(null); - - // set document title - document.title = title; - - // set conversation to the new conversation - if (conversation?.endpoint === 'gptPlugins') { - const lastSelectedTools = JSON.parse(localStorage.getItem('lastSelectedTools')) || []; - switchToConversation({ ...conversation, tools: lastSelectedTools }); - } else { - switchToConversation(conversation); - } - }; - - const renameHandler = (e) => { - e.preventDefault(); - setTitleInput(title); - setRenaming(true); - setTimeout(() => { - inputRef.current.focus(); - }, 25); - }; - - const cancelHandler = (e) => { - e.preventDefault(); - setRenaming(false); - }; - - const onRename = (e) => { - e.preventDefault(); - setRenaming(false); - if (titleInput === title) { - return; - } - updateConvoMutation.mutate( - { conversationId, title: titleInput }, - { - onSuccess: () => { - refreshConversations(); - if (conversationId == currentConversation?.conversationId) { - setCurrentConversation((prevState) => ({ - ...prevState, - title: titleInput, - })); - } - }, - onError: () => { - setTitleInput(title); - showToast({ - message: 'Failed to rename conversation', - severity: NotificationSeverity.ERROR, - showIcon: true, - }); - }, - }, - ); - }; - - const icon = MinimalIcon({ - size: 20, - endpoint: conversation.endpoint, - model: conversation.model, - error: false, - className: 'mr-0', - }); - - const handleKeyDown = (e) => { - if (e.key === 'Enter') { - onRename(e); - } - }; - - const aProps = { - className: - 'animate-flash group relative flex cursor-pointer items-center gap-3 break-all rounded-md bg-gray-300 dark:bg-gray-800 py-3 px-3 pr-14', - }; - - if (currentConversation?.conversationId !== conversationId) { - aProps.className = - 'group relative flex cursor-pointer items-center gap-3 break-all rounded-md py-3 px-3 hover:bg-gray-200 dark:hover:bg-gray-800 hover:pr-4'; - } - - return ( - clickHandler()} {...aProps}> - {icon} -
- {renaming === true ? ( - setTitleInput(e.target.value)} - onBlur={onRename} - onKeyDown={handleKeyDown} - /> - ) : ( - title - )} -
- {currentConversation?.conversationId === conversationId ? ( -
- - -
- ) : ( -
- )} - - ); -} diff --git a/client/src/components/Conversations/index.ts b/client/src/components/Conversations/index.ts index 72e8babd44..1e68c0bae3 100644 --- a/client/src/components/Conversations/index.ts +++ b/client/src/components/Conversations/index.ts @@ -1,5 +1,4 @@ export { default as Fork } from './Fork'; export { default as Pages } from './Pages'; -export { default as Conversation } from './Conversation'; export { default as RenameButton } from './RenameButton'; export { default as Conversations } from './Conversations';