🚤 refactor: Optimize Request Lifecycle Speeds (#3222)

* refactor: optimize backend operations for client requests

* fix: message styling

* refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js

* refactor: Improve handleKeyUp logic in StreamRunManager.js and handleText.js

* fix: clear new convo messages on clear all convos

* fix: forgot to pass userId to getConvo

* refactor: update getPartialText to send basePayload.text
This commit is contained in:
Danny Avila 2024-06-28 08:44:47 -04:00 committed by GitHub
parent 83619de158
commit a2fd975cd5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
17 changed files with 115 additions and 65 deletions

View file

@ -1,7 +1,7 @@
const crypto = require('crypto'); const crypto = require('crypto');
const fetch = require('node-fetch'); const fetch = require('node-fetch');
const { supportsBalanceCheck, Constants } = require('librechat-data-provider'); const { supportsBalanceCheck, Constants } = require('librechat-data-provider');
const { getConvo, getMessages, saveMessage, updateMessage, saveConvo } = require('~/models'); const { getMessages, saveMessage, updateMessage, saveConvo } = require('~/models');
const { addSpaceIfNeeded, isEnabled } = require('~/server/utils'); const { addSpaceIfNeeded, isEnabled } = require('~/server/utils');
const checkBalance = require('~/models/checkBalance'); const checkBalance = require('~/models/checkBalance');
const { getFiles } = require('~/models/File'); const { getFiles } = require('~/models/File');
@ -23,6 +23,10 @@ class BaseClient {
this.skipSaveConvo = false; this.skipSaveConvo = false;
/** @type {boolean} */ /** @type {boolean} */
this.skipSaveUserMessage = false; this.skipSaveUserMessage = false;
/** @type {ClientDatabaseSavePromise} */
this.userMessagePromise;
/** @type {ClientDatabaseSavePromise} */
this.responsePromise;
} }
setOptions() { setOptions() {
@ -481,7 +485,12 @@ class BaseClient {
} }
if (!isEdited && !this.skipSaveUserMessage) { if (!isEdited && !this.skipSaveUserMessage) {
await this.saveMessageToDatabase(userMessage, saveOptions, user); this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise: this.userMessagePromise,
});
}
} }
if ( if (
@ -530,15 +539,11 @@ class BaseClient {
const completionTokens = this.getTokenCount(completion); const completionTokens = this.getTokenCount(completion);
await this.recordTokenUsage({ promptTokens, completionTokens }); await this.recordTokenUsage({ promptTokens, completionTokens });
} }
await this.saveMessageToDatabase(responseMessage, saveOptions, user); this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount; delete responseMessage.tokenCount;
return responseMessage; return responseMessage;
} }
async getConversation(conversationId, user = null) {
return await getConvo(user, conversationId);
}
async loadHistory(conversationId, parentMessageId = null) { async loadHistory(conversationId, parentMessageId = null) {
logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId }); logger.debug('[BaseClient] Loading history:', { conversationId, parentMessageId });
@ -593,7 +598,7 @@ class BaseClient {
* @param {string | null} user * @param {string | null} user
*/ */
async saveMessageToDatabase(message, endpointOptions, user = null) { async saveMessageToDatabase(message, endpointOptions, user = null) {
await saveMessage({ const savedMessage = await saveMessage({
...message, ...message,
endpoint: this.options.endpoint, endpoint: this.options.endpoint,
unfinished: false, unfinished: false,
@ -601,14 +606,16 @@ class BaseClient {
}); });
if (this.skipSaveConvo) { if (this.skipSaveConvo) {
return; return { message: savedMessage };
} }
await saveConvo(user, { const conversation = await saveConvo(user, {
conversationId: message.conversationId, conversationId: message.conversationId,
endpoint: this.options.endpoint, endpoint: this.options.endpoint,
endpointType: this.options.endpointType, endpointType: this.options.endpointType,
...endpointOptions, ...endpointOptions,
}); });
return { message: savedMessage, conversation };
} }
async updateMessageInDatabase(message) { async updateMessageInDatabase(message) {

View file

@ -238,7 +238,7 @@ class PluginsClient extends OpenAIClient {
await this.recordTokenUsage(responseMessage); await this.recordTokenUsage(responseMessage);
} }
await this.saveMessageToDatabase(responseMessage, saveOptions, user); this.responsePromise = this.saveMessageToDatabase(responseMessage, saveOptions, user);
delete responseMessage.tokenCount; delete responseMessage.tokenCount;
return { ...responseMessage, ...result }; return { ...responseMessage, ...result };
} }
@ -303,7 +303,12 @@ class PluginsClient extends OpenAIClient {
} }
if (!this.skipSaveUserMessage) { if (!this.skipSaveUserMessage) {
await this.saveMessageToDatabase(userMessage, saveOptions, user); this.userMessagePromise = this.saveMessageToDatabase(userMessage, saveOptions, user);
if (typeof opts?.getReqData === 'function') {
opts.getReqData({
userMessagePromise: this.userMessagePromise,
});
}
} }
if (isEnabled(process.env.CHECK_BALANCE)) { if (isEnabled(process.env.CHECK_BALANCE)) {

View file

@ -30,7 +30,7 @@ module.exports = {
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, { return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true, new: true,
upsert: true, upsert: true,
}); }).lean();
} catch (error) { } catch (error) {
logger.error('[saveConvo] Error saving conversation', error); logger.error('[saveConvo] Error saving conversation', error);
return { message: 'Error saving conversation' }; return { message: 'Error saving conversation' };

View file

@ -57,18 +57,11 @@ module.exports = {
if (files) { if (files) {
update.files = files; update.files = files;
} }
// may also need to update the conversation here
await Message.findOneAndUpdate({ messageId }, update, { upsert: true, new: true });
return { return await Message.findOneAndUpdate({ messageId }, update, {
messageId, upsert: true,
conversationId, new: true,
parentMessageId, }).lean();
sender,
text,
isCreatedByUser,
tokenCount,
};
} catch (err) { } catch (err) {
logger.error('Error saving message:', err); logger.error('Error saving message:', err);
throw new Error('Failed to save message.'); throw new Error('Failed to save message.');

View file

@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider'); const { getResponseSender, Constants, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models'); const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
const AskController = async (req, res, next, initializeClient, addTitle) => { const AskController = async (req, res, next, initializeClient, addTitle) => {
@ -18,6 +18,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
logger.debug('[AskController]', { text, conversationId, ...endpointOption }); logger.debug('[AskController]', { text, conversationId, ...endpointOption });
let userMessage; let userMessage;
let userMessagePromise;
let promptTokens; let promptTokens;
let userMessageId; let userMessageId;
let responseMessageId; let responseMessageId;
@ -34,6 +35,8 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
if (key === 'userMessage') { if (key === 'userMessage') {
userMessage = data[key]; userMessage = data[key];
userMessageId = data[key].messageId; userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') { } else if (key === 'responseMessageId') {
responseMessageId = data[key]; responseMessageId = data[key];
} else if (key === 'promptTokens') { } else if (key === 'promptTokens') {
@ -74,6 +77,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
const getAbortData = () => ({ const getAbortData = () => ({
sender, sender,
conversationId, conversationId,
userMessagePromise,
messageId: responseMessageId, messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId, parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(), text: getPartialText(),
@ -121,7 +125,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
response.endpoint = endpointOption.endpoint; response.endpoint = endpointOption.endpoint;
const conversation = await getConvo(user, conversationId); const { conversation = {} } = await client.responsePromise;
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -2,7 +2,7 @@ const throttle = require('lodash/throttle');
const { getResponseSender, EModelEndpoint } = require('librechat-data-provider'); const { getResponseSender, EModelEndpoint } = require('librechat-data-provider');
const { createAbortController, handleAbortError } = require('~/server/middleware'); const { createAbortController, handleAbortError } = require('~/server/middleware');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { saveMessage, getConvo } = require('~/models'); const { saveMessage } = require('~/models');
const { logger } = require('~/config'); const { logger } = require('~/config');
const EditController = async (req, res, next, initializeClient) => { const EditController = async (req, res, next, initializeClient) => {
@ -27,6 +27,7 @@ const EditController = async (req, res, next, initializeClient) => {
}); });
let userMessage; let userMessage;
let userMessagePromise;
let promptTokens; let promptTokens;
const sender = getResponseSender({ const sender = getResponseSender({
...endpointOption, ...endpointOption,
@ -40,6 +41,8 @@ const EditController = async (req, res, next, initializeClient) => {
for (let key in data) { for (let key in data) {
if (key === 'userMessage') { if (key === 'userMessage') {
userMessage = data[key]; userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') { } else if (key === 'responseMessageId') {
responseMessageId = data[key]; responseMessageId = data[key];
} else if (key === 'promptTokens') { } else if (key === 'promptTokens') {
@ -73,6 +76,7 @@ const EditController = async (req, res, next, initializeClient) => {
const getAbortData = () => ({ const getAbortData = () => ({
conversationId, conversationId,
userMessagePromise,
messageId: responseMessageId, messageId: responseMessageId,
sender, sender,
parentMessageId: overrideParentMessageId ?? userMessageId, parentMessageId: overrideParentMessageId ?? userMessageId,
@ -120,7 +124,7 @@ const EditController = async (req, res, next, initializeClient) => {
}, },
}); });
const conversation = await getConvo(user, conversationId); const { conversation = {} } = await client.responsePromise;
conversation.title = conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat'; conversation && !conversation.title ? null : conversation?.title || 'New Chat';

View file

@ -1,9 +1,9 @@
const { isAssistantsEndpoint } = require('librechat-data-provider'); const { isAssistantsEndpoint } = require('librechat-data-provider');
const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils'); const { sendMessage, sendError, countTokens, isEnabled } = require('~/server/utils');
const { truncateText, smartTruncateText } = require('~/app/clients/prompts'); const { truncateText, smartTruncateText } = require('~/app/clients/prompts');
const { saveMessage, getConvo, getConvoTitle } = require('~/models');
const clearPendingReq = require('~/cache/clearPendingReq'); const clearPendingReq = require('~/cache/clearPendingReq');
const abortControllers = require('./abortControllers'); const abortControllers = require('./abortControllers');
const { saveMessage, getConvo } = require('~/models');
const spendTokens = require('~/models/spendTokens'); const spendTokens = require('~/models/spendTokens');
const { abortRun } = require('./abortRun'); const { abortRun } = require('./abortRun');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -90,7 +90,8 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
abortController.abortCompletion = async function () { abortController.abortCompletion = async function () {
abortController.abort(); abortController.abort();
const { conversationId, userMessage, promptTokens, ...responseData } = getAbortData(); const { conversationId, userMessage, userMessagePromise, promptTokens, ...responseData } =
getAbortData();
const completionTokens = await countTokens(responseData?.text ?? ''); const completionTokens = await countTokens(responseData?.text ?? '');
const user = req.user.id; const user = req.user.id;
@ -114,10 +115,20 @@ const createAbortController = (req, res, getAbortData, getReqData) => {
saveMessage({ ...responseMessage, user }); saveMessage({ ...responseMessage, user });
let conversation;
if (userMessagePromise) {
const resolved = await userMessagePromise;
conversation = resolved?.conversation;
}
if (!conversation) {
conversation = await getConvo(req.user.id, conversationId);
}
return { return {
title: await getConvoTitle(user, conversationId), title: conversation && !conversation.title ? null : conversation?.title || 'New Chat',
final: true, final: true,
conversation: await getConvo(user, conversationId), conversation,
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: responseMessage, responseMessage: responseMessage,
}; };

View file

@ -1,4 +1,4 @@
const { getConvo } = require('../../models'); const { getConvo } = require('~/models');
// Middleware to validate conversationId and user relationship // Middleware to validate conversationId and user relationship
const validateMessageReq = async (req, res, next) => { const validateMessageReq = async (req, res, next) => {

View file

@ -2,9 +2,9 @@ const express = require('express');
const throttle = require('lodash/throttle'); const throttle = require('lodash/throttle');
const { getResponseSender, Constants } = require('librechat-data-provider'); const { getResponseSender, Constants } = require('librechat-data-provider');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils'); const { sendMessage, createOnProgress } = require('~/server/utils');
const { addTitle } = require('~/server/services/Endpoints/openAI'); const { addTitle } = require('~/server/services/Endpoints/openAI');
const { saveMessage } = require('~/models');
const { const {
handleAbort, handleAbort,
createAbortController, createAbortController,
@ -41,6 +41,7 @@ router.post(
logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption }); logger.debug('[/ask/gptPlugins]', { text, conversationId, ...endpointOption });
let userMessage; let userMessage;
let userMessagePromise;
let promptTokens; let promptTokens;
let userMessageId; let userMessageId;
let responseMessageId; let responseMessageId;
@ -58,6 +59,8 @@ router.post(
if (key === 'userMessage') { if (key === 'userMessage') {
userMessage = data[key]; userMessage = data[key];
userMessageId = data[key].messageId; userMessageId = data[key].messageId;
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') { } else if (key === 'responseMessageId') {
responseMessageId = data[key]; responseMessageId = data[key];
} else if (key === 'promptTokens') { } else if (key === 'promptTokens') {
@ -151,6 +154,7 @@ router.post(
const getAbortData = () => ({ const getAbortData = () => ({
sender, sender,
conversationId, conversationId,
userMessagePromise,
messageId: responseMessageId, messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId, parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(), text: getPartialText(),
@ -207,10 +211,14 @@ router.post(
response.plugins = plugins.map((p) => ({ ...p, loading: false })); response.plugins = plugins.map((p) => ({ ...p, loading: false }));
await saveMessage({ ...response, user }); await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, { sendMessage(res, {
title: await getConvoTitle(user, conversationId), title: conversation.title,
final: true, final: true,
conversation: await getConvo(user, conversationId), conversation,
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: response, responseMessage: response,
}); });

View file

@ -13,7 +13,7 @@ const {
} = require('~/server/middleware'); } = require('~/server/middleware');
const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils'); const { sendMessage, createOnProgress, formatSteps, formatAction } = require('~/server/utils');
const { initializeClient } = require('~/server/services/Endpoints/gptPlugins'); const { initializeClient } = require('~/server/services/Endpoints/gptPlugins');
const { saveMessage, getConvoTitle, getConvo } = require('~/models'); const { saveMessage } = require('~/models');
const { validateTools } = require('~/app'); const { validateTools } = require('~/app');
const { logger } = require('~/config'); const { logger } = require('~/config');
@ -49,6 +49,7 @@ router.post(
}); });
let userMessage; let userMessage;
let userMessagePromise;
let promptTokens; let promptTokens;
const sender = getResponseSender({ const sender = getResponseSender({
...endpointOption, ...endpointOption,
@ -68,6 +69,8 @@ router.post(
for (let key in data) { for (let key in data) {
if (key === 'userMessage') { if (key === 'userMessage') {
userMessage = data[key]; userMessage = data[key];
} else if (key === 'userMessagePromise') {
userMessagePromise = data[key];
} else if (key === 'responseMessageId') { } else if (key === 'responseMessageId') {
responseMessageId = data[key]; responseMessageId = data[key];
} else if (key === 'promptTokens') { } else if (key === 'promptTokens') {
@ -119,6 +122,7 @@ router.post(
const getAbortData = () => ({ const getAbortData = () => ({
sender, sender,
conversationId, conversationId,
userMessagePromise,
messageId: responseMessageId, messageId: responseMessageId,
parentMessageId: overrideParentMessageId ?? userMessageId, parentMessageId: overrideParentMessageId ?? userMessageId,
text: getPartialText(), text: getPartialText(),
@ -179,10 +183,14 @@ router.post(
response.plugin = { ...plugin, loading: false }; response.plugin = { ...plugin, loading: false };
await saveMessage({ ...response, user }); await saveMessage({ ...response, user });
const { conversation = {} } = await client.responsePromise;
conversation.title =
conversation && !conversation.title ? null : conversation?.title || 'New Chat';
sendMessage(res, { sendMessage(res, {
title: await getConvoTitle(user, conversationId), title: conversation.title,
final: true, final: true,
conversation: await getConvo(user, conversationId), conversation,
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: response, responseMessage: response,
}); });

View file

@ -427,7 +427,7 @@ class StreamRunManager {
const toolCallDelta = toolCall[toolCall.type]; const toolCallDelta = toolCall[toolCall.type];
const progressCallback = this.progressCallbacks.get(stepKey); const progressCallback = this.progressCallbacks.get(stepKey);
await progressCallback(toolCallDelta); progressCallback(toolCallDelta);
} }
} }

View file

@ -12,33 +12,35 @@ const citationRegex = /\[\^\d+?\^]/g;
const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text); const addSpaceIfNeeded = (text) => (text.length > 0 && !text.endsWith(' ') ? text + ' ' : text);
const base = { message: true, initial: true };
const createOnProgress = ({ generation = '', onProgress: _onProgress }) => { const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
let i = 0; let i = 0;
let tokens = addSpaceIfNeeded(generation); let tokens = addSpaceIfNeeded(generation);
const progressCallback = async (partial, { res, text, bing = false, ...rest }) => { const basePayload = Object.assign({}, base, { text: tokens || '' });
const progressCallback = (partial, { res, text, ...rest }) => {
let chunk = partial === text ? '' : partial; let chunk = partial === text ? '' : partial;
tokens += chunk; basePayload.text = basePayload.text + chunk;
tokens = tokens.replaceAll('[DONE]', '');
if (bing) { const payload = Object.assign({}, basePayload, rest);
tokens = citeText(tokens, true); sendMessage(res, payload);
if (_onProgress) {
_onProgress(payload);
}
if (i === 0) {
basePayload.initial = false;
} }
const payload = { text: tokens, message: true, initial: i === 0, ...rest };
sendMessage(res, { ...payload, text: tokens });
_onProgress && _onProgress(payload);
i++; i++;
}; };
const sendIntermediateMessage = (res, payload, extraTokens = '') => { const sendIntermediateMessage = (res, payload, extraTokens = '') => {
tokens += extraTokens; basePayload.text = basePayload.text + extraTokens;
sendMessage(res, { const message = Object.assign({}, basePayload, payload);
text: tokens?.length === 0 ? '' : tokens, sendMessage(res, message);
message: true, if (i === 0) {
initial: i === 0, basePayload.initial = false;
...payload, }
});
i++; i++;
}; };
@ -47,7 +49,7 @@ const createOnProgress = ({ generation = '', onProgress: _onProgress }) => {
}; };
const getPartialText = () => { const getPartialText = () => {
return tokens; return basePayload.text;
}; };
return { onProgress, getPartialText, sendIntermediateMessage }; return { onProgress, getPartialText, sendIntermediateMessage };

View file

@ -1442,3 +1442,10 @@
* @typedef {import('librechat-data-provider').TForkConvoRequest} TForkConvoRequest * @typedef {import('librechat-data-provider').TForkConvoRequest} TForkConvoRequest
* @memberof typedefs * @memberof typedefs
*/ */
/** Clients */
/**
* @typedef {Promise<{ message: TMessage, conversation: TConversation }> | undefined} ClientDatabaseSavePromise
* @memberof typedefs
*/

View file

@ -122,7 +122,7 @@ const MessageRender = React.memo(
</div> </div>
</div> </div>
{!msg?.children?.length && (isSubmittingFamily || isSubmitting) ? ( {!msg?.children?.length && (isSubmittingFamily || isSubmitting) ? (
<PlaceholderRow isLast={isLast} isCard={isCard} /> <PlaceholderRow isCard={isCard} />
) : ( ) : (
<SubRow classes="text-xs"> <SubRow classes="text-xs">
<SiblingSwitch <SiblingSwitch

View file

@ -1,12 +1,9 @@
import { memo } from 'react'; import { memo } from 'react';
const PlaceholderRow = memo(({ isLast, isCard }: { isLast: boolean; isCard?: boolean }) => { const PlaceholderRow = memo(({ isCard }: { isCard?: boolean }) => {
if (!isCard) { if (!isCard) {
return null; return null;
} }
if (!isLast) {
return null;
}
return <div className="mt-1 h-[27px] bg-transparent" />; return <div className="mt-1 h-[27px] bg-transparent" />;
}); });

View file

@ -1,9 +1,9 @@
import { useState } from 'react'; import { useState } from 'react';
import { Dialog } from '~/components/ui/';
import DialogTemplate from '~/components/ui/DialogTemplate';
import { ClearChatsButton } from './SettingsTabs/';
import { useClearConversationsMutation } from 'librechat-data-provider/react-query'; import { useClearConversationsMutation } from 'librechat-data-provider/react-query';
import { useLocalize, useConversation, useConversations } from '~/hooks'; import { useLocalize, useConversation, useConversations } from '~/hooks';
import DialogTemplate from '~/components/ui/DialogTemplate';
import { ClearChatsButton } from './SettingsTabs';
import { Dialog } from '~/components/ui';
const ClearConvos = ({ open, onOpenChange }) => { const ClearConvos = ({ open, onOpenChange }) => {
const { newConversation } = useConversation(); const { newConversation } = useConversation();

View file

@ -1,5 +1,7 @@
import { useCallback } from 'react'; import { useCallback } from 'react';
import { useNavigate } from 'react-router-dom'; import { useNavigate } from 'react-router-dom';
import { QueryKeys } from 'librechat-data-provider';
import { useQueryClient } from '@tanstack/react-query';
import { useSetRecoilState, useResetRecoilState, useRecoilCallback } from 'recoil'; import { useSetRecoilState, useResetRecoilState, useRecoilCallback } from 'recoil';
import { useGetEndpointsQuery, useGetModelsQuery } from 'librechat-data-provider/react-query'; import { useGetEndpointsQuery, useGetModelsQuery } from 'librechat-data-provider/react-query';
import type { import type {
@ -15,6 +17,7 @@ import store from '~/store';
const useConversation = () => { const useConversation = () => {
const navigate = useNavigate(); const navigate = useNavigate();
const queryClient = useQueryClient();
const setConversation = useSetRecoilState(store.conversation); const setConversation = useSetRecoilState(store.conversation);
const resetLatestMessage = useResetRecoilState(store.latestMessage); const resetLatestMessage = useResetRecoilState(store.latestMessage);
const setMessages = useSetRecoilState<TMessagesAtom>(store.messages); const setMessages = useSetRecoilState<TMessagesAtom>(store.messages);
@ -59,6 +62,7 @@ const useConversation = () => {
resetLatestMessage(); resetLatestMessage();
if (conversation.conversationId === 'new' && !modelsData) { if (conversation.conversationId === 'new' && !modelsData) {
queryClient.invalidateQueries([QueryKeys.messages, 'new']);
navigate('/c/new'); navigate('/c/new');
} }
}, },