From a953fc9f2b3022fbf4a20150fd3ffdfef61e14f2 Mon Sep 17 00:00:00 2001 From: Daniel Avila Date: Sun, 9 Apr 2023 22:21:27 -0400 Subject: [PATCH] wip: feat: abort messages and continue conversation fix(addToCache.js): remove unused variables and parameters feat(addToCache.js): add message to cache with id, parentMessageId, role, and text fix(askOpenAI.js): remove parentMessageId parameter from addToCache call feat(MessageHandler.jsx): add latestMessage to store on cancel of submission, and generate messageId and parentMessageId for latestMessage --- api/server/routes/ask/addToCache.js | 19 ++++++--------- api/server/routes/ask/askOpenAI.js | 4 ++-- .../src/components/MessageHandler/index.jsx | 24 ++++++++++++------- 3 files changed, 25 insertions(+), 22 deletions(-) diff --git a/api/server/routes/ask/addToCache.js b/api/server/routes/ask/addToCache.js index 869f47dc7f..d53e98cbe3 100644 --- a/api/server/routes/ask/addToCache.js +++ b/api/server/routes/ask/addToCache.js @@ -1,21 +1,16 @@ const Keyv = require('keyv'); const { KeyvFile } = require('keyv-file'); -const crypto = require('crypto'); const { saveMessage } = require('../../../models'); -const addToCache = async ({ - endpointOption, - conversationId, - userMessage, - latestMessage, - parentMessageId -}) => { +const addToCache = async ({ endpointOption, userMessage, latestMessage }) => { try { const conversationsCache = new Keyv({ store: new KeyvFile({ filename: './data/cache.json' }), - namespace: 'chatgpt', // should be 'bing' for bing/sydney + namespace: 'chatgpt' // should be 'bing' for bing/sydney }); + const { conversationId, messageId, parentMessageId, text } = latestMessage; + let conversation = await conversationsCache.get(conversationId); // used to generate a title for the conversation if none exists // let isNewConversation = false; @@ -38,13 +33,13 @@ const addToCache = async ({ } }; - const messageId = crypto.randomUUID(); + // const messageId = crypto.randomUUID(); let responseMessage = { id: messageId, parentMessageId, role: roles(endpointOption), - message: latestMessage + message: text }; await saveMessage({ @@ -52,7 +47,7 @@ const addToCache = async ({ conversationId, messageId, sender: responseMessage.role, - text: latestMessage + text }); conversation.messages.push(userMessage, responseMessage); diff --git a/api/server/routes/ask/askOpenAI.js b/api/server/routes/ask/askOpenAI.js index 56f4aee897..b21c7aba1f 100644 --- a/api/server/routes/ask/askOpenAI.js +++ b/api/server/routes/ask/askOpenAI.js @@ -10,7 +10,7 @@ const { handleError, sendMessage, createOnProgress, handleText } = require('./ha const abortControllers = new Map(); router.post('/abort', async (req, res) => { - const { abortKey, latestMessage, parentMessageId } = req.body; + const { abortKey, latestMessage } = req.body; console.log(`req.body`, req.body); if (!abortControllers.has(abortKey)) { return res.status(404).send('Request not found'); @@ -24,7 +24,7 @@ router.post('/abort', async (req, res) => { abortController.abort(); abortControllers.delete(abortKey); console.log('Aborted request', abortKey, userMessage, endpointOption); - await addToCache({ endpointOption, conversationId: abortKey, userMessage, latestMessage, parentMessageId }); + await addToCache({ endpointOption, userMessage, latestMessage }); res.status(200).send('Aborted'); }); diff --git a/client/src/components/MessageHandler/index.jsx b/client/src/components/MessageHandler/index.jsx index dd792e7ed0..b7736a4374 100644 --- a/client/src/components/MessageHandler/index.jsx +++ b/client/src/components/MessageHandler/index.jsx @@ -3,7 +3,7 @@ import { useRecoilValue, useRecoilState, useResetRecoilState, useSetRecoilState import { SSE } from '~/data-provider/sse.mjs'; import createPayload from '~/data-provider/createPayload'; import { useAbortRequestWithMessage } from '~/data-provider'; - +import { v4 } from 'uuid'; import store from '~/store'; export default function MessageHandler() { @@ -13,6 +13,7 @@ export default function MessageHandler() { const setConversation = useSetRecoilState(store.conversation); const resetLatestMessage = useResetRecoilState(store.latestMessage); const [lastResponse, setLastResponse] = useRecoilState(store.lastResponse); + const setLatestMessage = useSetRecoilState(store.latestMessage); const setSubmission = useSetRecoilState(store.submission); const [source, setSource] = useState(null); // const [abortKey, setAbortKey] = useState(null); @@ -50,6 +51,7 @@ export default function MessageHandler() { const cancelHandler = (data, submission) => { const { messages, message, initialResponse, isRegenerate = false } = submission; + const { text, messageId, parentMessageId } = data; if (isRegenerate) { setMessages([ @@ -63,14 +65,15 @@ export default function MessageHandler() { } ]); } else { + console.log('cancelHandler, isRegenerate = false'); setMessages([ ...messages, message, { ...initialResponse, - text: data, + text, parentMessageId: message?.messageId, - messageId: message?.messageId + '_' + messageId, // cancelled: true } ]); @@ -78,6 +81,7 @@ export default function MessageHandler() { setSource(null); setIsSubmitting(false); setSubmission(null); + setLatestMessage(data); } }; @@ -168,7 +172,14 @@ export default function MessageHandler() { const { endpoint } = submission.conversation; // splitting twice because the cursor may or may not be wrapped in a span - const latestMessage = lastResponse.split('█')[0].split('')[0]; + const latestMessageText = lastResponse.split('█')[0].split('')[0]; + const latestMessage = { + text: latestMessageText, + messageId: v4(), + parentMessageId: currentParent.messageId, + conversationId: currentParent.conversationId + }; + fetch(`/api/ask/${endpoint}/abort`, { method: 'POST', headers: { @@ -176,8 +187,7 @@ export default function MessageHandler() { }, body: JSON.stringify({ abortKey: currentParent.conversationId, - latestMessage, - parentMessageId: currentParent.messageId, + latestMessage }) }) .then(response => { @@ -195,8 +205,6 @@ export default function MessageHandler() { return; } - // events.oncancel = () => cancelHandler(latestResponseText, { ...submission, message }); - const { server, payload } = createPayload(submission); const events = new SSE(server, {