chore: refactor progressCB to one place, fix sydney, and sanitize html

This commit is contained in:
Danny Avila 2023-03-14 15:42:59 -04:00
parent 9a17e94f8f
commit 2e20b28c4d
12 changed files with 351 additions and 69 deletions

View file

@ -11,7 +11,7 @@ const {
detectCode
} = require('../../app/');
const { getConvo, saveMessage, getConvoTitle, saveConvo } = require('../../models');
const { handleError, sendMessage } = require('./handlers');
const { handleError, sendMessage, createOnProgress } = require('./handlers');
const { getMessages } = require('../../models/Message');
router.use('/bing', askBing);
@ -138,37 +138,10 @@ const ask = async ({
sendMessage(res, { message: userMessage, created: true });
try {
let i = 0;
let tokens = '';
const progressCallback = async (partial) => {
if (i === 0 && typeof partial === 'object') {
userMessage.conversationId = conversationId ? conversationId : partial.conversationId;
await saveMessage(userMessage);
sendMessage(res, { ...partial, parentMessageId: overrideParentMessageId || userMessageId, initial: true });
i++;
}
if (typeof partial === 'object') {
sendMessage(res, { ...partial, parentMessageId: overrideParentMessageId || userMessageId, message: true });
} else {
tokens += partial === text ? '' : partial;
if (tokens.match(/^\n/)) {
tokens = tokens.replace(/^\n/, '');
}
if (tokens.includes('[DONE]')) {
tokens = tokens.replace('[DONE]', '');
}
// tokens = await detectCode(tokens);
sendMessage(res, { text: tokens, message: true, initial: i === 0 ? true : false });
i++;
}
};
const progressCallback = createOnProgress();
let gptResponse = await client({
text,
progressCallback,
onProgress: progressCallback.call(null, model, {res, text }),
convo: {
parentMessageId: userParentMessageId,
conversationId,

View file

@ -3,7 +3,7 @@ const crypto = require('crypto');
const router = express.Router();
const { titleConvo, getCitations, citeText, askBing } = require('../../app/');
const { saveMessage, getConvoTitle, saveConvo } = require('../../models');
const { handleError, sendMessage } = require('./handlers');
const { handleError, sendMessage, createOnProgress } = require('./handlers');
const citationRegex = /\[\^\d+?\^]/g;
router.post('/', async (req, res) => {
@ -68,17 +68,10 @@ const ask = async ({
sendMessage(res, { message: userMessage, created: true });
try {
let tokens = '';
const progressCallback = async (partial) => {
tokens += partial === text ? '' : partial;
// tokens = appendCode(tokens);
tokens = citeText(tokens, true);
sendMessage(res, { text: tokens, message: true, parentMessageId: overrideParentMessageId || userMessageId });
};
const progressCallback = createOnProgress();
let response = await askBing({
text,
progressCallback,
onProgress: progressCallback.call(null, model, {res, text, parentMessageId: overrideParentMessageId || userMessageId }),
convo: {
...convo,
parentMessageId: userParentMessageId,

View file

@ -3,7 +3,7 @@ const crypto = require('crypto');
const router = express.Router();
const { titleConvo, getCitations, citeText, askSydney } = require('../../app/');
const { saveMessage, saveConvo, getConvoTitle } = require('../../models');
const { handleError, sendMessage } = require('./handlers');
const { handleError, sendMessage, createOnProgress } = require('./handlers');
const citationRegex = /\[\^\d+?\^]/g;
router.post('/', async (req, res) => {
@ -68,17 +68,10 @@ const ask = async ({
sendMessage(res, { message: userMessage, created: true });
try {
let tokens = '';
const progressCallback = async (partial) => {
tokens += partial === text ? '' : partial;
// tokens = appendCode(tokens);
tokens = citeText(tokens, true);
sendMessage(res, { text: tokens, message: true, parentMessageId: overrideParentMessageId || userMessageId });
};
const progressCallback = createOnProgress();
let response = await askSydney({
text,
progressCallback,
onProgress: progressCallback.call(null, model, {res, text, parentMessageId: overrideParentMessageId || userMessageId }),
convo: {
parentMessageId: userParentMessageId,
conversationId,

View file

@ -1,3 +1,7 @@
const { citeText } = require('../../app/');
const _ = require('lodash');
const sanitizeHtml = require('sanitize-html');
const handleError = (res, message) => {
res.write(`event: error\ndata: ${JSON.stringify(message)}\n\n`);
res.end();
@ -10,4 +14,32 @@ const sendMessage = (res, message) => {
res.write(`event: message\ndata: ${JSON.stringify(message)}\n\n`);
};
module.exports = { handleError, sendMessage };
const createOnProgress = () => {
let i = 0;
let tokens = '';
const progressCallback = async (partial, { res, text, bing = false, ...rest }) => {
tokens += partial === text ? '' : partial;
tokens = tokens.trim();
tokens = tokens.replaceAll('[DONE]', '');
if (tokens.includes('```')) {
tokens = sanitizeHtml(tokens);
}
if (bing) {
tokens = citeText(tokens, true);
}
sendMessage(res, { text: tokens, message: true, initial: i === 0, ...rest });
i++;
};
const onProgress = (model, opts) => {
const bingModels = new Set(['bingai', 'sydney']);
return _.partialRight(progressCallback, { ...opts, bing: bingModels.has(model) });
};
return onProgress;
};
module.exports = { handleError, sendMessage, createOnProgress };