mirror of
https://github.com/danny-avila/LibreChat.git
synced 2025-12-17 08:50:15 +01:00
* refactor(Chains/llms): allow passing callbacks * refactor(BaseClient): accurately count completion tokens as generation only * refactor(OpenAIClient): remove unused getTokenCountForResponse, pass streaming var and callbacks in initializeLLM * wip: summary prompt tokens * refactor(summarizeMessages): new cut-off strategy that generates a better summary by adding context from beginning, truncating the middle, and providing the end wip: draft out relevant providers and variables for token tracing * refactor(createLLM): make streaming prop false by default * chore: remove use of getTokenCountForResponse * refactor(agents): use BufferMemory as ConversationSummaryBufferMemory token usage not easy to trace * chore: remove passing of streaming prop, also console log useful vars for tracing * feat: formatFromLangChain helper function to count tokens for ChatModelStart * refactor(initializeLLM): add role for LLM tracing * chore(formatFromLangChain): update JSDoc * feat(formatMessages): formats langChain messages into OpenAI payload format * chore: install openai-chat-tokens * refactor(formatMessage): optimize conditional langChain logic fix(formatFromLangChain): fix destructuring * feat: accurate prompt tokens for ChatModelStart before generation * refactor(handleChatModelStart): move to callbacks dir, use factory function * refactor(initializeLLM): rename 'role' to 'context' * feat(Balance/Transaction): new schema/models for tracking token spend refactor(Key): factor out model export to separate file * refactor(initializeClient): add req,res objects to client options * feat: add-balance script to add to an existing users' token balance refactor(Transaction): use multiplier map/function, return balance update * refactor(Tx): update enum for tokenType, return 1 for multiplier if no map match * refactor(Tx): add fair fallback value multiplier incase the config result is undefined * refactor(Balance): rename 'tokens' to 'tokenCredits' * feat: balance check, add tx.js for new tx-related methods and tests * chore(summaryPrompts): update prompt token count * refactor(callbacks): pass req, res wip: check balance * refactor(Tx): make convoId a String type, fix(calculateTokenValue) * refactor(BaseClient): add conversationId as client prop when assigned * feat(RunManager): track LLM runs with manager, track token spend from LLM, refactor(OpenAIClient): use RunManager to create callbacks, pass user prop to langchain api calls * feat(spendTokens): helper to spend prompt/completion tokens * feat(checkBalance): add helper to check, log, deny request if balance doesn't have enough funds refactor(Balance): static check method to return object instead of boolean now wip(OpenAIClient): implement use of checkBalance * refactor(initializeLLM): add token buffer to assure summary isn't generated when subsequent payload is too large refactor(OpenAIClient): add checkBalance refactor(createStartHandler): add checkBalance * chore: remove prompt and completion token logging from route handler * chore(spendTokens): add JSDoc * feat(logTokenCost): record transactions for basic api calls * chore(ask/edit): invoke getResponseSender only once per API call * refactor(ask/edit): pass promptTokens to getIds and include in abort data * refactor(getIds -> getReqData): rename function * refactor(Tx): increase value if incomplete message * feat: record tokenUsage when message is aborted * refactor: subtract tokens when payload includes function_call * refactor: add namespace for token_balance * fix(spendTokens): only execute if corresponding token type amounts are defined * refactor(checkBalance): throws Error if not enough token credits * refactor(runTitleChain): pass and use signal, spread object props in create helpers, and use 'call' instead of 'run' * fix(abortMiddleware): circular dependency, and default to empty string for completionTokens * fix: properly cancel title requests when there isn't enough tokens to generate * feat(predictNewSummary): custom chain for summaries to allow signal passing refactor(summaryBuffer): use new custom chain * feat(RunManager): add getRunByConversationId method, refactor: remove run and throw llm error on handleLLMError * refactor(createStartHandler): if summary, add error details to runs * fix(OpenAIClient): support aborting from summarization & showing error to user refactor(summarizeMessages): remove unnecessary operations counting summaryPromptTokens and note for alternative, pass signal to summaryBuffer * refactor(logTokenCost -> recordTokenUsage): rename * refactor(checkBalance): include promptTokens in errorMessage * refactor(checkBalance/spendTokens): move to models dir * fix(createLanguageChain): correctly pass config * refactor(initializeLLM/title): add tokenBuffer of 150 for balance check * refactor(openAPIPlugin): pass signal and memory, filter functions by the one being called * refactor(createStartHandler): add error to run if context is plugins as well * refactor(RunManager/handleLLMError): throw error immediately if plugins, don't remove run * refactor(PluginsClient): pass memory and signal to tools, cleanup error handling logic * chore: use absolute equality for addTitle condition * refactor(checkBalance): move checkBalance to execute after userMessage and tokenCounts are saved, also make conditional * style: icon changes to match official * fix(BaseClient): getTokenCountForResponse -> getTokenCount * fix(formatLangChainMessages): add kwargs as fallback prop from lc_kwargs, update JSDoc * refactor(Tx.create): does not update balance if CHECK_BALANCE is not enabled * fix(e2e/cleanUp): cleanup new collections, import all model methods from index * fix(config/add-balance): add uncaughtException listener * fix: circular dependency * refactor(initializeLLM/checkBalance): append new generations to errorMessage if cost exceeds balance * fix(handleResponseMessage): only record token usage in this method if not error and completion is not skipped * fix(createStartHandler): correct condition for generations * chore: bump postcss due to moderate severity vulnerability * chore: bump zod due to low severity vulnerability * chore: bump openai & data-provider version * feat(types): OpenAI Message types * chore: update bun lockfile * refactor(CodeBlock): add error block formatting * refactor(utils/Plugin): factor out formatJSON and cn to separate files (json.ts and cn.ts), add extractJSON * chore(logViolation): delete user_id after error is logged * refactor(getMessageError -> Error): change to React.FC, add token_balance handling, use extractJSON to determine JSON instead of regex * fix(DALL-E): use latest openai SDK * chore: reorganize imports, fix type issue * feat(server): add balance route * fix(api/models): add auth * feat(data-provider): /api/balance query * feat: show balance if checking is enabled, refetch on final message or error * chore: update docs, .env.example with token_usage info, add balance script command * fix(Balance): fallback to empty obj for balance query * style: slight adjustment of balance element * docs(token_usage): add PR notes
616 lines
24 KiB
JavaScript
616 lines
24 KiB
JavaScript
const { initializeFakeClient } = require('./FakeClient');
|
|
|
|
jest.mock('../../../lib/db/connectDb');
|
|
jest.mock('../../../models', () => {
|
|
return function () {
|
|
return {
|
|
save: jest.fn(),
|
|
deleteConvos: jest.fn(),
|
|
getConvo: jest.fn(),
|
|
getMessages: jest.fn(),
|
|
saveMessage: jest.fn(),
|
|
updateMessage: jest.fn(),
|
|
saveConvo: jest.fn(),
|
|
};
|
|
};
|
|
});
|
|
|
|
jest.mock('langchain/chat_models/openai', () => {
|
|
return {
|
|
ChatOpenAI: jest.fn().mockImplementation(() => {
|
|
return {};
|
|
}),
|
|
};
|
|
});
|
|
|
|
let parentMessageId;
|
|
let conversationId;
|
|
const fakeMessages = [];
|
|
const userMessage = 'Hello, ChatGPT!';
|
|
const apiKey = 'fake-api-key';
|
|
|
|
const messageHistory = [
|
|
{ role: 'user', isCreatedByUser: true, text: 'Hello', messageId: '1' },
|
|
{ role: 'assistant', isCreatedByUser: false, text: 'Hi', messageId: '2', parentMessageId: '1' },
|
|
{
|
|
role: 'user',
|
|
isCreatedByUser: true,
|
|
text: 'What\'s up',
|
|
messageId: '3',
|
|
parentMessageId: '2',
|
|
},
|
|
];
|
|
|
|
describe('BaseClient', () => {
|
|
let TestClient;
|
|
const options = {
|
|
// debug: true,
|
|
modelOptions: {
|
|
model: 'gpt-3.5-turbo',
|
|
temperature: 0,
|
|
},
|
|
};
|
|
|
|
beforeEach(() => {
|
|
TestClient = initializeFakeClient(apiKey, options, fakeMessages);
|
|
TestClient.summarizeMessages = jest.fn().mockResolvedValue({
|
|
summaryMessage: {
|
|
role: 'system',
|
|
content: 'Refined answer',
|
|
},
|
|
summaryTokenCount: 5,
|
|
});
|
|
});
|
|
|
|
test('returns the input messages without instructions when addInstructions() is called with empty instructions', () => {
|
|
const messages = [{ content: 'Hello' }, { content: 'How are you?' }, { content: 'Goodbye' }];
|
|
const instructions = '';
|
|
const result = TestClient.addInstructions(messages, instructions);
|
|
expect(result).toEqual(messages);
|
|
});
|
|
|
|
test('returns the input messages with instructions properly added when addInstructions() is called with non-empty instructions', () => {
|
|
const messages = [{ content: 'Hello' }, { content: 'How are you?' }, { content: 'Goodbye' }];
|
|
const instructions = { content: 'Please respond to the question.' };
|
|
const result = TestClient.addInstructions(messages, instructions);
|
|
const expected = [
|
|
{ content: 'Hello' },
|
|
{ content: 'How are you?' },
|
|
{ content: 'Please respond to the question.' },
|
|
{ content: 'Goodbye' },
|
|
];
|
|
expect(result).toEqual(expected);
|
|
});
|
|
|
|
test('concats messages correctly in concatenateMessages()', () => {
|
|
const messages = [
|
|
{ name: 'User', content: 'Hello' },
|
|
{ name: 'Assistant', content: 'How can I help you?' },
|
|
{ name: 'User', content: 'I have a question.' },
|
|
];
|
|
const result = TestClient.concatenateMessages(messages);
|
|
const expected =
|
|
'User:\nHello\n\nAssistant:\nHow can I help you?\n\nUser:\nI have a question.\n\n';
|
|
expect(result).toBe(expected);
|
|
});
|
|
|
|
test('refines messages correctly in summarizeMessages()', async () => {
|
|
const messagesToRefine = [
|
|
{ role: 'user', content: 'Hello', tokenCount: 10 },
|
|
{ role: 'assistant', content: 'How can I help you?', tokenCount: 20 },
|
|
];
|
|
const remainingContextTokens = 100;
|
|
const expectedRefinedMessage = {
|
|
role: 'system',
|
|
content: 'Refined answer',
|
|
};
|
|
|
|
const result = await TestClient.summarizeMessages({ messagesToRefine, remainingContextTokens });
|
|
expect(result.summaryMessage).toEqual(expectedRefinedMessage);
|
|
});
|
|
|
|
test('gets messages within token limit (under limit) correctly in getMessagesWithinTokenLimit()', async () => {
|
|
TestClient.maxContextTokens = 100;
|
|
TestClient.shouldSummarize = true;
|
|
|
|
const messages = [
|
|
{ role: 'user', content: 'Hello', tokenCount: 5 },
|
|
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
|
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
|
];
|
|
const expectedContext = [
|
|
{ role: 'user', content: 'Hello', tokenCount: 5 }, // 'Hello'.length
|
|
{ role: 'assistant', content: 'How can I help you?', tokenCount: 19 },
|
|
{ role: 'user', content: 'I have a question.', tokenCount: 18 },
|
|
];
|
|
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
|
|
const expectedRemainingContextTokens = 58 - 3; // (100 - 5 - 19 - 18) - 3
|
|
const expectedMessagesToRefine = [];
|
|
|
|
const lastExpectedMessage =
|
|
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
|
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
|
|
|
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
|
|
|
expect(result.context).toEqual(expectedContext);
|
|
expect(result.summaryIndex).toEqual(expectedIndex);
|
|
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
|
|
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
|
});
|
|
|
|
test('gets result over token limit correctly in getMessagesWithinTokenLimit()', async () => {
|
|
TestClient.maxContextTokens = 50; // Set a lower limit
|
|
TestClient.shouldSummarize = true;
|
|
|
|
const messages = [
|
|
{ role: 'user', content: 'Hello', tokenCount: 30 },
|
|
{ role: 'assistant', content: 'How can I help you?', tokenCount: 30 },
|
|
{ role: 'user', content: 'I have a question.', tokenCount: 5 },
|
|
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 19 },
|
|
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 18 },
|
|
];
|
|
|
|
// Subtract 3 tokens for Assistant Label priming after all messages have been counted.
|
|
const expectedRemainingContextTokens = 5; // (50 - 18 - 19 - 5) - 3
|
|
const expectedMessagesToRefine = [
|
|
{ role: 'user', content: 'Hello', tokenCount: 30 },
|
|
{ role: 'assistant', content: 'How can I help you?', tokenCount: 30 },
|
|
];
|
|
const expectedContext = [
|
|
{ role: 'user', content: 'I have a question.', tokenCount: 5 },
|
|
{ role: 'user', content: 'I need a coffee, stat!', tokenCount: 19 },
|
|
{ role: 'assistant', content: 'Sure, I can help with that.', tokenCount: 18 },
|
|
];
|
|
|
|
const lastExpectedMessage =
|
|
expectedMessagesToRefine?.[expectedMessagesToRefine.length - 1] ?? {};
|
|
const expectedIndex = messages.findIndex((msg) => msg.content === lastExpectedMessage?.content);
|
|
|
|
const result = await TestClient.getMessagesWithinTokenLimit(messages);
|
|
|
|
expect(result.context).toEqual(expectedContext);
|
|
expect(result.summaryIndex).toEqual(expectedIndex);
|
|
expect(result.remainingContextTokens).toBe(expectedRemainingContextTokens);
|
|
expect(result.messagesToRefine).toEqual(expectedMessagesToRefine);
|
|
});
|
|
|
|
test('handles context strategy correctly in handleContextStrategy()', async () => {
|
|
TestClient.addInstructions = jest
|
|
.fn()
|
|
.mockReturnValue([
|
|
{ content: 'Hello' },
|
|
{ content: 'How can I help you?' },
|
|
{ content: 'Please provide more details.' },
|
|
{ content: 'I can assist you with that.' },
|
|
]);
|
|
TestClient.getMessagesWithinTokenLimit = jest.fn().mockReturnValue({
|
|
context: [
|
|
{ content: 'How can I help you?' },
|
|
{ content: 'Please provide more details.' },
|
|
{ content: 'I can assist you with that.' },
|
|
],
|
|
remainingContextTokens: 80,
|
|
messagesToRefine: [{ content: 'Hello' }],
|
|
summaryIndex: 3,
|
|
});
|
|
|
|
TestClient.getTokenCount = jest.fn().mockReturnValue(40);
|
|
|
|
const instructions = { content: 'Please provide more details.' };
|
|
const orderedMessages = [
|
|
{ content: 'Hello' },
|
|
{ content: 'How can I help you?' },
|
|
{ content: 'Please provide more details.' },
|
|
{ content: 'I can assist you with that.' },
|
|
];
|
|
const formattedMessages = [
|
|
{ content: 'Hello' },
|
|
{ content: 'How can I help you?' },
|
|
{ content: 'Please provide more details.' },
|
|
{ content: 'I can assist you with that.' },
|
|
];
|
|
const expectedResult = {
|
|
payload: [
|
|
{
|
|
role: 'system',
|
|
content: 'Refined answer',
|
|
},
|
|
{ content: 'How can I help you?' },
|
|
{ content: 'Please provide more details.' },
|
|
{ content: 'I can assist you with that.' },
|
|
],
|
|
promptTokens: expect.any(Number),
|
|
tokenCountMap: {},
|
|
messages: expect.any(Array),
|
|
};
|
|
|
|
TestClient.shouldSummarize = true;
|
|
const result = await TestClient.handleContextStrategy({
|
|
instructions,
|
|
orderedMessages,
|
|
formattedMessages,
|
|
});
|
|
|
|
expect(result).toEqual(expectedResult);
|
|
});
|
|
|
|
describe('getMessagesForConversation', () => {
|
|
it('should return an empty array if the parentMessageId does not exist', () => {
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: unorderedMessages,
|
|
parentMessageId: '999',
|
|
});
|
|
expect(result).toEqual([]);
|
|
});
|
|
|
|
it('should handle messages with messageId property', () => {
|
|
const messagesWithMessageId = [
|
|
{ messageId: '1', parentMessageId: null, text: 'Message 1' },
|
|
{ messageId: '2', parentMessageId: '1', text: 'Message 2' },
|
|
];
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: messagesWithMessageId,
|
|
parentMessageId: '2',
|
|
});
|
|
expect(result).toEqual([
|
|
{ messageId: '1', parentMessageId: null, text: 'Message 1' },
|
|
{ messageId: '2', parentMessageId: '1', text: 'Message 2' },
|
|
]);
|
|
});
|
|
|
|
const messagesWithNullParent = [
|
|
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
|
{ id: '2', parentMessageId: null, text: 'Message 2' },
|
|
];
|
|
|
|
it('should handle messages with null parentMessageId that are not root', () => {
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: messagesWithNullParent,
|
|
parentMessageId: '2',
|
|
});
|
|
expect(result).toEqual([{ id: '2', parentMessageId: null, text: 'Message 2' }]);
|
|
});
|
|
|
|
const cyclicMessages = [
|
|
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
|
{ id: '1', parentMessageId: '3', text: 'Message 1' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
|
];
|
|
|
|
it('should handle cyclic references without going into an infinite loop', () => {
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: cyclicMessages,
|
|
parentMessageId: '3',
|
|
});
|
|
expect(result).toEqual([
|
|
{ id: '1', parentMessageId: '3', text: 'Message 1' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
|
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
|
]);
|
|
});
|
|
|
|
const unorderedMessages = [
|
|
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
|
{ id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' },
|
|
];
|
|
|
|
it('should return ordered messages based on parentMessageId', () => {
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: unorderedMessages,
|
|
parentMessageId: '3',
|
|
});
|
|
expect(result).toEqual([
|
|
{ id: '1', parentMessageId: '00000000-0000-0000-0000-000000000000', text: 'Message 1' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
|
{ id: '3', parentMessageId: '2', text: 'Message 3' },
|
|
]);
|
|
});
|
|
|
|
const unorderedBranchedMessages = [
|
|
{ id: '4', parentMessageId: '2', text: 'Message 4', summary: 'Summary for Message 4' },
|
|
{ id: '10', parentMessageId: '7', text: 'Message 10' },
|
|
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
|
{ id: '6', parentMessageId: '5', text: 'Message 7' },
|
|
{ id: '7', parentMessageId: '5', text: 'Message 7' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2' },
|
|
{ id: '8', parentMessageId: '6', text: 'Message 8' },
|
|
{ id: '5', parentMessageId: '3', text: 'Message 5' },
|
|
{ id: '3', parentMessageId: '1', text: 'Message 3' },
|
|
{ id: '6', parentMessageId: '4', text: 'Message 6' },
|
|
{ id: '8', parentMessageId: '7', text: 'Message 9' },
|
|
{ id: '9', parentMessageId: '7', text: 'Message 9' },
|
|
{ id: '11', parentMessageId: '2', text: 'Message 11', summary: 'Summary for Message 11' },
|
|
];
|
|
|
|
it('should return ordered messages from a branched array based on parentMessageId', () => {
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: unorderedBranchedMessages,
|
|
parentMessageId: '10',
|
|
summary: true,
|
|
});
|
|
expect(result).toEqual([
|
|
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
|
{ id: '3', parentMessageId: '1', text: 'Message 3' },
|
|
{ id: '5', parentMessageId: '3', text: 'Message 5' },
|
|
{ id: '7', parentMessageId: '5', text: 'Message 7' },
|
|
{ id: '10', parentMessageId: '7', text: 'Message 10' },
|
|
]);
|
|
});
|
|
|
|
it('should return an empty array if no messages are provided', () => {
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: [],
|
|
parentMessageId: '3',
|
|
});
|
|
expect(result).toEqual([]);
|
|
});
|
|
|
|
it('should map over the ordered messages if mapMethod is provided', () => {
|
|
const mapMethod = (msg) => msg.text;
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: unorderedMessages,
|
|
parentMessageId: '3',
|
|
mapMethod,
|
|
});
|
|
expect(result).toEqual(['Message 1', 'Message 2', 'Message 3']);
|
|
});
|
|
|
|
let unorderedMessagesWithSummary = [
|
|
{ id: '4', parentMessageId: '3', text: 'Message 4' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2', summary: 'Summary for Message 2' },
|
|
{ id: '3', parentMessageId: '2', text: 'Message 3', summary: 'Summary for Message 3' },
|
|
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
|
];
|
|
|
|
it('should start with the message that has a summary property and continue until the specified parentMessageId', () => {
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: unorderedMessagesWithSummary,
|
|
parentMessageId: '4',
|
|
summary: true,
|
|
});
|
|
expect(result).toEqual([
|
|
{
|
|
id: '3',
|
|
parentMessageId: '2',
|
|
role: 'system',
|
|
text: 'Summary for Message 3',
|
|
summary: 'Summary for Message 3',
|
|
},
|
|
{ id: '4', parentMessageId: '3', text: 'Message 4' },
|
|
]);
|
|
});
|
|
|
|
it('should handle multiple summaries and return the branch from the latest to the parentMessageId', () => {
|
|
unorderedMessagesWithSummary = [
|
|
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2', summary: 'Summary for Message 2' },
|
|
{ id: '3', parentMessageId: '2', text: 'Message 3', summary: 'Summary for Message 3' },
|
|
{ id: '4', parentMessageId: '3', text: 'Message 4', summary: 'Summary for Message 4' },
|
|
{ id: '1', parentMessageId: null, text: 'Message 1' },
|
|
];
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: unorderedMessagesWithSummary,
|
|
parentMessageId: '5',
|
|
summary: true,
|
|
});
|
|
expect(result).toEqual([
|
|
{
|
|
id: '4',
|
|
parentMessageId: '3',
|
|
role: 'system',
|
|
text: 'Summary for Message 4',
|
|
summary: 'Summary for Message 4',
|
|
},
|
|
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
|
]);
|
|
});
|
|
|
|
it('should handle summary at root edge case and continue until the parentMessageId', () => {
|
|
unorderedMessagesWithSummary = [
|
|
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
|
{ id: '1', parentMessageId: null, text: 'Message 1', summary: 'Summary for Message 1' },
|
|
{ id: '4', parentMessageId: '3', text: 'Message 4', summary: 'Summary for Message 4' },
|
|
{ id: '2', parentMessageId: '1', text: 'Message 2', summary: 'Summary for Message 2' },
|
|
{ id: '3', parentMessageId: '2', text: 'Message 3', summary: 'Summary for Message 3' },
|
|
];
|
|
const result = TestClient.constructor.getMessagesForConversation({
|
|
messages: unorderedMessagesWithSummary,
|
|
parentMessageId: '5',
|
|
summary: true,
|
|
});
|
|
expect(result).toEqual([
|
|
{
|
|
id: '4',
|
|
parentMessageId: '3',
|
|
role: 'system',
|
|
text: 'Summary for Message 4',
|
|
summary: 'Summary for Message 4',
|
|
},
|
|
{ id: '5', parentMessageId: '4', text: 'Message 5' },
|
|
]);
|
|
});
|
|
});
|
|
|
|
describe('sendMessage', () => {
|
|
test('sendMessage should return a response message', async () => {
|
|
const expectedResult = expect.objectContaining({
|
|
sender: TestClient.sender,
|
|
text: expect.any(String),
|
|
isCreatedByUser: false,
|
|
messageId: expect.any(String),
|
|
parentMessageId: expect.any(String),
|
|
conversationId: expect.any(String),
|
|
});
|
|
|
|
const response = await TestClient.sendMessage(userMessage);
|
|
parentMessageId = response.messageId;
|
|
conversationId = response.conversationId;
|
|
expect(response).toEqual(expectedResult);
|
|
});
|
|
|
|
test('sendMessage should work with provided conversationId and parentMessageId', async () => {
|
|
const userMessage = 'Second message in the conversation';
|
|
const opts = {
|
|
conversationId,
|
|
parentMessageId,
|
|
getReqData: jest.fn(),
|
|
onStart: jest.fn(),
|
|
};
|
|
|
|
const expectedResult = expect.objectContaining({
|
|
sender: TestClient.sender,
|
|
text: expect.any(String),
|
|
isCreatedByUser: false,
|
|
messageId: expect.any(String),
|
|
parentMessageId: expect.any(String),
|
|
conversationId: opts.conversationId,
|
|
});
|
|
|
|
const response = await TestClient.sendMessage(userMessage, opts);
|
|
parentMessageId = response.messageId;
|
|
expect(response.conversationId).toEqual(conversationId);
|
|
expect(response).toEqual(expectedResult);
|
|
expect(opts.getReqData).toHaveBeenCalled();
|
|
expect(opts.onStart).toHaveBeenCalled();
|
|
expect(TestClient.getBuildMessagesOptions).toHaveBeenCalled();
|
|
expect(TestClient.getSaveOptions).toHaveBeenCalled();
|
|
});
|
|
|
|
test('should return chat history', async () => {
|
|
TestClient = initializeFakeClient(apiKey, options, messageHistory);
|
|
const chatMessages = await TestClient.loadHistory(conversationId, '2');
|
|
expect(TestClient.currentMessages).toHaveLength(2);
|
|
expect(chatMessages[0].text).toEqual('Hello');
|
|
|
|
const chatMessages2 = await TestClient.loadHistory(conversationId, '3');
|
|
expect(TestClient.currentMessages).toHaveLength(3);
|
|
expect(chatMessages2[chatMessages2.length - 1].text).toEqual('What\'s up');
|
|
});
|
|
|
|
/* Most of the new sendMessage logic revolving around edited/continued AI messages
|
|
* can be summarized by the following test. The condition will load the entire history up to
|
|
* the message that is being edited, which will trigger the AI API to 'continue' the response.
|
|
* The 'userMessage' is only passed by convention and is not necessary for the generation.
|
|
*/
|
|
it('should not push userMessage to currentMessages when isEdited is true and vice versa', async () => {
|
|
const overrideParentMessageId = 'user-message-id';
|
|
const responseMessageId = 'response-message-id';
|
|
const newHistory = messageHistory.slice();
|
|
newHistory.push({
|
|
role: 'assistant',
|
|
isCreatedByUser: false,
|
|
text: 'test message',
|
|
messageId: responseMessageId,
|
|
parentMessageId: '3',
|
|
});
|
|
|
|
TestClient = initializeFakeClient(apiKey, options, newHistory);
|
|
const sendMessageOptions = {
|
|
isEdited: true,
|
|
overrideParentMessageId,
|
|
parentMessageId: '3',
|
|
responseMessageId,
|
|
};
|
|
|
|
await TestClient.sendMessage('test message', sendMessageOptions);
|
|
const currentMessages = TestClient.currentMessages;
|
|
expect(currentMessages[currentMessages.length - 1].messageId).not.toEqual(
|
|
overrideParentMessageId,
|
|
);
|
|
|
|
// Test the opposite case
|
|
sendMessageOptions.isEdited = false;
|
|
await TestClient.sendMessage('test message', sendMessageOptions);
|
|
const currentMessages2 = TestClient.currentMessages;
|
|
expect(currentMessages2[currentMessages2.length - 1].messageId).toEqual(
|
|
overrideParentMessageId,
|
|
);
|
|
});
|
|
|
|
test('setOptions is called with the correct arguments', async () => {
|
|
TestClient.setOptions = jest.fn();
|
|
const opts = { conversationId: '123', parentMessageId: '456' };
|
|
await TestClient.sendMessage('Hello, world!', opts);
|
|
expect(TestClient.setOptions).toHaveBeenCalledWith(opts);
|
|
TestClient.setOptions.mockClear();
|
|
});
|
|
|
|
test('loadHistory is called with the correct arguments', async () => {
|
|
const opts = { conversationId: '123', parentMessageId: '456' };
|
|
await TestClient.sendMessage('Hello, world!', opts);
|
|
expect(TestClient.loadHistory).toHaveBeenCalledWith(
|
|
opts.conversationId,
|
|
opts.parentMessageId,
|
|
);
|
|
});
|
|
|
|
test('getReqData is called with the correct arguments', async () => {
|
|
const getReqData = jest.fn();
|
|
const opts = { getReqData };
|
|
const response = await TestClient.sendMessage('Hello, world!', opts);
|
|
expect(getReqData).toHaveBeenCalledWith({
|
|
userMessage: expect.objectContaining({ text: 'Hello, world!' }),
|
|
conversationId: response.conversationId,
|
|
responseMessageId: response.messageId,
|
|
});
|
|
});
|
|
|
|
test('onStart is called with the correct arguments', async () => {
|
|
const onStart = jest.fn();
|
|
const opts = { onStart };
|
|
await TestClient.sendMessage('Hello, world!', opts);
|
|
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));
|
|
});
|
|
|
|
test('saveMessageToDatabase is called with the correct arguments', async () => {
|
|
const saveOptions = TestClient.getSaveOptions();
|
|
const user = {}; // Mock user
|
|
const opts = { user };
|
|
await TestClient.sendMessage('Hello, world!', opts);
|
|
expect(TestClient.saveMessageToDatabase).toHaveBeenCalledWith(
|
|
expect.objectContaining({
|
|
sender: expect.any(String),
|
|
text: expect.any(String),
|
|
isCreatedByUser: expect.any(Boolean),
|
|
messageId: expect.any(String),
|
|
parentMessageId: expect.any(String),
|
|
conversationId: expect.any(String),
|
|
}),
|
|
saveOptions,
|
|
user,
|
|
);
|
|
});
|
|
|
|
test('sendCompletion is called with the correct arguments', async () => {
|
|
const payload = {}; // Mock payload
|
|
TestClient.buildMessages.mockReturnValue({ prompt: payload, tokenCountMap: null });
|
|
const opts = {};
|
|
await TestClient.sendMessage('Hello, world!', opts);
|
|
expect(TestClient.sendCompletion).toHaveBeenCalledWith(payload, opts);
|
|
});
|
|
|
|
test('getTokenCount for response is called with the correct arguments', async () => {
|
|
const tokenCountMap = {}; // Mock tokenCountMap
|
|
TestClient.buildMessages.mockReturnValue({ prompt: [], tokenCountMap });
|
|
TestClient.getTokenCount = jest.fn();
|
|
const response = await TestClient.sendMessage('Hello, world!', {});
|
|
expect(TestClient.getTokenCount).toHaveBeenCalledWith(response.text);
|
|
});
|
|
|
|
test('returns an object with the correct shape', async () => {
|
|
const response = await TestClient.sendMessage('Hello, world!', {});
|
|
expect(response).toEqual(
|
|
expect.objectContaining({
|
|
sender: expect.any(String),
|
|
text: expect.any(String),
|
|
isCreatedByUser: expect.any(Boolean),
|
|
messageId: expect.any(String),
|
|
parentMessageId: expect.any(String),
|
|
conversationId: expect.any(String),
|
|
}),
|
|
);
|
|
});
|
|
});
|
|
});
|