diff --git a/api/app/google/GoogleClient.js b/api/app/google/GoogleClient.js index 819282b4d6..b867c2e81c 100644 --- a/api/app/google/GoogleClient.js +++ b/api/app/google/GoogleClient.js @@ -266,7 +266,7 @@ class GoogleAgent { const user = opts.user || null; const conversationId = opts.conversationId || crypto.randomUUID(); const parentMessageId = opts.parentMessageId || '00000000-0000-0000-0000-000000000000'; - const userMessageId = crypto.randomUUID(); + const userMessageId = opts.overrideParentMessageId || crypto.randomUUID(); const responseMessageId = crypto.randomUUID(); const messages = await this.loadHistory(conversationId, this.options?.parentMessageId); diff --git a/api/server/routes/ask/askGoogle.js b/api/server/routes/ask/askGoogle.js index f5c985d306..f49263392c 100644 --- a/api/server/routes/ask/askGoogle.js +++ b/api/server/routes/ask/askGoogle.js @@ -54,6 +54,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI let userMessageId; let responseMessageId; let lastSavedTimestamp = 0; + const { overrideParentMessageId = null } = req.body; try { const getIds = (data) => { @@ -74,7 +75,7 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI messageId: responseMessageId, sender: 'PaLM2', conversationId, - parentMessageId: userMessageId, + parentMessageId: overrideParentMessageId || userMessageId, text: partialText, unfinished: true, cancelled: false, @@ -113,12 +114,21 @@ const ask = async ({ text, endpointOption, parentMessageId = null, conversationI let response = await client.sendMessage(text, { getIds, user: req.user.id, - parentMessageId, conversationId, - onProgress: progressCallback.call(null, { res, text, parentMessageId: userMessageId }), + parentMessageId, + overrideParentMessageId, + onProgress: progressCallback.call(null, { + res, + text, + parentMessageId: overrideParentMessageId || userMessageId + }), abortController }); + if (overrideParentMessageId) { + response.parentMessageId = overrideParentMessageId; + } + await saveConvo(req.user.id, { ...endpointOption, ...endpointOption.modelOptions, diff --git a/client/src/components/MessageHandler/index.jsx b/client/src/components/MessageHandler/index.jsx index dda9cf8a3d..c4c4b725c8 100644 --- a/client/src/components/MessageHandler/index.jsx +++ b/client/src/components/MessageHandler/index.jsx @@ -1,8 +1,7 @@ -import { useEffect, useState } from 'react'; -import { useRecoilValue, useRecoilState, useResetRecoilState, useSetRecoilState } from 'recoil'; +import { useEffect } from 'react'; +import { useRecoilValue, useResetRecoilState, useSetRecoilState } from 'recoil'; import { SSE } from '~/data-provider/sse.mjs'; import createPayload from '~/data-provider/createPayload'; -import { useAbortRequestWithMessage } from '~/data-provider'; import store from '~/store'; import { useAuthContext } from '~/hooks/AuthContext'; @@ -117,8 +116,11 @@ export default function MessageHandler() { const { requestMessage, responseMessage, conversation } = data; // update the messages - if (isRegenerate) setMessages([...messages, responseMessage]); - else setMessages([...messages, requestMessage, responseMessage]); + if (isRegenerate) { + setMessages([...messages, responseMessage]); + } else { + setMessages([...messages, requestMessage, responseMessage]); + } setIsSubmitting(false); // refresh title