🛠️ fix: Tool Filtering in PluginsClient (#3266)

* feat(plugins): implement tool filtering in PluginsClient

Add functionality to filter tools based on filteredTools and includedTools
arrays in the request's app locals. This allows for dynamic tool selection
on a per-request basis, enhancing the flexibility of the plugin system.

* test(plugins): add unit tests for tool filtering in PluginsClient

Introduce comprehensive test suite for the new tool filtering feature
in PluginsClient. Cover scenarios including filtering out tools,
including specific tools, prioritization of includedTools over
filteredTools, and behavior when no filters are provided.

* chore: Remove unused legacy Conversation component and update imports
This commit is contained in:
Danny Avila 2024-07-04 10:34:28 -04:00 committed by GitHub
parent 04654014b2
commit 5d985746cb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 102 additions and 157 deletions

View file

@ -244,6 +244,17 @@ class PluginsClient extends OpenAIClient {
} }
async sendMessage(message, opts = {}) { async sendMessage(message, opts = {}) {
/** @type {{ filteredTools: string[], includedTools: string[] }} */
const { filteredTools = [], includedTools = [] } = this.options.req.app.locals;
if (includedTools.length > 0) {
const tools = this.options.tools.filter((plugin) => includedTools.includes(plugin));
this.options.tools = tools;
} else {
const tools = this.options.tools.filter((plugin) => !filteredTools.includes(plugin));
this.options.tools = tools;
}
// If a message is edited, no tools can be used. // If a message is edited, no tools can be used.
const completionMode = this.options.tools.length === 0 || opts.isEdited; const completionMode = this.options.tools.length === 0 || opts.isEdited;
if (completionMode) { if (completionMode) {

View file

@ -194,6 +194,7 @@ describe('PluginsClient', () => {
expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo'); expect(client.getFunctionModelName('')).toBe('gpt-3.5-turbo');
}); });
}); });
describe('Azure OpenAI tests specific to Plugins', () => { describe('Azure OpenAI tests specific to Plugins', () => {
// TODO: add more tests for Azure OpenAI integration with Plugins // TODO: add more tests for Azure OpenAI integration with Plugins
// let client; // let client;
@ -220,4 +221,94 @@ describe('PluginsClient', () => {
spy.mockRestore(); spy.mockRestore();
}); });
}); });
describe('sendMessage with filtered tools', () => {
let TestAgent;
const apiKey = 'fake-api-key';
const mockTools = [{ name: 'tool1' }, { name: 'tool2' }, { name: 'tool3' }, { name: 'tool4' }];
beforeEach(() => {
TestAgent = new PluginsClient(apiKey, {
tools: mockTools,
modelOptions: {
model: 'gpt-3.5-turbo',
temperature: 0,
max_tokens: 2,
},
agentOptions: {
model: 'gpt-3.5-turbo',
},
});
TestAgent.options.req = {
app: {
locals: {},
},
};
TestAgent.sendMessage = jest.fn().mockImplementation(async () => {
const { filteredTools = [], includedTools = [] } = TestAgent.options.req.app.locals;
if (includedTools.length > 0) {
const tools = TestAgent.options.tools.filter((plugin) =>
includedTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
} else {
const tools = TestAgent.options.tools.filter(
(plugin) => !filteredTools.includes(plugin.name),
);
TestAgent.options.tools = tools;
}
return {
text: 'Mocked response',
tools: TestAgent.options.tools,
};
});
});
test('should filter out tools when filteredTools is provided', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should only include specified tools when includedTools is provided', async () => {
TestAgent.options.req.app.locals.includedTools = ['tool2', 'tool4'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool2' }),
expect.objectContaining({ name: 'tool4' }),
]),
);
});
test('should prioritize includedTools over filteredTools', async () => {
TestAgent.options.req.app.locals.filteredTools = ['tool1', 'tool3'];
TestAgent.options.req.app.locals.includedTools = ['tool1', 'tool2'];
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(2);
expect(response.tools).toEqual(
expect.arrayContaining([
expect.objectContaining({ name: 'tool1' }),
expect.objectContaining({ name: 'tool2' }),
]),
);
});
test('should not modify tools when no filters are provided', async () => {
const response = await TestAgent.sendMessage('Test message');
expect(response.tools).toHaveLength(4);
expect(response.tools).toEqual(expect.arrayContaining(mockTools));
});
});
}); });

View file

