diff --git a/api/server/middleware/abortMiddleware.js b/api/server/middleware/abortMiddleware.js index 08f614286a..a2be50ee82 100644 --- a/api/server/middleware/abortMiddleware.js +++ b/api/server/middleware/abortMiddleware.js @@ -110,7 +110,7 @@ const handleAbortError = async (res, req, error, data) => { } const respondWithError = async (partialText) => { - const options = { + let options = { sender, messageId, conversationId, @@ -121,7 +121,8 @@ const handleAbortError = async (res, req, error, data) => { }; if (partialText) { - options.overrideProps = { + options = { + ...options, error: false, unfinished: true, text: partialText, diff --git a/api/server/middleware/abortRun.js b/api/server/middleware/abortRun.js index b93eb8c213..fd3f5353f5 100644 --- a/api/server/middleware/abortRun.js +++ b/api/server/middleware/abortRun.js @@ -37,7 +37,7 @@ async function abortRun(req, res) { try { await cache.set(cacheKey, 'cancelled'); const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); - logger.debug('Cancelled run:', cancelledRun); + logger.debug('[abortRun] Cancelled run:', cancelledRun); } catch (error) { logger.error('[abortRun] Error cancelling run', error); if ( diff --git a/api/server/routes/assistants/chat.js b/api/server/routes/assistants/chat.js index 8796bbcf30..73cf0628f2 100644 --- a/api/server/routes/assistants/chat.js +++ b/api/server/routes/assistants/chat.js @@ -11,10 +11,10 @@ const { } = require('~/server/services/Threads'); const { runAssistant, createOnTextProgress } = require('~/server/services/AssistantService'); const { addTitle, initializeClient } = require('~/server/services/Endpoints/assistant'); +const { sendResponse, sendMessage } = require('~/server/utils'); const { createRun, sleep } = require('~/server/services/Runs'); const { getConvo } = require('~/models/Conversation'); const getLogStores = require('~/cache/getLogStores'); -const { sendMessage } = require('~/server/utils'); const { logger } = require('~/config'); const router = express.Router(); @@ -101,32 +101,52 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res let completedRun; const handleError = async (error) => { + const messageData = { + thread_id, + assistant_id, + conversationId, + parentMessageId, + sender: 'System', + user: req.user.id, + shouldSaveMessage: false, + messageId: responseMessageId, + endpoint: EModelEndpoint.assistants, + }; + if (error.message === 'Run cancelled') { return res.end(); - } - if (error.message === 'Request closed' && completedRun) { + } else if (error.message === 'Request closed' && completedRun) { return; } else if (error.message === 'Request closed') { logger.debug('[/assistants/chat/] Request aborted on close'); + } else { + logger.error('[/assistants/chat/]', error); } - logger.error('[/assistants/chat/]', error); - if (!openai || !thread_id || !run_id) { - return res.status(500).json({ error: 'The Assistant run failed to initialize' }); + return sendResponse(res, messageData, 'The Assistant run failed to initialize'); } + await sleep(3000); + try { + const status = await cache.get(cacheKey); + if (status === 'cancelled') { + logger.debug('[/assistants/chat/] Run already cancelled'); + return res.end(); + } await cache.delete(cacheKey); const cancelledRun = await openai.beta.threads.runs.cancel(thread_id, run_id); - logger.debug('Cancelled run:', cancelledRun); + logger.debug('[/assistants/chat/] Cancelled run:', cancelledRun); } catch (error) { - logger.error('[abortRun] Error cancelling run', error); + logger.error('[/assistants/chat/] Error cancelling run', error); } await sleep(2000); + + let run; try { - const run = await openai.beta.threads.runs.retrieve(thread_id, run_id); + run = await openai.beta.threads.runs.retrieve(thread_id, run_id); await recordUsage({ ...run.usage, model: run.model, @@ -137,6 +157,7 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res logger.error('[/assistants/chat/] Error fetching or processing run', error); } + let finalEvent; try { const runMessages = await checkMessageGaps({ openai, @@ -146,22 +167,18 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res latestMessageId: responseMessageId, }); - const finalEvent = { + finalEvent = { title: 'New Chat', final: true, conversation: await getConvo(req.user.id, conversationId), runMessages, }; - - if (res.headersSent && finalEvent) { - return sendMessage(res, finalEvent); - } - - res.json(finalEvent); } catch (error) { logger.error('[/assistants/chat/] Error finalizing error process', error); - return res.status(500).json({ error: 'The Assistant run failed' }); + return sendResponse(res, messageData, 'The Assistant run failed'); } + + return sendResponse(res, finalEvent); }; try { @@ -172,10 +189,12 @@ router.post('/', validateModel, buildEndpointOption, setHeaders, async (req, res }); if (convoId && !_thread_id) { + completedRun = true; throw new Error('Missing thread_id for existing conversation'); } if (!assistant_id) { + completedRun = true; throw new Error('Missing assistant_id'); } diff --git a/api/server/utils/streamResponse.js b/api/server/utils/streamResponse.js index 109a407463..b7a691d91a 100644 --- a/api/server/utils/streamResponse.js +++ b/api/server/utils/streamResponse.js @@ -32,6 +32,13 @@ const sendMessage = (res, message, event = 'message') => { * @async * @param {object} res - The server response. * @param {object} options - The options for handling the error containing message properties. + * @param {object} options.user - The user ID. + * @param {string} options.sender - The sender of the message. + * @param {string} options.conversationId - The conversation ID. + * @param {string} options.messageId - The message ID. + * @param {string} options.parentMessageId - The parent message ID. + * @param {string} options.text - The error message. + * @param {boolean} options.shouldSaveMessage - [Optional] Whether the message should be saved. Default is true. * @param {function} callback - [Optional] The callback function to be executed. */ const sendError = async (res, options, callback) => { @@ -43,7 +50,7 @@ const sendError = async (res, options, callback) => { parentMessageId, text, shouldSaveMessage, - overrideProps = {}, + ...rest } = options; const errorMessage = { sender, @@ -55,7 +62,7 @@ const sendError = async (res, options, callback) => { final: true, text, isCreatedByUser: false, - ...overrideProps, + ...rest, }; if (callback && typeof callback === 'function') { await callback(); @@ -88,7 +95,28 @@ const sendError = async (res, options, callback) => { handleError(res, errorMessage); }; +/** + * Sends the response based on whether headers have been sent or not. + * @param {Express.Response} res - The server response. + * @param {Object} data - The data to be sent. + * @param {string} [errorMessage] - The error message, if any. + */ +const sendResponse = (res, data, errorMessage) => { + if (!res.headersSent) { + if (errorMessage) { + return res.status(500).json({ error: errorMessage }); + } + return res.json(data); + } + + if (errorMessage) { + return sendError(res, { ...data, text: errorMessage }); + } + return sendMessage(res, data); +}; + module.exports = { + sendResponse, handleError, sendMessage, sendError, diff --git a/client/src/hooks/Input/useTextarea.ts b/client/src/hooks/Input/useTextarea.ts index 9fa3b04808..92cfb9fd57 100644 --- a/client/src/hooks/Input/useTextarea.ts +++ b/client/src/hooks/Input/useTextarea.ts @@ -45,7 +45,9 @@ export default function useTextarea({ const localize = useLocalize(); const { conversationId, jailbreak, endpoint = '', assistant_id } = conversation || {}; - const isNotAppendable = (latestMessage?.unfinished && !isSubmitting) || latestMessage?.error; + const isNotAppendable = + ((latestMessage?.unfinished && !isSubmitting) || latestMessage?.error) && + endpoint !== EModelEndpoint.assistants; // && (conversationId?.length ?? 0) > 6; // also ensures that we don't show the wrong placeholder const assistant = endpoint === EModelEndpoint.assistants && assistantMap?.[assistant_id ?? '']; diff --git a/client/src/hooks/SSE/useSSE.ts b/client/src/hooks/SSE/useSSE.ts index 638e9476fc..dc95602667 100644 --- a/client/src/hooks/SSE/useSSE.ts +++ b/client/src/hooks/SSE/useSSE.ts @@ -418,7 +418,6 @@ export default function useSSE(submission: TSubmission | null, index = 0) { const abortConversation = useCallback( async (conversationId = '', submission: TSubmission) => { - console.log(submission); let runAbortKey = ''; try { const conversation = (JSON.parse(localStorage.getItem('lastConversationSetup') ?? '') ?? diff --git a/client/src/hooks/useChatHelpers.ts b/client/src/hooks/useChatHelpers.ts index cf6a54e74b..46b8fc9ddd 100644 --- a/client/src/hooks/useChatHelpers.ts +++ b/client/src/hooks/useChatHelpers.ts @@ -140,7 +140,10 @@ export default function useChatHelpers(index = 0, paramId: string | undefined) { (msg) => msg.messageId === latestMessage?.parentMessageId, ); - const thread_id = parentMessage?.thread_id ?? latestMessage?.thread_id; + let thread_id = parentMessage?.thread_id ?? latestMessage?.thread_id; + if (!thread_id) { + thread_id = currentMessages.find((message) => message.thread_id)?.thread_id; + } const endpointsConfig = queryClient.getQueryData([QueryKeys.endpoints]); const endpointType = getEndpointField(endpointsConfig, endpoint, 'type');