fix: Enhance Test Coverage and Fix Compatibility Issues 👷‍♂️ (#1363)

* refactor: only remove conversation states from localStorage on login/logout but not on refresh

* chore: add debugging log for azure completion url

* chore: add api-key to redact regex

* fix: do not show endpoint selector if endpoint is falsy

* chore: remove logger from genAzureChatCompletion

* feat(ci): mock fetchEventSource

* refactor(ci): mock all model methods in BaseClient.test, as well as mock the implementation for getCompletion in FakeClient

* fix(OpenAIClient): consider chatCompletion if model name includes `gpt` as opposed to `gpt-`

* fix(ChatGPTClient/azureOpenAI): Remove 'model' option for Azure compatibility (cannot be sent in payload body)

* feat(ci): write new test suite that significantly increase test coverage for OpenAIClient and BaseClient by covering most of the real implementation of the `sendMessage` method
- test for the azure edge case where model option is appended to modelOptions, ensuring removal before sent to the azure endpoint
- test for expected azure url being passed to SSE POST request
- test for AZURE_OPENAI_DEFAULT_MODEL being set, but is not included in the URL deployment name as expected
- test getCompletion method to have correct payload
fix(ci/OpenAIClient.test.js): correctly mock hanging/async methods

* refactor(addTitle): allow azure to title as it aborts signal on completion
This commit is contained in:
Danny Avila 2023-12-15 13:27:13 -05:00 committed by GitHub
parent 072a7e5f05
commit 0958db3825
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 189 additions and 35 deletions

View file

@ -166,6 +166,12 @@ class ChatGPTClient extends BaseClient {
console.debug(modelOptions);
console.debug();
}
if (this.azure || this.options.azure) {
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;
}
const opts = {
method: 'POST',
headers: {

View file

@ -104,7 +104,7 @@ class OpenAIClient extends BaseClient {
const { model } = this.modelOptions;
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-');
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt');
this.isChatGptModel = this.isChatCompletion;
if (
model.includes('text-davinci') ||

View file

@ -1,19 +1,33 @@
const { initializeFakeClient } = require('./FakeClient');
jest.mock('../../../lib/db/connectDb');
jest.mock('../../../models', () => {
return function () {
return {
save: jest.fn(),
deleteConvos: jest.fn(),
getConvo: jest.fn(),
getMessages: jest.fn(),
saveMessage: jest.fn(),
updateMessage: jest.fn(),
saveConvo: jest.fn(),
};
};
});
jest.mock('~/models', () => ({
User: jest.fn(),
Key: jest.fn(),
Session: jest.fn(),
Balance: jest.fn(),
Transaction: jest.fn(),
getMessages: jest.fn().mockResolvedValue([]),
saveMessage: jest.fn(),
updateMessage: jest.fn(),
deleteMessagesSince: jest.fn(),
deleteMessages: jest.fn(),
getConvoTitle: jest.fn(),
getConvo: jest.fn(),
saveConvo: jest.fn(),
deleteConvos: jest.fn(),
getPreset: jest.fn(),
getPresets: jest.fn(),
savePreset: jest.fn(),
deletePresets: jest.fn(),
findFileById: jest.fn(),
createFile: jest.fn(),
updateFile: jest.fn(),
deleteFile: jest.fn(),
deleteFiles: jest.fn(),
getFiles: jest.fn(),
updateFileUsage: jest.fn(),
}));
jest.mock('langchain/chat_models/openai', () => {
return {

View file

@ -42,7 +42,6 @@ class FakeClient extends BaseClient {
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 4097;
}
getCompletion() {}
buildMessages() {}
getTokenCount(str) {
return str.length;
@ -86,6 +85,19 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
return 'Mock response text';
});
// eslint-disable-next-line no-unused-vars
TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => {
return {
choices: [
{
message: {
content: 'Mock response text',
},
},
],
};
});
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
const orderedMessages = TestClient.constructor.getMessagesForConversation({
messages,

View file

@ -1,8 +1,46 @@
require('dotenv').config();
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { genAzureChatCompletion } = require('~/utils/azureUtils');
const OpenAIClient = require('../OpenAIClient');
jest.mock('meilisearch');
jest.mock('~/lib/db/connectDb');
jest.mock('~/models', () => ({
User: jest.fn(),
Key: jest.fn(),
Session: jest.fn(),
Balance: jest.fn(),
Transaction: jest.fn(),
getMessages: jest.fn().mockResolvedValue([]),
saveMessage: jest.fn(),
updateMessage: jest.fn(),
deleteMessagesSince: jest.fn(),
deleteMessages: jest.fn(),
getConvoTitle: jest.fn(),
getConvo: jest.fn(),
saveConvo: jest.fn(),
deleteConvos: jest.fn(),
getPreset: jest.fn(),
getPresets: jest.fn(),
savePreset: jest.fn(),
deletePresets: jest.fn(),
findFileById: jest.fn(),
createFile: jest.fn(),
updateFile: jest.fn(),
deleteFile: jest.fn(),
deleteFiles: jest.fn(),
getFiles: jest.fn(),
updateFileUsage: jest.fn(),
}));
jest.mock('langchain/chat_models/openai', () => {
return {
ChatOpenAI: jest.fn().mockImplementation(() => {
return {};
}),
};
});
describe('OpenAIClient', () => {
let client, client2;
const model = 'gpt-4';
@ -12,6 +50,21 @@ describe('OpenAIClient', () => {
{ role: 'assistant', sender: 'Assistant', text: 'Hi', messageId: '2' },
];
const defaultOptions = {
// debug: true,
openaiApiKey: 'new-api-key',
modelOptions: {
model,
temperature: 0.7,
},
};
const defaultAzureOptions = {
azureOpenAIApiInstanceName: 'your-instance-name',
azureOpenAIApiDeploymentName: 'your-deployment-name',
azureOpenAIApiVersion: '2020-07-01-preview',
};
beforeAll(() => {
jest.spyOn(console, 'warn').mockImplementation(() => {});
});
@ -21,14 +74,7 @@ describe('OpenAIClient', () => {
});
beforeEach(() => {
const options = {
// debug: true,
openaiApiKey: 'new-api-key',
modelOptions: {
model,
temperature: 0.7,
},
};
const options = { ...defaultOptions };
client = new OpenAIClient('test-api-key', options);
client2 = new OpenAIClient('test-api-key', options);
client.summarizeMessages = jest.fn().mockResolvedValue({
@ -40,6 +86,7 @@ describe('OpenAIClient', () => {
.fn()
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
client.constructor.freeAndResetAllEncoders();
client.getMessages = jest.fn().mockResolvedValue([]);
});
describe('setOptions', () => {
@ -408,4 +455,46 @@ describe('OpenAIClient', () => {
});
});
});
describe('sendMessage/getCompletion', () => {
afterEach(() => {
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
});
it('[Azure OpenAI] should call getCompletion and fetchEventSource with correct args', async () => {
// Set a default model
process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt4-turbo';
const onProgress = jest.fn().mockImplementation(() => ({}));
client.azure = defaultAzureOptions;
const getCompletion = jest.spyOn(client, 'getCompletion');
await client.sendMessage('Hi mom!', {
replaceOptions: true,
...defaultOptions,
onProgress,
azure: defaultAzureOptions,
});
expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1);
expect(getCompletion.mock.calls[0][0][0].role).toBe('user');
expect(getCompletion.mock.calls[0][0][0].content).toBe('Hi mom!');
expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1);
// Check if the first argument (url) is correct
const expectedURL = genAzureChatCompletion(defaultAzureOptions);
const firstCallArgs = fetchEventSource.mock.calls[0];
expect(firstCallArgs[0]).toBe(expectedURL);
// Should not have model in the deployment name
expect(firstCallArgs[0]).not.toContain('gpt4-turbo');
// Should not include the model in request body
const requestBody = JSON.parse(firstCallArgs[1].body);
expect(requestBody).not.toHaveProperty('model');
});
});
});

View file

@ -3,7 +3,7 @@ const winston = require('winston');
const traverse = require('traverse');
const { klona } = require('klona/full');
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/];
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/, /api-key: \w+/];
/**
* Determines if a given key string is sensitive.

View file

@ -7,6 +7,7 @@ module.exports = {
'./test/jestSetup.js',
'./test/__mocks__/KeyvMongo.js',
'./test/__mocks__/logger.js',
'./test/__mocks__/fetchEventSource.js',
],
moduleNameMapper: {
'~/(.*)': '<rootDir>/$1',

View file

@ -7,8 +7,8 @@ const addTitle = async (req, { text, response, client }) => {
return;
}
// If the request was aborted, don't generate the title.
if (client.abortController.signal.aborted) {
// If the request was aborted and is not azure, don't generate the title.
if (!client.azure && client.abortController.signal.aborted) {
return;
}

View file

@ -0,0 +1,27 @@
jest.mock('@waylaidwanderer/fetch-event-source', () => ({
fetchEventSource: jest
.fn()
.mockImplementation((url, { onopen, onmessage, onclose, onerror, error }) => {
// Simulating the onopen event
onopen && onopen({ status: 200 });
// Simulating a few onmessage events
onmessage &&
onmessage({ data: JSON.stringify({ message: 'First message' }), event: 'message' });
onmessage &&
onmessage({ data: JSON.stringify({ message: 'Second message' }), event: 'message' });
onmessage &&
onmessage({ data: JSON.stringify({ message: 'Third message' }), event: 'message' });
// Simulate the onclose event
onclose && onclose();
if (error) {
// Simulate the onerror event
onerror && onerror({ status: 500 });
}
// Return a Promise that resolves to simulate async behavior
return Promise.resolve();
}),
}));

View file

@ -6,7 +6,7 @@
* @property {string} azureOpenAIApiVersion - The Azure OpenAI API version.
*/
const { isEnabled } = require('../server/utils');
const { isEnabled } = require('~/server/utils');
/**
* Sanitizes the model name to be used in the URL by removing or replacing disallowed characters.

View file

@ -1,5 +1,5 @@
import { Content, Portal, Root } from '@radix-ui/react-popover';
import { alternateName } from 'librechat-data-provider';
import { Content, Portal, Root } from '@radix-ui/react-popover';
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
import type { FC } from 'react';
import EndpointItems from './Endpoints/MenuItems';
@ -14,9 +14,14 @@ const EndpointsMenu: FC = () => {
const { conversation } = useChatContext();
const selected = conversation?.endpoint ?? '';
if (!selected) {
console.warn('No endpoint selected');
return null;
}
return (
<Root>
<TitleButton primaryText={alternateName[selected] + ' '} />
<TitleButton primaryText={(alternateName[selected] ?? '') + ' '} />
<Portal>
<div
style={{

View file

@ -347,6 +347,11 @@ export const useLoginUserMutation = (): UseMutationResult<
return useMutation((payload: t.TLoginUser) => dataService.login(payload), {
onMutate: () => {
queryClient.removeQueries();
localStorage.removeItem('lastConversationSetup');
localStorage.removeItem('lastSelectedModel');
localStorage.removeItem('lastSelectedTools');
localStorage.removeItem('filesToDelete');
localStorage.removeItem('lastAssistant');
},
});
};
@ -375,11 +380,6 @@ export const useRefreshTokenMutation = (): UseMutationResult<
return useMutation(() => request.refreshToken(), {
onMutate: () => {
queryClient.removeQueries();
localStorage.removeItem('lastConversationSetup');
localStorage.removeItem('lastSelectedModel');
localStorage.removeItem('lastSelectedTools');
localStorage.removeItem('filesToDelete');
localStorage.removeItem('lastAssistant');
},
});
};