mirror of
https://github.com/danny-avila/LibreChat.git
synced 2026-01-29 22:05:18 +01:00
🌿 feat: Multi-response Streaming (#3191)
* chore: comment back handlePlusCommand * chore: ignore .git dir * refactor: pass newConversation to `useSelectMention` refactor: pass newConversation to Mention component refactor: useChatFunctions for modular use of `ask` and `regenerate` refactor: set latest message only for the first index in useChatFunctions refactor: pass setLatestMessage to useChatFunctions refactor: Pass setSubmission to useChatFunctions for submission handling refactor: consolidate event handlers to separate hook from useSSE WIP: additional response handlers feat: responsive added convo, clears on new chat/navigating to chat, assistants excluded feat: Add conversationByKeySelector to select any conversation by index WIP: handle second submission with messages paired to root * style: surface-primary-contrast * refactor: remove unnecessary console.log statement in useChatFunctions * refactor: Consolidate imports in ChatForm and Input hooks * refactor: compositional usage of useSSE for multiple streams * WIP: set latest 'multi' message * WIP: first pass, added response streaming * pass: performant multi-message stream * fix: styling and message render * second pass: modular, performant multi-stream * fix: align parentMessageId of multiMessage * refactor: move resetting latestMultiMessage * chore: update footer text in Chat component * fix: stop button styling * fix: handle abortMessage request for multi-response * clear messages but bug with latest message reset present * fix: add delay for additional message generation * fix: access LAST_CONVO_SETUP by index * style: add div to prevent layout shift before hover buttons render * chore: Update Message component styling for card messages * chore: move hook use order * fix: abort middleware using unsent field from req.body * feat: support multi-response stream from initial message * refactor: buildTree function to improve readability and remove unused code * feat: add logger for frontend dev * refactor: use depth to track if message is really last in its branch * fix(buildTree): default export * fix: share parent message Id and avoid duplication error for multi-response streams * fix: prevent addedConvo reset to response convo * feat: allow setting multi message as latest message to control which to respond to * chore: wrap setSiblingIdxRev with useCallback * chore: styling and allow editing messages * style: styling fixes * feat: Add "AddMultiConvo" component to Chat Header * feat: prevent clearing added convos on endpoint, preset, mention, or modelSpec switch * fix: message styling fixes, mainly related to code blocks * fix: stop button visibility logic * fix: Handle edge case in abortMiddleware for non-existant `abortControllers` * refactor: optimize/memoize icons * chore(GoogleClient): change info to debug logs * style: active message styling * style: prevent layout shift due to placeholder row * chore: remove unused code * fix: Update BaseClient to handle optional request body properties * fix(ci): `onStart` now accepts 2 args, the 2nd being responseMessageId * chore: bump data-provider
This commit is contained in:
parent
eef894e608
commit
156c52e293
72 changed files with 2697 additions and 1326 deletions
|
|
@ -19,6 +19,10 @@ class BaseClient {
|
|||
day: 'numeric',
|
||||
});
|
||||
this.fetch = this.fetch.bind(this);
|
||||
/** @type {boolean} */
|
||||
this.skipSaveConvo = false;
|
||||
/** @type {boolean} */
|
||||
this.skipSaveUserMessage = false;
|
||||
}
|
||||
|
||||
setOptions() {
|
||||
|
|
@ -84,19 +88,45 @@ class BaseClient {
|
|||
await stream.processTextStream(onProgress);
|
||||
}
|
||||
|
||||
/**
|
||||
* @returns {[string|undefined, string|undefined]}
|
||||
*/
|
||||
processOverideIds() {
|
||||
/** @type {Record<string, string | undefined>} */
|
||||
let { overrideConvoId, overrideUserMessageId } = this.options?.req?.body ?? {};
|
||||
if (overrideConvoId) {
|
||||
const [conversationId, index] = overrideConvoId.split(Constants.COMMON_DIVIDER);
|
||||
overrideConvoId = conversationId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveConvo = true;
|
||||
}
|
||||
}
|
||||
if (overrideUserMessageId) {
|
||||
const [userMessageId, index] = overrideUserMessageId.split(Constants.COMMON_DIVIDER);
|
||||
overrideUserMessageId = userMessageId;
|
||||
if (index !== '0') {
|
||||
this.skipSaveUserMessage = true;
|
||||
}
|
||||
}
|
||||
|
||||
return [overrideConvoId, overrideUserMessageId];
|
||||
}
|
||||
|
||||
async setMessageOptions(opts = {}) {
|
||||
if (opts && opts.replaceOptions) {
|
||||
this.setOptions(opts);
|
||||
}
|
||||
|
||||
const [overrideConvoId, overrideUserMessageId] = this.processOverideIds();
|
||||
const { isEdited, isContinued } = opts;
|
||||
const user = opts.user ?? null;
|
||||
this.user = user;
|
||||
const saveOptions = this.getSaveOptions();
|
||||
this.abortController = opts.abortController ?? new AbortController();
|
||||
const conversationId = opts.conversationId ?? crypto.randomUUID();
|
||||
const conversationId = overrideConvoId ?? opts.conversationId ?? crypto.randomUUID();
|
||||
const parentMessageId = opts.parentMessageId ?? Constants.NO_PARENT;
|
||||
const userMessageId = opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
const userMessageId =
|
||||
overrideUserMessageId ?? opts.overrideParentMessageId ?? crypto.randomUUID();
|
||||
let responseMessageId = opts.responseMessageId ?? crypto.randomUUID();
|
||||
let head = isEdited ? responseMessageId : parentMessageId;
|
||||
this.currentMessages = (await this.loadHistory(conversationId, head)) ?? [];
|
||||
|
|
@ -160,7 +190,7 @@ class BaseClient {
|
|||
}
|
||||
|
||||
if (typeof opts?.onStart === 'function') {
|
||||
opts.onStart(userMessage);
|
||||
opts.onStart(userMessage, responseMessageId);
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -450,7 +480,7 @@ class BaseClient {
|
|||
this.handleTokenCountMap(tokenCountMap);
|
||||
}
|
||||
|
||||
if (!isEdited) {
|
||||
if (!isEdited && !this.skipSaveUserMessage) {
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
|
|
@ -569,6 +599,10 @@ class BaseClient {
|
|||
unfinished: false,
|
||||
user,
|
||||
});
|
||||
|
||||
if (this.skipSaveConvo) {
|
||||
return;
|
||||
}
|
||||
await saveConvo(user, {
|
||||
conversationId: message.conversationId,
|
||||
endpoint: this.options.endpoint,
|
||||
|
|
|
|||
|
|
@ -737,7 +737,7 @@ class GoogleClient extends BaseClient {
|
|||
|
||||
let clientOptions = { ...parameters, maxRetries: 2 };
|
||||
|
||||
logger.info('Initialized title client options');
|
||||
logger.debug('Initialized title client options');
|
||||
|
||||
if (this.project_id) {
|
||||
clientOptions['authOptions'] = {
|
||||
|
|
@ -764,7 +764,7 @@ class GoogleClient extends BaseClient {
|
|||
|
||||
const modelName = clientOptions.modelName ?? clientOptions.model ?? '';
|
||||
if (modelName?.includes('1.5') && !this.project_id) {
|
||||
logger.info('Identified titling model as 1.5 version');
|
||||
logger.debug('Identified titling model as 1.5 version');
|
||||
/** @type {GenerativeModel} */
|
||||
const client = model;
|
||||
const requestOptions = {
|
||||
|
|
@ -790,7 +790,7 @@ class GoogleClient extends BaseClient {
|
|||
|
||||
return reply;
|
||||
} else {
|
||||
logger.info('Beginning titling');
|
||||
logger.debug('Beginning titling');
|
||||
const safetySettings = _payload.safetySettings;
|
||||
|
||||
const titleResponse = await model.invoke(messages, {
|
||||
|
|
@ -840,7 +840,7 @@ class GoogleClient extends BaseClient {
|
|||
} catch (e) {
|
||||
logger.error('[GoogleClient] There was an issue generating the title', e);
|
||||
}
|
||||
logger.info(`Title response: ${title}`);
|
||||
logger.debug(`Title response: ${title}`);
|
||||
return title;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -301,7 +301,10 @@ class PluginsClient extends OpenAIClient {
|
|||
if (payload) {
|
||||
this.currentMessages = payload;
|
||||
}
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
|
||||
if (!this.skipSaveUserMessage) {
|
||||
await this.saveMessageToDatabase(userMessage, saveOptions, user);
|
||||
}
|
||||
|
||||
if (isEnabled(process.env.CHECK_BALANCE)) {
|
||||
await checkBalance({
|
||||
|
|
|
|||
|
|
@ -576,7 +576,11 @@ describe('BaseClient', () => {
|
|||
const onStart = jest.fn();
|
||||
const opts = { onStart };
|
||||
await TestClient.sendMessage('Hello, world!', opts);
|
||||
expect(onStart).toHaveBeenCalledWith(expect.objectContaining({ text: 'Hello, world!' }));
|
||||
|
||||
expect(onStart).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ text: 'Hello, world!' }),
|
||||
expect.any(String),
|
||||
);
|
||||
});
|
||||
|
||||
test('saveMessageToDatabase is called with the correct arguments', async () => {
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[AskController] Request closed');
|
||||
|
|
@ -144,7 +144,9 @@ const AskController = async (req, res, next, initializeClient, addTitle) => {
|
|||
await saveMessage({ ...response, user });
|
||||
}
|
||||
|
||||
await saveMessage(userMessage);
|
||||
if (!client.skipSaveUserMessage) {
|
||||
await saveMessage(userMessage);
|
||||
}
|
||||
|
||||
if (addTitle && parentMessageId === Constants.NO_PARENT && newConvo) {
|
||||
addTitle(req, {
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ const EditController = async (req, res, next, initializeClient) => {
|
|||
promptTokens,
|
||||
});
|
||||
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
res.on('close', () => {
|
||||
logger.debug('[EditController] Request closed');
|
||||
|
|
|
|||
|
|
@ -9,23 +9,28 @@ const { abortRun } = require('./abortRun');
|
|||
const { logger } = require('~/config');
|
||||
|
||||
async function abortMessage(req, res) {
|
||||
let { abortKey, conversationId, endpoint } = req.body;
|
||||
|
||||
if (!abortKey && conversationId) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
let { abortKey, endpoint } = req.body;
|
||||
|
||||
if (isAssistantsEndpoint(endpoint)) {
|
||||
return await abortRun(req, res);
|
||||
}
|
||||
|
||||
const conversationId = abortKey?.split(':')?.[0] ?? req.user.id;
|
||||
|
||||
if (!abortControllers.has(abortKey) && abortControllers.has(conversationId)) {
|
||||
abortKey = conversationId;
|
||||
}
|
||||
|
||||
if (!abortControllers.has(abortKey) && !res.headersSent) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
|
||||
const { abortController } = abortControllers.get(abortKey);
|
||||
const { abortController } = abortControllers.get(abortKey) ?? {};
|
||||
if (!abortController) {
|
||||
return res.status(204).send({ message: 'Request not found' });
|
||||
}
|
||||
const finalEvent = await abortController.abortCompletion();
|
||||
logger.debug('[abortMessage] Aborted request', { abortKey });
|
||||
logger.info('[abortMessage] Aborted request', { abortKey });
|
||||
abortControllers.delete(abortKey);
|
||||
|
||||
if (res.headersSent && finalEvent) {
|
||||
|
|
@ -50,12 +55,32 @@ const handleAbort = () => {
|
|||
};
|
||||
};
|
||||
|
||||
const createAbortController = (req, res, getAbortData) => {
|
||||
const createAbortController = (req, res, getAbortData, getReqData) => {
|
||||
const abortController = new AbortController();
|
||||
const { endpointOption } = req.body;
|
||||
const onStart = (userMessage) => {
|
||||
|
||||
abortController.getAbortData = function () {
|
||||
return getAbortData();
|
||||
};
|
||||
|
||||
/**
|
||||
* @param {TMessage} userMessage
|
||||
* @param {string} responseMessageId
|
||||
*/
|
||||
const onStart = (userMessage, responseMessageId) => {
|
||||
sendMessage(res, { message: userMessage, created: true });
|
||||
const abortKey = userMessage?.conversationId ?? req.user.id;
|
||||
const prevRequest = abortControllers.get(abortKey);
|
||||
if (prevRequest && prevRequest?.abortController) {
|
||||
const data = prevRequest.abortController.getAbortData();
|
||||
getReqData({ userMessage: data?.userMessage });
|
||||
const addedAbortKey = `${abortKey}:${responseMessageId}`;
|
||||
abortControllers.set(addedAbortKey, { abortController, ...endpointOption });
|
||||
res.on('finish', function () {
|
||||
abortControllers.delete(addedAbortKey);
|
||||
});
|
||||
return;
|
||||
}
|
||||
abortControllers.set(abortKey, { abortController, ...endpointOption });
|
||||
|
||||
res.on('finish', function () {
|
||||
|
|
|
|||
|
|
@ -148,15 +148,6 @@ router.post(
|
|||
}
|
||||
};
|
||||
|
||||
const onChainEnd = () => {
|
||||
saveMessage({ ...userMessage, user });
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
const getAbortData = () => ({
|
||||
sender,
|
||||
conversationId,
|
||||
|
|
@ -167,12 +158,23 @@ router.post(
|
|||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
const onChainEnd = () => {
|
||||
if (!client.skipSaveUserMessage) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugins,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
conversationId,
|
||||
|
|
|
|||
|
|
@ -103,21 +103,6 @@ router.post(
|
|||
},
|
||||
});
|
||||
|
||||
const onAgentAction = (action, start = false) => {
|
||||
const formattedAction = formatAction(action);
|
||||
plugin.inputs.push(formattedAction);
|
||||
plugin.latest = formattedAction.plugin;
|
||||
if (!start) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
const onChainEnd = (data) => {
|
||||
let { intermediateSteps: steps } = data;
|
||||
plugin.outputs = steps && steps[0].action ? formatSteps(steps) : 'An error occurred.';
|
||||
|
|
@ -141,12 +126,27 @@ router.post(
|
|||
userMessage,
|
||||
promptTokens,
|
||||
});
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData);
|
||||
const { abortController, onStart } = createAbortController(req, res, getAbortData, getReqData);
|
||||
|
||||
try {
|
||||
endpointOption.tools = await validateTools(user, endpointOption.tools);
|
||||
const { client } = await initializeClient({ req, res, endpointOption });
|
||||
|
||||
const onAgentAction = (action, start = false) => {
|
||||
const formattedAction = formatAction(action);
|
||||
plugin.inputs.push(formattedAction);
|
||||
plugin.latest = formattedAction.plugin;
|
||||
if (!start && !client.skipSaveUserMessage) {
|
||||
saveMessage({ ...userMessage, user });
|
||||
}
|
||||
sendIntermediateMessage(res, {
|
||||
plugin,
|
||||
parentMessageId: userMessage.messageId,
|
||||
messageId: responseMessageId,
|
||||
});
|
||||
// logger.debug('PLUGIN ACTION', formattedAction);
|
||||
};
|
||||
|
||||
let response = await client.sendMessage(text, {
|
||||
user,
|
||||
generation,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue