mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-16 16:30:15 +01:00
fix: Enhance Test Coverage and Fix Compatibility Issues 👷♂️ (#1363)
* refactor: only remove conversation states from localStorage on login/logout but not on refresh * chore: add debugging log for azure completion url * chore: add api-key to redact regex * fix: do not show endpoint selector if endpoint is falsy * chore: remove logger from genAzureChatCompletion * feat(ci): mock fetchEventSource * refactor(ci): mock all model methods in BaseClient.test, as well as mock the implementation for getCompletion in FakeClient * fix(OpenAIClient): consider chatCompletion if model name includes `gpt` as opposed to `gpt-` * fix(ChatGPTClient/azureOpenAI): Remove 'model' option for Azure compatibility (cannot be sent in payload body) * feat(ci): write new test suite that significantly increase test coverage for OpenAIClient and BaseClient by covering most of the real implementation of the `sendMessage` method - test for the azure edge case where model option is appended to modelOptions, ensuring removal before sent to the azure endpoint - test for expected azure url being passed to SSE POST request - test for AZURE_OPENAI_DEFAULT_MODEL being set, but is not included in the URL deployment name as expected - test getCompletion method to have correct payload fix(ci/OpenAIClient.test.js): correctly mock hanging/async methods * refactor(addTitle): allow azure to title as it aborts signal on completion
This commit is contained in:
parent
072a7e5f05
commit
0958db3825
12 changed files with 189 additions and 35 deletions
|
|
@ -166,6 +166,12 @@ class ChatGPTClient extends BaseClient {
|
||||||
console.debug(modelOptions);
|
console.debug(modelOptions);
|
||||||
console.debug();
|
console.debug();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (this.azure || this.options.azure) {
|
||||||
|
// Azure does not accept `model` in the body, so we need to remove it.
|
||||||
|
delete modelOptions.model;
|
||||||
|
}
|
||||||
|
|
||||||
const opts = {
|
const opts = {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
|
|
|
||||||
|
|
@ -104,7 +104,7 @@ class OpenAIClient extends BaseClient {
|
||||||
|
|
||||||
const { model } = this.modelOptions;
|
const { model } = this.modelOptions;
|
||||||
|
|
||||||
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt-');
|
this.isChatCompletion = this.useOpenRouter || !!reverseProxy || model.includes('gpt');
|
||||||
this.isChatGptModel = this.isChatCompletion;
|
this.isChatGptModel = this.isChatCompletion;
|
||||||
if (
|
if (
|
||||||
model.includes('text-davinci') ||
|
model.includes('text-davinci') ||
|
||||||
|
|
|
||||||
|
|
@ -1,19 +1,33 @@
|
||||||
const { initializeFakeClient } = require('./FakeClient');
|
const { initializeFakeClient } = require('./FakeClient');
|
||||||
|
|
||||||
jest.mock('../../../lib/db/connectDb');
|
jest.mock('../../../lib/db/connectDb');
|
||||||
jest.mock('../../../models', () => {
|
jest.mock('~/models', () => ({
|
||||||
return function () {
|
User: jest.fn(),
|
||||||
return {
|
Key: jest.fn(),
|
||||||
save: jest.fn(),
|
Session: jest.fn(),
|
||||||
deleteConvos: jest.fn(),
|
Balance: jest.fn(),
|
||||||
getConvo: jest.fn(),
|
Transaction: jest.fn(),
|
||||||
getMessages: jest.fn(),
|
getMessages: jest.fn().mockResolvedValue([]),
|
||||||
saveMessage: jest.fn(),
|
saveMessage: jest.fn(),
|
||||||
updateMessage: jest.fn(),
|
updateMessage: jest.fn(),
|
||||||
saveConvo: jest.fn(),
|
deleteMessagesSince: jest.fn(),
|
||||||
};
|
deleteMessages: jest.fn(),
|
||||||
};
|
getConvoTitle: jest.fn(),
|
||||||
});
|
getConvo: jest.fn(),
|
||||||
|
saveConvo: jest.fn(),
|
||||||
|
deleteConvos: jest.fn(),
|
||||||
|
getPreset: jest.fn(),
|
||||||
|
getPresets: jest.fn(),
|
||||||
|
savePreset: jest.fn(),
|
||||||
|
deletePresets: jest.fn(),
|
||||||
|
findFileById: jest.fn(),
|
||||||
|
createFile: jest.fn(),
|
||||||
|
updateFile: jest.fn(),
|
||||||
|
deleteFile: jest.fn(),
|
||||||
|
deleteFiles: jest.fn(),
|
||||||
|
getFiles: jest.fn(),
|
||||||
|
updateFileUsage: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
jest.mock('langchain/chat_models/openai', () => {
|
jest.mock('langchain/chat_models/openai', () => {
|
||||||
return {
|
return {
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,6 @@ class FakeClient extends BaseClient {
|
||||||
|
|
||||||
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 4097;
|
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 4097;
|
||||||
}
|
}
|
||||||
getCompletion() {}
|
|
||||||
buildMessages() {}
|
buildMessages() {}
|
||||||
getTokenCount(str) {
|
getTokenCount(str) {
|
||||||
return str.length;
|
return str.length;
|
||||||
|
|
@ -86,6 +85,19 @@ const initializeFakeClient = (apiKey, options, fakeMessages) => {
|
||||||
return 'Mock response text';
|
return 'Mock response text';
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// eslint-disable-next-line no-unused-vars
|
||||||
|
TestClient.getCompletion = jest.fn().mockImplementation(async (..._args) => {
|
||||||
|
return {
|
||||||
|
choices: [
|
||||||
|
{
|
||||||
|
message: {
|
||||||
|
content: 'Mock response text',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
|
TestClient.buildMessages = jest.fn(async (messages, parentMessageId) => {
|
||||||
const orderedMessages = TestClient.constructor.getMessagesForConversation({
|
const orderedMessages = TestClient.constructor.getMessagesForConversation({
|
||||||
messages,
|
messages,
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,46 @@
|
||||||
require('dotenv').config();
|
require('dotenv').config();
|
||||||
|
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
|
||||||
|
const { genAzureChatCompletion } = require('~/utils/azureUtils');
|
||||||
const OpenAIClient = require('../OpenAIClient');
|
const OpenAIClient = require('../OpenAIClient');
|
||||||
|
|
||||||
jest.mock('meilisearch');
|
jest.mock('meilisearch');
|
||||||
|
|
||||||
|
jest.mock('~/lib/db/connectDb');
|
||||||
|
jest.mock('~/models', () => ({
|
||||||
|
User: jest.fn(),
|
||||||
|
Key: jest.fn(),
|
||||||
|
Session: jest.fn(),
|
||||||
|
Balance: jest.fn(),
|
||||||
|
Transaction: jest.fn(),
|
||||||
|
getMessages: jest.fn().mockResolvedValue([]),
|
||||||
|
saveMessage: jest.fn(),
|
||||||
|
updateMessage: jest.fn(),
|
||||||
|
deleteMessagesSince: jest.fn(),
|
||||||
|
deleteMessages: jest.fn(),
|
||||||
|
getConvoTitle: jest.fn(),
|
||||||
|
getConvo: jest.fn(),
|
||||||
|
saveConvo: jest.fn(),
|
||||||
|
deleteConvos: jest.fn(),
|
||||||
|
getPreset: jest.fn(),
|
||||||
|
getPresets: jest.fn(),
|
||||||
|
savePreset: jest.fn(),
|
||||||
|
deletePresets: jest.fn(),
|
||||||
|
findFileById: jest.fn(),
|
||||||
|
createFile: jest.fn(),
|
||||||
|
updateFile: jest.fn(),
|
||||||
|
deleteFile: jest.fn(),
|
||||||
|
deleteFiles: jest.fn(),
|
||||||
|
getFiles: jest.fn(),
|
||||||
|
updateFileUsage: jest.fn(),
|
||||||
|
}));
|
||||||
|
|
||||||
|
jest.mock('langchain/chat_models/openai', () => {
|
||||||
|
return {
|
||||||
|
ChatOpenAI: jest.fn().mockImplementation(() => {
|
||||||
|
return {};
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
describe('OpenAIClient', () => {
|
describe('OpenAIClient', () => {
|
||||||
let client, client2;
|
let client, client2;
|
||||||
const model = 'gpt-4';
|
const model = 'gpt-4';
|
||||||
|
|
@ -12,6 +50,21 @@ describe('OpenAIClient', () => {
|
||||||
{ role: 'assistant', sender: 'Assistant', text: 'Hi', messageId: '2' },
|
{ role: 'assistant', sender: 'Assistant', text: 'Hi', messageId: '2' },
|
||||||
];
|
];
|
||||||
|
|
||||||
|
const defaultOptions = {
|
||||||
|
// debug: true,
|
||||||
|
openaiApiKey: 'new-api-key',
|
||||||
|
modelOptions: {
|
||||||
|
model,
|
||||||
|
temperature: 0.7,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
const defaultAzureOptions = {
|
||||||
|
azureOpenAIApiInstanceName: 'your-instance-name',
|
||||||
|
azureOpenAIApiDeploymentName: 'your-deployment-name',
|
||||||
|
azureOpenAIApiVersion: '2020-07-01-preview',
|
||||||
|
};
|
||||||
|
|
||||||
beforeAll(() => {
|
beforeAll(() => {
|
||||||
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
jest.spyOn(console, 'warn').mockImplementation(() => {});
|
||||||
});
|
});
|
||||||
|
|
@ -21,14 +74,7 @@ describe('OpenAIClient', () => {
|
||||||
});
|
});
|
||||||
|
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
const options = {
|
const options = { ...defaultOptions };
|
||||||
// debug: true,
|
|
||||||
openaiApiKey: 'new-api-key',
|
|
||||||
modelOptions: {
|
|
||||||
model,
|
|
||||||
temperature: 0.7,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
client = new OpenAIClient('test-api-key', options);
|
client = new OpenAIClient('test-api-key', options);
|
||||||
client2 = new OpenAIClient('test-api-key', options);
|
client2 = new OpenAIClient('test-api-key', options);
|
||||||
client.summarizeMessages = jest.fn().mockResolvedValue({
|
client.summarizeMessages = jest.fn().mockResolvedValue({
|
||||||
|
|
@ -40,6 +86,7 @@ describe('OpenAIClient', () => {
|
||||||
.fn()
|
.fn()
|
||||||
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
|
.mockResolvedValue({ prompt: messages.map((m) => m.text).join('\n') });
|
||||||
client.constructor.freeAndResetAllEncoders();
|
client.constructor.freeAndResetAllEncoders();
|
||||||
|
client.getMessages = jest.fn().mockResolvedValue([]);
|
||||||
});
|
});
|
||||||
|
|
||||||
describe('setOptions', () => {
|
describe('setOptions', () => {
|
||||||
|
|
@ -408,4 +455,46 @@ describe('OpenAIClient', () => {
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
describe('sendMessage/getCompletion', () => {
|
||||||
|
afterEach(() => {
|
||||||
|
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
|
||||||
|
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
|
||||||
|
});
|
||||||
|
|
||||||
|
it('[Azure OpenAI] should call getCompletion and fetchEventSource with correct args', async () => {
|
||||||
|
// Set a default model
|
||||||
|
process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt4-turbo';
|
||||||
|
|
||||||
|
const onProgress = jest.fn().mockImplementation(() => ({}));
|
||||||
|
client.azure = defaultAzureOptions;
|
||||||
|
const getCompletion = jest.spyOn(client, 'getCompletion');
|
||||||
|
await client.sendMessage('Hi mom!', {
|
||||||
|
replaceOptions: true,
|
||||||
|
...defaultOptions,
|
||||||
|
onProgress,
|
||||||
|
azure: defaultAzureOptions,
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(getCompletion).toHaveBeenCalled();
|
||||||
|
expect(getCompletion.mock.calls.length).toBe(1);
|
||||||
|
expect(getCompletion.mock.calls[0][0][0].role).toBe('user');
|
||||||
|
expect(getCompletion.mock.calls[0][0][0].content).toBe('Hi mom!');
|
||||||
|
|
||||||
|
expect(fetchEventSource).toHaveBeenCalled();
|
||||||
|
expect(fetchEventSource.mock.calls.length).toBe(1);
|
||||||
|
|
||||||
|
// Check if the first argument (url) is correct
|
||||||
|
const expectedURL = genAzureChatCompletion(defaultAzureOptions);
|
||||||
|
const firstCallArgs = fetchEventSource.mock.calls[0];
|
||||||
|
|
||||||
|
expect(firstCallArgs[0]).toBe(expectedURL);
|
||||||
|
// Should not have model in the deployment name
|
||||||
|
expect(firstCallArgs[0]).not.toContain('gpt4-turbo');
|
||||||
|
|
||||||
|
// Should not include the model in request body
|
||||||
|
const requestBody = JSON.parse(firstCallArgs[1].body);
|
||||||
|
expect(requestBody).not.toHaveProperty('model');
|
||||||
|
});
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ const winston = require('winston');
|
||||||
const traverse = require('traverse');
|
const traverse = require('traverse');
|
||||||
const { klona } = require('klona/full');
|
const { klona } = require('klona/full');
|
||||||
|
|
||||||
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/];
|
const sensitiveKeys = [/^sk-\w+$/, /Bearer \w+/, /api-key: \w+/];
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Determines if a given key string is sensitive.
|
* Determines if a given key string is sensitive.
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ module.exports = {
|
||||||
'./test/jestSetup.js',
|
'./test/jestSetup.js',
|
||||||
'./test/__mocks__/KeyvMongo.js',
|
'./test/__mocks__/KeyvMongo.js',
|
||||||
'./test/__mocks__/logger.js',
|
'./test/__mocks__/logger.js',
|
||||||
|
'./test/__mocks__/fetchEventSource.js',
|
||||||
],
|
],
|
||||||
moduleNameMapper: {
|
moduleNameMapper: {
|
||||||
'~/(.*)': '<rootDir>/$1',
|
'~/(.*)': '<rootDir>/$1',
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,8 @@ const addTitle = async (req, { text, response, client }) => {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the request was aborted, don't generate the title.
|
// If the request was aborted and is not azure, don't generate the title.
|
||||||
if (client.abortController.signal.aborted) {
|
if (!client.azure && client.abortController.signal.aborted) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
27
api/test/__mocks__/fetchEventSource.js
Normal file
27
api/test/__mocks__/fetchEventSource.js
Normal file
|
|
@ -0,0 +1,27 @@
|
||||||
|
jest.mock('@waylaidwanderer/fetch-event-source', () => ({
|
||||||
|
fetchEventSource: jest
|
||||||
|
.fn()
|
||||||
|
.mockImplementation((url, { onopen, onmessage, onclose, onerror, error }) => {
|
||||||
|
// Simulating the onopen event
|
||||||
|
onopen && onopen({ status: 200 });
|
||||||
|
|
||||||
|
// Simulating a few onmessage events
|
||||||
|
onmessage &&
|
||||||
|
onmessage({ data: JSON.stringify({ message: 'First message' }), event: 'message' });
|
||||||
|
onmessage &&
|
||||||
|
onmessage({ data: JSON.stringify({ message: 'Second message' }), event: 'message' });
|
||||||
|
onmessage &&
|
||||||
|
onmessage({ data: JSON.stringify({ message: 'Third message' }), event: 'message' });
|
||||||
|
|
||||||
|
// Simulate the onclose event
|
||||||
|
onclose && onclose();
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
// Simulate the onerror event
|
||||||
|
onerror && onerror({ status: 500 });
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a Promise that resolves to simulate async behavior
|
||||||
|
return Promise.resolve();
|
||||||
|
}),
|
||||||
|
}));
|
||||||
|
|
@ -6,7 +6,7 @@
|
||||||
* @property {string} azureOpenAIApiVersion - The Azure OpenAI API version.
|
* @property {string} azureOpenAIApiVersion - The Azure OpenAI API version.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const { isEnabled } = require('../server/utils');
|
const { isEnabled } = require('~/server/utils');
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Sanitizes the model name to be used in the URL by removing or replacing disallowed characters.
|
* Sanitizes the model name to be used in the URL by removing or replacing disallowed characters.
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
import { Content, Portal, Root } from '@radix-ui/react-popover';
|
|
||||||
import { alternateName } from 'librechat-data-provider';
|
import { alternateName } from 'librechat-data-provider';
|
||||||
|
import { Content, Portal, Root } from '@radix-ui/react-popover';
|
||||||
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
|
import { useGetEndpointsQuery } from 'librechat-data-provider/react-query';
|
||||||
import type { FC } from 'react';
|
import type { FC } from 'react';
|
||||||
import EndpointItems from './Endpoints/MenuItems';
|
import EndpointItems from './Endpoints/MenuItems';
|
||||||
|
|
@ -14,9 +14,14 @@ const EndpointsMenu: FC = () => {
|
||||||
|
|
||||||
const { conversation } = useChatContext();
|
const { conversation } = useChatContext();
|
||||||
const selected = conversation?.endpoint ?? '';
|
const selected = conversation?.endpoint ?? '';
|
||||||
|
|
||||||
|
if (!selected) {
|
||||||
|
console.warn('No endpoint selected');
|
||||||
|
return null;
|
||||||
|
}
|
||||||
return (
|
return (
|
||||||
<Root>
|
<Root>
|
||||||
<TitleButton primaryText={alternateName[selected] + ' '} />
|
<TitleButton primaryText={(alternateName[selected] ?? '') + ' '} />
|
||||||
<Portal>
|
<Portal>
|
||||||
<div
|
<div
|
||||||
style={{
|
style={{
|
||||||
|
|
|
||||||
|
|
@ -347,6 +347,11 @@ export const useLoginUserMutation = (): UseMutationResult<
|
||||||
return useMutation((payload: t.TLoginUser) => dataService.login(payload), {
|
return useMutation((payload: t.TLoginUser) => dataService.login(payload), {
|
||||||
onMutate: () => {
|
onMutate: () => {
|
||||||
queryClient.removeQueries();
|
queryClient.removeQueries();
|
||||||
|
localStorage.removeItem('lastConversationSetup');
|
||||||
|
localStorage.removeItem('lastSelectedModel');
|
||||||
|
localStorage.removeItem('lastSelectedTools');
|
||||||
|
localStorage.removeItem('filesToDelete');
|
||||||
|
localStorage.removeItem('lastAssistant');
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
@ -375,11 +380,6 @@ export const useRefreshTokenMutation = (): UseMutationResult<
|
||||||
return useMutation(() => request.refreshToken(), {
|
return useMutation(() => request.refreshToken(), {
|
||||||
onMutate: () => {
|
onMutate: () => {
|
||||||
queryClient.removeQueries();
|
queryClient.removeQueries();
|
||||||
localStorage.removeItem('lastConversationSetup');
|
|
||||||
localStorage.removeItem('lastSelectedModel');
|
|
||||||
localStorage.removeItem('lastSelectedTools');
|
|
||||||
localStorage.removeItem('filesToDelete');
|
|
||||||
localStorage.removeItem('lastAssistant');
|
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
};
|
};
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue