🚤 refactor: Optimize Request Lifecycle Speeds (#3222)

* refactor: optimize backend operations for client requests

* fix: message styling

* refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js

* refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js

* fix: clear new convo messages on clear all convos

* fix: forgot to pass userId to getConvo

* refactor: update getPartialText to send basePayload.text
This commit is contained in:
Danny Avila 2024-06-28 08:44:47 -04:00 committed by GitHub
parent 83619de158
commit a2fd975cd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 115 additions and 65 deletions

View file

@ -1,7 +1,7 @@
const crypto = require('crypto');
const fetch = require('node-fetch');
const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File');
@ -23,6 +23,10 @@ class BaseClient {
this.skipSaveConvo = false;
/** @type {boolean} */
this.skipSaveUserMessage = false;
/** @type {ClientDatabaseSavePromise} */
this.userMessagePromise;
/** @type {ClientDatabaseSavePromise} */
this.responsePromise;
}
setOptions() {
@ -481,7 +485,12 @@ class BaseClient {
}
if (!isEdited && !this.skipSaveUserMessage) {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise: this.userMessagePromise,
});
}
}
if (
@ -530,15 +539,11 @@ class BaseClient {
const completionTokens = this.getTokenCount(completion);
await this.recordTokenUsage({ promptTokens, completionTokens });
}
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return responseMessage;
}
async getConversation(conversationId, user = null) {
return await getConvo(user, conversationId);
}
async loadHistory(conversationId, parentMessageId = null) {
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
@ -593,7 +598,7 @@ class BaseClient {
* @param {string | null} user
*/
async saveMessageToDatabase(message, endpointOptions, user = null) {
await saveMessage({
const savedMessage = await saveMessage({
...message,
endpoint: this.options.endpoint,
unfinished: false,
@ -601,14 +606,16 @@ class BaseClient {
});
if (this.skipSaveConvo) {
return;
return { message: savedMessage };
}
await saveConvo(user, {
const conversation = await saveConvo(user, {
conversationId: message.conversationId,
endpoint: this.options.endpoint,
endpointType: this.options.endpointType,
...endpointOptions,
});
return { message: savedMessage, conversation };
}
async updateMessageInDatabase(message) {

View file

@ -238,7 +238,7 @@ class PluginsClient extends OpenAIClient {
await this.recordTokenUsage(responseMessage);
}
await this.saveMessageToDatabase(responseMessage, saveOptions, user);
this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount;
return { ...responseMessage, ...result };
}
@ -303,7 +303,12 @@ class PluginsClient extends OpenAIClient {
}
if (!this.skipSaveUserMessage) {
await this.saveMessageToDatabase(userMessage, saveOptions, user);
this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise: this.userMessagePromise,
});
}
}
if (isEnabled(process.env.CHECK_BALANCE)) {

View file

@ -30,7 +30,7 @@ module.exports = {
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true,
upsert: true,
});
}).lean();
} catch (error) {
logger.error('[saveConvo] Error saving conversation', error);
return { message: 'Error saving conversation' };

View file

@ -57,18 +57,11 @@ module.exports = {
if (files) {
update.files = files;
}
// may also need to update the conversation here
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
return {
messageId,
conversationId,
parentMessageId,
sender,
text,
isCreatedByUser,
tokenCount,
};
return await Message.findOneAndUpdate({ messageId }, update, {
upsert: true,
new: true,
}).lean();
} catch (err) {
logger.error('Error saving message:', err);
throw new Error('Failed to save message.');

View file

@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AskController = async (req, res, next, initializeClient, addTitle) => {
@ -18,6 +18,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
@ -34,6 +35,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -74,6 +77,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@ -121,7 +125,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
response.endpoint = endpointOption.endpoint;
const conversation = await getConvo(user, conversationId);
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => {
@ -27,6 +27,7 @@ const EditController = async (req, res, next, initializeClient) => {
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
@ -40,6 +41,8 @@ const EditController = async (req, res, next, initializeClient) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -73,6 +76,7 @@ const EditController = async (req, res, next, initializeClient) => {
const getAbortData = () => ({
conversationId,
userMessagePromise,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
@ -120,7 +124,7 @@ const EditController = async (req, res, next, initializeClient) => {
},
});
const conversation = await getConvo(user, conversationId);
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -1,9 +1,9 @@
const { isAssistantsEndpoint } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
const clearPendingReq = require('~/cache/clearPendingReq');
const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const spendTokens = require('~/models/spendTokens');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
@ -90,7 +90,8 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
abortController.abortCompletion = async function () {
abortController.abort();
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
getAbortData();
const completionTokens = await countTokens(responseData?.text ?? '');
const user = req.user.id;
@ -114,10 +115,20 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
saveMessage({ ...responseMessage, user });
let conversation;
if (userMessagePromise) {
const resolved = await userMessagePromise;
conversation = resolved?.conversation;
}
if (!conversation) {
conversation = await getConvo(req.user.id, conversationId);
}
return {
title: await getConvoTitle(user, conversationId),
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
final: true,
conversation: await getConvo(user, conversationId),
conversation,
requestMessage: userMessage,
responseMessage: responseMessage,
};

View file

@ -1,4 +1,4 @@
const { getConvo } = require('../../models');
const { getConvo } = require('~/models');
// Middleware to validate conversationId and user relationship
const validateMessageReq = async (req, res, next) => {

View file

@ -2,9 +2,9 @@ const express = require('express');
const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage } = require('~/models');
const {
handleAbort,
createAbortController,
@ -41,6 +41,7 @@ router.post(
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
@ -58,6 +59,8 @@ router.post(
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -151,6 +154,7 @@ router.post(
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@ -207,10 +211,14 @@ router.post(
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
title: conversation.title,
final: true,
conversation: await getConvo(user, conversationId),
conversation,
requestMessage: userMessage,
responseMessage: response,
});

View file

@ -13,7 +13,7 @@ const {
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { saveMessage } = require('~/models');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
@ -49,6 +49,7 @@ router.post(
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
@ -68,6 +69,8 @@ router.post(
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -119,6 +122,7 @@ router.post(
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@ -179,10 +183,14 @@ router.post(
response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
title: conversation.title,
final: true,
conversation: await getConvo(user, conversationId),
conversation,
requestMessage: userMessage,
responseMessage: response,
});

View file

@ -427,7 +427,7 @@ class StreamRunManager {
const toolCallDelta = toolCall[toolCall.type];
const progressCallback = this.progressCallbacks.get(stepKey);
await progressCallback(toolCallDelta);
progressCallback(toolCallDelta);
}
}

View file

@ -12,33 +12,35 @@ const citationRegex = /\[\^\d+?\^]/g;
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
const base = { message: true, initial: true };
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
let i = 0;
let tokens = addSpaceIfNeeded(generation);
const progressCallback = async (partial, { res, text, bing = false, ...rest }) => {
const basePayload = Object.assign({}, base, { text: tokens || '' });
const progressCallback = (partial, { res, text, ...rest }) => {
let chunk = partial === text ? '' : partial;
tokens += chunk;
tokens = tokens.replaceAll('[DONE]', '');
basePayload.text = basePayload.text + chunk;
if (bing) {
tokens = citeText(tokens, true);
const payload = Object.assign({}, basePayload, rest);
sendMessage(res, payload);
if (_onProgress) {
_onProgress(payload);
}
if (i === 0) {
basePayload.initial = false;
}
const payload = { text: tokens, message: true, initial: i === 0, ...rest };
sendMessage(res, { ...payload, text: tokens });
_onProgress && _onProgress(payload);
i++;
};
const sendIntermediateMessage = (res, payload, extraTokens = '') => {
tokens += extraTokens;
sendMessage(res, {
text: tokens?.length === 0 ? '' : tokens,
message: true,
initial: i === 0,
...payload,
});
basePayload.text = basePayload.text + extraTokens;
const message = Object.assign({}, basePayload, payload);
sendMessage(res, message);
if (i === 0) {
basePayload.initial = false;
}
i++;
};
@ -47,7 +49,7 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
};
const getPartialText = () => {
return tokens;
return basePayload.text;
};
return { onProgress, getPartialText, sendIntermediateMessage };

View file

@ -1442,3 +1442,10 @@
* @typedef {import('librechat-data-provider').TForkConvoRequest} TForkConvoRequest
* @memberof typedefs
*/
/** Clients */
/**
* @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise
* @memberof typedefs
*/

View file

@ -122,7 +122,7 @@ const MessageRender = React.memo(
</div>
</div>
{!msg?.children?.length && (isSubmittingFamily || isSubmitting) ? (
<PlaceholderRow isLast={isLast} isCard={isCard} />
<PlaceholderRow isCard={isCard} />
) : (
<SubRow classes="text-xs">
<SiblingSwitch

View file

@ -1,12 +1,9 @@
import { memo } from 'react';
const PlaceholderRow = memo(({ isLast, isCard }: { isLast: boolean; isCard?: boolean }) => {
const PlaceholderRow = memo(({ isCard }: { isCard?: boolean }) => {
if (!isCard) {
return null;
}
if (!isLast) {
return null;
}
return <div className="mt-1 h-[27px] bg-transparent" />;
});

View file

@ -1,9 +1,9 @@
import { useState } from 'react';
import { Dialog } from '~/components/ui/';
import DialogTemplate from '~/components/ui/DialogTemplate';
import { ClearChatsButton } from './SettingsTabs/';
import { useClearConversationsMutation } from 'librechat-data-provider/react-query';
import { useLocalize, useConversation, useConversations } from '~/hooks';
import DialogTemplate from '~/components/ui/DialogTemplate';
import { ClearChatsButton } from './SettingsTabs';
import { Dialog } from '~/components/ui';
const ClearConvos = ({ open, onOpenChange }) => {
const { newConversation } = useConversation();

View file

@ -1,5 +1,7 @@
import { useCallback } from 'react';
import { useNavigate } from 'react-router-dom';
import { QueryKeys } from 'librechat-data-provider';
import { useQueryClient } from '@tanstack/react-query';
import { useSetRecoilState, useResetRecoilState, useRecoilCallback } from 'recoil';
import { useGetEndpointsQuery, useGetModelsQuery } from 'librechat-data-provider/react-query';
import type {
@ -15,6 +17,7 @@ import store from '~/store';
const useConversation = () => {
const navigate = useNavigate();
const queryClient = useQueryClient();
const setConversation = useSetRecoilState(store.conversation);
const resetLatestMessage = useResetRecoilState(store.latestMessage);
const setMessages = useSetRecoilState<TMessagesAtom>(store.messages);
@ -59,6 +62,7 @@ const useConversation = () => {
resetLatestMessage();
if (conversation.conversationId === 'new' && !modelsData) {
queryClient.invalidateQueries([QueryKeys.messages, 'new']);
navigate('/c/new');
}
},