🚤 refactor: Optimize Request Lifecycle Speeds (#3222)

* refactor: optimize backend operations for client requests

* fix: message styling

* refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js

* refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js

* fix: clear new convo messages on clear all convos

* fix: forgot to pass userId to getConvo

* refactor: update getPartialText to send basePayload.text
This commit is contained in:
Danny Avila 2024-06-28 08:44:47 -04:00 committed by GitHub
parent 83619de158
commit a2fd975cd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 115 additions and 65 deletions

View file

@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const AskController = async (req, res, next, initializeClient, addTitle) => {
@ -18,6 +18,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
logger.debug('[AskController]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
@ -34,6 +35,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -74,6 +77,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@ -121,7 +125,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
response.endpoint = endpointOption.endpoint;
const conversation = await getConvo(user, conversationId);
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models');
const { saveMessage } = require('~/models');
const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => {
@ -27,6 +27,7 @@ const EditController = async (req, res, next, initializeClient) => {
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
@ -40,6 +41,8 @@ const EditController = async (req, res, next, initializeClient) => {
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -73,6 +76,7 @@ const EditController = async (req, res, next, initializeClient) => {
const getAbortData = () => ({
conversationId,
userMessagePromise,
messageId: responseMessageId,
sender,
parentMessageId: overrideParentMessageId ?? userMessageId,
@ -120,7 +124,7 @@ const EditController = async (req, res, next, initializeClient) => {
},
});
const conversation = await getConvo(user, conversationId);
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -1,9 +1,9 @@
const { isAssistantsEndpoint } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
const clearPendingReq = require('~/cache/clearPendingReq');
const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const spendTokens = require('~/models/spendTokens');
const { abortRun } = require('./abortRun');
const { logger } = require('~/config');
@ -90,7 +90,8 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
abortController.abortCompletion = async function () {
abortController.abort();
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData();
const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
getAbortData();
const completionTokens = await countTokens(responseData?.text ?? '');
const user = req.user.id;
@ -114,10 +115,20 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
saveMessage({ ...responseMessage, user });
let conversation;
if (userMessagePromise) {
const resolved = await userMessagePromise;
conversation = resolved?.conversation;
}
if (!conversation) {
conversation = await getConvo(req.user.id, conversationId);
}
return {
title: await getConvoTitle(user, conversationId),
title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
final: true,
conversation: await getConvo(user, conversationId),
conversation,
requestMessage: userMessage,
responseMessage: responseMessage,
};

View file

@ -1,4 +1,4 @@
const { getConvo } = require('../../models');
const { getConvo } = require('~/models');
// Middleware to validate conversationId and user relationship
const validateMessageReq = async (req, res, next) => {

View file

@ -2,9 +2,9 @@ const express = require('express');
const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage } = require('~/models');
const {
handleAbort,
createAbortController,
@ -41,6 +41,7 @@ router.post(
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
let userMessage;
let userMessagePromise;
let promptTokens;
let userMessageId;
let responseMessageId;
@ -58,6 +59,8 @@ router.post(
if (key === 'userMessage') {
userMessage = data[key];
userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -151,6 +154,7 @@ router.post(
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@ -207,10 +211,14 @@ router.post(
response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
title: conversation.title,
final: true,
conversation: await getConvo(user, conversationId),
conversation,
requestMessage: userMessage,
responseMessage: response,
});

View file

@ -13,7 +13,7 @@ const {
} = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { saveMessage } = require('~/models');
const { validateTools } = require('~/app');
const { logger } = require('~/config');
@ -49,6 +49,7 @@ router.post(
});
let userMessage;
let userMessagePromise;
let promptTokens;
const sender = getResponseSender({
...endpointOption,
@ -68,6 +69,8 @@ router.post(
for (let key in data) {
if (key === 'userMessage') {
userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') {
responseMessageId = data[key];
} else if (key === 'promptTokens') {
@ -119,6 +122,7 @@ router.post(
const getAbortData = () => ({
sender,
conversationId,
userMessagePromise,
messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(),
@ -179,10 +183,14 @@ router.post(
response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
title: conversation.title,
final: true,
conversation: await getConvo(user, conversationId),
conversation,
requestMessage: userMessage,
responseMessage: response,
});

View file

@ -427,7 +427,7 @@ class StreamRunManager {
const toolCallDelta = toolCall[toolCall.type];
const progressCallback = this.progressCallbacks.get(stepKey);
await progressCallback(toolCallDelta);
progressCallback(toolCallDelta);
}
}

View file

@ -12,33 +12,35 @@ const citationRegex = /\[\^\d+?\^]/g;
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
const base = { message: true, initial: true };
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
let i = 0;
let tokens = addSpaceIfNeeded(generation);
const progressCallback = async (partial, { res, text, bing = false, ...rest }) => {
const basePayload = Object.assign({}, base, { text: tokens || '' });
const progressCallback = (partial, { res, text, ...rest }) => {
let chunk = partial === text ? '' : partial;
tokens += chunk;
tokens = tokens.replaceAll('[DONE]', '');
basePayload.text = basePayload.text + chunk;
if (bing) {
tokens = citeText(tokens, true);
const payload = Object.assign({}, basePayload, rest);
sendMessage(res, payload);
if (_onProgress) {
_onProgress(payload);
}
if (i === 0) {
basePayload.initial = false;
}
const payload = { text: tokens, message: true, initial: i === 0, ...rest };
sendMessage(res, { ...payload, text: tokens });
_onProgress && _onProgress(payload);
i++;
};
const sendIntermediateMessage = (res, payload, extraTokens = '') => {
tokens += extraTokens;
sendMessage(res, {
text: tokens?.length === 0 ? '' : tokens,
message: true,
initial: i === 0,
...payload,
});
basePayload.text = basePayload.text + extraTokens;
const message = Object.assign({}, basePayload, payload);
sendMessage(res, message);
if (i === 0) {
basePayload.initial = false;
}
i++;
};
@ -47,7 +49,7 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
};
const getPartialText = () => {
return tokens;
return basePayload.text;
};
return { onProgress, getPartialText, sendIntermediateMessage };