merge from dannya

feat: support unfinished messages.
This commit is contained in:
Wentao Lyu 2023-04-11 03:26:38 +08:00
parent bbf2f8a6ca
commit a5a0eab7f7
15 changed files with 308 additions and 221 deletions

View file

@ -13,45 +13,21 @@ const getConvo = async (user, conversationId) => {
module.exports = {
Conversation,
saveConvo: async (user, { conversationId, newConversationId, title, ...convo }) => {
saveConvo: async (user, { conversationId, newConversationId, ...convo }) => {
try {
const messages = await getMessages({ conversationId });
const update = { ...convo, messages };
if (title) {
update.title = title;
update.user = user;
}
const update = { ...convo, messages, user };
if (newConversationId) {
update.conversationId = newConversationId;
}
if (!update.jailbreakConversationId) {
update.jailbreakConversationId = null;
}
return await Conversation.findOneAndUpdate(
{ conversationId: conversationId, user },
{ $set: update },
{ new: true, upsert: true }
).exec();
} catch (error) {
console.log(error);
return { message: 'Error saving conversation' };
}
},
updateConvo: async (user, { conversationId, oldConvoId, ...update }) => {
try {
let convoId = conversationId;
if (oldConvoId) {
convoId = oldConvoId;
update.conversationId = conversationId;
}
return await Conversation.findOneAndUpdate({ conversationId: convoId, user }, update, {
new: true
return await Conversation.findOneAndUpdate({ conversationId: conversationId, user }, update, {
new: true,
upsert: true
}).exec();
} catch (error) {
console.log(error);
return { message: 'Error updating conversation' };
return { message: 'Error saving conversation' };
}
},
getConvosByPage: async (user, pageNumber = 1, pageSize = 12) => {
@ -82,7 +58,7 @@ module.exports = {
// will handle a syncing solution soon
const deletedConvoIds = [];
convoIds.forEach(convo =>
convoIds.forEach((convo) =>
promises.push(
Conversation.findOne({
user,
@ -145,7 +121,7 @@ module.exports = {
},
deleteConvos: async (user, filter) => {
let toRemove = await Conversation.find({ ...filter, user }).select('conversationId');
const ids = toRemove.map(instance => instance.conversationId);
const ids = toRemove.map((instance) => instance.conversationId);
let deleteCount = await Conversation.deleteMany({ ...filter, user }).exec();
deleteCount.messages = await deleteMessages({ conversationId: { $in: ids } });
return deleteCount;

View file

@ -9,7 +9,9 @@ module.exports = {
sender,
text,
isCreatedByUser = false,
error
error,
unfinished,
cancelled
}) => {
try {
// may also need to update the conversation here
@ -22,7 +24,9 @@ module.exports = {
sender,
text,
isCreatedByUser,
error
error,
unfinished,
cancelled
},
{ upsert: true, new: true }
);
@ -45,7 +49,7 @@ module.exports = {
return { message: 'Error deleting messages' };
}
},
getMessages: async filter => {
getMessages: async (filter) => {
try {
return await Message.find(filter).sort({ createdAt: 1 }).exec();
} catch (error) {
@ -53,7 +57,7 @@ module.exports = {
return { message: 'Error getting messages' };
}
},
deleteMessages: async filter => {
deleteMessages: async (filter) => {
try {
return await Message.deleteMany(filter).exec();
} catch (error) {

View file

@ -1,5 +1,5 @@
const { getMessages, saveMessage, deleteMessagesSince, deleteMessages } = require('./Message');
const { getConvoTitle, getConvo, saveConvo, updateConvo } = require('./Conversation');
const { getConvoTitle, getConvo, saveConvo } = require('./Conversation');
const { getPreset, getPresets, savePreset, deletePresets } = require('./Preset');
module.exports = {
@ -11,7 +11,6 @@ module.exports = {
getConvoTitle,
getConvo,
saveConvo,
updateConvo,
getPreset,
getPresets,

View file

@ -43,6 +43,14 @@ const messageSchema = mongoose.Schema(
required: true,
default: false
},
unfinished: {
type: Boolean,
default: false
},
cancelled: {
type: Boolean,
default: false
},
error: {
type: Boolean,
default: false

View file

@ -2,14 +2,24 @@ const Keyv = require('keyv');
const { KeyvFile } = require('keyv-file');
const { saveMessage } = require('../../../models');
const addToCache = async ({ endpointOption, userMessage, latestMessage }) => {
const addToCache = async ({ endpoint, endpointOption, userMessage, responseMessage }) => {
try {
const conversationsCache = new Keyv({
store: new KeyvFile({ filename: './data/cache.json' }),
namespace: 'chatgpt' // should be 'bing' for bing/sydney
});
const { conversationId, messageId, parentMessageId, text } = latestMessage;
const {
conversationId,
messageId: userMessageId,
parentMessageId: userParentMessageId,
text: userText
} = userMessage;
const {
messageId: responseMessageId,
parentMessageId: responseParentMessageId,
text: responseText
} = responseMessage;
let conversation = await conversationsCache.get(conversationId);
// used to generate a title for the conversation if none exists
@ -22,10 +32,7 @@ const addToCache = async ({ endpointOption, userMessage, latestMessage }) => {
// isNewConversation = true;
}
// const shouldGenerateTitle = opts.shouldGenerateTitle && isNewConversation;
const roles = (options) => {
const { endpoint } = options;
if (endpoint === 'openAI') {
return options?.chatGptLabel || 'ChatGPT';
} else if (endpoint === 'bingAI') {
@ -33,24 +40,21 @@ const addToCache = async ({ endpointOption, userMessage, latestMessage }) => {
}
};
// const messageId = crypto.randomUUID();
let responseMessage = {
id: messageId,
parentMessageId,
role: roles(endpointOption),
message: text
let _userMessage = {
id: userMessageId,
parentMessageId: userParentMessageId,
role: 'User',
message: userText
};
await saveMessage({
...responseMessage,
conversationId,
messageId,
sender: responseMessage.role,
text
});
let _responseMessage = {
id: responseMessageId,
parentMessageId: responseParentMessageId,
role: roles(endpointOption),
message: responseText
};
conversation.messages.push(userMessage, responseMessage);
conversation.messages.push(_userMessage, _responseMessage);
await conversationsCache.set(conversationId, conversation);
} catch (error) {

View file

@ -2,7 +2,7 @@ const express = require('express');
const crypto = require('crypto');
const router = express.Router();
const { titleConvo, askBing } = require('../../../app');
const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
router.post('/', async (req, res) => {
@ -95,6 +95,8 @@ const ask = async ({
}) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
let responseMessageId = crypto.randomUUID();
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
@ -106,9 +108,26 @@ const ask = async ({
if (preSendRequest) sendMessage(res, { message: userMessage, created: true });
try {
const progressCallback = createOnProgress();
let lastSavedTimestamp = 0;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: text,
unfinished: true,
cancelled: false,
error: false
});
}
}
});
const abortController = new AbortController();
res.on('close', () => abortController.abort());
let response = await askBing({
text,
parentMessageId: userParentMessageId,
@ -135,14 +154,20 @@ const ask = async ({
let responseMessage = {
conversationId: newConversationId,
messageId: newResponseMessageId,
messageId: responseMessageId,
newMessageId: newResponseMessageId,
parentMessageId: overrideParentMessageId || newUserMassageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
text: await handleText(response, true),
suggestions: response.details.suggestedResponses && response.details.suggestedResponses.map(s => s.text)
suggestions:
response.details.suggestedResponses && response.details.suggestedResponses.map((s) => s.text),
unfinished: false,
cancelled: false,
error: false
};
await saveMessage(responseMessage);
responseMessage.messageId = newResponseMessageId;
// STEP2 update the convosation.
@ -204,7 +229,7 @@ const ask = async ({
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
await updateConvo(req?.session?.user?.username, {
await saveConvo(req?.session?.user?.username, {
conversationId: conversationId,
title
});
@ -212,10 +237,12 @@ const ask = async ({
} catch (error) {
console.log(error);
const errorMessage = {
messageId: crypto.randomUUID(),
messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message
};

View file

@ -3,7 +3,7 @@ const crypto = require('crypto');
const router = express.Router();
const { getChatGPTBrowserModels } = require('../endpoints');
const { browserClient } = require('../../../app/');
const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
router.post('/', async (req, res) => {
@ -38,7 +38,7 @@ router.post('/', async (req, res) => {
};
const availableModels = getChatGPTBrowserModels();
if (availableModels.find(model => model === endpointOption.model) === undefined)
if (availableModels.find((model) => model === endpointOption.model) === undefined)
return handleError(res, { text: 'Illegal request: model' });
console.log('ask log', {
@ -92,10 +92,29 @@ const ask = async ({
if (preSendRequest) sendMessage(res, { message: userMessage, created: true });
let responseMessageId = crypto.randomUUID();
try {
const progressCallback = createOnProgress();
let lastSavedTimestamp = 0;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: endpointOption?.jailbreak ? 'Sydney' : 'BingAI',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: text,
unfinished: true,
cancelled: false,
error: false
});
}
}
});
const abortController = new AbortController();
res.on('close', () => abortController.abort());
let response = await browserClient({
text,
parentMessageId: userParentMessageId,
@ -116,13 +135,18 @@ const ask = async ({
let responseMessage = {
conversationId: newConversationId,
messageId: newResponseMessageId,
messageId: responseMessageId,
newMessageId: newResponseMessageId,
parentMessageId: overrideParentMessageId || newUserMassageId,
text: await handleText(response),
sender: endpointOption?.chatGptLabel || 'ChatGPT'
sender: endpointOption?.chatGptLabel || 'ChatGPT',
unfinished: false,
cancelled: false,
error: false
};
await saveMessage(responseMessage);
responseMessage.messageId = newResponseMessageId;
// STEP2 update the conversation
@ -168,17 +192,19 @@ const ask = async ({
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
// const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
const title = await response.details.title;
await updateConvo(req?.session?.user?.username, {
await saveConvo(req?.session?.user?.username, {
conversationId: conversationId,
title
});
}
} catch (error) {
const errorMessage = {
messageId: crypto.randomUUID(),
messageId: responseMessageId,
sender: 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message
};

View file

@ -4,29 +4,26 @@ const router = express.Router();
const addToCache = require('./addToCache');
const { getOpenAIModels } = require('../endpoints');
const { titleConvo, askClient } = require('../../../app/');
const { saveMessage, getConvoTitle, saveConvo, updateConvo, getConvo } = require('../../../models');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { handleError, sendMessage, createOnProgress, handleText } = require('./handlers');
const abortControllers = new Map();
router.post('/abort', async (req, res) => {
const { abortKey, latestMessage } = req.body;
const { abortKey } = req.body;
console.log(`req.body`, req.body);
if (!abortControllers.has(abortKey)) {
return res.status(404).send('Request not found');
}
const { abortController, userMessage, endpointOption } = abortControllers.get(abortKey);
if (!endpointOption.endpoint) {
endpointOption.endpoint = req.originalUrl.replace('/api/ask/','').split('/abort')[0];
}
abortController.abort();
const { abortController } = abortControllers.get(abortKey);
abortControllers.delete(abortKey);
console.log('Aborted request', abortKey, userMessage, endpointOption);
await addToCache({ endpointOption, userMessage, latestMessage });
res.status(200).send('Aborted');
const ret = await abortController.abortAsk();
console.log('Aborted request', abortKey);
console.log('Aborted message:', ret);
res.send(JSON.stringify(ret));
});
router.post('/', async (req, res) => {
@ -66,7 +63,7 @@ router.post('/', async (req, res) => {
};
const availableModels = getOpenAIModels();
if (availableModels.find(model => model === endpointOption.model) === undefined)
if (availableModels.find((model) => model === endpointOption.model) === undefined)
return handleError(res, { text: 'Illegal request: model' });
console.log('ask log', {
@ -110,6 +107,8 @@ const ask = async ({
}) => {
let { text, parentMessageId: userParentMessageId, messageId: userMessageId } = userMessage;
let responseMessageId = crypto.randomUUID();
res.writeHead(200, {
Connection: 'keep-alive',
'Content-Type': 'text/event-stream',
@ -121,16 +120,55 @@ const ask = async ({
if (preSendRequest) sendMessage(res, { message: userMessage, created: true });
try {
const progressCallback = createOnProgress();
const abortController = new AbortController();
const abortKey = conversationId;
console.log('conversationId -----> ', conversationId);
abortControllers.set(abortKey, { abortController, userMessage, endpointOption });
res.on('close', () => {
abortController.abort();
return res.end();
let lastSavedTimestamp = 0;
const { onProgress: progressCallback, getPartialText } = createOnProgress({
onProgress: ({ text }) => {
const currentTimestamp = Date.now();
if (currentTimestamp - lastSavedTimestamp > 500) {
lastSavedTimestamp = currentTimestamp;
saveMessage({
messageId: responseMessageId,
sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: text,
unfinished: true,
cancelled: false,
error: false
});
}
}
});
let abortController = new AbortController();
abortController.abortAsk = async function () {
this.abort();
const responseMessage = {
messageId: responseMessageId,
sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
text: getPartialText(),
unfinished: false,
cancelled: true,
error: false
};
saveMessage(responseMessage);
await addToCache({ endpoint: 'openAI', endpointOption, userMessage, responseMessage });
return {
title: await getConvoTitle(req?.session?.user?.username, conversationId),
final: true,
conversation: await getConvo(req?.session?.user?.username, conversationId),
requestMessage: userMessage,
responseMessage: responseMessage
};
};
const abortKey = conversationId;
abortControllers.set(abortKey, { abortController, ...endpointOption });
let response = await askClient({
text,
parentMessageId: userParentMessageId,
@ -144,6 +182,7 @@ const ask = async ({
abortController
});
abortControllers.delete(abortKey);
console.log('CLIENT RESPONSE', response);
const newConversationId = response.conversationId || conversationId;
@ -155,13 +194,18 @@ const ask = async ({
let responseMessage = {
conversationId: newConversationId,
messageId: newResponseMessageId,
messageId: responseMessageId,
newMessageId: newResponseMessageId,
parentMessageId: overrideParentMessageId || newUserMassageId,
text: await handleText(response),
sender: endpointOption?.chatGptLabel || 'ChatGPT'
sender: endpointOption?.chatGptLabel || 'ChatGPT',
unfinished: false,
cancelled: false,
error: false
};
await saveMessage(responseMessage);
responseMessage.messageId = newResponseMessageId;
// STEP2 update the conversation
let conversationUpdate = { conversationId: newConversationId, endpoint: 'openAI' };
@ -200,12 +244,11 @@ const ask = async ({
requestMessage: userMessage,
responseMessage: responseMessage
});
abortControllers.delete(abortKey);
res.end();
if (userParentMessageId == '00000000-0000-0000-0000-000000000000') {
const title = await titleConvo({ endpoint: endpointOption?.endpoint, text, response: responseMessage });
await updateConvo(req?.session?.user?.username, {
await saveConvo(req?.session?.user?.username, {
conversationId: conversationId,
title
});
@ -213,10 +256,12 @@ const ask = async ({
} catch (error) {
console.error(error);
const errorMessage = {
messageId: crypto.randomUUID(),
messageId: responseMessageId,
sender: endpointOption?.chatGptLabel || 'ChatGPT',
conversationId,
parentMessageId: overrideParentMessageId || userMessageId,
unfinished: false,
cancelled: false,
error: true,
text: error.message
};

View file

@ -17,7 +17,7 @@ const sendMessage = (res, message) => {
res.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`);
};
const createOnProgress = () => {
const createOnProgress = ({ onProgress: _onProgress }) => {
let i = 0;
let code = '';
let tokens = '';
@ -65,14 +65,21 @@ const createOnProgress = () => {
}
sendMessage(res, { text: tokens + cursor, message: true, initial: i === 0, ...rest });
_onProgress && _onProgress({ text: tokens, message: true, initial: i === 0, ...rest });
i++;
};
const onProgress = opts => {
const onProgress = (opts) => {
return _.partialRight(progressCallback, opts);
};
return onProgress;
const getPartialText = () => {
return tokens;
};
return { onProgress, getPartialText };
};
const handleText = async (response, bing = false) => {

View file

@ -2,7 +2,7 @@ const express = require('express');
const router = express.Router();
const { titleConvo } = require('../../app/');
const { getConvo, saveConvo, getConvoTitle } = require('../../models');
const { getConvosByPage, deleteConvos, updateConvo } = require('../../models/Conversation');
const { getConvosByPage, deleteConvos } = require('../../models/Conversation');
const { getMessages } = require('../../models/Message');
router.get('/', async (req, res) => {
@ -44,7 +44,7 @@ router.post('/update', async (req, res) => {
const update = req.body.arg;
try {
const dbResponse = await updateConvo(req?.session?.user?.username, update);
const dbResponse = await saveConvo(req?.session?.user?.username, update);
res.status(201).send(dbResponse);
} catch (error) {
console.error(error);