merge from dannya

feat: support unfinished messages.
This commit is contained in:
Wentao Lyu 2023-04-11 03:26:38 +08:00
parent bbf2f8a6ca
commit a5a0eab7f7
15 changed files with 308 additions and 221 deletions

View file

@ -13,45 +13,21 @@ const getConvo = async (user, conversationId) => {
module.exports = { module.exports = {
Conversation, Conversation,
saveConvo: async (user, { conversationId, newConversationId, title, ...convo }) => { saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
try { try {
const messages = await getMessages({ conversationId }); const messages = await getMessages({ conversationId });
const update = { ...convo, messages }; const update = { ...convo, messages, user };
if (title) {
update.title = title;
update.user = user;
}
if (newConversationId) { if (newConversationId) {
update.conversationId = newConversationId; update.conversationId = newConversationId;
} }
if (!update.jailbreakConversationId) {
update.jailbreakConversationId = null;
}
return await Conversation.findOneAndUpdate( return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
{ conversationId: conversationId, user }, new: true,
{ $set: update }, upsert: true
{ new: true, upsert: true }
).exec();
} catch (error) {
console.log(error);
return { message: 'Error saving conversation' };
}
},
updateConvo: async (user, { conversationId, oldConvoId, ...update }) => {
try {
let convoId = conversationId;
if (oldConvoId) {
convoId = oldConvoId;
update.conversationId = conversationId;
}
return await Conversation.findOneAndUpdate({ conversationId: convoId, user }, update, {
new: true
}).exec(); }).exec();
} catch (error) { } catch (error) {
console.log(error); console.log(error);
return { message: 'Error updating conversation' }; return { message: 'Error saving conversation' };
} }
}, },
getConvosByPage: async (user, pageNumber = 1, pageSize = 12) => { getConvosByPage: async (user, pageNumber = 1, pageSize = 12) => {
@ -82,7 +58,7 @@ module.exports = {
// will handle a syncing solution soon // will handle a syncing solution soon
const deletedConvoIds = []; const deletedConvoIds = [];
convoIds.forEach(convo => convoIds.forEach((convo) =>
promises.push( promises.push(
Conversation.findOne({ Conversation.findOne({
user, user,
@ -145,7 +121,7 @@ module.exports = {
}, },
deleteConvos: async (user, filter) => { deleteConvos: async (user, filter) => {
let toRemove = await Conversation.find({ ...filter, user }).select('conversationId'); let toRemove = await Conversation.find({ ...filter, user }).select('conversationId');
const ids = toRemove.map(instance => instance.conversationId); const ids = toRemove.map((instance) => instance.conversationId);
let deleteCount = await Conversation.deleteMany({ ...filter, user }).exec(); let deleteCount = await Conversation.deleteMany({ ...filter, user }).exec();
deleteCount.messages = await deleteMessages({ conversationId: { $in: ids } }); deleteCount.messages = await deleteMessages({ conversationId: { $in: ids } });
return deleteCount; return deleteCount;

View file

@ -9,7 +9,9 @@ module.exports = {
sender, sender,
text, text,
isCreatedByUser = false, isCreatedByUser = false,
error error,
unfinished,
cancelled
}) => { }) => {
try { try {
// may also need to update the conversation here // may also need to update the conversation here
@ -22,7 +24,9 @@ module.exports = {
sender, sender,
text, text,
isCreatedByUser, isCreatedByUser,
error error,
unfinished,
cancelled
}, },
{ upsert: true, new: true } { upsert: true, new: true }
); );
@ -45,7 +49,7 @@ module.exports = {
return { message: 'Error deleting messages' }; return { message: 'Error deleting messages' };
} }
}, },
getMessages: async filter => { getMessages: async (filter) => {
try { try {
return await Message.find(filter).sort({ createdAt: 1 }).exec(); return await Message.find(filter).sort({ createdAt: 1 }).exec();
} catch (error) { } catch (error) {
@ -53,7 +57,7 @@ module.exports = {
return { message: 'Error getting messages' }; return { message: 'Error getting messages' };
} }
}, },
deleteMessages: async filter => { deleteMessages: async (filter) => {
try { try {
return await Message.deleteMany(filter).exec(); return await Message.deleteMany(filter).exec();
} catch (error) { } catch (error) {

View file

@ -1,5 +1,5 @@
const { getMessages, saveMessage, deleteMessagesSince, deleteMessages } = require('./Message'); const { getMessages, saveMessage, deleteMessagesSince, deleteMessages } = require('./Message');
const { getConvoTitle, getConvo, saveConvo, updateConvo } = require('./Conversation'); const { getConvoTitle, getConvo, saveConvo } = require('./Conversation');
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset'); const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
module.exports = { module.exports = {
@ -11,7 +11,6 @@ module.exports = {
getConvoTitle, getConvoTitle,
getConvo, getConvo,
saveConvo, saveConvo,
updateConvo,
getPreset, getPreset,
getPresets, getPresets,

View file

@ -43,6 +43,14 @@ const messageSchema = mongoose.Schema(
required: true, required: true,
default: false default: false
}, },
unfinished: {
type: Boolean,
default: false
},
cancelled: {
type: Boolean,
default: false
},
error: { error: {
type: Boolean, type: Boolean,
default: false default: false

View file

@ -2,14 +2,24 @@ const Keyv = require('keyv');
const { KeyvFile } = require('keyv-file'); const { KeyvFile } = require('keyv-file');
const { saveMessage } = require('../../../models'); const { saveMessage } = require('../../../models');
const addToCache = async ({ endpointOption, userMessage, latestMessage }) => { const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => {
try { try {
const conversationsCache = new Keyv({ const conversationsCache = new Keyv({
store: new KeyvFile({ filename: './data/cache.json' }), store: new KeyvFile({ filename: './data/cache.json' }),
namespace: 'chatgpt' // should be 'bing' for bing/sydney namespace: 'chatgpt' // should be 'bing' for bing/sydney
}); });
const { conversationId, messageId, parentMessageId, text } = latestMessage; const {
conversationId,
messageId: userMessageId,
parentMessageId: userParentMessageId,
text: userText
} = userMessage;
const {
messageId: responseMessageId,
parentMessageId: responseParentMessageId,
text: responseText
} = responseMessage;
let conversation = await conversationsCache.get(conversationId); let conversation = await conversationsCache.get(conversationId);
// used to generate a title for the conversation if none exists // used to generate a title for the conversation if none exists
@ -22,10 +32,7 @@ const addToCache = async ({ endpointOption, userMessage, latestMessage }) => {
// isNewConversation = true; // isNewConversation = true;
} }
// const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
const roles = (options) => { const roles = (options) => {
const { endpoint } = options;
if (endpoint === 'openAI') { if (endpoint === 'openAI') {
return options?.chatGptLabel || 'ChatGPT'; return options?.chatGptLabel || 'ChatGPT';
} else if (endpoint === 'bingAI') { } else if (endpoint === 'bingAI') {
@ -33,24 +40,21 @@ const addToCache = async ({ endpointOption, userMessage, latestMessage }) => {
} }
}; };
// const messageId = crypto.randomUUID(); let _userMessage = {
id: userMessageId,
let responseMessage = { parentMessageId: userParentMessageId,
id: messageId, role: 'User',
parentMessageId, message: userText
role: roles(endpointOption),
message: text
}; };
await saveMessage({ let _responseMessage = {
...responseMessage, id: responseMessageId,
conversationId, parentMessageId: responseParentMessageId,
messageId, role: roles(endpointOption),
sender: responseMessage.role, message: responseText
text };
});
conversation.messages.push(userMessage, responseMessage); conversation.messages.push(_userMessage, _responseMessage);
await conversationsCache.set(conversationId, conversation); await conversationsCache.set(conversationId, conversation);
} catch (error) { } catch (error) {

View file

@ -2,7 +2,7 @@ const express = require('express');
const crypto = require('crypto'); const crypto = require('crypto');
const router = express.Router(); const router = express.Router();
const { titleConvo, askBing } = require('../../../app'); const { titleConvo, askBing } = require('../../../app');
const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
@ -95,6 +95,8 @@ const ask = async ({
}) => { }) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
let responseMessageId = crypto.randomUUID();
res.writeHead(200, { res.writeHead(200, {
Connection: 'keep-alive', Connection: 'keep-alive',
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -106,9 +108,26 @@ const ask = async ({
if (preSendRequest) sendMessage(res, { message: userMessage, created: true }); if (preSendRequest) sendMessage(res, { message: userMessage, created: true });
try { try {
const progressCallback = createOnProgress(); let lastSavedTimestamp = 0;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: text,
unfinished: true,
cancelled: false,
error: false
});
}
}
});
const abortController = new AbortController(); const abortController = new AbortController();
res.on('close', () => abortController.abort());
let response = await askBing({ let response = await askBing({
text, text,
parentMessageId: userParentMessageId, parentMessageId: userParentMessageId,
@ -135,14 +154,20 @@ const ask = async ({
let responseMessage = { let responseMessage = {
conversationId: newConversationId, conversationId: newConversationId,
messageId: newResponseMessageId, messageId: responseMessageId,
newMessageId: newResponseMessageId,
parentMessageId: overrideParentMessageId || newUserMassageId, parentMessageId: overrideParentMessageId || newUserMassageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI', sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
text: await handleText(response, true), text: await handleText(response, true),
suggestions: response.details.suggestedResponses && response.details.suggestedResponses.map(s => s.text) suggestions:
response.details.suggestedResponses && response.details.suggestedResponses.map((s) => s.text),
unfinished: false,
cancelled: false,
error: false
}; };
await saveMessage(responseMessage); await saveMessage(responseMessage);
responseMessage.messageId = newResponseMessageId;
// STEP2 update the convosation. // STEP2 update the convosation.
@ -204,7 +229,7 @@ const ask = async ({
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
await updateConvo(req?.session?.user?.username, { await saveConvo(req?.session?.user?.username, {
conversationId: conversationId, conversationId: conversationId,
title title
}); });
@ -212,10 +237,12 @@ const ask = async ({
} catch (error) { } catch (error) {
console.log(error); console.log(error);
const errorMessage = { const errorMessage = {
messageId: crypto.randomUUID(), messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI', sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId, conversationId,
parentMessageId: overrideParentMessageId || userMessageId, parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true, error: true,
text: error.message text: error.message
}; };

View file

@ -3,7 +3,7 @@ const crypto = require('crypto');
const router = express.Router(); const router = express.Router();
const { getChatGPTBrowserModels } = require('../endpoints'); const { getChatGPTBrowserModels } = require('../endpoints');
const { browserClient } = require('../../../app/'); const { browserClient } = require('../../../app/');
const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
@ -38,7 +38,7 @@ router.post('/', async (req, res) => {
}; };
const availableModels = getChatGPTBrowserModels(); const availableModels = getChatGPTBrowserModels();
if (availableModels.find(model => model === endpointOption.model) === undefined) if (availableModels.find((model) => model === endpointOption.model) === undefined)
return handleError(res, { text: 'Illegal request: model' }); return handleError(res, { text: 'Illegal request: model' });
console.log('ask log', { console.log('ask log', {
@ -92,10 +92,29 @@ const ask = async ({
if (preSendRequest) sendMessage(res, { message: userMessage, created: true }); if (preSendRequest) sendMessage(res, { message: userMessage, created: true });
let responseMessageId = crypto.randomUUID();
try { try {
const progressCallback = createOnProgress(); let lastSavedTimestamp = 0;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: text,
unfinished: true,
cancelled: false,
error: false
});
}
}
});
const abortController = new AbortController(); const abortController = new AbortController();
res.on('close', () => abortController.abort());
let response = await browserClient({ let response = await browserClient({
text, text,
parentMessageId: userParentMessageId, parentMessageId: userParentMessageId,
@ -116,13 +135,18 @@ const ask = async ({
let responseMessage = { let responseMessage = {
conversationId: newConversationId, conversationId: newConversationId,
messageId: newResponseMessageId, messageId: responseMessageId,
newMessageId: newResponseMessageId,
parentMessageId: overrideParentMessageId || newUserMassageId, parentMessageId: overrideParentMessageId || newUserMassageId,
text: await handleText(response), text: await handleText(response),
sender: endpointOption?.chatGptLabel || 'ChatGPT' sender: endpointOption?.chatGptLabel || 'ChatGPT',
unfinished: false,
cancelled: false,
error: false
}; };
await saveMessage(responseMessage); await saveMessage(responseMessage);
responseMessage.messageId = newResponseMessageId;
// STEP2 update the conversation // STEP2 update the conversation
@ -168,17 +192,19 @@ const ask = async ({
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
// const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); // const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
const title = await response.details.title; const title = await response.details.title;
await updateConvo(req?.session?.user?.username, { await saveConvo(req?.session?.user?.username, {
conversationId: conversationId, conversationId: conversationId,
title title
}); });
} }
} catch (error) { } catch (error) {
const errorMessage = { const errorMessage = {
messageId: crypto.randomUUID(), messageId: responseMessageId,
sender: 'ChatGPT', sender: 'ChatGPT',
conversationId, conversationId,
parentMessageId: overrideParentMessageId || userMessageId, parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true, error: true,
text: error.message text: error.message
}; };

View file

@ -4,29 +4,26 @@ const router = express.Router();
const addToCache = require('./addToCache'); const addToCache = require('./addToCache');
const { getOpenAIModels } = require('../endpoints'); const { getOpenAIModels } = require('../endpoints');
const { titleConvo, askClient } = require('../../../app/'); const { titleConvo, askClient } = require('../../../app/');
const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models'); const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers'); const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
const abortControllers = new Map(); const abortControllers = new Map();
router.post('/abort', async (req, res) => { router.post('/abort', async (req, res) => {
const { abortKey, latestMessage } = req.body; const { abortKey } = req.body;
console.log(`req.body`, req.body); console.log(`req.body`, req.body);
if (!abortControllers.has(abortKey)) { if (!abortControllers.has(abortKey)) {
return res.status(404).send('Request not found'); return res.status(404).send('Request not found');
} }
const { abortController, userMessage, endpointOption } = abortControllers.get(abortKey);
if (!endpointOption.endpoint) {
endpointOption.endpoint = req.originalUrl.replace('/api/ask/','').split('/abort')[0];
}
abortController.abort(); const { abortController } = abortControllers.get(abortKey);
abortControllers.delete(abortKey); abortControllers.delete(abortKey);
console.log('Aborted request', abortKey, userMessage, endpointOption); const ret = await abortController.abortAsk();
await addToCache({ endpointOption, userMessage, latestMessage }); console.log('Aborted request', abortKey);
console.log('Aborted message:', ret);
res.status(200).send('Aborted');
res.send(JSON.stringify(ret));
}); });
router.post('/', async (req, res) => { router.post('/', async (req, res) => {
@ -66,7 +63,7 @@ router.post('/', async (req, res) => {
}; };
const availableModels = getOpenAIModels(); const availableModels = getOpenAIModels();
if (availableModels.find(model => model === endpointOption.model) === undefined) if (availableModels.find((model) => model === endpointOption.model) === undefined)
return handleError(res, { text: 'Illegal request: model' }); return handleError(res, { text: 'Illegal request: model' });
console.log('ask log', { console.log('ask log', {
@ -110,6 +107,8 @@ const ask = async ({
}) => { }) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage; let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
let responseMessageId = crypto.randomUUID();
res.writeHead(200, { res.writeHead(200, {
Connection: 'keep-alive', Connection: 'keep-alive',
'Content-Type': 'text/event-stream', 'Content-Type': 'text/event-stream',
@ -121,16 +120,55 @@ const ask = async ({
if (preSendRequest) sendMessage(res, { message: userMessage, created: true }); if (preSendRequest) sendMessage(res, { message: userMessage, created: true });
try { try {
const progressCallback = createOnProgress(); let lastSavedTimestamp = 0;
const abortController = new AbortController(); const { onProgress: progressCallback, getPartialText } = createOnProgress({
const abortKey = conversationId; onProgress: ({ text }) => {
console.log('conversationId -----> ', conversationId); const currentTimestamp = Date.now();
abortControllers.set(abortKey, { abortController, userMessage, endpointOption }); if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
res.on('close', () => { saveMessage({
abortController.abort(); messageId: responseMessageId,
return res.end(); sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: text,
unfinished: true,
cancelled: false,
error: false
});
}
}
}); });
let abortController = new AbortController();
abortController.abortAsk = async function () {
this.abort();
const responseMessage = {
messageId: responseMessageId,
sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: getPartialText(),
unfinished: false,
cancelled: true,
error: false
};
saveMessage(responseMessage);
await addToCache({ endpoint: 'openAI', endpointOption, userMessage, responseMessage });
return {
title: await getConvoTitle(req?.session?.user?.username, conversationId),
final: true,
conversation: await getConvo(req?.session?.user?.username, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage
};
};
const abortKey = conversationId;
abortControllers.set(abortKey, { abortController, ...endpointOption });
let response = await askClient({ let response = await askClient({
text, text,
parentMessageId: userParentMessageId, parentMessageId: userParentMessageId,
@ -144,6 +182,7 @@ const ask = async ({
abortController abortController
}); });
abortControllers.delete(abortKey);
console.log('CLIENT RESPONSE', response); console.log('CLIENT RESPONSE', response);
const newConversationId = response.conversationId || conversationId; const newConversationId = response.conversationId || conversationId;
@ -155,13 +194,18 @@ const ask = async ({
let responseMessage = { let responseMessage = {
conversationId: newConversationId, conversationId: newConversationId,
messageId: newResponseMessageId, messageId: responseMessageId,
newMessageId: newResponseMessageId,
parentMessageId: overrideParentMessageId || newUserMassageId, parentMessageId: overrideParentMessageId || newUserMassageId,
text: await handleText(response), text: await handleText(response),
sender: endpointOption?.chatGptLabel || 'ChatGPT' sender: endpointOption?.chatGptLabel || 'ChatGPT',
unfinished: false,
cancelled: false,
error: false
}; };
await saveMessage(responseMessage); await saveMessage(responseMessage);
responseMessage.messageId = newResponseMessageId;
// STEP2 update the conversation // STEP2 update the conversation
let conversationUpdate = { conversationId: newConversationId, endpoint: 'openAI' }; let conversationUpdate = { conversationId: newConversationId, endpoint: 'openAI' };
@ -200,12 +244,11 @@ const ask = async ({
requestMessage: userMessage, requestMessage: userMessage,
responseMessage: responseMessage responseMessage: responseMessage
}); });
abortControllers.delete(abortKey);
res.end(); res.end();
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') { if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage }); const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
await updateConvo(req?.session?.user?.username, { await saveConvo(req?.session?.user?.username, {
conversationId: conversationId, conversationId: conversationId,
title title
}); });
@ -213,10 +256,12 @@ const ask = async ({
} catch (error) { } catch (error) {
console.error(error); console.error(error);
const errorMessage = { const errorMessage = {
messageId: crypto.randomUUID(), messageId: responseMessageId,
sender: endpointOption?.chatGptLabel || 'ChatGPT', sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId, conversationId,
parentMessageId: overrideParentMessageId || userMessageId, parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true, error: true,
text: error.message text: error.message
}; };

View file

@ -17,7 +17,7 @@ const sendMessage = (res, message) => {
res.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`); res.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`);
}; };
const createOnProgress = () => { const createOnProgress = ({ onProgress: _onProgress }) => {
let i = 0; let i = 0;
let code = ''; let code = '';
let tokens = ''; let tokens = '';
@ -65,14 +65,21 @@ const createOnProgress = () => {
} }
sendMessage(res, { text: tokens + cursor, message: true, initial: i === 0, ...rest }); sendMessage(res, { text: tokens + cursor, message: true, initial: i === 0, ...rest });
_onProgress && _onProgress({ text: tokens, message: true, initial: i === 0, ...rest });
i++; i++;
}; };
const onProgress = opts => { const onProgress = (opts) => {
return _.partialRight(progressCallback, opts); return _.partialRight(progressCallback, opts);
}; };
return onProgress; const getPartialText = () => {
return tokens;
};
return { onProgress, getPartialText };
}; };
const handleText = async (response, bing = false) => { const handleText = async (response, bing = false) => {

View file

@ -2,7 +2,7 @@ const express = require('express');
const router = express.Router(); const router = express.Router();
const { titleConvo } = require('../../app/'); const { titleConvo } = require('../../app/');
const { getConvo, saveConvo, getConvoTitle } = require('../../models'); const { getConvo, saveConvo, getConvoTitle } = require('../../models');
const { getConvosByPage, deleteConvos, updateConvo } = require('../../models/Conversation'); const { getConvosByPage, deleteConvos } = require('../../models/Conversation');
const { getMessages } = require('../../models/Message'); const { getMessages } = require('../../models/Message');
router.get('/', async (req, res) => { router.get('/', async (req, res) => {
@ -44,7 +44,7 @@ router.post('/update', async (req, res) => {
const update = req.body.arg; const update = req.body.arg;
try { try {
const dbResponse = await updateConvo(req?.session?.user?.username, update); const dbResponse = await saveConvo(req?.session?.user?.username, update);
res.status(201).send(dbResponse); res.status(201).send(dbResponse);
} catch (error) { } catch (error) {
console.error(error); console.error(error);

View file

@ -33,7 +33,7 @@ export default function TextChat({ isSearchView = false }) {
// const bingStylesRef = useRef(null); // const bingStylesRef = useRef(null);
const [showBingToneSetting, setShowBingToneSetting] = useState(false); const [showBingToneSetting, setShowBingToneSetting] = useState(false);
const isNotAppendable = latestMessage?.cancelled || latestMessage?.error; const isNotAppendable = latestMessage?.unfinished || latestMessage?.error;
// auto focus to input, when enter a conversation. // auto focus to input, when enter a conversation.
useEffect(() => { useEffect(() => {

View file

@ -3,7 +3,6 @@ import { useRecoilValue, useRecoilState, useResetRecoilState, useSetRecoilState
import { SSE } from '~/data-provider/sse.mjs'; import { SSE } from '~/data-provider/sse.mjs';
import createPayload from '~/data-provider/createPayload'; import createPayload from '~/data-provider/createPayload';
import { useAbortRequestWithMessage } from '~/data-provider'; import { useAbortRequestWithMessage } from '~/data-provider';
import { v4 } from 'uuid';
import store from '~/store'; import store from '~/store';
export default function MessageHandler() { export default function MessageHandler() {
@ -12,12 +11,6 @@ export default function MessageHandler() {
const setMessages = useSetRecoilState(store.messages); const setMessages = useSetRecoilState(store.messages);
const setConversation = useSetRecoilState(store.conversation); const setConversation = useSetRecoilState(store.conversation);
const resetLatestMessage = useResetRecoilState(store.latestMessage); const resetLatestMessage = useResetRecoilState(store.latestMessage);
const [lastResponse, setLastResponse] = useRecoilState(store.lastResponse);
const setLatestMessage = useSetRecoilState(store.latestMessage);
const setSubmission = useSetRecoilState(store.submission);
const [source, setSource] = useState(null);
// const [abortKey, setAbortKey] = useState(null);
const [currentParent, setCurrentParent] = useState(null);
const { refreshConversations } = store.useConversations(); const { refreshConversations } = store.useConversations();
@ -32,7 +25,8 @@ export default function MessageHandler() {
text: data, text: data,
parentMessageId: message?.overrideParentMessageId, parentMessageId: message?.overrideParentMessageId,
messageId: message?.overrideParentMessageId + '_', messageId: message?.overrideParentMessageId + '_',
submitting: true submitting: true,
unfinished: true
} }
]); ]);
else else
@ -44,45 +38,38 @@ export default function MessageHandler() {
text: data, text: data,
parentMessageId: message?.messageId, parentMessageId: message?.messageId,
messageId: message?.messageId + '_', messageId: message?.messageId + '_',
submitting: true submitting: true,
unfinished: true
} }
]); ]);
}; };
const cancelHandler = (data, submission) => { const cancelHandler = (data, submission) => {
const { messages, message, initialResponse, isRegenerate = false } = submission; const { messages, isRegenerate = false } = submission;
const { text, messageId, parentMessageId } = data;
if (isRegenerate) { const { requestMessage, responseMessage, conversation } = data;
setMessages([
...messages, // update the messages
{ if (isRegenerate) setMessages([...messages, responseMessage]);
...initialResponse, else setMessages([...messages, requestMessage, responseMessage]);
text: data, setIsSubmitting(false);
parentMessageId: message?.overrideParentMessageId,
messageId: message?.overrideParentMessageId + '_', // refresh title
cancelled: true if (requestMessage.parentMessageId == '00000000-0000-0000-0000-000000000000') {
} setTimeout(() => {
]); refreshConversations();
} else { }, 2000);
console.log('cancelHandler, isRegenerate = false');
setMessages([ // in case it takes too long.
...messages, setTimeout(() => {
message, refreshConversations();
{ }, 5000);
...initialResponse,
text,
parentMessageId: message?.messageId,
messageId,
// cancelled: true
}
]);
setLastResponse('');
setSource(null);
setIsSubmitting(false);
setSubmission(null);
setLatestMessage(data);
} }
setConversation(prevState => ({
...prevState,
...conversation
}));
}; };
const createdHandler = (data, submission) => { const createdHandler = (data, submission) => {
@ -160,50 +147,37 @@ export default function MessageHandler() {
return; return;
}; };
const abortConversation = conversationId => {
console.log(submission);
const { endpoint } = submission?.conversation || {};
fetch(`/api/ask/${endpoint}/abort`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
abortKey: conversationId
})
})
.then(response => response.json())
.then(data => {
console.log('aborted', data);
cancelHandler(data, submission);
})
.catch(error => {
console.error('Error aborting request');
console.error(error);
// errorHandler({ text: 'Error aborting request' }, { ...submission, message });
});
return;
};
useEffect(() => { useEffect(() => {
if (submission === null) return; if (submission === null) return;
if (Object.keys(submission).length === 0) return; if (Object.keys(submission).length === 0) return;
let { message, cancel } = submission; let { message } = submission;
if (cancel && source) {
console.log('message aborted', submission);
source.close();
const { endpoint } = submission.conversation;
// splitting twice because the cursor may or may not be wrapped in a span
const latestMessageText = lastResponse.split('█')[0].split('<span className="result-streaming">')[0];
const latestMessage = {
text: latestMessageText,
messageId: v4(),
parentMessageId: currentParent.messageId,
conversationId: currentParent.conversationId
};
fetch(`/api/ask/${endpoint}/abort`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
abortKey: currentParent.conversationId,
latestMessage
})
})
.then(response => {
if (response.ok) {
console.log('Request aborted');
} else {
console.error('Error aborting request');
}
})
.catch(error => {
console.error(error);
});
console.log('source closed, got this far');
cancelHandler(latestMessage, { ...submission, message });
return;
}
const { server, payload } = createPayload(submission); const { server, payload } = createPayload(submission);
@ -212,9 +186,6 @@ export default function MessageHandler() {
headers: { 'Content-Type': 'application/json' } headers: { 'Content-Type': 'application/json' }
}); });
setSource(events);
// let latestResponseText = '';
events.onmessage = e => { events.onmessage = e => {
const data = JSON.parse(e.data); const data = JSON.parse(e.data);
@ -229,24 +200,19 @@ export default function MessageHandler() {
}; };
createdHandler(data, { ...submission, message }); createdHandler(data, { ...submission, message });
console.log('created', message); console.log('created', message);
// setAbortKey(message?.conversationId);
setCurrentParent(message);
} else { } else {
let text = data.text || data.response; let text = data.text || data.response;
if (data.initial) console.log(data); if (data.initial) console.log(data);
if (data.message) { if (data.message) {
// latestResponseText = text;
setLastResponse(text);
messageHandler(text, { ...submission, message }); messageHandler(text, { ...submission, message });
} }
// console.log('dataStream', data);
} }
}; };
events.onopen = () => console.log('connection is opened'); events.onopen = () => console.log('connection is opened');
// events.oncancel = () => cancelHandler(latestResponseText, { ...submission, message }); events.oncancel = () => abortConversation(message?.conversationId || submission?.conversationId);
events.onerror = function (e) { events.onerror = function (e) {
console.log('error in opening conn.'); console.log('error in opening conn.');
@ -263,7 +229,7 @@ export default function MessageHandler() {
return () => { return () => {
const isCancelled = events.readyState <= 1; const isCancelled = events.readyState <= 1;
events.close(); events.close();
setSource(null); // setSource(null);
if (isCancelled) { if (isCancelled) {
const e = new Event('cancel'); const e = new Event('cancel');
events.dispatchEvent(e); events.dispatchEvent(e);

View file

@ -1,4 +1,4 @@
import { useState, useEffect, useRef } from 'react'; import React, { useState, useEffect, useRef } from 'react';
import { useRecoilValue, useSetRecoilState } from 'recoil'; import { useRecoilValue, useSetRecoilState } from 'recoil';
import copy from 'copy-to-clipboard'; import copy from 'copy-to-clipboard';
import SubRow from './Content/SubRow'; import SubRow from './Content/SubRow';
@ -22,7 +22,7 @@ export default function Message({
siblingCount, siblingCount,
setSiblingIdx setSiblingIdx
}) { }) {
const { text, searchResult, isCreatedByUser, error, submitting } = message; const { text, searchResult, isCreatedByUser, error, submitting, unfinished, cancelled } = message;
const isSubmitting = useRecoilValue(store.isSubmitting); const isSubmitting = useRecoilValue(store.isSubmitting);
const setLatestMessage = useSetRecoilState(store.latestMessage); const setLatestMessage = useSetRecoilState(store.latestMessage);
const [abortScroll, setAbort] = useState(false); const [abortScroll, setAbort] = useState(false);
@ -98,7 +98,7 @@ export default function Message({
const clickSearchResult = async () => { const clickSearchResult = async () => {
if (!searchResult) return; if (!searchResult) return;
getConversationQuery.refetch(message.conversationId).then((response) => { getConversationQuery.refetch(message.conversationId).then(response => {
switchToConversation(response.data); switchToConversation(response.data);
}); });
}; };
@ -170,23 +170,39 @@ export default function Message({
</div> </div>
</div> </div>
) : ( ) : (
<div <>
className={cn( <div
'flex min-h-[20px] flex-grow flex-col items-start gap-4 ', className={cn(
isCreatedByUser ? 'whitespace-pre-wrap' : '' 'flex min-h-[20px] flex-grow flex-col items-start gap-4 ',
)} isCreatedByUser ? 'whitespace-pre-wrap' : ''
>
{/* <div className={`${blinker ? 'result-streaming' : ''} markdown prose dark:prose-invert light w-full break-words`}> */}
<div className="markdown prose dark:prose-invert light w-full break-words">
{!isCreatedByUser ? (
<>
<Content content={text} />
</>
) : (
<>{text}</>
)} )}
>
{/* <div className={`${blinker ? 'result-streaming' : ''} markdown prose dark:prose-invert light w-full break-words`}> */}
<div className="markdown prose dark:prose-invert light w-full break-words">
{!isCreatedByUser ? (
<>
<Content content={text} />
</>
) : (
<>{text}</>
)}
</div>
</div> </div>
</div> {!submitting && cancelled ? (
<div className="flex flex min-h-[20px] flex-grow flex-col items-start gap-2 gap-4 text-red-500">
<div className="rounded-md border border-blue-400 bg-blue-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-100">
{`This is a cancelled message.`}
</div>
</div>
) : null}
{!submitting && unfinished ? (
<div className="flex flex min-h-[20px] flex-grow flex-col items-start gap-2 gap-4 text-red-500">
<div className="rounded-md border border-blue-400 bg-blue-500/10 px-3 py-2 text-sm text-gray-600 dark:text-gray-100">
{`This is an unfinished message. It might because the AI is still generating or it has been aborted. Refresh later to see more updates.`}
</div>
</div>
) : null}
</>
)} )}
</div> </div>
<HoverButtons <HoverButtons

View file

@ -37,16 +37,16 @@ export default function ExportModel({ open, onOpenChange }) {
); );
const typeOptions = [ const typeOptions = [
{ value: 'screenshot', display: 'screenshot (.png)' },
{ value: 'text', display: 'text (.txt)' }, { value: 'text', display: 'text (.txt)' },
{ value: 'markdown', display: 'markdown (.md)' }, { value: 'markdown', display: 'markdown (.md)' },
{ value: 'csv', display: 'csv (.csv)' },
{ value: 'json', display: 'json (.json)' }, { value: 'json', display: 'json (.json)' },
{ value: 'screenshot', display: 'screenshot (.png)' } { value: 'csv', display: 'csv (.csv)' }
]; //,, 'webpage']; ]; //,, 'webpage'];
useEffect(() => { useEffect(() => {
setFileName(filenamify(String(conversation?.title || 'file'))); setFileName(filenamify(String(conversation?.title || 'file')));
setType('text'); setType('screenshot');
setIncludeOptions(true); setIncludeOptions(true);
setExportBranches(false); setExportBranches(false);
setRecursive(true); setRecursive(true);
@ -144,6 +144,8 @@ export default function ExportModel({ open, onOpenChange }) {
fieldValues: entries.find(e => e.fieldName == 'isCreatedByUser').fieldValues fieldValues: entries.find(e => e.fieldName == 'isCreatedByUser').fieldValues
}, },
{ fieldName: 'error', fieldValues: entries.find(e => e.fieldName == 'error').fieldValues }, { fieldName: 'error', fieldValues: entries.find(e => e.fieldName == 'error').fieldValues },
{ fieldName: 'unfinished', fieldValues: entries.find(e => e.fieldName == 'unfinished').fieldValues },
{ fieldName: 'cancelled', fieldValues: entries.find(e => e.fieldName == 'cancelled').fieldValues },
{ fieldName: 'messageId', fieldValues: entries.find(e => e.fieldName == 'messageId').fieldValues }, { fieldName: 'messageId', fieldValues: entries.find(e => e.fieldName == 'messageId').fieldValues },
{ {
fieldName: 'parentMessageId', fieldName: 'parentMessageId',
@ -181,7 +183,11 @@ export default function ExportModel({ open, onOpenChange }) {
data += `\n## History\n`; data += `\n## History\n`;
for (const message of messages) { for (const message of messages) {
data += `**${message?.sender}:**\n${message?.text}\n\n`; data += `**${message?.sender}:**\n${message?.text}\n`;
if (message.error) data += `*(This is an error message)*\n`;
if (message.unfinished) data += `*(This is an unfinished message)*\n`;
if (message.cancelled) data += `*(This is a cancelled message)*\n`;
data += '\n\n';
} }
exportFromJSON({ exportFromJSON({
@ -220,7 +226,11 @@ export default function ExportModel({ open, onOpenChange }) {
data += `\nHistory\n########################\n`; data += `\nHistory\n########################\n`;
for (const message of messages) { for (const message of messages) {
data += `${message?.sender}:\n${message?.text}\n\n`; data += `>> ${message?.sender}:\n${message?.text}\n`;
if (message.error) data += `(This is an error message)\n`;
if (message.unfinished) data += `(This is an unfinished message)\n`;
if (message.cancelled) data += `(This is a cancelled message)\n`;
data += '\n\n';
} }
exportFromJSON({ exportFromJSON({

View file

@ -142,8 +142,7 @@ const useMessageHandler = () => {
}; };
const stopGenerating = () => { const stopGenerating = () => {
// setSubmission(null); setSubmission(null);
setSubmission(prev => ({ ...prev, cancel: true }));
}; };
return { ask, regenerate, stopGenerating }; return { ask, regenerate, stopGenerating };