diff --git a/api/server/routes/ask/askOpenAI.js b/api/server/routes/ask/askOpenAI.js index 9c5d121639..4d95cc4c8d 100644 --- a/api/server/routes/ask/askOpenAI.js +++ b/api/server/routes/ask/askOpenAI.js @@ -6,6 +6,22 @@ const { titleConvo, askClient } = require('../../../app/'); const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models'); const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); +const abortControllers = new Map(); + +router.get('/abort', (req, res) => { + const requestId = req.query.requestId; + + if (abortControllers.has(requestId)) { + const abortController = abortControllers.get(requestId); + abortController.abort(); + abortControllers.delete(requestId); + console.log('Aborted request', requestId); + res.status(200).send('Aborted'); + } else { + res.status(404).send('Request not found'); + } +}); + router.post('/', async (req, res) => { const { endpoint, @@ -100,7 +116,14 @@ const ask = async ({ try { const progressCallback = createOnProgress(); const abortController = new AbortController(); - res.on('close', () => abortController.abort()); + const abortKey = userMessage.messageId; + abortControllers.set(abortKey, abortController); + + res.on('close', () => { + console.log('stopped message, aborting'); + abortController.abort(); + return res.end(); + }); let response = await askClient({ text, parentMessageId: userParentMessageId, @@ -170,6 +193,7 @@ const ask = async ({ requestMessage: userMessage, responseMessage: responseMessage }); + abortControllers.delete(abortKey); res.end(); if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { diff --git a/client/src/components/Input/index.jsx b/client/src/components/Input/index.jsx index e87f305ac7..067f1bcda9 100644 --- a/client/src/components/Input/index.jsx +++ b/client/src/components/Input/index.jsx @@ -62,7 +62,8 @@ export default function TextChat({ isSearchView = false }) { setText(''); }; - const handleStopGenerating = () => { + const handleStopGenerating = (e) => { + e.preventDefault(); stopGenerating(); }; diff --git a/client/src/components/MessageHandler/index.jsx b/client/src/components/MessageHandler/index.jsx index 24ec7c3dc4..c53b5bcd4c 100644 --- a/client/src/components/MessageHandler/index.jsx +++ b/client/src/components/MessageHandler/index.jsx @@ -1,5 +1,5 @@ -import { useEffect } from 'react'; -import { useRecoilValue, useResetRecoilState, useSetRecoilState } from 'recoil'; +import { useEffect, useState } from 'react'; +import { useRecoilValue, useRecoilState, useResetRecoilState, useSetRecoilState } from 'recoil'; import { SSE } from '~/data-provider/sse.mjs'; import createPayload from '~/data-provider/createPayload'; @@ -11,6 +11,10 @@ export default function MessageHandler() { const setMessages = useSetRecoilState(store.messages); const setConversation = useSetRecoilState(store.conversation); const resetLatestMessage = useResetRecoilState(store.latestMessage); + const [lastResponse, setLastResponse] = useRecoilState(store.lastResponse); + const setSubmission = useSetRecoilState(store.submission); + const [source, setSource] = useState(null); + const [abortKey, setAbortKey] = useState(null); const { refreshConversations } = store.useConversations(); @@ -45,7 +49,7 @@ export default function MessageHandler() { const cancelHandler = (data, submission) => { const { messages, message, initialResponse, isRegenerate = false } = submission; - if (isRegenerate) + if (isRegenerate) { setMessages([ ...messages, { @@ -56,7 +60,7 @@ export default function MessageHandler() { cancelled: true } ]); - else + } else { setMessages([ ...messages, message, @@ -65,9 +69,12 @@ export default function MessageHandler() { text: data, parentMessageId: message?.messageId, messageId: message?.messageId + '_', - cancelled: true + // cancelled: true } ]); + setLastResponse(''); + setSource(null); + } }; const createdHandler = (data, submission) => { @@ -149,7 +156,33 @@ export default function MessageHandler() { if (submission === null) return; if (Object.keys(submission).length === 0) return; - let { message } = submission; + let { message, cancel } = submission; + + if (cancel && source) { + console.log('message aborted', submission); + source.close(); + const { endpoint } = submission.conversation; + const latestMessage = lastResponse.replaceAll('█', ''); + + fetch(`/api/ask/${endpoint}/abort?requestId=${abortKey}`) + .then(response => { + if (response.ok) { + console.log('Request aborted'); + } else { + console.error('Error aborting request'); + } + }) + .catch(error => { + console.error(error); + }); + console.log('source closed, got this far'); + cancelHandler(latestMessage, { ...submission, message }); + setIsSubmitting(false); + setSubmission(null); + return; + } + + // events.oncancel = () => cancelHandler(latestResponseText, { ...submission, message }); const { server, payload } = createPayload(submission); @@ -158,7 +191,9 @@ export default function MessageHandler() { headers: { 'Content-Type': 'application/json' } }); - let latestResponseText = ''; + setSource(events); + + // let latestResponseText = ''; events.onmessage = e => { const data = JSON.parse(e.data); @@ -173,12 +208,14 @@ export default function MessageHandler() { }; createdHandler(data, { ...submission, message }); console.log('created', message); + setAbortKey(message?.messageId); } else { let text = data.text || data.response; if (data.initial) console.log(data); if (data.message) { - latestResponseText = text; + // latestResponseText = text; + setLastResponse(text); messageHandler(text, { ...submission, message }); } // console.log('dataStream', data); @@ -187,7 +224,7 @@ export default function MessageHandler() { events.onopen = () => console.log('connection is opened'); - events.oncancel = () => cancelHandler(latestResponseText, { ...submission, message }); + // events.oncancel = () => cancelHandler(latestResponseText, { ...submission, message }); events.onerror = function (e) { console.log('error in opening conn.'); @@ -204,6 +241,7 @@ export default function MessageHandler() { return () => { const isCancelled = events.readyState <= 1; events.close(); + setSource(null); if (isCancelled) { const e = new Event('cancel'); events.dispatchEvent(e); diff --git a/client/src/store/submission.js b/client/src/store/submission.js index 90e08b6dff..a83a118236 100644 --- a/client/src/store/submission.js +++ b/client/src/store/submission.js @@ -31,7 +31,19 @@ const isSubmitting = atom({ default: false, }); +const lastResponse = atom({ + key: "lastResponse", + default: '', +}); + +const source = atom({ + key: "source", + default: null, +}); + export default { submission, isSubmitting, + lastResponse, + source, }; diff --git a/client/src/utils/handleSubmit.js b/client/src/utils/handleSubmit.js index 3c0db371ef..8514977eef 100644 --- a/client/src/utils/handleSubmit.js +++ b/client/src/utils/handleSubmit.js @@ -138,7 +138,8 @@ const useMessageHandler = () => { }; const stopGenerating = () => { - setSubmission(null); + // setSubmission(null); + setSubmission(prev => ({ ...prev, cancel: true })); }; return { ask, regenerate, stopGenerating };