diff --git a/api/models/Conversation.js b/api/models/Conversation.js index 6e7e1ceaab..580a07f6e1 100644 --- a/api/models/Conversation.js +++ b/api/models/Conversation.js @@ -13,45 +13,21 @@ const getConvo = async (user, conversationId) => { module.exports = { Conversation, - saveConvo: async (user, { conversationId, newConversationId, title, ...convo }) => { + saveConvo: async (user, { conversationId, newConversationId, ...convo }) => { try { const messages = await getMessages({ conversationId }); - const update = { ...convo, messages }; - if (title) { - update.title = title; - update.user = user; - } + const update = { ...convo, messages, user }; if (newConversationId) { update.conversationId = newConversationId; } - if (!update.jailbreakConversationId) { - update.jailbreakConversationId = null; - } - return await Conversation.findOneAndUpdate( - { conversationId: conversationId, user }, - { $set: update }, - { new: true, upsert: true } - ).exec(); - } catch (error) { - console.log(error); - return { message: 'Error saving conversation' }; - } - }, - updateConvo: async (user, { conversationId, oldConvoId, ...update }) => { - try { - let convoId = conversationId; - if (oldConvoId) { - convoId = oldConvoId; - update.conversationId = conversationId; - } - - return await Conversation.findOneAndUpdate({ conversationId: convoId, user }, update, { - new: true + return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, { + new: true, + upsert: true }).exec(); } catch (error) { console.log(error); - return { message: 'Error updating conversation' }; + return { message: 'Error saving conversation' }; } }, getConvosByPage: async (user, pageNumber = 1, pageSize = 12) => { @@ -82,7 +58,7 @@ module.exports = { // will handle a syncing solution soon const deletedConvoIds = []; - convoIds.forEach(convo => + convoIds.forEach((convo) => promises.push( Conversation.findOne({ user, @@ -145,7 +121,7 @@ module.exports = { }, deleteConvos: async (user, filter) => { let toRemove = await Conversation.find({ ...filter, user }).select('conversationId'); - const ids = toRemove.map(instance => instance.conversationId); + const ids = toRemove.map((instance) => instance.conversationId); let deleteCount = await Conversation.deleteMany({ ...filter, user }).exec(); deleteCount.messages = await deleteMessages({ conversationId: { $in: ids } }); return deleteCount; diff --git a/api/models/Message.js b/api/models/Message.js index af3d454980..89dde7b128 100644 --- a/api/models/Message.js +++ b/api/models/Message.js @@ -9,7 +9,9 @@ module.exports = { sender, text, isCreatedByUser = false, - error + error, + unfinished, + cancelled }) => { try { // may also need to update the conversation here @@ -22,7 +24,9 @@ module.exports = { sender, text, isCreatedByUser, - error + error, + unfinished, + cancelled }, { upsert: true, new: true } ); @@ -45,7 +49,7 @@ module.exports = { return { message: 'Error deleting messages' }; } }, - getMessages: async filter => { + getMessages: async (filter) => { try { return await Message.find(filter).sort({ createdAt: 1 }).exec(); } catch (error) { @@ -53,7 +57,7 @@ module.exports = { return { message: 'Error getting messages' }; } }, - deleteMessages: async filter => { + deleteMessages: async (filter) => { try { return await Message.deleteMany(filter).exec(); } catch (error) { diff --git a/api/models/index.js b/api/models/index.js index 8ae2eba854..bda6239d7b 100644 --- a/api/models/index.js +++ b/api/models/index.js @@ -1,5 +1,5 @@ const { getMessages, saveMessage, deleteMessagesSince, deleteMessages } = require('./Message'); -const { getConvoTitle, getConvo, saveConvo, updateConvo } = require('./Conversation'); +const { getConvoTitle, getConvo, saveConvo } = require('./Conversation'); const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset'); module.exports = { @@ -11,7 +11,6 @@ module.exports = { getConvoTitle, getConvo, saveConvo, - updateConvo, getPreset, getPresets, diff --git a/api/models/schema/messageSchema.js b/api/models/schema/messageSchema.js index 9d21dd5b92..a8de751278 100644 --- a/api/models/schema/messageSchema.js +++ b/api/models/schema/messageSchema.js @@ -43,6 +43,14 @@ const messageSchema = mongoose.Schema( required: true, default: false }, + unfinished: { + type: Boolean, + default: false + }, + cancelled: { + type: Boolean, + default: false + }, error: { type: Boolean, default: false diff --git a/api/server/routes/ask/addToCache.js b/api/server/routes/ask/addToCache.js index d53e98cbe3..e2214ba178 100644 --- a/api/server/routes/ask/addToCache.js +++ b/api/server/routes/ask/addToCache.js @@ -2,14 +2,24 @@ const Keyv = require('keyv'); const { KeyvFile } = require('keyv-file'); const { saveMessage } = require('../../../models'); -const addToCache = async ({ endpointOption, userMessage, latestMessage }) => { +const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => { try { const conversationsCache = new Keyv({ store: new KeyvFile({ filename: './data/cache.json' }), namespace: 'chatgpt' // should be 'bing' for bing/sydney }); - const { conversationId, messageId, parentMessageId, text } = latestMessage; + const { + conversationId, + messageId: userMessageId, + parentMessageId: userParentMessageId, + text: userText + } = userMessage; + const { + messageId: responseMessageId, + parentMessageId: responseParentMessageId, + text: responseText + } = responseMessage; let conversation = await conversationsCache.get(conversationId); // used to generate a title for the conversation if none exists @@ -22,10 +32,7 @@ const addToCache = async ({ endpointOption, userMessage, latestMessage }) => { // isNewConversation = true; } - // const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation; - const roles = (options) => { - const { endpoint } = options; if (endpoint === 'openAI') { return options?.chatGptLabel || 'ChatGPT'; } else if (endpoint === 'bingAI') { @@ -33,24 +40,21 @@ const addToCache = async ({ endpointOption, userMessage, latestMessage }) => { } }; - // const messageId = crypto.randomUUID(); - - let responseMessage = { - id: messageId, - parentMessageId, - role: roles(endpointOption), - message: text + let _userMessage = { + id: userMessageId, + parentMessageId: userParentMessageId, + role: 'User', + message: userText }; - await saveMessage({ - ...responseMessage, - conversationId, - messageId, - sender: responseMessage.role, - text - }); + let _responseMessage = { + id: responseMessageId, + parentMessageId: responseParentMessageId, + role: roles(endpointOption), + message: responseText + }; - conversation.messages.push(userMessage, responseMessage); + conversation.messages.push(_userMessage, _responseMessage); await conversationsCache.set(conversationId, conversation); } catch (error) { diff --git a/api/server/routes/ask/askBingAI.js b/api/server/routes/ask/askBingAI.js index 9f8d9bc878..5d21ec4d27 100644 --- a/api/server/routes/ask/askBingAI.js +++ b/api/server/routes/ask/askBingAI.js @@ -2,7 +2,7 @@ const express = require('express'); const crypto = require('crypto'); const router = express.Router(); const { titleConvo, askBing } = require('../../../app'); -const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models'); +const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); router.post('/', async (req, res) => { @@ -95,6 +95,8 @@ const ask = async ({ }) => { let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; + let responseMessageId = crypto.randomUUID(); + res.writeHead(200, { Connection: 'keep-alive', 'Content-Type': 'text/event-stream', @@ -106,9 +108,26 @@ const ask = async ({ if (preSendRequest) sendMessage(res, { message: userMessage, created: true }); try { - const progressCallback = createOnProgress(); + let lastSavedTimestamp = 0; + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text }) => { + const currentTimestamp = Date.now(); + if (currentTimestamp - lastSavedTimestamp > 500) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI', + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: text, + unfinished: true, + cancelled: false, + error: false + }); + } + } + }); const abortController = new AbortController(); - res.on('close', () => abortController.abort()); let response = await askBing({ text, parentMessageId: userParentMessageId, @@ -135,14 +154,20 @@ const ask = async ({ let responseMessage = { conversationId: newConversationId, - messageId: newResponseMessageId, + messageId: responseMessageId, + newMessageId: newResponseMessageId, parentMessageId: overrideParentMessageId || newUserMassageId, sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI', text: await handleText(response, true), - suggestions: response.details.suggestedResponses && response.details.suggestedResponses.map(s => s.text) + suggestions: + response.details.suggestedResponses && response.details.suggestedResponses.map((s) => s.text), + unfinished: false, + cancelled: false, + error: false }; await saveMessage(responseMessage); + responseMessage.messageId = newResponseMessageId; // STEP2 update the convosation. @@ -204,7 +229,7 @@ const ask = async ({ if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); - await updateConvo(req?.session?.user?.username, { + await saveConvo(req?.session?.user?.username, { conversationId: conversationId, title }); @@ -212,10 +237,12 @@ const ask = async ({ } catch (error) { console.log(error); const errorMessage = { - messageId: crypto.randomUUID(), + messageId: responseMessageId, sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI', conversationId, parentMessageId: overrideParentMessageId || userMessageId, + unfinished: false, + cancelled: false, error: true, text: error.message }; diff --git a/api/server/routes/ask/askChatGPTBrowser.js b/api/server/routes/ask/askChatGPTBrowser.js index 4e416e7c4a..25b6f6b737 100644 --- a/api/server/routes/ask/askChatGPTBrowser.js +++ b/api/server/routes/ask/askChatGPTBrowser.js @@ -3,7 +3,7 @@ const crypto = require('crypto'); const router = express.Router(); const { getChatGPTBrowserModels } = require('../endpoints'); const { browserClient } = require('../../../app/'); -const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models'); +const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); router.post('/', async (req, res) => { @@ -38,7 +38,7 @@ router.post('/', async (req, res) => { }; const availableModels = getChatGPTBrowserModels(); - if (availableModels.find(model => model === endpointOption.model) === undefined) + if (availableModels.find((model) => model === endpointOption.model) === undefined) return handleError(res, { text: 'Illegal request: model' }); console.log('ask log', { @@ -92,10 +92,29 @@ const ask = async ({ if (preSendRequest) sendMessage(res, { message: userMessage, created: true }); + let responseMessageId = crypto.randomUUID(); + try { - const progressCallback = createOnProgress(); + let lastSavedTimestamp = 0; + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text }) => { + const currentTimestamp = Date.now(); + if (currentTimestamp - lastSavedTimestamp > 500) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI', + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: text, + unfinished: true, + cancelled: false, + error: false + }); + } + } + }); const abortController = new AbortController(); - res.on('close', () => abortController.abort()); let response = await browserClient({ text, parentMessageId: userParentMessageId, @@ -116,13 +135,18 @@ const ask = async ({ let responseMessage = { conversationId: newConversationId, - messageId: newResponseMessageId, + messageId: responseMessageId, + newMessageId: newResponseMessageId, parentMessageId: overrideParentMessageId || newUserMassageId, text: await handleText(response), - sender: endpointOption?.chatGptLabel || 'ChatGPT' + sender: endpointOption?.chatGptLabel || 'ChatGPT', + unfinished: false, + cancelled: false, + error: false }; await saveMessage(responseMessage); + responseMessage.messageId = newResponseMessageId; // STEP2 update the conversation @@ -168,17 +192,19 @@ const ask = async ({ if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { // const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); const title = await response.details.title; - await updateConvo(req?.session?.user?.username, { + await saveConvo(req?.session?.user?.username, { conversationId: conversationId, title }); } } catch (error) { const errorMessage = { - messageId: crypto.randomUUID(), + messageId: responseMessageId, sender: 'ChatGPT', conversationId, parentMessageId: overrideParentMessageId || userMessageId, + unfinished: false, + cancelled: false, error: true, text: error.message }; diff --git a/api/server/routes/ask/askOpenAI.js b/api/server/routes/ask/askOpenAI.js index b21c7aba1f..550518c2c0 100644 --- a/api/server/routes/ask/askOpenAI.js +++ b/api/server/routes/ask/askOpenAI.js @@ -4,29 +4,26 @@ const router = express.Router(); const addToCache = require('./addToCache'); const { getOpenAIModels } = require('../endpoints'); const { titleConvo, askClient } = require('../../../app/'); -const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models'); +const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); const abortControllers = new Map(); router.post('/abort', async (req, res) => { - const { abortKey, latestMessage } = req.body; + const { abortKey } = req.body; console.log(`req.body`, req.body); if (!abortControllers.has(abortKey)) { return res.status(404).send('Request not found'); } - - const { abortController, userMessage, endpointOption } = abortControllers.get(abortKey); - if (!endpointOption.endpoint) { - endpointOption.endpoint = req.originalUrl.replace('/api/ask/','').split('/abort')[0]; - } - abortController.abort(); + const { abortController } = abortControllers.get(abortKey); + abortControllers.delete(abortKey); - console.log('Aborted request', abortKey, userMessage, endpointOption); - await addToCache({ endpointOption, userMessage, latestMessage }); - - res.status(200).send('Aborted'); + const ret = await abortController.abortAsk(); + console.log('Aborted request', abortKey); + console.log('Aborted message:', ret); + + res.send(JSON.stringify(ret)); }); router.post('/', async (req, res) => { @@ -66,7 +63,7 @@ router.post('/', async (req, res) => { }; const availableModels = getOpenAIModels(); - if (availableModels.find(model => model === endpointOption.model) === undefined) + if (availableModels.find((model) => model === endpointOption.model) === undefined) return handleError(res, { text: 'Illegal request: model' }); console.log('ask log', { @@ -110,6 +107,8 @@ const ask = async ({ }) => { let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; + let responseMessageId = crypto.randomUUID(); + res.writeHead(200, { Connection: 'keep-alive', 'Content-Type': 'text/event-stream', @@ -121,16 +120,55 @@ const ask = async ({ if (preSendRequest) sendMessage(res, { message: userMessage, created: true }); try { - const progressCallback = createOnProgress(); - const abortController = new AbortController(); - const abortKey = conversationId; - console.log('conversationId -----> ', conversationId); - abortControllers.set(abortKey, { abortController, userMessage, endpointOption }); - - res.on('close', () => { - abortController.abort(); - return res.end(); + let lastSavedTimestamp = 0; + const { onProgress: progressCallback, getPartialText } = createOnProgress({ + onProgress: ({ text }) => { + const currentTimestamp = Date.now(); + if (currentTimestamp - lastSavedTimestamp > 500) { + lastSavedTimestamp = currentTimestamp; + saveMessage({ + messageId: responseMessageId, + sender: endpointOption?.chatGptLabel || 'ChatGPT', + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: text, + unfinished: true, + cancelled: false, + error: false + }); + } + } }); + + let abortController = new AbortController(); + abortController.abortAsk = async function () { + this.abort(); + + const responseMessage = { + messageId: responseMessageId, + sender: endpointOption?.chatGptLabel || 'ChatGPT', + conversationId, + parentMessageId: overrideParentMessageId || userMessageId, + text: getPartialText(), + unfinished: false, + cancelled: true, + error: false + }; + + saveMessage(responseMessage); + await addToCache({ endpoint: 'openAI', endpointOption, userMessage, responseMessage }); + + return { + title: await getConvoTitle(req?.session?.user?.username, conversationId), + final: true, + conversation: await getConvo(req?.session?.user?.username, conversationId), + requestMessage: userMessage, + responseMessage: responseMessage + }; + }; + const abortKey = conversationId; + abortControllers.set(abortKey, { abortController, ...endpointOption }); + let response = await askClient({ text, parentMessageId: userParentMessageId, @@ -144,6 +182,7 @@ const ask = async ({ abortController }); + abortControllers.delete(abortKey); console.log('CLIENT RESPONSE', response); const newConversationId = response.conversationId || conversationId; @@ -155,13 +194,18 @@ const ask = async ({ let responseMessage = { conversationId: newConversationId, - messageId: newResponseMessageId, + messageId: responseMessageId, + newMessageId: newResponseMessageId, parentMessageId: overrideParentMessageId || newUserMassageId, text: await handleText(response), - sender: endpointOption?.chatGptLabel || 'ChatGPT' + sender: endpointOption?.chatGptLabel || 'ChatGPT', + unfinished: false, + cancelled: false, + error: false }; await saveMessage(responseMessage); + responseMessage.messageId = newResponseMessageId; // STEP2 update the conversation let conversationUpdate = { conversationId: newConversationId, endpoint: 'openAI' }; @@ -200,12 +244,11 @@ const ask = async ({ requestMessage: userMessage, responseMessage: responseMessage }); - abortControllers.delete(abortKey); res.end(); if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); - await updateConvo(req?.session?.user?.username, { + await saveConvo(req?.session?.user?.username, { conversationId: conversationId, title }); @@ -213,10 +256,12 @@ const ask = async ({ } catch (error) { console.error(error); const errorMessage = { - messageId: crypto.randomUUID(), + messageId: responseMessageId, sender: endpointOption?.chatGptLabel || 'ChatGPT', conversationId, parentMessageId: overrideParentMessageId || userMessageId, + unfinished: false, + cancelled: false, error: true, text: error.message }; diff --git a/api/server/routes/ask/handlers.js b/api/server/routes/ask/handlers.js index 9efcd292c9..08b9e8008e 100644 --- a/api/server/routes/ask/handlers.js +++ b/api/server/routes/ask/handlers.js @@ -17,7 +17,7 @@ const sendMessage = (res, message) => { res.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`); }; -const createOnProgress = () => { +const createOnProgress = ({ onProgress: _onProgress }) => { let i = 0; let code = ''; let tokens = ''; @@ -65,14 +65,21 @@ const createOnProgress = () => { } sendMessage(res, { text: tokens + cursor, message: true, initial: i === 0, ...rest }); + + _onProgress && _onProgress({ text: tokens, message: true, initial: i === 0, ...rest }); + i++; }; - const onProgress = opts => { + const onProgress = (opts) => { return _.partialRight(progressCallback, opts); }; - return onProgress; + const getPartialText = () => { + return tokens; + }; + + return { onProgress, getPartialText }; }; const handleText = async (response, bing = false) => { diff --git a/api/server/routes/convos.js b/api/server/routes/convos.js index ed3f6db15c..03ba85a1a6 100644 --- a/api/server/routes/convos.js +++ b/api/server/routes/convos.js @@ -2,7 +2,7 @@ const express = require('express'); const router = express.Router(); const { titleConvo } = require('../../app/'); const { getConvo, saveConvo, getConvoTitle } = require('../../models'); -const { getConvosByPage, deleteConvos, updateConvo } = require('../../models/Conversation'); +const { getConvosByPage, deleteConvos } = require('../../models/Conversation'); const { getMessages } = require('../../models/Message'); router.get('/', async (req, res) => { @@ -44,7 +44,7 @@ router.post('/update', async (req, res) => { const update = req.body.arg; try { - const dbResponse = await updateConvo(req?.session?.user?.username, update); + const dbResponse = await saveConvo(req?.session?.user?.username, update); res.status(201).send(dbResponse); } catch (error) { console.error(error); diff --git a/client/src/components/Input/index.jsx b/client/src/components/Input/index.jsx index 0cda35af2b..8f98a4339e 100644 --- a/client/src/components/Input/index.jsx +++ b/client/src/components/Input/index.jsx @@ -33,7 +33,7 @@ export default function TextChat({ isSearchView = false }) { // const bingStylesRef = useRef(null); const [showBingToneSetting, setShowBingToneSetting] = useState(false); - const isNotAppendable = latestMessage?.cancelled || latestMessage?.error; + const isNotAppendable = latestMessage?.unfinished || latestMessage?.error; // auto focus to input, when enter a conversation. useEffect(() => { diff --git a/client/src/components/MessageHandler/index.jsx b/client/src/components/MessageHandler/index.jsx index b7736a4374..9f79c3aac6 100644 --- a/client/src/components/MessageHandler/index.jsx +++ b/client/src/components/MessageHandler/index.jsx @@ -3,7 +3,6 @@ import { useRecoilValue, useRecoilState, useResetRecoilState, useSetRecoilState import { SSE } from '~/data-provider/sse.mjs'; import createPayload from '~/data-provider/createPayload'; import { useAbortRequestWithMessage } from '~/data-provider'; -import { v4 } from 'uuid'; import store from '~/store'; export default function MessageHandler() { @@ -12,12 +11,6 @@ export default function MessageHandler() { const setMessages = useSetRecoilState(store.messages); const setConversation = useSetRecoilState(store.conversation); const resetLatestMessage = useResetRecoilState(store.latestMessage); - const [lastResponse, setLastResponse] = useRecoilState(store.lastResponse); - const setLatestMessage = useSetRecoilState(store.latestMessage); - const setSubmission = useSetRecoilState(store.submission); - const [source, setSource] = useState(null); - // const [abortKey, setAbortKey] = useState(null); - const [currentParent, setCurrentParent] = useState(null); const { refreshConversations } = store.useConversations(); @@ -32,7 +25,8 @@ export default function MessageHandler() { text: data, parentMessageId: message?.overrideParentMessageId, messageId: message?.overrideParentMessageId + '_', - submitting: true + submitting: true, + unfinished: true } ]); else @@ -44,45 +38,38 @@ export default function MessageHandler() { text: data, parentMessageId: message?.messageId, messageId: message?.messageId + '_', - submitting: true + submitting: true, + unfinished: true } ]); }; const cancelHandler = (data, submission) => { - const { messages, message, initialResponse, isRegenerate = false } = submission; - const { text, messageId, parentMessageId } = data; + const { messages, isRegenerate = false } = submission; - if (isRegenerate) { - setMessages([ - ...messages, - { - ...initialResponse, - text: data, - parentMessageId: message?.overrideParentMessageId, - messageId: message?.overrideParentMessageId + '_', - cancelled: true - } - ]); - } else { - console.log('cancelHandler, isRegenerate = false'); - setMessages([ - ...messages, - message, - { - ...initialResponse, - text, - parentMessageId: message?.messageId, - messageId, - // cancelled: true - } - ]); - setLastResponse(''); - setSource(null); - setIsSubmitting(false); - setSubmission(null); - setLatestMessage(data); + const { requestMessage, responseMessage, conversation } = data; + + // update the messages + if (isRegenerate) setMessages([...messages, responseMessage]); + else setMessages([...messages, requestMessage, responseMessage]); + setIsSubmitting(false); + + // refresh title + if (requestMessage.parentMessageId == '00000000-0000-0000-0000-000000000000') { + setTimeout(() => { + refreshConversations(); + }, 2000); + + // in case it takes too long. + setTimeout(() => { + refreshConversations(); + }, 5000); } + + setConversation(prevState => ({ + ...prevState, + ...conversation + })); }; const createdHandler = (data, submission) => { @@ -160,50 +147,37 @@ export default function MessageHandler() { return; }; + const abortConversation = conversationId => { + console.log(submission); + const { endpoint } = submission?.conversation || {}; + + fetch(`/api/ask/${endpoint}/abort`, { + method: 'POST', + headers: { + 'Content-Type': 'application/json' + }, + body: JSON.stringify({ + abortKey: conversationId + }) + }) + .then(response => response.json()) + .then(data => { + console.log('aborted', data); + cancelHandler(data, submission); + }) + .catch(error => { + console.error('Error aborting request'); + console.error(error); + // errorHandler({ text: 'Error aborting request' }, { ...submission, message }); + }); + return; + }; + useEffect(() => { if (submission === null) return; if (Object.keys(submission).length === 0) return; - let { message, cancel } = submission; - - if (cancel && source) { - console.log('message aborted', submission); - source.close(); - const { endpoint } = submission.conversation; - - // splitting twice because the cursor may or may not be wrapped in a span - const latestMessageText = lastResponse.split('█')[0].split('')[0]; - const latestMessage = { - text: latestMessageText, - messageId: v4(), - parentMessageId: currentParent.messageId, - conversationId: currentParent.conversationId - }; - - fetch(`/api/ask/${endpoint}/abort`, { - method: 'POST', - headers: { - 'Content-Type': 'application/json' - }, - body: JSON.stringify({ - abortKey: currentParent.conversationId, - latestMessage - }) - }) - .then(response => { - if (response.ok) { - console.log('Request aborted'); - } else { - console.error('Error aborting request'); - } - }) - .catch(error => { - console.error(error); - }); - console.log('source closed, got this far'); - cancelHandler(latestMessage, { ...submission, message }); - return; - } + let { message } = submission; const { server, payload } = createPayload(submission); @@ -212,9 +186,6 @@ export default function MessageHandler() { headers: { 'Content-Type': 'application/json' } }); - setSource(events); - - // let latestResponseText = ''; events.onmessage = e => { const data = JSON.parse(e.data); @@ -229,24 +200,19 @@ export default function MessageHandler() { }; createdHandler(data, { ...submission, message }); console.log('created', message); - // setAbortKey(message?.conversationId); - setCurrentParent(message); } else { let text = data.text || data.response; if (data.initial) console.log(data); if (data.message) { - // latestResponseText = text; - setLastResponse(text); messageHandler(text, { ...submission, message }); } - // console.log('dataStream', data); } }; events.onopen = () => console.log('connection is opened'); - // events.oncancel = () => cancelHandler(latestResponseText, { ...submission, message }); + events.oncancel = () => abortConversation(message?.conversationId || submission?.conversationId); events.onerror = function (e) { console.log('error in opening conn.'); @@ -263,7 +229,7 @@ export default function MessageHandler() { return () => { const isCancelled = events.readyState <= 1; events.close(); - setSource(null); + // setSource(null); if (isCancelled) { const e = new Event('cancel'); events.dispatchEvent(e); diff --git a/client/src/components/Messages/Message.jsx b/client/src/components/Messages/Message.jsx index eec7f96bd0..802883a484 100644 --- a/client/src/components/Messages/Message.jsx +++ b/client/src/components/Messages/Message.jsx @@ -1,4 +1,4 @@ -import { useState, useEffect, useRef } from 'react'; +import React, { useState, useEffect, useRef } from 'react'; import { useRecoilValue, useSetRecoilState } from 'recoil'; import copy from 'copy-to-clipboard'; import SubRow from './Content/SubRow'; @@ -22,7 +22,7 @@ export default function Message({ siblingCount, setSiblingIdx }) { - const { text, searchResult, isCreatedByUser, error, submitting } = message; + const { text, searchResult, isCreatedByUser, error, submitting, unfinished, cancelled } = message; const isSubmitting = useRecoilValue(store.isSubmitting); const setLatestMessage = useSetRecoilState(store.latestMessage); const [abortScroll, setAbort] = useState(false); @@ -98,7 +98,7 @@ export default function Message({ const clickSearchResult = async () => { if (!searchResult) return; - getConversationQuery.refetch(message.conversationId).then((response) => { + getConversationQuery.refetch(message.conversationId).then(response => { switchToConversation(response.data); }); }; @@ -170,23 +170,39 @@ export default function Message({ ) : ( -
- {/*
*/} -
- {!isCreatedByUser ? ( - <> - - - ) : ( - <>{text} + <> +
+ {/*
*/} +
+ {!isCreatedByUser ? ( + <> + + + ) : ( + <>{text} + )} +
-
+ {!submitting && cancelled ? ( +
+
+ {`This is a cancelled message.`} +
+
+ ) : null} + {!submitting && unfinished ? ( +
+
+ {`This is an unfinished message. It might because the AI is still generating or it has been aborted. Refresh later to see more updates.`} +
+
+ ) : null} + )}
{ setFileName(filenamify(String(conversation?.title || 'file'))); - setType('text'); + setType('screenshot'); setIncludeOptions(true); setExportBranches(false); setRecursive(true); @@ -144,6 +144,8 @@ export default function ExportModel({ open, onOpenChange }) { fieldValues: entries.find(e => e.fieldName == 'isCreatedByUser').fieldValues }, { fieldName: 'error', fieldValues: entries.find(e => e.fieldName == 'error').fieldValues }, + { fieldName: 'unfinished', fieldValues: entries.find(e => e.fieldName == 'unfinished').fieldValues }, + { fieldName: 'cancelled', fieldValues: entries.find(e => e.fieldName == 'cancelled').fieldValues }, { fieldName: 'messageId', fieldValues: entries.find(e => e.fieldName == 'messageId').fieldValues }, { fieldName: 'parentMessageId', @@ -181,7 +183,11 @@ export default function ExportModel({ open, onOpenChange }) { data += `\n## History\n`; for (const message of messages) { - data += `**${message?.sender}:**\n${message?.text}\n\n`; + data += `**${message?.sender}:**\n${message?.text}\n`; + if (message.error) data += `*(This is an error message)*\n`; + if (message.unfinished) data += `*(This is an unfinished message)*\n`; + if (message.cancelled) data += `*(This is a cancelled message)*\n`; + data += '\n\n'; } exportFromJSON({ @@ -220,7 +226,11 @@ export default function ExportModel({ open, onOpenChange }) { data += `\nHistory\n########################\n`; for (const message of messages) { - data += `${message?.sender}:\n${message?.text}\n\n`; + data += `>> ${message?.sender}:\n${message?.text}\n`; + if (message.error) data += `(This is an error message)\n`; + if (message.unfinished) data += `(This is an unfinished message)\n`; + if (message.cancelled) data += `(This is a cancelled message)\n`; + data += '\n\n'; } exportFromJSON({ diff --git a/client/src/utils/handleSubmit.js b/client/src/utils/handleSubmit.js index 7f25d74768..4001aff409 100644 --- a/client/src/utils/handleSubmit.js +++ b/client/src/utils/handleSubmit.js @@ -142,8 +142,7 @@ const useMessageHandler = () => { }; const stopGenerating = () => { - // setSubmission(null); - setSubmission(prev => ({ ...prev, cancel: true })); + setSubmission(null); }; return { ask, regenerate, stopGenerating };