@ -1,156 +0,0 @@
import { useState, useRef } from 'react';
import { useRecoilState, useSetRecoilState } from 'recoil';
import { useUpdateConversationMutation } from '~/data-provider';
import { useConversations, useConversation } from '~/hooks';
import { MinimalIcon } from '~/components/Endpoints';
import { NotificationSeverity } from '~/common';
import { useToastContext } from '~/Providers';
import DeleteButton from './DeleteButton';
import RenameButton from './RenameButton';
import store from '~/store';
export default function Conversation({ conversation, retainView }) {
const { showToast } = useToastContext();
const [currentConversation, setCurrentConversation] = useRecoilState(store.conversation);
const setSubmission = useSetRecoilState(store.submission);
const { refreshConversations } = useConversations();
const { switchToConversation } = useConversation();
const updateConvoMutation = useUpdateConversationMutation(currentConversation?.conversationId);
const [renaming, setRenaming] = useState(false);
const inputRef = useRef(null);
const { conversationId, title } = conversation;
const [titleInput, setTitleInput] = useState(title);
const clickHandler = async () => {
if (currentConversation?.conversationId === conversationId) {
return;
}
// stop existing submission
setSubmission(null);
// set document title
document.title = title;
// set conversation to the new conversation
if (conversation?.endpoint === 'gptPlugins') {
const lastSelectedTools = JSON.parse(localStorage.getItem('lastSelectedTools')) || [];
switchToConversation({ ...conversation, tools: lastSelectedTools });
} else {
switchToConversation(conversation);
}
};
const renameHandler = (e) => {
e.preventDefault();
setTitleInput(title);
setRenaming(true);
setTimeout(() => {
inputRef.current.focus();
}, 25);
};
const cancelHandler = (e) => {
e.preventDefault();
setRenaming(false);
};
const onRename = (e) => {
e.preventDefault();
setRenaming(false);
if (titleInput === title) {
return;
}
updateConvoMutation.mutate(
{ conversationId, title: titleInput },
{
onSuccess: () => {
refreshConversations();
if (conversationId == currentConversation?.conversationId) {
setCurrentConversation((prevState) => ({
...prevState,
title: titleInput,
}));
}
},
onError: () => {
setTitleInput(title);
showToast({
message: 'Failed to rename conversation',
severity: NotificationSeverity.ERROR,
showIcon: true,
});
},
},
);
};
const icon = MinimalIcon({
size: 20,
endpoint: conversation.endpoint,
model: conversation.model,
error: false,
className: 'mr-0',
});
const handleKeyDown = (e) => {
if (e.key === 'Enter') {
onRename(e);
}
};
const aProps = {
className:
'animate-flash group relative flex cursor-pointer items-center gap-3 break-all rounded-md bg-gray-300 dark:bg-gray-800 py-3 px-3 pr-14',
};
if (currentConversation?.conversationId !== conversationId) {
aProps.className =
'group relative flex cursor-pointer items-center gap-3 break-all rounded-md py-3 px-3 hover:bg-gray-200 dark:hover:bg-gray-800 hover:pr-4';
}
return (
<a data-testid="convo-item" onClick={() => clickHandler()} {...aProps}>
{icon}
<div className="relative max-h-5 flex-1 overflow-hidden text-ellipsis break-all">
{renaming === true ? (
<input
ref={inputRef}
type="text"
className="m-0 mr-0 w-full border border-blue-500 bg-transparent p-0 text-sm leading-tight outline-none"
value={titleInput}
onChange={(e) => setTitleInput(e.target.value)}
onBlur={onRename}
onKeyDown={handleKeyDown}
/>
) : (
title
)}
</div>
{currentConversation?.conversationId === conversationId ? (
<div className="visible absolute right-1 z-10 flex text-gray-300">
<RenameButton
conversationId={conversationId}
renaming={renaming}
renameHandler={renameHandler}
onRename={onRename}
/>
<DeleteButton
conversationId={conversationId}
renaming={renaming}
cancelHandler={cancelHandler}
retainView={retainView}
title={title}
/>
</div>
) : (
<div className="absolute inset-y-0 right-0 z-10 w-8 rounded-r-md bg-gradient-to-l from-gray-50 group-hover:from-gray-50 dark:from-gray-900 dark:group-hover:from-gray-800" />
)}
</a>
);
}

View file

@ -1,5 +1,4 @@
export { default as Fork } from './Fork'; export { default as Fork } from './Fork';
export { default as Pages } from './Pages'; export { default as Pages } from './Pages';
export { default as Conversation } from './Conversation';
export { default as RenameButton } from './RenameButton'; export { default as RenameButton } from './RenameButton';
export { default as Conversations } from './Conversations'; export { default as Conversations } from './Conversations';